1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "ModelBuilder"
18 
19 #include "ModelBuilder.h"
20 
21 #include <algorithm>
22 #include <map>
23 #include <memory>
24 #include <set>
25 #include <utility>
26 #include <vector>
27 
28 #include "CompilationBuilder.h"
29 #include "GraphDump.h"
30 #include "Manager.h"
31 #include "TypeManager.h"
32 #include "Utils.h"
33 #include "ValidateHal.h"
34 
35 namespace android {
36 namespace nn {
37 
38 using namespace hal;
39 
40 // The maximum number of operands and operations that a model may have.
41 const uint32_t MAX_NUMBER_OF_OPERANDS = 0xFFFFFFFE;
42 const uint32_t MAX_NUMBER_OF_OPERATIONS = 0xFFFFFFFE;
43 
badState(const char * name)44 bool ModelBuilder::badState(const char* name) {
45     if (mCompletedModel) {
46         LOG(ERROR) << "ANeuralNetworksModel_" << name << " can't modify after model finished";
47         return true;
48     }
49     if (mInvalidModel) {
50         LOG(ERROR) << "ANeuralNetworksModel_" << name << " can't modify an invalid model";
51         return true;
52     }
53     return false;
54 }
55 
getExtensionType(const char * extensionName,uint16_t typeWithinExtension,int32_t * type)56 int ModelBuilder::getExtensionType(const char* extensionName, uint16_t typeWithinExtension,
57                                    int32_t* type) {
58     return TypeManager::get()->getExtensionType(extensionName, typeWithinExtension, type)
59                    ? ANEURALNETWORKS_NO_ERROR
60                    : ANEURALNETWORKS_BAD_DATA;
61 }
62 
addOperand(const ANeuralNetworksOperandType & type)63 int ModelBuilder::addOperand(const ANeuralNetworksOperandType& type) {
64     if (badState("addOperand")) {
65         return ANEURALNETWORKS_BAD_STATE;
66     }
67 
68     OperandType operandType = static_cast<OperandType>(type.type);
69     if (isExtensionOperandType(operandType) && !TypeManager::get()->areExtensionsAllowed()) {
70         LOG(ERROR) << "Extensions are not supported for this process.";
71         return ANEURALNETWORKS_BAD_DATA;
72     }
73     bool isOemOperand =
74             operandType == OperandType::OEM || operandType == OperandType::TENSOR_OEM_BYTE;
75     if (isOemOperand && !mHasOEMOperand) {
76         LOG(WARNING) << "OEM data type is deprecated. Use Extensions instead.";
77     }
78 
79     const Extension::OperandTypeInformation* info = nullptr;
80     if (isExtensionOperandType(operandType) &&
81         !TypeManager::get()->getExtensionOperandTypeInfo(operandType, &info)) {
82         LOG(ERROR) << "Extension operand type " << toString(operandType) << " is not registered";
83         return ANEURALNETWORKS_BAD_DATA;
84     }
85     NN_RETURN_IF_ERROR(validateOperandType(type, info, "ANeuralNetworksModel_addOperand", true));
86     size_t idx = mOperands.size();
87     if (idx >= MAX_NUMBER_OF_OPERANDS) {
88         LOG(ERROR) << "ANeuralNetworksModel_addOperand exceed max operands";
89         return ANEURALNETWORKS_BAD_DATA;
90     }
91 
92     mOperands.push_back({
93             .type = operandType,
94             .dimensions =
95                     hidl_vec<uint32_t>(type.dimensions, type.dimensions + type.dimensionCount),
96             .numberOfConsumers = 0,
97             .scale = type.scale,
98             .zeroPoint = type.zeroPoint,
99             .lifetime = OperandLifeTime::TEMPORARY_VARIABLE,
100             .location = {.poolIndex = 0, .offset = 0, .length = 0},
101             .extraParams = OperandExtraParams(),
102     });
103     mHasOEMOperand |= isOemOperand;
104     return ANEURALNETWORKS_NO_ERROR;
105 }
106 
setOperandValue(uint32_t index,const void * buffer,size_t length)107 int ModelBuilder::setOperandValue(uint32_t index, const void* buffer, size_t length) {
108     VLOG(MODEL) << __func__ << " for operand " << index << " size " << length;
109     if (badState("setOperandValue")) {
110         return ANEURALNETWORKS_BAD_STATE;
111     }
112 
113     if (index >= operandCount()) {
114         LOG(ERROR) << "ANeuralNetworksModel_setOperandValue setting operand " << index << " of "
115                    << operandCount();
116         return ANEURALNETWORKS_BAD_DATA;
117     }
118     Operand& operand = mOperands[index];
119     if (buffer == nullptr) {
120         if (length) {
121             LOG(ERROR) << "ANeuralNetworksModel_setOperandValue buffer is nullptr but length is "
122                           "not 0";
123             return ANEURALNETWORKS_BAD_DATA;
124         }
125         operand.lifetime = OperandLifeTime::NO_VALUE;
126         // The location is unused and is set to zeros.
127         operand.location = {.poolIndex = 0, .offset = 0, .length = 0};
128     } else {
129         if (TypeManager::get()->isTensorType(operand.type) &&
130             tensorHasUnspecifiedDimensions(operand)) {
131             LOG(ERROR) << "ANeuralNetworksModel_setOperandValue setting operand " << index
132                        << " which has operand type that is not fully specified";
133             return ANEURALNETWORKS_BAD_DATA;
134         }
135         if (length > 0xFFFFFFFF) {
136             LOG(ERROR) << "ANeuralNetworksModel_setOperandValue value length of " << length
137                        << " exceeds max size";
138             return ANEURALNETWORKS_BAD_DATA;
139         }
140         uint32_t valueLength = static_cast<uint32_t>(length);
141         if (operand.type != OperandType::OEM) {
142             uint32_t neededLength = TypeManager::get()->getSizeOfData(operand);
143             if (neededLength != valueLength) {
144                 LOG(ERROR) << "ANeuralNetworksModel_setOperandValue setting " << valueLength
145                            << " bytes when needing " << neededLength;
146                 return ANEURALNETWORKS_BAD_DATA;
147             }
148         }
149         if (valueLength <= ANEURALNETWORKS_MAX_SIZE_OF_IMMEDIATELY_COPIED_VALUES) {
150             uint32_t existingSize = static_cast<uint32_t>(mSmallOperandValues.size());
151             uint32_t extraBytes = alignBytesNeeded(existingSize, valueLength);
152             mSmallOperandValues.resize(existingSize + extraBytes + valueLength);
153             operand.lifetime = OperandLifeTime::CONSTANT_COPY;
154             operand.location = {
155                     .poolIndex = 0, .offset = existingSize + extraBytes, .length = valueLength};
156             memcpy(&mSmallOperandValues[operand.location.offset], buffer, valueLength);
157             VLOG(MODEL) << "Copied small value to offset " << operand.location.offset;
158         } else {
159             VLOG(MODEL) << "Saving large value";
160             operand.lifetime = OperandLifeTime::CONSTANT_REFERENCE;
161             // The values for poolIndex and offset will be set when the model is finished.
162             typedef decltype(operand.location.poolIndex) PoolIndexType;
163             typedef decltype(operand.location.offset) OffsetType;
164             operand.location = {.poolIndex = ~PoolIndexType(0),
165                                 .offset = ~OffsetType(0),
166                                 .length = valueLength};
167             // We keep track of the buffers. We'll allocate the shared memory only
168             // once we know the total size, to avoid needless copies.
169             mLargeOperandValues.push_back(LargeValue{.operandIndex = index, .buffer = buffer});
170         }
171     }
172     return ANEURALNETWORKS_NO_ERROR;
173 }
174 
setOperandValueFromModel(uint32_t index,const ModelBuilder * value)175 int ModelBuilder::setOperandValueFromModel(uint32_t index, const ModelBuilder* value) {
176     VLOG(MODEL) << __func__ << " for operand " << index << " model " << value;
177     if (badState("setOperandValueFromModel")) {
178         return ANEURALNETWORKS_BAD_STATE;
179     }
180     if (!value->mCompletedModel) {
181         LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromModel value model must be finished";
182         return ANEURALNETWORKS_BAD_STATE;
183     }
184     if (value->mInvalidModel) {
185         LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromModel value model is invalid";
186         return ANEURALNETWORKS_BAD_STATE;
187     }
188     if (index >= operandCount()) {
189         LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromModel setting operand " << index
190                    << " of " << operandCount();
191         return ANEURALNETWORKS_BAD_DATA;
192     }
193     Operand& operand = mOperands[index];
194     operand.lifetime = OperandLifeTime::SUBGRAPH;
195     operand.location = {
196             .poolIndex = 0,
197             .offset = static_cast<uint32_t>(mReferencedModels.size()),
198             .length = 0,
199     };
200     mReferencedModels.push_back(value);
201     return ANEURALNETWORKS_NO_ERROR;
202 }
203 
setOperandSymmPerChannelQuantParams(uint32_t index,const ANeuralNetworksSymmPerChannelQuantParams & channelQuant)204 int ModelBuilder::setOperandSymmPerChannelQuantParams(
205         uint32_t index, const ANeuralNetworksSymmPerChannelQuantParams& channelQuant) {
206     if (badState("setOperandSymmPerChannelQuantParams")) {
207         return ANEURALNETWORKS_BAD_STATE;
208     }
209 
210     if (index >= operandCount()) {
211         LOG(ERROR) << "ANeuralNetworksModel_setOperandSymmPerChannelQuantParams "
212                    << "setting per-channel quantization parameters for operand " << index << " of "
213                    << operandCount();
214         return ANEURALNETWORKS_BAD_DATA;
215     }
216     Operand& operand = mOperands[index];
217 
218     if (!validateOperandSymmPerChannelQuantParams(
219                 operand, channelQuant,
220                 "ANeuralNetworksModel_setOperandSymmPerChannelQuantParams")) {
221         return ANEURALNETWORKS_BAD_DATA;
222     }
223     switch (operand.type) {
224         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
225             operand.extraParams.channelQuant({
226                     .scales = hidl_vec<float>(channelQuant.scales,
227                                               channelQuant.scales + channelQuant.scaleCount),
228                     .channelDim = channelQuant.channelDim,
229             });
230             break;
231         default:
232             LOG(ERROR) << "ANeuralNetworksModel_setOperandSymmPerChannelQuantParams "
233                        << "invalid operand type " << static_cast<int32_t>(operand.type);
234             return ANEURALNETWORKS_BAD_DATA;
235     }
236     return ANEURALNETWORKS_NO_ERROR;
237 }
238 
setOperandExtensionData(uint32_t index,const void * data,size_t length)239 int ModelBuilder::setOperandExtensionData(uint32_t index, const void* data, size_t length) {
240     if (badState("setOperandExtensionData")) {
241         return ANEURALNETWORKS_BAD_STATE;
242     }
243 
244     if (index >= operandCount()) {
245         LOG(ERROR) << "ANeuralNetworksModel_setOperandExtensionData "
246                    << "setting extension data for operand " << index << " of " << operandCount();
247         return ANEURALNETWORKS_BAD_DATA;
248     }
249     Operand& operand = mOperands[index];
250 
251     if (data == nullptr && length != 0) {
252         LOG(ERROR) << "ANeuralNetworksModel_setOperandExtensionData data is nullptr but length is "
253                    << length;
254         return ANEURALNETWORKS_BAD_DATA;
255     }
256     if (data != nullptr && length == 0) {
257         LOG(ERROR) << "ANeuralNetworksModel_setOperandExtensionData data is not nullptr but length "
258                    << "is zero";
259         return ANEURALNETWORKS_BAD_DATA;
260     }
261     if (!isExtensionOperandType(operand.type)) {
262         LOG(ERROR) << "ANeuralNetworksModel_setOperandExtensionData "
263                    << "setting extension data for a base operand type "
264                    << static_cast<int32_t>(operand.type);
265         return ANEURALNETWORKS_BAD_DATA;
266     }
267 
268     if (data == nullptr) {
269         operand.extraParams.none();
270     } else {
271         operand.extraParams.extension(
272                 hidl_vec<uint8_t>(reinterpret_cast<const uint8_t*>(data),
273                                   reinterpret_cast<const uint8_t*>(data) + length));
274     }
275     return ANEURALNETWORKS_NO_ERROR;
276 }
277 
copyLargeValuesToSharedMemory()278 int ModelBuilder::copyLargeValuesToSharedMemory() {
279     VLOG(MODEL) << __func__ << " has " << mLargeOperandValues.size() << " values.";
280     if (!mLargeOperandValues.empty()) {
281         // Calculate the size of the shared memory needed for all the large values.
282         // Also sets the offset for each value within the memory.
283         size_t poolSize = 0;
284         for (LargeValue& l : mLargeOperandValues) {
285             Operand& operand = mOperands[l.operandIndex];
286             nnAssert(operand.lifetime == OperandLifeTime::CONSTANT_REFERENCE);
287             poolSize += alignBytesNeeded(poolSize, operand.location.length);
288             operand.location.offset = poolSize;
289             poolSize += operand.location.length;
290         }
291 
292         // Allocate the shared memory.
293         int n;
294         std::tie(n, mLargeValueMemory) = MemoryAshmem::create(poolSize);
295         NN_RETURN_IF_ERROR(n);
296         uint8_t* memoryPointer = mLargeValueMemory->getPointer();
297         uint32_t poolIndex = mMemories.add(mLargeValueMemory.get());
298         VLOG(MODEL) << "Allocated large value pool of size " << poolSize << " at index "
299                     << poolIndex;
300 
301         // Copy the values to this memory.
302         for (LargeValue& l : mLargeOperandValues) {
303             Operand& operand = mOperands[l.operandIndex];
304             operand.location.poolIndex = poolIndex;
305             memcpy(memoryPointer + operand.location.offset, l.buffer, operand.location.length);
306         }
307     }
308     return ANEURALNETWORKS_NO_ERROR;
309 }
310 
setOperandValueFromMemory(uint32_t index,const Memory * memory,uint32_t offset,size_t length)311 int ModelBuilder::setOperandValueFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
312                                             size_t length) {
313     VLOG(MODEL) << __func__ << " for operand " << index << " offset " << offset << " size "
314                 << length;
315     if (badState("setOperandValueFromMemory")) {
316         return ANEURALNETWORKS_BAD_STATE;
317     }
318 
319     if (index >= operandCount()) {
320         LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromMemory setting operand " << index
321                    << " of " << operandCount();
322         return ANEURALNETWORKS_BAD_DATA;
323     }
324     Operand& operand = mOperands[index];
325     if (TypeManager::get()->isTensorType(operand.type) && tensorHasUnspecifiedDimensions(operand)) {
326         LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromMemory setting operand " << index
327                    << " which has operand type that is not fully specified";
328         return ANEURALNETWORKS_BAD_DATA;
329     }
330     uint32_t neededLength = TypeManager::get()->getSizeOfData(operand);
331     if (neededLength != length) {
332         LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromMemory setting " << length
333                    << " bytes when needing " << neededLength;
334         return ANEURALNETWORKS_BAD_DATA;
335     }
336     // Set compilation = nullptr to indicate that the memory is used for a model constant.
337     // In this case, IOType::INPUT is a dummy value that is ignored by the validator.
338     if (!memory->getValidator().validate(/*compilation=*/nullptr, /*dummy*/ IOType::INPUT, index,
339                                          nullptr, offset, length)) {
340         return ANEURALNETWORKS_BAD_DATA;
341     }
342     operand.lifetime = OperandLifeTime::CONSTANT_REFERENCE;
343     operand.location = {.poolIndex = mMemories.add(memory),
344                         .offset = offset,
345                         .length = static_cast<uint32_t>(length)};
346     return ANEURALNETWORKS_NO_ERROR;
347 }
348 
addOperation(ANeuralNetworksOperationType type,uint32_t inputCount,const uint32_t * inputs,uint32_t outputCount,const uint32_t * outputs)349 int ModelBuilder::addOperation(ANeuralNetworksOperationType type, uint32_t inputCount,
350                                const uint32_t* inputs, uint32_t outputCount,
351                                const uint32_t* outputs) {
352     if (badState("addOperation")) {
353         return ANEURALNETWORKS_BAD_STATE;
354     }
355 
356     OperationType operationType = static_cast<OperationType>(type);
357     if (isExtensionOperationType(operationType) && !TypeManager::get()->areExtensionsAllowed()) {
358         LOG(ERROR) << "Extensions are not supported for this process.";
359         return ANEURALNETWORKS_BAD_DATA;
360     }
361     if (operationType == OperationType::OEM_OPERATION && !mHasOEMOperation) {
362         LOG(WARNING) << "OEM_OPERATION is deprecated. Use Extensions instead.";
363     }
364 
365     if (!isExtensionOperationType(operationType)) {
366         if (!validCode(kNumberOfOperationTypes, kNumberOfOperationTypesOEM, type)) {
367             LOG(ERROR) << "ANeuralNetworksModel_addOperation invalid operation type " << type;
368             return ANEURALNETWORKS_BAD_DATA;
369         }
370     }
371 
372     auto isValidSubgraphReference = [this](const Operand& modelOperand) -> bool {
373         NN_RET_CHECK(modelOperand.type == OperandType::SUBGRAPH)
374                 << "Unexpected operand type: " << toString(modelOperand.type);
375         NN_RET_CHECK_LT(modelOperand.location.offset, referencedModelCount())
376                 << "Invalid subgraph model reference";
377         return true;
378     };
379     auto getInputCount = [this](const Operand& modelOperand) -> uint32_t {
380         return getReferencedModel(modelOperand)->inputCount();
381     };
382     auto getOutputCount = [this](const Operand& modelOperand) -> uint32_t {
383         return getReferencedModel(modelOperand)->outputCount();
384     };
385     auto getInputOperand = [this](const Operand& modelOperand, uint32_t index) -> const Operand* {
386         return &getReferencedModel(modelOperand)->getInputOperand(index);
387     };
388     auto getOutputOperand = [this](const Operand& modelOperand, uint32_t index) -> const Operand* {
389         return &getReferencedModel(modelOperand)->getOutputOperand(index);
390     };
391     NN_RETURN_IF_ERROR(validateOperation(
392             type, inputCount, inputs, outputCount, outputs, mOperands, HalVersion::LATEST,
393             {.isValidSubgraphReference = isValidSubgraphReference,
394              .getSubgraphInputCount = getInputCount,
395              .getSubgraphOutputCount = getOutputCount,
396              .getSubgraphInputOperand = getInputOperand,
397              .getSubgraphOutputOperand = getOutputOperand,
398              .allowControlFlowOperationWithOperandOfUnknownSize = true}));
399 
400     uint32_t operationIndex = operationCount();
401     if (operationIndex >= MAX_NUMBER_OF_OPERATIONS) {
402         LOG(ERROR) << "ANeuralNetworksModel_addOperation exceed max operations";
403         return ANEURALNETWORKS_BAD_DATA;
404     }
405 
406     mOperations.push_back({
407             .type = operationType,
408             .inputs = hidl_vec<uint32_t>(inputs, inputs + inputCount),
409             .outputs = hidl_vec<uint32_t>(outputs, outputs + outputCount),
410     });
411     for (uint32_t i : mOperations.back().inputs) {
412         mOperands[i].numberOfConsumers++;
413     }
414     mHasOEMOperation |= (operationType == OperationType::OEM_OPERATION);
415     mHasExtensionOperation |= isExtensionOperationType(operationType);
416 
417     return ANEURALNETWORKS_NO_ERROR;
418 }
419 
identifyInputsAndOutputs(uint32_t inputCount,const uint32_t * inputs,uint32_t outputCount,const uint32_t * outputs)420 int ModelBuilder::identifyInputsAndOutputs(uint32_t inputCount, const uint32_t* inputs,
421                                            uint32_t outputCount, const uint32_t* outputs) {
422     if (badState("identifyInputsAndOutputs")) {
423         return ANEURALNETWORKS_BAD_STATE;
424     }
425 
426     int n = validateOperandList(inputCount, inputs, operandCount(),
427                                 "ANeuralNetworksModel_identifyInputsAndOutputs inputs");
428     if (n != ANEURALNETWORKS_NO_ERROR) {
429         return n;
430     }
431     n = validateOperandList(outputCount, outputs, operandCount(),
432                             "ANeuralNetworksModel_identifyInputsAndOutputs outputs");
433     if (n != ANEURALNETWORKS_NO_ERROR) {
434         return n;
435     }
436 
437     // Makes a copy of the index list, validates the arguments, and changes
438     // the lifetime info of the corresponding operand.
439     auto setArguments = [&](std::vector<uint32_t>* indexVector, uint32_t indexCount,
440                             const uint32_t* indexList, OperandLifeTime lifetime) -> bool {
441         indexVector->resize(indexCount);
442         for (uint32_t i = 0; i < indexCount; i++) {
443             const uint32_t operandIndex = indexList[i];
444             if (operandIndex >= mOperands.size()) {
445                 LOG(ERROR) << "ANeuralNetworksModel_identifyInputsAndOutputs Can't set input or "
446                               "output "
447                               "to be "
448                            << operandIndex << " as this exceeds the numbe of operands "
449                            << mOperands.size();
450                 return false;
451             }
452             (*indexVector)[i] = operandIndex;
453             Operand& operand = mOperands[operandIndex];
454             if (operand.lifetime != OperandLifeTime::TEMPORARY_VARIABLE) {
455                 LOG(ERROR) << "ANeuralNetworksModel_identifyInputsAndOutputs Can't set operand "
456                            << operandIndex
457                            << " to be an input or output.  Check that it's not a constant or "
458                               "already an input or output";
459                 return false;
460             }
461             operand.lifetime = lifetime;
462         }
463         return true;
464     };
465 
466     if (!setArguments(&mInputIndexes, inputCount, inputs, OperandLifeTime::SUBGRAPH_INPUT) ||
467         !setArguments(&mOutputIndexes, outputCount, outputs, OperandLifeTime::SUBGRAPH_OUTPUT)) {
468         return ANEURALNETWORKS_BAD_DATA;
469     }
470 
471     return ANEURALNETWORKS_NO_ERROR;
472 }
473 
relaxComputationFloat32toFloat16(bool allow)474 int ModelBuilder::relaxComputationFloat32toFloat16(bool allow) {
475     if (badState("relaxComputationFloat32toFloat16")) {
476         return ANEURALNETWORKS_BAD_STATE;
477     }
478 
479     mRelaxComputationFloat32toFloat16 = allow;
480 
481     return ANEURALNETWORKS_NO_ERROR;
482 }
483 
createCompilation(CompilationBuilder ** compilation,const std::vector<std::shared_ptr<Device>> & devices,bool explicitDeviceList)484 int ModelBuilder::createCompilation(CompilationBuilder** compilation,
485                                     const std::vector<std::shared_ptr<Device>>& devices,
486                                     bool explicitDeviceList) {
487     if (!mCompletedModel || mInvalidModel) {
488         LOG(ERROR) << "ANeuralNetworksCompilation_create passed an unfinished or invalid model";
489         *compilation = nullptr;
490         return ANEURALNETWORKS_BAD_STATE;
491     }
492     *compilation = new (std::nothrow) CompilationBuilder(this, devices, explicitDeviceList);
493     return (*compilation ? ANEURALNETWORKS_NO_ERROR : ANEURALNETWORKS_OUT_OF_MEMORY);
494 }
495 
finish()496 int ModelBuilder::finish() {
497     if (mCompletedModel) {
498         LOG(ERROR) << "ANeuralNetworksModel_finish called more than once";
499         return ANEURALNETWORKS_BAD_STATE;
500     }
501     if (mInvalidModel) {
502         LOG(ERROR) << "ANeuralNetworksModel_finish called on an invalid model";
503         return ANEURALNETWORKS_BAD_STATE;
504     }
505 
506     int n = copyLargeValuesToSharedMemory();
507     if (n != ANEURALNETWORKS_NO_ERROR) {
508         return n;
509     }
510 
511     // We sort the operations so that they will be in the appropriate
512     // order for a single-threaded, op at a time execution.
513     // TODO: we don't need this if we always run the partitioner.
514     if (!sortIntoRunOrder()) {
515         // We expect sortIntoRunOrder() to have logged an appropriate error message.
516         mInvalidModel = true;
517         return ANEURALNETWORKS_BAD_DATA;
518     }
519 
520     // TODO: Modify validation so that it can be called without creating a HAL Model.
521     // NOTE: Must sortIntoRunOrder() before validation; validator expects operations
522     //       to have been sorted.
523     // NOTE: Must copyLargeValuesToSharedMemory() before validation; otherwise,
524     //       a CONSTANT_REFERENCE operand will not have correct .poolIndex, and
525     //       validation will not work properly.
526     const Model modelForValidation = makeHidlModel();
527     if (!validateModel(modelForValidation, ValidationMode::RUNTIME)) {
528         LOG(ERROR) << "ANeuralNetworksModel_finish called on invalid model";
529         mInvalidModel = true;
530         return ANEURALNETWORKS_BAD_DATA;
531     }
532     if (VLOG_IS_ON(MODEL)) {
533         graphDump("ModelBuilder::finish", modelForValidation, nullptr);
534     }
535 
536     removeTrailingArgumentsWithDefaultValues();
537 
538     mCompletedModel = true;
539     return ANEURALNETWORKS_NO_ERROR;
540 }
541 
logRemoval(const Operation & operation,uint32_t count,const std::vector<Operand> & operands)542 static void logRemoval(const Operation& operation, uint32_t count,
543                        const std::vector<Operand>& operands) {
544     std::ostringstream message;
545     message << "Operation " << toString(operation.type) << " with inputs {";
546     for (uint32_t i = 0; i < operation.inputs.size(); ++i) {
547         if (i != 0) {
548             message << ", ";
549         }
550         message << toString(operands[operation.inputs[i]].type);
551     }
552     message << "} has trailing optional inputs set to default values. Removing " << count
553             << " trailing inputs.";
554     VLOG(MODEL) << message.str();
555 }
556 
removeTrailingArgumentsWithDefaultValues()557 void ModelBuilder::removeTrailingArgumentsWithDefaultValues() {
558     for (Operation& operation : mOperations) {
559         const uint32_t count = getNumTrailingArgumentsToRemove(operation);
560         if (count == 0) {
561             continue;
562         }
563         if (VLOG_IS_ON(MODEL)) {
564             logRemoval(operation, count, mOperands);
565         }
566         const uint32_t inputCount = operation.inputs.size();
567         CHECK_LT(count, inputCount);
568         const uint32_t newInputCount = inputCount - count;
569         for (uint32_t i = newInputCount; i < inputCount; ++i) {
570             --mOperands[operation.inputs[i]].numberOfConsumers;
571         }
572         operation.inputs.resize(newInputCount);
573     }
574 }
575 
576 // See countMatchingTrailingArguments().
577 enum class TailSpec {
578     BOOL_FALSE,
579     INT32_ONE,
580     INT32_NEGATIVE_ONE,
581 };
582 
583 // See countMatchingTrailingArguments().
matchesSpec(TailSpec spec,const Operand & operand,const std::vector<uint8_t> & mSmallOperandValues)584 static bool matchesSpec(TailSpec spec, const Operand& operand,
585                         const std::vector<uint8_t>& mSmallOperandValues) {
586     if (operand.lifetime != OperandLifeTime::CONSTANT_COPY) {
587         // CONSTANT_REFERENCE operands are not supported to avoid mapping memory
588         // during compilation.
589         return false;
590     }
591     auto valuePtr = static_cast<const void*>(&mSmallOperandValues[operand.location.offset]);
592     switch (spec) {
593         case TailSpec::BOOL_FALSE:
594             return operand.type == OperandType::BOOL &&
595                    *static_cast<const bool8*>(valuePtr) == false;
596         case TailSpec::INT32_ONE:
597             return operand.type == OperandType::INT32 &&
598                    *static_cast<const int32_t*>(valuePtr) == 1;
599         case TailSpec::INT32_NEGATIVE_ONE:
600             return operand.type == OperandType::INT32 &&
601                    *static_cast<const int32_t*>(valuePtr) == -1;
602         default:
603             CHECK(false) << "Unhandled TailSpec: " << static_cast<int>(spec);
604     }
605 }
606 
607 // Returns the number of trailing operation inputs that match the specification.
608 //
609 // Example:
610 //     opeation.inputs = {BOOL_TRUE, BOOL_TRUE,  INT32_ONE, INT32_NEGATIVE_ONE}
611 //     tail            =            {BOOL_FALSE, INT32_ONE, INT32_NEGATIVE_ONE}
612 //     tailStartIndex  = 1    matching elements: ^^^^^^^^^  ^^^^^^^^^^^^^^^^^^
countMatchingTrailingArguments(uint32_t tailStartIndex,const std::vector<TailSpec> & tail,const Operation & operation,const std::vector<Operand> & operands,const std::vector<uint8_t> & smallOperandValues)613 static uint32_t countMatchingTrailingArguments(uint32_t tailStartIndex,
614                                                const std::vector<TailSpec>& tail,
615                                                const Operation& operation,
616                                                const std::vector<Operand>& operands,
617                                                const std::vector<uint8_t>& smallOperandValues) {
618     const uint32_t inputCount = operation.inputs.size();
619     uint32_t count = 0;
620     for (uint32_t i = inputCount - 1; i >= tailStartIndex; --i) {
621         const Operand& operand = operands[operation.inputs[i]];
622         if (!matchesSpec(tail[i - tailStartIndex], operand, smallOperandValues)) {
623             break;
624         }
625         ++count;
626     }
627     return count;
628 }
629 
getNumTrailingArgumentsToRemove(const Operation & operation) const630 uint32_t ModelBuilder::getNumTrailingArgumentsToRemove(const Operation& operation) const {
631     const uint32_t inputCount = operation.inputs.size();
632     auto getCount = [this, &operation](uint32_t tailStartIndex, const std::vector<TailSpec>& tail) {
633         return countMatchingTrailingArguments(tailStartIndex, tail, operation, mOperands,
634                                               mSmallOperandValues);
635     };
636     using TS = TailSpec;
637     // Check if the operation has optional arguments that might be set to default
638     // values. Skip the counting if no optional arguments are present.
639     switch (operation.type) {
640         case OperationType::AVERAGE_POOL_2D: {
641             if (inputCount == 11 && mOperands[operation.inputs[7]].type == OperandType::INT32) {
642                 // Explicit padding
643                 // API level 29: 10 to 11 inputs
644                 // API level 27: 10 inputs
645                 return getCount(10, {TS::BOOL_FALSE});
646             } else if (inputCount == 8 &&
647                        mOperands[operation.inputs[7]].type == OperandType::BOOL) {
648                 // Implicit padding
649                 // API level 29: 7 to 8 inputs
650                 // API level 27: 7 inputs
651                 return getCount(7, {TS::BOOL_FALSE});
652             }
653         } break;
654         case OperationType::CONV_2D: {
655             if (10 < inputCount && inputCount <= 13 &&
656                 mOperands[operation.inputs[7]].type == OperandType::INT32) {
657                 // Explicit padding
658                 // API level 29: 10 to 13 inputs
659                 // API level 27: 10 inputs
660                 uint32_t count = getCount(10, {TS::BOOL_FALSE, TS::INT32_ONE, TS::INT32_ONE});
661                 // Inputs 11 and 12 must come together.
662                 return inputCount - count == 12 ? 0 : count;
663             } else if (7 < inputCount && inputCount <= 10 &&
664                        mOperands[operation.inputs[7]].type == OperandType::BOOL) {
665                 // Implicit padding
666                 // API level 29: 7 to 10 inputs
667                 // API level 27: 7 inputs
668                 uint32_t count = getCount(7, {TS::BOOL_FALSE, TS::INT32_ONE, TS::INT32_ONE});
669                 // Inputs 8 and 9 must come together.
670                 return inputCount - count == 9 ? 0 : count;
671             }
672         } break;
673         case OperationType::DEPTHWISE_CONV_2D: {
674             if (11 < inputCount && inputCount <= 14 &&
675                 mOperands[operation.inputs[8]].type == OperandType::INT32) {
676                 // Explicit padding
677                 // API level 29: 11 to 14 inputs
678                 // API level 27: 11 inputs
679                 uint32_t count = getCount(11, {TS::BOOL_FALSE, TS::INT32_ONE, TS::INT32_ONE});
680                 // Inputs 12 and 13 must come together.
681                 return inputCount - count == 13 ? 0 : count;
682             } else if (8 < inputCount && inputCount <= 11 &&
683                        mOperands[operation.inputs[8]].type == OperandType::BOOL) {
684                 // Implicit padding
685                 // API level 29: 8 to 11 inputs
686                 // API level 27: 8 inputs
687                 uint32_t count = getCount(8, {TS::BOOL_FALSE, TS::INT32_ONE, TS::INT32_ONE});
688                 // Inputs 9 and 10 must come together.
689                 return inputCount - count == 10 ? 0 : count;
690             }
691         } break;
692         case OperationType::DEPTH_TO_SPACE: {
693             if (inputCount == 3) {
694                 // API level 29: 2 to 3 inputs
695                 // API level 27: 2 inputs
696                 return getCount(2, {TS::BOOL_FALSE});
697             }
698         } break;
699         case OperationType::L2_NORMALIZATION: {
700             if (inputCount == 2) {
701                 // API level 29: 1 to 2 inputs
702                 // API level 27: 1 inputs
703                 return getCount(1, {TS::INT32_NEGATIVE_ONE});
704             }
705         } break;
706         case OperationType::L2_POOL_2D: {
707             if (inputCount == 11 && mOperands[operation.inputs[7]].type == OperandType::INT32) {
708                 // Explicit padding
709                 // API level 29: 10 to 11 inputs
710                 // API level 27: 10 inputs
711                 return getCount(10, {TS::BOOL_FALSE});
712             } else if (inputCount == 8 &&
713                        mOperands[operation.inputs[7]].type == OperandType::BOOL) {
714                 // Implicit padding
715                 // API level 29: 7 to 8 inputs
716                 // API level 27: 7 inputs
717                 return getCount(7, {TS::BOOL_FALSE});
718             }
719         } break;
720         case OperationType::LOCAL_RESPONSE_NORMALIZATION: {
721             if (inputCount == 6) {
722                 // API level 29: 5 to 6 inputs
723                 // API level 27: 5 inputs
724                 return getCount(5, {TS::INT32_NEGATIVE_ONE});
725             }
726         } break;
727         case OperationType::MAX_POOL_2D: {
728             if (inputCount == 11 && mOperands[operation.inputs[7]].type == OperandType::INT32) {
729                 // Explicit padding
730                 // API level 29: 10 to 11 inputs
731                 // API level 27: 10 inputs
732                 return getCount(10, {TS::BOOL_FALSE});
733             } else if (inputCount == 8 &&
734                        mOperands[operation.inputs[7]].type == OperandType::BOOL) {
735                 // Implicit padding
736                 // API level 29: 7 to 8 inputs
737                 // API level 27: 7 inputs
738                 return getCount(7, {TS::BOOL_FALSE});
739             }
740         } break;
741         case OperationType::RESIZE_BILINEAR: {
742             if (3 < inputCount && inputCount <= 6) {
743                 // By shape:
744                 //     API level 30: 3 to 6 inputs
745                 //     API level 29: 3 to 4 inputs
746                 //     API level 27: 3 inputs
747                 // By scale:
748                 //     API level 30: 3 to 6 inputs
749                 //     API level 29: 3 to 4 inputs
750                 return getCount(3, {TS::BOOL_FALSE, TS::BOOL_FALSE, TS::BOOL_FALSE});
751             }
752         } break;
753         case OperationType::SOFTMAX: {
754             if (inputCount == 3) {
755                 // API level 29: 2 to 3 inputs
756                 // API level 27: 2 inputs
757                 return getCount(2, {TS::INT32_NEGATIVE_ONE});
758             }
759         } break;
760         case OperationType::SPACE_TO_DEPTH: {
761             if (inputCount == 3) {
762                 // API level 29: 2 to 3 inputs
763                 // API level 27: 2 inputs
764                 return getCount(2, {TS::BOOL_FALSE});
765             }
766         } break;
767         case OperationType::BATCH_TO_SPACE_ND: {
768             if (inputCount == 3) {
769                 // API level 29: 2 to 3 inputs
770                 // API level 28: 2 inputs
771                 return getCount(2, {TS::BOOL_FALSE});
772             }
773         } break;
774         case OperationType::SPACE_TO_BATCH_ND: {
775             if (inputCount == 4) {
776                 // API level 29: 3 to 4 inputs
777                 // API level 28: 3 inputs
778                 return getCount(3, {TS::BOOL_FALSE});
779             }
780         } break;
781         case OperationType::RESIZE_NEAREST_NEIGHBOR: {
782             if (4 < inputCount && inputCount <= 6) {
783                 // By shape or scale
784                 // API level 30: 4 to 6 inputs
785                 // API level 29: 4 inputs
786                 return getCount(4, {TS::BOOL_FALSE, TS::BOOL_FALSE});
787             }
788         } break;
789         default: {
790             // Do nothing.
791         } break;
792     }
793     // No trailing optional arguments to check.
794     return 0;
795 }
796 
sortIntoRunOrder()797 bool ModelBuilder::sortIntoRunOrder() {
798     // Note that this may be called before the model has been
799     // validated, so we must code defensively.  However, we can assume
800     // an Operation's inputs and outputs have legal indices -- this
801     // should have been checked in addOperation().
802 
803     if (!mSortedOperationIndexMap.empty()) {
804         LOG(ERROR) << "Operations were already sorted into run order.";
805         return true;
806     }
807 
808     // Tracks the operations that can be executed.
809     std::vector<uint32_t> sortedOperationIndexMap;
810     std::vector<uint32_t> opsReadyToRun;
811     std::vector<Operation> runOrder;
812 
813     // Tracks how many inputs are needed for each operation to be ready to run.
814     std::multimap<uint32_t, uint32_t> operandToOperations;
815     std::vector<uint32_t> unknownInputCount(operationCount());
816     for (uint32_t operationIndex = 0; operationIndex < operationCount(); operationIndex++) {
817         uint32_t& count = unknownInputCount[operationIndex];
818         count = 0;
819         for (uint32_t operandIndex : mOperations[operationIndex].inputs) {
820             auto lifetime = mOperands[operandIndex].lifetime;
821             if (lifetime == OperandLifeTime::TEMPORARY_VARIABLE ||
822                 lifetime == OperandLifeTime::SUBGRAPH_OUTPUT) {
823                 count++;
824                 operandToOperations.insert(
825                         std::pair<uint32_t, uint32_t>(operandIndex, operationIndex));
826             }
827         }
828         if (count == 0) {
829             opsReadyToRun.push_back(operationIndex);
830         }
831     }
832 
833     while (opsReadyToRun.size() > 0) {
834         // Execute the next op
835         int opIndex = opsReadyToRun.back();
836         opsReadyToRun.pop_back();
837         const Operation& operation = mOperations[opIndex];
838 
839         runOrder.push_back(mOperations[opIndex]);
840         sortedOperationIndexMap.push_back(opIndex);
841 
842         // Mark all its outputs as known.
843         for (uint32_t operandIndex : operation.outputs) {
844             auto range = operandToOperations.equal_range(operandIndex);
845             for (auto i = range.first; i != range.second; i++) {
846                 uint32_t& count = unknownInputCount[i->second];
847                 if (--count == 0) {
848                     opsReadyToRun.push_back(i->second);
849                 }
850             }
851         }
852     }
853 
854     if (runOrder.size() != mOperations.size()) {
855         nnAssert(runOrder.size() < mOperations.size());
856         // Graph must contain at least one cycle or one never-written
857         // operand, because there is at least one Operation that never
858         // became ready.
859         LOG(ERROR) << "Graph contains at least one cycle or one never-written operand";
860         return false;
861     }
862 
863     mSortedOperationIndexMap = std::move(sortedOperationIndexMap);
864     mOperations = std::move(runOrder);
865     return true;
866 }
867 
868 // A helper class to simplify state management when creating a HIDL model.
869 class ModelBuilder::HidlModelMaker {
870    public:
871     static Model run(const ModelBuilder* model);
872 
873    private:
874     static Subgraph makeSubgraph(const ModelBuilder* model);
HidlModelMaker()875     HidlModelMaker() {}
876     Model makeHidlModel(const ModelBuilder* mainModel);
877     uint32_t addSubgraph(const ModelBuilder* refModel);
878     void updateOperandLocations(const ModelBuilder* refModel, Subgraph* subgraph);
879     void addExtensions(const ModelBuilder* model);
880     void addExtensionWithPrefix(uint16_t prefix);
881 
882     std::vector<Subgraph> mRefSubgraphs;
883     std::vector<uint8_t> mOperandValues;
884     MemoryTracker mMemories;
885     std::vector<ExtensionNameAndPrefix> mExtensionNameToPrefix;
886     std::set<uint16_t> mPrefixSet;
887 };
888 
makeHidlModel() const889 Model ModelBuilder::makeHidlModel() const {
890     // TODO: Cache the HIDL model to speed up subsequent calls.
891     return HidlModelMaker::run(this);
892 }
893 
run(const ModelBuilder * model)894 Model ModelBuilder::HidlModelMaker::run(const ModelBuilder* model) {
895     // run() ensures the state of HidlModelMaker is destroyed after the call.
896     return HidlModelMaker().makeHidlModel(model);
897 }
898 
makeHidlModel(const ModelBuilder * mainModel)899 Model ModelBuilder::HidlModelMaker::makeHidlModel(const ModelBuilder* mainModel) {
900     addExtensions(mainModel);
901     Model model;
902     model.main = makeSubgraph(mainModel);
903     updateOperandLocations(mainModel, &model.main);
904     model.referenced = std::move(mRefSubgraphs);
905     model.operandValues = std::move(mOperandValues);
906     model.pools.resize(mMemories.size());
907     std::transform(mMemories.begin(), mMemories.end(), model.pools.begin(),
908                    [](const Memory* m) { return m->getHidlMemory(); });
909     model.relaxComputationFloat32toFloat16 = mainModel->mRelaxComputationFloat32toFloat16;
910     model.extensionNameToPrefix = std::move(mExtensionNameToPrefix);
911     return model;
912 }
913 
makeSubgraph(const ModelBuilder * model)914 Subgraph ModelBuilder::HidlModelMaker::makeSubgraph(const ModelBuilder* model) {
915     Subgraph subgraph;
916     subgraph.operands = model->mOperands;
917     subgraph.operations = model->mOperations;
918     subgraph.inputIndexes = model->mInputIndexes;
919     subgraph.outputIndexes = model->mOutputIndexes;
920     return subgraph;
921 }
922 
updateOperandLocations(const ModelBuilder * refModel,Subgraph * subgraph)923 void ModelBuilder::HidlModelMaker::updateOperandLocations(const ModelBuilder* refModel,
924                                                           Subgraph* subgraph) {
925     for (Operand& operand : subgraph->operands) {
926         if (operand.lifetime == OperandLifeTime::CONSTANT_COPY) {
927             uint32_t valueLength = operand.location.length;
928             uint32_t existingSize = mOperandValues.size();
929             uint32_t extraBytes = alignBytesNeeded(existingSize, valueLength);
930             uint32_t originalOffset = operand.location.offset;
931             uint32_t offset = existingSize + extraBytes;
932             mOperandValues.resize(offset + valueLength);
933             memcpy(&mOperandValues[offset], &refModel->mSmallOperandValues[originalOffset],
934                    valueLength);
935             operand.location.offset = offset;
936         } else if (operand.lifetime == OperandLifeTime::CONSTANT_REFERENCE) {
937             uint32_t originalPoolIndex = operand.location.poolIndex;
938             operand.location.poolIndex = mMemories.add(refModel->mMemories[originalPoolIndex]);
939         }
940     }
941     // Do recursive calls at the end to improve locality of mOperandValues.
942     for (Operand& operand : subgraph->operands) {
943         if (operand.lifetime == OperandLifeTime::SUBGRAPH) {
944             uint32_t refModelIndex = operand.location.offset;
945             // TODO(b/147875885): Avoid creating duplicate refSubgraphs when
946             // a single refModel is referenced multiple times.
947             operand.location.offset = addSubgraph(refModel->mReferencedModels[refModelIndex]);
948         }
949     }
950 }
951 
addSubgraph(const ModelBuilder * refModel)952 uint32_t ModelBuilder::HidlModelMaker::addSubgraph(const ModelBuilder* refModel) {
953     uint32_t index = mRefSubgraphs.size();
954     mRefSubgraphs.push_back(makeSubgraph(refModel));
955     updateOperandLocations(refModel, &mRefSubgraphs.back());
956     return index;
957 }
958 
addExtensions(const ModelBuilder * model)959 void ModelBuilder::HidlModelMaker::addExtensions(const ModelBuilder* model) {
960     constexpr uint8_t kLowBitsType = static_cast<uint8_t>(ExtensionTypeEncoding::LOW_BITS_TYPE);
961     for (const auto& operand : model->mOperands) {
962         if (isExtensionOperandType(operand.type)) {
963             addExtensionWithPrefix(static_cast<uint32_t>(operand.type) >> kLowBitsType);
964         }
965     }
966     for (const auto& operation : model->mOperations) {
967         if (isExtensionOperationType(operation.type)) {
968             addExtensionWithPrefix(static_cast<uint32_t>(operation.type) >> kLowBitsType);
969         }
970     }
971     for (const auto& refModel : model->mReferencedModels) {
972         addExtensions(refModel);
973     }
974 }
975 
addExtensionWithPrefix(uint16_t prefix)976 void ModelBuilder::HidlModelMaker::addExtensionWithPrefix(uint16_t prefix) {
977     if (!mPrefixSet.insert(prefix).second) {
978         return;
979     }
980     const Extension* extension;
981     CHECK(TypeManager::get()->getExtensionInfo(prefix, &extension));
982     mExtensionNameToPrefix.push_back({
983             .name = extension->name,
984             .prefix = prefix,
985     });
986 }
987 
988 }  // namespace nn
989 }  // namespace android
990