1 /*
2 * Copyright (C) 2022 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef ANDROID_PACKAGES_MODULES_NEURALNETWORKS_RUNTIME_FLATBUFFER_MODEL_BUILDER_UTILS_H
18 #define ANDROID_PACKAGES_MODULES_NEURALNETWORKS_RUNTIME_FLATBUFFER_MODEL_BUILDER_UTILS_H
19
20 #include <nnapi/Result.h>
21 #include <nnapi/TypeUtils.h>
22 #include <tensorflow/lite/schema/schema_generated.h>
23
24 #include <algorithm>
25 #include <vector>
26
27 #include "NeuralNetworks.h"
28 #include "TypeManager.h"
29
30 namespace android {
31 namespace nn {
32
33 using SubGraphFlatbuffer = flatbuffers::Offset<tflite::SubGraph>;
34 using SubGraphsFlatbuffer = flatbuffers::Offset<flatbuffers::Vector<SubGraphFlatbuffer>>;
35
36 using OperatorCodeFlatbuffer = flatbuffers::Offset<tflite::OperatorCode>;
37 using OperatorFlatbuffer = flatbuffers::Offset<tflite::Operator>;
38 using OperatorsFlatbuffer = flatbuffers::Offset<flatbuffers::Vector<OperatorFlatbuffer>>;
39
40 using TensorFlatbuffer = flatbuffers::Offset<tflite::Tensor>;
41 using TensorsFlatbuffer = flatbuffers::Offset<flatbuffers::Vector<TensorFlatbuffer>>;
42
43 using BufferFlatbuffer = flatbuffers::Offset<tflite::Buffer>;
44
45 using MetadataFlatbuffer = flatbuffers::Offset<tflite::Metadata>;
46
47 using ModelFlatbuffer = flatbuffers::Offset<tflite::Model>;
48
49 // Only supports tensor types
50 // Will crash if passed in a scalar type
getTensorFlatbufferOperandType(const OperandType & type)51 inline Result<tflite::TensorType> getTensorFlatbufferOperandType(const OperandType& type) {
52 CHECK(TypeManager::get()->isTensorType(type));
53
54 // TODO: Map more operands
55 switch (type) {
56 case OperandType::TENSOR_FLOAT32:
57 return tflite::TensorType::TensorType_FLOAT32;
58 case OperandType::TENSOR_INT32:
59 return tflite::TensorType::TensorType_INT32;
60 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
61 return tflite::TensorType::TensorType_INT8;
62 default:
63 NN_RET_CHECK_FAIL() << "OperandType not supported: " << type;
64 }
65 }
66
getFlatbufferOperator(const OperationType & type)67 inline tflite::BuiltinOperator getFlatbufferOperator(const OperationType& type) {
68 // TODO: Add more operation types
69 switch (type) {
70 case OperationType::PAD:
71 return tflite::BuiltinOperator::BuiltinOperator_PAD;
72 case OperationType::CONV_2D:
73 return tflite::BuiltinOperator::BuiltinOperator_CONV_2D;
74 case OperationType::ADD:
75 return tflite::BuiltinOperator::BuiltinOperator_ADD;
76 case OperationType::DEPTHWISE_CONV_2D:
77 return tflite::BuiltinOperator::BuiltinOperator_DEPTHWISE_CONV_2D;
78 case OperationType::LOGISTIC:
79 return tflite::BuiltinOperator::BuiltinOperator_LOGISTIC;
80 default:
81 LOG(FATAL) << "OperationType not supported: " << type;
82 return {};
83 }
84 }
85
86 // Referenced from external/tensorflow/tensorflow/lite/tools/versioning/op_version.cc
getMaxOperatorVersionCode(tflite::BuiltinOperator builtinCode)87 inline int32_t getMaxOperatorVersionCode(tflite::BuiltinOperator builtinCode) {
88 // TODO: Add more builtin_codes
89 switch (builtinCode) {
90 case tflite::BuiltinOperator::BuiltinOperator_CONV_2D:
91 return 5;
92 case tflite::BuiltinOperator::BuiltinOperator_DEPTHWISE_CONV_2D:
93 return 6;
94 case tflite::BuiltinOperator::BuiltinOperator_ADD:
95 return 4;
96 case tflite::BuiltinOperator::BuiltinOperator_PAD:
97 return 4;
98 case tflite::BuiltinOperator::BuiltinOperator_LOGISTIC:
99 return 3;
100 default:
101 LOG(FATAL) << "BuiltinOperator not supported: " << builtinCode;
102 return {};
103 }
104 }
105
getTfliteActivation(FusedActivationFunc activation)106 inline Result<tflite::ActivationFunctionType> getTfliteActivation(FusedActivationFunc activation) {
107 switch (activation) {
108 case FusedActivationFunc::NONE:
109 return tflite::ActivationFunctionType::ActivationFunctionType_NONE;
110 case FusedActivationFunc::RELU:
111 return tflite::ActivationFunctionType::ActivationFunctionType_RELU;
112 case FusedActivationFunc::RELU1:
113 return tflite::ActivationFunctionType::ActivationFunctionType_RELU_N1_TO_1;
114 case FusedActivationFunc::RELU6:
115 return tflite::ActivationFunctionType::ActivationFunctionType_RELU6;
116 default:
117 NN_RET_CHECK_FAIL() << "FusedActivationFunc not supported: " << activation;
118 }
119 }
120
tensorOperandHasUnspecifiedRank(const Operand & operand)121 inline bool tensorOperandHasUnspecifiedRank(const Operand& operand) {
122 return TypeManager::get()->isTensorType(operand.type) && operand.dimensions.empty();
123 }
124
checkAllTensorOperandsHaveSpecifiedRank(const std::vector<Operand> & operands)125 inline Result<void> checkAllTensorOperandsHaveSpecifiedRank(const std::vector<Operand>& operands) {
126 NN_RET_CHECK(std::none_of(operands.begin(), operands.end(), &tensorOperandHasUnspecifiedRank))
127 << "At least one Operand has unspecified rank";
128 return {};
129 }
130
subgraphOutputOperandHasDynamicShape(const Operand & operand)131 inline bool subgraphOutputOperandHasDynamicShape(const Operand& operand) {
132 return operand.lifetime == Operand::LifeTime::SUBGRAPH_OUTPUT &&
133 std::any_of(operand.dimensions.begin(), operand.dimensions.end(),
134 [](const uint32_t& dim) { return dim == 0; });
135 }
136
checkNoSubgraphOutputOperandsHaveDynamicShape(const std::vector<Operand> & operands)137 inline Result<void> checkNoSubgraphOutputOperandsHaveDynamicShape(
138 const std::vector<Operand>& operands) {
139 NN_RET_CHECK(
140 std::none_of(operands.begin(), operands.end(), &subgraphOutputOperandHasDynamicShape))
141 << "At least one subgraph output Operand has dynamic shape";
142 return {};
143 }
144
isOperandConstant(const Operand & operand)145 inline bool isOperandConstant(const Operand& operand) {
146 return operand.lifetime == Operand::LifeTime::CONSTANT_COPY ||
147 operand.lifetime == Operand::LifeTime::CONSTANT_REFERENCE;
148 }
149
getTFLitePadding(int32_t paddingType)150 inline tflite::Padding getTFLitePadding(int32_t paddingType) {
151 switch (paddingType) {
152 case ANEURALNETWORKS_PADDING_VALID: // VALID
153 case 0:
154 return tflite::Padding::Padding_VALID;
155 case ANEURALNETWORKS_PADDING_SAME: // SAME
156 return tflite::Padding::Padding_SAME;
157 default:
158 LOG(FATAL) << "Unsupported NNAPI NDK padding type: " << paddingType;
159 return {};
160 }
161 }
162
163 // Replace all 0 dimensions to -1 since TFLite only supports -1 as an unknown dimension
replaceZeroDimensions(std::vector<int32_t> * dims)164 inline void replaceZeroDimensions(std::vector<int32_t>* dims) {
165 std::replace(dims->begin(), dims->end(), 0, -1);
166 }
167
168 } // namespace nn
169 } // namespace android
170
171 #endif // ANDROID_PACKAGES_MODULES_NEURALNETWORKS_RUNTIME_FLATBUFFER_MODEL_BUILDER_UTILS_H
172