1 /*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <android-base/logging.h>
18 #include <nnapi/OperandTypes.h>
19 #include <nnapi/OperationTypes.h>
20 #include <nnapi/Result.h>
21 #include <nnapi/SharedMemory.h>
22 #include <nnapi/TypeUtils.h>
23 #include <nnapi/Types.h>
24
25 #include <algorithm>
26 #include <iterator>
27 #include <limits>
28 #include <memory>
29 #include <utility>
30 #include <vector>
31
32 #include "TestHarness.h"
33
34 namespace android::nn::test {
35 namespace {
36
37 using ::test_helper::TestModel;
38 using ::test_helper::TestOperand;
39 using ::test_helper::TestOperandLifeTime;
40 using ::test_helper::TestOperandType;
41 using ::test_helper::TestOperation;
42 using ::test_helper::TestSubgraph;
43
createOperand(const TestOperand & operand,Model::OperandValues * operandValues,ConstantMemoryBuilder * memoryBuilder)44 Result<Operand> createOperand(const TestOperand& operand, Model::OperandValues* operandValues,
45 ConstantMemoryBuilder* memoryBuilder) {
46 CHECK(operandValues != nullptr);
47 CHECK(memoryBuilder != nullptr);
48
49 const OperandType type = static_cast<OperandType>(operand.type);
50 Operand::LifeTime lifetime = static_cast<Operand::LifeTime>(operand.lifetime);
51
52 DataLocation location;
53 switch (operand.lifetime) {
54 case TestOperandLifeTime::TEMPORARY_VARIABLE:
55 case TestOperandLifeTime::SUBGRAPH_INPUT:
56 case TestOperandLifeTime::SUBGRAPH_OUTPUT:
57 case TestOperandLifeTime::NO_VALUE:
58 break;
59 case TestOperandLifeTime::CONSTANT_COPY:
60 case TestOperandLifeTime::CONSTANT_REFERENCE: {
61 const auto size = operand.data.size();
62 if (size == 0) {
63 lifetime = Operand::LifeTime::NO_VALUE;
64 } else {
65 location = (operand.lifetime == TestOperandLifeTime::CONSTANT_COPY)
66 ? operandValues->append(operand.data.get<uint8_t>(), size)
67 : memoryBuilder->append(operand.data.get<void>(), size);
68 }
69 break;
70 }
71 case TestOperandLifeTime::SUBGRAPH:
72 NN_RET_CHECK(operand.data.get<uint32_t>() != nullptr);
73 NN_RET_CHECK_GE(operand.data.size(), sizeof(uint32_t));
74 location = {.offset = *operand.data.get<uint32_t>()};
75 break;
76 }
77
78 Operand::ExtraParams extraParams;
79 if (operand.type == TestOperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
80 extraParams =
81 Operand::SymmPerChannelQuantParams{.scales = operand.channelQuant.scales,
82 .channelDim = operand.channelQuant.channelDim};
83 }
84
85 return Operand{
86 .type = type,
87 .dimensions = operand.dimensions,
88 .scale = operand.scale,
89 .zeroPoint = operand.zeroPoint,
90 .lifetime = lifetime,
91 .location = location,
92 .extraParams = std::move(extraParams),
93 };
94 }
95
createSubgraph(const TestSubgraph & testSubgraph,Model::OperandValues * operandValues,ConstantMemoryBuilder * memoryBuilder)96 Result<Model::Subgraph> createSubgraph(const TestSubgraph& testSubgraph,
97 Model::OperandValues* operandValues,
98 ConstantMemoryBuilder* memoryBuilder) {
99 // Operands.
100 std::vector<Operand> operands;
101 operands.reserve(testSubgraph.operands.size());
102 for (const auto& operand : testSubgraph.operands) {
103 operands.push_back(NN_TRY(createOperand(operand, operandValues, memoryBuilder)));
104 }
105
106 // Operations.
107 std::vector<Operation> operations;
108 operations.reserve(testSubgraph.operations.size());
109 std::transform(testSubgraph.operations.begin(), testSubgraph.operations.end(),
110 std::back_inserter(operations), [](const TestOperation& op) -> Operation {
111 return {.type = static_cast<OperationType>(op.type),
112 .inputs = op.inputs,
113 .outputs = op.outputs};
114 });
115
116 return Model::Subgraph{.operands = std::move(operands),
117 .operations = std::move(operations),
118 .inputIndexes = testSubgraph.inputIndexes,
119 .outputIndexes = testSubgraph.outputIndexes};
120 }
121
122 } // namespace
123
createModel(const TestModel & testModel)124 GeneralResult<Model> createModel(const TestModel& testModel) {
125 Model::OperandValues operandValues;
126 ConstantMemoryBuilder memoryBuilder(0);
127
128 Model::Subgraph mainSubgraph =
129 NN_TRY(createSubgraph(testModel.main, &operandValues, &memoryBuilder));
130 std::vector<Model::Subgraph> refSubgraphs;
131 refSubgraphs.reserve(testModel.referenced.size());
132 for (const auto& testSubgraph : testModel.referenced) {
133 refSubgraphs.push_back(
134 NN_TRY(createSubgraph(testSubgraph, &operandValues, &memoryBuilder)));
135 }
136
137 // Shared memory.
138 std::vector<SharedMemory> pools;
139 if (!memoryBuilder.empty()) {
140 pools.push_back(NN_TRY(memoryBuilder.finish()));
141 }
142
143 return Model{.main = std::move(mainSubgraph),
144 .referenced = std::move(refSubgraphs),
145 .operandValues = std::move(operandValues),
146 .pools = std::move(pools),
147 .relaxComputationFloat32toFloat16 = testModel.isRelaxed};
148 }
149
createRequest(const TestModel & testModel)150 GeneralResult<Request> createRequest(const TestModel& testModel) {
151 // Model inputs.
152 std::vector<Request::Argument> inputs;
153 inputs.reserve(testModel.main.inputIndexes.size());
154 for (uint32_t operandIndex : testModel.main.inputIndexes) {
155 NN_RET_CHECK_LT(operandIndex, testModel.main.operands.size())
156 << "createRequest failed because inputIndex of operand " << operandIndex
157 << " exceeds number of operands " << testModel.main.operands.size();
158
159 const auto& op = testModel.main.operands[operandIndex];
160 Request::Argument requestArgument;
161 if (op.data.size() == 0) {
162 // Omitted input.
163 requestArgument = {.lifetime = Request::Argument::LifeTime::NO_VALUE};
164 } else {
165 const auto location = DataLocation{.pointer = op.data.get<void>(),
166 .length = static_cast<uint32_t>(op.data.size())};
167 requestArgument = {.lifetime = Request::Argument::LifeTime::POINTER,
168 .location = location,
169 .dimensions = op.dimensions};
170 }
171 inputs.push_back(std::move(requestArgument));
172 }
173
174 // Model outputs.
175 std::vector<Request::Argument> outputs;
176 outputs.reserve(testModel.main.outputIndexes.size());
177 MutableMemoryBuilder outputBuilder(0);
178 for (uint32_t operandIndex : testModel.main.outputIndexes) {
179 NN_RET_CHECK_LT(operandIndex, testModel.main.operands.size())
180 << "createRequest failed because outputIndex of operand " << operandIndex
181 << " exceeds number of operands " << testModel.main.operands.size();
182
183 const auto& op = testModel.main.operands[operandIndex];
184
185 // In the case of zero-sized output, we should at least provide a one-byte buffer.
186 // This is because zero-sized tensors are only supported internally to the driver, or
187 // reported in output shapes. It is illegal for the client to pre-specify a zero-sized
188 // tensor as model output. Otherwise, we will have two semantic conflicts:
189 // - "Zero dimension" conflicts with "unspecified dimension".
190 // - "Omitted operand buffer" conflicts with "zero-sized operand buffer".
191 size_t bufferSize = std::max<size_t>(op.data.size(), 1);
192
193 const DataLocation location = outputBuilder.append(bufferSize);
194 outputs.push_back({.lifetime = Request::Argument::LifeTime::POOL,
195 .location = location,
196 .dimensions = op.dimensions});
197 }
198
199 // Model pools.
200 std::vector<Request::MemoryPool> pools;
201 if (!outputBuilder.empty()) {
202 pools.push_back(NN_TRY(outputBuilder.finish()));
203 }
204
205 return Request{
206 .inputs = std::move(inputs), .outputs = std::move(outputs), .pools = std::move(pools)};
207 }
208
209 } // namespace android::nn::test
210