1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "ModelBuilder"
18
19 #include "ModelBuilder.h"
20
21 #include <GraphDump.h>
22 #include <LegacyUtils.h>
23 #include <ModelUtils.h>
24 #include <android-base/logging.h>
25 #include <nnapi/Validation.h>
26
27 #include <algorithm>
28 #include <map>
29 #include <memory>
30 #include <set>
31 #include <utility>
32 #include <vector>
33
34 #include "CompilationBuilder.h"
35 #include "Manager.h"
36 #include "ModelArchHasher.h"
37 #include "TypeManager.h"
38
39 namespace android {
40 namespace nn {
41
42 // The maximum number of operands and operations that a model may have.
43 const uint32_t MAX_NUMBER_OF_OPERANDS = 0xFFFFFFFE;
44 const uint32_t MAX_NUMBER_OF_OPERATIONS = 0xFFFFFFFE;
45
46 #define NN_VALIDATE_NULL_OR_SIZED(tag, data, length) \
47 if ((data == nullptr) != (length == 0)) { \
48 LOG(ERROR) << "ANeuralNetworksModel_" << tag << " " << #data << " is " \
49 << (data == nullptr ? "null" : "not null") << " but " << #length << " is " \
50 << length; \
51 return ANEURALNETWORKS_BAD_DATA; \
52 }
53
54 template <typename Type>
makeVector(const Type * data,uint32_t length)55 static std::vector<Type> makeVector(const Type* data, uint32_t length) {
56 return length > 0 ? std::vector<Type>(data, data + length) : std::vector<Type>();
57 }
58
badState(const char * name)59 bool ModelBuilder::badState(const char* name) {
60 if (mCompletedModel) {
61 LOG(ERROR) << "ANeuralNetworksModel_" << name << " can't modify after model finished";
62 return true;
63 }
64 if (mInvalidModel) {
65 LOG(ERROR) << "ANeuralNetworksModel_" << name << " can't modify an invalid model";
66 return true;
67 }
68 return false;
69 }
70
getExtensionType(const char * extensionName,uint16_t typeWithinExtension,int32_t * type)71 int ModelBuilder::getExtensionType(const char* extensionName, uint16_t typeWithinExtension,
72 int32_t* type) {
73 return TypeManager::get()->getExtensionType(extensionName, typeWithinExtension, type)
74 ? ANEURALNETWORKS_NO_ERROR
75 : ANEURALNETWORKS_BAD_DATA;
76 }
77
addOperand(const ANeuralNetworksOperandType & type)78 int ModelBuilder::addOperand(const ANeuralNetworksOperandType& type) {
79 if (badState("addOperand")) {
80 return ANEURALNETWORKS_BAD_STATE;
81 }
82
83 OperandType operandType = static_cast<OperandType>(type.type);
84 if (isExtension(operandType) && !TypeManager::get()->areExtensionsAllowed()) {
85 LOG(ERROR) << "Extensions are not supported for this process.";
86 return ANEURALNETWORKS_BAD_DATA;
87 }
88 bool isOemOperand =
89 operandType == OperandType::OEM || operandType == OperandType::TENSOR_OEM_BYTE;
90 if (isOemOperand && !mHasOEMOperand) {
91 LOG(WARNING) << "OEM data type is deprecated. Use Extensions instead.";
92 }
93
94 const Extension::OperandTypeInformation* info = nullptr;
95 if (isExtension(operandType) &&
96 !TypeManager::get()->getExtensionOperandTypeInfo(operandType, &info)) {
97 LOG(ERROR) << "Extension operand type " << operandType << " is not registered";
98 return ANEURALNETWORKS_BAD_DATA;
99 }
100 NN_VALIDATE_NULL_OR_SIZED("addOperand", type.dimensions, type.dimensionCount);
101 Operand operand = {
102 .type = operandType,
103 .dimensions = makeVector(type.dimensions, type.dimensionCount),
104 .scale = type.scale,
105 .zeroPoint = type.zeroPoint,
106 .lifetime = Operand::LifeTime::TEMPORARY_VARIABLE,
107 .location = {.poolIndex = 0, .offset = 0, .length = 0},
108 .extraParams = {},
109 };
110 if (auto result = validateOperandType(operand, info, "ANeuralNetworksModel_addOperand", true);
111 !result.ok()) {
112 LOG(ERROR) << result.error();
113 return ANEURALNETWORKS_BAD_DATA;
114 }
115
116 size_t idx = mOperands.size();
117 if (idx >= MAX_NUMBER_OF_OPERANDS) {
118 LOG(ERROR) << "ANeuralNetworksModel_addOperand exceed max operands";
119 return ANEURALNETWORKS_BAD_DATA;
120 }
121
122 mOperands.push_back(std::move(operand));
123 mHasOEMOperand |= isOemOperand;
124 mHasControlFlow |= (operandType == OperandType::SUBGRAPH);
125 return ANEURALNETWORKS_NO_ERROR;
126 }
127
setOperandValue(uint32_t index,const void * buffer,size_t length)128 int ModelBuilder::setOperandValue(uint32_t index, const void* buffer, size_t length) {
129 VLOG(MODEL) << __func__ << " for operand " << index << " size " << length;
130 if (badState("setOperandValue")) {
131 return ANEURALNETWORKS_BAD_STATE;
132 }
133
134 if (index >= operandCount()) {
135 LOG(ERROR) << "ANeuralNetworksModel_setOperandValue setting operand " << index << " of "
136 << operandCount();
137 return ANEURALNETWORKS_BAD_DATA;
138 }
139 Operand& operand = mOperands[index];
140 NN_VALIDATE_NULL_OR_SIZED("setOperandValue", buffer, length);
141 if (buffer == nullptr) {
142 operand.lifetime = Operand::LifeTime::NO_VALUE;
143 // The location is unused and is set to zeros.
144 operand.location = {.poolIndex = 0, .offset = 0, .length = 0};
145 } else {
146 if (TypeManager::get()->isTensorType(operand.type) &&
147 tensorHasUnspecifiedDimensions(operand)) {
148 LOG(ERROR) << "ANeuralNetworksModel_setOperandValue setting operand " << index
149 << " which has operand type that is not fully specified";
150 return ANEURALNETWORKS_BAD_DATA;
151 }
152 if (length > 0xFFFFFFFF) {
153 LOG(ERROR) << "ANeuralNetworksModel_setOperandValue value length of " << length
154 << " exceeds max size";
155 return ANEURALNETWORKS_BAD_DATA;
156 }
157 uint32_t valueLength = static_cast<uint32_t>(length);
158 if (operand.type != OperandType::OEM) {
159 uint32_t neededLength = TypeManager::get()->getSizeOfData(operand);
160 if (neededLength != valueLength) {
161 LOG(ERROR) << "ANeuralNetworksModel_setOperandValue setting " << valueLength
162 << " bytes when needing " << neededLength;
163 return ANEURALNETWORKS_BAD_DATA;
164 }
165 }
166 if (valueLength <= ANEURALNETWORKS_MAX_SIZE_OF_IMMEDIATELY_COPIED_VALUES) {
167 uint32_t existingSize = static_cast<uint32_t>(mSmallOperandValues.size());
168 uint32_t extraBytes = alignBytesNeeded(existingSize, valueLength);
169 mSmallOperandValues.resize(existingSize + extraBytes + valueLength);
170 operand.lifetime = Operand::LifeTime::CONSTANT_COPY;
171 operand.location = {
172 .poolIndex = 0, .offset = existingSize + extraBytes, .length = valueLength};
173 memcpy(&mSmallOperandValues[operand.location.offset], buffer, valueLength);
174 VLOG(MODEL) << "Copied small value to offset " << operand.location.offset;
175 } else {
176 VLOG(MODEL) << "Saving large value";
177 operand.lifetime = Operand::LifeTime::CONSTANT_REFERENCE;
178 // The values for poolIndex and offset will be set when the model is finished.
179 typedef decltype(operand.location.poolIndex) PoolIndexType;
180 typedef decltype(operand.location.offset) OffsetType;
181 operand.location = {.poolIndex = ~PoolIndexType(0),
182 .offset = ~OffsetType(0),
183 .length = valueLength};
184 // We keep track of the buffers. We'll allocate the shared memory only
185 // once we know the total size, to avoid needless copies.
186 mLargeOperandValues.push_back(LargeValue{.operandIndex = index, .buffer = buffer});
187 }
188 }
189 return ANEURALNETWORKS_NO_ERROR;
190 }
191
setOperandValueFromModel(uint32_t index,const ModelBuilder * value)192 int ModelBuilder::setOperandValueFromModel(uint32_t index, const ModelBuilder* value) {
193 VLOG(MODEL) << __func__ << " for operand " << index << " model " << value;
194 if (badState("setOperandValueFromModel")) {
195 return ANEURALNETWORKS_BAD_STATE;
196 }
197 if (!value->mCompletedModel) {
198 LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromModel value model must be finished";
199 return ANEURALNETWORKS_BAD_STATE;
200 }
201 if (value->mInvalidModel) {
202 LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromModel value model is invalid";
203 return ANEURALNETWORKS_BAD_STATE;
204 }
205 if (index >= operandCount()) {
206 LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromModel setting operand " << index
207 << " of " << operandCount();
208 return ANEURALNETWORKS_BAD_DATA;
209 }
210 Operand& operand = mOperands[index];
211 operand.lifetime = Operand::LifeTime::SUBGRAPH;
212 operand.location = {
213 .poolIndex = 0,
214 .offset = static_cast<uint32_t>(mReferencedModels.size()),
215 .length = 0,
216 };
217 mReferencedModels.push_back(value);
218 mReferencedSubgraphsForValidation.push_back(value->makeModel().main);
219 return ANEURALNETWORKS_NO_ERROR;
220 }
221
setOperandSymmPerChannelQuantParams(uint32_t index,const ANeuralNetworksSymmPerChannelQuantParams & channelQuant)222 int ModelBuilder::setOperandSymmPerChannelQuantParams(
223 uint32_t index, const ANeuralNetworksSymmPerChannelQuantParams& channelQuant) {
224 if (badState("setOperandSymmPerChannelQuantParams")) {
225 return ANEURALNETWORKS_BAD_STATE;
226 }
227
228 if (index >= operandCount()) {
229 LOG(ERROR) << "ANeuralNetworksModel_setOperandSymmPerChannelQuantParams "
230 << "setting per-channel quantization parameters for operand " << index << " of "
231 << operandCount();
232 return ANEURALNETWORKS_BAD_DATA;
233 }
234 Operand& operand = mOperands[index];
235
236 NN_VALIDATE_NULL_OR_SIZED("setOperandSymmPerChannelQuantParams", channelQuant.scales,
237 channelQuant.scaleCount);
238 Operand::SymmPerChannelQuantParams extraParams = {
239 .scales = makeVector(channelQuant.scales, channelQuant.scaleCount),
240 .channelDim = channelQuant.channelDim,
241 };
242 if (auto result = validateOperandSymmPerChannelQuantParams(
243 operand, extraParams, "ANeuralNetworksModel_setOperandSymmPerChannelQuantParams");
244 !result.ok()) {
245 LOG(ERROR) << result.error();
246 return ANEURALNETWORKS_BAD_DATA;
247 }
248 switch (operand.type) {
249 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
250 operand.extraParams = std::move(extraParams);
251 break;
252 default:
253 LOG(ERROR) << "ANeuralNetworksModel_setOperandSymmPerChannelQuantParams "
254 << "invalid operand type " << static_cast<int32_t>(operand.type);
255 return ANEURALNETWORKS_BAD_DATA;
256 }
257 return ANEURALNETWORKS_NO_ERROR;
258 }
259
setOperandExtensionData(uint32_t index,const void * data,size_t length)260 int ModelBuilder::setOperandExtensionData(uint32_t index, const void* data, size_t length) {
261 if (badState("setOperandExtensionData")) {
262 return ANEURALNETWORKS_BAD_STATE;
263 }
264
265 if (index >= operandCount()) {
266 LOG(ERROR) << "ANeuralNetworksModel_setOperandExtensionData "
267 << "setting extension data for operand " << index << " of " << operandCount();
268 return ANEURALNETWORKS_BAD_DATA;
269 }
270 Operand& operand = mOperands[index];
271
272 if (!isExtension(operand.type)) {
273 LOG(ERROR) << "ANeuralNetworksModel_setOperandExtensionData "
274 << "setting extension data for a base operand type "
275 << static_cast<int32_t>(operand.type);
276 return ANEURALNETWORKS_BAD_DATA;
277 }
278
279 NN_VALIDATE_NULL_OR_SIZED("setOperandExtensionData", data, length);
280 if (data == nullptr) {
281 operand.extraParams = {};
282 } else {
283 operand.extraParams = Operand::ExtensionParams(
284 std::vector<uint8_t>(reinterpret_cast<const uint8_t*>(data),
285 reinterpret_cast<const uint8_t*>(data) + length));
286 }
287 return ANEURALNETWORKS_NO_ERROR;
288 }
289
copyLargeValuesToSharedMemory()290 int ModelBuilder::copyLargeValuesToSharedMemory() {
291 VLOG(MODEL) << __func__ << " has " << mLargeOperandValues.size() << " values.";
292 if (!mLargeOperandValues.empty()) {
293 // Calculate the size of the shared memory needed for all the large values.
294 // Also sets the offset for each value within the memory.
295 size_t poolSize = 0;
296 for (LargeValue& l : mLargeOperandValues) {
297 Operand& operand = mOperands[l.operandIndex];
298 CHECK_EQ(operand.lifetime, Operand::LifeTime::CONSTANT_REFERENCE);
299 poolSize += alignBytesNeeded(poolSize, operand.location.length);
300 operand.location.offset = poolSize;
301 poolSize += operand.location.length;
302 }
303
304 // Allocate the shared memory.
305 int n;
306 std::tie(n, mLargeValueMemory) = MemoryAshmem::create(poolSize);
307 NN_RETURN_IF_ERROR(n);
308 uint8_t* memoryPointer = mLargeValueMemory->getPointer();
309 uint32_t poolIndex = mMemories.add(mLargeValueMemory.get());
310 VLOG(MODEL) << "Allocated large value pool of size " << poolSize << " at index "
311 << poolIndex;
312
313 // Copy the values to this memory.
314 for (LargeValue& l : mLargeOperandValues) {
315 Operand& operand = mOperands[l.operandIndex];
316 operand.location.poolIndex = poolIndex;
317 memcpy(memoryPointer + operand.location.offset, l.buffer, operand.location.length);
318 }
319 }
320
321 return ANEURALNETWORKS_NO_ERROR;
322 }
323
setOperandValueFromMemory(uint32_t index,const RuntimeMemory * memory,uint32_t offset,size_t length)324 int ModelBuilder::setOperandValueFromMemory(uint32_t index, const RuntimeMemory* memory,
325 uint32_t offset, size_t length) {
326 VLOG(MODEL) << __func__ << " for operand " << index << " offset " << offset << " size "
327 << length;
328 if (badState("setOperandValueFromMemory")) {
329 return ANEURALNETWORKS_BAD_STATE;
330 }
331
332 if (index >= operandCount()) {
333 LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromMemory setting operand " << index
334 << " of " << operandCount();
335 return ANEURALNETWORKS_BAD_DATA;
336 }
337 Operand& operand = mOperands[index];
338 if (TypeManager::get()->isTensorType(operand.type) && tensorHasUnspecifiedDimensions(operand)) {
339 LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromMemory setting operand " << index
340 << " which has operand type that is not fully specified";
341 return ANEURALNETWORKS_BAD_DATA;
342 }
343 uint32_t neededLength = TypeManager::get()->getSizeOfData(operand);
344 if (neededLength != length) {
345 LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromMemory setting " << length
346 << " bytes when needing " << neededLength;
347 return ANEURALNETWORKS_BAD_DATA;
348 }
349 // Set compilation = nullptr to indicate that the memory is used for a model constant.
350 // In this case, IOType::INPUT is a placeholder value that is ignored by the validator.
351 if (!memory->getValidator().validate(/*compilation=*/nullptr, /*placeholder*/ IOType::INPUT,
352 index, nullptr, offset, length)) {
353 return ANEURALNETWORKS_BAD_DATA;
354 }
355 operand.lifetime = Operand::LifeTime::CONSTANT_REFERENCE;
356 operand.location = {.poolIndex = mMemories.add(memory),
357 .offset = offset,
358 .length = static_cast<uint32_t>(length)};
359 return ANEURALNETWORKS_NO_ERROR;
360 }
361
addOperation(ANeuralNetworksOperationType type,uint32_t inputCount,const uint32_t * inputs,uint32_t outputCount,const uint32_t * outputs)362 int ModelBuilder::addOperation(ANeuralNetworksOperationType type, uint32_t inputCount,
363 const uint32_t* inputs, uint32_t outputCount,
364 const uint32_t* outputs) {
365 if (badState("addOperation")) {
366 return ANEURALNETWORKS_BAD_STATE;
367 }
368
369 OperationType operationType = static_cast<OperationType>(type);
370 if (isExtension(operationType) && !TypeManager::get()->areExtensionsAllowed()) {
371 LOG(ERROR) << "Extensions are not supported for this process.";
372 return ANEURALNETWORKS_BAD_DATA;
373 }
374 if (operationType == OperationType::OEM_OPERATION && !mHasOEMOperation) {
375 LOG(WARNING) << "OEM_OPERATION is deprecated. Use Extensions instead.";
376 }
377
378 if (!isExtension(operationType)) {
379 bool allowExperimental = false;
380 #ifdef NN_EXPERIMENTAL_FEATURE
381 if (type >= BuiltinOperationResolver::kStartOfExperimentalOperations &&
382 type < BuiltinOperationResolver::kStartOfExperimentalOperations +
383 BuiltinOperationResolver::kNumberOfExperimentalOperationTypes) {
384 allowExperimental = true;
385 }
386 #endif // NN_EXPERIMENTAL_FEATURE
387 if (!validCode(kNumberOfOperationTypes, kNumberOfOperationTypesOEM, type) &&
388 !allowExperimental) {
389 LOG(ERROR) << "ANeuralNetworksModel_addOperation invalid operation type " << type;
390 return ANEURALNETWORKS_BAD_DATA;
391 }
392 } else {
393 const Extension* extension;
394 uint16_t extensionPrefix = getExtensionPrefix(static_cast<uint32_t>(operationType));
395 if (!TypeManager::get()->getExtensionInfo(extensionPrefix, &extension)) {
396 LOG(ERROR) << "Extension operation type " << operationType << " is not recognized";
397 return ANEURALNETWORKS_BAD_DATA;
398 }
399 }
400
401 NN_VALIDATE_NULL_OR_SIZED("addOperation", inputs, inputCount);
402 NN_VALIDATE_NULL_OR_SIZED("addOperation", outputs, outputCount);
403 Operation operation = {
404 .type = operationType,
405 .inputs = makeVector(inputs, inputCount),
406 .outputs = makeVector(outputs, outputCount),
407 };
408 if (auto result = validateOperationButNotOperands(operation, mOperands,
409 mReferencedSubgraphsForValidation);
410 !result.ok()) {
411 LOG(ERROR) << "Invalid Operation: " << result.error();
412 return ANEURALNETWORKS_BAD_DATA;
413 }
414
415 uint32_t operationIndex = operationCount();
416 if (operationIndex >= MAX_NUMBER_OF_OPERATIONS) {
417 LOG(ERROR) << "ANeuralNetworksModel_addOperation exceed max operations";
418 return ANEURALNETWORKS_BAD_DATA;
419 }
420
421 mOperations.push_back(std::move(operation));
422 mHasOEMOperation |= (operationType == OperationType::OEM_OPERATION);
423 mHasExtensionOperation |= isExtension(operationType);
424 mHasControlFlow |=
425 (operationType == OperationType::IF || operationType == OperationType::WHILE);
426
427 return ANEURALNETWORKS_NO_ERROR;
428 }
429
identifyInputsAndOutputs(uint32_t inputCount,const uint32_t * inputs,uint32_t outputCount,const uint32_t * outputs)430 int ModelBuilder::identifyInputsAndOutputs(uint32_t inputCount, const uint32_t* inputs,
431 uint32_t outputCount, const uint32_t* outputs) {
432 if (badState("identifyInputsAndOutputs")) {
433 return ANEURALNETWORKS_BAD_STATE;
434 }
435
436 NN_VALIDATE_NULL_OR_SIZED("identifyInputsAndOutputs", inputs, inputCount);
437 if (auto result = validateOperandList(makeVector(inputs, inputCount), operandCount(),
438 "ANeuralNetworksModel_identifyInputsAndOutputs inputs");
439 !result.ok()) {
440 LOG(ERROR) << result.error();
441 return ANEURALNETWORKS_BAD_DATA;
442 }
443 NN_VALIDATE_NULL_OR_SIZED("identifyInputsAndOutputs", outputs, outputCount);
444 if (auto result = validateOperandList(makeVector(outputs, outputCount), operandCount(),
445 "ANeuralNetworksModel_identifyInputsAndOutputs outputs");
446 !result.ok()) {
447 LOG(ERROR) << result.error();
448 return ANEURALNETWORKS_BAD_DATA;
449 }
450
451 // Makes a copy of the index list, validates the arguments, and changes
452 // the lifetime info of the corresponding operand.
453 auto setArguments = [&](std::vector<uint32_t>* indexVector, uint32_t indexCount,
454 const uint32_t* indexList, Operand::LifeTime lifetime) -> bool {
455 indexVector->resize(indexCount);
456 for (uint32_t i = 0; i < indexCount; i++) {
457 const uint32_t operandIndex = indexList[i];
458 if (operandIndex >= mOperands.size()) {
459 LOG(ERROR) << "ANeuralNetworksModel_identifyInputsAndOutputs Can't set input or "
460 "output "
461 "to be "
462 << operandIndex << " as this exceeds the numbe of operands "
463 << mOperands.size();
464 return false;
465 }
466 (*indexVector)[i] = operandIndex;
467 Operand& operand = mOperands[operandIndex];
468 if (operand.lifetime != Operand::LifeTime::TEMPORARY_VARIABLE) {
469 LOG(ERROR) << "ANeuralNetworksModel_identifyInputsAndOutputs Can't set operand "
470 << operandIndex
471 << " to be an input or output. Check that it's not a constant or "
472 "already an input or output";
473 return false;
474 }
475 operand.lifetime = lifetime;
476 }
477 return true;
478 };
479
480 if (!setArguments(&mInputIndexes, inputCount, inputs, Operand::LifeTime::SUBGRAPH_INPUT) ||
481 !setArguments(&mOutputIndexes, outputCount, outputs, Operand::LifeTime::SUBGRAPH_OUTPUT)) {
482 return ANEURALNETWORKS_BAD_DATA;
483 }
484
485 return ANEURALNETWORKS_NO_ERROR;
486 }
487
relaxComputationFloat32toFloat16(bool allow)488 int ModelBuilder::relaxComputationFloat32toFloat16(bool allow) {
489 if (badState("relaxComputationFloat32toFloat16")) {
490 return ANEURALNETWORKS_BAD_STATE;
491 }
492
493 mRelaxComputationFloat32toFloat16 = allow;
494
495 return ANEURALNETWORKS_NO_ERROR;
496 }
497
createCompilation(CompilationBuilder ** compilation,const std::vector<std::shared_ptr<Device>> & devices,bool explicitDeviceList)498 int ModelBuilder::createCompilation(CompilationBuilder** compilation,
499 const std::vector<std::shared_ptr<Device>>& devices,
500 bool explicitDeviceList) {
501 if (!mCompletedModel || mInvalidModel) {
502 LOG(ERROR) << "ANeuralNetworksCompilation_create passed an unfinished or invalid model";
503 *compilation = nullptr;
504 return ANEURALNETWORKS_BAD_STATE;
505 }
506 *compilation = new (std::nothrow) CompilationBuilder(this, devices, explicitDeviceList);
507 return (*compilation ? ANEURALNETWORKS_NO_ERROR : ANEURALNETWORKS_OUT_OF_MEMORY);
508 }
509
finish()510 int ModelBuilder::finish() {
511 if (mCompletedModel) {
512 LOG(ERROR) << "ANeuralNetworksModel_finish called more than once";
513 return ANEURALNETWORKS_BAD_STATE;
514 }
515 if (mInvalidModel) {
516 LOG(ERROR) << "ANeuralNetworksModel_finish called on an invalid model";
517 return ANEURALNETWORKS_BAD_STATE;
518 }
519
520 int n = copyLargeValuesToSharedMemory();
521 if (n != ANEURALNETWORKS_NO_ERROR) {
522 return n;
523 }
524
525 // We sort the operations so that they will be in the appropriate
526 // order for a single-threaded, op at a time execution.
527 // TODO: we don't need this if we always run the partitioner.
528 if (!sortIntoRunOrder()) {
529 // We expect sortIntoRunOrder() to have logged an appropriate error message.
530 mInvalidModel = true;
531 return ANEURALNETWORKS_BAD_DATA;
532 }
533
534 // TODO: Modify validation so that it can be called without creating a Model.
535 // NOTE: Must sortIntoRunOrder() before validation; validator expects operations
536 // to have been sorted.
537 // NOTE: Must copyLargeValuesToSharedMemory() before validation; otherwise,
538 // a CONSTANT_REFERENCE operand will not have correct .poolIndex, and
539 // validation will not work properly.
540 const Model modelForValidation = makeModel();
541 const auto maybeVersion = validate(modelForValidation);
542 if (!maybeVersion.ok()) {
543 LOG(ERROR) << "ANeuralNetworksModel_finish called on invalid model: "
544 << maybeVersion.error();
545 mInvalidModel = true;
546 return ANEURALNETWORKS_BAD_DATA;
547 }
548 if (!isCompliantVersion(maybeVersion.value(), DeviceManager::get()->getRuntimeVersion())) {
549 LOG(ERROR) << "ANeuralNetworksModel_finish called on a model that is newer what is "
550 "allowed. Model version needed: "
551 << maybeVersion.value() << ", current runtime version supported: "
552 << DeviceManager::get()->getRuntimeVersion();
553 mInvalidModel = true;
554 return ANEURALNETWORKS_BAD_DATA;
555 }
556 if (VLOG_IS_ON(MODEL)) {
557 graphDump("ModelBuilder::finish", modelForValidation, nullptr);
558 }
559
560 removeTrailingArgumentsWithDefaultValues();
561 simplifyModel();
562
563 mCompletedModel = true;
564 CHECK(calcModelArchHash(modelForValidation, mModelArchHash))
565 << "Failed to calculate model arch hash";
566 return ANEURALNETWORKS_NO_ERROR;
567 }
568
logRemoval(const Operation & operation,uint32_t count,const std::vector<Operand> & operands)569 static void logRemoval(const Operation& operation, uint32_t count,
570 const std::vector<Operand>& operands) {
571 std::ostringstream message;
572 message << "Operation " << operation.type << " with inputs {";
573 for (uint32_t i = 0; i < operation.inputs.size(); ++i) {
574 if (i != 0) {
575 message << ", ";
576 }
577 message << operands[operation.inputs[i]].type;
578 }
579 message << "} has trailing optional inputs set to default values. Removing " << count
580 << " trailing inputs.";
581 VLOG(MODEL) << message.str();
582 }
583
removeTrailingArgumentsWithDefaultValues()584 void ModelBuilder::removeTrailingArgumentsWithDefaultValues() {
585 for (Operation& operation : mOperations) {
586 const uint32_t count = getNumTrailingArgumentsToRemove(operation);
587 if (count == 0) {
588 continue;
589 }
590 if (VLOG_IS_ON(MODEL)) {
591 logRemoval(operation, count, mOperands);
592 }
593 const uint32_t inputCount = operation.inputs.size();
594 CHECK_LT(count, inputCount);
595 const uint32_t newInputCount = inputCount - count;
596 operation.inputs.resize(newInputCount);
597 }
598 }
599
600 // See countMatchingTrailingArguments().
601 enum class TailSpec {
602 BOOL_FALSE,
603 INT32_ONE,
604 INT32_NEGATIVE_ONE,
605 };
606
607 // See countMatchingTrailingArguments().
matchesSpec(TailSpec spec,const Operand & operand,const std::vector<uint8_t> & mSmallOperandValues)608 static bool matchesSpec(TailSpec spec, const Operand& operand,
609 const std::vector<uint8_t>& mSmallOperandValues) {
610 const void* valuePtr = nullptr;
611 if (operand.lifetime == Operand::LifeTime::CONSTANT_COPY) {
612 valuePtr = static_cast<const void*>(&mSmallOperandValues[operand.location.offset]);
613 } else if (operand.lifetime == Operand::LifeTime::POINTER) {
614 valuePtr = std::get<const void*>(operand.location.pointer);
615 } else {
616 // CONSTANT_REFERENCE operands are not supported to avoid mapping memory
617 // during compilation.
618 return false;
619 }
620 switch (spec) {
621 case TailSpec::BOOL_FALSE:
622 return operand.type == OperandType::BOOL &&
623 *static_cast<const bool8*>(valuePtr) == false;
624 case TailSpec::INT32_ONE:
625 return operand.type == OperandType::INT32 &&
626 *static_cast<const int32_t*>(valuePtr) == 1;
627 case TailSpec::INT32_NEGATIVE_ONE:
628 return operand.type == OperandType::INT32 &&
629 *static_cast<const int32_t*>(valuePtr) == -1;
630 default:
631 CHECK(false) << "Unhandled TailSpec: " << static_cast<int>(spec);
632 }
633 }
634
635 // Returns the number of trailing operation inputs that match the specification.
636 //
637 // Example:
638 // opeation.inputs = {BOOL_TRUE, BOOL_TRUE, INT32_ONE, INT32_NEGATIVE_ONE}
639 // tail = {BOOL_FALSE, INT32_ONE, INT32_NEGATIVE_ONE}
640 // tailStartIndex = 1 matching elements: ^^^^^^^^^ ^^^^^^^^^^^^^^^^^^
countMatchingTrailingArguments(uint32_t tailStartIndex,const std::vector<TailSpec> & tail,const Operation & operation,const std::vector<Operand> & operands,const std::vector<uint8_t> & smallOperandValues)641 static uint32_t countMatchingTrailingArguments(uint32_t tailStartIndex,
642 const std::vector<TailSpec>& tail,
643 const Operation& operation,
644 const std::vector<Operand>& operands,
645 const std::vector<uint8_t>& smallOperandValues) {
646 const uint32_t inputCount = operation.inputs.size();
647 uint32_t count = 0;
648 for (uint32_t i = inputCount - 1; i >= tailStartIndex; --i) {
649 const Operand& operand = operands[operation.inputs[i]];
650 if (!matchesSpec(tail[i - tailStartIndex], operand, smallOperandValues)) {
651 break;
652 }
653 ++count;
654 }
655 return count;
656 }
657
getNumTrailingArgumentsToRemove(const Operation & operation) const658 uint32_t ModelBuilder::getNumTrailingArgumentsToRemove(const Operation& operation) const {
659 const uint32_t inputCount = operation.inputs.size();
660 auto getCount = [this, &operation](uint32_t tailStartIndex, const std::vector<TailSpec>& tail) {
661 return countMatchingTrailingArguments(tailStartIndex, tail, operation, mOperands,
662 mSmallOperandValues);
663 };
664 using TS = TailSpec;
665 // Check if the operation has optional arguments that might be set to default
666 // values. Skip the counting if no optional arguments are present.
667 switch (operation.type) {
668 case OperationType::AVERAGE_POOL_2D: {
669 if (inputCount == 11 && mOperands[operation.inputs[7]].type == OperandType::INT32) {
670 // Explicit padding
671 // API level 29: 10 to 11 inputs
672 // API level 27: 10 inputs
673 return getCount(10, {TS::BOOL_FALSE});
674 } else if (inputCount == 8 &&
675 mOperands[operation.inputs[7]].type == OperandType::BOOL) {
676 // Implicit padding
677 // API level 29: 7 to 8 inputs
678 // API level 27: 7 inputs
679 return getCount(7, {TS::BOOL_FALSE});
680 }
681 } break;
682 case OperationType::CONV_2D: {
683 if (10 < inputCount && inputCount <= 13 &&
684 mOperands[operation.inputs[7]].type == OperandType::INT32) {
685 // Explicit padding
686 // API level 29: 10 to 13 inputs
687 // API level 27: 10 inputs
688 uint32_t count = getCount(10, {TS::BOOL_FALSE, TS::INT32_ONE, TS::INT32_ONE});
689 // Inputs 11 and 12 must come together.
690 return inputCount - count == 12 ? 0 : count;
691 } else if (7 < inputCount && inputCount <= 10 &&
692 mOperands[operation.inputs[7]].type == OperandType::BOOL) {
693 // Implicit padding
694 // API level 29: 7 to 10 inputs
695 // API level 27: 7 inputs
696 uint32_t count = getCount(7, {TS::BOOL_FALSE, TS::INT32_ONE, TS::INT32_ONE});
697 // Inputs 8 and 9 must come together.
698 return inputCount - count == 9 ? 0 : count;
699 }
700 } break;
701 case OperationType::DEPTHWISE_CONV_2D: {
702 if (11 < inputCount && inputCount <= 14 &&
703 mOperands[operation.inputs[8]].type == OperandType::INT32) {
704 // Explicit padding
705 // API level 29: 11 to 14 inputs
706 // API level 27: 11 inputs
707 uint32_t count = getCount(11, {TS::BOOL_FALSE, TS::INT32_ONE, TS::INT32_ONE});
708 // Inputs 12 and 13 must come together.
709 return inputCount - count == 13 ? 0 : count;
710 } else if (8 < inputCount && inputCount <= 11 &&
711 mOperands[operation.inputs[8]].type == OperandType::BOOL) {
712 // Implicit padding
713 // API level 29: 8 to 11 inputs
714 // API level 27: 8 inputs
715 uint32_t count = getCount(8, {TS::BOOL_FALSE, TS::INT32_ONE, TS::INT32_ONE});
716 // Inputs 9 and 10 must come together.
717 return inputCount - count == 10 ? 0 : count;
718 }
719 } break;
720 case OperationType::DEPTH_TO_SPACE: {
721 if (inputCount == 3) {
722 // API level 29: 2 to 3 inputs
723 // API level 27: 2 inputs
724 return getCount(2, {TS::BOOL_FALSE});
725 }
726 } break;
727 case OperationType::L2_NORMALIZATION: {
728 if (inputCount == 2) {
729 // API level 29: 1 to 2 inputs
730 // API level 27: 1 inputs
731 return getCount(1, {TS::INT32_NEGATIVE_ONE});
732 }
733 } break;
734 case OperationType::L2_POOL_2D: {
735 if (inputCount == 11 && mOperands[operation.inputs[7]].type == OperandType::INT32) {
736 // Explicit padding
737 // API level 29: 10 to 11 inputs
738 // API level 27: 10 inputs
739 return getCount(10, {TS::BOOL_FALSE});
740 } else if (inputCount == 8 &&
741 mOperands[operation.inputs[7]].type == OperandType::BOOL) {
742 // Implicit padding
743 // API level 29: 7 to 8 inputs
744 // API level 27: 7 inputs
745 return getCount(7, {TS::BOOL_FALSE});
746 }
747 } break;
748 case OperationType::LOCAL_RESPONSE_NORMALIZATION: {
749 if (inputCount == 6) {
750 // API level 29: 5 to 6 inputs
751 // API level 27: 5 inputs
752 return getCount(5, {TS::INT32_NEGATIVE_ONE});
753 }
754 } break;
755 case OperationType::MAX_POOL_2D: {
756 if (inputCount == 11 && mOperands[operation.inputs[7]].type == OperandType::INT32) {
757 // Explicit padding
758 // API level 29: 10 to 11 inputs
759 // API level 27: 10 inputs
760 return getCount(10, {TS::BOOL_FALSE});
761 } else if (inputCount == 8 &&
762 mOperands[operation.inputs[7]].type == OperandType::BOOL) {
763 // Implicit padding
764 // API level 29: 7 to 8 inputs
765 // API level 27: 7 inputs
766 return getCount(7, {TS::BOOL_FALSE});
767 }
768 } break;
769 case OperationType::RESIZE_BILINEAR: {
770 if (3 < inputCount && inputCount <= 6) {
771 // By shape:
772 // API level 30: 3 to 6 inputs
773 // API level 29: 3 to 4 inputs
774 // API level 27: 3 inputs
775 // By scale:
776 // API level 30: 3 to 6 inputs
777 // API level 29: 3 to 4 inputs
778 return getCount(3, {TS::BOOL_FALSE, TS::BOOL_FALSE, TS::BOOL_FALSE});
779 }
780 } break;
781 case OperationType::SOFTMAX: {
782 if (inputCount == 3) {
783 // API level 29: 2 to 3 inputs
784 // API level 27: 2 inputs
785 return getCount(2, {TS::INT32_NEGATIVE_ONE});
786 }
787 } break;
788 case OperationType::SPACE_TO_DEPTH: {
789 if (inputCount == 3) {
790 // API level 29: 2 to 3 inputs
791 // API level 27: 2 inputs
792 return getCount(2, {TS::BOOL_FALSE});
793 }
794 } break;
795 case OperationType::BATCH_TO_SPACE_ND: {
796 if (inputCount == 3) {
797 // API level 29: 2 to 3 inputs
798 // API level 28: 2 inputs
799 return getCount(2, {TS::BOOL_FALSE});
800 }
801 } break;
802 case OperationType::SPACE_TO_BATCH_ND: {
803 if (inputCount == 4) {
804 // API level 29: 3 to 4 inputs
805 // API level 28: 3 inputs
806 return getCount(3, {TS::BOOL_FALSE});
807 }
808 } break;
809 case OperationType::RESIZE_NEAREST_NEIGHBOR: {
810 if (4 < inputCount && inputCount <= 6) {
811 // By shape or scale
812 // API level 30: 4 to 6 inputs
813 // API level 29: 4 inputs
814 return getCount(4, {TS::BOOL_FALSE, TS::BOOL_FALSE});
815 }
816 } break;
817 default: {
818 // Do nothing.
819 } break;
820 }
821 // No trailing optional arguments to check.
822 return 0;
823 }
824
sortIntoRunOrder()825 bool ModelBuilder::sortIntoRunOrder() {
826 // Note that this may be called before the model has been
827 // validated, so we must code defensively. However, we can assume
828 // an Operation's inputs and outputs have legal indices -- this
829 // should have been checked in addOperation().
830
831 if (!mSortedOperationIndexMap.empty()) {
832 LOG(ERROR) << "Operations were already sorted into run order.";
833 return true;
834 }
835
836 // Tracks the operations that can be executed.
837 std::vector<uint32_t> sortedOperationIndexMap;
838 std::vector<uint32_t> opsReadyToRun;
839 std::vector<Operation> runOrder;
840
841 // Tracks how many inputs are needed for each operation to be ready to run.
842 std::multimap<uint32_t, uint32_t> operandToOperations;
843 std::vector<uint32_t> unknownInputCount(operationCount());
844 for (uint32_t operationIndex = 0; operationIndex < operationCount(); operationIndex++) {
845 uint32_t& count = unknownInputCount[operationIndex];
846 count = 0;
847 for (uint32_t operandIndex : mOperations[operationIndex].inputs) {
848 auto lifetime = mOperands[operandIndex].lifetime;
849 if (lifetime == Operand::LifeTime::TEMPORARY_VARIABLE ||
850 lifetime == Operand::LifeTime::SUBGRAPH_OUTPUT) {
851 count++;
852 operandToOperations.insert(
853 std::pair<uint32_t, uint32_t>(operandIndex, operationIndex));
854 }
855 }
856 if (count == 0) {
857 opsReadyToRun.push_back(operationIndex);
858 }
859 }
860
861 while (opsReadyToRun.size() > 0) {
862 // Execute the next op
863 int opIndex = opsReadyToRun.back();
864 opsReadyToRun.pop_back();
865 const Operation& operation = mOperations[opIndex];
866
867 runOrder.push_back(mOperations[opIndex]);
868 sortedOperationIndexMap.push_back(opIndex);
869
870 // Mark all its outputs as known.
871 for (uint32_t operandIndex : operation.outputs) {
872 auto range = operandToOperations.equal_range(operandIndex);
873 for (auto i = range.first; i != range.second; i++) {
874 uint32_t& count = unknownInputCount[i->second];
875 if (--count == 0) {
876 opsReadyToRun.push_back(i->second);
877 }
878 }
879 }
880 }
881
882 if (runOrder.size() != mOperations.size()) {
883 CHECK_LT(runOrder.size(), mOperations.size());
884 // Graph must contain at least one cycle or one never-written
885 // operand, because there is at least one Operation that never
886 // became ready.
887 LOG(ERROR) << "Graph contains at least one cycle or one never-written operand";
888 return false;
889 }
890
891 mSortedOperationIndexMap = std::move(sortedOperationIndexMap);
892 mOperations = std::move(runOrder);
893 return true;
894 }
895
896 // A helper class to simplify state management when creating a Model.
897 class ModelBuilder::ModelMaker {
898 public:
899 static Model run(const ModelBuilder* model, bool simplifyModel);
900
901 private:
902 static Model::Subgraph makeSubgraph(const ModelBuilder* model);
ModelMaker(bool simplifyModel)903 explicit ModelMaker(bool simplifyModel) : mSimplifyModel(simplifyModel) {}
904 Model makeModel(const ModelBuilder* mainModel);
905 uint32_t addSubgraph(const ModelBuilder* refModel);
906 void updateOperandLocations(const ModelBuilder* refModel, Model::Subgraph* subgraph);
907 void addExtensions(const ModelBuilder* model);
908 void addExtensionWithPrefix(uint16_t prefix);
909
910 bool mSimplifyModel;
911 std::vector<Model::Subgraph> mRefSubgraphs;
912 Model::OperandValues mOperandValues;
913 MemoryTracker mMemories;
914 std::vector<ExtensionNameAndPrefix> mExtensionNameToPrefix;
915 std::set<uint16_t> mPrefixSet;
916 };
917
simplifyModel()918 void ModelBuilder::simplifyModel() {
919 mSimplifyModel = true;
920 }
921
makeModel() const922 Model ModelBuilder::makeModel() const {
923 // TODO: Cache the Model to speed up subsequent calls.
924 return ModelMaker::run(this, mSimplifyModel);
925 }
926
run(const ModelBuilder * model,bool simplifyModel)927 Model ModelBuilder::ModelMaker::run(const ModelBuilder* model, bool simplifyModel) {
928 // run() ensures the state of ModelMaker is destroyed after the call.
929 return ModelMaker(simplifyModel).makeModel(model);
930 }
931
makeModel(const ModelBuilder * mainModel)932 Model ModelBuilder::ModelMaker::makeModel(const ModelBuilder* mainModel) {
933 addExtensions(mainModel);
934 Model model;
935 model.main = makeSubgraph(mainModel);
936 updateOperandLocations(mainModel, &model.main);
937 model.referenced = std::move(mRefSubgraphs);
938 model.operandValues = std::move(mOperandValues);
939 model.pools.reserve(mMemories.size());
940 std::transform(mMemories.begin(), mMemories.end(), std::back_inserter(model.pools),
941 [](const RuntimeMemory* m) { return m->getMemory(); });
942 model.relaxComputationFloat32toFloat16 = mainModel->mRelaxComputationFloat32toFloat16;
943 model.extensionNameToPrefix = std::move(mExtensionNameToPrefix);
944 if (mSimplifyModel) {
945 removeDeadOperands(&model);
946 }
947 return model;
948 }
949
makeSubgraph(const ModelBuilder * model)950 Model::Subgraph ModelBuilder::ModelMaker::makeSubgraph(const ModelBuilder* model) {
951 Model::Subgraph subgraph;
952 subgraph.operands = model->mOperands;
953 subgraph.operations = model->mOperations;
954 subgraph.inputIndexes = model->mInputIndexes;
955 subgraph.outputIndexes = model->mOutputIndexes;
956 return subgraph;
957 }
958
updateOperandLocations(const ModelBuilder * refModel,Model::Subgraph * subgraph)959 void ModelBuilder::ModelMaker::updateOperandLocations(const ModelBuilder* refModel,
960 Model::Subgraph* subgraph) {
961 for (Operand& operand : subgraph->operands) {
962 if (operand.lifetime == Operand::LifeTime::CONSTANT_COPY) {
963 uint32_t valueLength = operand.location.length;
964 uint32_t originalOffset = operand.location.offset;
965 operand.location = mOperandValues.append(&refModel->mSmallOperandValues[originalOffset],
966 valueLength);
967 } else if (operand.lifetime == Operand::LifeTime::CONSTANT_REFERENCE) {
968 uint32_t originalPoolIndex = operand.location.poolIndex;
969 operand.location.poolIndex = mMemories.add(refModel->mMemories[originalPoolIndex]);
970 }
971 }
972 // Do recursive calls at the end to improve locality of mOperandValues.
973 for (Operand& operand : subgraph->operands) {
974 if (operand.lifetime == Operand::LifeTime::SUBGRAPH) {
975 uint32_t refModelIndex = operand.location.offset;
976 // TODO(b/147875885): Avoid creating duplicate refSubgraphs when
977 // a single refModel is referenced multiple times.
978 operand.location.offset = addSubgraph(refModel->mReferencedModels[refModelIndex]);
979 }
980 }
981 }
982
addSubgraph(const ModelBuilder * refModel)983 uint32_t ModelBuilder::ModelMaker::addSubgraph(const ModelBuilder* refModel) {
984 uint32_t index = mRefSubgraphs.size();
985 mRefSubgraphs.push_back(makeSubgraph(refModel));
986 updateOperandLocations(refModel, &mRefSubgraphs.back());
987 return index;
988 }
989
addExtensions(const ModelBuilder * model)990 void ModelBuilder::ModelMaker::addExtensions(const ModelBuilder* model) {
991 for (const auto& operand : model->mOperands) {
992 if (isExtension(operand.type)) {
993 addExtensionWithPrefix(static_cast<uint32_t>(operand.type) >> kExtensionTypeBits);
994 }
995 }
996 for (const auto& operation : model->mOperations) {
997 if (isExtension(operation.type)) {
998 addExtensionWithPrefix(static_cast<uint32_t>(operation.type) >> kExtensionTypeBits);
999 }
1000 }
1001 for (const auto& refModel : model->mReferencedModels) {
1002 addExtensions(refModel);
1003 }
1004 }
1005
addExtensionWithPrefix(uint16_t prefix)1006 void ModelBuilder::ModelMaker::addExtensionWithPrefix(uint16_t prefix) {
1007 if (!mPrefixSet.insert(prefix).second) {
1008 return;
1009 }
1010 const Extension* extension;
1011 CHECK(TypeManager::get()->getExtensionInfo(prefix, &extension));
1012 mExtensionNameToPrefix.push_back({
1013 .name = extension->name,
1014 .prefix = prefix,
1015 });
1016 }
1017
getModelArchHash() const1018 const uint8_t* ModelBuilder::getModelArchHash() const {
1019 CHECK(mCompletedModel) << "Calling getModelArchHash on non completed model";
1020 return mModelArchHash;
1021 }
1022
1023 #undef NN_VALIDATE_NULL_OR_SIZED
1024
1025 } // namespace nn
1026 } // namespace android
1027