1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <android-base/logging.h>
18 
19 #include <cstdlib>
20 #include <optional>
21 #include <utility>
22 
23 #include "NeuralNetworksWrapper.h"
24 #include "TestHarness.h"
25 
26 namespace {
27 
28 using ::android::nn::wrapper::Compilation;
29 using ::android::nn::wrapper::Execution;
30 using ::android::nn::wrapper::Model;
31 using ::android::nn::wrapper::OperandType;
32 using ::android::nn::wrapper::Result;
33 using ::android::nn::wrapper::SymmPerChannelQuantParams;
34 using ::android::nn::wrapper::Type;
35 using ::test_helper::TestModel;
36 using ::test_helper::TestOperand;
37 using ::test_helper::TestOperandLifeTime;
38 using ::test_helper::TestOperandType;
39 using ::test_helper::TestSubgraph;
40 
getOperandType(const TestOperand & op)41 OperandType getOperandType(const TestOperand& op) {
42     const auto& dims = op.dimensions;
43     if (op.type == TestOperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
44         return OperandType(
45                 static_cast<Type>(op.type), dims,
46                 SymmPerChannelQuantParams(op.channelQuant.scales, op.channelQuant.channelDim));
47     } else {
48         return OperandType(static_cast<Type>(op.type), dims, op.scale, op.zeroPoint);
49     }
50 }
51 
52 enum class Visited : uint8_t {
53     NOT_YET_VISITED,
54     CURRENTLY_VISITING,
55     ALREADY_VISITED,
56 };
57 
areSubgraphsAcyclic(const TestModel & testModel,size_t index,std::vector<Visited> * visited,std::vector<size_t> * order)58 bool areSubgraphsAcyclic(const TestModel& testModel, size_t index, std::vector<Visited>* visited,
59                          std::vector<size_t>* order) {
60     if (index >= visited->size()) return false;
61     Visited& status = (*visited)[index];
62 
63     if (status == Visited::CURRENTLY_VISITING) return false;
64     if (status == Visited::ALREADY_VISITED) return true;
65     status = Visited::CURRENTLY_VISITING;
66 
67     const auto& subgraph = index == 0 ? testModel.main : testModel.referenced[index - 1];
68     for (const auto& operand : subgraph.operands) {
69         if (operand.lifetime != TestOperandLifeTime::SUBGRAPH) continue;
70         if (operand.data.size() < sizeof(uint32_t)) return false;
71         if (operand.data.get<void>() == nullptr) return false;
72         const uint32_t subgraphIndex = *operand.data.get<uint32_t>();
73         if (!areSubgraphsAcyclic(testModel, subgraphIndex + 1, visited, order)) return false;
74     }
75 
76     status = Visited::ALREADY_VISITED;
77     order->push_back(index);
78     return true;
79 }
80 
getSubgraphOrder(const TestModel & testModel)81 std::optional<std::vector<size_t>> getSubgraphOrder(const TestModel& testModel) {
82     std::vector<Visited> visited(testModel.referenced.size() + 1, Visited::NOT_YET_VISITED);
83     std::vector<size_t> order;
84     order.reserve(visited.size());
85     if (!areSubgraphsAcyclic(testModel, 0, &visited, &order)) return std::nullopt;
86     return order;
87 }
88 
CreateSubgraph(const TestModel & testModel,size_t subgraphIndex,const std::vector<Model> & subgraphs)89 std::optional<Model> CreateSubgraph(const TestModel& testModel, size_t subgraphIndex,
90                                     const std::vector<Model>& subgraphs) {
91     const TestSubgraph& testSubgraph =
92             subgraphIndex == 0 ? testModel.main : testModel.referenced[subgraphIndex - 1];
93     Model model;
94 
95     // Operands.
96     for (const auto& operand : testSubgraph.operands) {
97         auto type = getOperandType(operand);
98         auto index = model.addOperand(&type);
99 
100         switch (operand.lifetime) {
101             case TestOperandLifeTime::CONSTANT_COPY:
102             case TestOperandLifeTime::CONSTANT_REFERENCE:
103                 model.setOperandValue(index, operand.data.get<void>(), operand.data.size());
104                 break;
105             case TestOperandLifeTime::NO_VALUE:
106                 model.setOperandValue(index, nullptr, 0);
107                 break;
108             case TestOperandLifeTime::SUBGRAPH: {
109                 const uint32_t referencedSubgraphIndex = *operand.data.get<uint32_t>();
110                 model.setOperandValueFromModel(index, &subgraphs[referencedSubgraphIndex]);
111             } break;
112             case TestOperandLifeTime::SUBGRAPH_INPUT:
113             case TestOperandLifeTime::SUBGRAPH_OUTPUT:
114             case TestOperandLifeTime::TEMPORARY_VARIABLE:
115                 // Nothing to do here.
116                 break;
117         }
118         if (!model.isValid()) return std::nullopt;
119     }
120 
121     // Operations.
122     for (const auto& operation : testSubgraph.operations) {
123         model.addOperation(static_cast<int>(operation.type), operation.inputs, operation.outputs);
124         if (!model.isValid()) return std::nullopt;
125     }
126 
127     // Inputs and outputs.
128     model.identifyInputsAndOutputs(testSubgraph.inputIndexes, testSubgraph.outputIndexes);
129     if (!model.isValid()) return std::nullopt;
130 
131     // Relaxed computation.
132     model.relaxComputationFloat32toFloat16(testModel.isRelaxed);
133     if (!model.isValid()) return std::nullopt;
134 
135     if (model.finish() != Result::NO_ERROR) {
136         return std::nullopt;
137     }
138 
139     return model;
140 }
141 
142 // The first Model returned is the main model. Any subsequent Models are referenced models.
CreateModels(const TestModel & testModel)143 std::optional<std::vector<Model>> CreateModels(const TestModel& testModel) {
144     auto subgraphOrder = getSubgraphOrder(testModel);
145     if (!subgraphOrder.has_value()) return std::nullopt;
146 
147     std::vector<Model> subgraphs(testModel.referenced.size() + 1);
148     for (size_t index : subgraphOrder.value()) {
149         auto subgraph = CreateSubgraph(testModel, index, subgraphs);
150         if (!subgraph.has_value()) return std::nullopt;
151         subgraphs[index] = std::move(subgraph).value();
152     }
153 
154     return subgraphs;
155 }
156 
CreateCompilation(const Model & model)157 std::optional<Compilation> CreateCompilation(const Model& model) {
158     Compilation compilation(&model);
159     if (compilation.finish() != Result::NO_ERROR) {
160         return std::nullopt;
161     }
162     return compilation;
163 }
164 
CreateExecution(const Compilation & compilation,const TestModel & testModel)165 std::optional<Execution> CreateExecution(const Compilation& compilation,
166                                          const TestModel& testModel) {
167     Execution execution(&compilation);
168 
169     // Model inputs.
170     for (uint32_t i = 0; i < testModel.main.inputIndexes.size(); i++) {
171         const auto& operand = testModel.main.operands[testModel.main.inputIndexes[i]];
172         if (execution.setInput(i, operand.data.get<void>(), operand.data.size()) !=
173             Result::NO_ERROR) {
174             return std::nullopt;
175         }
176     }
177 
178     // Model outputs.
179     for (uint32_t i = 0; i < testModel.main.outputIndexes.size(); i++) {
180         const auto& operand = testModel.main.operands[testModel.main.outputIndexes[i]];
181         if (execution.setOutput(i, const_cast<void*>(operand.data.get<void>()),
182                                 operand.data.size()) != Result::NO_ERROR) {
183             return std::nullopt;
184         }
185     }
186 
187     return execution;
188 }
189 
190 }  // anonymous namespace
191 
nnapiFuzzTest(const TestModel & testModel)192 void nnapiFuzzTest(const TestModel& testModel) {
193     // set up model
194     auto models = CreateModels(testModel);
195     if (!models.has_value() || models->empty()) {
196         return;
197     }
198 
199     // set up compilation
200     auto compilation = CreateCompilation(models->front());
201     if (!compilation.has_value()) {
202         return;
203     }
204 
205     // set up execution
206     auto execution = CreateExecution(*compilation, testModel);
207     if (!execution.has_value()) {
208         return;
209     }
210 
211     // perform execution
212     execution->compute();
213 }
214