1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 /* Header-only library for various helpers of test harness
18  * See frameworks/ml/nn/runtime/test/TestGenerated.cpp for how this is used.
19  */
20 #ifndef ANDROID_ML_NN_TOOLS_TEST_GENERATOR_TEST_HARNESS_H
21 #define ANDROID_ML_NN_TOOLS_TEST_GENERATOR_TEST_HARNESS_H
22 
23 #include <gtest/gtest.h>
24 
25 #include <cmath>
26 #include <functional>
27 #include <map>
28 #include <tuple>
29 #include <vector>
30 
31 namespace test_helper {
32 
33 constexpr const size_t gMaximumNumberOfErrorMessages = 10;
34 
35 typedef std::map<int, std::vector<float>> Float32Operands;
36 typedef std::map<int, std::vector<int32_t>> Int32Operands;
37 typedef std::map<int, std::vector<uint8_t>> Quant8Operands;
38 typedef std::tuple<Float32Operands,  // ANEURALNETWORKS_TENSOR_FLOAT32
39                    Int32Operands,    // ANEURALNETWORKS_TENSOR_INT32
40                    Quant8Operands    // ANEURALNETWORKS_TENSOR_QUANT8_ASYMM
41                    >
42         MixedTyped;
43 typedef std::pair<MixedTyped, MixedTyped> MixedTypedExampleType;
44 
45 template <typename T>
46 struct MixedTypedIndex {};
47 
48 template <>
49 struct MixedTypedIndex<float> {
50     static constexpr size_t index = 0;
51 };
52 template <>
53 struct MixedTypedIndex<int32_t> {
54     static constexpr size_t index = 1;
55 };
56 template <>
57 struct MixedTypedIndex<uint8_t> {
58     static constexpr size_t index = 2;
59 };
60 
61 // Go through all index-value pairs of a given input type
62 template <typename T>
63 inline void for_each(const MixedTyped& idx_and_data,
64                      std::function<void(int, const std::vector<T>&)> execute) {
65     for (auto& i : std::get<MixedTypedIndex<T>::index>(idx_and_data)) {
66         execute(i.first, i.second);
67     }
68 }
69 
70 // non-const variant of for_each
71 template <typename T>
72 inline void for_each(MixedTyped& idx_and_data,
73                      std::function<void(int, std::vector<T>&)> execute) {
74     for (auto& i : std::get<MixedTypedIndex<T>::index>(idx_and_data)) {
75         execute(i.first, i.second);
76     }
77 }
78 
79 // internal helper for for_all
80 template <typename T>
81 inline void for_all_internal(
82         MixedTyped& idx_and_data,
83         std::function<void(int, void*, size_t)> execute_this) {
84     for_each<T>(idx_and_data, [&execute_this](int idx, std::vector<T>& m) {
85         execute_this(idx, static_cast<void*>(m.data()), m.size() * sizeof(T));
86     });
87 }
88 
89 // Go through all index-value pairs of all input types
90 // expects a functor that takes (int index, void *raw data, size_t sz)
91 inline void for_all(MixedTyped& idx_and_data,
92                     std::function<void(int, void*, size_t)> execute_this) {
93     for_all_internal<float>(idx_and_data, execute_this);
94     for_all_internal<int32_t>(idx_and_data, execute_this);
95     for_all_internal<uint8_t>(idx_and_data, execute_this);
96 }
97 
98 // Const variant of internal helper for for_all
99 template <typename T>
100 inline void for_all_internal(
101         const MixedTyped& idx_and_data,
102         std::function<void(int, const void*, size_t)> execute_this) {
103     for_each<T>(idx_and_data, [&execute_this](int idx, const std::vector<T>& m) {
104         execute_this(idx, static_cast<const void*>(m.data()), m.size() * sizeof(T));
105     });
106 }
107 
108 // Go through all index-value pairs (const variant)
109 // expects a functor that takes (int index, const void *raw data, size_t sz)
110 inline void for_all(
111         const MixedTyped& idx_and_data,
112         std::function<void(int, const void*, size_t)> execute_this) {
113     for_all_internal<float>(idx_and_data, execute_this);
114     for_all_internal<int32_t>(idx_and_data, execute_this);
115     for_all_internal<uint8_t>(idx_and_data, execute_this);
116 }
117 
118 // Helper template - resize test output per golden
119 template <typename ty, size_t tuple_index>
120 void resize_accordingly_(const MixedTyped& golden, MixedTyped& test) {
121     std::function<void(int, const std::vector<ty>&)> execute =
122             [&test](int index, const std::vector<ty>& m) {
123                 auto& t = std::get<tuple_index>(test);
124                 t[index].resize(m.size());
125             };
126     for_each<ty>(golden, execute);
127 }
128 
129 inline void resize_accordingly(const MixedTyped& golden, MixedTyped& test) {
130     resize_accordingly_<float, 0>(golden, test);
131     resize_accordingly_<int32_t, 1>(golden, test);
132     resize_accordingly_<uint8_t, 2>(golden, test);
133 }
134 
135 template <typename ty, size_t tuple_index>
136 void filter_internal(const MixedTyped& golden, MixedTyped* filtered,
137                      std::function<bool(int)> is_ignored) {
138     for_each<ty>(golden,
139                  [filtered, &is_ignored](int index, const std::vector<ty>& m) {
140                      auto& g = std::get<tuple_index>(*filtered);
141                      if (!is_ignored(index)) g[index] = m;
142                  });
143 }
144 
145 inline MixedTyped filter(const MixedTyped& golden,
146                          std::function<bool(int)> is_ignored) {
147     MixedTyped filtered;
148     filter_internal<float, 0>(golden, &filtered, is_ignored);
149     filter_internal<int32_t, 1>(golden, &filtered, is_ignored);
150     filter_internal<uint8_t, 2>(golden, &filtered, is_ignored);
151     return filtered;
152 }
153 
154 // Compare results
155 #define VECTOR_TYPE(x) \
156     typename std::tuple_element<x, MixedTyped>::type::mapped_type
157 #define VALUE_TYPE(x) VECTOR_TYPE(x)::value_type
158 template <size_t tuple_index>
159 void compare_(
160         const MixedTyped& golden, const MixedTyped& test,
161         std::function<void(VALUE_TYPE(tuple_index), VALUE_TYPE(tuple_index))>
162                 cmp) {
163     for_each<VALUE_TYPE(tuple_index)>(
164             golden,
165             [&test, &cmp](int index, const VECTOR_TYPE(tuple_index) & m) {
166                 const auto& test_operands = std::get<tuple_index>(test);
167                 const auto& test_ty = test_operands.find(index);
168                 ASSERT_NE(test_ty, test_operands.end());
169                 for (unsigned int i = 0; i < m.size(); i++) {
170                     SCOPED_TRACE(testing::Message()
171                                  << "When comparing element " << i);
172                     cmp(m[i], test_ty->second[i]);
173                 }
174             });
175 }
176 #undef VALUE_TYPE
177 #undef VECTOR_TYPE
178 inline void compare(const MixedTyped& golden, const MixedTyped& test, float fpRange = 1e-5f) {
179     size_t totalNumberOfErrors = 0;
180     compare_<0>(golden, test, [&totalNumberOfErrors, fpRange](float g, float t) {
181         if (totalNumberOfErrors < gMaximumNumberOfErrorMessages) {
182             EXPECT_NEAR(g, t, fpRange);
183         }
184         if (std::abs(g - t) > fpRange) {
185             totalNumberOfErrors++;
186         }
187     });
188     compare_<1>(golden, test, [&totalNumberOfErrors](int32_t g, int32_t t) {
189         if (totalNumberOfErrors < gMaximumNumberOfErrorMessages) {
190             EXPECT_EQ(g, t);
191         }
192         if (g != t) {
193             totalNumberOfErrors++;
194         }
195     });
196     compare_<2>(golden, test, [&totalNumberOfErrors](uint8_t g, uint8_t t) {
197         if (totalNumberOfErrors < gMaximumNumberOfErrorMessages) {
198             EXPECT_NEAR(g, t, 1);
199         }
200         if (std::abs(g - t) > 1) {
201             totalNumberOfErrors++;
202         }
203     });
204     EXPECT_EQ(size_t{0}, totalNumberOfErrors);
205 }
206 
207 };  // namespace test_helper
208 
209 #endif  // ANDROID_ML_NN_TOOLS_TEST_GENERATOR_TEST_HARNESS_H
210