1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "SubGraphContext"
18 
19 #include "SubGraphContext.h"
20 
21 #include <limits>
22 
23 #include "FlatbufferModelBuilderUtils.h"
24 
25 namespace android {
26 namespace nn {
27 
SubGraphContext(const Model * model,const Model::Subgraph * subgraph,flatbuffers::FlatBufferBuilder * builder,std::vector<OperatorCodeFlatbuffer> * opCodesVector,std::vector<int> * opCodeIndexForOperationType,std::vector<BufferFlatbuffer> * bufferVector)28 SubGraphContext::SubGraphContext(const Model* model, const Model::Subgraph* subgraph,
29                                  flatbuffers::FlatBufferBuilder* builder,
30                                  std::vector<OperatorCodeFlatbuffer>* opCodesVector,
31                                  std::vector<int>* opCodeIndexForOperationType,
32                                  std::vector<BufferFlatbuffer>* bufferVector)
33     : mModel(model),
34       mSubgraph(subgraph),
35       mBuilder(builder),
36       mOpCodesVector(opCodesVector),
37       mOpCodeIndexForOperationType(opCodeIndexForOperationType),
38       mBufferVector(bufferVector) {
39     CHECK(model != nullptr);
40     CHECK(subgraph != nullptr);
41     CHECK(opCodesVector != nullptr);
42     CHECK(opCodeIndexForOperationType != nullptr);
43     CHECK(bufferVector != nullptr);
44 
45     mOperandToTensorIdx.resize(subgraph->operands.size(), -1);
46     mMappings.resize(model->pools.size());
47 }
48 
finish()49 SubGraphFlatbuffer SubGraphContext::finish() {
50     return tflite::CreateSubGraphDirect(*mBuilder, &mTensorVector, &mInputTensors, &mOutputTensors,
51                                         &mOperatorVector);
52 }
53 
addTensorFlatbuffer(TensorFlatbuffer tensor,int32_t operandIdx)54 int SubGraphContext::addTensorFlatbuffer(TensorFlatbuffer tensor, int32_t operandIdx) {
55     mTensorVector.push_back(tensor);
56 
57     int tensorIdx = mTensorVector.size() - 1;
58     if (operandIdx >= 0) {
59         CHECK(mOperandToTensorIdx[operandIdx] == -1);
60         mOperandToTensorIdx[operandIdx] = tensorIdx;
61     }
62     return tensorIdx;
63 }
64 
addOperatorFlatbuffer(OperatorFlatbuffer opFlatbuffer)65 void SubGraphContext::addOperatorFlatbuffer(OperatorFlatbuffer opFlatbuffer) {
66     mOperatorVector.push_back(opFlatbuffer);
67 }
68 
addSubGraphInput(int32_t operandIdx)69 void SubGraphContext::addSubGraphInput(int32_t operandIdx) {
70     CHECK(mOperandToTensorIdx[operandIdx] != -1);
71     mInputTensors.push_back(mOperandToTensorIdx[operandIdx]);
72 }
73 
addSubGraphOutput(int32_t operandIdx)74 void SubGraphContext::addSubGraphOutput(int32_t operandIdx) {
75     CHECK(mOperandToTensorIdx[operandIdx] != -1);
76     mOutputTensors.push_back(mOperandToTensorIdx[operandIdx]);
77 }
78 
addOpCode(OperationType operationType)79 uint32_t SubGraphContext::addOpCode(OperationType operationType) {
80     uint32_t idx = static_cast<uint32_t>(operationType);
81     if (mOpCodeIndexForOperationType->at(idx) != -1) {
82         return mOpCodeIndexForOperationType->at(idx);
83     }
84 
85     OperatorCodeFlatbuffer opCode;
86 
87     tflite::BuiltinOperator builtinCode = getFlatbufferOperator(operationType);
88     if (builtinCode < tflite::BuiltinOperator::BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES)
89         opCode = tflite::CreateOperatorCode(
90                 *mBuilder, static_cast<int8_t>(builtinCode) /* deprecated_builtin_code */,
91                 0 /* custom_code */, getMaxOperatorVersionCode(builtinCode) /* version */);
92     else
93         opCode = tflite::CreateOperatorCode(*mBuilder, 0 /* deprecated_builtin_code */,
94                                             0 /* custom_code */,
95                                             getMaxOperatorVersionCode(builtinCode) /* version */,
96                                             builtinCode /* builtin_code */);
97 
98     mOpCodesVector->push_back(opCode);
99     uint32_t opCodeIdx = mOpCodesVector->size() - 1;
100     (*mOpCodeIndexForOperationType)[idx] = opCodeIdx;
101     return opCodeIdx;
102 }
103 
getTensorIdxFromOperandIdx(int operandIdx) const104 int SubGraphContext::getTensorIdxFromOperandIdx(int operandIdx) const {
105     return mOperandToTensorIdx[operandIdx];
106 }
107 
getMapping(uint32_t poolIndex)108 const Mapping& SubGraphContext::getMapping(uint32_t poolIndex) {
109     if (mMappings[poolIndex].size > 0) {
110         return mMappings[poolIndex];
111     }
112 
113     SharedMemory memory = mModel->pools[poolIndex];
114     GeneralResult<Mapping> mapping = map(memory);
115     CHECK(mapping.has_value()) << "CONSTANT_REFERENCE memory mapping error: "
116                                << mapping.error().message;
117 
118     mMappings[poolIndex] = std::move(mapping).value();
119     return mMappings[poolIndex];
120 }
121 
getConstantPointerAndLength(const Operand & operand)122 std::pair<const uint8_t*, uint32_t> SubGraphContext::getConstantPointerAndLength(
123         const Operand& operand) {
124     CHECK(isOperandConstant(operand));
125 
126     if (operand.lifetime == Operand::LifeTime::CONSTANT_COPY) {
127         return std::make_pair(mModel->operandValues.data() + operand.location.offset,
128                               operand.location.length);
129     }
130 
131     const Mapping& mapping = getMapping(operand.location.poolIndex);
132     const uint8_t* memoryPtr = static_cast<const uint8_t*>(
133             std::visit([](auto ptr) { return static_cast<const void*>(ptr); }, mapping.pointer));
134 
135     return std::make_pair(memoryPtr + operand.location.offset, operand.location.length);
136 }
137 
addBufferFromData(const uint8_t * data,uint32_t length)138 uint32_t SubGraphContext::addBufferFromData(const uint8_t* data, uint32_t length) {
139     auto dataVectorFlatbuffer = mBuilder->CreateVector(data, length);
140 
141     auto buffer = tflite::CreateBuffer(*mBuilder, dataVectorFlatbuffer);
142     mBufferVector->push_back(buffer);
143 
144     return mBufferVector->size() - 1;
145 }
146 
createTensorFlatbufferFromOperand(uint32_t operandIdx,bool makeSymmetric)147 Result<void> SubGraphContext::createTensorFlatbufferFromOperand(uint32_t operandIdx,
148                                                                 bool makeSymmetric) {
149     // An output Operand to one Operation can be an input Operand to
150     // another Operation, so this function can be run more than once.
151     // We simply return if the Tensor for the Operand is already created.
152     if (mOperandToTensorIdx[operandIdx] != -1) return {};
153 
154     const Operand& operand = mSubgraph->operands[operandIdx];
155 
156     std::vector<float> scaleVector{operand.scale};
157     std::vector<int64_t> zeroPointVector{operand.zeroPoint};
158     // min and max used to convert TFLite models to TF models, so it is unused in this case and can
159     // be set to 0
160     std::vector<float> minVector{0};
161     std::vector<float> maxVector{0};
162 
163     // build quantization parameters
164     auto quantizationParams = tflite::CreateQuantizationParametersDirect(
165             *mBuilder, &minVector /* min */, &maxVector /* max */, &scaleVector /* scale */,
166             &zeroPointVector /* zero_point */,
167             tflite::QuantizationDetails::QuantizationDetails_NONE /* details_type */);
168 
169     // add buffer if constant operand
170     // buffer at index 0 is reserved for tensors without a buffer
171     uint32_t bufferIdx = 0;
172     if (isOperandConstant(operand)) {
173         auto [data, dataLength] = getConstantPointerAndLength(operand);
174         if (makeSymmetric && operand.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
175             std::vector<int8_t> dataVector(reinterpret_cast<const int8_t*>(data),
176                                            reinterpret_cast<const int8_t*>(data) + dataLength);
177             bool emitWarning = false;
178             for (uint32_t i = 0; i < dataLength; i++) {
179                 int32_t newValue = static_cast<int32_t>(dataVector[i]) - operand.zeroPoint;
180                 if (newValue < std::numeric_limits<int8_t>::min() ||
181                     newValue > std::numeric_limits<int8_t>::max()) {
182                     emitWarning = true;
183                 }
184                 dataVector[i] = static_cast<int8_t>(std::clamp(
185                         newValue, static_cast<int32_t>(std::numeric_limits<int8_t>::min()),
186                         static_cast<int32_t>(std::numeric_limits<int8_t>::max())));
187             }
188 
189             if (emitWarning) {
190                 LOG(WARNING) << "Asymmetric to symmetric conversion will result in "
191                                 "underflow/overflow. Clamping data";
192             }
193             bufferIdx = addBufferFromData(reinterpret_cast<const uint8_t*>(dataVector.data()),
194                                           dataLength);
195         } else {
196             bufferIdx = addBufferFromData(data, dataLength);
197         }
198     }
199 
200     // shape of tensor
201     std::vector<int32_t> shape(operand.dimensions.begin(), operand.dimensions.end());
202     replaceZeroDimensions(&shape);
203 
204     // build tensor
205     TensorFlatbuffer tensor = tflite::CreateTensorDirect(
206             *mBuilder, &shape, NN_TRY(getTensorFlatbufferOperandType(operand.type)) /* type */,
207             bufferIdx /* buffer */, 0 /* name */, quantizationParams /* quantization */);
208     addTensorFlatbuffer(tensor, operandIdx);
209 
210     return {};
211 }
212 
213 }  // namespace nn
214 }  // namespace android