1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "HashtableLookup.h"
18
19 #include "NeuralNetworksWrapper.h"
20 #include "gmock/gmock-matchers.h"
21 #include "gtest/gtest.h"
22
23 using ::testing::FloatNear;
24 using ::testing::Matcher;
25
26 namespace android {
27 namespace nn {
28 namespace wrapper {
29
30 namespace {
31
ArrayFloatNear(const std::vector<float> & values,float max_abs_error=1.e-6)32 std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
33 float max_abs_error=1.e-6) {
34 std::vector<Matcher<float>> matchers;
35 matchers.reserve(values.size());
36 for (const float& v : values) {
37 matchers.emplace_back(FloatNear(v, max_abs_error));
38 }
39 return matchers;
40 }
41
42 } // namespace
43
44 using ::testing::ElementsAreArray;
45
46 #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
47 ACTION(Lookup, int) \
48 ACTION(Key, int) \
49 ACTION(Value, float)
50
51 // For all output and intermediate states
52 #define FOR_ALL_OUTPUT_TENSORS(ACTION) \
53 ACTION(Output, float) \
54 ACTION(Hits, uint8_t)
55
56 class HashtableLookupOpModel {
57 public:
HashtableLookupOpModel(std::initializer_list<uint32_t> lookup_shape,std::initializer_list<uint32_t> key_shape,std::initializer_list<uint32_t> value_shape)58 HashtableLookupOpModel(std::initializer_list<uint32_t> lookup_shape,
59 std::initializer_list<uint32_t> key_shape,
60 std::initializer_list<uint32_t> value_shape) {
61 auto it_vs = value_shape.begin();
62 rows_ = *it_vs++;
63 features_ = *it_vs;
64
65 std::vector<uint32_t> inputs;
66
67 // Input and weights
68 OperandType LookupTy(Type::TENSOR_INT32, lookup_shape);
69 inputs.push_back(model_.addOperand(&LookupTy));
70
71 OperandType KeyTy(Type::TENSOR_INT32, key_shape);
72 inputs.push_back(model_.addOperand(&KeyTy));
73
74 OperandType ValueTy(Type::TENSOR_FLOAT32, value_shape);
75 inputs.push_back(model_.addOperand(&ValueTy));
76
77 // Output and other intermediate state
78 std::vector<uint32_t> outputs;
79
80 std::vector<uint32_t> out_dim(lookup_shape.begin(), lookup_shape.end());
81 out_dim.push_back(features_);
82
83 OperandType OutputOpndTy(Type::TENSOR_FLOAT32, out_dim);
84 outputs.push_back(model_.addOperand(&OutputOpndTy));
85
86 OperandType HitsOpndTy(Type::TENSOR_QUANT8_ASYMM, lookup_shape, 1.f, 0);
87 outputs.push_back(model_.addOperand(&HitsOpndTy));
88
89 auto multiAll = [](const std::vector<uint32_t> &dims) -> uint32_t {
90 uint32_t sz = 1;
91 for (uint32_t d : dims) { sz *= d; }
92 return sz;
93 };
94
95 Value_.insert(Value_.end(), multiAll(value_shape), 0.f);
96 Output_.insert(Output_.end(), multiAll(out_dim), 0.f);
97 Hits_.insert(Hits_.end(), multiAll(lookup_shape), 0);
98
99 model_.addOperation(ANEURALNETWORKS_HASHTABLE_LOOKUP, inputs, outputs);
100 model_.identifyInputsAndOutputs(inputs, outputs);
101
102 model_.finish();
103 }
104
Invoke()105 void Invoke() {
106 ASSERT_TRUE(model_.isValid());
107
108 Compilation compilation(&model_);
109 compilation.finish();
110 Execution execution(&compilation);
111
112 #define SetInputOrWeight(X, T) \
113 ASSERT_EQ(execution.setInput(HashtableLookup::k##X##Tensor, X##_.data(), \
114 sizeof(T) * X##_.size()), \
115 Result::NO_ERROR);
116
117 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
118
119 #undef SetInputOrWeight
120
121 #define SetOutput(X, T) \
122 ASSERT_EQ(execution.setOutput(HashtableLookup::k##X##Tensor, X##_.data(), \
123 sizeof(T) * X##_.size()), \
124 Result::NO_ERROR);
125
126 FOR_ALL_OUTPUT_TENSORS(SetOutput);
127
128 #undef SetOutput
129
130 ASSERT_EQ(execution.compute(), Result::NO_ERROR);
131 }
132
133 #define DefineSetter(X, T) \
134 void Set##X(const std::vector<T>& f) { \
135 X##_.insert(X##_.end(), f.begin(), f.end()); \
136 }
137
138 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
139
140 #undef DefineSetter
141
SetHashtableValue(const std::function<float (uint32_t,uint32_t)> & function)142 void SetHashtableValue(const std::function<float(uint32_t, uint32_t)>& function) {
143 for (uint32_t i = 0; i < rows_; i++) {
144 for (uint32_t j = 0; j < features_; j++) {
145 Value_[i * features_ + j] = function(i, j);
146 }
147 }
148 }
149
GetOutput() const150 const std::vector<float>& GetOutput() const { return Output_; }
GetHits() const151 const std::vector<uint8_t>& GetHits() const { return Hits_; }
152
153 private:
154 Model model_;
155 uint32_t rows_;
156 uint32_t features_;
157
158 #define DefineTensor(X, T) std::vector<T> X##_;
159
160 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
161 FOR_ALL_OUTPUT_TENSORS(DefineTensor);
162
163 #undef DefineTensor
164 };
165
TEST(HashtableLookupOpTest,BlackBoxTest)166 TEST(HashtableLookupOpTest, BlackBoxTest) {
167 HashtableLookupOpModel m({4}, {3}, {3, 2});
168
169 m.SetLookup({1234, -292, -11, 0});
170 m.SetKey({-11, 0, 1234});
171 m.SetHashtableValue([](int i, int j) { return i + j / 10.0f; });
172
173 m.Invoke();
174
175 EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
176 2.0, 2.1, // 2-rd item
177 0, 0, // Not found
178 0.0, 0.1, // 0-th item
179 1.0, 1.1, // 1-st item
180 })));
181 EXPECT_EQ(m.GetHits(), std::vector<uint8_t>({
182 1, 0, 1, 1,
183 }));
184
185 }
186
187 } // namespace wrapper
188 } // namespace nn
189 } // namespace android
190