1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "LSHProjection.h"
18 
19 #include "NeuralNetworksWrapper.h"
20 #include "gmock/gmock-generated-matchers.h"
21 #include "gmock/gmock-matchers.h"
22 #include "gtest/gtest.h"
23 
24 using ::testing::FloatNear;
25 using ::testing::Matcher;
26 
27 namespace android {
28 namespace nn {
29 namespace wrapper {
30 
31 using ::testing::ElementsAre;
32 
33 #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
34     ACTION(Hash, float)                          \
35     ACTION(Input, int)                           \
36     ACTION(Weight, float)
37 
38 // For all output and intermediate states
39 #define FOR_ALL_OUTPUT_TENSORS(ACTION) ACTION(Output, int)
40 
41 class LSHProjectionOpModel {
42    public:
LSHProjectionOpModel(LSHProjectionType type,std::initializer_list<uint32_t> hash_shape,std::initializer_list<uint32_t> input_shape,std::initializer_list<uint32_t> weight_shape)43     LSHProjectionOpModel(LSHProjectionType type, std::initializer_list<uint32_t> hash_shape,
44                          std::initializer_list<uint32_t> input_shape,
45                          std::initializer_list<uint32_t> weight_shape)
46         : type_(type) {
47         std::vector<uint32_t> inputs;
48 
49         OperandType HashTy(Type::TENSOR_FLOAT32, hash_shape);
50         inputs.push_back(model_.addOperand(&HashTy));
51         OperandType InputTy(Type::TENSOR_INT32, input_shape);
52         inputs.push_back(model_.addOperand(&InputTy));
53         OperandType WeightTy(Type::TENSOR_FLOAT32, weight_shape);
54         inputs.push_back(model_.addOperand(&WeightTy));
55 
56         OperandType TypeParamTy(Type::INT32, {});
57         inputs.push_back(model_.addOperand(&TypeParamTy));
58 
59         std::vector<uint32_t> outputs;
60 
61         auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t {
62             uint32_t sz = 1;
63             for (uint32_t d : dims) {
64                 sz *= d;
65             }
66             return sz;
67         };
68 
69         uint32_t outShapeDimension = 0;
70         if (type == LSHProjectionType_SPARSE || type == LSHProjectionType_SPARSE_DEPRECATED) {
71             auto it = hash_shape.begin();
72             Output_.insert(Output_.end(), *it, 0.f);
73             outShapeDimension = *it;
74         } else {
75             Output_.insert(Output_.end(), multiAll(hash_shape), 0.f);
76             outShapeDimension = multiAll(hash_shape);
77         }
78 
79         OperandType OutputTy(Type::TENSOR_INT32, {outShapeDimension});
80         outputs.push_back(model_.addOperand(&OutputTy));
81 
82         model_.addOperation(ANEURALNETWORKS_LSH_PROJECTION, inputs, outputs);
83         model_.identifyInputsAndOutputs(inputs, outputs);
84 
85         model_.finish();
86     }
87 
88 #define DefineSetter(X, T) \
89     void Set##X(const std::vector<T>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }
90 
91     FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
92 
93 #undef DefineSetter
94 
GetOutput() const95     const std::vector<int>& GetOutput() const { return Output_; }
96 
Invoke()97     void Invoke() {
98         ASSERT_TRUE(model_.isValid());
99 
100         Compilation compilation(&model_);
101         compilation.finish();
102         Execution execution(&compilation);
103 
104 #define SetInputOrWeight(X, T)                                                                     \
105     ASSERT_EQ(                                                                                     \
106             execution.setInput(LSHProjection::k##X##Tensor, X##_.data(), sizeof(T) * X##_.size()), \
107             Result::NO_ERROR);
108 
109         FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
110 
111 #undef SetInputOrWeight
112 
113 #define SetOutput(X, T)                                                     \
114     ASSERT_EQ(execution.setOutput(LSHProjection::k##X##Tensor, X##_.data(), \
115                                   sizeof(T) * X##_.size()),                 \
116               Result::NO_ERROR);
117 
118         FOR_ALL_OUTPUT_TENSORS(SetOutput);
119 
120 #undef SetOutput
121 
122         ASSERT_EQ(execution.setInput(LSHProjection::kTypeParam, &type_, sizeof(type_)),
123                   Result::NO_ERROR);
124 
125         ASSERT_EQ(execution.compute(), Result::NO_ERROR);
126     }
127 
128    private:
129     Model model_;
130     LSHProjectionType type_;
131 
132     std::vector<float> Hash_;
133     std::vector<int> Input_;
134     std::vector<float> Weight_;
135     std::vector<int> Output_;
136 };  // namespace wrapper
137 
TEST(LSHProjectionOpTest2,DenseWithThreeInputs)138 TEST(LSHProjectionOpTest2, DenseWithThreeInputs) {
139     LSHProjectionOpModel m(LSHProjectionType_DENSE, {4, 2}, {3, 2}, {3});
140 
141     m.SetInput({12345, 54321, 67890, 9876, -12345678, -87654321});
142     m.SetHash({0.123, 0.456, -0.321, -0.654, 1.234, 5.678, -4.321, -8.765});
143     m.SetWeight({0.12, 0.34, 0.56});
144 
145     m.Invoke();
146 
147     EXPECT_THAT(m.GetOutput(), ElementsAre(1, 1, 1, 0, 1, 1, 1, 0));
148 }
149 
TEST(LSHProjectionOpTest2,SparseDeprecatedWithTwoInputs)150 TEST(LSHProjectionOpTest2, SparseDeprecatedWithTwoInputs) {
151     LSHProjectionOpModel m(LSHProjectionType_SPARSE_DEPRECATED, {4, 2}, {3, 2}, {0});
152 
153     m.SetInput({12345, 54321, 67890, 9876, -12345678, -87654321});
154     m.SetHash({0.123, 0.456, -0.321, -0.654, 1.234, 5.678, -4.321, -8.765});
155 
156     m.Invoke();
157 
158     EXPECT_THAT(m.GetOutput(), ElementsAre(1, 2, 2, 0));
159 }
160 
TEST(LSHProjectionOpTest2,SparseWithTwoInputs)161 TEST(LSHProjectionOpTest2, SparseWithTwoInputs) {
162     LSHProjectionOpModel m(LSHProjectionType_SPARSE, {4, 2}, {3, 2}, {0});
163 
164     m.SetInput({12345, 54321, 67890, 9876, -12345678, -87654321});
165     m.SetHash({0.123, 0.456, -0.321, -0.654, 1.234, 5.678, -4.321, -8.765});
166 
167     m.Invoke();
168 
169     EXPECT_THAT(m.GetOutput(), ElementsAre(1, 6, 10, 12));
170 }
171 
172 }  // namespace wrapper
173 }  // namespace nn
174 }  // namespace android
175