1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "ModelBuilder"
18 
19 #include "ModelBuilder.h"
20 
21 #include <GraphDump.h>
22 #include <LegacyUtils.h>
23 #include <ModelUtils.h>
24 #include <android-base/logging.h>
25 #include <nnapi/Validation.h>
26 
27 #include <algorithm>
28 #include <map>
29 #include <memory>
30 #include <set>
31 #include <utility>
32 #include <vector>
33 
34 #include "CompilationBuilder.h"
35 #include "Manager.h"
36 #include "ModelArchHasher.h"
37 #include "TypeManager.h"
38 
39 namespace android {
40 namespace nn {
41 
42 // The maximum number of operands and operations that a model may have.
43 const uint32_t MAX_NUMBER_OF_OPERANDS = 0xFFFFFFFE;
44 const uint32_t MAX_NUMBER_OF_OPERATIONS = 0xFFFFFFFE;
45 
46 #define NN_VALIDATE_NULL_OR_SIZED(tag, data, length)                                          \
47     if ((data == nullptr) != (length == 0)) {                                                 \
48         LOG(ERROR) << "ANeuralNetworksModel_" << tag << " " << #data << " is "                \
49                    << (data == nullptr ? "null" : "not null") << " but " << #length << " is " \
50                    << length;                                                                 \
51         return ANEURALNETWORKS_BAD_DATA;                                                      \
52     }
53 
54 template <typename Type>
makeVector(const Type * data,uint32_t length)55 static std::vector<Type> makeVector(const Type* data, uint32_t length) {
56     return length > 0 ? std::vector<Type>(data, data + length) : std::vector<Type>();
57 }
58 
badState(const char * name)59 bool ModelBuilder::badState(const char* name) {
60     if (mCompletedModel) {
61         LOG(ERROR) << "ANeuralNetworksModel_" << name << " can't modify after model finished";
62         return true;
63     }
64     if (mInvalidModel) {
65         LOG(ERROR) << "ANeuralNetworksModel_" << name << " can't modify an invalid model";
66         return true;
67     }
68     return false;
69 }
70 
getExtensionType(const char * extensionName,uint16_t typeWithinExtension,int32_t * type)71 int ModelBuilder::getExtensionType(const char* extensionName, uint16_t typeWithinExtension,
72                                    int32_t* type) {
73     return TypeManager::get()->getExtensionType(extensionName, typeWithinExtension, type)
74                    ? ANEURALNETWORKS_NO_ERROR
75                    : ANEURALNETWORKS_BAD_DATA;
76 }
77 
addOperand(const ANeuralNetworksOperandType & type)78 int ModelBuilder::addOperand(const ANeuralNetworksOperandType& type) {
79     if (badState("addOperand")) {
80         return ANEURALNETWORKS_BAD_STATE;
81     }
82 
83     OperandType operandType = static_cast<OperandType>(type.type);
84     if (isExtension(operandType) && !TypeManager::get()->areExtensionsAllowed()) {
85         LOG(ERROR) << "Extensions are not supported for this process.";
86         return ANEURALNETWORKS_BAD_DATA;
87     }
88     bool isOemOperand =
89             operandType == OperandType::OEM || operandType == OperandType::TENSOR_OEM_BYTE;
90     if (isOemOperand && !mHasOEMOperand) {
91         LOG(WARNING) << "OEM data type is deprecated. Use Extensions instead.";
92     }
93 
94     const Extension::OperandTypeInformation* info = nullptr;
95     if (isExtension(operandType) &&
96         !TypeManager::get()->getExtensionOperandTypeInfo(operandType, &info)) {
97         LOG(ERROR) << "Extension operand type " << operandType << " is not registered";
98         return ANEURALNETWORKS_BAD_DATA;
99     }
100     NN_VALIDATE_NULL_OR_SIZED("addOperand", type.dimensions, type.dimensionCount);
101     Operand operand = {
102             .type = operandType,
103             .dimensions = makeVector(type.dimensions, type.dimensionCount),
104             .scale = type.scale,
105             .zeroPoint = type.zeroPoint,
106             .lifetime = Operand::LifeTime::TEMPORARY_VARIABLE,
107             .location = {.poolIndex = 0, .offset = 0, .length = 0},
108             .extraParams = {},
109     };
110     if (auto result = validateOperandType(operand, info, "ANeuralNetworksModel_addOperand", true);
111         !result.ok()) {
112         LOG(ERROR) << result.error();
113         return ANEURALNETWORKS_BAD_DATA;
114     }
115 
116     size_t idx = mOperands.size();
117     if (idx >= MAX_NUMBER_OF_OPERANDS) {
118         LOG(ERROR) << "ANeuralNetworksModel_addOperand exceed max operands";
119         return ANEURALNETWORKS_BAD_DATA;
120     }
121 
122     mOperands.push_back(std::move(operand));
123     mHasOEMOperand |= isOemOperand;
124     mHasControlFlow |= (operandType == OperandType::SUBGRAPH);
125     return ANEURALNETWORKS_NO_ERROR;
126 }
127 
setOperandValue(uint32_t index,const void * buffer,size_t length)128 int ModelBuilder::setOperandValue(uint32_t index, const void* buffer, size_t length) {
129     VLOG(MODEL) << __func__ << " for operand " << index << " size " << length;
130     if (badState("setOperandValue")) {
131         return ANEURALNETWORKS_BAD_STATE;
132     }
133 
134     if (index >= operandCount()) {
135         LOG(ERROR) << "ANeuralNetworksModel_setOperandValue setting operand " << index << " of "
136                    << operandCount();
137         return ANEURALNETWORKS_BAD_DATA;
138     }
139     Operand& operand = mOperands[index];
140     NN_VALIDATE_NULL_OR_SIZED("setOperandValue", buffer, length);
141     if (buffer == nullptr) {
142         operand.lifetime = Operand::LifeTime::NO_VALUE;
143         // The location is unused and is set to zeros.
144         operand.location = {.poolIndex = 0, .offset = 0, .length = 0};
145     } else {
146         if (TypeManager::get()->isTensorType(operand.type) &&
147             tensorHasUnspecifiedDimensions(operand)) {
148             LOG(ERROR) << "ANeuralNetworksModel_setOperandValue setting operand " << index
149                        << " which has operand type that is not fully specified";
150             return ANEURALNETWORKS_BAD_DATA;
151         }
152         if (length > 0xFFFFFFFF) {
153             LOG(ERROR) << "ANeuralNetworksModel_setOperandValue value length of " << length
154                        << " exceeds max size";
155             return ANEURALNETWORKS_BAD_DATA;
156         }
157         uint32_t valueLength = static_cast<uint32_t>(length);
158         if (operand.type != OperandType::OEM) {
159             uint32_t neededLength = TypeManager::get()->getSizeOfData(operand);
160             if (neededLength != valueLength) {
161                 LOG(ERROR) << "ANeuralNetworksModel_setOperandValue setting " << valueLength
162                            << " bytes when needing " << neededLength;
163                 return ANEURALNETWORKS_BAD_DATA;
164             }
165         }
166         if (valueLength <= ANEURALNETWORKS_MAX_SIZE_OF_IMMEDIATELY_COPIED_VALUES) {
167             uint32_t existingSize = static_cast<uint32_t>(mSmallOperandValues.size());
168             uint32_t extraBytes = alignBytesNeeded(existingSize, valueLength);
169             mSmallOperandValues.resize(existingSize + extraBytes + valueLength);
170             operand.lifetime = Operand::LifeTime::CONSTANT_COPY;
171             operand.location = {
172                     .poolIndex = 0, .offset = existingSize + extraBytes, .length = valueLength};
173             memcpy(&mSmallOperandValues[operand.location.offset], buffer, valueLength);
174             VLOG(MODEL) << "Copied small value to offset " << operand.location.offset;
175         } else {
176             VLOG(MODEL) << "Saving large value";
177             operand.lifetime = Operand::LifeTime::CONSTANT_REFERENCE;
178             // The values for poolIndex and offset will be set when the model is finished.
179             typedef decltype(operand.location.poolIndex) PoolIndexType;
180             typedef decltype(operand.location.offset) OffsetType;
181             operand.location = {.poolIndex = ~PoolIndexType(0),
182                                 .offset = ~OffsetType(0),
183                                 .length = valueLength};
184             // We keep track of the buffers. We'll allocate the shared memory only
185             // once we know the total size, to avoid needless copies.
186             mLargeOperandValues.push_back(LargeValue{.operandIndex = index, .buffer = buffer});
187         }
188     }
189     return ANEURALNETWORKS_NO_ERROR;
190 }
191 
setOperandValueFromModel(uint32_t index,const ModelBuilder * value)192 int ModelBuilder::setOperandValueFromModel(uint32_t index, const ModelBuilder* value) {
193     VLOG(MODEL) << __func__ << " for operand " << index << " model " << value;
194     if (badState("setOperandValueFromModel")) {
195         return ANEURALNETWORKS_BAD_STATE;
196     }
197     if (!value->mCompletedModel) {
198         LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromModel value model must be finished";
199         return ANEURALNETWORKS_BAD_STATE;
200     }
201     if (value->mInvalidModel) {
202         LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromModel value model is invalid";
203         return ANEURALNETWORKS_BAD_STATE;
204     }
205     if (index >= operandCount()) {
206         LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromModel setting operand " << index
207                    << " of " << operandCount();
208         return ANEURALNETWORKS_BAD_DATA;
209     }
210     Operand& operand = mOperands[index];
211     operand.lifetime = Operand::LifeTime::SUBGRAPH;
212     operand.location = {
213             .poolIndex = 0,
214             .offset = static_cast<uint32_t>(mReferencedModels.size()),
215             .length = 0,
216     };
217     mReferencedModels.push_back(value);
218     mReferencedSubgraphsForValidation.push_back(value->makeModel().main);
219     return ANEURALNETWORKS_NO_ERROR;
220 }
221 
setOperandSymmPerChannelQuantParams(uint32_t index,const ANeuralNetworksSymmPerChannelQuantParams & channelQuant)222 int ModelBuilder::setOperandSymmPerChannelQuantParams(
223         uint32_t index, const ANeuralNetworksSymmPerChannelQuantParams& channelQuant) {
224     if (badState("setOperandSymmPerChannelQuantParams")) {
225         return ANEURALNETWORKS_BAD_STATE;
226     }
227 
228     if (index >= operandCount()) {
229         LOG(ERROR) << "ANeuralNetworksModel_setOperandSymmPerChannelQuantParams "
230                    << "setting per-channel quantization parameters for operand " << index << " of "
231                    << operandCount();
232         return ANEURALNETWORKS_BAD_DATA;
233     }
234     Operand& operand = mOperands[index];
235 
236     NN_VALIDATE_NULL_OR_SIZED("setOperandSymmPerChannelQuantParams", channelQuant.scales,
237                               channelQuant.scaleCount);
238     Operand::SymmPerChannelQuantParams extraParams = {
239             .scales = makeVector(channelQuant.scales, channelQuant.scaleCount),
240             .channelDim = channelQuant.channelDim,
241     };
242     if (auto result = validateOperandSymmPerChannelQuantParams(
243                 operand, extraParams, "ANeuralNetworksModel_setOperandSymmPerChannelQuantParams");
244         !result.ok()) {
245         LOG(ERROR) << result.error();
246         return ANEURALNETWORKS_BAD_DATA;
247     }
248     switch (operand.type) {
249         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
250             operand.extraParams = std::move(extraParams);
251             break;
252         default:
253             LOG(ERROR) << "ANeuralNetworksModel_setOperandSymmPerChannelQuantParams "
254                        << "invalid operand type " << static_cast<int32_t>(operand.type);
255             return ANEURALNETWORKS_BAD_DATA;
256     }
257     return ANEURALNETWORKS_NO_ERROR;
258 }
259 
setOperandExtensionData(uint32_t index,const void * data,size_t length)260 int ModelBuilder::setOperandExtensionData(uint32_t index, const void* data, size_t length) {
261     if (badState("setOperandExtensionData")) {
262         return ANEURALNETWORKS_BAD_STATE;
263     }
264 
265     if (index >= operandCount()) {
266         LOG(ERROR) << "ANeuralNetworksModel_setOperandExtensionData "
267                    << "setting extension data for operand " << index << " of " << operandCount();
268         return ANEURALNETWORKS_BAD_DATA;
269     }
270     Operand& operand = mOperands[index];
271 
272     if (!isExtension(operand.type)) {
273         LOG(ERROR) << "ANeuralNetworksModel_setOperandExtensionData "
274                    << "setting extension data for a base operand type "
275                    << static_cast<int32_t>(operand.type);
276         return ANEURALNETWORKS_BAD_DATA;
277     }
278 
279     NN_VALIDATE_NULL_OR_SIZED("setOperandExtensionData", data, length);
280     if (data == nullptr) {
281         operand.extraParams = {};
282     } else {
283         operand.extraParams = Operand::ExtensionParams(
284                 std::vector<uint8_t>(reinterpret_cast<const uint8_t*>(data),
285                                      reinterpret_cast<const uint8_t*>(data) + length));
286     }
287     return ANEURALNETWORKS_NO_ERROR;
288 }
289 
copyLargeValuesToSharedMemory()290 int ModelBuilder::copyLargeValuesToSharedMemory() {
291     VLOG(MODEL) << __func__ << " has " << mLargeOperandValues.size() << " values.";
292     if (!mLargeOperandValues.empty()) {
293         // Calculate the size of the shared memory needed for all the large values.
294         // Also sets the offset for each value within the memory.
295         size_t poolSize = 0;
296         for (LargeValue& l : mLargeOperandValues) {
297             Operand& operand = mOperands[l.operandIndex];
298             CHECK_EQ(operand.lifetime, Operand::LifeTime::CONSTANT_REFERENCE);
299             poolSize += alignBytesNeeded(poolSize, operand.location.length);
300             operand.location.offset = poolSize;
301             poolSize += operand.location.length;
302         }
303 
304         // Allocate the shared memory.
305         int n;
306         std::tie(n, mLargeValueMemory) = MemoryAshmem::create(poolSize);
307         NN_RETURN_IF_ERROR(n);
308         uint8_t* memoryPointer = mLargeValueMemory->getPointer();
309         uint32_t poolIndex = mMemories.add(mLargeValueMemory.get());
310         VLOG(MODEL) << "Allocated large value pool of size " << poolSize << " at index "
311                     << poolIndex;
312 
313         // Copy the values to this memory.
314         for (LargeValue& l : mLargeOperandValues) {
315             Operand& operand = mOperands[l.operandIndex];
316             operand.location.poolIndex = poolIndex;
317             memcpy(memoryPointer + operand.location.offset, l.buffer, operand.location.length);
318         }
319     }
320 
321     return ANEURALNETWORKS_NO_ERROR;
322 }
323 
setOperandValueFromMemory(uint32_t index,const RuntimeMemory * memory,uint32_t offset,size_t length)324 int ModelBuilder::setOperandValueFromMemory(uint32_t index, const RuntimeMemory* memory,
325                                             uint32_t offset, size_t length) {
326     VLOG(MODEL) << __func__ << " for operand " << index << " offset " << offset << " size "
327                 << length;
328     if (badState("setOperandValueFromMemory")) {
329         return ANEURALNETWORKS_BAD_STATE;
330     }
331 
332     if (index >= operandCount()) {
333         LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromMemory setting operand " << index
334                    << " of " << operandCount();
335         return ANEURALNETWORKS_BAD_DATA;
336     }
337     Operand& operand = mOperands[index];
338     if (TypeManager::get()->isTensorType(operand.type) && tensorHasUnspecifiedDimensions(operand)) {
339         LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromMemory setting operand " << index
340                    << " which has operand type that is not fully specified";
341         return ANEURALNETWORKS_BAD_DATA;
342     }
343     uint32_t neededLength = TypeManager::get()->getSizeOfData(operand);
344     if (neededLength != length) {
345         LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromMemory setting " << length
346                    << " bytes when needing " << neededLength;
347         return ANEURALNETWORKS_BAD_DATA;
348     }
349     // Set compilation = nullptr to indicate that the memory is used for a model constant.
350     // In this case, IOType::INPUT is a placeholder value that is ignored by the validator.
351     if (!memory->getValidator().validate(/*compilation=*/nullptr, /*placeholder*/ IOType::INPUT,
352                                          index, nullptr, offset, length)) {
353         return ANEURALNETWORKS_BAD_DATA;
354     }
355     operand.lifetime = Operand::LifeTime::CONSTANT_REFERENCE;
356     operand.location = {.poolIndex = mMemories.add(memory),
357                         .offset = offset,
358                         .length = static_cast<uint32_t>(length)};
359     return ANEURALNETWORKS_NO_ERROR;
360 }
361 
addOperation(ANeuralNetworksOperationType type,uint32_t inputCount,const uint32_t * inputs,uint32_t outputCount,const uint32_t * outputs)362 int ModelBuilder::addOperation(ANeuralNetworksOperationType type, uint32_t inputCount,
363                                const uint32_t* inputs, uint32_t outputCount,
364                                const uint32_t* outputs) {
365     if (badState("addOperation")) {
366         return ANEURALNETWORKS_BAD_STATE;
367     }
368 
369     OperationType operationType = static_cast<OperationType>(type);
370     if (isExtension(operationType) && !TypeManager::get()->areExtensionsAllowed()) {
371         LOG(ERROR) << "Extensions are not supported for this process.";
372         return ANEURALNETWORKS_BAD_DATA;
373     }
374     if (operationType == OperationType::OEM_OPERATION && !mHasOEMOperation) {
375         LOG(WARNING) << "OEM_OPERATION is deprecated. Use Extensions instead.";
376     }
377 
378     if (!isExtension(operationType)) {
379         bool allowExperimental = false;
380 #ifdef NN_EXPERIMENTAL_FEATURE
381         if (type >= BuiltinOperationResolver::kStartOfExperimentalOperations &&
382             type < BuiltinOperationResolver::kStartOfExperimentalOperations +
383                             BuiltinOperationResolver::kNumberOfExperimentalOperationTypes) {
384             allowExperimental = true;
385         }
386 #endif  // NN_EXPERIMENTAL_FEATURE
387         if (!validCode(kNumberOfOperationTypes, kNumberOfOperationTypesOEM, type) &&
388             !allowExperimental) {
389             LOG(ERROR) << "ANeuralNetworksModel_addOperation invalid operation type " << type;
390             return ANEURALNETWORKS_BAD_DATA;
391         }
392     } else {
393         const Extension* extension;
394         uint16_t extensionPrefix = getExtensionPrefix(static_cast<uint32_t>(operationType));
395         if (!TypeManager::get()->getExtensionInfo(extensionPrefix, &extension)) {
396             LOG(ERROR) << "Extension operation type " << operationType << " is not recognized";
397             return ANEURALNETWORKS_BAD_DATA;
398         }
399     }
400 
401     NN_VALIDATE_NULL_OR_SIZED("addOperation", inputs, inputCount);
402     NN_VALIDATE_NULL_OR_SIZED("addOperation", outputs, outputCount);
403     Operation operation = {
404             .type = operationType,
405             .inputs = makeVector(inputs, inputCount),
406             .outputs = makeVector(outputs, outputCount),
407     };
408     if (auto result = validateOperationButNotOperands(operation, mOperands,
409                                                       mReferencedSubgraphsForValidation);
410         !result.ok()) {
411         LOG(ERROR) << "Invalid Operation: " << result.error();
412         return ANEURALNETWORKS_BAD_DATA;
413     }
414 
415     uint32_t operationIndex = operationCount();
416     if (operationIndex >= MAX_NUMBER_OF_OPERATIONS) {
417         LOG(ERROR) << "ANeuralNetworksModel_addOperation exceed max operations";
418         return ANEURALNETWORKS_BAD_DATA;
419     }
420 
421     mOperations.push_back(std::move(operation));
422     mHasOEMOperation |= (operationType == OperationType::OEM_OPERATION);
423     mHasExtensionOperation |= isExtension(operationType);
424     mHasControlFlow |=
425             (operationType == OperationType::IF || operationType == OperationType::WHILE);
426 
427     return ANEURALNETWORKS_NO_ERROR;
428 }
429 
identifyInputsAndOutputs(uint32_t inputCount,const uint32_t * inputs,uint32_t outputCount,const uint32_t * outputs)430 int ModelBuilder::identifyInputsAndOutputs(uint32_t inputCount, const uint32_t* inputs,
431                                            uint32_t outputCount, const uint32_t* outputs) {
432     if (badState("identifyInputsAndOutputs")) {
433         return ANEURALNETWORKS_BAD_STATE;
434     }
435 
436     NN_VALIDATE_NULL_OR_SIZED("identifyInputsAndOutputs", inputs, inputCount);
437     if (auto result = validateOperandList(makeVector(inputs, inputCount), operandCount(),
438                                           "ANeuralNetworksModel_identifyInputsAndOutputs inputs");
439         !result.ok()) {
440         LOG(ERROR) << result.error();
441         return ANEURALNETWORKS_BAD_DATA;
442     }
443     NN_VALIDATE_NULL_OR_SIZED("identifyInputsAndOutputs", outputs, outputCount);
444     if (auto result = validateOperandList(makeVector(outputs, outputCount), operandCount(),
445                                           "ANeuralNetworksModel_identifyInputsAndOutputs outputs");
446         !result.ok()) {
447         LOG(ERROR) << result.error();
448         return ANEURALNETWORKS_BAD_DATA;
449     }
450 
451     // Makes a copy of the index list, validates the arguments, and changes
452     // the lifetime info of the corresponding operand.
453     auto setArguments = [&](std::vector<uint32_t>* indexVector, uint32_t indexCount,
454                             const uint32_t* indexList, Operand::LifeTime lifetime) -> bool {
455         indexVector->resize(indexCount);
456         for (uint32_t i = 0; i < indexCount; i++) {
457             const uint32_t operandIndex = indexList[i];
458             if (operandIndex >= mOperands.size()) {
459                 LOG(ERROR) << "ANeuralNetworksModel_identifyInputsAndOutputs Can't set input or "
460                               "output "
461                               "to be "
462                            << operandIndex << " as this exceeds the numbe of operands "
463                            << mOperands.size();
464                 return false;
465             }
466             (*indexVector)[i] = operandIndex;
467             Operand& operand = mOperands[operandIndex];
468             if (operand.lifetime != Operand::LifeTime::TEMPORARY_VARIABLE) {
469                 LOG(ERROR) << "ANeuralNetworksModel_identifyInputsAndOutputs Can't set operand "
470                            << operandIndex
471                            << " to be an input or output.  Check that it's not a constant or "
472                               "already an input or output";
473                 return false;
474             }
475             operand.lifetime = lifetime;
476         }
477         return true;
478     };
479 
480     if (!setArguments(&mInputIndexes, inputCount, inputs, Operand::LifeTime::SUBGRAPH_INPUT) ||
481         !setArguments(&mOutputIndexes, outputCount, outputs, Operand::LifeTime::SUBGRAPH_OUTPUT)) {
482         return ANEURALNETWORKS_BAD_DATA;
483     }
484 
485     return ANEURALNETWORKS_NO_ERROR;
486 }
487 
relaxComputationFloat32toFloat16(bool allow)488 int ModelBuilder::relaxComputationFloat32toFloat16(bool allow) {
489     if (badState("relaxComputationFloat32toFloat16")) {
490         return ANEURALNETWORKS_BAD_STATE;
491     }
492 
493     mRelaxComputationFloat32toFloat16 = allow;
494 
495     return ANEURALNETWORKS_NO_ERROR;
496 }
497 
createCompilation(CompilationBuilder ** compilation,const std::vector<std::shared_ptr<Device>> & devices,bool explicitDeviceList)498 int ModelBuilder::createCompilation(CompilationBuilder** compilation,
499                                     const std::vector<std::shared_ptr<Device>>& devices,
500                                     bool explicitDeviceList) {
501     if (!mCompletedModel || mInvalidModel) {
502         LOG(ERROR) << "ANeuralNetworksCompilation_create passed an unfinished or invalid model";
503         *compilation = nullptr;
504         return ANEURALNETWORKS_BAD_STATE;
505     }
506     *compilation = new (std::nothrow) CompilationBuilder(this, devices, explicitDeviceList);
507     return (*compilation ? ANEURALNETWORKS_NO_ERROR : ANEURALNETWORKS_OUT_OF_MEMORY);
508 }
509 
finish()510 int ModelBuilder::finish() {
511     if (mCompletedModel) {
512         LOG(ERROR) << "ANeuralNetworksModel_finish called more than once";
513         return ANEURALNETWORKS_BAD_STATE;
514     }
515     if (mInvalidModel) {
516         LOG(ERROR) << "ANeuralNetworksModel_finish called on an invalid model";
517         return ANEURALNETWORKS_BAD_STATE;
518     }
519 
520     int n = copyLargeValuesToSharedMemory();
521     if (n != ANEURALNETWORKS_NO_ERROR) {
522         return n;
523     }
524 
525     // We sort the operations so that they will be in the appropriate
526     // order for a single-threaded, op at a time execution.
527     // TODO: we don't need this if we always run the partitioner.
528     if (!sortIntoRunOrder()) {
529         // We expect sortIntoRunOrder() to have logged an appropriate error message.
530         mInvalidModel = true;
531         return ANEURALNETWORKS_BAD_DATA;
532     }
533 
534     // TODO: Modify validation so that it can be called without creating a Model.
535     // NOTE: Must sortIntoRunOrder() before validation; validator expects operations
536     //       to have been sorted.
537     // NOTE: Must copyLargeValuesToSharedMemory() before validation; otherwise,
538     //       a CONSTANT_REFERENCE operand will not have correct .poolIndex, and
539     //       validation will not work properly.
540     const Model modelForValidation = makeModel();
541     const auto maybeVersion = validate(modelForValidation);
542     if (!maybeVersion.ok()) {
543         LOG(ERROR) << "ANeuralNetworksModel_finish called on invalid model: "
544                    << maybeVersion.error();
545         mInvalidModel = true;
546         return ANEURALNETWORKS_BAD_DATA;
547     }
548     if (!isCompliantVersion(maybeVersion.value(), DeviceManager::get()->getRuntimeVersion())) {
549         LOG(ERROR) << "ANeuralNetworksModel_finish called on a model that is newer what is "
550                       "allowed. Model version needed: "
551                    << maybeVersion.value() << ", current runtime version supported: "
552                    << DeviceManager::get()->getRuntimeVersion();
553         mInvalidModel = true;
554         return ANEURALNETWORKS_BAD_DATA;
555     }
556     if (VLOG_IS_ON(MODEL)) {
557         graphDump("ModelBuilder::finish", modelForValidation, nullptr);
558     }
559 
560     removeTrailingArgumentsWithDefaultValues();
561     simplifyModel();
562 
563     mCompletedModel = true;
564     CHECK(calcModelArchHash(modelForValidation, mModelArchHash))
565             << "Failed to calculate model arch hash";
566     return ANEURALNETWORKS_NO_ERROR;
567 }
568 
logRemoval(const Operation & operation,uint32_t count,const std::vector<Operand> & operands)569 static void logRemoval(const Operation& operation, uint32_t count,
570                        const std::vector<Operand>& operands) {
571     std::ostringstream message;
572     message << "Operation " << operation.type << " with inputs {";
573     for (uint32_t i = 0; i < operation.inputs.size(); ++i) {
574         if (i != 0) {
575             message << ", ";
576         }
577         message << operands[operation.inputs[i]].type;
578     }
579     message << "} has trailing optional inputs set to default values. Removing " << count
580             << " trailing inputs.";
581     VLOG(MODEL) << message.str();
582 }
583 
removeTrailingArgumentsWithDefaultValues()584 void ModelBuilder::removeTrailingArgumentsWithDefaultValues() {
585     for (Operation& operation : mOperations) {
586         const uint32_t count = getNumTrailingArgumentsToRemove(operation);
587         if (count == 0) {
588             continue;
589         }
590         if (VLOG_IS_ON(MODEL)) {
591             logRemoval(operation, count, mOperands);
592         }
593         const uint32_t inputCount = operation.inputs.size();
594         CHECK_LT(count, inputCount);
595         const uint32_t newInputCount = inputCount - count;
596         operation.inputs.resize(newInputCount);
597     }
598 }
599 
600 // See countMatchingTrailingArguments().
601 enum class TailSpec {
602     BOOL_FALSE,
603     INT32_ONE,
604     INT32_NEGATIVE_ONE,
605 };
606 
607 // See countMatchingTrailingArguments().
matchesSpec(TailSpec spec,const Operand & operand,const std::vector<uint8_t> & mSmallOperandValues)608 static bool matchesSpec(TailSpec spec, const Operand& operand,
609                         const std::vector<uint8_t>& mSmallOperandValues) {
610     const void* valuePtr = nullptr;
611     if (operand.lifetime == Operand::LifeTime::CONSTANT_COPY) {
612         valuePtr = static_cast<const void*>(&mSmallOperandValues[operand.location.offset]);
613     } else if (operand.lifetime == Operand::LifeTime::POINTER) {
614         valuePtr = std::get<const void*>(operand.location.pointer);
615     } else {
616         // CONSTANT_REFERENCE operands are not supported to avoid mapping memory
617         // during compilation.
618         return false;
619     }
620     switch (spec) {
621         case TailSpec::BOOL_FALSE:
622             return operand.type == OperandType::BOOL &&
623                    *static_cast<const bool8*>(valuePtr) == false;
624         case TailSpec::INT32_ONE:
625             return operand.type == OperandType::INT32 &&
626                    *static_cast<const int32_t*>(valuePtr) == 1;
627         case TailSpec::INT32_NEGATIVE_ONE:
628             return operand.type == OperandType::INT32 &&
629                    *static_cast<const int32_t*>(valuePtr) == -1;
630         default:
631             CHECK(false) << "Unhandled TailSpec: " << static_cast<int>(spec);
632     }
633 }
634 
635 // Returns the number of trailing operation inputs that match the specification.
636 //
637 // Example:
638 //     opeation.inputs = {BOOL_TRUE, BOOL_TRUE,  INT32_ONE, INT32_NEGATIVE_ONE}
639 //     tail            =            {BOOL_FALSE, INT32_ONE, INT32_NEGATIVE_ONE}
640 //     tailStartIndex  = 1    matching elements: ^^^^^^^^^  ^^^^^^^^^^^^^^^^^^
countMatchingTrailingArguments(uint32_t tailStartIndex,const std::vector<TailSpec> & tail,const Operation & operation,const std::vector<Operand> & operands,const std::vector<uint8_t> & smallOperandValues)641 static uint32_t countMatchingTrailingArguments(uint32_t tailStartIndex,
642                                                const std::vector<TailSpec>& tail,
643                                                const Operation& operation,
644                                                const std::vector<Operand>& operands,
645                                                const std::vector<uint8_t>& smallOperandValues) {
646     const uint32_t inputCount = operation.inputs.size();
647     uint32_t count = 0;
648     for (uint32_t i = inputCount - 1; i >= tailStartIndex; --i) {
649         const Operand& operand = operands[operation.inputs[i]];
650         if (!matchesSpec(tail[i - tailStartIndex], operand, smallOperandValues)) {
651             break;
652         }
653         ++count;
654     }
655     return count;
656 }
657 
getNumTrailingArgumentsToRemove(const Operation & operation) const658 uint32_t ModelBuilder::getNumTrailingArgumentsToRemove(const Operation& operation) const {
659     const uint32_t inputCount = operation.inputs.size();
660     auto getCount = [this, &operation](uint32_t tailStartIndex, const std::vector<TailSpec>& tail) {
661         return countMatchingTrailingArguments(tailStartIndex, tail, operation, mOperands,
662                                               mSmallOperandValues);
663     };
664     using TS = TailSpec;
665     // Check if the operation has optional arguments that might be set to default
666     // values. Skip the counting if no optional arguments are present.
667     switch (operation.type) {
668         case OperationType::AVERAGE_POOL_2D: {
669             if (inputCount == 11 && mOperands[operation.inputs[7]].type == OperandType::INT32) {
670                 // Explicit padding
671                 // API level 29: 10 to 11 inputs
672                 // API level 27: 10 inputs
673                 return getCount(10, {TS::BOOL_FALSE});
674             } else if (inputCount == 8 &&
675                        mOperands[operation.inputs[7]].type == OperandType::BOOL) {
676                 // Implicit padding
677                 // API level 29: 7 to 8 inputs
678                 // API level 27: 7 inputs
679                 return getCount(7, {TS::BOOL_FALSE});
680             }
681         } break;
682         case OperationType::CONV_2D: {
683             if (10 < inputCount && inputCount <= 13 &&
684                 mOperands[operation.inputs[7]].type == OperandType::INT32) {
685                 // Explicit padding
686                 // API level 29: 10 to 13 inputs
687                 // API level 27: 10 inputs
688                 uint32_t count = getCount(10, {TS::BOOL_FALSE, TS::INT32_ONE, TS::INT32_ONE});
689                 // Inputs 11 and 12 must come together.
690                 return inputCount - count == 12 ? 0 : count;
691             } else if (7 < inputCount && inputCount <= 10 &&
692                        mOperands[operation.inputs[7]].type == OperandType::BOOL) {
693                 // Implicit padding
694                 // API level 29: 7 to 10 inputs
695                 // API level 27: 7 inputs
696                 uint32_t count = getCount(7, {TS::BOOL_FALSE, TS::INT32_ONE, TS::INT32_ONE});
697                 // Inputs 8 and 9 must come together.
698                 return inputCount - count == 9 ? 0 : count;
699             }
700         } break;
701         case OperationType::DEPTHWISE_CONV_2D: {
702             if (11 < inputCount && inputCount <= 14 &&
703                 mOperands[operation.inputs[8]].type == OperandType::INT32) {
704                 // Explicit padding
705                 // API level 29: 11 to 14 inputs
706                 // API level 27: 11 inputs
707                 uint32_t count = getCount(11, {TS::BOOL_FALSE, TS::INT32_ONE, TS::INT32_ONE});
708                 // Inputs 12 and 13 must come together.
709                 return inputCount - count == 13 ? 0 : count;
710             } else if (8 < inputCount && inputCount <= 11 &&
711                        mOperands[operation.inputs[8]].type == OperandType::BOOL) {
712                 // Implicit padding
713                 // API level 29: 8 to 11 inputs
714                 // API level 27: 8 inputs
715                 uint32_t count = getCount(8, {TS::BOOL_FALSE, TS::INT32_ONE, TS::INT32_ONE});
716                 // Inputs 9 and 10 must come together.
717                 return inputCount - count == 10 ? 0 : count;
718             }
719         } break;
720         case OperationType::DEPTH_TO_SPACE: {
721             if (inputCount == 3) {
722                 // API level 29: 2 to 3 inputs
723                 // API level 27: 2 inputs
724                 return getCount(2, {TS::BOOL_FALSE});
725             }
726         } break;
727         case OperationType::L2_NORMALIZATION: {
728             if (inputCount == 2) {
729                 // API level 29: 1 to 2 inputs
730                 // API level 27: 1 inputs
731                 return getCount(1, {TS::INT32_NEGATIVE_ONE});
732             }
733         } break;
734         case OperationType::L2_POOL_2D: {
735             if (inputCount == 11 && mOperands[operation.inputs[7]].type == OperandType::INT32) {
736                 // Explicit padding
737                 // API level 29: 10 to 11 inputs
738                 // API level 27: 10 inputs
739                 return getCount(10, {TS::BOOL_FALSE});
740             } else if (inputCount == 8 &&
741                        mOperands[operation.inputs[7]].type == OperandType::BOOL) {
742                 // Implicit padding
743                 // API level 29: 7 to 8 inputs
744                 // API level 27: 7 inputs
745                 return getCount(7, {TS::BOOL_FALSE});
746             }
747         } break;
748         case OperationType::LOCAL_RESPONSE_NORMALIZATION: {
749             if (inputCount == 6) {
750                 // API level 29: 5 to 6 inputs
751                 // API level 27: 5 inputs
752                 return getCount(5, {TS::INT32_NEGATIVE_ONE});
753             }
754         } break;
755         case OperationType::MAX_POOL_2D: {
756             if (inputCount == 11 && mOperands[operation.inputs[7]].type == OperandType::INT32) {
757                 // Explicit padding
758                 // API level 29: 10 to 11 inputs
759                 // API level 27: 10 inputs
760                 return getCount(10, {TS::BOOL_FALSE});
761             } else if (inputCount == 8 &&
762                        mOperands[operation.inputs[7]].type == OperandType::BOOL) {
763                 // Implicit padding
764                 // API level 29: 7 to 8 inputs
765                 // API level 27: 7 inputs
766                 return getCount(7, {TS::BOOL_FALSE});
767             }
768         } break;
769         case OperationType::RESIZE_BILINEAR: {
770             if (3 < inputCount && inputCount <= 6) {
771                 // By shape:
772                 //     API level 30: 3 to 6 inputs
773                 //     API level 29: 3 to 4 inputs
774                 //     API level 27: 3 inputs
775                 // By scale:
776                 //     API level 30: 3 to 6 inputs
777                 //     API level 29: 3 to 4 inputs
778                 return getCount(3, {TS::BOOL_FALSE, TS::BOOL_FALSE, TS::BOOL_FALSE});
779             }
780         } break;
781         case OperationType::SOFTMAX: {
782             if (inputCount == 3) {
783                 // API level 29: 2 to 3 inputs
784                 // API level 27: 2 inputs
785                 return getCount(2, {TS::INT32_NEGATIVE_ONE});
786             }
787         } break;
788         case OperationType::SPACE_TO_DEPTH: {
789             if (inputCount == 3) {
790                 // API level 29: 2 to 3 inputs
791                 // API level 27: 2 inputs
792                 return getCount(2, {TS::BOOL_FALSE});
793             }
794         } break;
795         case OperationType::BATCH_TO_SPACE_ND: {
796             if (inputCount == 3) {
797                 // API level 29: 2 to 3 inputs
798                 // API level 28: 2 inputs
799                 return getCount(2, {TS::BOOL_FALSE});
800             }
801         } break;
802         case OperationType::SPACE_TO_BATCH_ND: {
803             if (inputCount == 4) {
804                 // API level 29: 3 to 4 inputs
805                 // API level 28: 3 inputs
806                 return getCount(3, {TS::BOOL_FALSE});
807             }
808         } break;
809         case OperationType::RESIZE_NEAREST_NEIGHBOR: {
810             if (4 < inputCount && inputCount <= 6) {
811                 // By shape or scale
812                 // API level 30: 4 to 6 inputs
813                 // API level 29: 4 inputs
814                 return getCount(4, {TS::BOOL_FALSE, TS::BOOL_FALSE});
815             }
816         } break;
817         default: {
818             // Do nothing.
819         } break;
820     }
821     // No trailing optional arguments to check.
822     return 0;
823 }
824 
sortIntoRunOrder()825 bool ModelBuilder::sortIntoRunOrder() {
826     // Note that this may be called before the model has been
827     // validated, so we must code defensively.  However, we can assume
828     // an Operation's inputs and outputs have legal indices -- this
829     // should have been checked in addOperation().
830 
831     if (!mSortedOperationIndexMap.empty()) {
832         LOG(ERROR) << "Operations were already sorted into run order.";
833         return true;
834     }
835 
836     // Tracks the operations that can be executed.
837     std::vector<uint32_t> sortedOperationIndexMap;
838     std::vector<uint32_t> opsReadyToRun;
839     std::vector<Operation> runOrder;
840 
841     // Tracks how many inputs are needed for each operation to be ready to run.
842     std::multimap<uint32_t, uint32_t> operandToOperations;
843     std::vector<uint32_t> unknownInputCount(operationCount());
844     for (uint32_t operationIndex = 0; operationIndex < operationCount(); operationIndex++) {
845         uint32_t& count = unknownInputCount[operationIndex];
846         count = 0;
847         for (uint32_t operandIndex : mOperations[operationIndex].inputs) {
848             auto lifetime = mOperands[operandIndex].lifetime;
849             if (lifetime == Operand::LifeTime::TEMPORARY_VARIABLE ||
850                 lifetime == Operand::LifeTime::SUBGRAPH_OUTPUT) {
851                 count++;
852                 operandToOperations.insert(
853                         std::pair<uint32_t, uint32_t>(operandIndex, operationIndex));
854             }
855         }
856         if (count == 0) {
857             opsReadyToRun.push_back(operationIndex);
858         }
859     }
860 
861     while (opsReadyToRun.size() > 0) {
862         // Execute the next op
863         int opIndex = opsReadyToRun.back();
864         opsReadyToRun.pop_back();
865         const Operation& operation = mOperations[opIndex];
866 
867         runOrder.push_back(mOperations[opIndex]);
868         sortedOperationIndexMap.push_back(opIndex);
869 
870         // Mark all its outputs as known.
871         for (uint32_t operandIndex : operation.outputs) {
872             auto range = operandToOperations.equal_range(operandIndex);
873             for (auto i = range.first; i != range.second; i++) {
874                 uint32_t& count = unknownInputCount[i->second];
875                 if (--count == 0) {
876                     opsReadyToRun.push_back(i->second);
877                 }
878             }
879         }
880     }
881 
882     if (runOrder.size() != mOperations.size()) {
883         CHECK_LT(runOrder.size(), mOperations.size());
884         // Graph must contain at least one cycle or one never-written
885         // operand, because there is at least one Operation that never
886         // became ready.
887         LOG(ERROR) << "Graph contains at least one cycle or one never-written operand";
888         return false;
889     }
890 
891     mSortedOperationIndexMap = std::move(sortedOperationIndexMap);
892     mOperations = std::move(runOrder);
893     return true;
894 }
895 
896 // A helper class to simplify state management when creating a Model.
897 class ModelBuilder::ModelMaker {
898    public:
899     static Model run(const ModelBuilder* model, bool simplifyModel);
900 
901    private:
902     static Model::Subgraph makeSubgraph(const ModelBuilder* model);
ModelMaker(bool simplifyModel)903     explicit ModelMaker(bool simplifyModel) : mSimplifyModel(simplifyModel) {}
904     Model makeModel(const ModelBuilder* mainModel);
905     uint32_t addSubgraph(const ModelBuilder* refModel);
906     void updateOperandLocations(const ModelBuilder* refModel, Model::Subgraph* subgraph);
907     void addExtensions(const ModelBuilder* model);
908     void addExtensionWithPrefix(uint16_t prefix);
909 
910     bool mSimplifyModel;
911     std::vector<Model::Subgraph> mRefSubgraphs;
912     Model::OperandValues mOperandValues;
913     MemoryTracker mMemories;
914     std::vector<ExtensionNameAndPrefix> mExtensionNameToPrefix;
915     std::set<uint16_t> mPrefixSet;
916 };
917 
simplifyModel()918 void ModelBuilder::simplifyModel() {
919     mSimplifyModel = true;
920 }
921 
makeModel() const922 Model ModelBuilder::makeModel() const {
923     // TODO: Cache the Model to speed up subsequent calls.
924     return ModelMaker::run(this, mSimplifyModel);
925 }
926 
run(const ModelBuilder * model,bool simplifyModel)927 Model ModelBuilder::ModelMaker::run(const ModelBuilder* model, bool simplifyModel) {
928     // run() ensures the state of ModelMaker is destroyed after the call.
929     return ModelMaker(simplifyModel).makeModel(model);
930 }
931 
makeModel(const ModelBuilder * mainModel)932 Model ModelBuilder::ModelMaker::makeModel(const ModelBuilder* mainModel) {
933     addExtensions(mainModel);
934     Model model;
935     model.main = makeSubgraph(mainModel);
936     updateOperandLocations(mainModel, &model.main);
937     model.referenced = std::move(mRefSubgraphs);
938     model.operandValues = std::move(mOperandValues);
939     model.pools.reserve(mMemories.size());
940     std::transform(mMemories.begin(), mMemories.end(), std::back_inserter(model.pools),
941                    [](const RuntimeMemory* m) { return m->getMemory(); });
942     model.relaxComputationFloat32toFloat16 = mainModel->mRelaxComputationFloat32toFloat16;
943     model.extensionNameToPrefix = std::move(mExtensionNameToPrefix);
944     if (mSimplifyModel) {
945         removeDeadOperands(&model);
946     }
947     return model;
948 }
949 
makeSubgraph(const ModelBuilder * model)950 Model::Subgraph ModelBuilder::ModelMaker::makeSubgraph(const ModelBuilder* model) {
951     Model::Subgraph subgraph;
952     subgraph.operands = model->mOperands;
953     subgraph.operations = model->mOperations;
954     subgraph.inputIndexes = model->mInputIndexes;
955     subgraph.outputIndexes = model->mOutputIndexes;
956     return subgraph;
957 }
958 
updateOperandLocations(const ModelBuilder * refModel,Model::Subgraph * subgraph)959 void ModelBuilder::ModelMaker::updateOperandLocations(const ModelBuilder* refModel,
960                                                       Model::Subgraph* subgraph) {
961     for (Operand& operand : subgraph->operands) {
962         if (operand.lifetime == Operand::LifeTime::CONSTANT_COPY) {
963             uint32_t valueLength = operand.location.length;
964             uint32_t originalOffset = operand.location.offset;
965             operand.location = mOperandValues.append(&refModel->mSmallOperandValues[originalOffset],
966                                                      valueLength);
967         } else if (operand.lifetime == Operand::LifeTime::CONSTANT_REFERENCE) {
968             uint32_t originalPoolIndex = operand.location.poolIndex;
969             operand.location.poolIndex = mMemories.add(refModel->mMemories[originalPoolIndex]);
970         }
971     }
972     // Do recursive calls at the end to improve locality of mOperandValues.
973     for (Operand& operand : subgraph->operands) {
974         if (operand.lifetime == Operand::LifeTime::SUBGRAPH) {
975             uint32_t refModelIndex = operand.location.offset;
976             // TODO(b/147875885): Avoid creating duplicate refSubgraphs when
977             // a single refModel is referenced multiple times.
978             operand.location.offset = addSubgraph(refModel->mReferencedModels[refModelIndex]);
979         }
980     }
981 }
982 
addSubgraph(const ModelBuilder * refModel)983 uint32_t ModelBuilder::ModelMaker::addSubgraph(const ModelBuilder* refModel) {
984     uint32_t index = mRefSubgraphs.size();
985     mRefSubgraphs.push_back(makeSubgraph(refModel));
986     updateOperandLocations(refModel, &mRefSubgraphs.back());
987     return index;
988 }
989 
addExtensions(const ModelBuilder * model)990 void ModelBuilder::ModelMaker::addExtensions(const ModelBuilder* model) {
991     for (const auto& operand : model->mOperands) {
992         if (isExtension(operand.type)) {
993             addExtensionWithPrefix(static_cast<uint32_t>(operand.type) >> kExtensionTypeBits);
994         }
995     }
996     for (const auto& operation : model->mOperations) {
997         if (isExtension(operation.type)) {
998             addExtensionWithPrefix(static_cast<uint32_t>(operation.type) >> kExtensionTypeBits);
999         }
1000     }
1001     for (const auto& refModel : model->mReferencedModels) {
1002         addExtensions(refModel);
1003     }
1004 }
1005 
addExtensionWithPrefix(uint16_t prefix)1006 void ModelBuilder::ModelMaker::addExtensionWithPrefix(uint16_t prefix) {
1007     if (!mPrefixSet.insert(prefix).second) {
1008         return;
1009     }
1010     const Extension* extension;
1011     CHECK(TypeManager::get()->getExtensionInfo(prefix, &extension));
1012     mExtensionNameToPrefix.push_back({
1013             .name = extension->name,
1014             .prefix = prefix,
1015     });
1016 }
1017 
getModelArchHash() const1018 const uint8_t* ModelBuilder::getModelArchHash() const {
1019     CHECK(mCompletedModel) << "Calling getModelArchHash on non completed model";
1020     return mModelArchHash;
1021 }
1022 
1023 #undef NN_VALIDATE_NULL_OR_SIZED
1024 
1025 }  // namespace nn
1026 }  // namespace android
1027