1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "ValidateHal"
18 
19 #include "ValidateHal.h"
20 #include "NeuralNetworks.h"
21 #include "OperationsUtils.h"
22 #include "Tracing.h"
23 #include "Utils.h"
24 
25 #include <android-base/logging.h>
26 
27 namespace android {
28 namespace nn {
29 
30 template <class T_Model>
31 struct ModelToHalVersion;
32 template <>
33 struct ModelToHalVersion<V1_0::Model> {
34     static constexpr HalVersion version = HalVersion::V1_0;
35 };
36 template <>
37 struct ModelToHalVersion<V1_1::Model> {
38     static constexpr HalVersion version = HalVersion::V1_1;
39 };
40 template <>
41 struct ModelToHalVersion<V1_2::Model> {
42     static constexpr HalVersion version = HalVersion::V1_2;
43 };
44 
45 class MemoryAccessVerifier {
46 public:
MemoryAccessVerifier(const hidl_vec<hidl_memory> & pools)47     MemoryAccessVerifier(const hidl_vec<hidl_memory>& pools)
48         : mPoolCount(pools.size()), mPoolSizes(mPoolCount) {
49         for (size_t i = 0; i < mPoolCount; i++) {
50             mPoolSizes[i] = pools[i].size();
51         }
52     }
validate(const DataLocation & location)53     bool validate(const DataLocation& location) {
54         if (location.poolIndex >= mPoolCount) {
55             LOG(ERROR) << "Invalid poolIndex " << location.poolIndex << "/" << mPoolCount;
56             return false;
57         }
58         const size_t size = mPoolSizes[location.poolIndex];
59         // Do the addition using size_t to avoid potential wrap-around problems.
60         if (static_cast<size_t>(location.offset) + location.length > size) {
61             LOG(ERROR) << "Reference to pool " << location.poolIndex << " with offset "
62                        << location.offset << " and length " << location.length
63                        << " exceeds pool size of " << size;
64             return false;
65         }
66         return true;
67     }
68 
69 private:
70     size_t mPoolCount;
71     std::vector<size_t> mPoolSizes;
72 };
73 
validateOperandExtraParams(const V1_2::Operand & operand,uint32_t index)74 static bool validateOperandExtraParams(const V1_2::Operand& operand, uint32_t index) {
75     switch (operand.type) {
76         case OperandType::FLOAT32:
77         case OperandType::INT32:
78         case OperandType::UINT32:
79         case OperandType::BOOL:
80         case OperandType::TENSOR_FLOAT32:
81         case OperandType::TENSOR_FLOAT16:
82         case OperandType::TENSOR_INT32:
83         case OperandType::TENSOR_QUANT8_ASYMM:
84         case OperandType::TENSOR_QUANT8_SYMM:
85         case OperandType::TENSOR_QUANT16_ASYMM:
86         case OperandType::TENSOR_QUANT16_SYMM:
87         case OperandType::TENSOR_BOOL8: {
88             NN_RET_CHECK(operand.extraParams.getDiscriminator() ==
89                          V1_2::Operand::ExtraParams::hidl_discriminator::none)
90                     << "Operand " << index << ": Operand of type "
91                     << getOperandTypeName(operand.type)
92                     << " has incorrect extraParams: " << toString(operand.extraParams);
93         } break;
94         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: {
95             NN_RET_CHECK(operand.extraParams.getDiscriminator() ==
96                          V1_2::Operand::ExtraParams::hidl_discriminator::channelQuant)
97                     << "Operand " << index << ": Operand of type "
98                     << getOperandTypeName(operand.type) << " without a Channel Quantization params";
99             auto& channelQuant = operand.extraParams.channelQuant();
100 
101             size_t count = operand.dimensions.size();
102             NN_RET_CHECK_LT(channelQuant.channelDim, count)
103                     << "Operand " << index << ": Operand of type "
104                     << getOperandTypeName(operand.type)
105                     << " with an invalid channelQuant.channelDim " << channelQuant.channelDim
106                     << ", must be valid dimension index in range [0, " << count << ")";
107             uint32_t expected = operand.dimensions[channelQuant.channelDim];
108             NN_RET_CHECK_EQ(channelQuant.scales.size(), expected)
109                     << "Operand " << index << ": Operand of type "
110                     << getOperandTypeName(operand.type) << " with a wrong-sized scales, "
111                     << "expected " << expected << " was " << channelQuant.scales.size();
112             NN_RET_CHECK_NE(expected, 0)
113                     << "Operand " << index << ": Operand of type "
114                     << getOperandTypeName(operand.type) << " channel dimension "
115                     << channelQuant.channelDim << " is underspecified (can't be 0)";
116             for (uint32_t i = 0; i < expected; ++i) {
117                 NN_RET_CHECK_GT(channelQuant.scales[i], .0f)
118                         << "Operand " << index << ": Operand of type "
119                         << getOperandTypeName(operand.type) << " with a negative value in scales["
120                         << i << "]=" << channelQuant.scales[i];
121             }
122         } break;
123         default: {
124             if (isExtensionOperandType(operand.type)) {
125                 NN_RET_CHECK(operand.extraParams.getDiscriminator() ==
126                                      V1_2::Operand::ExtraParams::hidl_discriminator::extension ||
127                              operand.extraParams.getDiscriminator() ==
128                                      V1_2::Operand::ExtraParams::hidl_discriminator::none)
129                         << "Operand " << index << ": Extension operand of type "
130                         << getOperandTypeName(operand.type)
131                         << " has incorrect extraParams: " << toString(operand.extraParams);
132             }
133             // No validation for OEM types.
134         } break;
135     }
136     return true;
137 }
138 
139 template <typename VersionedOperand>
validateOperands(const hidl_vec<VersionedOperand> & operands,const hidl_vec<uint8_t> & operandValues,const hidl_vec<hidl_memory> & pools,bool allowUnspecifiedRank)140 static bool validateOperands(const hidl_vec<VersionedOperand>& operands,
141                              const hidl_vec<uint8_t>& operandValues,
142                              const hidl_vec<hidl_memory>& pools, bool allowUnspecifiedRank) {
143     uint32_t index = 0;
144     MemoryAccessVerifier poolVerifier(pools);
145     for (auto& versionedOperand : operands) {
146         if (!validOperandType(versionedOperand.type)) {
147             LOG(ERROR) << "Operand is not supported by this version: "
148                        << toString(versionedOperand.type);
149             return false;
150         }
151         // Once we are sure the operand is supported by its version, it is safe
152         // to convert it to the latest version for the rest of the validations.
153         V1_2::Operand operand = convertToV1_2(versionedOperand);
154         // Validate type and dimensions.
155         switch (operand.type) {
156             case OperandType::FLOAT16:
157             case OperandType::FLOAT32:
158             case OperandType::INT32:
159             case OperandType::UINT32:
160             case OperandType::BOOL:
161             case OperandType::OEM: {
162                 size_t count = operand.dimensions.size();
163                 if (count != 0) {
164                     LOG(ERROR) << "Operand " << index << ": Scalar data has dimensions of rank "
165                                << count;
166                     return false;
167                 }
168                 break;
169             }
170             case OperandType::TENSOR_FLOAT16:
171             case OperandType::TENSOR_FLOAT32:
172             case OperandType::TENSOR_INT32:
173             case OperandType::TENSOR_QUANT8_ASYMM:
174             case OperandType::TENSOR_QUANT8_SYMM:
175             case OperandType::TENSOR_QUANT16_ASYMM:
176             case OperandType::TENSOR_QUANT16_SYMM:
177             case OperandType::TENSOR_BOOL8:
178             case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
179             case OperandType::TENSOR_OEM_BYTE: {
180                 if ((!allowUnspecifiedRank || operand.lifetime == OperandLifeTime::CONSTANT_COPY ||
181                      operand.lifetime == OperandLifeTime::CONSTANT_REFERENCE) &&
182                     operand.dimensions.size() == 0) {
183                     LOG(ERROR) << "Operand " << index << ": Tensor has dimensions of rank 0";
184                     return false;
185                 }
186                 break;
187             }
188             default: {
189                 if (!isExtensionOperandType(operand.type)) {
190                     LOG(ERROR) << "Operand " << index << ": Invalid operand type "
191                                << toString(operand.type);
192                     return false;
193                 }
194             } break;
195         }
196 
197         // TODO Validate the numberOfConsumers.
198         // TODO Since we have to validate it, there was no point in including it. For the next
199         // release, consider removing unless we have an additional process in system space
200         // that creates this value. In that case, it would not have to be validated.
201 
202         // Validate the scale.
203         switch (operand.type) {
204             case OperandType::FLOAT16:
205             case OperandType::FLOAT32:
206             case OperandType::INT32:
207             case OperandType::UINT32:
208             case OperandType::BOOL:
209             case OperandType::TENSOR_FLOAT16:
210             case OperandType::TENSOR_FLOAT32:
211             case OperandType::TENSOR_BOOL8:
212             case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
213                 if (operand.scale != 0.f) {
214                     LOG(ERROR) << "Operand " << index << ": Operand of type "
215                                << getOperandTypeName(operand.type) << " with a non-zero scale ("
216                                << operand.scale << ")";
217                     return false;
218                 }
219                 break;
220             case OperandType::TENSOR_INT32:
221                 // TENSOR_INT32 may be used with or without scale, depending on the operation.
222                 if (operand.scale < 0.f) {
223                     LOG(ERROR) << "Operand " << index << ": Operand of type "
224                                << getOperandTypeName(operand.type) << " with a negative scale";
225                     return false;
226                 }
227                 break;
228             case OperandType::TENSOR_QUANT8_ASYMM:
229             case OperandType::TENSOR_QUANT8_SYMM:
230             case OperandType::TENSOR_QUANT16_ASYMM:
231             case OperandType::TENSOR_QUANT16_SYMM:
232                 if (operand.scale <= 0.f) {
233                     LOG(ERROR) << "Operand " << index << ": Operand of type "
234                                << getOperandTypeName(operand.type) << " with a non-positive scale";
235                     return false;
236                 }
237                 break;
238             default:
239                 if (isExtensionOperandType(operand.type) && operand.scale != 0.f) {
240                     LOG(ERROR) << "Operand " << index << ": Operand of type "
241                                << getOperandTypeName(operand.type) << " with a non-zero scale ("
242                                << operand.scale << ")";
243                     return false;
244                 }
245                 // No validation for OEM types.
246                 // TODO(b/119869082) We should have a separate type for TENSOR_INT32 with a scale.
247                 break;
248         }
249 
250         // Validate the zeroPoint.
251         switch (operand.type) {
252             case OperandType::FLOAT16:
253             case OperandType::FLOAT32:
254             case OperandType::INT32:
255             case OperandType::UINT32:
256             case OperandType::BOOL:
257             case OperandType::TENSOR_FLOAT16:
258             case OperandType::TENSOR_FLOAT32:
259             case OperandType::TENSOR_INT32:
260             case OperandType::TENSOR_BOOL8:
261             case OperandType::TENSOR_QUANT8_SYMM:
262             case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
263                 if (operand.zeroPoint != 0) {
264                     LOG(ERROR) << "Operand " << index << ": Operand of type "
265                                << getOperandTypeName(operand.type) << " with a non-zero zeroPoint "
266                                << operand.zeroPoint;
267                     return false;
268                 }
269                 break;
270             case OperandType::TENSOR_QUANT8_ASYMM:
271                 if (operand.zeroPoint < 0 || operand.zeroPoint > 255) {
272                     LOG(ERROR) << "Operand " << index << ": Operand of type "
273                                << getOperandTypeName(operand.type) << " with an invalid zeroPoint "
274                                << operand.zeroPoint << ", must be in range [0, 255]";
275                     return false;
276                 }
277                 break;
278             case OperandType::TENSOR_QUANT16_ASYMM:
279                 if (operand.zeroPoint < 0 || operand.zeroPoint > 65535) {
280                     LOG(ERROR) << "Operand " << index << ": Operand of type "
281                                << getOperandTypeName(operand.type) << " with an invalid zeroPoint "
282                                << operand.zeroPoint << ", must be in range [0, 65535]";
283                     return false;
284                 }
285                 break;
286             case OperandType::TENSOR_QUANT16_SYMM:
287                 if (operand.zeroPoint != 0) {
288                     LOG(ERROR) << "Operand " << index << ": Operand of type "
289                                << getOperandTypeName(operand.type) << " with a non-zero zeroPoint "
290                                << operand.zeroPoint;
291                     return false;
292                 }
293                 break;
294             default:
295                 if (isExtensionOperandType(operand.type) && operand.zeroPoint != 0) {
296                     LOG(ERROR) << "Operand " << index << ": Operand of type "
297                                << getOperandTypeName(operand.type) << " with a non-zero zeroPoint "
298                                << operand.zeroPoint;
299                     return false;
300                 }
301                 // No validation for OEM types.
302                 break;
303         }
304 
305         NN_RET_CHECK(validateOperandExtraParams(operand, index));
306 
307         // Validate the lifetime and the location.
308         const DataLocation& location = operand.location;
309         switch (operand.lifetime) {
310             case OperandLifeTime::CONSTANT_COPY:
311                 if (location.poolIndex != 0) {
312                     LOG(ERROR) << "Operand " << index
313                                << ": CONSTANT_COPY with a non-zero poolIndex "
314                                << location.poolIndex;
315                     return false;
316                 }
317                 // Do the addition using size_t to avoid potential wrap-around problems.
318                 if (static_cast<size_t>(location.offset) + location.length > operandValues.size()) {
319                     LOG(ERROR) << "Operand " << index
320                                << ": OperandValue location out of range.  Starts at "
321                                << location.offset << ", length " << location.length << ", max "
322                                << operandValues.size();
323                     return false;
324                 }
325                 break;
326             case OperandLifeTime::CONSTANT_REFERENCE:
327                 if (!poolVerifier.validate(location)) {
328                     return false;
329                 }
330                 break;
331             case OperandLifeTime::TEMPORARY_VARIABLE:
332             case OperandLifeTime::MODEL_INPUT:
333             case OperandLifeTime::MODEL_OUTPUT:
334             case OperandLifeTime::NO_VALUE:
335                 if (location.poolIndex != 0 || location.offset != 0 || location.length != 0) {
336                     LOG(ERROR) << "Operand " << index << ": Unexpected poolIndex "
337                                << location.poolIndex << ", offset " << location.offset
338                                << ", or length " << location.length << " for operand of lifetime "
339                                << toString(operand.lifetime);
340                     return false;
341                 }
342                 break;
343             default:
344                 LOG(ERROR) << "Operand " << index << ": Invalid lifetime "
345                            << toString(operand.lifetime);
346                 return false;
347         }
348 
349         // For constants, validate that the length is as expected. The other lifetimes
350         // expect the length to be 0. Don't validate for OEM types.
351         if (operand.lifetime == OperandLifeTime::CONSTANT_REFERENCE ||
352             operand.lifetime == OperandLifeTime::CONSTANT_COPY) {
353             if (!isExtensionOperandType(operand.type) && operand.type != OperandType::OEM &&
354                 operand.type != OperandType::TENSOR_OEM_BYTE) {
355                 uint32_t expectedLength = nonExtensionOperandSizeOfData(operand);
356                 if (location.length != expectedLength) {
357                     LOG(ERROR) << "Operand " << index << ": For operand " << toString(operand)
358                                << " expected a size of " << expectedLength << " but got "
359                                << location.length;
360                     return false;
361                 }
362             }
363         }
364 
365         index++;
366     }
367     return true;
368 }
369 
getHalVersion(const V1_0::Operation &)370 static HalVersion getHalVersion(const V1_0::Operation&) {
371     return HalVersion::V1_0;
372 }
373 
getHalVersion(const V1_1::Operation &)374 static HalVersion getHalVersion(const V1_1::Operation&) {
375     return HalVersion::V1_1;
376 }
377 
getHalVersion(const V1_2::Operation &)378 static HalVersion getHalVersion(const V1_2::Operation&) {
379     return HalVersion::V1_2;
380 }
381 
382 template <typename VersionedOperation>
validateOperations(const hidl_vec<VersionedOperation> & operations,const hidl_vec<Operand> & operands)383 static bool validateOperations(const hidl_vec<VersionedOperation>& operations,
384                                const hidl_vec<Operand>& operands) {
385     const size_t operandCount = operands.size();
386     // This vector keeps track of whether there's an operation that writes to
387     // each operand. It is used to validate that temporary variables and
388     // model outputs will be written to.
389     std::vector<bool> writtenTo(operandCount, false);
390     for (auto& op : operations) {
391         // TODO Validate the shapes and any known values. This is currently
392         // done in CpuExecutor but should be done here for all drivers.
393         int error = validateOperation(
394                 static_cast<int32_t>(op.type), op.inputs.size(),
395                 op.inputs.size() > 0 ? op.inputs.data() : nullptr, op.outputs.size(),
396                 op.outputs.size() > 0 ? op.outputs.data() : nullptr, operands, getHalVersion(op));
397         if (error != ANEURALNETWORKS_NO_ERROR) {
398             LOG(ERROR) << "Invalid operation " << toString(op.type);
399             return false;
400         }
401 
402         for (uint32_t i : op.outputs) {
403             const Operand& operand = operands[i];
404             if (operand.lifetime != OperandLifeTime::TEMPORARY_VARIABLE &&
405                 operand.lifetime != OperandLifeTime::MODEL_OUTPUT) {
406                 LOG(ERROR) << "Writing to an operand with incompatible lifetime "
407                            << toString(operand.lifetime);
408                 return false;
409             }
410 
411             // Check that we only write once to an operand.
412             if (writtenTo[i]) {
413                 LOG(ERROR) << "Operand " << i << " written a second time";
414                 return false;
415             }
416             writtenTo[i] = true;
417         }
418     }
419     for (size_t i = 0; i < operandCount; i++) {
420         if (!writtenTo[i]) {
421             const Operand& operand = operands[i];
422             if (operand.lifetime == OperandLifeTime::TEMPORARY_VARIABLE ||
423                 operand.lifetime == OperandLifeTime::MODEL_OUTPUT) {
424                 LOG(ERROR) << "Operand " << i << " with lifetime " << toString(operand.lifetime)
425                            << " is not being written to.";
426                 return false;
427             }
428         }
429     }
430     // TODO More whole graph verifications are possible, for example that an
431     // operand is not use as input & output for the same op, and more
432     // generally that it is acyclic.
433     return true;
434 }
435 
validatePool(const hidl_memory & pool,HalVersion ver)436 bool validatePool(const hidl_memory& pool, HalVersion ver) {
437     const auto& name = pool.name();
438     if (name != "ashmem" && name != "mmap_fd" &&
439         ((ver < HalVersion::V1_2) ||
440          (name != "hardware_buffer_blob" && name != "hardware_buffer"))) {
441         LOG(ERROR) << "Unsupported memory type " << name;
442         return false;
443     }
444     if (pool.handle() == nullptr) {
445         LOG(ERROR) << "Memory of type " << name << " is null";
446         return false;
447     }
448     return true;
449 }
450 
validatePools(const hidl_vec<hidl_memory> & pools,HalVersion ver)451 static bool validatePools(const hidl_vec<hidl_memory>& pools, HalVersion ver) {
452     return std::all_of(pools.begin(), pools.end(),
453                        [ver](const hidl_memory& pool) { return validatePool(pool, ver); });
454 }
455 
validateModelInputOutputs(const hidl_vec<uint32_t> indexes,const hidl_vec<Operand> & operands,OperandLifeTime lifetime)456 static bool validateModelInputOutputs(const hidl_vec<uint32_t> indexes,
457                                       const hidl_vec<Operand>& operands, OperandLifeTime lifetime) {
458     const size_t operandCount = operands.size();
459     for (uint32_t i : indexes) {
460         if (i >= operandCount) {
461             LOG(ERROR) << "Model input or output index out of range: " << i << "/" << operandCount;
462             return false;
463         }
464         const Operand& operand = operands[i];
465         if (operand.lifetime != lifetime) {
466             LOG(ERROR) << "Model input or output has lifetime of " << toString(operand.lifetime)
467                        << " instead of the expected " << toString(lifetime);
468             return false;
469         }
470     }
471 
472     std::vector<uint32_t> sortedIndexes = indexes;
473     std::sort(sortedIndexes.begin(), sortedIndexes.end());
474     auto adjacentI = std::adjacent_find(sortedIndexes.begin(), sortedIndexes.end());
475     if (adjacentI != sortedIndexes.end()) {
476         LOG(ERROR) << "Model input or output occurs multiple times: " << *adjacentI;
477         return false;
478     }
479     return true;
480 }
481 
482 template <class T_Model>
validateModel(const T_Model & model)483 bool validateModel(const T_Model& model) {
484     NNTRACE_FULL(NNTRACE_LAYER_UTILITY, NNTRACE_PHASE_UNSPECIFIED, "validateModel");
485     HalVersion version = ModelToHalVersion<T_Model>::version;
486     if (model.operations.size() == 0 || model.operands.size() == 0) {
487         LOG(ERROR) << "Invalid empty model.";
488         return false;
489     }
490     // We only need versioned operands for their validation. For all the other
491     // validations we can use operands upcasted to the latest version.
492     const hidl_vec<Operand> latestVersionOperands = convertToV1_2(model.operands);
493     return (validateOperands(model.operands, model.operandValues, model.pools,
494                              /*allowUnspecifiedRank=*/version >= HalVersion::V1_2) &&
495             validateOperations(model.operations, latestVersionOperands) &&
496             validateModelInputOutputs(model.inputIndexes, latestVersionOperands,
497                                       OperandLifeTime::MODEL_INPUT) &&
498             validateModelInputOutputs(model.outputIndexes, latestVersionOperands,
499                                       OperandLifeTime::MODEL_OUTPUT) &&
500             validatePools(model.pools, version));
501 }
502 
503 template bool validateModel<V1_0::Model>(const V1_0::Model& model);
504 template bool validateModel<V1_1::Model>(const V1_1::Model& model);
505 template bool validateModel<V1_2::Model>(const V1_2::Model& model);
506 
507 // Validates the arguments of a request. type is either "input" or "output" and is used
508 // for printing error messages. The operandIndexes is the appropriate array of input
509 // or output operand indexes that was passed to the ANeuralNetworksModel_identifyInputsAndOutputs.
validateRequestArguments(const hidl_vec<RequestArgument> & requestArguments,const hidl_vec<uint32_t> & operandIndexes,const hidl_vec<Operand> & operands,const hidl_vec<hidl_memory> & pools,bool allowUnspecified,const char * type)510 static bool validateRequestArguments(const hidl_vec<RequestArgument>& requestArguments,
511                                      const hidl_vec<uint32_t>& operandIndexes,
512                                      const hidl_vec<Operand>& operands,
513                                      const hidl_vec<hidl_memory>& pools, bool allowUnspecified,
514                                      const char* type) {
515     MemoryAccessVerifier poolVerifier(pools);
516     // The request should specify as many arguments as were described in the model.
517     const size_t requestArgumentCount = requestArguments.size();
518     if (requestArgumentCount != operandIndexes.size()) {
519         LOG(ERROR) << "Request specifies " << requestArgumentCount << " " << type
520                    << "s but the model has " << operandIndexes.size();
521         return false;
522     }
523     for (size_t requestArgumentIndex = 0; requestArgumentIndex < requestArgumentCount;
524          requestArgumentIndex++) {
525         const RequestArgument& requestArgument = requestArguments[requestArgumentIndex];
526         const DataLocation& location = requestArgument.location;
527         // Get the operand index for this argument. We extract it from the list
528         // that was provided in the call to ANeuralNetworksModel_identifyInputsAndOutputs.
529         // We assume in this function that the model has been validated already.
530         const uint32_t operandIndex = operandIndexes[requestArgumentIndex];
531         const Operand& operand = operands[operandIndex];
532         if (requestArgument.hasNoValue) {
533             if (location.poolIndex != 0 || location.offset != 0 || location.length != 0 ||
534                 requestArgument.dimensions.size() != 0) {
535                 LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
536                            << " has no value yet has details.";
537                 return false;
538             }
539         } else {
540             // Validate the location.
541             if (!poolVerifier.validate(location)) {
542                 return false;
543             }
544             // If the argument specified a dimension, validate it.
545             uint32_t rank = requestArgument.dimensions.size();
546             if (rank == 0) {
547                 if (!allowUnspecified) {
548                     // Validate that all the dimensions are specified in the model.
549                     for (size_t i = 0; i < operand.dimensions.size(); i++) {
550                         if (operand.dimensions[i] == 0) {
551                             LOG(ERROR) << "Model has dimension " << i
552                                        << " set to 0 but the request does specify the dimension.";
553                             return false;
554                         }
555                     }
556                 }
557             } else {
558                 if (rank != operand.dimensions.size()) {
559                     LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
560                                << " has number of dimensions (" << rank
561                                << ") different than the model's (" << operand.dimensions.size()
562                                << ")";
563                     return false;
564                 }
565                 for (size_t i = 0; i < rank; i++) {
566                     if (requestArgument.dimensions[i] != operand.dimensions[i] &&
567                         operand.dimensions[i] != 0) {
568                         LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
569                                    << " has dimension " << i << " of "
570                                    << requestArgument.dimensions[i]
571                                    << " different than the model's " << operand.dimensions[i];
572                         return false;
573                     }
574                     if (requestArgument.dimensions[i] == 0 && !allowUnspecified) {
575                         LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
576                                    << " has dimension " << i << " of zero";
577                         return false;
578                     }
579                 }
580             }
581         }
582     }
583     return true;
584 }
585 
586 template <class T_Model>
validateRequest(const Request & request,const T_Model & model)587 bool validateRequest(const Request& request, const T_Model& model) {
588     HalVersion version = ModelToHalVersion<T_Model>::version;
589     return (validateRequestArguments(request.inputs, model.inputIndexes,
590                                      convertToV1_2(model.operands), request.pools,
591                                      /*allowUnspecified=*/false, "input") &&
592             validateRequestArguments(request.outputs, model.outputIndexes,
593                                      convertToV1_2(model.operands), request.pools,
594                                      /*allowUnspecified=*/version >= HalVersion::V1_2, "output") &&
595             validatePools(request.pools, version));
596 }
597 
598 template bool validateRequest<V1_0::Model>(const Request& request, const V1_0::Model& model);
599 template bool validateRequest<V1_1::Model>(const Request& request, const V1_1::Model& model);
600 template bool validateRequest<V1_2::Model>(const Request& request, const V1_2::Model& model);
601 
validateExecutionPreference(ExecutionPreference preference)602 bool validateExecutionPreference(ExecutionPreference preference) {
603     return preference == ExecutionPreference::LOW_POWER ||
604            preference == ExecutionPreference::FAST_SINGLE_ANSWER ||
605            preference == ExecutionPreference::SUSTAINED_SPEED;
606 }
607 
validOperandType(V1_0::OperandType operandType)608 bool validOperandType(V1_0::OperandType operandType) {
609     switch (operandType) {
610         case V1_0::OperandType::FLOAT32:
611         case V1_0::OperandType::INT32:
612         case V1_0::OperandType::UINT32:
613         case V1_0::OperandType::TENSOR_FLOAT32:
614         case V1_0::OperandType::TENSOR_INT32:
615         case V1_0::OperandType::TENSOR_QUANT8_ASYMM:
616         case V1_0::OperandType::OEM:
617         case V1_0::OperandType::TENSOR_OEM_BYTE:
618             return true;
619         default:
620             return false;
621     }
622 }
623 
validOperandType(V1_2::OperandType operandType)624 bool validOperandType(V1_2::OperandType operandType) {
625     switch (operandType) {
626         case V1_2::OperandType::FLOAT16:
627         case V1_2::OperandType::FLOAT32:
628         case V1_2::OperandType::INT32:
629         case V1_2::OperandType::UINT32:
630         case V1_2::OperandType::BOOL:
631         case V1_2::OperandType::TENSOR_FLOAT16:
632         case V1_2::OperandType::TENSOR_FLOAT32:
633         case V1_2::OperandType::TENSOR_INT32:
634         case V1_2::OperandType::TENSOR_QUANT8_ASYMM:
635         case V1_2::OperandType::TENSOR_QUANT8_SYMM:
636         case V1_2::OperandType::TENSOR_QUANT16_ASYMM:
637         case V1_2::OperandType::TENSOR_QUANT16_SYMM:
638         case V1_2::OperandType::TENSOR_BOOL8:
639         case V1_2::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
640         case V1_2::OperandType::OEM:
641         case V1_2::OperandType::TENSOR_OEM_BYTE:
642             return true;
643         default:
644             return isExtensionOperandType(operandType);
645     }
646 }
647 
648 }  // namespace nn
649 }  // namespace android
650