1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <gmock/gmock.h>
18 #include <gtest/gtest.h>
19 
20 #include <functional>
21 #include <vector>
22 
23 #include "EmbeddingLookup.h"
24 #include "NeuralNetworksWrapper.h"
25 
26 using ::testing::FloatNear;
27 using ::testing::Matcher;
28 
29 namespace android {
30 namespace nn {
31 namespace wrapper {
32 
33 namespace {
34 
ArrayFloatNear(const std::vector<float> & values,float max_abs_error=1.e-6)35 std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
36                                            float max_abs_error = 1.e-6) {
37     std::vector<Matcher<float>> matchers;
38     matchers.reserve(values.size());
39     for (const float& v : values) {
40         matchers.emplace_back(FloatNear(v, max_abs_error));
41     }
42     return matchers;
43 }
44 
45 }  // namespace
46 
47 using ::testing::ElementsAreArray;
48 
49 #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
50     ACTION(Value, float)                         \
51     ACTION(Lookup, int)
52 
53 // For all output and intermediate states
54 #define FOR_ALL_OUTPUT_TENSORS(ACTION) ACTION(Output, float)
55 
56 class EmbeddingLookupOpModel {
57    public:
EmbeddingLookupOpModel(std::initializer_list<uint32_t> index_shape,std::initializer_list<uint32_t> weight_shape)58     EmbeddingLookupOpModel(std::initializer_list<uint32_t> index_shape,
59                            std::initializer_list<uint32_t> weight_shape) {
60         auto it = weight_shape.begin();
61         rows_ = *it++;
62         columns_ = *it++;
63         features_ = *it;
64 
65         std::vector<uint32_t> inputs;
66 
67         OperandType LookupTy(Type::TENSOR_INT32, index_shape);
68         inputs.push_back(model_.addOperand(&LookupTy));
69 
70         OperandType ValueTy(Type::TENSOR_FLOAT32, weight_shape);
71         inputs.push_back(model_.addOperand(&ValueTy));
72 
73         std::vector<uint32_t> outputs;
74 
75         OperandType OutputOpndTy(Type::TENSOR_FLOAT32, weight_shape);
76         outputs.push_back(model_.addOperand(&OutputOpndTy));
77 
78         auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t {
79             uint32_t sz = 1;
80             for (uint32_t d : dims) {
81                 sz *= d;
82             }
83             return sz;
84         };
85 
86         Value_.insert(Value_.end(), multiAll(weight_shape), 0.f);
87         Output_.insert(Output_.end(), multiAll(weight_shape), 0.f);
88 
89         model_.addOperation(ANEURALNETWORKS_EMBEDDING_LOOKUP, inputs, outputs);
90         model_.identifyInputsAndOutputs(inputs, outputs);
91 
92         model_.finish();
93     }
94 
Invoke()95     void Invoke() {
96         ASSERT_TRUE(model_.isValid());
97 
98         Compilation compilation(&model_);
99         compilation.finish();
100         Execution execution(&compilation);
101 
102 #define SetInputOrWeight(X, T)                                               \
103     ASSERT_EQ(execution.setInput(EmbeddingLookup::k##X##Tensor, X##_.data(), \
104                                  sizeof(T) * X##_.size()),                   \
105               Result::NO_ERROR);
106 
107         FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
108 
109 #undef SetInputOrWeight
110 
111 #define SetOutput(X, T)                                                       \
112     ASSERT_EQ(execution.setOutput(EmbeddingLookup::k##X##Tensor, X##_.data(), \
113                                   sizeof(T) * X##_.size()),                   \
114               Result::NO_ERROR);
115 
116         FOR_ALL_OUTPUT_TENSORS(SetOutput);
117 
118 #undef SetOutput
119 
120         ASSERT_EQ(execution.compute(), Result::NO_ERROR);
121     }
122 
123 #define DefineSetter(X, T) \
124     void Set##X(const std::vector<T>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }
125 
126     FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
127 
128 #undef DefineSetter
129 
Set3DWeightMatrix(const std::function<float (int,int,int)> & function)130     void Set3DWeightMatrix(const std::function<float(int, int, int)>& function) {
131         for (uint32_t i = 0; i < rows_; i++) {
132             for (uint32_t j = 0; j < columns_; j++) {
133                 for (uint32_t k = 0; k < features_; k++) {
134                     Value_[(i * columns_ + j) * features_ + k] = function(i, j, k);
135                 }
136             }
137         }
138     }
139 
GetOutput() const140     const std::vector<float>& GetOutput() const { return Output_; }
141 
142    private:
143     Model model_;
144     uint32_t rows_;
145     uint32_t columns_;
146     uint32_t features_;
147 
148 #define DefineTensor(X, T) std::vector<T> X##_;
149 
150     FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
151     FOR_ALL_OUTPUT_TENSORS(DefineTensor);
152 
153 #undef DefineTensor
154 };
155 
156 // TODO: write more tests that exercise the details of the op, such as
157 // lookup errors and variable input shapes.
TEST(EmbeddingLookupOpTest,SimpleTest)158 TEST(EmbeddingLookupOpTest, SimpleTest) {
159     EmbeddingLookupOpModel m({3}, {3, 2, 4});
160     m.SetLookup({1, 0, 2});
161     m.Set3DWeightMatrix([](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; });
162 
163     m.Invoke();
164 
165     EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
166                                        1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13,  // Row 1
167                                        0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13,  // Row 0
168                                        2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13,  // Row 2
169                                })));
170 }
171 
172 }  // namespace wrapper
173 }  // namespace nn
174 }  // namespace android
175