1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_PACKAGES_MODULES_NEURALNETWORKS_RUNTIME_FLATBUFFER_MODEL_BUILDER_UTILS_H
18 #define ANDROID_PACKAGES_MODULES_NEURALNETWORKS_RUNTIME_FLATBUFFER_MODEL_BUILDER_UTILS_H
19 
20 #include <nnapi/Result.h>
21 #include <nnapi/TypeUtils.h>
22 #include <tensorflow/lite/schema/schema_generated.h>
23 
24 #include <algorithm>
25 #include <vector>
26 
27 #include "NeuralNetworks.h"
28 #include "TypeManager.h"
29 
30 namespace android {
31 namespace nn {
32 
33 using SubGraphFlatbuffer = flatbuffers::Offset<tflite::SubGraph>;
34 using SubGraphsFlatbuffer = flatbuffers::Offset<flatbuffers::Vector<SubGraphFlatbuffer>>;
35 
36 using OperatorCodeFlatbuffer = flatbuffers::Offset<tflite::OperatorCode>;
37 using OperatorFlatbuffer = flatbuffers::Offset<tflite::Operator>;
38 using OperatorsFlatbuffer = flatbuffers::Offset<flatbuffers::Vector<OperatorFlatbuffer>>;
39 
40 using TensorFlatbuffer = flatbuffers::Offset<tflite::Tensor>;
41 using TensorsFlatbuffer = flatbuffers::Offset<flatbuffers::Vector<TensorFlatbuffer>>;
42 
43 using BufferFlatbuffer = flatbuffers::Offset<tflite::Buffer>;
44 
45 using MetadataFlatbuffer = flatbuffers::Offset<tflite::Metadata>;
46 
47 using ModelFlatbuffer = flatbuffers::Offset<tflite::Model>;
48 
49 // Only supports tensor types
50 // Will crash if passed in a scalar type
getTensorFlatbufferOperandType(const OperandType & type)51 inline Result<tflite::TensorType> getTensorFlatbufferOperandType(const OperandType& type) {
52     CHECK(TypeManager::get()->isTensorType(type));
53 
54     // TODO: Map more operands
55     switch (type) {
56         case OperandType::TENSOR_FLOAT32:
57             return tflite::TensorType::TensorType_FLOAT32;
58         case OperandType::TENSOR_INT32:
59             return tflite::TensorType::TensorType_INT32;
60         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
61             return tflite::TensorType::TensorType_INT8;
62         default:
63             NN_RET_CHECK_FAIL() << "OperandType not supported: " << type;
64     }
65 }
66 
getFlatbufferOperator(const OperationType & type)67 inline tflite::BuiltinOperator getFlatbufferOperator(const OperationType& type) {
68     // TODO: Add more operation types
69     switch (type) {
70         case OperationType::PAD:
71             return tflite::BuiltinOperator::BuiltinOperator_PAD;
72         case OperationType::CONV_2D:
73             return tflite::BuiltinOperator::BuiltinOperator_CONV_2D;
74         case OperationType::ADD:
75             return tflite::BuiltinOperator::BuiltinOperator_ADD;
76         case OperationType::DEPTHWISE_CONV_2D:
77             return tflite::BuiltinOperator::BuiltinOperator_DEPTHWISE_CONV_2D;
78         case OperationType::LOGISTIC:
79             return tflite::BuiltinOperator::BuiltinOperator_LOGISTIC;
80         default:
81             LOG(FATAL) << "OperationType not supported: " << type;
82             return {};
83     }
84 }
85 
86 // Referenced from external/tensorflow/tensorflow/lite/tools/versioning/op_version.cc
getMaxOperatorVersionCode(tflite::BuiltinOperator builtinCode)87 inline int32_t getMaxOperatorVersionCode(tflite::BuiltinOperator builtinCode) {
88     // TODO: Add more builtin_codes
89     switch (builtinCode) {
90         case tflite::BuiltinOperator::BuiltinOperator_CONV_2D:
91             return 5;
92         case tflite::BuiltinOperator::BuiltinOperator_DEPTHWISE_CONV_2D:
93             return 6;
94         case tflite::BuiltinOperator::BuiltinOperator_ADD:
95             return 4;
96         case tflite::BuiltinOperator::BuiltinOperator_PAD:
97             return 4;
98         case tflite::BuiltinOperator::BuiltinOperator_LOGISTIC:
99             return 3;
100         default:
101             LOG(FATAL) << "BuiltinOperator not supported: " << builtinCode;
102             return {};
103     }
104 }
105 
getTfliteActivation(FusedActivationFunc activation)106 inline Result<tflite::ActivationFunctionType> getTfliteActivation(FusedActivationFunc activation) {
107     switch (activation) {
108         case FusedActivationFunc::NONE:
109             return tflite::ActivationFunctionType::ActivationFunctionType_NONE;
110         case FusedActivationFunc::RELU:
111             return tflite::ActivationFunctionType::ActivationFunctionType_RELU;
112         case FusedActivationFunc::RELU1:
113             return tflite::ActivationFunctionType::ActivationFunctionType_RELU_N1_TO_1;
114         case FusedActivationFunc::RELU6:
115             return tflite::ActivationFunctionType::ActivationFunctionType_RELU6;
116         default:
117             NN_RET_CHECK_FAIL() << "FusedActivationFunc not supported: " << activation;
118     }
119 }
120 
tensorOperandHasUnspecifiedRank(const Operand & operand)121 inline bool tensorOperandHasUnspecifiedRank(const Operand& operand) {
122     return TypeManager::get()->isTensorType(operand.type) && operand.dimensions.empty();
123 }
124 
checkAllTensorOperandsHaveSpecifiedRank(const std::vector<Operand> & operands)125 inline Result<void> checkAllTensorOperandsHaveSpecifiedRank(const std::vector<Operand>& operands) {
126     NN_RET_CHECK(std::none_of(operands.begin(), operands.end(), &tensorOperandHasUnspecifiedRank))
127             << "At least one Operand has unspecified rank";
128     return {};
129 }
130 
subgraphOutputOperandHasDynamicShape(const Operand & operand)131 inline bool subgraphOutputOperandHasDynamicShape(const Operand& operand) {
132     return operand.lifetime == Operand::LifeTime::SUBGRAPH_OUTPUT &&
133            std::any_of(operand.dimensions.begin(), operand.dimensions.end(),
134                        [](const uint32_t& dim) { return dim == 0; });
135 }
136 
checkNoSubgraphOutputOperandsHaveDynamicShape(const std::vector<Operand> & operands)137 inline Result<void> checkNoSubgraphOutputOperandsHaveDynamicShape(
138         const std::vector<Operand>& operands) {
139     NN_RET_CHECK(
140             std::none_of(operands.begin(), operands.end(), &subgraphOutputOperandHasDynamicShape))
141             << "At least one subgraph output Operand has dynamic shape";
142     return {};
143 }
144 
isOperandConstant(const Operand & operand)145 inline bool isOperandConstant(const Operand& operand) {
146     return operand.lifetime == Operand::LifeTime::CONSTANT_COPY ||
147            operand.lifetime == Operand::LifeTime::CONSTANT_REFERENCE;
148 }
149 
getTFLitePadding(int32_t paddingType)150 inline tflite::Padding getTFLitePadding(int32_t paddingType) {
151     switch (paddingType) {
152         case ANEURALNETWORKS_PADDING_VALID:  // VALID
153         case 0:
154             return tflite::Padding::Padding_VALID;
155         case ANEURALNETWORKS_PADDING_SAME:  // SAME
156             return tflite::Padding::Padding_SAME;
157         default:
158             LOG(FATAL) << "Unsupported NNAPI NDK padding type: " << paddingType;
159             return {};
160     }
161 }
162 
163 // Replace all 0 dimensions to -1 since TFLite only supports -1 as an unknown dimension
replaceZeroDimensions(std::vector<int32_t> * dims)164 inline void replaceZeroDimensions(std::vector<int32_t>* dims) {
165     std::replace(dims->begin(), dims->end(), 0, -1);
166 }
167 
168 }  // namespace nn
169 }  // namespace android
170 
171 #endif  // ANDROID_PACKAGES_MODULES_NEURALNETWORKS_RUNTIME_FLATBUFFER_MODEL_BUILDER_UTILS_H
172