1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "ValidateHal"
18 
19 #include "ValidateHal.h"
20 
21 #include <android-base/logging.h>
22 
23 #include <algorithm>
24 #include <functional>
25 #include <set>
26 #include <utility>
27 #include <vector>
28 
29 #include "NeuralNetworks.h"
30 #include "OperationsUtils.h"
31 #include "Tracing.h"
32 #include "Utils.h"
33 #include "nnapi/TypeUtils.h"
34 
35 namespace android {
36 namespace nn {
37 
38 template <class T_Model>
39 struct ModelToHalVersion;
40 template <>
41 struct ModelToHalVersion<V1_0::Model> {
42     static constexpr HalVersion version = HalVersion::V1_0;
43 };
44 template <>
45 struct ModelToHalVersion<V1_1::Model> {
46     static constexpr HalVersion version = HalVersion::V1_1;
47 };
48 template <>
49 struct ModelToHalVersion<V1_2::Model> {
50     static constexpr HalVersion version = HalVersion::V1_2;
51 };
52 template <>
53 struct ModelToHalVersion<V1_3::Model> {
54     static constexpr HalVersion version = HalVersion::V1_3;
55 };
56 
57 class MemoryAccessVerifier {
58    public:
MemoryAccessVerifier(const hardware::hidl_vec<hardware::hidl_memory> & pools)59     MemoryAccessVerifier(const hardware::hidl_vec<hardware::hidl_memory>& pools)
60         : mPoolCount(pools.size()), mPoolSizes(mPoolCount) {
61         for (size_t i = 0; i < mPoolCount; i++) {
62             mPoolSizes[i] = pools[i].size();
63         }
64     }
MemoryAccessVerifier(const hardware::hidl_vec<V1_3::Request::MemoryPool> & pools)65     MemoryAccessVerifier(const hardware::hidl_vec<V1_3::Request::MemoryPool>& pools)
66         : mPoolCount(pools.size()), mPoolSizes(mPoolCount) {
67         for (size_t i = 0; i < mPoolCount; i++) {
68             switch (pools[i].getDiscriminator()) {
69                 case V1_3::Request::MemoryPool::hidl_discriminator::hidlMemory:
70                     mPoolSizes[i] = pools[i].hidlMemory().size();
71                     break;
72                 case V1_3::Request::MemoryPool::hidl_discriminator::token:
73                     // Set size to 0 to enforce length == 0 && offset == 0.
74                     mPoolSizes[i] = 0;
75                     break;
76             }
77         }
78     }
validate(const V1_0::DataLocation & location) const79     bool validate(const V1_0::DataLocation& location) const {
80         if (location.poolIndex >= mPoolCount) {
81             LOG(ERROR) << "Invalid poolIndex " << location.poolIndex << "/" << mPoolCount;
82             return false;
83         }
84         const size_t size = mPoolSizes[location.poolIndex];
85         // Do the addition using size_t to avoid potential wrap-around problems.
86         if (static_cast<size_t>(location.offset) + location.length > size) {
87             LOG(ERROR) << "Reference to pool " << location.poolIndex << " with offset "
88                        << location.offset << " and length " << location.length
89                        << " exceeds pool size of " << size;
90             return false;
91         }
92         return true;
93     }
94 
95    private:
96     size_t mPoolCount;
97     std::vector<size_t> mPoolSizes;
98 };
99 
validateOperandExtraParams(const V1_3::Operand & operand,uint32_t index)100 static bool validateOperandExtraParams(const V1_3::Operand& operand, uint32_t index) {
101     switch (operand.type) {
102         case V1_3::OperandType::FLOAT32:
103         case V1_3::OperandType::INT32:
104         case V1_3::OperandType::UINT32:
105         case V1_3::OperandType::BOOL:
106         case V1_3::OperandType::SUBGRAPH:
107         case V1_3::OperandType::TENSOR_FLOAT32:
108         case V1_3::OperandType::TENSOR_FLOAT16:
109         case V1_3::OperandType::TENSOR_INT32:
110         case V1_3::OperandType::TENSOR_QUANT8_ASYMM:
111         case V1_3::OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
112         case V1_3::OperandType::TENSOR_QUANT8_SYMM:
113         case V1_3::OperandType::TENSOR_QUANT16_ASYMM:
114         case V1_3::OperandType::TENSOR_QUANT16_SYMM:
115         case V1_3::OperandType::TENSOR_BOOL8: {
116             NN_RET_CHECK(operand.extraParams.getDiscriminator() ==
117                          V1_2::Operand::ExtraParams::hidl_discriminator::none)
118                     << "Operand " << index << ": Operand of type "
119                     << getOperandTypeName(operand.type)
120                     << " has incorrect extraParams: " << toString(operand.extraParams);
121         } break;
122         case V1_3::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: {
123             NN_RET_CHECK(operand.extraParams.getDiscriminator() ==
124                          V1_2::Operand::ExtraParams::hidl_discriminator::channelQuant)
125                     << "Operand " << index << ": Operand of type "
126                     << getOperandTypeName(operand.type) << " without a Channel Quantization params";
127             auto& channelQuant = operand.extraParams.channelQuant();
128 
129             size_t count = operand.dimensions.size();
130             NN_RET_CHECK_LT(channelQuant.channelDim, count)
131                     << "Operand " << index << ": Operand of type "
132                     << getOperandTypeName(operand.type)
133                     << " with an invalid channelQuant.channelDim " << channelQuant.channelDim
134                     << ", must be valid dimension index in range [0, " << count << ")";
135             uint32_t expected = operand.dimensions[channelQuant.channelDim];
136             NN_RET_CHECK_EQ(channelQuant.scales.size(), expected)
137                     << "Operand " << index << ": Operand of type "
138                     << getOperandTypeName(operand.type) << " with a wrong-sized scales, "
139                     << "expected " << expected << " was " << channelQuant.scales.size();
140             NN_RET_CHECK_NE(expected, 0u)
141                     << "Operand " << index << ": Operand of type "
142                     << getOperandTypeName(operand.type) << " channel dimension "
143                     << channelQuant.channelDim << " is underspecified (can't be 0)";
144             for (uint32_t i = 0; i < expected; ++i) {
145                 NN_RET_CHECK_GT(channelQuant.scales[i], .0f)
146                         << "Operand " << index << ": Operand of type "
147                         << getOperandTypeName(operand.type) << " with a negative value in scales["
148                         << i << "]=" << channelQuant.scales[i];
149             }
150         } break;
151         default: {
152             if (isExtensionOperandType(operand.type)) {
153                 NN_RET_CHECK(operand.extraParams.getDiscriminator() ==
154                                      V1_2::Operand::ExtraParams::hidl_discriminator::extension ||
155                              operand.extraParams.getDiscriminator() ==
156                                      V1_2::Operand::ExtraParams::hidl_discriminator::none)
157                         << "Operand " << index << ": Extension operand of type "
158                         << getOperandTypeName(operand.type)
159                         << " has incorrect extraParams: " << toString(operand.extraParams);
160             }
161             // No validation for OEM types.
162         } break;
163     }
164     return true;
165 }
166 
167 template <typename VersionedOperand>
validateOperands(const hardware::hidl_vec<VersionedOperand> & operands,const hardware::hidl_vec<uint8_t> & operandValues,const hardware::hidl_vec<hardware::hidl_memory> & pools,const hardware::hidl_vec<V1_3::Subgraph> & subgraphs,bool allowUnspecifiedRank)168 static bool validateOperands(const hardware::hidl_vec<VersionedOperand>& operands,
169                              const hardware::hidl_vec<uint8_t>& operandValues,
170                              const hardware::hidl_vec<hardware::hidl_memory>& pools,
171                              const hardware::hidl_vec<V1_3::Subgraph>& subgraphs,
172                              bool allowUnspecifiedRank) {
173     uint32_t index = 0;
174     MemoryAccessVerifier poolVerifier(pools);
175     for (auto& versionedOperand : operands) {
176         if (!validOperandType(versionedOperand.type)) {
177             LOG(ERROR) << "Operand is not supported by this version: "
178                        << toString(versionedOperand.type);
179             return false;
180         }
181         // Once we are sure the operand is supported by its version, it is safe
182         // to convert it to the latest version for the rest of the validations.
183         V1_3::Operand operand = convertToV1_3(versionedOperand);
184         // Validate type and dimensions.
185         switch (operand.type) {
186             case V1_3::OperandType::FLOAT16:
187             case V1_3::OperandType::FLOAT32:
188             case V1_3::OperandType::INT32:
189             case V1_3::OperandType::UINT32:
190             case V1_3::OperandType::BOOL:
191             case V1_3::OperandType::SUBGRAPH:
192             case V1_3::OperandType::OEM: {
193                 size_t count = operand.dimensions.size();
194                 if (count != 0) {
195                     LOG(ERROR) << "Operand " << index << ": Scalar data has dimensions of rank "
196                                << count;
197                     return false;
198                 }
199                 break;
200             }
201             case V1_3::OperandType::TENSOR_FLOAT16:
202             case V1_3::OperandType::TENSOR_FLOAT32:
203             case V1_3::OperandType::TENSOR_INT32:
204             case V1_3::OperandType::TENSOR_QUANT8_ASYMM:
205             case V1_3::OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
206             case V1_3::OperandType::TENSOR_QUANT8_SYMM:
207             case V1_3::OperandType::TENSOR_QUANT16_ASYMM:
208             case V1_3::OperandType::TENSOR_QUANT16_SYMM:
209             case V1_3::OperandType::TENSOR_BOOL8:
210             case V1_3::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
211             case V1_3::OperandType::TENSOR_OEM_BYTE: {
212                 if ((!allowUnspecifiedRank ||
213                      operand.lifetime == V1_3::OperandLifeTime::CONSTANT_COPY ||
214                      operand.lifetime == V1_3::OperandLifeTime::CONSTANT_REFERENCE) &&
215                     operand.dimensions.size() == 0) {
216                     LOG(ERROR) << "Operand " << index << ": Tensor has dimensions of rank 0";
217                     return false;
218                 }
219                 break;
220             }
221             default: {
222                 if (!isExtensionOperandType(operand.type)) {
223                     LOG(ERROR) << "Operand " << index << ": Invalid operand type "
224                                << toString(operand.type);
225                     return false;
226                 }
227             } break;
228         }
229 
230         // Validate the scale.
231         switch (operand.type) {
232             case V1_3::OperandType::FLOAT16:
233             case V1_3::OperandType::FLOAT32:
234             case V1_3::OperandType::INT32:
235             case V1_3::OperandType::UINT32:
236             case V1_3::OperandType::BOOL:
237             case V1_3::OperandType::SUBGRAPH:
238             case V1_3::OperandType::TENSOR_FLOAT16:
239             case V1_3::OperandType::TENSOR_FLOAT32:
240             case V1_3::OperandType::TENSOR_BOOL8:
241             case V1_3::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
242                 if (operand.scale != 0.f) {
243                     LOG(ERROR) << "Operand " << index << ": Operand of type "
244                                << getOperandTypeName(operand.type) << " with a non-zero scale ("
245                                << operand.scale << ")";
246                     return false;
247                 }
248                 break;
249             case V1_3::OperandType::TENSOR_INT32:
250                 // TENSOR_INT32 may be used with or without scale, depending on the operation.
251                 if (operand.scale < 0.f) {
252                     LOG(ERROR) << "Operand " << index << ": Operand of type "
253                                << getOperandTypeName(operand.type) << " with a negative scale";
254                     return false;
255                 }
256                 break;
257             case V1_3::OperandType::TENSOR_QUANT8_ASYMM:
258             case V1_3::OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
259             case V1_3::OperandType::TENSOR_QUANT8_SYMM:
260             case V1_3::OperandType::TENSOR_QUANT16_ASYMM:
261             case V1_3::OperandType::TENSOR_QUANT16_SYMM:
262                 if (operand.scale <= 0.f) {
263                     LOG(ERROR) << "Operand " << index << ": Operand of type "
264                                << getOperandTypeName(operand.type) << " with a non-positive scale";
265                     return false;
266                 }
267                 break;
268             default:
269                 if (isExtensionOperandType(operand.type) && operand.scale != 0.f) {
270                     LOG(ERROR) << "Operand " << index << ": Operand of type "
271                                << getOperandTypeName(operand.type) << " with a non-zero scale ("
272                                << operand.scale << ")";
273                     return false;
274                 }
275                 // No validation for OEM types.
276                 // TODO(b/119869082) We should have a separate type for TENSOR_INT32 with a scale.
277                 break;
278         }
279 
280         // Validate the zeroPoint.
281         switch (operand.type) {
282             case V1_3::OperandType::FLOAT16:
283             case V1_3::OperandType::FLOAT32:
284             case V1_3::OperandType::INT32:
285             case V1_3::OperandType::UINT32:
286             case V1_3::OperandType::BOOL:
287             case V1_3::OperandType::SUBGRAPH:
288             case V1_3::OperandType::TENSOR_FLOAT16:
289             case V1_3::OperandType::TENSOR_FLOAT32:
290             case V1_3::OperandType::TENSOR_INT32:
291             case V1_3::OperandType::TENSOR_BOOL8:
292             case V1_3::OperandType::TENSOR_QUANT8_SYMM:
293             case V1_3::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
294                 if (operand.zeroPoint != 0) {
295                     LOG(ERROR) << "Operand " << index << ": Operand of type "
296                                << getOperandTypeName(operand.type) << " with a non-zero zeroPoint "
297                                << operand.zeroPoint;
298                     return false;
299                 }
300                 break;
301             case V1_3::OperandType::TENSOR_QUANT8_ASYMM:
302                 if (operand.zeroPoint < 0 || operand.zeroPoint > 255) {
303                     LOG(ERROR) << "Operand " << index << ": Operand of type "
304                                << getOperandTypeName(operand.type) << " with an invalid zeroPoint "
305                                << operand.zeroPoint << ", must be in range [0, 255]";
306                     return false;
307                 }
308                 break;
309             case V1_3::OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
310                 if (operand.zeroPoint < -128 || operand.zeroPoint > 127) {
311                     LOG(ERROR) << "Operand " << index << ": Operand of type "
312                                << getOperandTypeName(operand.type) << " with an invalid zeroPoint "
313                                << operand.zeroPoint << ", must be in range [-128, 127]";
314                     return false;
315                 }
316                 break;
317             case V1_3::OperandType::TENSOR_QUANT16_ASYMM:
318                 if (operand.zeroPoint < 0 || operand.zeroPoint > 65535) {
319                     LOG(ERROR) << "Operand " << index << ": Operand of type "
320                                << getOperandTypeName(operand.type) << " with an invalid zeroPoint "
321                                << operand.zeroPoint << ", must be in range [0, 65535]";
322                     return false;
323                 }
324                 break;
325             case V1_3::OperandType::TENSOR_QUANT16_SYMM:
326                 if (operand.zeroPoint != 0) {
327                     LOG(ERROR) << "Operand " << index << ": Operand of type "
328                                << getOperandTypeName(operand.type) << " with a non-zero zeroPoint "
329                                << operand.zeroPoint;
330                     return false;
331                 }
332                 break;
333             default:
334                 if (isExtensionOperandType(operand.type) && operand.zeroPoint != 0) {
335                     LOG(ERROR) << "Operand " << index << ": Operand of type "
336                                << getOperandTypeName(operand.type) << " with a non-zero zeroPoint "
337                                << operand.zeroPoint;
338                     return false;
339                 }
340                 // No validation for OEM types.
341                 break;
342         }
343 
344         NN_RET_CHECK(validateOperandExtraParams(operand, index));
345 
346         // Validate the lifetime and the location.
347         const V1_0::DataLocation& location = operand.location;
348         switch (operand.lifetime) {
349             case V1_3::OperandLifeTime::CONSTANT_COPY:
350                 if (location.poolIndex != 0) {
351                     LOG(ERROR) << "Operand " << index
352                                << ": CONSTANT_COPY with a non-zero poolIndex "
353                                << location.poolIndex;
354                     return false;
355                 }
356                 // Do the addition using size_t to avoid potential wrap-around problems.
357                 if (static_cast<size_t>(location.offset) + location.length > operandValues.size()) {
358                     LOG(ERROR) << "Operand " << index
359                                << ": OperandValue location out of range.  Starts at "
360                                << location.offset << ", length " << location.length << ", max "
361                                << operandValues.size();
362                     return false;
363                 }
364                 break;
365             case V1_3::OperandLifeTime::CONSTANT_REFERENCE:
366                 if (!poolVerifier.validate(location)) {
367                     return false;
368                 }
369                 break;
370             case V1_3::OperandLifeTime::TEMPORARY_VARIABLE:
371             case V1_3::OperandLifeTime::SUBGRAPH_INPUT:
372             case V1_3::OperandLifeTime::SUBGRAPH_OUTPUT:
373             case V1_3::OperandLifeTime::NO_VALUE:
374                 if (location.poolIndex != 0 || location.offset != 0 || location.length != 0) {
375                     LOG(ERROR) << "Operand " << index << ": Unexpected poolIndex "
376                                << location.poolIndex << ", offset " << location.offset
377                                << ", or length " << location.length << " for operand of lifetime "
378                                << toString(operand.lifetime);
379                     return false;
380                 }
381                 break;
382             case V1_3::OperandLifeTime::SUBGRAPH: {
383                 if (location.poolIndex != 0) {
384                     LOG(ERROR) << "Operand " << index << ": SUBGRAPH with a non-zero poolIndex "
385                                << location.poolIndex;
386                     return false;
387                 }
388                 if (location.offset >= subgraphs.size()) {
389                     LOG(ERROR) << "Model::Subgraph index out of range: " << location.offset
390                                << " >= " << subgraphs.size();
391                     return false;
392                 }
393                 if (location.length != 0) {
394                     LOG(ERROR) << "Operand " << index << ": SUBGRAPH with a non-zero length "
395                                << location.length;
396                     return false;
397                 }
398             } break;
399             default:
400                 LOG(ERROR) << "Operand " << index << ": Invalid lifetime "
401                            << toString(operand.lifetime);
402                 return false;
403         }
404 
405         // Make sure SUBGRAPH operand type and lifetime always go together.
406         if ((operand.type == V1_3::OperandType::SUBGRAPH) !=
407             (operand.lifetime == V1_3::OperandLifeTime::SUBGRAPH)) {
408             LOG(ERROR) << "Operand " << index << ": Operand of type " << toString(operand.type)
409                        << " cannot have lifetime " << toString(operand.lifetime);
410             return false;
411         }
412 
413         // For constants, validate that the length is as expected. The other lifetimes
414         // expect the length to be 0. Don't validate for OEM types.
415         if (operand.lifetime == V1_3::OperandLifeTime::CONSTANT_REFERENCE ||
416             operand.lifetime == V1_3::OperandLifeTime::CONSTANT_COPY) {
417             if (!isExtensionOperandType(operand.type) && operand.type != V1_3::OperandType::OEM &&
418                 operand.type != V1_3::OperandType::TENSOR_OEM_BYTE) {
419                 uint32_t expectedLength = nonExtensionOperandSizeOfData(operand);
420                 if (location.length != expectedLength) {
421                     LOG(ERROR) << "Operand " << index << ": For operand " << toString(operand)
422                                << " expected a size of " << expectedLength << " but got "
423                                << location.length;
424                     return false;
425                 }
426             }
427         }
428 
429         index++;
430     }
431     return true;
432 }
433 
getHalVersion(const V1_0::Operation &)434 static HalVersion getHalVersion(const V1_0::Operation&) {
435     return HalVersion::V1_0;
436 }
437 
getHalVersion(const V1_1::Operation &)438 static HalVersion getHalVersion(const V1_1::Operation&) {
439     return HalVersion::V1_1;
440 }
441 
getHalVersion(const V1_2::Operation &)442 static HalVersion getHalVersion(const V1_2::Operation&) {
443     return HalVersion::V1_2;
444 }
445 
getHalVersion(const V1_3::Operation &)446 static HalVersion getHalVersion(const V1_3::Operation&) {
447     return HalVersion::V1_3;
448 }
449 
450 template <typename VersionedOperation>
validateOperations(const hardware::hidl_vec<VersionedOperation> & operations,const hardware::hidl_vec<V1_3::Operand> & operands,const hardware::hidl_vec<V1_3::Subgraph> & subgraphs,ValidationMode mode)451 static bool validateOperations(const hardware::hidl_vec<VersionedOperation>& operations,
452                                const hardware::hidl_vec<V1_3::Operand>& operands,
453                                const hardware::hidl_vec<V1_3::Subgraph>& subgraphs,
454                                ValidationMode mode) {
455     auto canonicalSubgraphs = uncheckedConvert(subgraphs);
456     auto isValidSubgraphReference = [&canonicalSubgraphs](const Operand& modelOperand) -> bool {
457         NN_RET_CHECK(modelOperand.type == OperandType::SUBGRAPH)
458                 << "Unexpected operand type: " << modelOperand.type;
459         NN_RET_CHECK_LT(modelOperand.location.offset, canonicalSubgraphs.size())
460                 << "Invalid subgraph reference";
461         return true;
462     };
463     auto getSubgraph =
464             [&canonicalSubgraphs](const Operand& modelOperand) -> const Model::Subgraph* {
465         CHECK_LT(modelOperand.location.offset, canonicalSubgraphs.size());
466         return &canonicalSubgraphs[modelOperand.location.offset];
467     };
468     auto getInputCount = [&getSubgraph](const Operand& modelOperand) -> uint32_t {
469         return getSubgraph(modelOperand)->inputIndexes.size();
470     };
471     auto getOutputCount = [&getSubgraph](const Operand& modelOperand) -> uint32_t {
472         return getSubgraph(modelOperand)->outputIndexes.size();
473     };
474     auto getInputOperand = [&getSubgraph](const Operand& modelOperand,
475                                           uint32_t index) -> const Operand* {
476         const Model::Subgraph& subgraph = *getSubgraph(modelOperand);
477         CHECK_LT(subgraph.inputIndexes[index], subgraph.operands.size());
478         return &subgraph.operands[subgraph.inputIndexes[index]];
479     };
480     auto getOutputOperand = [&getSubgraph](const Operand& modelOperand,
481                                            uint32_t index) -> const Operand* {
482         const Model::Subgraph& subgraph = *getSubgraph(modelOperand);
483         CHECK_LT(subgraph.outputIndexes[index], subgraph.operands.size());
484         return &subgraph.operands[subgraph.outputIndexes[index]];
485     };
486     for (auto& op : operations) {
487         // TODO Validate the shapes and any known values. This is currently
488         // done in CpuExecutor but should be done here for all drivers.
489         int error = validateOperation(static_cast<int32_t>(op.type), op.inputs.size(),
490                                       op.inputs.size() > 0 ? op.inputs.data() : nullptr,
491                                       op.outputs.size(),
492                                       op.outputs.size() > 0 ? op.outputs.data() : nullptr,
493                                       uncheckedConvert(operands), getHalVersion(op),
494                                       {.isValidSubgraphReference = isValidSubgraphReference,
495                                        .getSubgraphInputCount = getInputCount,
496                                        .getSubgraphOutputCount = getOutputCount,
497                                        .getSubgraphInputOperand = getInputOperand,
498                                        .getSubgraphOutputOperand = getOutputOperand,
499                                        // 1.3 HAL does not support CF operations with operands of
500                                        // unknown size. See http://b/132458982#comment63.
501                                        .allowControlFlowOperationWithOperandOfUnknownSize =
502                                                mode == ValidationMode::RUNTIME});
503         if (error != ANEURALNETWORKS_NO_ERROR) {
504             LOG(ERROR) << "Invalid operation " << toString(op.type);
505             return false;
506         }
507 
508         // This is redundant because of the checks in validateGraph(),
509         // but it is retained here in order to emit more informative
510         // error messages.
511         for (uint32_t i : op.outputs) {
512             const V1_3::Operand& operand = operands[i];
513             if (operand.lifetime != V1_3::OperandLifeTime::TEMPORARY_VARIABLE &&
514                 operand.lifetime != V1_3::OperandLifeTime::SUBGRAPH_OUTPUT) {
515                 LOG(ERROR) << "Writing to operand " << i << " with incompatible lifetime "
516                            << toString(operand.lifetime);
517                 return false;
518             }
519         }
520     }
521     return true;
522 }
523 
validatePool(const hardware::hidl_memory & pool,HalVersion ver)524 bool validatePool(const hardware::hidl_memory& pool, HalVersion ver) {
525     const auto& name = pool.name();
526     if (name != "ashmem" && name != "mmap_fd" &&
527         ((ver < HalVersion::V1_2) ||
528          (name != "hardware_buffer_blob" && name != "hardware_buffer"))) {
529         LOG(ERROR) << "Unsupported memory type " << name;
530         return false;
531     }
532     if (pool.handle() == nullptr) {
533         LOG(ERROR) << "Memory of type " << name << " is null";
534         return false;
535     }
536     return true;
537 }
538 
validatePool(const V1_3::Request::MemoryPool & pool,HalVersion ver)539 bool validatePool(const V1_3::Request::MemoryPool& pool, HalVersion ver) {
540     switch (pool.getDiscriminator()) {
541         case V1_3::Request::MemoryPool::hidl_discriminator::hidlMemory:
542             return validatePool(pool.hidlMemory(), ver);
543         case V1_3::Request::MemoryPool::hidl_discriminator::token:
544             return pool.token() > 0;
545     }
546     LOG(FATAL) << "unknown MemoryPool discriminator";
547     return false;
548 }
549 
550 template <class T_MemoryPool>
validatePools(const hardware::hidl_vec<T_MemoryPool> & pools,HalVersion ver)551 static bool validatePools(const hardware::hidl_vec<T_MemoryPool>& pools, HalVersion ver) {
552     return std::all_of(pools.begin(), pools.end(),
553                        [ver](const auto& pool) { return validatePool(pool, ver); });
554 }
555 
validateModelInputOutputs(const hardware::hidl_vec<uint32_t> indexes,const hardware::hidl_vec<V1_3::Operand> & operands,V1_3::OperandLifeTime lifetime)556 static bool validateModelInputOutputs(const hardware::hidl_vec<uint32_t> indexes,
557                                       const hardware::hidl_vec<V1_3::Operand>& operands,
558                                       V1_3::OperandLifeTime lifetime) {
559     const size_t operandCount = operands.size();
560     for (uint32_t i : indexes) {
561         if (i >= operandCount) {
562             LOG(ERROR) << "Model input or output index out of range: " << i << "/" << operandCount;
563             return false;
564         }
565         const V1_3::Operand& operand = operands[i];
566         if (operand.lifetime != lifetime) {
567             LOG(ERROR) << "Model input or output operand " << i << " has lifetime of "
568                        << toString(operand.lifetime) << " instead of the expected "
569                        << toString(lifetime);
570             return false;
571         }
572     }
573 
574     std::vector<uint32_t> sortedIndexes = indexes;
575     std::sort(sortedIndexes.begin(), sortedIndexes.end());
576     auto adjacentI = std::adjacent_find(sortedIndexes.begin(), sortedIndexes.end());
577     if (adjacentI != sortedIndexes.end()) {
578         LOG(ERROR) << "Model input or output occurs multiple times: " << *adjacentI;
579         return false;
580     }
581 
582     for (size_t i = 0; i < operands.size(); ++i) {
583         if (operands[i].lifetime == lifetime &&
584             !binary_search(sortedIndexes.begin(), sortedIndexes.end(), i)) {
585             LOG(ERROR) << "Operand " << i << " marked as " << toString(lifetime)
586                        << " but is not included in Model input or output indexes";
587             return false;
588         }
589     }
590 
591     return true;
592 }
593 
594 template <typename VersionedModelOrSubgraph>
validateGraph(const VersionedModelOrSubgraph & model)595 static bool validateGraph(const VersionedModelOrSubgraph& model) {
596     // set up counts
597     std::vector<uint32_t> operandNumberOfConsumers(model.operands.size(), 0);
598     //     Either the operand has a known value before model execution
599     //     begins, or we've seen a writer for this operand while
600     //     walking operands in execution order.
601     std::vector<bool> operandValueKnown(model.operands.size(), false);
602 
603     // mark known operands
604     for (size_t i = 0; i < model.operands.size(); ++i) {
605         const auto& operand = model.operands[i];
606         const V1_3::OperandLifeTime lifetime = convertToV1_3(operand.lifetime);
607         operandValueKnown[i] = lifetime == V1_3::OperandLifeTime::SUBGRAPH_INPUT ||
608                                lifetime == V1_3::OperandLifeTime::CONSTANT_COPY ||
609                                lifetime == V1_3::OperandLifeTime::CONSTANT_REFERENCE ||
610                                lifetime == V1_3::OperandLifeTime::NO_VALUE ||
611                                lifetime == V1_3::OperandLifeTime::SUBGRAPH;
612     }
613 
614     // Validate that operations are sorted into execution order.
615     //
616     // If there is a cycle in the graph, the operations will not
617     // appear to be sorted into execution order: Some operation will
618     // have an input for which operandValueKnown[] is false.
619     for (size_t i = 0; i < model.operations.size(); ++i) {
620         const auto& operation = model.operations[i];
621 
622         for (size_t j = 0; j < operation.inputs.size(); ++j) {
623             uint32_t k = operation.inputs[j];
624             if (!operandValueKnown[k]) {
625                 LOG(ERROR) << "Operation " << i << " input " << j << " (operand " << k
626                            << ") is read before it is written";
627                 return false;
628             }
629             operandNumberOfConsumers[k]++;
630         }
631 
632         for (size_t j = 0; j < operation.outputs.size(); ++j) {
633             uint32_t k = operation.outputs[j];
634             if (operandValueKnown[k]) {
635                 // Assuming validateOperations() has returned true, we
636                 // know that this output is TEMPORARY_VARIABLE or
637                 // MODEL_OUTPUT, and so the only way
638                 // operandValueKnown[k] can be true is if we've
639                 // already seen a writer for this operand.
640                 LOG(ERROR) << "Operation " << i << " output " << j << " (operand " << k
641                            << ") has already been written";
642                 return false;
643             }
644             operandValueKnown[k] = true;
645         }
646     }
647 
648     // validate number of consumers
649     //
650     // TODO Because we have to validate it, there was no point in including it
651     // in struct Operand. For the next release, consider removing unless we have
652     // an additional process in system space that creates this value. In that
653     // case, it would not have to be validated.
654     for (size_t i = 0; i < model.operands.size(); ++i) {
655         if (model.operands[i].numberOfConsumers != operandNumberOfConsumers[i]) {
656             LOG(ERROR) << "Operand " << i << " has incorrect number of consumers "
657                        << model.operands[i].numberOfConsumers << ", expected "
658                        << operandNumberOfConsumers[i];
659             return false;
660         }
661     }
662 
663     // verify all operands are written
664     for (size_t i = 0; i < model.operands.size(); ++i) {
665         if (!operandValueKnown[i]) {
666             LOG(ERROR) << "Operand " << i << " is never written";
667             return false;
668         }
669     }
670 
671     return true;
672 }
673 
674 // Makes sure the model does not contain subgraph reference cycles.
checkNoReferenceCycles(const V1_3::Model & model,const V1_3::Subgraph & subgraph,std::set<const V1_3::Subgraph * > * path)675 static bool checkNoReferenceCycles(const V1_3::Model& model, const V1_3::Subgraph& subgraph,
676                                    std::set<const V1_3::Subgraph*>* path) {
677     auto [_, isNew] = path->insert(&subgraph);
678     if (!isNew) {
679         LOG(ERROR) << "Model contains a circular subgraph reference";
680         return false;
681     }
682     for (const V1_3::Operand& operand : subgraph.operands) {
683         if (operand.lifetime == V1_3::OperandLifeTime::SUBGRAPH) {
684             uint32_t refSubgraphIndex = operand.location.offset;
685             if (!checkNoReferenceCycles(model, model.referenced[refSubgraphIndex], path)) {
686                 return false;
687             }
688         }
689     }
690     path->erase(&subgraph);
691     return true;
692 }
693 
checkNoReferenceCycles(const V1_3::Model & model)694 static bool checkNoReferenceCycles(const V1_3::Model& model) {
695     std::set<const V1_3::Subgraph*> path;
696     return checkNoReferenceCycles(model, model.main, &path);
697 }
698 
699 template <class T_Model>
validateModel(const T_Model & model,ValidationMode mode)700 bool validateModel(const T_Model& model, ValidationMode mode) {
701     NNTRACE_FULL(NNTRACE_LAYER_UTILITY, NNTRACE_PHASE_UNSPECIFIED, "validateModel");
702     HalVersion version = ModelToHalVersion<T_Model>::version;
703     if (model.operations.size() == 0 || model.operands.size() == 0) {
704         LOG(ERROR) << "Invalid empty model.";
705         return false;
706     }
707     // We only need versioned operands for their validation. For all the other
708     // validations we can use operands upcasted to the latest version.
709     const hardware::hidl_vec<V1_3::Operand> latestVersionOperands = convertToV1_3(model.operands);
710     return (validateOperands(model.operands, model.operandValues, model.pools, /*subgraphs=*/{},
711                              /*allowUnspecifiedRank=*/version >= HalVersion::V1_2) &&
712             validateOperations(model.operations, latestVersionOperands, /*subgraphs=*/{}, mode) &&
713             validateModelInputOutputs(model.inputIndexes, latestVersionOperands,
714                                       V1_3::OperandLifeTime::SUBGRAPH_INPUT) &&
715             validateModelInputOutputs(model.outputIndexes, latestVersionOperands,
716                                       V1_3::OperandLifeTime::SUBGRAPH_OUTPUT) &&
717             validatePools(model.pools, version) && validateGraph(model));
718 }
719 
720 template bool validateModel<V1_0::Model>(const V1_0::Model& model, ValidationMode mode);
721 template bool validateModel<V1_1::Model>(const V1_1::Model& model, ValidationMode mode);
722 template bool validateModel<V1_2::Model>(const V1_2::Model& model, ValidationMode mode);
723 
724 template <>
validateModel(const V1_3::Model & model,ValidationMode mode)725 bool validateModel(const V1_3::Model& model, ValidationMode mode) {
726     NNTRACE_FULL(NNTRACE_LAYER_UTILITY, NNTRACE_PHASE_UNSPECIFIED, "validateModel");
727     if (model.main.operations.size() == 0 || model.main.operands.size() == 0) {
728         LOG(ERROR) << "Invalid empty model.";
729         return false;
730     }
731     auto validateSubgraph = [&model, mode](const V1_3::Subgraph& subgraph) -> bool {
732         return (validateOperands(subgraph.operands, model.operandValues, model.pools,
733                                  model.referenced, /*allowUnspecifiedRank=*/true) &&
734                 validateOperations(subgraph.operations, subgraph.operands, model.referenced,
735                                    mode) &&
736                 validateModelInputOutputs(subgraph.inputIndexes, subgraph.operands,
737                                           V1_3::OperandLifeTime::SUBGRAPH_INPUT) &&
738                 validateModelInputOutputs(subgraph.outputIndexes, subgraph.operands,
739                                           V1_3::OperandLifeTime::SUBGRAPH_OUTPUT) &&
740                 validateGraph(subgraph));
741     };
742     return (validateSubgraph(model.main) &&
743             std::all_of(model.referenced.begin(), model.referenced.end(), validateSubgraph) &&
744             validatePools(model.pools, HalVersion::V1_3) && checkNoReferenceCycles(model));
745 }
746 
747 // Validates the arguments of a request. type is either "input" or "output" and is used
748 // for printing error messages. The operandIndexes is the appropriate array of input
749 // or output operand indexes that was passed to the ANeuralNetworksModel_identifyInputsAndOutputs.
validateRequestArguments(const hardware::hidl_vec<V1_0::RequestArgument> & requestArguments,const hardware::hidl_vec<uint32_t> & operandIndexes,const hardware::hidl_vec<V1_3::Operand> & operands,const MemoryAccessVerifier & poolVerifier,bool allowUnspecified,const char * type)750 static bool validateRequestArguments(
751         const hardware::hidl_vec<V1_0::RequestArgument>& requestArguments,
752         const hardware::hidl_vec<uint32_t>& operandIndexes,
753         const hardware::hidl_vec<V1_3::Operand>& operands, const MemoryAccessVerifier& poolVerifier,
754         bool allowUnspecified, const char* type) {
755     // The request should specify as many arguments as were described in the model.
756     const size_t requestArgumentCount = requestArguments.size();
757     if (requestArgumentCount != operandIndexes.size()) {
758         LOG(ERROR) << "Request specifies " << requestArgumentCount << " " << type
759                    << "s but the model has " << operandIndexes.size();
760         return false;
761     }
762     for (size_t requestArgumentIndex = 0; requestArgumentIndex < requestArgumentCount;
763          requestArgumentIndex++) {
764         const V1_0::RequestArgument& requestArgument = requestArguments[requestArgumentIndex];
765         const V1_0::DataLocation& location = requestArgument.location;
766         // Get the operand index for this argument. We extract it from the list
767         // that was provided in the call to ANeuralNetworksModel_identifyInputsAndOutputs.
768         // We assume in this function that the model has been validated already.
769         const uint32_t operandIndex = operandIndexes[requestArgumentIndex];
770         const V1_3::Operand& operand = operands[operandIndex];
771         if (requestArgument.hasNoValue) {
772             if (location.poolIndex != 0 || location.offset != 0 || location.length != 0 ||
773                 requestArgument.dimensions.size() != 0) {
774                 LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
775                            << " has no value yet has details.";
776                 return false;
777             }
778         } else {
779             // Validate the location.
780             if (!poolVerifier.validate(location)) {
781                 return false;
782             }
783             // If the argument specified a dimension, validate it.
784             uint32_t modelRank = operand.dimensions.size();
785             uint32_t requestRank = requestArgument.dimensions.size();
786             if (requestRank == 0) {
787                 if (!allowUnspecified) {
788                     // NOTE: validateRequestArguments cannot validate unknown tensor rank with
789                     // extension operand type.
790                     if (!isExtensionOperandType(operand.type) &&
791                         !nonExtensionOperandTypeIsScalar(static_cast<int>(operand.type))) {
792                         NN_RET_CHECK_GT(modelRank, 0u)
793                                 << "Model " << type << " " << requestArgumentIndex
794                                 << " has unknown rank but the request does not specify the rank.";
795                     }
796                     // Validate that all the dimensions are specified in the model.
797                     for (size_t i = 0; i < modelRank; i++) {
798                         if (operand.dimensions[i] == 0) {
799                             LOG(ERROR)
800                                     << "Model has dimension " << i
801                                     << " set to 0 but the request does not specify the dimension.";
802                             return false;
803                         }
804                     }
805                 }
806             } else {
807                 if (modelRank != 0 && requestRank != modelRank) {
808                     LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
809                                << " has number of dimensions (" << requestRank
810                                << ") different than the model's (" << modelRank << ")";
811                     return false;
812                 }
813                 for (size_t i = 0; i < requestRank; i++) {
814                     if (modelRank != 0 && requestArgument.dimensions[i] != operand.dimensions[i] &&
815                         operand.dimensions[i] != 0) {
816                         LOG(ERROR)
817                                 << "Request " << type << " " << requestArgumentIndex
818                                 << " has dimension " << i << " of " << requestArgument.dimensions[i]
819                                 << " different than the model's " << operand.dimensions[i];
820                         return false;
821                     }
822                     if (requestArgument.dimensions[i] == 0 && !allowUnspecified) {
823                         LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
824                                    << " has dimension " << i << " of zero";
825                         return false;
826                     }
827                 }
828             }
829         }
830     }
831     return true;
832 }
833 
834 template <class T_Request, class T_Model>
validateRequest(const T_Request & request,const T_Model & model,bool allowUnspecifiedOutput)835 bool validateRequest(const T_Request& request, const T_Model& model, bool allowUnspecifiedOutput) {
836     HalVersion version = ModelToHalVersion<T_Model>::version;
837     MemoryAccessVerifier poolVerifier(request.pools);
838     return (validateRequestArguments(request.inputs, model.inputIndexes,
839                                      convertToV1_3(model.operands), poolVerifier,
840                                      /*allowUnspecified=*/false, "input") &&
841             validateRequestArguments(
842                     request.outputs, model.outputIndexes, convertToV1_3(model.operands),
843                     poolVerifier,
844                     /*allowUnspecified=*/version >= HalVersion::V1_2 && allowUnspecifiedOutput,
845                     "output") &&
846             validatePools(request.pools, version));
847 }
848 
849 template bool validateRequest<V1_0::Request, V1_0::Model>(const V1_0::Request& request,
850                                                           const V1_0::Model& model,
851                                                           bool allowUnspecifiedOutput);
852 template bool validateRequest<V1_0::Request, V1_1::Model>(const V1_0::Request& request,
853                                                           const V1_1::Model& model,
854                                                           bool allowUnspecifiedOutput);
855 template bool validateRequest<V1_0::Request, V1_2::Model>(const V1_0::Request& request,
856                                                           const V1_2::Model& model,
857                                                           bool allowUnspecifiedOutput);
858 
859 template <>
validateRequest(const V1_3::Request & request,const V1_3::Model & model,bool allowUnspecifiedOutput)860 bool validateRequest(const V1_3::Request& request, const V1_3::Model& model,
861                      bool allowUnspecifiedOutput) {
862     return (validateRequestArguments(request.inputs, model.main.inputIndexes, model.main.operands,
863                                      request.pools,
864                                      /*allowUnspecified=*/false, "input") &&
865             validateRequestArguments(request.outputs, model.main.outputIndexes, model.main.operands,
866                                      request.pools, allowUnspecifiedOutput, "output") &&
867             validatePools(request.pools, HalVersion::V1_3));
868 }
869 
validateMemoryDesc(const V1_3::BufferDesc & desc,const hardware::hidl_vec<sp<V1_3::IPreparedModel>> & preparedModels,const hardware::hidl_vec<V1_3::BufferRole> & inputRoles,const hardware::hidl_vec<V1_3::BufferRole> & outputRoles,std::function<const V1_3::Model * (const sp<V1_3::IPreparedModel> &)> getModel,std::set<HalPreparedModelRole> * preparedModelRoles,V1_3::Operand * combinedOperand)870 bool validateMemoryDesc(const V1_3::BufferDesc& desc,
871                         const hardware::hidl_vec<sp<V1_3::IPreparedModel>>& preparedModels,
872                         const hardware::hidl_vec<V1_3::BufferRole>& inputRoles,
873                         const hardware::hidl_vec<V1_3::BufferRole>& outputRoles,
874                         std::function<const V1_3::Model*(const sp<V1_3::IPreparedModel>&)> getModel,
875                         std::set<HalPreparedModelRole>* preparedModelRoles,
876                         V1_3::Operand* combinedOperand) {
877     NN_RET_CHECK(preparedModels.size() != 0);
878     NN_RET_CHECK(inputRoles.size() != 0 || outputRoles.size() != 0);
879 
880     std::set<HalPreparedModelRole> roles;
881     std::vector<V1_3::Operand> operands;
882     operands.reserve(inputRoles.size() + outputRoles.size());
883     for (const auto& role : inputRoles) {
884         NN_RET_CHECK_LT(role.modelIndex, preparedModels.size());
885         const auto& preparedModel = preparedModels[role.modelIndex];
886         NN_RET_CHECK(preparedModel != nullptr);
887         const auto* model = getModel(preparedModel);
888         NN_RET_CHECK(model != nullptr);
889         const auto& inputIndexes = model->main.inputIndexes;
890         NN_RET_CHECK_LT(role.ioIndex, inputIndexes.size());
891         NN_RET_CHECK_GT(role.frequency, 0.0f);
892         NN_RET_CHECK_LE(role.frequency, 1.0f);
893         const auto [it, success] = roles.emplace(preparedModel.get(), IOType::INPUT, role.ioIndex);
894         NN_RET_CHECK(success);
895         operands.push_back(model->main.operands[inputIndexes[role.ioIndex]]);
896     }
897     for (const auto& role : outputRoles) {
898         NN_RET_CHECK_LT(role.modelIndex, preparedModels.size());
899         const auto& preparedModel = preparedModels[role.modelIndex];
900         NN_RET_CHECK(preparedModel != nullptr);
901         const auto* model = getModel(preparedModel);
902         NN_RET_CHECK(model != nullptr);
903         const auto& outputIndexes = model->main.outputIndexes;
904         NN_RET_CHECK_LT(role.ioIndex, outputIndexes.size());
905         NN_RET_CHECK_GT(role.frequency, 0.0f);
906         NN_RET_CHECK_LE(role.frequency, 1.0f);
907         const auto [it, success] = roles.emplace(preparedModel.get(), IOType::OUTPUT, role.ioIndex);
908         NN_RET_CHECK(success);
909         operands.push_back(model->main.operands[outputIndexes[role.ioIndex]]);
910     }
911 
912     CHECK(!operands.empty());
913     const auto opType = operands[0].type;
914     const bool isExtension = isExtensionOperandType(opType);
915 
916     std::vector<uint32_t> dimensions = desc.dimensions;
917     for (const auto& operand : operands) {
918         NN_RET_CHECK(operand.type == operands[0].type)
919                 << toString(operand.type) << " vs " << toString(operands[0].type);
920         NN_RET_CHECK_EQ(operand.scale, operands[0].scale);
921         NN_RET_CHECK_EQ(operand.zeroPoint, operands[0].zeroPoint);
922         // NOTE: validateMemoryDesc cannot validate extra parameters for extension operand type.
923         if (!isExtension) {
924             NN_RET_CHECK(operand.extraParams == operands[0].extraParams)
925                     << toString(operand.extraParams) << " vs " << toString(operands[0].extraParams);
926         }
927         const auto combined = combineDimensions(dimensions, operand.dimensions);
928         NN_RET_CHECK(combined.has_value());
929         dimensions = combined.value();
930     }
931 
932     // NOTE: validateMemoryDesc cannot validate scalar dimensions with extension operand type.
933     if (!isExtension) {
934         NN_RET_CHECK(!nonExtensionOperandTypeIsScalar(static_cast<int>(opType)) ||
935                      dimensions.empty())
936                 << "invalid dimensions with scalar operand type.";
937     }
938 
939     if (preparedModelRoles != nullptr) {
940         *preparedModelRoles = std::move(roles);
941     }
942     if (combinedOperand != nullptr) {
943         *combinedOperand = operands[0];
944         combinedOperand->dimensions = dimensions;
945     }
946     return true;
947 }
948 
validateExecutionPreference(V1_1::ExecutionPreference preference)949 bool validateExecutionPreference(V1_1::ExecutionPreference preference) {
950     return preference == V1_1::ExecutionPreference::LOW_POWER ||
951            preference == V1_1::ExecutionPreference::FAST_SINGLE_ANSWER ||
952            preference == V1_1::ExecutionPreference::SUSTAINED_SPEED;
953 }
954 
validatePriority(V1_3::Priority priority)955 bool validatePriority(V1_3::Priority priority) {
956     return priority == V1_3::Priority::LOW || priority == V1_3::Priority::MEDIUM ||
957            priority == V1_3::Priority::HIGH;
958 }
959 
validOperandType(V1_0::OperandType operandType)960 bool validOperandType(V1_0::OperandType operandType) {
961     switch (operandType) {
962         case V1_0::OperandType::FLOAT32:
963         case V1_0::OperandType::INT32:
964         case V1_0::OperandType::UINT32:
965         case V1_0::OperandType::TENSOR_FLOAT32:
966         case V1_0::OperandType::TENSOR_INT32:
967         case V1_0::OperandType::TENSOR_QUANT8_ASYMM:
968         case V1_0::OperandType::OEM:
969         case V1_0::OperandType::TENSOR_OEM_BYTE:
970             return true;
971         default:
972             return false;
973     }
974 }
975 
validOperandType(V1_2::OperandType operandType)976 bool validOperandType(V1_2::OperandType operandType) {
977     switch (operandType) {
978         case V1_2::OperandType::FLOAT16:
979         case V1_2::OperandType::FLOAT32:
980         case V1_2::OperandType::INT32:
981         case V1_2::OperandType::UINT32:
982         case V1_2::OperandType::BOOL:
983         case V1_2::OperandType::TENSOR_FLOAT16:
984         case V1_2::OperandType::TENSOR_FLOAT32:
985         case V1_2::OperandType::TENSOR_INT32:
986         case V1_2::OperandType::TENSOR_QUANT8_ASYMM:
987         case V1_2::OperandType::TENSOR_QUANT8_SYMM:
988         case V1_2::OperandType::TENSOR_QUANT16_ASYMM:
989         case V1_2::OperandType::TENSOR_QUANT16_SYMM:
990         case V1_2::OperandType::TENSOR_BOOL8:
991         case V1_2::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
992         case V1_2::OperandType::OEM:
993         case V1_2::OperandType::TENSOR_OEM_BYTE:
994             return true;
995         default:
996             return isExtensionOperandType(static_cast<V1_3::OperandType>(operandType));
997     }
998 }
999 
validOperandType(V1_3::OperandType operandType)1000 bool validOperandType(V1_3::OperandType operandType) {
1001     switch (operandType) {
1002         case V1_3::OperandType::FLOAT16:
1003         case V1_3::OperandType::FLOAT32:
1004         case V1_3::OperandType::INT32:
1005         case V1_3::OperandType::UINT32:
1006         case V1_3::OperandType::BOOL:
1007         case V1_3::OperandType::TENSOR_FLOAT16:
1008         case V1_3::OperandType::TENSOR_FLOAT32:
1009         case V1_3::OperandType::TENSOR_INT32:
1010         case V1_3::OperandType::TENSOR_QUANT8_ASYMM:
1011         case V1_3::OperandType::TENSOR_QUANT8_SYMM:
1012         case V1_3::OperandType::TENSOR_QUANT16_ASYMM:
1013         case V1_3::OperandType::TENSOR_QUANT16_SYMM:
1014         case V1_3::OperandType::TENSOR_BOOL8:
1015         case V1_3::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
1016         case V1_3::OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
1017         case V1_3::OperandType::SUBGRAPH:
1018         case V1_3::OperandType::OEM:
1019         case V1_3::OperandType::TENSOR_OEM_BYTE:
1020             return true;
1021         default:
1022             return isExtensionOperandType(operandType);
1023     }
1024 }
1025 
1026 }  // namespace nn
1027 }  // namespace android
1028