1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "ValidateHal"
18
19 #include "ValidateHal.h"
20
21 #include <android-base/logging.h>
22
23 #include <algorithm>
24 #include <functional>
25 #include <set>
26 #include <utility>
27 #include <vector>
28
29 #include "NeuralNetworks.h"
30 #include "OperationsUtils.h"
31 #include "Tracing.h"
32 #include "Utils.h"
33 #include "nnapi/TypeUtils.h"
34
35 namespace android {
36 namespace nn {
37
38 template <class T_Model>
39 struct ModelToHalVersion;
40 template <>
41 struct ModelToHalVersion<V1_0::Model> {
42 static constexpr HalVersion version = HalVersion::V1_0;
43 };
44 template <>
45 struct ModelToHalVersion<V1_1::Model> {
46 static constexpr HalVersion version = HalVersion::V1_1;
47 };
48 template <>
49 struct ModelToHalVersion<V1_2::Model> {
50 static constexpr HalVersion version = HalVersion::V1_2;
51 };
52 template <>
53 struct ModelToHalVersion<V1_3::Model> {
54 static constexpr HalVersion version = HalVersion::V1_3;
55 };
56
57 class MemoryAccessVerifier {
58 public:
MemoryAccessVerifier(const hardware::hidl_vec<hardware::hidl_memory> & pools)59 MemoryAccessVerifier(const hardware::hidl_vec<hardware::hidl_memory>& pools)
60 : mPoolCount(pools.size()), mPoolSizes(mPoolCount) {
61 for (size_t i = 0; i < mPoolCount; i++) {
62 mPoolSizes[i] = pools[i].size();
63 }
64 }
MemoryAccessVerifier(const hardware::hidl_vec<V1_3::Request::MemoryPool> & pools)65 MemoryAccessVerifier(const hardware::hidl_vec<V1_3::Request::MemoryPool>& pools)
66 : mPoolCount(pools.size()), mPoolSizes(mPoolCount) {
67 for (size_t i = 0; i < mPoolCount; i++) {
68 switch (pools[i].getDiscriminator()) {
69 case V1_3::Request::MemoryPool::hidl_discriminator::hidlMemory:
70 mPoolSizes[i] = pools[i].hidlMemory().size();
71 break;
72 case V1_3::Request::MemoryPool::hidl_discriminator::token:
73 // Set size to 0 to enforce length == 0 && offset == 0.
74 mPoolSizes[i] = 0;
75 break;
76 }
77 }
78 }
validate(const V1_0::DataLocation & location) const79 bool validate(const V1_0::DataLocation& location) const {
80 if (location.poolIndex >= mPoolCount) {
81 LOG(ERROR) << "Invalid poolIndex " << location.poolIndex << "/" << mPoolCount;
82 return false;
83 }
84 const size_t size = mPoolSizes[location.poolIndex];
85 // Do the addition using size_t to avoid potential wrap-around problems.
86 if (static_cast<size_t>(location.offset) + location.length > size) {
87 LOG(ERROR) << "Reference to pool " << location.poolIndex << " with offset "
88 << location.offset << " and length " << location.length
89 << " exceeds pool size of " << size;
90 return false;
91 }
92 return true;
93 }
94
95 private:
96 size_t mPoolCount;
97 std::vector<size_t> mPoolSizes;
98 };
99
validateOperandExtraParams(const V1_3::Operand & operand,uint32_t index)100 static bool validateOperandExtraParams(const V1_3::Operand& operand, uint32_t index) {
101 switch (operand.type) {
102 case V1_3::OperandType::FLOAT32:
103 case V1_3::OperandType::INT32:
104 case V1_3::OperandType::UINT32:
105 case V1_3::OperandType::BOOL:
106 case V1_3::OperandType::SUBGRAPH:
107 case V1_3::OperandType::TENSOR_FLOAT32:
108 case V1_3::OperandType::TENSOR_FLOAT16:
109 case V1_3::OperandType::TENSOR_INT32:
110 case V1_3::OperandType::TENSOR_QUANT8_ASYMM:
111 case V1_3::OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
112 case V1_3::OperandType::TENSOR_QUANT8_SYMM:
113 case V1_3::OperandType::TENSOR_QUANT16_ASYMM:
114 case V1_3::OperandType::TENSOR_QUANT16_SYMM:
115 case V1_3::OperandType::TENSOR_BOOL8: {
116 NN_RET_CHECK(operand.extraParams.getDiscriminator() ==
117 V1_2::Operand::ExtraParams::hidl_discriminator::none)
118 << "Operand " << index << ": Operand of type "
119 << getOperandTypeName(operand.type)
120 << " has incorrect extraParams: " << toString(operand.extraParams);
121 } break;
122 case V1_3::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: {
123 NN_RET_CHECK(operand.extraParams.getDiscriminator() ==
124 V1_2::Operand::ExtraParams::hidl_discriminator::channelQuant)
125 << "Operand " << index << ": Operand of type "
126 << getOperandTypeName(operand.type) << " without a Channel Quantization params";
127 auto& channelQuant = operand.extraParams.channelQuant();
128
129 size_t count = operand.dimensions.size();
130 NN_RET_CHECK_LT(channelQuant.channelDim, count)
131 << "Operand " << index << ": Operand of type "
132 << getOperandTypeName(operand.type)
133 << " with an invalid channelQuant.channelDim " << channelQuant.channelDim
134 << ", must be valid dimension index in range [0, " << count << ")";
135 uint32_t expected = operand.dimensions[channelQuant.channelDim];
136 NN_RET_CHECK_EQ(channelQuant.scales.size(), expected)
137 << "Operand " << index << ": Operand of type "
138 << getOperandTypeName(operand.type) << " with a wrong-sized scales, "
139 << "expected " << expected << " was " << channelQuant.scales.size();
140 NN_RET_CHECK_NE(expected, 0u)
141 << "Operand " << index << ": Operand of type "
142 << getOperandTypeName(operand.type) << " channel dimension "
143 << channelQuant.channelDim << " is underspecified (can't be 0)";
144 for (uint32_t i = 0; i < expected; ++i) {
145 NN_RET_CHECK_GT(channelQuant.scales[i], .0f)
146 << "Operand " << index << ": Operand of type "
147 << getOperandTypeName(operand.type) << " with a negative value in scales["
148 << i << "]=" << channelQuant.scales[i];
149 }
150 } break;
151 default: {
152 if (isExtensionOperandType(operand.type)) {
153 NN_RET_CHECK(operand.extraParams.getDiscriminator() ==
154 V1_2::Operand::ExtraParams::hidl_discriminator::extension ||
155 operand.extraParams.getDiscriminator() ==
156 V1_2::Operand::ExtraParams::hidl_discriminator::none)
157 << "Operand " << index << ": Extension operand of type "
158 << getOperandTypeName(operand.type)
159 << " has incorrect extraParams: " << toString(operand.extraParams);
160 }
161 // No validation for OEM types.
162 } break;
163 }
164 return true;
165 }
166
167 template <typename VersionedOperand>
validateOperands(const hardware::hidl_vec<VersionedOperand> & operands,const hardware::hidl_vec<uint8_t> & operandValues,const hardware::hidl_vec<hardware::hidl_memory> & pools,const hardware::hidl_vec<V1_3::Subgraph> & subgraphs,bool allowUnspecifiedRank)168 static bool validateOperands(const hardware::hidl_vec<VersionedOperand>& operands,
169 const hardware::hidl_vec<uint8_t>& operandValues,
170 const hardware::hidl_vec<hardware::hidl_memory>& pools,
171 const hardware::hidl_vec<V1_3::Subgraph>& subgraphs,
172 bool allowUnspecifiedRank) {
173 uint32_t index = 0;
174 MemoryAccessVerifier poolVerifier(pools);
175 for (auto& versionedOperand : operands) {
176 if (!validOperandType(versionedOperand.type)) {
177 LOG(ERROR) << "Operand is not supported by this version: "
178 << toString(versionedOperand.type);
179 return false;
180 }
181 // Once we are sure the operand is supported by its version, it is safe
182 // to convert it to the latest version for the rest of the validations.
183 V1_3::Operand operand = convertToV1_3(versionedOperand);
184 // Validate type and dimensions.
185 switch (operand.type) {
186 case V1_3::OperandType::FLOAT16:
187 case V1_3::OperandType::FLOAT32:
188 case V1_3::OperandType::INT32:
189 case V1_3::OperandType::UINT32:
190 case V1_3::OperandType::BOOL:
191 case V1_3::OperandType::SUBGRAPH:
192 case V1_3::OperandType::OEM: {
193 size_t count = operand.dimensions.size();
194 if (count != 0) {
195 LOG(ERROR) << "Operand " << index << ": Scalar data has dimensions of rank "
196 << count;
197 return false;
198 }
199 break;
200 }
201 case V1_3::OperandType::TENSOR_FLOAT16:
202 case V1_3::OperandType::TENSOR_FLOAT32:
203 case V1_3::OperandType::TENSOR_INT32:
204 case V1_3::OperandType::TENSOR_QUANT8_ASYMM:
205 case V1_3::OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
206 case V1_3::OperandType::TENSOR_QUANT8_SYMM:
207 case V1_3::OperandType::TENSOR_QUANT16_ASYMM:
208 case V1_3::OperandType::TENSOR_QUANT16_SYMM:
209 case V1_3::OperandType::TENSOR_BOOL8:
210 case V1_3::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
211 case V1_3::OperandType::TENSOR_OEM_BYTE: {
212 if ((!allowUnspecifiedRank ||
213 operand.lifetime == V1_3::OperandLifeTime::CONSTANT_COPY ||
214 operand.lifetime == V1_3::OperandLifeTime::CONSTANT_REFERENCE) &&
215 operand.dimensions.size() == 0) {
216 LOG(ERROR) << "Operand " << index << ": Tensor has dimensions of rank 0";
217 return false;
218 }
219 break;
220 }
221 default: {
222 if (!isExtensionOperandType(operand.type)) {
223 LOG(ERROR) << "Operand " << index << ": Invalid operand type "
224 << toString(operand.type);
225 return false;
226 }
227 } break;
228 }
229
230 // Validate the scale.
231 switch (operand.type) {
232 case V1_3::OperandType::FLOAT16:
233 case V1_3::OperandType::FLOAT32:
234 case V1_3::OperandType::INT32:
235 case V1_3::OperandType::UINT32:
236 case V1_3::OperandType::BOOL:
237 case V1_3::OperandType::SUBGRAPH:
238 case V1_3::OperandType::TENSOR_FLOAT16:
239 case V1_3::OperandType::TENSOR_FLOAT32:
240 case V1_3::OperandType::TENSOR_BOOL8:
241 case V1_3::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
242 if (operand.scale != 0.f) {
243 LOG(ERROR) << "Operand " << index << ": Operand of type "
244 << getOperandTypeName(operand.type) << " with a non-zero scale ("
245 << operand.scale << ")";
246 return false;
247 }
248 break;
249 case V1_3::OperandType::TENSOR_INT32:
250 // TENSOR_INT32 may be used with or without scale, depending on the operation.
251 if (operand.scale < 0.f) {
252 LOG(ERROR) << "Operand " << index << ": Operand of type "
253 << getOperandTypeName(operand.type) << " with a negative scale";
254 return false;
255 }
256 break;
257 case V1_3::OperandType::TENSOR_QUANT8_ASYMM:
258 case V1_3::OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
259 case V1_3::OperandType::TENSOR_QUANT8_SYMM:
260 case V1_3::OperandType::TENSOR_QUANT16_ASYMM:
261 case V1_3::OperandType::TENSOR_QUANT16_SYMM:
262 if (operand.scale <= 0.f) {
263 LOG(ERROR) << "Operand " << index << ": Operand of type "
264 << getOperandTypeName(operand.type) << " with a non-positive scale";
265 return false;
266 }
267 break;
268 default:
269 if (isExtensionOperandType(operand.type) && operand.scale != 0.f) {
270 LOG(ERROR) << "Operand " << index << ": Operand of type "
271 << getOperandTypeName(operand.type) << " with a non-zero scale ("
272 << operand.scale << ")";
273 return false;
274 }
275 // No validation for OEM types.
276 // TODO(b/119869082) We should have a separate type for TENSOR_INT32 with a scale.
277 break;
278 }
279
280 // Validate the zeroPoint.
281 switch (operand.type) {
282 case V1_3::OperandType::FLOAT16:
283 case V1_3::OperandType::FLOAT32:
284 case V1_3::OperandType::INT32:
285 case V1_3::OperandType::UINT32:
286 case V1_3::OperandType::BOOL:
287 case V1_3::OperandType::SUBGRAPH:
288 case V1_3::OperandType::TENSOR_FLOAT16:
289 case V1_3::OperandType::TENSOR_FLOAT32:
290 case V1_3::OperandType::TENSOR_INT32:
291 case V1_3::OperandType::TENSOR_BOOL8:
292 case V1_3::OperandType::TENSOR_QUANT8_SYMM:
293 case V1_3::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
294 if (operand.zeroPoint != 0) {
295 LOG(ERROR) << "Operand " << index << ": Operand of type "
296 << getOperandTypeName(operand.type) << " with a non-zero zeroPoint "
297 << operand.zeroPoint;
298 return false;
299 }
300 break;
301 case V1_3::OperandType::TENSOR_QUANT8_ASYMM:
302 if (operand.zeroPoint < 0 || operand.zeroPoint > 255) {
303 LOG(ERROR) << "Operand " << index << ": Operand of type "
304 << getOperandTypeName(operand.type) << " with an invalid zeroPoint "
305 << operand.zeroPoint << ", must be in range [0, 255]";
306 return false;
307 }
308 break;
309 case V1_3::OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
310 if (operand.zeroPoint < -128 || operand.zeroPoint > 127) {
311 LOG(ERROR) << "Operand " << index << ": Operand of type "
312 << getOperandTypeName(operand.type) << " with an invalid zeroPoint "
313 << operand.zeroPoint << ", must be in range [-128, 127]";
314 return false;
315 }
316 break;
317 case V1_3::OperandType::TENSOR_QUANT16_ASYMM:
318 if (operand.zeroPoint < 0 || operand.zeroPoint > 65535) {
319 LOG(ERROR) << "Operand " << index << ": Operand of type "
320 << getOperandTypeName(operand.type) << " with an invalid zeroPoint "
321 << operand.zeroPoint << ", must be in range [0, 65535]";
322 return false;
323 }
324 break;
325 case V1_3::OperandType::TENSOR_QUANT16_SYMM:
326 if (operand.zeroPoint != 0) {
327 LOG(ERROR) << "Operand " << index << ": Operand of type "
328 << getOperandTypeName(operand.type) << " with a non-zero zeroPoint "
329 << operand.zeroPoint;
330 return false;
331 }
332 break;
333 default:
334 if (isExtensionOperandType(operand.type) && operand.zeroPoint != 0) {
335 LOG(ERROR) << "Operand " << index << ": Operand of type "
336 << getOperandTypeName(operand.type) << " with a non-zero zeroPoint "
337 << operand.zeroPoint;
338 return false;
339 }
340 // No validation for OEM types.
341 break;
342 }
343
344 NN_RET_CHECK(validateOperandExtraParams(operand, index));
345
346 // Validate the lifetime and the location.
347 const V1_0::DataLocation& location = operand.location;
348 switch (operand.lifetime) {
349 case V1_3::OperandLifeTime::CONSTANT_COPY:
350 if (location.poolIndex != 0) {
351 LOG(ERROR) << "Operand " << index
352 << ": CONSTANT_COPY with a non-zero poolIndex "
353 << location.poolIndex;
354 return false;
355 }
356 // Do the addition using size_t to avoid potential wrap-around problems.
357 if (static_cast<size_t>(location.offset) + location.length > operandValues.size()) {
358 LOG(ERROR) << "Operand " << index
359 << ": OperandValue location out of range. Starts at "
360 << location.offset << ", length " << location.length << ", max "
361 << operandValues.size();
362 return false;
363 }
364 break;
365 case V1_3::OperandLifeTime::CONSTANT_REFERENCE:
366 if (!poolVerifier.validate(location)) {
367 return false;
368 }
369 break;
370 case V1_3::OperandLifeTime::TEMPORARY_VARIABLE:
371 case V1_3::OperandLifeTime::SUBGRAPH_INPUT:
372 case V1_3::OperandLifeTime::SUBGRAPH_OUTPUT:
373 case V1_3::OperandLifeTime::NO_VALUE:
374 if (location.poolIndex != 0 || location.offset != 0 || location.length != 0) {
375 LOG(ERROR) << "Operand " << index << ": Unexpected poolIndex "
376 << location.poolIndex << ", offset " << location.offset
377 << ", or length " << location.length << " for operand of lifetime "
378 << toString(operand.lifetime);
379 return false;
380 }
381 break;
382 case V1_3::OperandLifeTime::SUBGRAPH: {
383 if (location.poolIndex != 0) {
384 LOG(ERROR) << "Operand " << index << ": SUBGRAPH with a non-zero poolIndex "
385 << location.poolIndex;
386 return false;
387 }
388 if (location.offset >= subgraphs.size()) {
389 LOG(ERROR) << "Model::Subgraph index out of range: " << location.offset
390 << " >= " << subgraphs.size();
391 return false;
392 }
393 if (location.length != 0) {
394 LOG(ERROR) << "Operand " << index << ": SUBGRAPH with a non-zero length "
395 << location.length;
396 return false;
397 }
398 } break;
399 default:
400 LOG(ERROR) << "Operand " << index << ": Invalid lifetime "
401 << toString(operand.lifetime);
402 return false;
403 }
404
405 // Make sure SUBGRAPH operand type and lifetime always go together.
406 if ((operand.type == V1_3::OperandType::SUBGRAPH) !=
407 (operand.lifetime == V1_3::OperandLifeTime::SUBGRAPH)) {
408 LOG(ERROR) << "Operand " << index << ": Operand of type " << toString(operand.type)
409 << " cannot have lifetime " << toString(operand.lifetime);
410 return false;
411 }
412
413 // For constants, validate that the length is as expected. The other lifetimes
414 // expect the length to be 0. Don't validate for OEM types.
415 if (operand.lifetime == V1_3::OperandLifeTime::CONSTANT_REFERENCE ||
416 operand.lifetime == V1_3::OperandLifeTime::CONSTANT_COPY) {
417 if (!isExtensionOperandType(operand.type) && operand.type != V1_3::OperandType::OEM &&
418 operand.type != V1_3::OperandType::TENSOR_OEM_BYTE) {
419 uint32_t expectedLength = nonExtensionOperandSizeOfData(operand);
420 if (location.length != expectedLength) {
421 LOG(ERROR) << "Operand " << index << ": For operand " << toString(operand)
422 << " expected a size of " << expectedLength << " but got "
423 << location.length;
424 return false;
425 }
426 }
427 }
428
429 index++;
430 }
431 return true;
432 }
433
getHalVersion(const V1_0::Operation &)434 static HalVersion getHalVersion(const V1_0::Operation&) {
435 return HalVersion::V1_0;
436 }
437
getHalVersion(const V1_1::Operation &)438 static HalVersion getHalVersion(const V1_1::Operation&) {
439 return HalVersion::V1_1;
440 }
441
getHalVersion(const V1_2::Operation &)442 static HalVersion getHalVersion(const V1_2::Operation&) {
443 return HalVersion::V1_2;
444 }
445
getHalVersion(const V1_3::Operation &)446 static HalVersion getHalVersion(const V1_3::Operation&) {
447 return HalVersion::V1_3;
448 }
449
450 template <typename VersionedOperation>
validateOperations(const hardware::hidl_vec<VersionedOperation> & operations,const hardware::hidl_vec<V1_3::Operand> & operands,const hardware::hidl_vec<V1_3::Subgraph> & subgraphs,ValidationMode mode)451 static bool validateOperations(const hardware::hidl_vec<VersionedOperation>& operations,
452 const hardware::hidl_vec<V1_3::Operand>& operands,
453 const hardware::hidl_vec<V1_3::Subgraph>& subgraphs,
454 ValidationMode mode) {
455 auto canonicalSubgraphs = uncheckedConvert(subgraphs);
456 auto isValidSubgraphReference = [&canonicalSubgraphs](const Operand& modelOperand) -> bool {
457 NN_RET_CHECK(modelOperand.type == OperandType::SUBGRAPH)
458 << "Unexpected operand type: " << modelOperand.type;
459 NN_RET_CHECK_LT(modelOperand.location.offset, canonicalSubgraphs.size())
460 << "Invalid subgraph reference";
461 return true;
462 };
463 auto getSubgraph =
464 [&canonicalSubgraphs](const Operand& modelOperand) -> const Model::Subgraph* {
465 CHECK_LT(modelOperand.location.offset, canonicalSubgraphs.size());
466 return &canonicalSubgraphs[modelOperand.location.offset];
467 };
468 auto getInputCount = [&getSubgraph](const Operand& modelOperand) -> uint32_t {
469 return getSubgraph(modelOperand)->inputIndexes.size();
470 };
471 auto getOutputCount = [&getSubgraph](const Operand& modelOperand) -> uint32_t {
472 return getSubgraph(modelOperand)->outputIndexes.size();
473 };
474 auto getInputOperand = [&getSubgraph](const Operand& modelOperand,
475 uint32_t index) -> const Operand* {
476 const Model::Subgraph& subgraph = *getSubgraph(modelOperand);
477 CHECK_LT(subgraph.inputIndexes[index], subgraph.operands.size());
478 return &subgraph.operands[subgraph.inputIndexes[index]];
479 };
480 auto getOutputOperand = [&getSubgraph](const Operand& modelOperand,
481 uint32_t index) -> const Operand* {
482 const Model::Subgraph& subgraph = *getSubgraph(modelOperand);
483 CHECK_LT(subgraph.outputIndexes[index], subgraph.operands.size());
484 return &subgraph.operands[subgraph.outputIndexes[index]];
485 };
486 for (auto& op : operations) {
487 // TODO Validate the shapes and any known values. This is currently
488 // done in CpuExecutor but should be done here for all drivers.
489 int error = validateOperation(static_cast<int32_t>(op.type), op.inputs.size(),
490 op.inputs.size() > 0 ? op.inputs.data() : nullptr,
491 op.outputs.size(),
492 op.outputs.size() > 0 ? op.outputs.data() : nullptr,
493 uncheckedConvert(operands), getHalVersion(op),
494 {.isValidSubgraphReference = isValidSubgraphReference,
495 .getSubgraphInputCount = getInputCount,
496 .getSubgraphOutputCount = getOutputCount,
497 .getSubgraphInputOperand = getInputOperand,
498 .getSubgraphOutputOperand = getOutputOperand,
499 // 1.3 HAL does not support CF operations with operands of
500 // unknown size. See http://b/132458982#comment63.
501 .allowControlFlowOperationWithOperandOfUnknownSize =
502 mode == ValidationMode::RUNTIME});
503 if (error != ANEURALNETWORKS_NO_ERROR) {
504 LOG(ERROR) << "Invalid operation " << toString(op.type);
505 return false;
506 }
507
508 // This is redundant because of the checks in validateGraph(),
509 // but it is retained here in order to emit more informative
510 // error messages.
511 for (uint32_t i : op.outputs) {
512 const V1_3::Operand& operand = operands[i];
513 if (operand.lifetime != V1_3::OperandLifeTime::TEMPORARY_VARIABLE &&
514 operand.lifetime != V1_3::OperandLifeTime::SUBGRAPH_OUTPUT) {
515 LOG(ERROR) << "Writing to operand " << i << " with incompatible lifetime "
516 << toString(operand.lifetime);
517 return false;
518 }
519 }
520 }
521 return true;
522 }
523
validatePool(const hardware::hidl_memory & pool,HalVersion ver)524 bool validatePool(const hardware::hidl_memory& pool, HalVersion ver) {
525 const auto& name = pool.name();
526 if (name != "ashmem" && name != "mmap_fd" &&
527 ((ver < HalVersion::V1_2) ||
528 (name != "hardware_buffer_blob" && name != "hardware_buffer"))) {
529 LOG(ERROR) << "Unsupported memory type " << name;
530 return false;
531 }
532 if (pool.handle() == nullptr) {
533 LOG(ERROR) << "Memory of type " << name << " is null";
534 return false;
535 }
536 return true;
537 }
538
validatePool(const V1_3::Request::MemoryPool & pool,HalVersion ver)539 bool validatePool(const V1_3::Request::MemoryPool& pool, HalVersion ver) {
540 switch (pool.getDiscriminator()) {
541 case V1_3::Request::MemoryPool::hidl_discriminator::hidlMemory:
542 return validatePool(pool.hidlMemory(), ver);
543 case V1_3::Request::MemoryPool::hidl_discriminator::token:
544 return pool.token() > 0;
545 }
546 LOG(FATAL) << "unknown MemoryPool discriminator";
547 return false;
548 }
549
550 template <class T_MemoryPool>
validatePools(const hardware::hidl_vec<T_MemoryPool> & pools,HalVersion ver)551 static bool validatePools(const hardware::hidl_vec<T_MemoryPool>& pools, HalVersion ver) {
552 return std::all_of(pools.begin(), pools.end(),
553 [ver](const auto& pool) { return validatePool(pool, ver); });
554 }
555
validateModelInputOutputs(const hardware::hidl_vec<uint32_t> indexes,const hardware::hidl_vec<V1_3::Operand> & operands,V1_3::OperandLifeTime lifetime)556 static bool validateModelInputOutputs(const hardware::hidl_vec<uint32_t> indexes,
557 const hardware::hidl_vec<V1_3::Operand>& operands,
558 V1_3::OperandLifeTime lifetime) {
559 const size_t operandCount = operands.size();
560 for (uint32_t i : indexes) {
561 if (i >= operandCount) {
562 LOG(ERROR) << "Model input or output index out of range: " << i << "/" << operandCount;
563 return false;
564 }
565 const V1_3::Operand& operand = operands[i];
566 if (operand.lifetime != lifetime) {
567 LOG(ERROR) << "Model input or output operand " << i << " has lifetime of "
568 << toString(operand.lifetime) << " instead of the expected "
569 << toString(lifetime);
570 return false;
571 }
572 }
573
574 std::vector<uint32_t> sortedIndexes = indexes;
575 std::sort(sortedIndexes.begin(), sortedIndexes.end());
576 auto adjacentI = std::adjacent_find(sortedIndexes.begin(), sortedIndexes.end());
577 if (adjacentI != sortedIndexes.end()) {
578 LOG(ERROR) << "Model input or output occurs multiple times: " << *adjacentI;
579 return false;
580 }
581
582 for (size_t i = 0; i < operands.size(); ++i) {
583 if (operands[i].lifetime == lifetime &&
584 !binary_search(sortedIndexes.begin(), sortedIndexes.end(), i)) {
585 LOG(ERROR) << "Operand " << i << " marked as " << toString(lifetime)
586 << " but is not included in Model input or output indexes";
587 return false;
588 }
589 }
590
591 return true;
592 }
593
594 template <typename VersionedModelOrSubgraph>
validateGraph(const VersionedModelOrSubgraph & model)595 static bool validateGraph(const VersionedModelOrSubgraph& model) {
596 // set up counts
597 std::vector<uint32_t> operandNumberOfConsumers(model.operands.size(), 0);
598 // Either the operand has a known value before model execution
599 // begins, or we've seen a writer for this operand while
600 // walking operands in execution order.
601 std::vector<bool> operandValueKnown(model.operands.size(), false);
602
603 // mark known operands
604 for (size_t i = 0; i < model.operands.size(); ++i) {
605 const auto& operand = model.operands[i];
606 const V1_3::OperandLifeTime lifetime = convertToV1_3(operand.lifetime);
607 operandValueKnown[i] = lifetime == V1_3::OperandLifeTime::SUBGRAPH_INPUT ||
608 lifetime == V1_3::OperandLifeTime::CONSTANT_COPY ||
609 lifetime == V1_3::OperandLifeTime::CONSTANT_REFERENCE ||
610 lifetime == V1_3::OperandLifeTime::NO_VALUE ||
611 lifetime == V1_3::OperandLifeTime::SUBGRAPH;
612 }
613
614 // Validate that operations are sorted into execution order.
615 //
616 // If there is a cycle in the graph, the operations will not
617 // appear to be sorted into execution order: Some operation will
618 // have an input for which operandValueKnown[] is false.
619 for (size_t i = 0; i < model.operations.size(); ++i) {
620 const auto& operation = model.operations[i];
621
622 for (size_t j = 0; j < operation.inputs.size(); ++j) {
623 uint32_t k = operation.inputs[j];
624 if (!operandValueKnown[k]) {
625 LOG(ERROR) << "Operation " << i << " input " << j << " (operand " << k
626 << ") is read before it is written";
627 return false;
628 }
629 operandNumberOfConsumers[k]++;
630 }
631
632 for (size_t j = 0; j < operation.outputs.size(); ++j) {
633 uint32_t k = operation.outputs[j];
634 if (operandValueKnown[k]) {
635 // Assuming validateOperations() has returned true, we
636 // know that this output is TEMPORARY_VARIABLE or
637 // MODEL_OUTPUT, and so the only way
638 // operandValueKnown[k] can be true is if we've
639 // already seen a writer for this operand.
640 LOG(ERROR) << "Operation " << i << " output " << j << " (operand " << k
641 << ") has already been written";
642 return false;
643 }
644 operandValueKnown[k] = true;
645 }
646 }
647
648 // validate number of consumers
649 //
650 // TODO Because we have to validate it, there was no point in including it
651 // in struct Operand. For the next release, consider removing unless we have
652 // an additional process in system space that creates this value. In that
653 // case, it would not have to be validated.
654 for (size_t i = 0; i < model.operands.size(); ++i) {
655 if (model.operands[i].numberOfConsumers != operandNumberOfConsumers[i]) {
656 LOG(ERROR) << "Operand " << i << " has incorrect number of consumers "
657 << model.operands[i].numberOfConsumers << ", expected "
658 << operandNumberOfConsumers[i];
659 return false;
660 }
661 }
662
663 // verify all operands are written
664 for (size_t i = 0; i < model.operands.size(); ++i) {
665 if (!operandValueKnown[i]) {
666 LOG(ERROR) << "Operand " << i << " is never written";
667 return false;
668 }
669 }
670
671 return true;
672 }
673
674 // Makes sure the model does not contain subgraph reference cycles.
checkNoReferenceCycles(const V1_3::Model & model,const V1_3::Subgraph & subgraph,std::set<const V1_3::Subgraph * > * path)675 static bool checkNoReferenceCycles(const V1_3::Model& model, const V1_3::Subgraph& subgraph,
676 std::set<const V1_3::Subgraph*>* path) {
677 auto [_, isNew] = path->insert(&subgraph);
678 if (!isNew) {
679 LOG(ERROR) << "Model contains a circular subgraph reference";
680 return false;
681 }
682 for (const V1_3::Operand& operand : subgraph.operands) {
683 if (operand.lifetime == V1_3::OperandLifeTime::SUBGRAPH) {
684 uint32_t refSubgraphIndex = operand.location.offset;
685 if (!checkNoReferenceCycles(model, model.referenced[refSubgraphIndex], path)) {
686 return false;
687 }
688 }
689 }
690 path->erase(&subgraph);
691 return true;
692 }
693
checkNoReferenceCycles(const V1_3::Model & model)694 static bool checkNoReferenceCycles(const V1_3::Model& model) {
695 std::set<const V1_3::Subgraph*> path;
696 return checkNoReferenceCycles(model, model.main, &path);
697 }
698
699 template <class T_Model>
validateModel(const T_Model & model,ValidationMode mode)700 bool validateModel(const T_Model& model, ValidationMode mode) {
701 NNTRACE_FULL(NNTRACE_LAYER_UTILITY, NNTRACE_PHASE_UNSPECIFIED, "validateModel");
702 HalVersion version = ModelToHalVersion<T_Model>::version;
703 if (model.operations.size() == 0 || model.operands.size() == 0) {
704 LOG(ERROR) << "Invalid empty model.";
705 return false;
706 }
707 // We only need versioned operands for their validation. For all the other
708 // validations we can use operands upcasted to the latest version.
709 const hardware::hidl_vec<V1_3::Operand> latestVersionOperands = convertToV1_3(model.operands);
710 return (validateOperands(model.operands, model.operandValues, model.pools, /*subgraphs=*/{},
711 /*allowUnspecifiedRank=*/version >= HalVersion::V1_2) &&
712 validateOperations(model.operations, latestVersionOperands, /*subgraphs=*/{}, mode) &&
713 validateModelInputOutputs(model.inputIndexes, latestVersionOperands,
714 V1_3::OperandLifeTime::SUBGRAPH_INPUT) &&
715 validateModelInputOutputs(model.outputIndexes, latestVersionOperands,
716 V1_3::OperandLifeTime::SUBGRAPH_OUTPUT) &&
717 validatePools(model.pools, version) && validateGraph(model));
718 }
719
720 template bool validateModel<V1_0::Model>(const V1_0::Model& model, ValidationMode mode);
721 template bool validateModel<V1_1::Model>(const V1_1::Model& model, ValidationMode mode);
722 template bool validateModel<V1_2::Model>(const V1_2::Model& model, ValidationMode mode);
723
724 template <>
validateModel(const V1_3::Model & model,ValidationMode mode)725 bool validateModel(const V1_3::Model& model, ValidationMode mode) {
726 NNTRACE_FULL(NNTRACE_LAYER_UTILITY, NNTRACE_PHASE_UNSPECIFIED, "validateModel");
727 if (model.main.operations.size() == 0 || model.main.operands.size() == 0) {
728 LOG(ERROR) << "Invalid empty model.";
729 return false;
730 }
731 auto validateSubgraph = [&model, mode](const V1_3::Subgraph& subgraph) -> bool {
732 return (validateOperands(subgraph.operands, model.operandValues, model.pools,
733 model.referenced, /*allowUnspecifiedRank=*/true) &&
734 validateOperations(subgraph.operations, subgraph.operands, model.referenced,
735 mode) &&
736 validateModelInputOutputs(subgraph.inputIndexes, subgraph.operands,
737 V1_3::OperandLifeTime::SUBGRAPH_INPUT) &&
738 validateModelInputOutputs(subgraph.outputIndexes, subgraph.operands,
739 V1_3::OperandLifeTime::SUBGRAPH_OUTPUT) &&
740 validateGraph(subgraph));
741 };
742 return (validateSubgraph(model.main) &&
743 std::all_of(model.referenced.begin(), model.referenced.end(), validateSubgraph) &&
744 validatePools(model.pools, HalVersion::V1_3) && checkNoReferenceCycles(model));
745 }
746
747 // Validates the arguments of a request. type is either "input" or "output" and is used
748 // for printing error messages. The operandIndexes is the appropriate array of input
749 // or output operand indexes that was passed to the ANeuralNetworksModel_identifyInputsAndOutputs.
validateRequestArguments(const hardware::hidl_vec<V1_0::RequestArgument> & requestArguments,const hardware::hidl_vec<uint32_t> & operandIndexes,const hardware::hidl_vec<V1_3::Operand> & operands,const MemoryAccessVerifier & poolVerifier,bool allowUnspecified,const char * type)750 static bool validateRequestArguments(
751 const hardware::hidl_vec<V1_0::RequestArgument>& requestArguments,
752 const hardware::hidl_vec<uint32_t>& operandIndexes,
753 const hardware::hidl_vec<V1_3::Operand>& operands, const MemoryAccessVerifier& poolVerifier,
754 bool allowUnspecified, const char* type) {
755 // The request should specify as many arguments as were described in the model.
756 const size_t requestArgumentCount = requestArguments.size();
757 if (requestArgumentCount != operandIndexes.size()) {
758 LOG(ERROR) << "Request specifies " << requestArgumentCount << " " << type
759 << "s but the model has " << operandIndexes.size();
760 return false;
761 }
762 for (size_t requestArgumentIndex = 0; requestArgumentIndex < requestArgumentCount;
763 requestArgumentIndex++) {
764 const V1_0::RequestArgument& requestArgument = requestArguments[requestArgumentIndex];
765 const V1_0::DataLocation& location = requestArgument.location;
766 // Get the operand index for this argument. We extract it from the list
767 // that was provided in the call to ANeuralNetworksModel_identifyInputsAndOutputs.
768 // We assume in this function that the model has been validated already.
769 const uint32_t operandIndex = operandIndexes[requestArgumentIndex];
770 const V1_3::Operand& operand = operands[operandIndex];
771 if (requestArgument.hasNoValue) {
772 if (location.poolIndex != 0 || location.offset != 0 || location.length != 0 ||
773 requestArgument.dimensions.size() != 0) {
774 LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
775 << " has no value yet has details.";
776 return false;
777 }
778 } else {
779 // Validate the location.
780 if (!poolVerifier.validate(location)) {
781 return false;
782 }
783 // If the argument specified a dimension, validate it.
784 uint32_t modelRank = operand.dimensions.size();
785 uint32_t requestRank = requestArgument.dimensions.size();
786 if (requestRank == 0) {
787 if (!allowUnspecified) {
788 // NOTE: validateRequestArguments cannot validate unknown tensor rank with
789 // extension operand type.
790 if (!isExtensionOperandType(operand.type) &&
791 !nonExtensionOperandTypeIsScalar(static_cast<int>(operand.type))) {
792 NN_RET_CHECK_GT(modelRank, 0u)
793 << "Model " << type << " " << requestArgumentIndex
794 << " has unknown rank but the request does not specify the rank.";
795 }
796 // Validate that all the dimensions are specified in the model.
797 for (size_t i = 0; i < modelRank; i++) {
798 if (operand.dimensions[i] == 0) {
799 LOG(ERROR)
800 << "Model has dimension " << i
801 << " set to 0 but the request does not specify the dimension.";
802 return false;
803 }
804 }
805 }
806 } else {
807 if (modelRank != 0 && requestRank != modelRank) {
808 LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
809 << " has number of dimensions (" << requestRank
810 << ") different than the model's (" << modelRank << ")";
811 return false;
812 }
813 for (size_t i = 0; i < requestRank; i++) {
814 if (modelRank != 0 && requestArgument.dimensions[i] != operand.dimensions[i] &&
815 operand.dimensions[i] != 0) {
816 LOG(ERROR)
817 << "Request " << type << " " << requestArgumentIndex
818 << " has dimension " << i << " of " << requestArgument.dimensions[i]
819 << " different than the model's " << operand.dimensions[i];
820 return false;
821 }
822 if (requestArgument.dimensions[i] == 0 && !allowUnspecified) {
823 LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
824 << " has dimension " << i << " of zero";
825 return false;
826 }
827 }
828 }
829 }
830 }
831 return true;
832 }
833
834 template <class T_Request, class T_Model>
validateRequest(const T_Request & request,const T_Model & model,bool allowUnspecifiedOutput)835 bool validateRequest(const T_Request& request, const T_Model& model, bool allowUnspecifiedOutput) {
836 HalVersion version = ModelToHalVersion<T_Model>::version;
837 MemoryAccessVerifier poolVerifier(request.pools);
838 return (validateRequestArguments(request.inputs, model.inputIndexes,
839 convertToV1_3(model.operands), poolVerifier,
840 /*allowUnspecified=*/false, "input") &&
841 validateRequestArguments(
842 request.outputs, model.outputIndexes, convertToV1_3(model.operands),
843 poolVerifier,
844 /*allowUnspecified=*/version >= HalVersion::V1_2 && allowUnspecifiedOutput,
845 "output") &&
846 validatePools(request.pools, version));
847 }
848
849 template bool validateRequest<V1_0::Request, V1_0::Model>(const V1_0::Request& request,
850 const V1_0::Model& model,
851 bool allowUnspecifiedOutput);
852 template bool validateRequest<V1_0::Request, V1_1::Model>(const V1_0::Request& request,
853 const V1_1::Model& model,
854 bool allowUnspecifiedOutput);
855 template bool validateRequest<V1_0::Request, V1_2::Model>(const V1_0::Request& request,
856 const V1_2::Model& model,
857 bool allowUnspecifiedOutput);
858
859 template <>
validateRequest(const V1_3::Request & request,const V1_3::Model & model,bool allowUnspecifiedOutput)860 bool validateRequest(const V1_3::Request& request, const V1_3::Model& model,
861 bool allowUnspecifiedOutput) {
862 return (validateRequestArguments(request.inputs, model.main.inputIndexes, model.main.operands,
863 request.pools,
864 /*allowUnspecified=*/false, "input") &&
865 validateRequestArguments(request.outputs, model.main.outputIndexes, model.main.operands,
866 request.pools, allowUnspecifiedOutput, "output") &&
867 validatePools(request.pools, HalVersion::V1_3));
868 }
869
validateMemoryDesc(const V1_3::BufferDesc & desc,const hardware::hidl_vec<sp<V1_3::IPreparedModel>> & preparedModels,const hardware::hidl_vec<V1_3::BufferRole> & inputRoles,const hardware::hidl_vec<V1_3::BufferRole> & outputRoles,std::function<const V1_3::Model * (const sp<V1_3::IPreparedModel> &)> getModel,std::set<HalPreparedModelRole> * preparedModelRoles,V1_3::Operand * combinedOperand)870 bool validateMemoryDesc(const V1_3::BufferDesc& desc,
871 const hardware::hidl_vec<sp<V1_3::IPreparedModel>>& preparedModels,
872 const hardware::hidl_vec<V1_3::BufferRole>& inputRoles,
873 const hardware::hidl_vec<V1_3::BufferRole>& outputRoles,
874 std::function<const V1_3::Model*(const sp<V1_3::IPreparedModel>&)> getModel,
875 std::set<HalPreparedModelRole>* preparedModelRoles,
876 V1_3::Operand* combinedOperand) {
877 NN_RET_CHECK(preparedModels.size() != 0);
878 NN_RET_CHECK(inputRoles.size() != 0 || outputRoles.size() != 0);
879
880 std::set<HalPreparedModelRole> roles;
881 std::vector<V1_3::Operand> operands;
882 operands.reserve(inputRoles.size() + outputRoles.size());
883 for (const auto& role : inputRoles) {
884 NN_RET_CHECK_LT(role.modelIndex, preparedModels.size());
885 const auto& preparedModel = preparedModels[role.modelIndex];
886 NN_RET_CHECK(preparedModel != nullptr);
887 const auto* model = getModel(preparedModel);
888 NN_RET_CHECK(model != nullptr);
889 const auto& inputIndexes = model->main.inputIndexes;
890 NN_RET_CHECK_LT(role.ioIndex, inputIndexes.size());
891 NN_RET_CHECK_GT(role.frequency, 0.0f);
892 NN_RET_CHECK_LE(role.frequency, 1.0f);
893 const auto [it, success] = roles.emplace(preparedModel.get(), IOType::INPUT, role.ioIndex);
894 NN_RET_CHECK(success);
895 operands.push_back(model->main.operands[inputIndexes[role.ioIndex]]);
896 }
897 for (const auto& role : outputRoles) {
898 NN_RET_CHECK_LT(role.modelIndex, preparedModels.size());
899 const auto& preparedModel = preparedModels[role.modelIndex];
900 NN_RET_CHECK(preparedModel != nullptr);
901 const auto* model = getModel(preparedModel);
902 NN_RET_CHECK(model != nullptr);
903 const auto& outputIndexes = model->main.outputIndexes;
904 NN_RET_CHECK_LT(role.ioIndex, outputIndexes.size());
905 NN_RET_CHECK_GT(role.frequency, 0.0f);
906 NN_RET_CHECK_LE(role.frequency, 1.0f);
907 const auto [it, success] = roles.emplace(preparedModel.get(), IOType::OUTPUT, role.ioIndex);
908 NN_RET_CHECK(success);
909 operands.push_back(model->main.operands[outputIndexes[role.ioIndex]]);
910 }
911
912 CHECK(!operands.empty());
913 const auto opType = operands[0].type;
914 const bool isExtension = isExtensionOperandType(opType);
915
916 std::vector<uint32_t> dimensions = desc.dimensions;
917 for (const auto& operand : operands) {
918 NN_RET_CHECK(operand.type == operands[0].type)
919 << toString(operand.type) << " vs " << toString(operands[0].type);
920 NN_RET_CHECK_EQ(operand.scale, operands[0].scale);
921 NN_RET_CHECK_EQ(operand.zeroPoint, operands[0].zeroPoint);
922 // NOTE: validateMemoryDesc cannot validate extra parameters for extension operand type.
923 if (!isExtension) {
924 NN_RET_CHECK(operand.extraParams == operands[0].extraParams)
925 << toString(operand.extraParams) << " vs " << toString(operands[0].extraParams);
926 }
927 const auto combined = combineDimensions(dimensions, operand.dimensions);
928 NN_RET_CHECK(combined.has_value());
929 dimensions = combined.value();
930 }
931
932 // NOTE: validateMemoryDesc cannot validate scalar dimensions with extension operand type.
933 if (!isExtension) {
934 NN_RET_CHECK(!nonExtensionOperandTypeIsScalar(static_cast<int>(opType)) ||
935 dimensions.empty())
936 << "invalid dimensions with scalar operand type.";
937 }
938
939 if (preparedModelRoles != nullptr) {
940 *preparedModelRoles = std::move(roles);
941 }
942 if (combinedOperand != nullptr) {
943 *combinedOperand = operands[0];
944 combinedOperand->dimensions = dimensions;
945 }
946 return true;
947 }
948
validateExecutionPreference(V1_1::ExecutionPreference preference)949 bool validateExecutionPreference(V1_1::ExecutionPreference preference) {
950 return preference == V1_1::ExecutionPreference::LOW_POWER ||
951 preference == V1_1::ExecutionPreference::FAST_SINGLE_ANSWER ||
952 preference == V1_1::ExecutionPreference::SUSTAINED_SPEED;
953 }
954
validatePriority(V1_3::Priority priority)955 bool validatePriority(V1_3::Priority priority) {
956 return priority == V1_3::Priority::LOW || priority == V1_3::Priority::MEDIUM ||
957 priority == V1_3::Priority::HIGH;
958 }
959
validOperandType(V1_0::OperandType operandType)960 bool validOperandType(V1_0::OperandType operandType) {
961 switch (operandType) {
962 case V1_0::OperandType::FLOAT32:
963 case V1_0::OperandType::INT32:
964 case V1_0::OperandType::UINT32:
965 case V1_0::OperandType::TENSOR_FLOAT32:
966 case V1_0::OperandType::TENSOR_INT32:
967 case V1_0::OperandType::TENSOR_QUANT8_ASYMM:
968 case V1_0::OperandType::OEM:
969 case V1_0::OperandType::TENSOR_OEM_BYTE:
970 return true;
971 default:
972 return false;
973 }
974 }
975
validOperandType(V1_2::OperandType operandType)976 bool validOperandType(V1_2::OperandType operandType) {
977 switch (operandType) {
978 case V1_2::OperandType::FLOAT16:
979 case V1_2::OperandType::FLOAT32:
980 case V1_2::OperandType::INT32:
981 case V1_2::OperandType::UINT32:
982 case V1_2::OperandType::BOOL:
983 case V1_2::OperandType::TENSOR_FLOAT16:
984 case V1_2::OperandType::TENSOR_FLOAT32:
985 case V1_2::OperandType::TENSOR_INT32:
986 case V1_2::OperandType::TENSOR_QUANT8_ASYMM:
987 case V1_2::OperandType::TENSOR_QUANT8_SYMM:
988 case V1_2::OperandType::TENSOR_QUANT16_ASYMM:
989 case V1_2::OperandType::TENSOR_QUANT16_SYMM:
990 case V1_2::OperandType::TENSOR_BOOL8:
991 case V1_2::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
992 case V1_2::OperandType::OEM:
993 case V1_2::OperandType::TENSOR_OEM_BYTE:
994 return true;
995 default:
996 return isExtensionOperandType(static_cast<V1_3::OperandType>(operandType));
997 }
998 }
999
validOperandType(V1_3::OperandType operandType)1000 bool validOperandType(V1_3::OperandType operandType) {
1001 switch (operandType) {
1002 case V1_3::OperandType::FLOAT16:
1003 case V1_3::OperandType::FLOAT32:
1004 case V1_3::OperandType::INT32:
1005 case V1_3::OperandType::UINT32:
1006 case V1_3::OperandType::BOOL:
1007 case V1_3::OperandType::TENSOR_FLOAT16:
1008 case V1_3::OperandType::TENSOR_FLOAT32:
1009 case V1_3::OperandType::TENSOR_INT32:
1010 case V1_3::OperandType::TENSOR_QUANT8_ASYMM:
1011 case V1_3::OperandType::TENSOR_QUANT8_SYMM:
1012 case V1_3::OperandType::TENSOR_QUANT16_ASYMM:
1013 case V1_3::OperandType::TENSOR_QUANT16_SYMM:
1014 case V1_3::OperandType::TENSOR_BOOL8:
1015 case V1_3::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
1016 case V1_3::OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
1017 case V1_3::OperandType::SUBGRAPH:
1018 case V1_3::OperandType::OEM:
1019 case V1_3::OperandType::TENSOR_OEM_BYTE:
1020 return true;
1021 default:
1022 return isExtensionOperandType(operandType);
1023 }
1024 }
1025
1026 } // namespace nn
1027 } // namespace android
1028