1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "ValidateHal"
18
19 #include "ValidateHal.h"
20 #include "NeuralNetworks.h"
21 #include "OperationsUtils.h"
22 #include "Tracing.h"
23 #include "Utils.h"
24
25 #include <android-base/logging.h>
26
27 namespace android {
28 namespace nn {
29
30 template <class T_Model>
31 struct ModelToHalVersion;
32 template <>
33 struct ModelToHalVersion<V1_0::Model> {
34 static constexpr HalVersion version = HalVersion::V1_0;
35 };
36 template <>
37 struct ModelToHalVersion<V1_1::Model> {
38 static constexpr HalVersion version = HalVersion::V1_1;
39 };
40 template <>
41 struct ModelToHalVersion<V1_2::Model> {
42 static constexpr HalVersion version = HalVersion::V1_2;
43 };
44
45 class MemoryAccessVerifier {
46 public:
MemoryAccessVerifier(const hidl_vec<hidl_memory> & pools)47 MemoryAccessVerifier(const hidl_vec<hidl_memory>& pools)
48 : mPoolCount(pools.size()), mPoolSizes(mPoolCount) {
49 for (size_t i = 0; i < mPoolCount; i++) {
50 mPoolSizes[i] = pools[i].size();
51 }
52 }
validate(const DataLocation & location)53 bool validate(const DataLocation& location) {
54 if (location.poolIndex >= mPoolCount) {
55 LOG(ERROR) << "Invalid poolIndex " << location.poolIndex << "/" << mPoolCount;
56 return false;
57 }
58 const size_t size = mPoolSizes[location.poolIndex];
59 // Do the addition using size_t to avoid potential wrap-around problems.
60 if (static_cast<size_t>(location.offset) + location.length > size) {
61 LOG(ERROR) << "Reference to pool " << location.poolIndex << " with offset "
62 << location.offset << " and length " << location.length
63 << " exceeds pool size of " << size;
64 return false;
65 }
66 return true;
67 }
68
69 private:
70 size_t mPoolCount;
71 std::vector<size_t> mPoolSizes;
72 };
73
validateOperandExtraParams(const V1_2::Operand & operand,uint32_t index)74 static bool validateOperandExtraParams(const V1_2::Operand& operand, uint32_t index) {
75 switch (operand.type) {
76 case OperandType::FLOAT32:
77 case OperandType::INT32:
78 case OperandType::UINT32:
79 case OperandType::BOOL:
80 case OperandType::TENSOR_FLOAT32:
81 case OperandType::TENSOR_FLOAT16:
82 case OperandType::TENSOR_INT32:
83 case OperandType::TENSOR_QUANT8_ASYMM:
84 case OperandType::TENSOR_QUANT8_SYMM:
85 case OperandType::TENSOR_QUANT16_ASYMM:
86 case OperandType::TENSOR_QUANT16_SYMM:
87 case OperandType::TENSOR_BOOL8: {
88 NN_RET_CHECK(operand.extraParams.getDiscriminator() ==
89 V1_2::Operand::ExtraParams::hidl_discriminator::none)
90 << "Operand " << index << ": Operand of type "
91 << getOperandTypeName(operand.type)
92 << " has incorrect extraParams: " << toString(operand.extraParams);
93 } break;
94 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: {
95 NN_RET_CHECK(operand.extraParams.getDiscriminator() ==
96 V1_2::Operand::ExtraParams::hidl_discriminator::channelQuant)
97 << "Operand " << index << ": Operand of type "
98 << getOperandTypeName(operand.type) << " without a Channel Quantization params";
99 auto& channelQuant = operand.extraParams.channelQuant();
100
101 size_t count = operand.dimensions.size();
102 NN_RET_CHECK_LT(channelQuant.channelDim, count)
103 << "Operand " << index << ": Operand of type "
104 << getOperandTypeName(operand.type)
105 << " with an invalid channelQuant.channelDim " << channelQuant.channelDim
106 << ", must be valid dimension index in range [0, " << count << ")";
107 uint32_t expected = operand.dimensions[channelQuant.channelDim];
108 NN_RET_CHECK_EQ(channelQuant.scales.size(), expected)
109 << "Operand " << index << ": Operand of type "
110 << getOperandTypeName(operand.type) << " with a wrong-sized scales, "
111 << "expected " << expected << " was " << channelQuant.scales.size();
112 NN_RET_CHECK_NE(expected, 0)
113 << "Operand " << index << ": Operand of type "
114 << getOperandTypeName(operand.type) << " channel dimension "
115 << channelQuant.channelDim << " is underspecified (can't be 0)";
116 for (uint32_t i = 0; i < expected; ++i) {
117 NN_RET_CHECK_GT(channelQuant.scales[i], .0f)
118 << "Operand " << index << ": Operand of type "
119 << getOperandTypeName(operand.type) << " with a negative value in scales["
120 << i << "]=" << channelQuant.scales[i];
121 }
122 } break;
123 default: {
124 if (isExtensionOperandType(operand.type)) {
125 NN_RET_CHECK(operand.extraParams.getDiscriminator() ==
126 V1_2::Operand::ExtraParams::hidl_discriminator::extension ||
127 operand.extraParams.getDiscriminator() ==
128 V1_2::Operand::ExtraParams::hidl_discriminator::none)
129 << "Operand " << index << ": Extension operand of type "
130 << getOperandTypeName(operand.type)
131 << " has incorrect extraParams: " << toString(operand.extraParams);
132 }
133 // No validation for OEM types.
134 } break;
135 }
136 return true;
137 }
138
139 template <typename VersionedOperand>
validateOperands(const hidl_vec<VersionedOperand> & operands,const hidl_vec<uint8_t> & operandValues,const hidl_vec<hidl_memory> & pools,bool allowUnspecifiedRank)140 static bool validateOperands(const hidl_vec<VersionedOperand>& operands,
141 const hidl_vec<uint8_t>& operandValues,
142 const hidl_vec<hidl_memory>& pools, bool allowUnspecifiedRank) {
143 uint32_t index = 0;
144 MemoryAccessVerifier poolVerifier(pools);
145 for (auto& versionedOperand : operands) {
146 if (!validOperandType(versionedOperand.type)) {
147 LOG(ERROR) << "Operand is not supported by this version: "
148 << toString(versionedOperand.type);
149 return false;
150 }
151 // Once we are sure the operand is supported by its version, it is safe
152 // to convert it to the latest version for the rest of the validations.
153 V1_2::Operand operand = convertToV1_2(versionedOperand);
154 // Validate type and dimensions.
155 switch (operand.type) {
156 case OperandType::FLOAT16:
157 case OperandType::FLOAT32:
158 case OperandType::INT32:
159 case OperandType::UINT32:
160 case OperandType::BOOL:
161 case OperandType::OEM: {
162 size_t count = operand.dimensions.size();
163 if (count != 0) {
164 LOG(ERROR) << "Operand " << index << ": Scalar data has dimensions of rank "
165 << count;
166 return false;
167 }
168 break;
169 }
170 case OperandType::TENSOR_FLOAT16:
171 case OperandType::TENSOR_FLOAT32:
172 case OperandType::TENSOR_INT32:
173 case OperandType::TENSOR_QUANT8_ASYMM:
174 case OperandType::TENSOR_QUANT8_SYMM:
175 case OperandType::TENSOR_QUANT16_ASYMM:
176 case OperandType::TENSOR_QUANT16_SYMM:
177 case OperandType::TENSOR_BOOL8:
178 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
179 case OperandType::TENSOR_OEM_BYTE: {
180 if ((!allowUnspecifiedRank || operand.lifetime == OperandLifeTime::CONSTANT_COPY ||
181 operand.lifetime == OperandLifeTime::CONSTANT_REFERENCE) &&
182 operand.dimensions.size() == 0) {
183 LOG(ERROR) << "Operand " << index << ": Tensor has dimensions of rank 0";
184 return false;
185 }
186 break;
187 }
188 default: {
189 if (!isExtensionOperandType(operand.type)) {
190 LOG(ERROR) << "Operand " << index << ": Invalid operand type "
191 << toString(operand.type);
192 return false;
193 }
194 } break;
195 }
196
197 // TODO Validate the numberOfConsumers.
198 // TODO Since we have to validate it, there was no point in including it. For the next
199 // release, consider removing unless we have an additional process in system space
200 // that creates this value. In that case, it would not have to be validated.
201
202 // Validate the scale.
203 switch (operand.type) {
204 case OperandType::FLOAT16:
205 case OperandType::FLOAT32:
206 case OperandType::INT32:
207 case OperandType::UINT32:
208 case OperandType::BOOL:
209 case OperandType::TENSOR_FLOAT16:
210 case OperandType::TENSOR_FLOAT32:
211 case OperandType::TENSOR_BOOL8:
212 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
213 if (operand.scale != 0.f) {
214 LOG(ERROR) << "Operand " << index << ": Operand of type "
215 << getOperandTypeName(operand.type) << " with a non-zero scale ("
216 << operand.scale << ")";
217 return false;
218 }
219 break;
220 case OperandType::TENSOR_INT32:
221 // TENSOR_INT32 may be used with or without scale, depending on the operation.
222 if (operand.scale < 0.f) {
223 LOG(ERROR) << "Operand " << index << ": Operand of type "
224 << getOperandTypeName(operand.type) << " with a negative scale";
225 return false;
226 }
227 break;
228 case OperandType::TENSOR_QUANT8_ASYMM:
229 case OperandType::TENSOR_QUANT8_SYMM:
230 case OperandType::TENSOR_QUANT16_ASYMM:
231 case OperandType::TENSOR_QUANT16_SYMM:
232 if (operand.scale <= 0.f) {
233 LOG(ERROR) << "Operand " << index << ": Operand of type "
234 << getOperandTypeName(operand.type) << " with a non-positive scale";
235 return false;
236 }
237 break;
238 default:
239 if (isExtensionOperandType(operand.type) && operand.scale != 0.f) {
240 LOG(ERROR) << "Operand " << index << ": Operand of type "
241 << getOperandTypeName(operand.type) << " with a non-zero scale ("
242 << operand.scale << ")";
243 return false;
244 }
245 // No validation for OEM types.
246 // TODO(b/119869082) We should have a separate type for TENSOR_INT32 with a scale.
247 break;
248 }
249
250 // Validate the zeroPoint.
251 switch (operand.type) {
252 case OperandType::FLOAT16:
253 case OperandType::FLOAT32:
254 case OperandType::INT32:
255 case OperandType::UINT32:
256 case OperandType::BOOL:
257 case OperandType::TENSOR_FLOAT16:
258 case OperandType::TENSOR_FLOAT32:
259 case OperandType::TENSOR_INT32:
260 case OperandType::TENSOR_BOOL8:
261 case OperandType::TENSOR_QUANT8_SYMM:
262 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
263 if (operand.zeroPoint != 0) {
264 LOG(ERROR) << "Operand " << index << ": Operand of type "
265 << getOperandTypeName(operand.type) << " with a non-zero zeroPoint "
266 << operand.zeroPoint;
267 return false;
268 }
269 break;
270 case OperandType::TENSOR_QUANT8_ASYMM:
271 if (operand.zeroPoint < 0 || operand.zeroPoint > 255) {
272 LOG(ERROR) << "Operand " << index << ": Operand of type "
273 << getOperandTypeName(operand.type) << " with an invalid zeroPoint "
274 << operand.zeroPoint << ", must be in range [0, 255]";
275 return false;
276 }
277 break;
278 case OperandType::TENSOR_QUANT16_ASYMM:
279 if (operand.zeroPoint < 0 || operand.zeroPoint > 65535) {
280 LOG(ERROR) << "Operand " << index << ": Operand of type "
281 << getOperandTypeName(operand.type) << " with an invalid zeroPoint "
282 << operand.zeroPoint << ", must be in range [0, 65535]";
283 return false;
284 }
285 break;
286 case OperandType::TENSOR_QUANT16_SYMM:
287 if (operand.zeroPoint != 0) {
288 LOG(ERROR) << "Operand " << index << ": Operand of type "
289 << getOperandTypeName(operand.type) << " with a non-zero zeroPoint "
290 << operand.zeroPoint;
291 return false;
292 }
293 break;
294 default:
295 if (isExtensionOperandType(operand.type) && operand.zeroPoint != 0) {
296 LOG(ERROR) << "Operand " << index << ": Operand of type "
297 << getOperandTypeName(operand.type) << " with a non-zero zeroPoint "
298 << operand.zeroPoint;
299 return false;
300 }
301 // No validation for OEM types.
302 break;
303 }
304
305 NN_RET_CHECK(validateOperandExtraParams(operand, index));
306
307 // Validate the lifetime and the location.
308 const DataLocation& location = operand.location;
309 switch (operand.lifetime) {
310 case OperandLifeTime::CONSTANT_COPY:
311 if (location.poolIndex != 0) {
312 LOG(ERROR) << "Operand " << index
313 << ": CONSTANT_COPY with a non-zero poolIndex "
314 << location.poolIndex;
315 return false;
316 }
317 // Do the addition using size_t to avoid potential wrap-around problems.
318 if (static_cast<size_t>(location.offset) + location.length > operandValues.size()) {
319 LOG(ERROR) << "Operand " << index
320 << ": OperandValue location out of range. Starts at "
321 << location.offset << ", length " << location.length << ", max "
322 << operandValues.size();
323 return false;
324 }
325 break;
326 case OperandLifeTime::CONSTANT_REFERENCE:
327 if (!poolVerifier.validate(location)) {
328 return false;
329 }
330 break;
331 case OperandLifeTime::TEMPORARY_VARIABLE:
332 case OperandLifeTime::MODEL_INPUT:
333 case OperandLifeTime::MODEL_OUTPUT:
334 case OperandLifeTime::NO_VALUE:
335 if (location.poolIndex != 0 || location.offset != 0 || location.length != 0) {
336 LOG(ERROR) << "Operand " << index << ": Unexpected poolIndex "
337 << location.poolIndex << ", offset " << location.offset
338 << ", or length " << location.length << " for operand of lifetime "
339 << toString(operand.lifetime);
340 return false;
341 }
342 break;
343 default:
344 LOG(ERROR) << "Operand " << index << ": Invalid lifetime "
345 << toString(operand.lifetime);
346 return false;
347 }
348
349 // For constants, validate that the length is as expected. The other lifetimes
350 // expect the length to be 0. Don't validate for OEM types.
351 if (operand.lifetime == OperandLifeTime::CONSTANT_REFERENCE ||
352 operand.lifetime == OperandLifeTime::CONSTANT_COPY) {
353 if (!isExtensionOperandType(operand.type) && operand.type != OperandType::OEM &&
354 operand.type != OperandType::TENSOR_OEM_BYTE) {
355 uint32_t expectedLength = nonExtensionOperandSizeOfData(operand);
356 if (location.length != expectedLength) {
357 LOG(ERROR) << "Operand " << index << ": For operand " << toString(operand)
358 << " expected a size of " << expectedLength << " but got "
359 << location.length;
360 return false;
361 }
362 }
363 }
364
365 index++;
366 }
367 return true;
368 }
369
getHalVersion(const V1_0::Operation &)370 static HalVersion getHalVersion(const V1_0::Operation&) {
371 return HalVersion::V1_0;
372 }
373
getHalVersion(const V1_1::Operation &)374 static HalVersion getHalVersion(const V1_1::Operation&) {
375 return HalVersion::V1_1;
376 }
377
getHalVersion(const V1_2::Operation &)378 static HalVersion getHalVersion(const V1_2::Operation&) {
379 return HalVersion::V1_2;
380 }
381
382 template <typename VersionedOperation>
validateOperations(const hidl_vec<VersionedOperation> & operations,const hidl_vec<Operand> & operands)383 static bool validateOperations(const hidl_vec<VersionedOperation>& operations,
384 const hidl_vec<Operand>& operands) {
385 const size_t operandCount = operands.size();
386 // This vector keeps track of whether there's an operation that writes to
387 // each operand. It is used to validate that temporary variables and
388 // model outputs will be written to.
389 std::vector<bool> writtenTo(operandCount, false);
390 for (auto& op : operations) {
391 // TODO Validate the shapes and any known values. This is currently
392 // done in CpuExecutor but should be done here for all drivers.
393 int error = validateOperation(
394 static_cast<int32_t>(op.type), op.inputs.size(),
395 op.inputs.size() > 0 ? op.inputs.data() : nullptr, op.outputs.size(),
396 op.outputs.size() > 0 ? op.outputs.data() : nullptr, operands, getHalVersion(op));
397 if (error != ANEURALNETWORKS_NO_ERROR) {
398 LOG(ERROR) << "Invalid operation " << toString(op.type);
399 return false;
400 }
401
402 for (uint32_t i : op.outputs) {
403 const Operand& operand = operands[i];
404 if (operand.lifetime != OperandLifeTime::TEMPORARY_VARIABLE &&
405 operand.lifetime != OperandLifeTime::MODEL_OUTPUT) {
406 LOG(ERROR) << "Writing to an operand with incompatible lifetime "
407 << toString(operand.lifetime);
408 return false;
409 }
410
411 // Check that we only write once to an operand.
412 if (writtenTo[i]) {
413 LOG(ERROR) << "Operand " << i << " written a second time";
414 return false;
415 }
416 writtenTo[i] = true;
417 }
418 }
419 for (size_t i = 0; i < operandCount; i++) {
420 if (!writtenTo[i]) {
421 const Operand& operand = operands[i];
422 if (operand.lifetime == OperandLifeTime::TEMPORARY_VARIABLE ||
423 operand.lifetime == OperandLifeTime::MODEL_OUTPUT) {
424 LOG(ERROR) << "Operand " << i << " with lifetime " << toString(operand.lifetime)
425 << " is not being written to.";
426 return false;
427 }
428 }
429 }
430 // TODO More whole graph verifications are possible, for example that an
431 // operand is not use as input & output for the same op, and more
432 // generally that it is acyclic.
433 return true;
434 }
435
validatePool(const hidl_memory & pool,HalVersion ver)436 bool validatePool(const hidl_memory& pool, HalVersion ver) {
437 const auto& name = pool.name();
438 if (name != "ashmem" && name != "mmap_fd" &&
439 ((ver < HalVersion::V1_2) ||
440 (name != "hardware_buffer_blob" && name != "hardware_buffer"))) {
441 LOG(ERROR) << "Unsupported memory type " << name;
442 return false;
443 }
444 if (pool.handle() == nullptr) {
445 LOG(ERROR) << "Memory of type " << name << " is null";
446 return false;
447 }
448 return true;
449 }
450
validatePools(const hidl_vec<hidl_memory> & pools,HalVersion ver)451 static bool validatePools(const hidl_vec<hidl_memory>& pools, HalVersion ver) {
452 return std::all_of(pools.begin(), pools.end(),
453 [ver](const hidl_memory& pool) { return validatePool(pool, ver); });
454 }
455
validateModelInputOutputs(const hidl_vec<uint32_t> indexes,const hidl_vec<Operand> & operands,OperandLifeTime lifetime)456 static bool validateModelInputOutputs(const hidl_vec<uint32_t> indexes,
457 const hidl_vec<Operand>& operands, OperandLifeTime lifetime) {
458 const size_t operandCount = operands.size();
459 for (uint32_t i : indexes) {
460 if (i >= operandCount) {
461 LOG(ERROR) << "Model input or output index out of range: " << i << "/" << operandCount;
462 return false;
463 }
464 const Operand& operand = operands[i];
465 if (operand.lifetime != lifetime) {
466 LOG(ERROR) << "Model input or output has lifetime of " << toString(operand.lifetime)
467 << " instead of the expected " << toString(lifetime);
468 return false;
469 }
470 }
471
472 std::vector<uint32_t> sortedIndexes = indexes;
473 std::sort(sortedIndexes.begin(), sortedIndexes.end());
474 auto adjacentI = std::adjacent_find(sortedIndexes.begin(), sortedIndexes.end());
475 if (adjacentI != sortedIndexes.end()) {
476 LOG(ERROR) << "Model input or output occurs multiple times: " << *adjacentI;
477 return false;
478 }
479 return true;
480 }
481
482 template <class T_Model>
validateModel(const T_Model & model)483 bool validateModel(const T_Model& model) {
484 NNTRACE_FULL(NNTRACE_LAYER_UTILITY, NNTRACE_PHASE_UNSPECIFIED, "validateModel");
485 HalVersion version = ModelToHalVersion<T_Model>::version;
486 if (model.operations.size() == 0 || model.operands.size() == 0) {
487 LOG(ERROR) << "Invalid empty model.";
488 return false;
489 }
490 // We only need versioned operands for their validation. For all the other
491 // validations we can use operands upcasted to the latest version.
492 const hidl_vec<Operand> latestVersionOperands = convertToV1_2(model.operands);
493 return (validateOperands(model.operands, model.operandValues, model.pools,
494 /*allowUnspecifiedRank=*/version >= HalVersion::V1_2) &&
495 validateOperations(model.operations, latestVersionOperands) &&
496 validateModelInputOutputs(model.inputIndexes, latestVersionOperands,
497 OperandLifeTime::MODEL_INPUT) &&
498 validateModelInputOutputs(model.outputIndexes, latestVersionOperands,
499 OperandLifeTime::MODEL_OUTPUT) &&
500 validatePools(model.pools, version));
501 }
502
503 template bool validateModel<V1_0::Model>(const V1_0::Model& model);
504 template bool validateModel<V1_1::Model>(const V1_1::Model& model);
505 template bool validateModel<V1_2::Model>(const V1_2::Model& model);
506
507 // Validates the arguments of a request. type is either "input" or "output" and is used
508 // for printing error messages. The operandIndexes is the appropriate array of input
509 // or output operand indexes that was passed to the ANeuralNetworksModel_identifyInputsAndOutputs.
validateRequestArguments(const hidl_vec<RequestArgument> & requestArguments,const hidl_vec<uint32_t> & operandIndexes,const hidl_vec<Operand> & operands,const hidl_vec<hidl_memory> & pools,bool allowUnspecified,const char * type)510 static bool validateRequestArguments(const hidl_vec<RequestArgument>& requestArguments,
511 const hidl_vec<uint32_t>& operandIndexes,
512 const hidl_vec<Operand>& operands,
513 const hidl_vec<hidl_memory>& pools, bool allowUnspecified,
514 const char* type) {
515 MemoryAccessVerifier poolVerifier(pools);
516 // The request should specify as many arguments as were described in the model.
517 const size_t requestArgumentCount = requestArguments.size();
518 if (requestArgumentCount != operandIndexes.size()) {
519 LOG(ERROR) << "Request specifies " << requestArgumentCount << " " << type
520 << "s but the model has " << operandIndexes.size();
521 return false;
522 }
523 for (size_t requestArgumentIndex = 0; requestArgumentIndex < requestArgumentCount;
524 requestArgumentIndex++) {
525 const RequestArgument& requestArgument = requestArguments[requestArgumentIndex];
526 const DataLocation& location = requestArgument.location;
527 // Get the operand index for this argument. We extract it from the list
528 // that was provided in the call to ANeuralNetworksModel_identifyInputsAndOutputs.
529 // We assume in this function that the model has been validated already.
530 const uint32_t operandIndex = operandIndexes[requestArgumentIndex];
531 const Operand& operand = operands[operandIndex];
532 if (requestArgument.hasNoValue) {
533 if (location.poolIndex != 0 || location.offset != 0 || location.length != 0 ||
534 requestArgument.dimensions.size() != 0) {
535 LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
536 << " has no value yet has details.";
537 return false;
538 }
539 } else {
540 // Validate the location.
541 if (!poolVerifier.validate(location)) {
542 return false;
543 }
544 // If the argument specified a dimension, validate it.
545 uint32_t rank = requestArgument.dimensions.size();
546 if (rank == 0) {
547 if (!allowUnspecified) {
548 // Validate that all the dimensions are specified in the model.
549 for (size_t i = 0; i < operand.dimensions.size(); i++) {
550 if (operand.dimensions[i] == 0) {
551 LOG(ERROR) << "Model has dimension " << i
552 << " set to 0 but the request does specify the dimension.";
553 return false;
554 }
555 }
556 }
557 } else {
558 if (rank != operand.dimensions.size()) {
559 LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
560 << " has number of dimensions (" << rank
561 << ") different than the model's (" << operand.dimensions.size()
562 << ")";
563 return false;
564 }
565 for (size_t i = 0; i < rank; i++) {
566 if (requestArgument.dimensions[i] != operand.dimensions[i] &&
567 operand.dimensions[i] != 0) {
568 LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
569 << " has dimension " << i << " of "
570 << requestArgument.dimensions[i]
571 << " different than the model's " << operand.dimensions[i];
572 return false;
573 }
574 if (requestArgument.dimensions[i] == 0 && !allowUnspecified) {
575 LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
576 << " has dimension " << i << " of zero";
577 return false;
578 }
579 }
580 }
581 }
582 }
583 return true;
584 }
585
586 template <class T_Model>
validateRequest(const Request & request,const T_Model & model)587 bool validateRequest(const Request& request, const T_Model& model) {
588 HalVersion version = ModelToHalVersion<T_Model>::version;
589 return (validateRequestArguments(request.inputs, model.inputIndexes,
590 convertToV1_2(model.operands), request.pools,
591 /*allowUnspecified=*/false, "input") &&
592 validateRequestArguments(request.outputs, model.outputIndexes,
593 convertToV1_2(model.operands), request.pools,
594 /*allowUnspecified=*/version >= HalVersion::V1_2, "output") &&
595 validatePools(request.pools, version));
596 }
597
598 template bool validateRequest<V1_0::Model>(const Request& request, const V1_0::Model& model);
599 template bool validateRequest<V1_1::Model>(const Request& request, const V1_1::Model& model);
600 template bool validateRequest<V1_2::Model>(const Request& request, const V1_2::Model& model);
601
validateExecutionPreference(ExecutionPreference preference)602 bool validateExecutionPreference(ExecutionPreference preference) {
603 return preference == ExecutionPreference::LOW_POWER ||
604 preference == ExecutionPreference::FAST_SINGLE_ANSWER ||
605 preference == ExecutionPreference::SUSTAINED_SPEED;
606 }
607
validOperandType(V1_0::OperandType operandType)608 bool validOperandType(V1_0::OperandType operandType) {
609 switch (operandType) {
610 case V1_0::OperandType::FLOAT32:
611 case V1_0::OperandType::INT32:
612 case V1_0::OperandType::UINT32:
613 case V1_0::OperandType::TENSOR_FLOAT32:
614 case V1_0::OperandType::TENSOR_INT32:
615 case V1_0::OperandType::TENSOR_QUANT8_ASYMM:
616 case V1_0::OperandType::OEM:
617 case V1_0::OperandType::TENSOR_OEM_BYTE:
618 return true;
619 default:
620 return false;
621 }
622 }
623
validOperandType(V1_2::OperandType operandType)624 bool validOperandType(V1_2::OperandType operandType) {
625 switch (operandType) {
626 case V1_2::OperandType::FLOAT16:
627 case V1_2::OperandType::FLOAT32:
628 case V1_2::OperandType::INT32:
629 case V1_2::OperandType::UINT32:
630 case V1_2::OperandType::BOOL:
631 case V1_2::OperandType::TENSOR_FLOAT16:
632 case V1_2::OperandType::TENSOR_FLOAT32:
633 case V1_2::OperandType::TENSOR_INT32:
634 case V1_2::OperandType::TENSOR_QUANT8_ASYMM:
635 case V1_2::OperandType::TENSOR_QUANT8_SYMM:
636 case V1_2::OperandType::TENSOR_QUANT16_ASYMM:
637 case V1_2::OperandType::TENSOR_QUANT16_SYMM:
638 case V1_2::OperandType::TENSOR_BOOL8:
639 case V1_2::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
640 case V1_2::OperandType::OEM:
641 case V1_2::OperandType::TENSOR_OEM_BYTE:
642 return true;
643 default:
644 return isExtensionOperandType(operandType);
645 }
646 }
647
648 } // namespace nn
649 } // namespace android
650