1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 // Top level driver for models and examples generated by test_generator.py
18
19 #include "Bridge.h"
20 #include "NeuralNetworksWrapper.h"
21 #include "TestHarness.h"
22
23 #include <gtest/gtest.h>
24 #include <cassert>
25 #include <cmath>
26 #include <fstream>
27 #include <iostream>
28 #include <map>
29
30 // Uncomment the following line to generate DOT graphs.
31 //
32 // #define GRAPH GRAPH
33
34 namespace generated_tests {
35 using namespace android::nn::wrapper;
36 using namespace test_helper;
37
graphDump(const char * name,const Model & model)38 void graphDump([[maybe_unused]] const char* name, [[maybe_unused]] const Model& model) {
39 #ifdef GRAPH
40 ::android::nn::bridge_tests::graphDump(
41 name,
42 reinterpret_cast<const ::android::nn::ModelBuilder*>(model.getHandle()));
43 #endif
44 }
45
46 template <typename T>
print(std::ostream & os,const MixedTyped & test)47 static void print(std::ostream& os, const MixedTyped& test) {
48 // dump T-typed inputs
49 for_each<T>(test, [&os](int idx, const std::vector<T>& f) {
50 os << " aliased_output" << idx << ": [";
51 for (size_t i = 0; i < f.size(); ++i) {
52 os << (i == 0 ? "" : ", ") << +f[i];
53 }
54 os << "],\n";
55 });
56 }
57
printAll(std::ostream & os,const MixedTyped & test)58 static void printAll(std::ostream& os, const MixedTyped& test) {
59 print<float>(os, test);
60 print<int32_t>(os, test);
61 print<uint8_t>(os, test);
62 }
63
64 // Test driver for those generated from ml/nn/runtime/test/spec
execute(std::function<void (Model *)> createModel,std::function<bool (int)> isIgnored,std::vector<MixedTypedExampleType> & examples,std::string dumpFile="")65 static void execute(std::function<void(Model*)> createModel,
66 std::function<bool(int)> isIgnored,
67 std::vector<MixedTypedExampleType>& examples,
68 std::string dumpFile = "") {
69 Model model;
70 createModel(&model);
71 model.finish();
72 graphDump("", model);
73 bool dumpToFile = !dumpFile.empty();
74
75 std::ofstream s;
76 if (dumpToFile) {
77 s.open(dumpFile, std::ofstream::trunc);
78 ASSERT_TRUE(s.is_open());
79 }
80
81 int exampleNo = 0;
82 Compilation compilation(&model);
83 compilation.finish();
84
85 // If in relaxed mode, set the error range to be 5ULP of FP16.
86 float fpRange = !model.isRelaxed() ? 1e-5f : 5.0f * 0.0009765625f;
87 for (auto& example : examples) {
88 SCOPED_TRACE(exampleNo);
89 // TODO: We leave it as a copy here.
90 // Should verify if the input gets modified by the test later.
91 MixedTyped inputs = example.first;
92 const MixedTyped& golden = example.second;
93
94 Execution execution(&compilation);
95
96 // Set all inputs
97 for_all(inputs, [&execution](int idx, const void* p, size_t s) {
98 const void* buffer = s == 0 ? nullptr : p;
99 ASSERT_EQ(Result::NO_ERROR, execution.setInput(idx, buffer, s));
100 });
101
102 MixedTyped test;
103 // Go through all typed outputs
104 resize_accordingly(golden, test);
105 for_all(test, [&execution](int idx, void* p, size_t s) {
106 void* buffer = s == 0 ? nullptr : p;
107 ASSERT_EQ(Result::NO_ERROR, execution.setOutput(idx, buffer, s));
108 });
109
110 Result r = execution.compute();
111 ASSERT_EQ(Result::NO_ERROR, r);
112
113 // Dump all outputs for the slicing tool
114 if (dumpToFile) {
115 s << "output" << exampleNo << " = {\n";
116 printAll(s, test);
117 // all outputs are done
118 s << "}\n";
119 }
120
121 // Filter out don't cares
122 MixedTyped filteredGolden = filter(golden, isIgnored);
123 MixedTyped filteredTest = filter(test, isIgnored);
124 // We want "close-enough" results for float
125
126 compare(filteredGolden, filteredTest, fpRange);
127 exampleNo++;
128 }
129 }
130
131 }; // namespace generated_tests
132
133 using namespace android::nn::wrapper;
134
135 // Mixed-typed examples
136 typedef test_helper::MixedTypedExampleType MixedTypedExample;
137
138 class GeneratedTests : public ::testing::Test {
139 protected:
SetUp()140 virtual void SetUp() {}
141 };
142
143 // Testcases generated from runtime/test/specs/*.mod.py
144 using namespace test_helper;
145 using namespace generated_tests;
146 #include "generated/all_generated_tests.cpp"
147 // End of testcases generated from runtime/test/specs/*.mod.py
148