1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <android-base/logging.h>
18 #include <gmock/gmock.h>
19 #include <gtest/gtest.h>
20
21 #include <sstream>
22 #include <string>
23 #include <vector>
24
25 #include "LSTM.h"
26 #include "NeuralNetworksWrapper.h"
27
28 namespace android {
29 namespace nn {
30 namespace wrapper {
31
32 using ::testing::Each;
33 using ::testing::FloatNear;
34 using ::testing::Matcher;
35
36 namespace {
37
ArrayFloatNear(const std::vector<float> & values,float max_abs_error=1.e-6)38 std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
39 float max_abs_error = 1.e-6) {
40 std::vector<Matcher<float>> matchers;
41 matchers.reserve(values.size());
42 for (const float& v : values) {
43 matchers.emplace_back(FloatNear(v, max_abs_error));
44 }
45 return matchers;
46 }
47
48 } // anonymous namespace
49
50 #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
51 ACTION(Input) \
52 ACTION(InputToInputWeights) \
53 ACTION(InputToCellWeights) \
54 ACTION(InputToForgetWeights) \
55 ACTION(InputToOutputWeights) \
56 ACTION(RecurrentToInputWeights) \
57 ACTION(RecurrentToCellWeights) \
58 ACTION(RecurrentToForgetWeights) \
59 ACTION(RecurrentToOutputWeights) \
60 ACTION(CellToInputWeights) \
61 ACTION(CellToForgetWeights) \
62 ACTION(CellToOutputWeights) \
63 ACTION(InputGateBias) \
64 ACTION(CellGateBias) \
65 ACTION(ForgetGateBias) \
66 ACTION(OutputGateBias) \
67 ACTION(ProjectionWeights) \
68 ACTION(ProjectionBias) \
69 ACTION(OutputStateIn) \
70 ACTION(CellStateIn)
71
72 #define FOR_ALL_LAYER_NORM_WEIGHTS(ACTION) \
73 ACTION(InputLayerNormWeights) \
74 ACTION(ForgetLayerNormWeights) \
75 ACTION(CellLayerNormWeights) \
76 ACTION(OutputLayerNormWeights)
77
78 // For all output and intermediate states
79 #define FOR_ALL_OUTPUT_TENSORS(ACTION) \
80 ACTION(ScratchBuffer) \
81 ACTION(OutputStateOut) \
82 ACTION(CellStateOut) \
83 ACTION(Output)
84
85 class LayerNormLSTMOpModel {
86 public:
LayerNormLSTMOpModel(uint32_t n_batch,uint32_t n_input,uint32_t n_cell,uint32_t n_output,bool use_cifg,bool use_peephole,bool use_projection_weights,bool use_projection_bias,float cell_clip,float proj_clip,const std::vector<std::vector<uint32_t>> & input_shapes0)87 LayerNormLSTMOpModel(uint32_t n_batch, uint32_t n_input, uint32_t n_cell, uint32_t n_output,
88 bool use_cifg, bool use_peephole, bool use_projection_weights,
89 bool use_projection_bias, float cell_clip, float proj_clip,
90 const std::vector<std::vector<uint32_t>>& input_shapes0)
91 : n_input_(n_input),
92 n_output_(n_output),
93 use_cifg_(use_cifg),
94 use_peephole_(use_peephole),
95 use_projection_weights_(use_projection_weights),
96 use_projection_bias_(use_projection_bias),
97 activation_(ActivationFn::kActivationTanh),
98 cell_clip_(cell_clip),
99 proj_clip_(proj_clip) {
100 std::vector<uint32_t> inputs;
101 std::vector<std::vector<uint32_t>> input_shapes(input_shapes0);
102
103 auto it = input_shapes.begin();
104
105 // Input and weights
106 #define AddInput(X) \
107 CHECK(it != input_shapes.end()); \
108 OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it++); \
109 inputs.push_back(model_.addOperand(&X##OpndTy));
110
111 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(AddInput);
112
113 // Parameters
114 OperandType ActivationOpndTy(Type::INT32, {});
115 inputs.push_back(model_.addOperand(&ActivationOpndTy));
116 OperandType CellClipOpndTy(Type::FLOAT32, {});
117 inputs.push_back(model_.addOperand(&CellClipOpndTy));
118 OperandType ProjClipOpndTy(Type::FLOAT32, {});
119 inputs.push_back(model_.addOperand(&ProjClipOpndTy));
120
121 FOR_ALL_LAYER_NORM_WEIGHTS(AddInput);
122
123 #undef AddOperand
124
125 // Output and other intermediate state
126 std::vector<std::vector<uint32_t>> output_shapes{
127 {n_batch, n_cell * (use_cifg ? 3 : 4)},
128 {n_batch, n_output},
129 {n_batch, n_cell},
130 {n_batch, n_output},
131 };
132 std::vector<uint32_t> outputs;
133
134 auto it2 = output_shapes.begin();
135
136 #define AddOutput(X) \
137 CHECK(it2 != output_shapes.end()); \
138 OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it2++); \
139 outputs.push_back(model_.addOperand(&X##OpndTy));
140
141 FOR_ALL_OUTPUT_TENSORS(AddOutput);
142
143 #undef AddOutput
144
145 model_.addOperation(ANEURALNETWORKS_LSTM, inputs, outputs);
146 model_.identifyInputsAndOutputs(inputs, outputs);
147
148 Input_.insert(Input_.end(), n_batch * n_input, 0.f);
149 OutputStateIn_.insert(OutputStateIn_.end(), n_batch * n_output, 0.f);
150 CellStateIn_.insert(CellStateIn_.end(), n_batch * n_cell, 0.f);
151
152 auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t {
153 uint32_t sz = 1;
154 for (uint32_t d : dims) {
155 sz *= d;
156 }
157 return sz;
158 };
159
160 it2 = output_shapes.begin();
161
162 #define ReserveOutput(X) X##_.insert(X##_.end(), multiAll(*it2++), 0.f);
163
164 FOR_ALL_OUTPUT_TENSORS(ReserveOutput);
165
166 #undef ReserveOutput
167
168 model_.finish();
169 }
170
171 #define DefineSetter(X) \
172 void Set##X(const std::vector<float>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }
173
174 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
175 FOR_ALL_LAYER_NORM_WEIGHTS(DefineSetter);
176
177 #undef DefineSetter
178
ResetOutputState()179 void ResetOutputState() {
180 std::fill(OutputStateIn_.begin(), OutputStateIn_.end(), 0.f);
181 std::fill(OutputStateOut_.begin(), OutputStateOut_.end(), 0.f);
182 }
183
ResetCellState()184 void ResetCellState() {
185 std::fill(CellStateIn_.begin(), CellStateIn_.end(), 0.f);
186 std::fill(CellStateOut_.begin(), CellStateOut_.end(), 0.f);
187 }
188
SetInput(int offset,const float * begin,const float * end)189 void SetInput(int offset, const float* begin, const float* end) {
190 for (; begin != end; begin++, offset++) {
191 Input_[offset] = *begin;
192 }
193 }
194
num_inputs() const195 uint32_t num_inputs() const { return n_input_; }
num_outputs() const196 uint32_t num_outputs() const { return n_output_; }
197
GetOutput() const198 const std::vector<float>& GetOutput() const { return Output_; }
199
Invoke()200 void Invoke() {
201 ASSERT_TRUE(model_.isValid());
202
203 OutputStateIn_.swap(OutputStateOut_);
204 CellStateIn_.swap(CellStateOut_);
205
206 Compilation compilation(&model_);
207 compilation.finish();
208 Execution execution(&compilation);
209 #define SetInputOrWeight(X) \
210 ASSERT_EQ( \
211 execution.setInput(LSTMCell::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
212 Result::NO_ERROR);
213
214 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
215 FOR_ALL_LAYER_NORM_WEIGHTS(SetInputOrWeight);
216
217 #undef SetInputOrWeight
218
219 #define SetOutput(X) \
220 ASSERT_EQ( \
221 execution.setOutput(LSTMCell::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
222 Result::NO_ERROR);
223
224 FOR_ALL_OUTPUT_TENSORS(SetOutput);
225
226 #undef SetOutput
227
228 if (use_cifg_) {
229 execution.setInput(LSTMCell::kInputToInputWeightsTensor, nullptr, 0);
230 execution.setInput(LSTMCell::kRecurrentToInputWeightsTensor, nullptr, 0);
231 }
232
233 if (use_peephole_) {
234 if (use_cifg_) {
235 execution.setInput(LSTMCell::kCellToInputWeightsTensor, nullptr, 0);
236 }
237 } else {
238 execution.setInput(LSTMCell::kCellToInputWeightsTensor, nullptr, 0);
239 execution.setInput(LSTMCell::kCellToForgetWeightsTensor, nullptr, 0);
240 execution.setInput(LSTMCell::kCellToOutputWeightsTensor, nullptr, 0);
241 }
242
243 if (use_projection_weights_) {
244 if (!use_projection_bias_) {
245 execution.setInput(LSTMCell::kProjectionBiasTensor, nullptr, 0);
246 }
247 } else {
248 execution.setInput(LSTMCell::kProjectionWeightsTensor, nullptr, 0);
249 execution.setInput(LSTMCell::kProjectionBiasTensor, nullptr, 0);
250 }
251
252 ASSERT_EQ(execution.setInput(LSTMCell::kActivationParam, &activation_, sizeof(activation_)),
253 Result::NO_ERROR);
254 ASSERT_EQ(execution.setInput(LSTMCell::kCellClipParam, &cell_clip_, sizeof(cell_clip_)),
255 Result::NO_ERROR);
256 ASSERT_EQ(execution.setInput(LSTMCell::kProjClipParam, &proj_clip_, sizeof(proj_clip_)),
257 Result::NO_ERROR);
258
259 ASSERT_EQ(execution.compute(), Result::NO_ERROR);
260 }
261
262 private:
263 Model model_;
264 // Execution execution_;
265 const uint32_t n_input_;
266 const uint32_t n_output_;
267
268 const bool use_cifg_;
269 const bool use_peephole_;
270 const bool use_projection_weights_;
271 const bool use_projection_bias_;
272
273 const int activation_;
274 const float cell_clip_;
275 const float proj_clip_;
276
277 #define DefineTensor(X) std::vector<float> X##_;
278
279 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
280 FOR_ALL_LAYER_NORM_WEIGHTS(DefineTensor);
281 FOR_ALL_OUTPUT_TENSORS(DefineTensor);
282
283 #undef DefineTensor
284 };
285
TEST(LSTMOpTest,LayerNormNoCifgPeepholeProjectionNoClipping)286 TEST(LSTMOpTest, LayerNormNoCifgPeepholeProjectionNoClipping) {
287 const int n_batch = 2;
288 const int n_input = 5;
289 // n_cell and n_output have the same size when there is no projection.
290 const int n_cell = 4;
291 const int n_output = 3;
292
293 LayerNormLSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
294 /*use_cifg=*/false, /*use_peephole=*/true,
295 /*use_projection_weights=*/true,
296 /*use_projection_bias=*/false,
297 /*cell_clip=*/0.0, /*proj_clip=*/0.0,
298 {
299 {n_batch, n_input}, // input tensor
300
301 {n_cell, n_input}, // input_to_input_weight tensor
302 {n_cell, n_input}, // input_to_forget_weight tensor
303 {n_cell, n_input}, // input_to_cell_weight tensor
304 {n_cell, n_input}, // input_to_output_weight tensor
305
306 {n_cell, n_output}, // recurrent_to_input_weight tensor
307 {n_cell, n_output}, // recurrent_to_forget_weight tensor
308 {n_cell, n_output}, // recurrent_to_cell_weight tensor
309 {n_cell, n_output}, // recurrent_to_output_weight tensor
310
311 {n_cell}, // cell_to_input_weight tensor
312 {n_cell}, // cell_to_forget_weight tensor
313 {n_cell}, // cell_to_output_weight tensor
314
315 {n_cell}, // input_gate_bias tensor
316 {n_cell}, // forget_gate_bias tensor
317 {n_cell}, // cell_bias tensor
318 {n_cell}, // output_gate_bias tensor
319
320 {n_output, n_cell}, // projection_weight tensor
321 {0}, // projection_bias tensor
322
323 {n_batch, n_output}, // output_state_in tensor
324 {n_batch, n_cell}, // cell_state_in tensor
325
326 {n_cell}, // input_layer_norm_weights tensor
327 {n_cell}, // forget_layer_norm_weights tensor
328 {n_cell}, // cell_layer_norm_weights tensor
329 {n_cell}, // output_layer_norm_weights tensor
330 });
331
332 lstm.SetInputToInputWeights({0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2, 0.3, -0.4, 0.5,
333 -0.8, 0.7, -0.6, 0.5, -0.4, -0.5, -0.4, -0.3, -0.2, -0.1});
334
335 lstm.SetInputToForgetWeights({-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, -0.4, 0.3, -0.8,
336 -0.4, 0.3, -0.5, -0.4, -0.6, 0.3, -0.4, -0.6, -0.5, -0.5});
337
338 lstm.SetInputToCellWeights({-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, -0.3, -0.2, -0.6,
339 0.6, -0.1, -0.4, -0.3, -0.7, 0.7, -0.9, -0.5, 0.8, 0.6});
340
341 lstm.SetInputToOutputWeights({-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, -0.3, -0.8, -0.2,
342 0.6, -0.2, 0.4, -0.7, -0.3, -0.5, 0.1, 0.5, -0.6, -0.4});
343
344 lstm.SetInputGateBias({0.03, 0.15, 0.22, 0.38});
345
346 lstm.SetForgetGateBias({0.1, -0.3, -0.2, 0.1});
347
348 lstm.SetCellGateBias({-0.05, 0.72, 0.25, 0.08});
349
350 lstm.SetOutputGateBias({0.05, -0.01, 0.2, 0.1});
351
352 lstm.SetRecurrentToInputWeights(
353 {-0.2, -0.3, 0.4, 0.1, -0.5, 0.9, -0.2, -0.3, -0.7, 0.05, -0.2, -0.6});
354
355 lstm.SetRecurrentToCellWeights(
356 {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08, -0.2, 0.3, 0.8, -0.6, -0.1, 0.2});
357
358 lstm.SetRecurrentToForgetWeights(
359 {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4, 0.9, 0.3, -0.1, 0.2, 0.5, 0.2});
360
361 lstm.SetRecurrentToOutputWeights(
362 {0.3, -0.1, 0.1, -0.2, -0.5, -0.7, -0.2, -0.6, -0.1, -0.4, -0.7, -0.2});
363
364 lstm.SetCellToInputWeights({0.05, 0.1, 0.25, 0.15});
365 lstm.SetCellToForgetWeights({-0.02, -0.15, -0.25, -0.03});
366 lstm.SetCellToOutputWeights({0.1, -0.1, -0.5, 0.05});
367
368 lstm.SetProjectionWeights({-0.1, 0.2, 0.01, -0.2, 0.1, 0.5, 0.3, 0.08, 0.07, 0.2, -0.4, 0.2});
369
370 lstm.SetInputLayerNormWeights({0.1, 0.2, 0.3, 0.5});
371 lstm.SetForgetLayerNormWeights({0.2, 0.2, 0.4, 0.3});
372 lstm.SetCellLayerNormWeights({0.7, 0.2, 0.3, 0.8});
373 lstm.SetOutputLayerNormWeights({0.6, 0.2, 0.2, 0.5});
374
375 const std::vector<std::vector<float>> lstm_input = {
376 { // Batch0: 3 (input_sequence_size) * 5 (n_input)
377 0.7, 0.8, 0.1, 0.2, 0.3, // seq 0
378 0.8, 0.1, 0.2, 0.4, 0.5, // seq 1
379 0.2, 0.7, 0.7, 0.1, 0.7}, // seq 2
380
381 { // Batch1: 3 (input_sequence_size) * 5 (n_input)
382 0.3, 0.2, 0.9, 0.8, 0.1, // seq 0
383 0.1, 0.5, 0.2, 0.4, 0.2, // seq 1
384 0.6, 0.9, 0.2, 0.5, 0.7}, // seq 2
385 };
386
387 const std::vector<std::vector<float>> lstm_golden_output = {
388 {
389 // Batch0: 3 (input_sequence_size) * 3 (n_output)
390 0.0244077, 0.128027, -0.00170918, // seq 0
391 0.0137642, 0.140751, 0.0395835, // seq 1
392 -0.00459231, 0.155278, 0.0837377, // seq 2
393 },
394 {
395 // Batch1: 3 (input_sequence_size) * 3 (n_output)
396 -0.00692428, 0.0848741, 0.063445, // seq 0
397 -0.00403912, 0.139963, 0.072681, // seq 1
398 0.00752706, 0.161903, 0.0561371, // seq 2
399 }};
400
401 // Resetting cell_state and output_state
402 lstm.ResetCellState();
403 lstm.ResetOutputState();
404
405 const int input_sequence_size = lstm_input[0].size() / n_input;
406 for (int i = 0; i < input_sequence_size; i++) {
407 for (int b = 0; b < n_batch; ++b) {
408 const float* batch_start = lstm_input[b].data() + i * n_input;
409 const float* batch_end = batch_start + n_input;
410
411 lstm.SetInput(b * n_input, batch_start, batch_end);
412 }
413
414 lstm.Invoke();
415
416 std::vector<float> expected;
417 for (int b = 0; b < n_batch; ++b) {
418 const float* golden_start = lstm_golden_output[b].data() + i * n_output;
419 const float* golden_end = golden_start + n_output;
420 expected.insert(expected.end(), golden_start, golden_end);
421 }
422 EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
423 }
424 }
425
426 } // namespace wrapper
427 } // namespace nn
428 } // namespace android
429