1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "OperationValidationUtils"
18
19 #include "OperationsValidationUtils.h"
20
21 #include <android-base/logging.h>
22
23 #include <functional>
24 #include <vector>
25
26 #include "OperationsUtils.h"
27 #include "nnapi/Validation.h"
28
29 namespace android::nn {
30 namespace {
31
validateOperandTypes(const std::vector<OperandType> & expectedTypes,const char * tag,uint32_t operandCount,std::function<OperandType (uint32_t)> getOperandType)32 bool validateOperandTypes(const std::vector<OperandType>& expectedTypes, const char* tag,
33 uint32_t operandCount,
34 std::function<OperandType(uint32_t)> getOperandType) {
35 NN_RET_CHECK_EQ(operandCount, expectedTypes.size());
36 for (uint32_t i = 0; i < operandCount; ++i) {
37 OperandType type = getOperandType(i);
38 NN_RET_CHECK(type == expectedTypes[i])
39 << "Invalid " << tag << " tensor type " << type << " for " << tag << " " << i
40 << ", expected " << expectedTypes[i];
41 }
42 return true;
43 }
44
45 } // namespace
46
invalidInOutNumberMessage(int expIn,int expOut) const47 std::string IOperationValidationContext::invalidInOutNumberMessage(int expIn, int expOut) const {
48 std::ostringstream os;
49 os << "Invalid number of input operands (" << getNumInputs() << ", expected " << expIn
50 << ") or output operands (" << getNumOutputs() << ", expected " << expOut
51 << ") for operation " << getOperationName();
52 return os.str();
53 }
54
validateOperationOperandTypes(const std::vector<OperandType> & inExpectedTypes,const std::vector<OperandType> & outExpectedInTypes) const55 Result<void> IOperationValidationContext::validateOperationOperandTypes(
56 const std::vector<OperandType>& inExpectedTypes,
57 const std::vector<OperandType>& outExpectedInTypes) const {
58 NN_RET_CHECK_EQ(getNumInputs(), inExpectedTypes.size())
59 << "Wrong operand count: expected " << inExpectedTypes.size() << " inputs, got "
60 << getNumInputs() << " inputs";
61 NN_RET_CHECK_EQ(getNumOutputs(), outExpectedInTypes.size())
62 << "Wrong operand count: expected " << outExpectedInTypes.size() << " outputs, got "
63 << getNumOutputs() << " outputs";
64 for (size_t i = 0; i < getNumInputs(); i++) {
65 NN_RET_CHECK_EQ(getInputType(i), inExpectedTypes[i])
66 << "Invalid input tensor type " << getInputType(i) << " for input " << i
67 << ", expected " << inExpectedTypes[i];
68 }
69 for (size_t i = 0; i < getNumOutputs(); i++) {
70 NN_RET_CHECK_EQ(getOutputType(i), outExpectedInTypes[i])
71 << "Invalid output tensor type " << getOutputType(i) << " for input " << i
72 << ", expected " << outExpectedInTypes[i];
73 }
74
75 return {};
76 }
77
validateInputTypes(const IOperationValidationContext * context,const std::vector<OperandType> & expectedTypes)78 bool validateInputTypes(const IOperationValidationContext* context,
79 const std::vector<OperandType>& expectedTypes) {
80 return validateOperandTypes(expectedTypes, "input", context->getNumInputs(),
81 [context](uint32_t index) { return context->getInputType(index); });
82 }
83
validateOutputTypes(const IOperationValidationContext * context,const std::vector<OperandType> & expectedTypes)84 bool validateOutputTypes(const IOperationValidationContext* context,
85 const std::vector<OperandType>& expectedTypes) {
86 return validateOperandTypes(
87 expectedTypes, "output", context->getNumOutputs(),
88 [context](uint32_t index) { return context->getOutputType(index); });
89 }
90
validateVersion(const IOperationValidationContext * context,Version contextVersion,Version minSupportedVersion)91 bool validateVersion(const IOperationValidationContext* context, Version contextVersion,
92 Version minSupportedVersion) {
93 if (!isCompliantVersion(minSupportedVersion, contextVersion)) {
94 std::ostringstream message;
95 message << "Operation " << context->getOperationName() << " with inputs {";
96 for (uint32_t i = 0, n = context->getNumInputs(); i < n; ++i) {
97 if (i != 0) {
98 message << ", ";
99 }
100 message << context->getInputType(i);
101 }
102 message << "} and outputs {";
103 for (uint32_t i = 0, n = context->getNumOutputs(); i < n; ++i) {
104 if (i != 0) {
105 message << ", ";
106 }
107 message << context->getOutputType(i);
108 }
109 message << "} is only supported since " << minSupportedVersion << " (validating using "
110 << contextVersion << ")";
111 NN_RET_CHECK_FAIL() << message.str();
112 }
113 return true;
114 }
115
116 } // namespace android::nn
117