1 /*
2 * Copyright (C) 2022 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "SubGraphContext"
18
19 #include "SubGraphContext.h"
20
21 #include <limits>
22
23 #include "FlatbufferModelBuilderUtils.h"
24
25 namespace android {
26 namespace nn {
27
SubGraphContext(const Model * model,const Model::Subgraph * subgraph,flatbuffers::FlatBufferBuilder * builder,std::vector<OperatorCodeFlatbuffer> * opCodesVector,std::vector<int> * opCodeIndexForOperationType,std::vector<BufferFlatbuffer> * bufferVector)28 SubGraphContext::SubGraphContext(const Model* model, const Model::Subgraph* subgraph,
29 flatbuffers::FlatBufferBuilder* builder,
30 std::vector<OperatorCodeFlatbuffer>* opCodesVector,
31 std::vector<int>* opCodeIndexForOperationType,
32 std::vector<BufferFlatbuffer>* bufferVector)
33 : mModel(model),
34 mSubgraph(subgraph),
35 mBuilder(builder),
36 mOpCodesVector(opCodesVector),
37 mOpCodeIndexForOperationType(opCodeIndexForOperationType),
38 mBufferVector(bufferVector) {
39 CHECK(model != nullptr);
40 CHECK(subgraph != nullptr);
41 CHECK(opCodesVector != nullptr);
42 CHECK(opCodeIndexForOperationType != nullptr);
43 CHECK(bufferVector != nullptr);
44
45 mOperandToTensorIdx.resize(subgraph->operands.size(), -1);
46 mMappings.resize(model->pools.size());
47 }
48
finish()49 SubGraphFlatbuffer SubGraphContext::finish() {
50 return tflite::CreateSubGraphDirect(*mBuilder, &mTensorVector, &mInputTensors, &mOutputTensors,
51 &mOperatorVector);
52 }
53
addTensorFlatbuffer(TensorFlatbuffer tensor,int32_t operandIdx)54 int SubGraphContext::addTensorFlatbuffer(TensorFlatbuffer tensor, int32_t operandIdx) {
55 mTensorVector.push_back(tensor);
56
57 int tensorIdx = mTensorVector.size() - 1;
58 if (operandIdx >= 0) {
59 CHECK(mOperandToTensorIdx[operandIdx] == -1);
60 mOperandToTensorIdx[operandIdx] = tensorIdx;
61 }
62 return tensorIdx;
63 }
64
addOperatorFlatbuffer(OperatorFlatbuffer opFlatbuffer)65 void SubGraphContext::addOperatorFlatbuffer(OperatorFlatbuffer opFlatbuffer) {
66 mOperatorVector.push_back(opFlatbuffer);
67 }
68
addSubGraphInput(int32_t operandIdx)69 void SubGraphContext::addSubGraphInput(int32_t operandIdx) {
70 CHECK(mOperandToTensorIdx[operandIdx] != -1);
71 mInputTensors.push_back(mOperandToTensorIdx[operandIdx]);
72 }
73
addSubGraphOutput(int32_t operandIdx)74 void SubGraphContext::addSubGraphOutput(int32_t operandIdx) {
75 CHECK(mOperandToTensorIdx[operandIdx] != -1);
76 mOutputTensors.push_back(mOperandToTensorIdx[operandIdx]);
77 }
78
addOpCode(OperationType operationType)79 uint32_t SubGraphContext::addOpCode(OperationType operationType) {
80 uint32_t idx = static_cast<uint32_t>(operationType);
81 if (mOpCodeIndexForOperationType->at(idx) != -1) {
82 return mOpCodeIndexForOperationType->at(idx);
83 }
84
85 OperatorCodeFlatbuffer opCode;
86
87 tflite::BuiltinOperator builtinCode = getFlatbufferOperator(operationType);
88 if (builtinCode < tflite::BuiltinOperator::BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES)
89 opCode = tflite::CreateOperatorCode(
90 *mBuilder, static_cast<int8_t>(builtinCode) /* deprecated_builtin_code */,
91 0 /* custom_code */, getMaxOperatorVersionCode(builtinCode) /* version */);
92 else
93 opCode = tflite::CreateOperatorCode(*mBuilder, 0 /* deprecated_builtin_code */,
94 0 /* custom_code */,
95 getMaxOperatorVersionCode(builtinCode) /* version */,
96 builtinCode /* builtin_code */);
97
98 mOpCodesVector->push_back(opCode);
99 uint32_t opCodeIdx = mOpCodesVector->size() - 1;
100 (*mOpCodeIndexForOperationType)[idx] = opCodeIdx;
101 return opCodeIdx;
102 }
103
getTensorIdxFromOperandIdx(int operandIdx) const104 int SubGraphContext::getTensorIdxFromOperandIdx(int operandIdx) const {
105 return mOperandToTensorIdx[operandIdx];
106 }
107
getMapping(uint32_t poolIndex)108 const Mapping& SubGraphContext::getMapping(uint32_t poolIndex) {
109 if (mMappings[poolIndex].size > 0) {
110 return mMappings[poolIndex];
111 }
112
113 SharedMemory memory = mModel->pools[poolIndex];
114 GeneralResult<Mapping> mapping = map(memory);
115 CHECK(mapping.has_value()) << "CONSTANT_REFERENCE memory mapping error: "
116 << mapping.error().message;
117
118 mMappings[poolIndex] = std::move(mapping).value();
119 return mMappings[poolIndex];
120 }
121
getConstantPointerAndLength(const Operand & operand)122 std::pair<const uint8_t*, uint32_t> SubGraphContext::getConstantPointerAndLength(
123 const Operand& operand) {
124 CHECK(isOperandConstant(operand));
125
126 if (operand.lifetime == Operand::LifeTime::CONSTANT_COPY) {
127 return std::make_pair(mModel->operandValues.data() + operand.location.offset,
128 operand.location.length);
129 }
130
131 const Mapping& mapping = getMapping(operand.location.poolIndex);
132 const uint8_t* memoryPtr = static_cast<const uint8_t*>(
133 std::visit([](auto ptr) { return static_cast<const void*>(ptr); }, mapping.pointer));
134
135 return std::make_pair(memoryPtr + operand.location.offset, operand.location.length);
136 }
137
addBufferFromData(const uint8_t * data,uint32_t length)138 uint32_t SubGraphContext::addBufferFromData(const uint8_t* data, uint32_t length) {
139 auto dataVectorFlatbuffer = mBuilder->CreateVector(data, length);
140
141 auto buffer = tflite::CreateBuffer(*mBuilder, dataVectorFlatbuffer);
142 mBufferVector->push_back(buffer);
143
144 return mBufferVector->size() - 1;
145 }
146
createTensorFlatbufferFromOperand(uint32_t operandIdx,bool makeSymmetric)147 Result<void> SubGraphContext::createTensorFlatbufferFromOperand(uint32_t operandIdx,
148 bool makeSymmetric) {
149 // An output Operand to one Operation can be an input Operand to
150 // another Operation, so this function can be run more than once.
151 // We simply return if the Tensor for the Operand is already created.
152 if (mOperandToTensorIdx[operandIdx] != -1) return {};
153
154 const Operand& operand = mSubgraph->operands[operandIdx];
155
156 std::vector<float> scaleVector{operand.scale};
157 std::vector<int64_t> zeroPointVector{operand.zeroPoint};
158 // min and max used to convert TFLite models to TF models, so it is unused in this case and can
159 // be set to 0
160 std::vector<float> minVector{0};
161 std::vector<float> maxVector{0};
162
163 // build quantization parameters
164 auto quantizationParams = tflite::CreateQuantizationParametersDirect(
165 *mBuilder, &minVector /* min */, &maxVector /* max */, &scaleVector /* scale */,
166 &zeroPointVector /* zero_point */,
167 tflite::QuantizationDetails::QuantizationDetails_NONE /* details_type */);
168
169 // add buffer if constant operand
170 // buffer at index 0 is reserved for tensors without a buffer
171 uint32_t bufferIdx = 0;
172 if (isOperandConstant(operand)) {
173 auto [data, dataLength] = getConstantPointerAndLength(operand);
174 if (makeSymmetric && operand.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
175 std::vector<int8_t> dataVector(reinterpret_cast<const int8_t*>(data),
176 reinterpret_cast<const int8_t*>(data) + dataLength);
177 bool emitWarning = false;
178 for (uint32_t i = 0; i < dataLength; i++) {
179 int32_t newValue = static_cast<int32_t>(dataVector[i]) - operand.zeroPoint;
180 if (newValue < std::numeric_limits<int8_t>::min() ||
181 newValue > std::numeric_limits<int8_t>::max()) {
182 emitWarning = true;
183 }
184 dataVector[i] = static_cast<int8_t>(std::clamp(
185 newValue, static_cast<int32_t>(std::numeric_limits<int8_t>::min()),
186 static_cast<int32_t>(std::numeric_limits<int8_t>::max())));
187 }
188
189 if (emitWarning) {
190 LOG(WARNING) << "Asymmetric to symmetric conversion will result in "
191 "underflow/overflow. Clamping data";
192 }
193 bufferIdx = addBufferFromData(reinterpret_cast<const uint8_t*>(dataVector.data()),
194 dataLength);
195 } else {
196 bufferIdx = addBufferFromData(data, dataLength);
197 }
198 }
199
200 // shape of tensor
201 std::vector<int32_t> shape(operand.dimensions.begin(), operand.dimensions.end());
202 replaceZeroDimensions(&shape);
203
204 // build tensor
205 TensorFlatbuffer tensor = tflite::CreateTensorDirect(
206 *mBuilder, &shape, NN_TRY(getTensorFlatbufferOperandType(operand.type)) /* type */,
207 bufferIdx /* buffer */, 0 /* name */, quantizationParams /* quantization */);
208 addTensorFlatbuffer(tensor, operandIdx);
209
210 return {};
211 }
212
213 } // namespace nn
214 } // namespace android