1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <android-base/logging.h>
18
19 #include <cstdlib>
20 #include <optional>
21 #include <utility>
22
23 #include "NeuralNetworksWrapper.h"
24 #include "TestHarness.h"
25
26 namespace {
27
28 using ::android::nn::wrapper::Compilation;
29 using ::android::nn::wrapper::Execution;
30 using ::android::nn::wrapper::Model;
31 using ::android::nn::wrapper::OperandType;
32 using ::android::nn::wrapper::Result;
33 using ::android::nn::wrapper::SymmPerChannelQuantParams;
34 using ::android::nn::wrapper::Type;
35 using ::test_helper::TestModel;
36 using ::test_helper::TestOperand;
37 using ::test_helper::TestOperandLifeTime;
38 using ::test_helper::TestOperandType;
39 using ::test_helper::TestSubgraph;
40
getOperandType(const TestOperand & op)41 OperandType getOperandType(const TestOperand& op) {
42 const auto& dims = op.dimensions;
43 if (op.type == TestOperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
44 return OperandType(
45 static_cast<Type>(op.type), dims,
46 SymmPerChannelQuantParams(op.channelQuant.scales, op.channelQuant.channelDim));
47 } else {
48 return OperandType(static_cast<Type>(op.type), dims, op.scale, op.zeroPoint);
49 }
50 }
51
52 enum class Visited : uint8_t {
53 NOT_YET_VISITED,
54 CURRENTLY_VISITING,
55 ALREADY_VISITED,
56 };
57
areSubgraphsAcyclic(const TestModel & testModel,size_t index,std::vector<Visited> * visited,std::vector<size_t> * order)58 bool areSubgraphsAcyclic(const TestModel& testModel, size_t index, std::vector<Visited>* visited,
59 std::vector<size_t>* order) {
60 if (index >= visited->size()) return false;
61 Visited& status = (*visited)[index];
62
63 if (status == Visited::CURRENTLY_VISITING) return false;
64 if (status == Visited::ALREADY_VISITED) return true;
65 status = Visited::CURRENTLY_VISITING;
66
67 const auto& subgraph = index == 0 ? testModel.main : testModel.referenced[index - 1];
68 for (const auto& operand : subgraph.operands) {
69 if (operand.lifetime != TestOperandLifeTime::SUBGRAPH) continue;
70 if (operand.data.size() < sizeof(uint32_t)) return false;
71 if (operand.data.get<void>() == nullptr) return false;
72 const uint32_t subgraphIndex = *operand.data.get<uint32_t>();
73 if (!areSubgraphsAcyclic(testModel, subgraphIndex + 1, visited, order)) return false;
74 }
75
76 status = Visited::ALREADY_VISITED;
77 order->push_back(index);
78 return true;
79 }
80
getSubgraphOrder(const TestModel & testModel)81 std::optional<std::vector<size_t>> getSubgraphOrder(const TestModel& testModel) {
82 std::vector<Visited> visited(testModel.referenced.size() + 1, Visited::NOT_YET_VISITED);
83 std::vector<size_t> order;
84 order.reserve(visited.size());
85 if (!areSubgraphsAcyclic(testModel, 0, &visited, &order)) return std::nullopt;
86 return order;
87 }
88
CreateSubgraph(const TestModel & testModel,size_t subgraphIndex,const std::vector<Model> & subgraphs)89 std::optional<Model> CreateSubgraph(const TestModel& testModel, size_t subgraphIndex,
90 const std::vector<Model>& subgraphs) {
91 const TestSubgraph& testSubgraph =
92 subgraphIndex == 0 ? testModel.main : testModel.referenced[subgraphIndex - 1];
93 Model model;
94
95 // Operands.
96 for (const auto& operand : testSubgraph.operands) {
97 auto type = getOperandType(operand);
98 auto index = model.addOperand(&type);
99
100 switch (operand.lifetime) {
101 case TestOperandLifeTime::CONSTANT_COPY:
102 case TestOperandLifeTime::CONSTANT_REFERENCE:
103 model.setOperandValue(index, operand.data.get<void>(), operand.data.size());
104 break;
105 case TestOperandLifeTime::NO_VALUE:
106 model.setOperandValue(index, nullptr, 0);
107 break;
108 case TestOperandLifeTime::SUBGRAPH: {
109 const uint32_t referencedSubgraphIndex = *operand.data.get<uint32_t>();
110 model.setOperandValueFromModel(index, &subgraphs[referencedSubgraphIndex]);
111 } break;
112 case TestOperandLifeTime::SUBGRAPH_INPUT:
113 case TestOperandLifeTime::SUBGRAPH_OUTPUT:
114 case TestOperandLifeTime::TEMPORARY_VARIABLE:
115 // Nothing to do here.
116 break;
117 }
118 if (!model.isValid()) return std::nullopt;
119 }
120
121 // Operations.
122 for (const auto& operation : testSubgraph.operations) {
123 model.addOperation(static_cast<int>(operation.type), operation.inputs, operation.outputs);
124 if (!model.isValid()) return std::nullopt;
125 }
126
127 // Inputs and outputs.
128 model.identifyInputsAndOutputs(testSubgraph.inputIndexes, testSubgraph.outputIndexes);
129 if (!model.isValid()) return std::nullopt;
130
131 // Relaxed computation.
132 model.relaxComputationFloat32toFloat16(testModel.isRelaxed);
133 if (!model.isValid()) return std::nullopt;
134
135 if (model.finish() != Result::NO_ERROR) {
136 return std::nullopt;
137 }
138
139 return model;
140 }
141
142 // The first Model returned is the main model. Any subsequent Models are referenced models.
CreateModels(const TestModel & testModel)143 std::optional<std::vector<Model>> CreateModels(const TestModel& testModel) {
144 auto subgraphOrder = getSubgraphOrder(testModel);
145 if (!subgraphOrder.has_value()) return std::nullopt;
146
147 std::vector<Model> subgraphs(testModel.referenced.size() + 1);
148 for (size_t index : subgraphOrder.value()) {
149 auto subgraph = CreateSubgraph(testModel, index, subgraphs);
150 if (!subgraph.has_value()) return std::nullopt;
151 subgraphs[index] = std::move(subgraph).value();
152 }
153
154 return subgraphs;
155 }
156
CreateCompilation(const Model & model)157 std::optional<Compilation> CreateCompilation(const Model& model) {
158 Compilation compilation(&model);
159 if (compilation.finish() != Result::NO_ERROR) {
160 return std::nullopt;
161 }
162 return compilation;
163 }
164
CreateExecution(const Compilation & compilation,const TestModel & testModel)165 std::optional<Execution> CreateExecution(const Compilation& compilation,
166 const TestModel& testModel) {
167 Execution execution(&compilation);
168
169 // Model inputs.
170 for (uint32_t i = 0; i < testModel.main.inputIndexes.size(); i++) {
171 const auto& operand = testModel.main.operands[testModel.main.inputIndexes[i]];
172 if (execution.setInput(i, operand.data.get<void>(), operand.data.size()) !=
173 Result::NO_ERROR) {
174 return std::nullopt;
175 }
176 }
177
178 // Model outputs.
179 for (uint32_t i = 0; i < testModel.main.outputIndexes.size(); i++) {
180 const auto& operand = testModel.main.operands[testModel.main.outputIndexes[i]];
181 if (execution.setOutput(i, const_cast<void*>(operand.data.get<void>()),
182 operand.data.size()) != Result::NO_ERROR) {
183 return std::nullopt;
184 }
185 }
186
187 return execution;
188 }
189
190 } // anonymous namespace
191
nnapiFuzzTest(const TestModel & testModel)192 void nnapiFuzzTest(const TestModel& testModel) {
193 // set up model
194 auto models = CreateModels(testModel);
195 if (!models.has_value() || models->empty()) {
196 return;
197 }
198
199 // set up compilation
200 auto compilation = CreateCompilation(models->front());
201 if (!compilation.has_value()) {
202 return;
203 }
204
205 // set up execution
206 auto execution = CreateExecution(*compilation, testModel);
207 if (!execution.has_value()) {
208 return;
209 }
210
211 // perform execution
212 execution->compute();
213 }
214