1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Class used to build a model through a succession of successive calls 18 // to the NN API. 19 20 #ifndef ANDROID_ML_NN_RUNTIME_MODEL_BUILDER_H 21 #define ANDROID_ML_NN_RUNTIME_MODEL_BUILDER_H 22 23 #include "HalInterfaces.h" 24 #include "Memory.h" 25 #include "NeuralNetworks.h" 26 #include "Utils.h" 27 28 namespace android { 29 namespace nn { 30 31 class CompilationBuilder; 32 class Device; 33 class ExecutionPlan; 34 class Memory; 35 36 class ModelBuilder { 37 public: 38 // Adds an operand to the model. 39 int addOperand(const ANeuralNetworksOperandType& type); 40 int setOperandValue(uint32_t index, const void* buffer, size_t length); 41 int setOperandValueFromMemory(uint32_t index, const Memory* memory, uint32_t offset, 42 size_t length); 43 44 int addOperation(ANeuralNetworksOperationType type, uint32_t inputCount, const uint32_t* inputs, 45 uint32_t outputCount, const uint32_t* outputs); 46 int identifyInputsAndOutputs(uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, 47 const uint32_t* outputs); 48 int relaxComputationFloat32toFloat16(bool allow); isComputationFloat32RelaxedToFloat16()49 bool isComputationFloat32RelaxedToFloat16() const { return mRelaxComputationFloat32toFloat16; } 50 51 int finish(); isFinished()52 bool isFinished() const { return mCompletedModel; } 53 54 int createCompilation(CompilationBuilder** compilation); 55 56 void setHidlModel(Model* model) const; 57 operandCount()58 uint32_t operandCount() const { 59 // We don't allow more than uint32_t worth of operands 60 return static_cast<uint32_t>(mOperands.size()); 61 } operationCount()62 uint32_t operationCount() const { 63 // We don't allow more than uint32_t worth of operations 64 return static_cast<uint32_t>(mOperations.size()); 65 } inputCount()66 uint32_t inputCount() const { return static_cast<uint32_t>(mInputIndexes.size()); } outputCount()67 uint32_t outputCount() const { return static_cast<uint32_t>(mOutputIndexes.size()); } getInputOperandIndex(uint32_t i)68 uint32_t getInputOperandIndex(uint32_t i) const { return mInputIndexes[i]; } getInputOperand(uint32_t i)69 const Operand& getInputOperand(uint32_t i) const { 70 return mOperands[getInputOperandIndex(i)]; 71 } getOutputOperandIndex(uint32_t i)72 uint32_t getOutputOperandIndex(uint32_t i) const { return mOutputIndexes[i]; } getOutputOperand(uint32_t i)73 const Operand& getOutputOperand(uint32_t i) const { 74 return mOperands[getOutputOperandIndex(i)]; 75 } getOperand(uint32_t index)76 const Operand& getOperand(uint32_t index) const { return mOperands[index]; } getOperation(uint32_t index)77 const Operation& getOperation(uint32_t index) const { return mOperations[index]; } getMemories()78 const MemoryTracker& getMemories() const { return mMemories; } getOperations()79 const std::vector<Operation>& getOperations() const { return mOperations; } getPointerToOperandValue(uint32_t offset)80 const uint8_t* getPointerToOperandValue(uint32_t offset) const { 81 return mSmallOperandValues.data() + offset; 82 } 83 84 int partitionTheWork(const std::vector<std::shared_ptr<Device>>& devices, 85 uint32_t preference, ExecutionPlan* plan) const; 86 87 private: 88 // TODO: move partitionTheWork, findBestDeviceForEachOperation, 89 // sortIntoRunOrder to CompilationBuilder? 90 91 int findBestDeviceForEachOperation(uint32_t preference, 92 const std::vector<std::shared_ptr<Device>>& devices, 93 const size_t deviceCount, 94 std::vector<int>* bestDeviceForOperation) const; 95 PerformanceInfo getPerformanceInfo(const std::shared_ptr<Device> device, 96 uint32_t operationIndex) const; 97 98 // Return true if either mCompleteModel or mInvalidModel is true. 99 bool badState(const char* name); 100 101 // Sorts the operations to be in the correct order for single threaded 102 // node-at-a-time execution. 103 void sortIntoRunOrder(); 104 105 // Copies the large values to a shared memory, if we have any. 106 int copyLargeValuesToSharedMemory(); 107 108 // The operations of the graph. 109 std::vector<Operation> mOperations; 110 // The description of the operands of the graph. 111 std::vector<Operand> mOperands; 112 // Specifies where to find the list of indexes identifying 113 // the inputs and outputs of the model. The offset is into 114 // the mOperandIndexes table. 115 std::vector<uint32_t> mInputIndexes; 116 std::vector<uint32_t> mOutputIndexes; 117 118 MemoryTracker mMemories; 119 120 // The value of the small operands that are defined at model 121 // creation time. 122 std::vector<uint8_t> mSmallOperandValues; 123 124 struct LargeValue { 125 uint32_t operandIndex; 126 const void* buffer; 127 }; 128 // Operand index and buffer pointer for all the large operand values of this model. 129 std::vector<LargeValue> mLargeOperandValues; 130 // The shared memory region that will contain the large values. 131 Memory mLargeValueMemory; 132 133 // Once the model has been finished, we should not allow further 134 // modifications to the model. 135 mutable bool mCompletedModel = false; 136 137 // Any invalid manipulation of the model will mark the model invalid. 138 // No further modifications are allowed to the model. 139 mutable bool mInvalidModel = false; 140 141 // 'true' indicates TENSOR_FLOAT32 may be calculated with range and/or 142 // precision as low as that of the IEEE 754 16-bit floating-point format. 143 // 'false' indicates TENSOR_FLOAT32 must be calculated using at least the 144 // range and precision of the IEEE 754 32-bit floating-point format. 145 bool mRelaxComputationFloat32toFloat16 = false; 146 }; 147 148 } // namespace nn 149 } // namespace android 150 151 #endif // ANDROID_ML_NN_RUNTIME_MODEL_BUILDER_H 152