1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <android-base/logging.h>
18 #include <gmock/gmock.h>
19 #include <gtest/gtest.h>
20 
21 #include <sstream>
22 #include <string>
23 #include <vector>
24 
25 #include "LSTM.h"
26 #include "NeuralNetworksWrapper.h"
27 
28 namespace android {
29 namespace nn {
30 namespace wrapper {
31 
32 using ::testing::Each;
33 using ::testing::FloatNear;
34 using ::testing::Matcher;
35 
36 namespace {
37 
ArrayFloatNear(const std::vector<float> & values,float max_abs_error=1.e-6)38 std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
39                                            float max_abs_error = 1.e-6) {
40     std::vector<Matcher<float>> matchers;
41     matchers.reserve(values.size());
42     for (const float& v : values) {
43         matchers.emplace_back(FloatNear(v, max_abs_error));
44     }
45     return matchers;
46 }
47 
48 }  // anonymous namespace
49 
50 #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
51     ACTION(Input)                                \
52     ACTION(InputToInputWeights)                  \
53     ACTION(InputToCellWeights)                   \
54     ACTION(InputToForgetWeights)                 \
55     ACTION(InputToOutputWeights)                 \
56     ACTION(RecurrentToInputWeights)              \
57     ACTION(RecurrentToCellWeights)               \
58     ACTION(RecurrentToForgetWeights)             \
59     ACTION(RecurrentToOutputWeights)             \
60     ACTION(CellToInputWeights)                   \
61     ACTION(CellToForgetWeights)                  \
62     ACTION(CellToOutputWeights)                  \
63     ACTION(InputGateBias)                        \
64     ACTION(CellGateBias)                         \
65     ACTION(ForgetGateBias)                       \
66     ACTION(OutputGateBias)                       \
67     ACTION(ProjectionWeights)                    \
68     ACTION(ProjectionBias)                       \
69     ACTION(OutputStateIn)                        \
70     ACTION(CellStateIn)
71 
72 #define FOR_ALL_LAYER_NORM_WEIGHTS(ACTION) \
73     ACTION(InputLayerNormWeights)          \
74     ACTION(ForgetLayerNormWeights)         \
75     ACTION(CellLayerNormWeights)           \
76     ACTION(OutputLayerNormWeights)
77 
78 // For all output and intermediate states
79 #define FOR_ALL_OUTPUT_TENSORS(ACTION) \
80     ACTION(ScratchBuffer)              \
81     ACTION(OutputStateOut)             \
82     ACTION(CellStateOut)               \
83     ACTION(Output)
84 
85 class LayerNormLSTMOpModel {
86    public:
LayerNormLSTMOpModel(uint32_t n_batch,uint32_t n_input,uint32_t n_cell,uint32_t n_output,bool use_cifg,bool use_peephole,bool use_projection_weights,bool use_projection_bias,float cell_clip,float proj_clip,const std::vector<std::vector<uint32_t>> & input_shapes0)87     LayerNormLSTMOpModel(uint32_t n_batch, uint32_t n_input, uint32_t n_cell, uint32_t n_output,
88                          bool use_cifg, bool use_peephole, bool use_projection_weights,
89                          bool use_projection_bias, float cell_clip, float proj_clip,
90                          const std::vector<std::vector<uint32_t>>& input_shapes0)
91         : n_input_(n_input),
92           n_output_(n_output),
93           use_cifg_(use_cifg),
94           use_peephole_(use_peephole),
95           use_projection_weights_(use_projection_weights),
96           use_projection_bias_(use_projection_bias),
97           activation_(ActivationFn::kActivationTanh),
98           cell_clip_(cell_clip),
99           proj_clip_(proj_clip) {
100         std::vector<uint32_t> inputs;
101         std::vector<std::vector<uint32_t>> input_shapes(input_shapes0);
102 
103         auto it = input_shapes.begin();
104 
105         // Input and weights
106 #define AddInput(X)                                     \
107     CHECK(it != input_shapes.end());                    \
108     OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it++); \
109     inputs.push_back(model_.addOperand(&X##OpndTy));
110 
111         FOR_ALL_INPUT_AND_WEIGHT_TENSORS(AddInput);
112 
113         // Parameters
114         OperandType ActivationOpndTy(Type::INT32, {});
115         inputs.push_back(model_.addOperand(&ActivationOpndTy));
116         OperandType CellClipOpndTy(Type::FLOAT32, {});
117         inputs.push_back(model_.addOperand(&CellClipOpndTy));
118         OperandType ProjClipOpndTy(Type::FLOAT32, {});
119         inputs.push_back(model_.addOperand(&ProjClipOpndTy));
120 
121         FOR_ALL_LAYER_NORM_WEIGHTS(AddInput);
122 
123 #undef AddOperand
124 
125         // Output and other intermediate state
126         std::vector<std::vector<uint32_t>> output_shapes{
127                 {n_batch, n_cell * (use_cifg ? 3 : 4)},
128                 {n_batch, n_output},
129                 {n_batch, n_cell},
130                 {n_batch, n_output},
131         };
132         std::vector<uint32_t> outputs;
133 
134         auto it2 = output_shapes.begin();
135 
136 #define AddOutput(X)                                     \
137     CHECK(it2 != output_shapes.end());                   \
138     OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it2++); \
139     outputs.push_back(model_.addOperand(&X##OpndTy));
140 
141         FOR_ALL_OUTPUT_TENSORS(AddOutput);
142 
143 #undef AddOutput
144 
145         model_.addOperation(ANEURALNETWORKS_LSTM, inputs, outputs);
146         model_.identifyInputsAndOutputs(inputs, outputs);
147 
148         Input_.insert(Input_.end(), n_batch * n_input, 0.f);
149         OutputStateIn_.insert(OutputStateIn_.end(), n_batch * n_output, 0.f);
150         CellStateIn_.insert(CellStateIn_.end(), n_batch * n_cell, 0.f);
151 
152         auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t {
153             uint32_t sz = 1;
154             for (uint32_t d : dims) {
155                 sz *= d;
156             }
157             return sz;
158         };
159 
160         it2 = output_shapes.begin();
161 
162 #define ReserveOutput(X) X##_.insert(X##_.end(), multiAll(*it2++), 0.f);
163 
164         FOR_ALL_OUTPUT_TENSORS(ReserveOutput);
165 
166 #undef ReserveOutput
167 
168         model_.finish();
169     }
170 
171 #define DefineSetter(X) \
172     void Set##X(const std::vector<float>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }
173 
174     FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
175     FOR_ALL_LAYER_NORM_WEIGHTS(DefineSetter);
176 
177 #undef DefineSetter
178 
ResetOutputState()179     void ResetOutputState() {
180         std::fill(OutputStateIn_.begin(), OutputStateIn_.end(), 0.f);
181         std::fill(OutputStateOut_.begin(), OutputStateOut_.end(), 0.f);
182     }
183 
ResetCellState()184     void ResetCellState() {
185         std::fill(CellStateIn_.begin(), CellStateIn_.end(), 0.f);
186         std::fill(CellStateOut_.begin(), CellStateOut_.end(), 0.f);
187     }
188 
SetInput(int offset,const float * begin,const float * end)189     void SetInput(int offset, const float* begin, const float* end) {
190         for (; begin != end; begin++, offset++) {
191             Input_[offset] = *begin;
192         }
193     }
194 
num_inputs() const195     uint32_t num_inputs() const { return n_input_; }
num_outputs() const196     uint32_t num_outputs() const { return n_output_; }
197 
GetOutput() const198     const std::vector<float>& GetOutput() const { return Output_; }
199 
Invoke()200     void Invoke() {
201         ASSERT_TRUE(model_.isValid());
202 
203         OutputStateIn_.swap(OutputStateOut_);
204         CellStateIn_.swap(CellStateOut_);
205 
206         Compilation compilation(&model_);
207         compilation.finish();
208         Execution execution(&compilation);
209 #define SetInputOrWeight(X)                                                                       \
210     ASSERT_EQ(                                                                                    \
211             execution.setInput(LSTMCell::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
212             Result::NO_ERROR);
213 
214         FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
215         FOR_ALL_LAYER_NORM_WEIGHTS(SetInputOrWeight);
216 
217 #undef SetInputOrWeight
218 
219 #define SetOutput(X)                                                                               \
220     ASSERT_EQ(                                                                                     \
221             execution.setOutput(LSTMCell::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
222             Result::NO_ERROR);
223 
224         FOR_ALL_OUTPUT_TENSORS(SetOutput);
225 
226 #undef SetOutput
227 
228         if (use_cifg_) {
229             execution.setInput(LSTMCell::kInputToInputWeightsTensor, nullptr, 0);
230             execution.setInput(LSTMCell::kRecurrentToInputWeightsTensor, nullptr, 0);
231         }
232 
233         if (use_peephole_) {
234             if (use_cifg_) {
235                 execution.setInput(LSTMCell::kCellToInputWeightsTensor, nullptr, 0);
236             }
237         } else {
238             execution.setInput(LSTMCell::kCellToInputWeightsTensor, nullptr, 0);
239             execution.setInput(LSTMCell::kCellToForgetWeightsTensor, nullptr, 0);
240             execution.setInput(LSTMCell::kCellToOutputWeightsTensor, nullptr, 0);
241         }
242 
243         if (use_projection_weights_) {
244             if (!use_projection_bias_) {
245                 execution.setInput(LSTMCell::kProjectionBiasTensor, nullptr, 0);
246             }
247         } else {
248             execution.setInput(LSTMCell::kProjectionWeightsTensor, nullptr, 0);
249             execution.setInput(LSTMCell::kProjectionBiasTensor, nullptr, 0);
250         }
251 
252         ASSERT_EQ(execution.setInput(LSTMCell::kActivationParam, &activation_, sizeof(activation_)),
253                   Result::NO_ERROR);
254         ASSERT_EQ(execution.setInput(LSTMCell::kCellClipParam, &cell_clip_, sizeof(cell_clip_)),
255                   Result::NO_ERROR);
256         ASSERT_EQ(execution.setInput(LSTMCell::kProjClipParam, &proj_clip_, sizeof(proj_clip_)),
257                   Result::NO_ERROR);
258 
259         ASSERT_EQ(execution.compute(), Result::NO_ERROR);
260     }
261 
262    private:
263     Model model_;
264     // Execution execution_;
265     const uint32_t n_input_;
266     const uint32_t n_output_;
267 
268     const bool use_cifg_;
269     const bool use_peephole_;
270     const bool use_projection_weights_;
271     const bool use_projection_bias_;
272 
273     const int activation_;
274     const float cell_clip_;
275     const float proj_clip_;
276 
277 #define DefineTensor(X) std::vector<float> X##_;
278 
279     FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
280     FOR_ALL_LAYER_NORM_WEIGHTS(DefineTensor);
281     FOR_ALL_OUTPUT_TENSORS(DefineTensor);
282 
283 #undef DefineTensor
284 };
285 
TEST(LSTMOpTest,LayerNormNoCifgPeepholeProjectionNoClipping)286 TEST(LSTMOpTest, LayerNormNoCifgPeepholeProjectionNoClipping) {
287     const int n_batch = 2;
288     const int n_input = 5;
289     // n_cell and n_output have the same size when there is no projection.
290     const int n_cell = 4;
291     const int n_output = 3;
292 
293     LayerNormLSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
294                               /*use_cifg=*/false, /*use_peephole=*/true,
295                               /*use_projection_weights=*/true,
296                               /*use_projection_bias=*/false,
297                               /*cell_clip=*/0.0, /*proj_clip=*/0.0,
298                               {
299                                       {n_batch, n_input},  // input tensor
300 
301                                       {n_cell, n_input},  // input_to_input_weight tensor
302                                       {n_cell, n_input},  // input_to_forget_weight tensor
303                                       {n_cell, n_input},  // input_to_cell_weight tensor
304                                       {n_cell, n_input},  // input_to_output_weight tensor
305 
306                                       {n_cell, n_output},  // recurrent_to_input_weight tensor
307                                       {n_cell, n_output},  // recurrent_to_forget_weight tensor
308                                       {n_cell, n_output},  // recurrent_to_cell_weight tensor
309                                       {n_cell, n_output},  // recurrent_to_output_weight tensor
310 
311                                       {n_cell},  // cell_to_input_weight tensor
312                                       {n_cell},  // cell_to_forget_weight tensor
313                                       {n_cell},  // cell_to_output_weight tensor
314 
315                                       {n_cell},  // input_gate_bias tensor
316                                       {n_cell},  // forget_gate_bias tensor
317                                       {n_cell},  // cell_bias tensor
318                                       {n_cell},  // output_gate_bias tensor
319 
320                                       {n_output, n_cell},  // projection_weight tensor
321                                       {0},                 // projection_bias tensor
322 
323                                       {n_batch, n_output},  // output_state_in tensor
324                                       {n_batch, n_cell},    // cell_state_in tensor
325 
326                                       {n_cell},  // input_layer_norm_weights tensor
327                                       {n_cell},  // forget_layer_norm_weights tensor
328                                       {n_cell},  // cell_layer_norm_weights tensor
329                                       {n_cell},  // output_layer_norm_weights tensor
330                               });
331 
332     lstm.SetInputToInputWeights({0.5,  0.6, 0.7,  -0.8, -0.9, 0.1,  0.2,  0.3,  -0.4, 0.5,
333                                  -0.8, 0.7, -0.6, 0.5,  -0.4, -0.5, -0.4, -0.3, -0.2, -0.1});
334 
335     lstm.SetInputToForgetWeights({-0.6, -0.1, 0.3,  0.2,  0.9,  -0.5, -0.2, -0.4, 0.3,  -0.8,
336                                   -0.4, 0.3,  -0.5, -0.4, -0.6, 0.3,  -0.4, -0.6, -0.5, -0.5});
337 
338     lstm.SetInputToCellWeights({-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, -0.3, -0.2, -0.6,
339                                 0.6,  -0.1, -0.4, -0.3, -0.7, 0.7, -0.9, -0.5, 0.8,  0.6});
340 
341     lstm.SetInputToOutputWeights({-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, -0.3, -0.8, -0.2,
342                                   0.6,  -0.2, 0.4,  -0.7, -0.3, -0.5, 0.1, 0.5,  -0.6, -0.4});
343 
344     lstm.SetInputGateBias({0.03, 0.15, 0.22, 0.38});
345 
346     lstm.SetForgetGateBias({0.1, -0.3, -0.2, 0.1});
347 
348     lstm.SetCellGateBias({-0.05, 0.72, 0.25, 0.08});
349 
350     lstm.SetOutputGateBias({0.05, -0.01, 0.2, 0.1});
351 
352     lstm.SetRecurrentToInputWeights(
353             {-0.2, -0.3, 0.4, 0.1, -0.5, 0.9, -0.2, -0.3, -0.7, 0.05, -0.2, -0.6});
354 
355     lstm.SetRecurrentToCellWeights(
356             {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08, -0.2, 0.3, 0.8, -0.6, -0.1, 0.2});
357 
358     lstm.SetRecurrentToForgetWeights(
359             {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4, 0.9, 0.3, -0.1, 0.2, 0.5, 0.2});
360 
361     lstm.SetRecurrentToOutputWeights(
362             {0.3, -0.1, 0.1, -0.2, -0.5, -0.7, -0.2, -0.6, -0.1, -0.4, -0.7, -0.2});
363 
364     lstm.SetCellToInputWeights({0.05, 0.1, 0.25, 0.15});
365     lstm.SetCellToForgetWeights({-0.02, -0.15, -0.25, -0.03});
366     lstm.SetCellToOutputWeights({0.1, -0.1, -0.5, 0.05});
367 
368     lstm.SetProjectionWeights({-0.1, 0.2, 0.01, -0.2, 0.1, 0.5, 0.3, 0.08, 0.07, 0.2, -0.4, 0.2});
369 
370     lstm.SetInputLayerNormWeights({0.1, 0.2, 0.3, 0.5});
371     lstm.SetForgetLayerNormWeights({0.2, 0.2, 0.4, 0.3});
372     lstm.SetCellLayerNormWeights({0.7, 0.2, 0.3, 0.8});
373     lstm.SetOutputLayerNormWeights({0.6, 0.2, 0.2, 0.5});
374 
375     const std::vector<std::vector<float>> lstm_input = {
376             {                           // Batch0: 3 (input_sequence_size) * 5 (n_input)
377              0.7, 0.8, 0.1, 0.2, 0.3,   // seq 0
378              0.8, 0.1, 0.2, 0.4, 0.5,   // seq 1
379              0.2, 0.7, 0.7, 0.1, 0.7},  // seq 2
380 
381             {                           // Batch1: 3 (input_sequence_size) * 5 (n_input)
382              0.3, 0.2, 0.9, 0.8, 0.1,   // seq 0
383              0.1, 0.5, 0.2, 0.4, 0.2,   // seq 1
384              0.6, 0.9, 0.2, 0.5, 0.7},  // seq 2
385     };
386 
387     const std::vector<std::vector<float>> lstm_golden_output = {
388             {
389                     // Batch0: 3 (input_sequence_size) * 3 (n_output)
390                     0.0244077, 0.128027, -0.00170918,  // seq 0
391                     0.0137642, 0.140751, 0.0395835,    // seq 1
392                     -0.00459231, 0.155278, 0.0837377,  // seq 2
393             },
394             {
395                     // Batch1: 3 (input_sequence_size) * 3 (n_output)
396                     -0.00692428, 0.0848741, 0.063445,  // seq 0
397                     -0.00403912, 0.139963, 0.072681,   // seq 1
398                     0.00752706, 0.161903, 0.0561371,   // seq 2
399             }};
400 
401     // Resetting cell_state and output_state
402     lstm.ResetCellState();
403     lstm.ResetOutputState();
404 
405     const int input_sequence_size = lstm_input[0].size() / n_input;
406     for (int i = 0; i < input_sequence_size; i++) {
407         for (int b = 0; b < n_batch; ++b) {
408             const float* batch_start = lstm_input[b].data() + i * n_input;
409             const float* batch_end = batch_start + n_input;
410 
411             lstm.SetInput(b * n_input, batch_start, batch_end);
412         }
413 
414         lstm.Invoke();
415 
416         std::vector<float> expected;
417         for (int b = 0; b < n_batch; ++b) {
418             const float* golden_start = lstm_golden_output[b].data() + i * n_output;
419             const float* golden_end = golden_start + n_output;
420             expected.insert(expected.end(), golden_start, golden_end);
421         }
422         EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
423     }
424 }
425 
426 }  // namespace wrapper
427 }  // namespace nn
428 }  // namespace android
429