1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Class used to build a model through a succession of successive calls 18 // to the NN API. 19 20 #ifndef ANDROID_ML_NN_RUNTIME_MODEL_BUILDER_H 21 #define ANDROID_ML_NN_RUNTIME_MODEL_BUILDER_H 22 23 #include "HalInterfaces.h" 24 #include "Memory.h" 25 #include "NeuralNetworks.h" 26 #include "Utils.h" 27 28 namespace android { 29 namespace nn { 30 31 class CompilationBuilder; 32 class Device; 33 class ExecutionPlan; 34 class Memory; 35 36 class ModelBuilder { 37 public: ModelBuilder()38 ModelBuilder() {} 39 // Returns an operand/operation type corresponding to a given extension operand/operation type. 40 int getExtensionType(const char* extensionName, uint16_t typeWithinExtension, int32_t* type); 41 // Adds an operand to the model. 42 int addOperand(const ANeuralNetworksOperandType& type); 43 int setOperandValue(uint32_t index, const void* buffer, size_t length); 44 int setOperandValueFromMemory(uint32_t index, const Memory* memory, uint32_t offset, 45 size_t length); 46 int setOperandSymmPerChannelQuantParams( 47 uint32_t index, const ANeuralNetworksSymmPerChannelQuantParams& extraParams); 48 int setOperandExtensionData(uint32_t index, const void* data, size_t length); 49 50 int addOperation(ANeuralNetworksOperationType type, uint32_t inputCount, const uint32_t* inputs, 51 uint32_t outputCount, const uint32_t* outputs); 52 int identifyInputsAndOutputs(uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, 53 const uint32_t* outputs); 54 int relaxComputationFloat32toFloat16(bool allow); isComputationFloat32RelaxedToFloat16()55 bool isComputationFloat32RelaxedToFloat16() const { return mRelaxComputationFloat32toFloat16; } 56 57 int finish(); isFinished()58 bool isFinished() const { return mCompletedModel; } isValid()59 bool isValid() const { return !mInvalidModel; } 60 hasOEMOperation()61 bool hasOEMOperation() const { return mHasOEMOperation; } hasExtensionOperation()62 bool hasExtensionOperation() const { return mHasExtensionOperation; } 63 64 // explicitDeviceList is true if the list of devices was provided explicitly 65 // via the ANeuralNetworksModel_createForDevices API (which has certain 66 // special semantics) and false otherwise. 67 int createCompilation(CompilationBuilder** compilation, 68 const std::vector<std::shared_ptr<Device>>& devices, 69 bool explicitDeviceList = false); 70 71 void setHidlModel(Model* model) const; 72 operandCount()73 uint32_t operandCount() const { 74 // We don't allow more than uint32_t worth of operands 75 return static_cast<uint32_t>(mOperands.size()); 76 } operationCount()77 uint32_t operationCount() const { 78 // We don't allow more than uint32_t worth of operations 79 return static_cast<uint32_t>(mOperations.size()); 80 } inputCount()81 uint32_t inputCount() const { return static_cast<uint32_t>(mInputIndexes.size()); } outputCount()82 uint32_t outputCount() const { return static_cast<uint32_t>(mOutputIndexes.size()); } getInputOperandIndex(uint32_t i)83 uint32_t getInputOperandIndex(uint32_t i) const { return mInputIndexes[i]; } getInputOperandIndexes()84 const std::vector<uint32_t>& getInputOperandIndexes() const { return mInputIndexes; } getInputOperand(uint32_t i)85 const Operand& getInputOperand(uint32_t i) const { return mOperands[getInputOperandIndex(i)]; } getOutputOperandIndex(uint32_t i)86 uint32_t getOutputOperandIndex(uint32_t i) const { return mOutputIndexes[i]; } getOutputOperandIndexes()87 const std::vector<uint32_t>& getOutputOperandIndexes() const { return mOutputIndexes; } getOutputOperand(uint32_t i)88 const Operand& getOutputOperand(uint32_t i) const { 89 return mOperands[getOutputOperandIndex(i)]; 90 } getOperand(uint32_t index)91 const Operand& getOperand(uint32_t index) const { return mOperands[index]; } getOperation(uint32_t index)92 const Operation& getOperation(uint32_t index) const { return mOperations[index]; } getMemories()93 const MemoryTracker& getMemories() const { return mMemories; } getOperations()94 const std::vector<Operation>& getOperations() const { return mOperations; } getSortedOperationMapping()95 const std::vector<uint32_t>& getSortedOperationMapping() const { 96 return mSortedOperationIndexMap; 97 } getPointerToOperandValue(uint32_t offset)98 const uint8_t* getPointerToOperandValue(uint32_t offset) const { 99 return mSmallOperandValues.data() + offset; 100 } 101 102 int partitionTheWork(const std::vector<std::shared_ptr<Device>>& devices, uint32_t preference, 103 ExecutionPlan* plan) const; 104 105 private: 106 // TODO: move partitionTheWork, findBestDeviceForEachOperation, 107 // sortIntoRunOrder to CompilationBuilder? 108 109 int findBestDeviceForEachOperation(uint32_t preference, 110 const std::vector<std::shared_ptr<Device>>& devices, 111 std::vector<int>* bestDeviceForOperation) const; 112 PerformanceInfo getPerformanceInfo(const std::shared_ptr<Device> device, 113 uint32_t operationIndex) const; 114 115 // Return true if either mCompleteModel or mInvalidModel is true. 116 bool badState(const char* name); 117 118 // Sorts the operations to be in the correct order for single threaded 119 // node-at-a-time execution. 120 void sortIntoRunOrder(); 121 122 // Copies the large values to a shared memory, if we have any. 123 int copyLargeValuesToSharedMemory(); 124 125 // Returns the list of extension names and corresponding numeric "prefixes" 126 // of operand and operation type values used in the model. 127 // 128 // Devices rely on this mapping to interpret extension types. 129 std::vector<Model::ExtensionNameAndPrefix> getExtensionNameToPrefixMap() const; 130 131 // The operations of the graph. 132 std::vector<Operation> mOperations; 133 // The mapping from sorted index to the original index of operations in mOperations. 134 // mSortedOperationIndexMap is empty before sortIntoRunOrder() is called. 135 std::vector<uint32_t> mSortedOperationIndexMap; 136 // Is at least one of those operations an OEM_OPERATION? 137 bool mHasOEMOperation = false; 138 // Is at least one of those operations an extension operation? 139 bool mHasExtensionOperation = false; 140 // The description of the operands of the graph. 141 std::vector<Operand> mOperands; 142 // Specifies where to find the list of indexes identifying 143 // the inputs and outputs of the model. The offset is into 144 // the mOperandIndexes table. 145 std::vector<uint32_t> mInputIndexes; 146 std::vector<uint32_t> mOutputIndexes; 147 148 MemoryTracker mMemories; 149 150 // The value of the small operands that are defined at model 151 // creation time. 152 std::vector<uint8_t> mSmallOperandValues; 153 154 struct LargeValue { 155 uint32_t operandIndex; 156 const void* buffer; 157 }; 158 // Operand index and buffer pointer for all the large operand values of this model. 159 std::vector<LargeValue> mLargeOperandValues; 160 // The shared memory region that will contain the large values. 161 Memory mLargeValueMemory; 162 163 // Once the model has been finished, we should not allow further 164 // modifications to the model. 165 bool mCompletedModel = false; 166 167 // Any invalid manipulation of the model will mark the model invalid. 168 // No further modifications are allowed to the model. 169 bool mInvalidModel = false; 170 171 172 // 'true' indicates TENSOR_FLOAT32 may be calculated with range and/or 173 // precision as low as that of the IEEE 754 16-bit floating-point format. 174 // 'false' indicates TENSOR_FLOAT32 must be calculated using at least the 175 // range and precision of the IEEE 754 32-bit floating-point format. 176 bool mRelaxComputationFloat32toFloat16 = false; 177 }; 178 179 } // namespace nn 180 } // namespace android 181 182 #endif // ANDROID_ML_NN_RUNTIME_MODEL_BUILDER_H 183