1 /*
2 * Copyright (C) 2022 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef ANDROID_PACKAGES_MODULES_NEURALNETWORKS_RUNTIME_OPERATION_CONVERTERS_SUBGRAPH_CONTEXT_H
18 #define ANDROID_PACKAGES_MODULES_NEURALNETWORKS_RUNTIME_OPERATION_CONVERTERS_SUBGRAPH_CONTEXT_H
19
20 #include <utility>
21 #include <vector>
22
23 #include "FlatbufferModelBuilderUtils.h"
24 #include "NeuralNetworks.h"
25
26 namespace android {
27 namespace nn {
28
29 // This keeps track of all the data needed to convert NNAPI subgraphs to TFLite subgraphs
30 // This also provides information needed to convert NNAPI Operations to TFLite Operators
31 // Once the subgraph is done building, call finish() to return the flatbuffer
32 class SubGraphContext {
33 public:
34 SubGraphContext(const Model* model, const Model::Subgraph* subgraph,
35 flatbuffers::FlatBufferBuilder* builder,
36 std::vector<OperatorCodeFlatbuffer>* opCodesVector,
37 std::vector<int>* opCodeIndexForOperationType,
38 std::vector<BufferFlatbuffer>* bufferVector);
39
40 SubGraphFlatbuffer finish();
41
42 // If the operandIdx is -1, it suggests that the tensor being added doesn't have a
43 // corresponding Operand from the NNAPI NDK model.
44 // Returns index of Tensor being added.
45 int addTensorFlatbuffer(TensorFlatbuffer tensor, int32_t operandIdx = -1);
46 void addOperatorFlatbuffer(OperatorFlatbuffer opFlatbuffer);
47 void addSubGraphInput(int32_t operandIdx);
48 void addSubGraphOutput(int32_t operandIdx);
49
getSubgraph()50 const Model::Subgraph* getSubgraph() const { return mSubgraph; }
51 // Returns -1 if there is no corresponding tensor index
52 int getTensorIdxFromOperandIdx(int operandIdx) const;
53 uint32_t addOpCode(OperationType operationType);
getBuilder()54 flatbuffers::FlatBufferBuilder& getBuilder() { return *mBuilder; }
55
56 // OperandLifeTime must be CONSTANT_COPY or CONSTANT_REFERENCE
57 // Will crash if OperandLifeTime is not either of the two.
58 // dataSize is the size of data in bytes.
59 template <typename Type>
60 void copyConstantValueToData(const Operand& operand, Type* data, size_t dataSize);
61 template <typename Type>
62 Type getConstantScalar(const Operand& operand);
63
64 // Returns Buffer index
65 uint32_t addBufferFromData(const uint8_t* data, uint32_t length);
66 // makeSymmetric turns asymmetric tensors to symmetric by doing setting data = data - zeroPoint
67 // makeSymmetric is supported only for constant OperandType::TENSOR_QUANT8_ASYMM_SIGNED
68 // If unsupported type is passed, makeSymmetric is ignored
69 Result<void> createTensorFlatbufferFromOperand(uint32_t operandIdx, bool makeSymmetric = false);
70
71 private:
72 const Mapping& getMapping(uint32_t poolIndex);
73 std::pair<const uint8_t*, uint32_t> getConstantPointerAndLength(const Operand& operand);
74
75 const Model* mModel;
76 const Model::Subgraph* mSubgraph;
77 flatbuffers::FlatBufferBuilder* mBuilder;
78
79 std::vector<OperatorCodeFlatbuffer>* mOpCodesVector;
80 std::vector<int>* mOpCodeIndexForOperationType;
81 std::vector<BufferFlatbuffer>* mBufferVector;
82
83 std::vector<OperatorFlatbuffer> mOperatorVector;
84 std::vector<TensorFlatbuffer> mTensorVector;
85 std::vector<int32_t> mInputTensors;
86 std::vector<int32_t> mOutputTensors;
87 std::vector<int> mOperandToTensorIdx;
88 // Each index corresponds to the pool index of shared memory
89 std::vector<Mapping> mMappings;
90 };
91
92 template <typename Type>
copyConstantValueToData(const Operand & operand,Type * data,size_t dataSize)93 void SubGraphContext::copyConstantValueToData(const Operand& operand, Type* data, size_t dataSize) {
94 auto [pointer, length] = getConstantPointerAndLength(operand);
95 CHECK_GE(dataSize, length);
96
97 std::memcpy(data, pointer, length);
98 }
99
100 template <typename Type>
getConstantScalar(const Operand & operand)101 Type SubGraphContext::getConstantScalar(const Operand& operand) {
102 Type data;
103 copyConstantValueToData(operand, &data, sizeof(Type));
104 return data;
105 }
106
107 } // namespace nn
108 } // namespace android
109
110 #endif // ANDROID_PACKAGES_MODULES_NEURALNETWORKS_RUNTIME_OPERATION_CONVERTERS_SUBGRAPH_CONTEXT_H