1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "OperationValidationUtils"
18 
19 #include "OperationsValidationUtils.h"
20 
21 #include <android-base/logging.h>
22 
23 #include <functional>
24 #include <vector>
25 
26 #include "OperationsUtils.h"
27 #include "nnapi/Validation.h"
28 
29 namespace android::nn {
30 namespace {
31 
validateOperandTypes(const std::vector<OperandType> & expectedTypes,const char * tag,uint32_t operandCount,std::function<OperandType (uint32_t)> getOperandType)32 bool validateOperandTypes(const std::vector<OperandType>& expectedTypes, const char* tag,
33                           uint32_t operandCount,
34                           std::function<OperandType(uint32_t)> getOperandType) {
35     NN_RET_CHECK_EQ(operandCount, expectedTypes.size());
36     for (uint32_t i = 0; i < operandCount; ++i) {
37         OperandType type = getOperandType(i);
38         NN_RET_CHECK(type == expectedTypes[i])
39                 << "Invalid " << tag << " tensor type " << type << " for " << tag << " " << i
40                 << ", expected " << expectedTypes[i];
41     }
42     return true;
43 }
44 
45 }  // namespace
46 
invalidInOutNumberMessage(int expIn,int expOut) const47 std::string IOperationValidationContext::invalidInOutNumberMessage(int expIn, int expOut) const {
48     std::ostringstream os;
49     os << "Invalid number of input operands (" << getNumInputs() << ", expected " << expIn
50        << ") or output operands (" << getNumOutputs() << ", expected " << expOut
51        << ") for operation " << getOperationName();
52     return os.str();
53 }
54 
validateOperationOperandTypes(const std::vector<OperandType> & inExpectedTypes,const std::vector<OperandType> & outExpectedInTypes) const55 Result<void> IOperationValidationContext::validateOperationOperandTypes(
56         const std::vector<OperandType>& inExpectedTypes,
57         const std::vector<OperandType>& outExpectedInTypes) const {
58     NN_RET_CHECK_EQ(getNumInputs(), inExpectedTypes.size())
59             << "Wrong operand count: expected " << inExpectedTypes.size() << " inputs, got "
60             << getNumInputs() << " inputs";
61     NN_RET_CHECK_EQ(getNumOutputs(), outExpectedInTypes.size())
62             << "Wrong operand count: expected " << outExpectedInTypes.size() << " outputs, got "
63             << getNumOutputs() << " outputs";
64     for (size_t i = 0; i < getNumInputs(); i++) {
65         NN_RET_CHECK_EQ(getInputType(i), inExpectedTypes[i])
66                 << "Invalid input tensor type " << getInputType(i) << " for input " << i
67                 << ", expected " << inExpectedTypes[i];
68     }
69     for (size_t i = 0; i < getNumOutputs(); i++) {
70         NN_RET_CHECK_EQ(getOutputType(i), outExpectedInTypes[i])
71                 << "Invalid output tensor type " << getOutputType(i) << " for input " << i
72                 << ", expected " << outExpectedInTypes[i];
73     }
74 
75     return {};
76 }
77 
validateInputTypes(const IOperationValidationContext * context,const std::vector<OperandType> & expectedTypes)78 bool validateInputTypes(const IOperationValidationContext* context,
79                         const std::vector<OperandType>& expectedTypes) {
80     return validateOperandTypes(expectedTypes, "input", context->getNumInputs(),
81                                 [context](uint32_t index) { return context->getInputType(index); });
82 }
83 
validateOutputTypes(const IOperationValidationContext * context,const std::vector<OperandType> & expectedTypes)84 bool validateOutputTypes(const IOperationValidationContext* context,
85                          const std::vector<OperandType>& expectedTypes) {
86     return validateOperandTypes(
87             expectedTypes, "output", context->getNumOutputs(),
88             [context](uint32_t index) { return context->getOutputType(index); });
89 }
90 
validateVersion(const IOperationValidationContext * context,Version contextVersion,Version minSupportedVersion)91 bool validateVersion(const IOperationValidationContext* context, Version contextVersion,
92                      Version minSupportedVersion) {
93     if (!isCompliantVersion(minSupportedVersion, contextVersion)) {
94         std::ostringstream message;
95         message << "Operation " << context->getOperationName() << " with inputs {";
96         for (uint32_t i = 0, n = context->getNumInputs(); i < n; ++i) {
97             if (i != 0) {
98                 message << ", ";
99             }
100             message << context->getInputType(i);
101         }
102         message << "} and outputs {";
103         for (uint32_t i = 0, n = context->getNumOutputs(); i < n; ++i) {
104             if (i != 0) {
105                 message << ", ";
106             }
107             message << context->getOutputType(i);
108         }
109         message << "} is only supported since " << minSupportedVersion << " (validating using "
110                 << contextVersion << ")";
111         NN_RET_CHECK_FAIL() << message.str();
112     }
113     return true;
114 }
115 
116 }  // namespace android::nn
117