1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "HashtableLookup.h"
18 
19 #include "NeuralNetworksWrapper.h"
20 #include "gmock/gmock-matchers.h"
21 #include "gtest/gtest.h"
22 
23 using ::testing::FloatNear;
24 using ::testing::Matcher;
25 
26 namespace android {
27 namespace nn {
28 namespace wrapper {
29 
30 namespace {
31 
ArrayFloatNear(const std::vector<float> & values,float max_abs_error=1.e-6)32 std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
33                                            float max_abs_error=1.e-6) {
34   std::vector<Matcher<float>> matchers;
35   matchers.reserve(values.size());
36   for (const float& v : values) {
37     matchers.emplace_back(FloatNear(v, max_abs_error));
38   }
39   return matchers;
40 }
41 
42 }  // namespace
43 
44 using ::testing::ElementsAreArray;
45 
46 #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION)     \
47   ACTION(Lookup, int)                                \
48   ACTION(Key, int)                                   \
49   ACTION(Value, float)
50 
51 // For all output and intermediate states
52 #define FOR_ALL_OUTPUT_TENSORS(ACTION) \
53   ACTION(Output, float)                \
54   ACTION(Hits, uint8_t)
55 
56 class HashtableLookupOpModel {
57  public:
HashtableLookupOpModel(std::initializer_list<uint32_t> lookup_shape,std::initializer_list<uint32_t> key_shape,std::initializer_list<uint32_t> value_shape)58     HashtableLookupOpModel(std::initializer_list<uint32_t> lookup_shape,
59                            std::initializer_list<uint32_t> key_shape,
60                            std::initializer_list<uint32_t> value_shape) {
61     auto it_vs = value_shape.begin();
62     rows_ = *it_vs++;
63     features_ = *it_vs;
64 
65     std::vector<uint32_t> inputs;
66 
67     // Input and weights
68     OperandType LookupTy(Type::TENSOR_INT32, lookup_shape);
69     inputs.push_back(model_.addOperand(&LookupTy));
70 
71     OperandType KeyTy(Type::TENSOR_INT32, key_shape);
72     inputs.push_back(model_.addOperand(&KeyTy));
73 
74     OperandType ValueTy(Type::TENSOR_FLOAT32, value_shape);
75     inputs.push_back(model_.addOperand(&ValueTy));
76 
77     // Output and other intermediate state
78     std::vector<uint32_t> outputs;
79 
80     std::vector<uint32_t> out_dim(lookup_shape.begin(), lookup_shape.end());
81     out_dim.push_back(features_);
82 
83     OperandType OutputOpndTy(Type::TENSOR_FLOAT32, out_dim);
84     outputs.push_back(model_.addOperand(&OutputOpndTy));
85 
86     OperandType HitsOpndTy(Type::TENSOR_QUANT8_ASYMM, lookup_shape, 1.f, 0);
87     outputs.push_back(model_.addOperand(&HitsOpndTy));
88 
89     auto multiAll = [](const std::vector<uint32_t> &dims) -> uint32_t {
90         uint32_t sz = 1;
91         for (uint32_t d : dims) { sz *= d; }
92         return sz;
93     };
94 
95     Value_.insert(Value_.end(), multiAll(value_shape), 0.f);
96     Output_.insert(Output_.end(), multiAll(out_dim), 0.f);
97     Hits_.insert(Hits_.end(), multiAll(lookup_shape), 0);
98 
99     model_.addOperation(ANEURALNETWORKS_HASHTABLE_LOOKUP, inputs, outputs);
100     model_.identifyInputsAndOutputs(inputs, outputs);
101 
102     model_.finish();
103   }
104 
Invoke()105   void Invoke() {
106     ASSERT_TRUE(model_.isValid());
107 
108     Compilation compilation(&model_);
109     compilation.finish();
110     Execution execution(&compilation);
111 
112 #define SetInputOrWeight(X, T)                                             \
113   ASSERT_EQ(execution.setInput(HashtableLookup::k##X##Tensor, X##_.data(), \
114                                sizeof(T) * X##_.size()),                   \
115             Result::NO_ERROR);
116 
117     FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
118 
119 #undef SetInputOrWeight
120 
121 #define SetOutput(X, T)                                                     \
122   ASSERT_EQ(execution.setOutput(HashtableLookup::k##X##Tensor, X##_.data(), \
123                                sizeof(T) * X##_.size()),                    \
124             Result::NO_ERROR);
125 
126     FOR_ALL_OUTPUT_TENSORS(SetOutput);
127 
128 #undef SetOutput
129 
130     ASSERT_EQ(execution.compute(), Result::NO_ERROR);
131   }
132 
133 #define DefineSetter(X, T)                       \
134   void Set##X(const std::vector<T>& f) {         \
135     X##_.insert(X##_.end(), f.begin(), f.end()); \
136   }
137 
138   FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
139 
140 #undef DefineSetter
141 
SetHashtableValue(const std::function<float (uint32_t,uint32_t)> & function)142   void SetHashtableValue(const std::function<float(uint32_t, uint32_t)>& function) {
143     for (uint32_t i = 0; i < rows_; i++) {
144       for (uint32_t j = 0; j < features_; j++) {
145           Value_[i * features_ + j] = function(i, j);
146       }
147     }
148   }
149 
GetOutput() const150   const std::vector<float>& GetOutput() const { return Output_; }
GetHits() const151   const std::vector<uint8_t>& GetHits() const { return Hits_; }
152 
153  private:
154   Model model_;
155   uint32_t rows_;
156   uint32_t features_;
157 
158 #define DefineTensor(X, T) std::vector<T> X##_;
159 
160   FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
161   FOR_ALL_OUTPUT_TENSORS(DefineTensor);
162 
163 #undef DefineTensor
164 };
165 
TEST(HashtableLookupOpTest,BlackBoxTest)166 TEST(HashtableLookupOpTest, BlackBoxTest) {
167   HashtableLookupOpModel m({4}, {3}, {3, 2});
168 
169   m.SetLookup({1234, -292, -11, 0});
170   m.SetKey({-11, 0, 1234});
171   m.SetHashtableValue([](int i, int j) { return i + j / 10.0f; });
172 
173   m.Invoke();
174 
175   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
176                                  2.0, 2.1,  // 2-rd item
177                                  0, 0,      // Not found
178                                  0.0, 0.1,  // 0-th item
179                                  1.0, 1.1,  // 1-st item
180                              })));
181   EXPECT_EQ(m.GetHits(), std::vector<uint8_t>({
182                                1, 0, 1, 1,
183                            }));
184 
185 }
186 
187 }  // namespace wrapper
188 }  // namespace nn
189 }  // namespace android
190