1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
16 #include <gtest/gtest.h>
17 #include "tensorflow/lite/interpreter.h"
18 #include "tensorflow/lite/kernels/test_util.h"
19 #include "tensorflow/lite/model.h"
20
21 namespace testing {
22 namespace internal {
23
24 // The CTS test fails to compile without this.
25 // TODO(b/130342510): Find a proper solution.
FormatMatcherDescription(bool unused_negation,const char * matcher_name,const Strings & unused_param_values)26 std::string FormatMatcherDescription(bool unused_negation,
27 const char* matcher_name,
28 const Strings& unused_param_values) {
29 return matcher_name;
30 }
31
32 } // namespace internal
33 } // namespace testing
34
35 namespace tflite {
36 namespace {
37
38 using ::testing::ElementsAre;
39 using ::testing::ElementsAreArray;
40
41 // TODO(b/110368244): figure out how to share the existing tests in kernels/ but
42 // with the delegation on. Also, add more unit tests to improve code coverage.
43
44 // This matcher uses 1 as maximum tolerance.
45 MATCHER(QuantizedNear, "") {
46 const int diff = abs(std::get<0>(arg) - std::get<1>(arg));
47 if (diff > 1) {
48 *result_listener << "Quantized values can be at most off by one: " << diff;
49 return false;
50 }
51 return true;
52 }
53
54 class SingleOpModelWithNNAPI : public SingleOpModel {
55 public:
SingleOpModelWithNNAPI()56 SingleOpModelWithNNAPI() {
57 this->SetApplyDelegate([](Interpreter* interpreter) {
58 interpreter->ModifyGraphWithDelegate(NnApiDelegate());
59 });
60 }
61
ResizeInputTensor(int tensor_index,const std::vector<int> & dims)62 TfLiteStatus ResizeInputTensor(int tensor_index,
63 const std::vector<int>& dims) {
64 return interpreter_->ResizeInputTensor(tensor_index, dims);
65 }
66
67 protected:
SetData(int index,TensorType type,const std::vector<float> & data)68 void SetData(int index, TensorType type, const std::vector<float>& data) {
69 switch (type) {
70 case TensorType_FLOAT32:
71 PopulateTensor(index, data);
72 break;
73 case TensorType_INT32:
74 QuantizeAndPopulate<int32_t>(index, data);
75 break;
76 case TensorType_UINT8:
77 QuantizeAndPopulate<uint8_t>(index, data);
78 break;
79 case TensorType_INT8:
80 QuantizeAndPopulate<int8_t>(index, data);
81 break;
82 default:
83 FAIL() << "Type not supported: " << type;
84 break;
85 }
86 }
87 };
88
89 class FloatAddOpModel : public SingleOpModelWithNNAPI {
90 public:
FloatAddOpModel(const TensorData & input1,const TensorData & input2,const TensorData & output,ActivationFunctionType activation_type,bool allow_fp32_relax_to_fp16=false)91 FloatAddOpModel(const TensorData& input1, const TensorData& input2,
92 const TensorData& output,
93 ActivationFunctionType activation_type,
94 bool allow_fp32_relax_to_fp16 = false) {
95 input1_ = AddInput(input1);
96 input2_ = AddInput(input2);
97 output_ = AddOutput(output);
98 SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions,
99 CreateAddOptions(builder_, activation_type).Union());
100 BuildInterpreter({GetShape(input1_), GetShape(input2_)},
101 allow_fp32_relax_to_fp16);
102 }
103
input1()104 int input1() { return input1_; }
input2()105 int input2() { return input2_; }
106
GetOutput()107 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
108
109 protected:
110 int input1_;
111 int input2_;
112 int output_;
113 };
114
115 // Do a test with the NN API using no activation.
TEST(NNAPIDelegate,AddWithNoActivation)116 TEST(NNAPIDelegate, AddWithNoActivation) {
117 FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
118 {TensorType_FLOAT32, {1, 2, 2, 1}},
119 {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
120 m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
121 m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
122 m.Invoke();
123 EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3}));
124 }
125
126 // Do a test with the NN API using no activation.
127 // The test allows computing FP32 with FP16 precision. In this particular case,
128 // calculating in FP32 or FP16 should produce the same results.
TEST(NNAPIDelegate,AddWithNoActivationRelaxed)129 TEST(NNAPIDelegate, AddWithNoActivationRelaxed) {
130 FloatAddOpModel m(
131 {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}},
132 {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE, true);
133 m.PopulateTensor<float>(m.input1(), {-2.0, -1.0, 1.0, 2.0});
134 m.PopulateTensor<float>(m.input2(), {1.0, 2.0, 3.0, 4.0});
135 m.Invoke();
136 EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.0, 1.0, 4.0, 6.0}));
137 }
138
139 // Do a test with the NN api with relu.
TEST(NNAPIDelegate,AddWithRelu)140 TEST(NNAPIDelegate, AddWithRelu) {
141 FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
142 {TensorType_FLOAT32, {1, 2, 2, 1}},
143 {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU);
144 m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
145 m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
146 m.Invoke();
147 EXPECT_THAT(m.GetOutput(), ElementsAreArray({0.0, 0.4, 1.0, 1.3}));
148 }
149
150 // Verify that resize attempts fail.
151 // TODO(b/113110851): Verify success after the delegate supports resizing.
TEST(NNAPIDelegate,ResizeFails)152 TEST(NNAPIDelegate, ResizeFails) {
153 FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
154 {TensorType_FLOAT32, {1, 2, 2, 1}},
155 {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
156 m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
157 m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
158 EXPECT_EQ(m.ResizeInputTensor(m.input1(), {1, 3, 3, 1}), kTfLiteError);
159 }
160
161 class FloatMulOpModel : public SingleOpModelWithNNAPI {
162 public:
FloatMulOpModel(const TensorData & input1,const TensorData & input2,const TensorData & output,ActivationFunctionType activation_type)163 FloatMulOpModel(const TensorData& input1, const TensorData& input2,
164 const TensorData& output,
165 ActivationFunctionType activation_type) {
166 input1_ = AddInput(input1);
167 input2_ = AddInput(input2);
168 output_ = AddOutput(output);
169 SetBuiltinOp(BuiltinOperator_MUL, BuiltinOptions_MulOptions,
170 CreateMulOptions(builder_, activation_type).Union());
171 BuildInterpreter({GetShape(input1_), GetShape(input2_)});
172 }
173
input1()174 int input1() { return input1_; }
input2()175 int input2() { return input2_; }
176
GetOutput()177 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
178
179 protected:
180 int input1_;
181 int input2_;
182 int output_;
183 };
184
TEST(NNAPIDelegate,MulWithNoActivation)185 TEST(NNAPIDelegate, MulWithNoActivation) {
186 FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
187 {TensorType_FLOAT32, {1, 2, 2, 1}},
188 {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
189 m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
190 m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
191 m.Invoke();
192 EXPECT_THAT(m.GetOutput(),
193 ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4})));
194 }
195
196 class FloatPoolingOpModel : public SingleOpModelWithNNAPI {
197 public:
FloatPoolingOpModel(BuiltinOperator type,const TensorData & input,int filter_width,int filter_height,const TensorData & output)198 FloatPoolingOpModel(BuiltinOperator type, const TensorData& input,
199 int filter_width, int filter_height,
200 const TensorData& output) {
201 input_ = AddInput(input);
202 output_ = AddOutput(output);
203
204 SetBuiltinOp(
205 type, BuiltinOptions_Pool2DOptions,
206 CreatePool2DOptions(builder_, Padding_VALID, 2, 2, filter_width,
207 filter_height, ActivationFunctionType_NONE)
208 .Union());
209
210 BuildInterpreter({GetShape(input_)});
211 }
212
SetInput(std::initializer_list<float> data)213 void SetInput(std::initializer_list<float> data) {
214 PopulateTensor(input_, data);
215 }
216
GetOutput()217 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
218
219 protected:
220 int input_;
221 int output_;
222 };
223
TEST(NNAPIDelegate,AveragePoolWithNoActivation)224 TEST(NNAPIDelegate, AveragePoolWithNoActivation) {
225 FloatPoolingOpModel m(BuiltinOperator_AVERAGE_POOL_2D,
226 /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
227 /*filter_width=*/2, /*filter_height=*/2,
228 /*output=*/{TensorType_FLOAT32, {}});
229 m.SetInput({
230 0, 6, 2, 4, //
231 3, 2, 10, 7, //
232 });
233 m.Invoke();
234 EXPECT_THAT(m.GetOutput(), ElementsAreArray({2.75, 5.75}));
235 }
236
TEST(NNAPIDelegate,MaxPoolWithNoActivation)237 TEST(NNAPIDelegate, MaxPoolWithNoActivation) {
238 FloatPoolingOpModel m(BuiltinOperator_MAX_POOL_2D,
239 /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
240 /*filter_width=*/2, /*filter_height=*/2,
241 /*output=*/{TensorType_FLOAT32, {}});
242 m.SetInput({
243 0, 6, 2, 4, //
244 3, 2, 10, 7, //
245 });
246 m.Invoke();
247 EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 10}));
248 }
249
TEST(NNAPIDelegate,L2PoolWithNoActivation)250 TEST(NNAPIDelegate, L2PoolWithNoActivation) {
251 FloatPoolingOpModel m(BuiltinOperator_L2_POOL_2D,
252 /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
253 /*filter_width=*/2, /*filter_height=*/2,
254 /*output=*/{TensorType_FLOAT32, {}});
255 m.SetInput({
256 0, 6, 2, 4, //
257 3, 2, 10, 7, //
258 });
259 m.Invoke();
260 EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.5, 6.5}));
261 }
262
263 class ConvolutionOpModel : public SingleOpModelWithNNAPI {
264 public:
ConvolutionOpModel(const TensorData & input,const TensorData & filter,const TensorData & output,int stride_width=2,int stride_height=2,enum Padding padding=Padding_VALID,enum ActivationFunctionType activation=ActivationFunctionType_NONE,int dilation_width_factor=1,int dilation_height_factor=1)265 ConvolutionOpModel(
266 const TensorData& input, const TensorData& filter,
267 const TensorData& output, int stride_width = 2, int stride_height = 2,
268 enum Padding padding = Padding_VALID,
269 enum ActivationFunctionType activation = ActivationFunctionType_NONE,
270 int dilation_width_factor = 1, int dilation_height_factor = 1)
271 : input_type_(input.type), filter_type_(filter.type) {
272 input_ = AddInput(input);
273 filter_ = AddInput(filter);
274
275 int bias_size = GetShape(filter_)[0];
276 if (input.type == TensorType_FLOAT32) {
277 bias_ = AddInput({TensorType_FLOAT32, {bias_size}});
278 } else {
279 // This is a quantized version. The scale of 'bias' depends on the scales
280 // of input and filter. Supposedly this is correctly set during quantized
281 // training.
282 auto bias_scale = GetScale(input_) * GetScale(filter_);
283 TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale};
284 bias_ = AddInput(bias);
285 }
286
287 output_ = AddOutput(output);
288
289 if (input_type_ != TensorType_FLOAT32) {
290 // The following is required by quantized inference. It is the unittest's
291 // responsibility to make sure the output scale falls into the correct
292 // range.
293 CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_));
294 }
295
296 SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions,
297 CreateConv2DOptions(
298 builder_, padding, stride_width, stride_height, activation,
299 dilation_width_factor, dilation_height_factor)
300 .Union());
301
302 BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)});
303 }
304
SetInput(std::initializer_list<float> data)305 void SetInput(std::initializer_list<float> data) {
306 SetData(input_, input_type_, data);
307 }
308
SetFilter(std::initializer_list<float> data)309 void SetFilter(std::initializer_list<float> data) {
310 SetData(filter_, filter_type_, data);
311 }
312
SetBias(std::initializer_list<float> data)313 void SetBias(std::initializer_list<float> data) {
314 const auto bias_type =
315 (input_type_ == TensorType_FLOAT32) ? input_type_ : TensorType_INT32;
316 SetData(bias_, bias_type, data);
317 }
318
GetOutput()319 std::vector<float> GetOutput() {
320 if (input_type_ == TensorType_FLOAT32) {
321 return ExtractVector<float>(output_);
322 } else {
323 return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
324 GetScale(output_), GetZeroPoint(output_));
325 }
326 }
327
GetQuantizedOutput()328 std::vector<uint8_t> GetQuantizedOutput() {
329 if (input_type_ == TensorType_FLOAT32) {
330 return {}; // Not supported.
331 } else {
332 return ExtractVector<uint8_t>(output_);
333 }
334 }
335
336 protected:
337 int input_;
338 int filter_;
339 int bias_;
340 int output_;
341
342 const TensorType input_type_;
343 const TensorType filter_type_;
344 };
345
346 // In this tests we set the input and output scales so that the results
347 // match exactly the 'non-quantized' version.
TEST(ConvolutionOpTest,SimpleTestQuantized)348 TEST(ConvolutionOpTest, SimpleTestQuantized) {
349 ConvolutionOpModel m({TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64},
350 {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64},
351 {TensorType_UINT8, {}, -127, 128});
352 m.SetInput({
353 // First batch
354 1, 1, 1, 1, // row = 1
355 2, 2, 2, 2, // row = 2
356 // Second batch
357 1, 2, 3, 4, // row = 1
358 1, 2, 3, 4, // row = 2
359 });
360 m.SetFilter({
361 1, 2, 3, 4, // first 2x2 filter
362 -1, 1, -1, 1, // second 2x2 filter
363 -1, -1, 1, 1, // third 2x2 filter
364 });
365 m.SetBias({1, 2, 3});
366
367 m.Invoke();
368
369 EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
370 {
371 18, 2, 5, // first batch, left
372 18, 2, 5, // first batch, right
373 17, 4, 3, // second batch, left
374 37, 4, 3, // second batch, right
375 },
376 1e-5)));
377 // For good measure, let's also verify the quantized values:
378 EXPECT_THAT(m.GetQuantizedOutput(), ElementsAreArray({
379 145, 129, 132, //
380 145, 129, 132, //
381 144, 131, 130, //
382 164, 131, 130, //
383 }));
384 }
385
TEST(ConvolutionOpTest,FloatInputQuantizedWeights)386 TEST(ConvolutionOpTest, FloatInputQuantizedWeights) {
387 ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}},
388 {TensorType_UINT8, {3, 2, 2, 1}, 0, 64},
389 {TensorType_FLOAT32, {}});
390 m.SetInput({
391 // First batch
392 1, 1, 1, 2, // row = 1
393 2, 2, 2, 1, // row = 2
394 // Second batch
395 1, 2, 3, 4, // row = 1
396 1, 2, 3, 4, // row = 2
397 });
398 m.SetFilter({
399 1, 2, 3, 4, // first 2x2 filter
400 0, 1, 0, 1, // second 2x2 filter
401 0, 0, 1, 1, // third 2x2 filter
402 });
403 m.SetBias({1, 2, 3});
404
405 m.Invoke();
406
407 EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
408 {
409 18, 5, 7, // first batch, left
410 16, 5, 6, // first batch, right
411 17, 6, 6, // second batch, left
412 37, 10, 10, // second batch, right
413 },
414 0.2)));
415 }
416
TEST(ConvolutionOpTest,NoActivation)417 TEST(ConvolutionOpTest, NoActivation) {
418 ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}},
419 {TensorType_FLOAT32, {3, 2, 2, 1}},
420 {TensorType_FLOAT32, {}});
421
422 m.SetInput({
423 // First batch
424 1, 1, 1, 1, // row = 1
425 2, 2, 2, 2, // row = 2
426 // Second batch
427 1, 2, 3, 4, // row = 1
428 1, 2, 3, 4, // row = 2
429 });
430 m.SetFilter({
431 1, 2, 3, 4, // first 2x2 filter
432 -1, 1, -1, 1, // second 2x2 filter
433 -1, -1, 1, 1, // third 2x2 filter
434 });
435 m.SetBias({1, 2, 3});
436
437 m.Invoke();
438
439 EXPECT_THAT(m.GetOutput(), ElementsAreArray({
440 18, 2, 5, // first batch, left
441 18, 2, 5, // first batch, right
442 17, 4, 3, // second batch, left
443 37, 4, 3, // second batch, right
444 }));
445 }
446
447 class DepthwiseConvolutionOpModel : public SingleOpModelWithNNAPI {
448 public:
DepthwiseConvolutionOpModel(const TensorData & input,const TensorData & filter,const TensorData & output)449 DepthwiseConvolutionOpModel(const TensorData& input, const TensorData& filter,
450 const TensorData& output) {
451 input_ = AddInput(input);
452 filter_ = AddInput(filter);
453
454 int bias_size = GetShape(filter_)[3];
455 if (input.type == TensorType_FLOAT32) {
456 bias_ = AddInput({TensorType_FLOAT32, {bias_size}});
457 } else {
458 // This is a quantized version. The scale of 'bias' depends on the scales
459 // of input and filter. Supposedly this is correctly set during quantized
460 // training.
461 auto bias_scale = GetScale(input_) * GetScale(filter_);
462 TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale};
463 bias_ = AddInput(bias);
464 }
465
466 output_ = AddOutput(output);
467
468 int input_depth = GetShape(input_)[3];
469 int output_depth = GetShape(filter_)[3];
470 int depth_mul = output_depth / input_depth;
471
472 SetBuiltinOp(
473 BuiltinOperator_DEPTHWISE_CONV_2D,
474 BuiltinOptions_DepthwiseConv2DOptions,
475 CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul,
476 ActivationFunctionType_NONE)
477 .Union());
478
479 BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)});
480 }
481
SetFilter(std::initializer_list<float> f)482 void SetFilter(std::initializer_list<float> f) { PopulateTensor(filter_, f); }
483
SetBias(std::initializer_list<float> f)484 void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
485
SetInput(std::initializer_list<float> data)486 void SetInput(std::initializer_list<float> data) {
487 PopulateTensor(input_, data);
488 }
489
GetOutput()490 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
491
492 protected:
493 int input_;
494 int filter_;
495 int bias_;
496 int output_;
497 };
498
TEST(NNAPIDelegate,DepthwiseConv2DWithNoActivation)499 TEST(NNAPIDelegate, DepthwiseConv2DWithNoActivation) {
500 DepthwiseConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 2, 2}},
501 {TensorType_FLOAT32, {1, 2, 2, 4}},
502 {TensorType_FLOAT32, {}});
503
504 m.SetInput({
505 1, 2, 7, 8, // column 1
506 3, 4, 9, 10, // column 2
507 5, 6, 11, 12, // column 3
508 });
509 m.SetFilter({
510 1, 2, 3, 4, //
511 -9, 10, -11, 12, //
512 5, 6, 7, 8, //
513 13, -14, 15, -16, //
514 });
515 m.SetBias({1, 2, 3, 4});
516
517 m.Invoke();
518
519 EXPECT_THAT(m.GetOutput(), ElementsAreArray({
520 71, -34, 99, -20, //
521 91, -26, 127, -4, //
522 }));
523 }
524
525 class FullyConnectedOpModel : public SingleOpModelWithNNAPI {
526 public:
FullyConnectedOpModel(const TensorData & input,const TensorData & weights,const TensorData & output,enum ActivationFunctionType activation=ActivationFunctionType_NONE)527 FullyConnectedOpModel(
528 const TensorData& input, const TensorData& weights,
529 const TensorData& output,
530 enum ActivationFunctionType activation = ActivationFunctionType_NONE)
531 : input_type_(input.type), weights_type_(weights.type) {
532 input_ = AddInput(input);
533 weights_ = AddInput(weights);
534
535 const int units = weights.shape[0];
536 if (input.type == TensorType_FLOAT32) {
537 bias_ = AddInput({TensorType_FLOAT32, {units}});
538 } else {
539 // This is a quantized version. The scale of 'bias' depends on the scales
540 // of input and filter. Supposedly this is correctly set during quantized
541 // training.
542 auto bias_scale = GetScale(input_) * GetScale(weights_);
543 TensorData bias{TensorType_INT32, {units}, 0, 0, bias_scale};
544 bias_ = AddInput(bias);
545 }
546
547 output_ = AddOutput(output);
548
549 SetBuiltinOp(BuiltinOperator_FULLY_CONNECTED,
550 BuiltinOptions_FullyConnectedOptions,
551 CreateFullyConnectedOptions(builder_, activation).Union());
552 BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)});
553 }
554
SetInput(std::initializer_list<float> data)555 void SetInput(std::initializer_list<float> data) {
556 SetData(input_, input_type_, data);
557 }
558
SetWeights(std::initializer_list<float> data)559 void SetWeights(std::initializer_list<float> data) {
560 SetData(weights_, weights_type_, data);
561 }
562
SetBias(std::initializer_list<float> data)563 void SetBias(std::initializer_list<float> data) {
564 const auto bias_type =
565 (input_type_ == TensorType_FLOAT32) ? input_type_ : TensorType_INT32;
566 SetData(bias_, bias_type, data);
567 }
568
GetOutput()569 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
570
571 protected:
572 int input_;
573 int weights_;
574 int bias_;
575 int output_;
576
577 const TensorType input_type_;
578 const TensorType weights_type_;
579 };
580
TEST(FullyConnectedOpTest,SimpleTest)581 TEST(FullyConnectedOpTest, SimpleTest) {
582 FullyConnectedOpModel m(/*input=*/{TensorType_FLOAT32, {2, 10}},
583 /*weights=*/{TensorType_FLOAT32, {3, 10}},
584 /*output=*/{TensorType_FLOAT32});
585 m.SetWeights({
586 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
587 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
588 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
589 });
590 m.SetBias({1, 2, 3});
591
592 m.SetInput({
593 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
594 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
595 });
596
597 m.Invoke();
598
599 EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
600 }
601
TEST(FullyConnectedOpTest,FloatInputQuantizedWeights)602 TEST(FullyConnectedOpTest, FloatInputQuantizedWeights) {
603 FullyConnectedOpModel m(/*input=*/{TensorType_FLOAT32, {2, 10}},
604 /*weights=*/{TensorType_UINT8, {3, 10}, 0, 64},
605 /*output=*/{TensorType_FLOAT32});
606 m.SetWeights({
607 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
608 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
609 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
610 });
611 m.SetBias({1, 2, 3});
612
613 m.SetInput({
614 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
615 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
616 });
617
618 m.Invoke();
619
620 EXPECT_THAT(m.GetOutput(),
621 ElementsAreArray(ArrayFloatNear({24, 25, 26, 58, 59, 60}, 1.3)));
622 }
623
624 class SoftmaxOpModel : public SingleOpModelWithNNAPI {
625 public:
SoftmaxOpModel(int batches,int size,float beta)626 SoftmaxOpModel(int batches, int size, float beta)
627 : batches_(batches), input_size_(size), beta_(beta) {
628 input_ = AddInput(TensorType_FLOAT32);
629 output_ = AddOutput(TensorType_FLOAT32);
630 SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions,
631 CreateSoftmaxOptions(builder_, beta_).Union());
632 BuildInterpreter({{batches_, input_size_}});
633 }
634
SetInput(std::initializer_list<float> data)635 void SetInput(std::initializer_list<float> data) {
636 PopulateTensor(input_, data);
637 }
638
SetInput(int offset,float * begin,float * end)639 void SetInput(int offset, float* begin, float* end) {
640 PopulateTensor(input_, offset, begin, end);
641 }
642
GetOutput()643 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
644
645 private:
646 int input_;
647 int output_;
648
649 int batches_;
650 int input_size_;
651 float beta_;
652 };
653
TEST(NNAPIDelegate,SoftmaxSimpleTest)654 TEST(NNAPIDelegate, SoftmaxSimpleTest) {
655 SoftmaxOpModel m(/*batches=*/2, /*size=*/5, /*beta=*/1.0);
656 m.SetInput({
657 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0
658 -1.0, -2.0, -3.0, -4.0, -5.0, // b = 0
659 });
660
661 m.Invoke();
662
663 EXPECT_THAT(
664 m.GetOutput(),
665 ElementsAreArray(ArrayFloatNear(
666 {0.011656231, 0.031684921, 0.086128544, 0.234121657, 0.636408647,
667 0.636408647, 0.234121657, 0.086128544, 0.031684921, 0.011656231},
668 1e-6)));
669 }
670
671 class ReshapeOpModel : public SingleOpModelWithNNAPI {
672 public:
ReshapeOpModel(std::initializer_list<int> input_shape,std::initializer_list<int> new_shape)673 ReshapeOpModel(std::initializer_list<int> input_shape,
674 std::initializer_list<int> new_shape) {
675 input_ = AddInput(TensorType_FLOAT32);
676 new_shape_ = AddConstInput<int>(TensorType_INT32, new_shape,
677 {static_cast<int>(new_shape.size())});
678 output_ = AddOutput(TensorType_FLOAT32);
679 SetBuiltinOp(
680 BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions,
681 CreateReshapeOptions(builder_, builder_.CreateVector<int>(new_shape))
682 .Union());
683 BuildInterpreter({input_shape, {static_cast<int>(new_shape.size())}});
684 }
685
SetInput(std::initializer_list<float> data)686 void SetInput(std::initializer_list<float> data) {
687 PopulateTensor<float>(input_, data);
688 }
GetOutput()689 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()690 std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
691
692 private:
693 int input_;
694 int new_shape_;
695 int output_;
696 };
697
TEST(NNAPIDelegate,ReshapeSimpleTest)698 TEST(NNAPIDelegate, ReshapeSimpleTest) {
699 ReshapeOpModel m({1, 2, 4, 1}, {2, 2, 2});
700 m.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
701 m.Invoke();
702 EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8}));
703 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2}));
704 }
705
706 class SqueezeOpModel : public SingleOpModelWithNNAPI {
707 public:
SqueezeOpModel(const TensorData & input,const TensorData & output,std::initializer_list<int> axis)708 SqueezeOpModel(const TensorData& input, const TensorData& output,
709 std::initializer_list<int> axis) {
710 input_ = AddInput(input);
711 output_ = AddOutput(output);
712 SetBuiltinOp(
713 BuiltinOperator_SQUEEZE, BuiltinOptions_SqueezeOptions,
714 CreateSqueezeOptions(builder_, builder_.CreateVector<int>(axis))
715 .Union());
716 BuildInterpreter({GetShape(input_)});
717 }
718
SetInput(std::initializer_list<float> data)719 void SetInput(std::initializer_list<float> data) {
720 PopulateTensor<float>(input_, data);
721 }
GetOutput()722 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()723 std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
724
725 private:
726 int input_;
727 int new_shape_;
728 int output_;
729 };
730
TEST(NNAPIDelegate,SqueezeSimpleTest)731 TEST(NNAPIDelegate, SqueezeSimpleTest) {
732 std::initializer_list<float> data = {
733 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
734 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
735 SqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}}, {TensorType_FLOAT32, {24}},
736 {});
737 m.SetInput(data);
738 m.Invoke();
739 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({24}));
740 EXPECT_THAT(
741 m.GetOutput(),
742 ElementsAreArray({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
743 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
744 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}));
745 }
746
TEST(NNAPIDelegate,SqueezeWithAxisTest)747 TEST(NNAPIDelegate, SqueezeWithAxisTest) {
748 std::initializer_list<float> data = {
749 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
750 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
751 SqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}}, {TensorType_FLOAT32, {24}},
752 {2});
753 m.SetInput(data);
754 m.Invoke();
755 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 24}));
756 EXPECT_THAT(
757 m.GetOutput(),
758 ElementsAreArray({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
759 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
760 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}));
761 }
762
763 class L2NormOpModel : public SingleOpModelWithNNAPI {
764 public:
L2NormOpModel(const TensorData & input,const TensorData & output,ActivationFunctionType activation_type)765 L2NormOpModel(const TensorData& input, const TensorData& output,
766 ActivationFunctionType activation_type) {
767 input_ = AddInput(input);
768 output_ = AddOutput(output);
769 SetBuiltinOp(BuiltinOperator_L2_NORMALIZATION, BuiltinOptions_L2NormOptions,
770 CreateL2NormOptions(builder_, activation_type).Union());
771 BuildInterpreter({GetShape(input_)});
772 }
773
SetInput(std::initializer_list<float> data)774 void SetInput(std::initializer_list<float> data) {
775 PopulateTensor<float>(input_, data);
776 }
GetOutput()777 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()778 std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
779
780 private:
781 int input_;
782 int new_shape_;
783 int output_;
784 };
785
TEST(NNAPIDelegate,L2NormSimpleTest)786 TEST(NNAPIDelegate, L2NormSimpleTest) {
787 std::initializer_list<float> data = {-1.1, 0.6, 0.7, 1.2, -0.7, 0.1};
788 L2NormOpModel m({TensorType_FLOAT32, {1, 1, 1, 6}},
789 {TensorType_FLOAT32, {1, 1, 1, 6}},
790 ActivationFunctionType_NONE);
791 m.SetInput(data);
792 m.Invoke();
793 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 6}));
794 EXPECT_THAT(m.GetOutput(),
795 ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}));
796 }
797
798 class TransposeSimpleModel : public SingleOpModelWithNNAPI {
799 public:
TransposeSimpleModel(std::initializer_list<int> input_shape,std::initializer_list<int> perm_shape,std::initializer_list<int> perm)800 TransposeSimpleModel(std::initializer_list<int> input_shape,
801 std::initializer_list<int> perm_shape,
802 std::initializer_list<int> perm) {
803 input_ = AddInput(TensorType_FLOAT32);
804 perm_ = AddConstInput(TensorType_INT32, perm, perm_shape);
805 output_ = AddOutput(TensorType_FLOAT32);
806 SetBuiltinOp(BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions,
807 CreateTransposeOptions(builder_).Union());
808 BuildInterpreter({input_shape, perm_shape});
809 }
810
SetInput(std::initializer_list<float> data)811 void SetInput(std::initializer_list<float> data) {
812 PopulateTensor<float>(input_, data);
813 }
814
GetOutput()815 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()816 std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
817
818 private:
819 int input_;
820 int perm_;
821 int output_;
822 };
823
TEST(NNAPIDelegate,TransposeSimpleTest)824 TEST(NNAPIDelegate, TransposeSimpleTest) {
825 TransposeSimpleModel m({2, 3, 4}, {3}, {2, 0, 1});
826 m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
827 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23});
828 m.Invoke();
829 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3}));
830 EXPECT_THAT(m.GetOutput(),
831 ElementsAreArray({0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21,
832 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23}));
833 }
834
835 class FloatSubOpModel : public SingleOpModelWithNNAPI {
836 public:
FloatSubOpModel(const TensorData & input1,const TensorData & input2,const TensorData & output,ActivationFunctionType activation_type)837 FloatSubOpModel(const TensorData& input1, const TensorData& input2,
838 const TensorData& output,
839 ActivationFunctionType activation_type) {
840 input1_ = AddInput(input1);
841 input2_ = AddInput(input2);
842 output_ = AddOutput(output);
843 SetBuiltinOp(BuiltinOperator_SUB, BuiltinOptions_SubOptions,
844 CreateMulOptions(builder_, activation_type).Union());
845 BuildInterpreter({GetShape(input1_), GetShape(input2_)});
846 }
847
input1()848 int input1() { return input1_; }
input2()849 int input2() { return input2_; }
850
GetOutput()851 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
852
853 protected:
854 int input1_;
855 int input2_;
856 int output_;
857 };
858
TEST(NNAPIDelegate,SubWithNoActivation)859 TEST(NNAPIDelegate, SubWithNoActivation) {
860 FloatSubOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
861 {TensorType_FLOAT32, {1, 2, 2, 1}},
862 {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
863 m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
864 m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
865 m.Invoke();
866 EXPECT_THAT(m.GetOutput(),
867 ElementsAreArray(ArrayFloatNear({-2.1, 0.0, 0.4, 0.3})));
868 }
869
870 class FloatDivOpModel : public SingleOpModelWithNNAPI {
871 public:
FloatDivOpModel(const TensorData & input1,const TensorData & input2,const TensorData & output,ActivationFunctionType activation_type)872 FloatDivOpModel(const TensorData& input1, const TensorData& input2,
873 const TensorData& output,
874 ActivationFunctionType activation_type) {
875 input1_ = AddInput(input1);
876 input2_ = AddInput(input2);
877 output_ = AddOutput(output);
878 SetBuiltinOp(BuiltinOperator_DIV, BuiltinOptions_DivOptions,
879 CreateMulOptions(builder_, activation_type).Union());
880 BuildInterpreter({GetShape(input1_), GetShape(input2_)});
881 }
882
input1()883 int input1() { return input1_; }
input2()884 int input2() { return input2_; }
885
GetOutput()886 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
887
888 protected:
889 int input1_;
890 int input2_;
891 int output_;
892 };
893
TEST(NNAPIDelegate,DivWithNoActivation)894 TEST(NNAPIDelegate, DivWithNoActivation) {
895 FloatDivOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
896 {TensorType_FLOAT32, {1, 2, 2, 1}},
897 {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
898 m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.8, 0.8});
899 m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.4, 0.2});
900 m.Invoke();
901 EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({-20, 1, 2, 4})));
902 }
903
904 class BaseConcatenationOpModel : public SingleOpModelWithNNAPI {
905 public:
BaseConcatenationOpModel()906 BaseConcatenationOpModel() {}
BaseConcatenationOpModel(const TensorData & input_template,int axis,int num_inputs)907 BaseConcatenationOpModel(const TensorData& input_template, int axis,
908 int num_inputs) {
909 std::vector<std::vector<int>> all_input_shapes;
910 for (int i = 0; i < num_inputs; ++i) {
911 all_input_shapes.push_back(input_template.shape);
912 AddInput(input_template);
913 }
914 output_ = AddOutput({input_template.type, /*shape=*/{}, input_template.min,
915 input_template.max});
916 SetBuiltinOp(
917 BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions,
918 CreateConcatenationOptions(builder_, axis, ActivationFunctionType_NONE)
919 .Union());
920 BuildInterpreter(all_input_shapes);
921 }
922
923 protected:
924 int output_;
925 };
926
927 class ConcatenationOpModel : public BaseConcatenationOpModel {
928 public:
929 using BaseConcatenationOpModel::BaseConcatenationOpModel;
SetInput(int index,std::initializer_list<float> data)930 void SetInput(int index, std::initializer_list<float> data) {
931 PopulateTensor(index, data);
932 }
GetOutput()933 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
934 };
935
TEST(NNAPIDelegate,ConcatenationThreeDimensionalOneInput)936 TEST(NNAPIDelegate, ConcatenationThreeDimensionalOneInput) {
937 ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/1,
938 /*num_inputs=*/1);
939 m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
940 m0.Invoke();
941 EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 3, 4, 7}));
942 }
943
TEST(NNAPIDelegate,ConcatenationFourInputs)944 TEST(NNAPIDelegate, ConcatenationFourInputs) {
945 ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/2,
946 /*num_inputs=*/4);
947 m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
948 m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f});
949 m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f});
950 m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f});
951 m0.Invoke();
952 EXPECT_THAT(m0.GetOutput(),
953 ElementsAreArray({
954 1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f, //
955 4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f, //
956 }));
957 }
958
959 class QuantizedConcatenationOpModel : public BaseConcatenationOpModel {
960 public:
961 using BaseConcatenationOpModel::BaseConcatenationOpModel;
QuantizedConcatenationOpModel(const std::vector<TensorData> & input_template,int axis,int num_inputs,const TensorData & output_template)962 QuantizedConcatenationOpModel(const std::vector<TensorData>& input_template,
963 int axis, int num_inputs,
964 const TensorData& output_template) {
965 std::vector<std::vector<int>> all_input_shapes;
966 CHECK_EQ(input_template.size(), num_inputs);
967 for (int i = 0; i < num_inputs; ++i) {
968 all_input_shapes.push_back(input_template[i].shape);
969 AddInput(input_template[i]);
970 }
971 output_ = AddOutput({output_template.type, /*shape=*/{},
972 output_template.min, output_template.max});
973 SetBuiltinOp(
974 BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions,
975 CreateConcatenationOptions(builder_, axis, ActivationFunctionType_NONE)
976 .Union());
977 BuildInterpreter(all_input_shapes);
978 }
SetInput(int index,std::initializer_list<float> data)979 void SetInput(int index, std::initializer_list<float> data) {
980 QuantizeAndPopulate<uint8_t>(index, data);
981 }
GetOutput()982 std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
GetDequantizedOutput()983 std::vector<float> GetDequantizedOutput() {
984 return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
985 GetScale(output_), GetZeroPoint(output_));
986 }
987 };
988
TEST(NNAPIDelegate,ConcatenationFourInputsQuantized)989 TEST(NNAPIDelegate, ConcatenationFourInputsQuantized) {
990 QuantizedConcatenationOpModel m0({TensorType_UINT8, {2, 1, 2}, -12.7, 12.8},
991 /*axis=*/2,
992 /*num_inputs=*/4);
993
994 m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
995 m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f});
996 m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f});
997 m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f});
998 m0.Invoke();
999 EXPECT_THAT(m0.GetDequantizedOutput(),
1000 ElementsAreArray(ArrayFloatNear({
1001 1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f, //
1002 4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f, //
1003 })));
1004 EXPECT_THAT(m0.GetOutput(), ElementsAreArray({
1005 137, 157, 138, 158, 139, 159, 140, 160, //
1006 167, 197, 168, 198, 169, 199, 170, 200, //
1007 }));
1008 }
1009
TEST(NNAPIDelegate,ConcatenationFourInputsQuantizedMixedRange)1010 TEST(NNAPIDelegate, ConcatenationFourInputsQuantizedMixedRange) {
1011 QuantizedConcatenationOpModel m0({{TensorType_UINT8, {2, 1, 2}, -10.7, 10.8},
1012 {TensorType_UINT8, {2, 1, 2}, 0, 12.8},
1013 {TensorType_UINT8, {2, 1, 2}, -11, 11.8},
1014 {TensorType_UINT8, {2, 1, 2}, 0, 7.4}},
1015 /*axis=*/2, /*num_inputs=*/4,
1016 {TensorType_UINT8, {2, 1, 2}, -12.7, 12.8});
1017
1018 m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
1019 m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f});
1020 m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f});
1021 m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f});
1022 m0.Invoke();
1023 EXPECT_THAT(m0.GetDequantizedOutput(),
1024 ElementsAreArray(ArrayFloatNear({
1025 1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f, //
1026 4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f, //
1027 })));
1028 EXPECT_THAT(m0.GetOutput(), ElementsAreArray({
1029 137, 157, 138, 158, 139, 159, 140, 160, //
1030 167, 197, 168, 198, 169, 199, 170, 200, //
1031 }));
1032 }
1033
1034 class DequantizeOpModel : public SingleOpModelWithNNAPI {
1035 public:
DequantizeOpModel(TensorType inputType,std::initializer_list<int> shape,float min,float max)1036 DequantizeOpModel(TensorType inputType, std::initializer_list<int> shape,
1037 float min, float max) {
1038 input_ = AddInput({inputType, shape, min, max});
1039 output_ = AddOutput({TensorType_FLOAT32, shape});
1040 SetBuiltinOp(BuiltinOperator_DEQUANTIZE, BuiltinOptions_DequantizeOptions,
1041 CreateDequantizeOptions(builder_).Union());
1042
1043 BuildInterpreter({GetShape(input_)});
1044 }
1045
1046 template <typename T>
SetInput(std::initializer_list<T> data)1047 void SetInput(std::initializer_list<T> data) {
1048 PopulateTensor(input_, data);
1049 }
1050
GetOutput()1051 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
1052
1053 private:
1054 int input_;
1055 int output_;
1056 };
1057
TEST(NNAPIDelegate,DequantizeFourDimensionalUint8)1058 TEST(NNAPIDelegate, DequantizeFourDimensionalUint8) {
1059 DequantizeOpModel m(TensorType_UINT8, {2, 5}, -63.5, 64);
1060
1061 m.SetInput<uint8_t>({0, 1, 2, 3, 4, 251, 252, 253, 254, 255});
1062 m.Invoke();
1063 EXPECT_THAT(m.GetOutput(),
1064 ElementsAreArray(ArrayFloatNear(
1065 {-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64})));
1066 }
1067
TEST(NNAPIDelegate,DequantizeFourDimensionalInt8Symm)1068 TEST(NNAPIDelegate, DequantizeFourDimensionalInt8Symm) {
1069 // [-64, 63.5] -> scale=0.5, zero_point=0 for INT8
1070 DequantizeOpModel m(TensorType_INT8, {2, 5}, -64, 63.5);
1071
1072 m.SetInput<int8_t>({-128, -127, -126, -125, -124, 123, 124, 125, 126, 127});
1073 m.Invoke();
1074 EXPECT_THAT(m.GetOutput(),
1075 ElementsAreArray(ArrayFloatNear(
1076 {-64, -63.5, -63, -62.5, -62, 61.5, 62, 62.5, 63, 63.5})));
1077 }
1078
1079 class FloorOpModel : public SingleOpModelWithNNAPI {
1080 public:
FloorOpModel(std::initializer_list<int> input_shape,TensorType input_type)1081 FloorOpModel(std::initializer_list<int> input_shape, TensorType input_type) {
1082 input_ = AddInput(TensorType_FLOAT32);
1083 output_ = AddOutput(TensorType_FLOAT32);
1084 SetBuiltinOp(BuiltinOperator_FLOOR, BuiltinOptions_NONE, 0);
1085 BuildInterpreter({
1086 input_shape,
1087 });
1088 }
1089
input()1090 int input() { return input_; }
1091
GetOutput()1092 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()1093 std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
1094
1095 private:
1096 int input_;
1097 int output_;
1098 };
1099
TEST(NNAPIDelegate,FloorSingleDim)1100 TEST(NNAPIDelegate, FloorSingleDim) {
1101 FloorOpModel model({2}, TensorType_FLOAT32);
1102 model.PopulateTensor<float>(model.input(), {8.5, 0.0});
1103 model.Invoke();
1104 EXPECT_THAT(model.GetOutput(), ElementsAreArray({8, 0}));
1105 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2}));
1106 }
1107
TEST(NNAPIDelegate,FloorMultiDims)1108 TEST(NNAPIDelegate, FloorMultiDims) {
1109 FloorOpModel model({2, 1, 1, 5}, TensorType_FLOAT32);
1110 model.PopulateTensor<float>(model.input(), {
1111 0.0001,
1112 8.0001,
1113 0.9999,
1114 9.9999,
1115 0.5,
1116 -0.0001,
1117 -8.0001,
1118 -0.9999,
1119 -9.9999,
1120 -0.5,
1121 });
1122 model.Invoke();
1123 EXPECT_THAT(model.GetOutput(),
1124 ElementsAreArray({0, 8, 0, 9, 0, -1, -9, -1, -10, -1}));
1125 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 1, 5}));
1126 }
1127
1128 class LocalResponseNormOpModel : public SingleOpModelWithNNAPI {
1129 public:
LocalResponseNormOpModel(std::initializer_list<int> input_shape,int radius,float bias,float alpha,float beta)1130 LocalResponseNormOpModel(std::initializer_list<int> input_shape, int radius,
1131 float bias, float alpha, float beta) {
1132 input_ = AddInput(TensorType_FLOAT32);
1133 output_ = AddOutput(TensorType_FLOAT32);
1134 SetBuiltinOp(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
1135 BuiltinOptions_LocalResponseNormalizationOptions,
1136 CreateLocalResponseNormalizationOptions(builder_, radius, bias,
1137 alpha, beta)
1138 .Union());
1139 BuildInterpreter({input_shape});
1140 }
1141
SetInput(std::initializer_list<float> data)1142 void SetInput(std::initializer_list<float> data) {
1143 PopulateTensor(input_, data);
1144 }
1145
GetOutput()1146 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
1147
1148 private:
1149 int input_;
1150 int output_;
1151 };
1152
TEST(NNAPIDelegate,LocalResponseNormSameAsL2Norm)1153 TEST(NNAPIDelegate, LocalResponseNormSameAsL2Norm) {
1154 LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/0.0,
1155 /*alpha=*/1.0, /*beta=*/0.5);
1156 m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
1157 m.Invoke();
1158 // The result is every input divided by 2.
1159 EXPECT_THAT(
1160 m.GetOutput(),
1161 ElementsAreArray(ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05})));
1162 }
1163
TEST(NNAPIDelegate,LocalResponseNormWithAlpha)1164 TEST(NNAPIDelegate, LocalResponseNormWithAlpha) {
1165 LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/0.0,
1166 /*alpha=*/4.0, /*beta=*/0.5);
1167 m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
1168 m.Invoke();
1169 // The result is every input divided by 3.
1170 EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
1171 {-0.275, 0.15, 0.175, 0.3, -0.175, 0.025})));
1172 }
1173
TEST(NNAPIDelegate,LocalResponseNormWithBias)1174 TEST(NNAPIDelegate, LocalResponseNormWithBias) {
1175 LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/9.0,
1176 /*alpha=*/4.0, /*beta=*/0.5);
1177 m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
1178 m.Invoke();
1179 // The result is every input divided by 5.
1180 EXPECT_THAT(
1181 m.GetOutput(),
1182 ElementsAreArray(ArrayFloatNear({-0.22, 0.12, 0.14, 0.24, -0.14, 0.02})));
1183 }
1184
TEST(NNAPIDelegate,LocalResponseNormSmallRadius)1185 TEST(NNAPIDelegate, LocalResponseNormSmallRadius) {
1186 LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/2, /*bias=*/9.0,
1187 /*alpha=*/4.0, /*beta=*/0.5);
1188 m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
1189 m.Invoke();
1190 EXPECT_THAT(
1191 m.GetOutput(),
1192 ElementsAreArray(ArrayFloatNear(
1193 {-0.264926, 0.125109, 0.140112, 0.267261, -0.161788, 0.0244266})));
1194 }
1195
1196 class LSHProjectionOpModel : public SingleOpModelWithNNAPI {
1197 public:
LSHProjectionOpModel(LSHProjectionType type,std::initializer_list<int> hash_shape,std::initializer_list<int> input_shape,std::initializer_list<int> weight_shape)1198 LSHProjectionOpModel(LSHProjectionType type,
1199 std::initializer_list<int> hash_shape,
1200 std::initializer_list<int> input_shape,
1201 std::initializer_list<int> weight_shape) {
1202 hash_ = AddInput(TensorType_FLOAT32);
1203 input_ = AddInput(TensorType_INT32);
1204 if (weight_shape.size() > 0) {
1205 weight_ = AddInput(TensorType_FLOAT32);
1206 }
1207 output_ = AddOutput(TensorType_INT32);
1208
1209 SetBuiltinOp(BuiltinOperator_LSH_PROJECTION,
1210 BuiltinOptions_LSHProjectionOptions,
1211 CreateLSHProjectionOptions(builder_, type).Union());
1212 if (weight_shape.size() > 0) {
1213 BuildInterpreter({hash_shape, input_shape, weight_shape});
1214 } else {
1215 BuildInterpreter({hash_shape, input_shape});
1216 }
1217
1218 output_size_ = 1;
1219 for (int i : hash_shape) {
1220 output_size_ *= i;
1221 if (type == LSHProjectionType_SPARSE) {
1222 break;
1223 }
1224 }
1225 }
SetInput(std::initializer_list<int> data)1226 void SetInput(std::initializer_list<int> data) {
1227 PopulateTensor(input_, data);
1228 }
1229
SetHash(std::initializer_list<float> data)1230 void SetHash(std::initializer_list<float> data) {
1231 PopulateTensor(hash_, data);
1232 }
1233
SetWeight(std::initializer_list<float> f)1234 void SetWeight(std::initializer_list<float> f) { PopulateTensor(weight_, f); }
1235
GetOutput()1236 std::vector<int> GetOutput() { return ExtractVector<int>(output_); }
1237
1238 private:
1239 int input_;
1240 int hash_;
1241 int weight_;
1242 int output_;
1243
1244 int output_size_;
1245 };
1246
TEST(NNAPIDelegate,LSHProjectionDense1DInputs)1247 TEST(NNAPIDelegate, LSHProjectionDense1DInputs) {
1248 LSHProjectionOpModel m(LSHProjectionType_DENSE, {3, 2}, {5}, {5});
1249
1250 m.SetInput({12345, 54321, 67890, 9876, -12345678});
1251 m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321});
1252 m.SetWeight({1.0, 1.0, 1.0, 1.0, 1.0});
1253
1254 m.Invoke();
1255
1256 EXPECT_THAT(m.GetOutput(), ElementsAre(0, 0, 0, 1, 0, 0));
1257 }
1258
TEST(NNAPIDelegate,LSHProjectionSparse1DInputs)1259 TEST(NNAPIDelegate, LSHProjectionSparse1DInputs) {
1260 LSHProjectionOpModel m(LSHProjectionType_SPARSE, {3, 2}, {5}, {});
1261
1262 m.SetInput({12345, 54321, 67890, 9876, -12345678});
1263 m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321});
1264
1265 m.Invoke();
1266
1267 EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 0, 4 + 1, 8 + 0));
1268 }
1269
TEST(NNAPIDelegate,LSHProjectionSparse3DInputs)1270 TEST(NNAPIDelegate, LSHProjectionSparse3DInputs) {
1271 LSHProjectionOpModel m(LSHProjectionType_SPARSE, {3, 2}, {5, 2, 2}, {5});
1272
1273 m.SetInput({1234, 2345, 3456, 1234, 4567, 5678, 6789, 4567, 7891, 8912,
1274 9123, 7890, -987, -876, -765, -987, -543, -432, -321, -543});
1275 m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321});
1276 m.SetWeight({0.12, 0.34, 0.56, 0.67, 0.78});
1277
1278 m.Invoke();
1279
1280 EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 2, 4 + 1, 8 + 1));
1281 }
1282
1283 class BaseActivationsOpModel : public SingleOpModelWithNNAPI {
1284 public:
1285 // Most activations don't take any options, so this constructor works for
1286 // them.
BaseActivationsOpModel(BuiltinOperator type,TensorData input)1287 BaseActivationsOpModel(BuiltinOperator type, TensorData input) {
1288 input_ = AddInput(input);
1289 if (input.type == TensorType_UINT8) {
1290 output_ = AddOutput({input.type, {}, 0, 0, 1. / 256});
1291 } else {
1292 output_ = AddOutput({input.type, {}});
1293 }
1294 SetBuiltinOp(type, BuiltinOptions_NONE, 0);
1295 BuildInterpreter({GetShape(input_)});
1296 }
1297
BaseActivationsOpModel(BuiltinOperator type,const TensorData & input,const TensorData & output)1298 BaseActivationsOpModel(BuiltinOperator type, const TensorData& input,
1299 const TensorData& output) {
1300 input_ = AddInput(input);
1301 output_ = AddOutput(output);
1302 SetBuiltinOp(type, BuiltinOptions_NONE, 0);
1303 BuildInterpreter({GetShape(input_)});
1304 }
1305
1306 protected:
1307 int input_;
1308 int output_;
1309 };
1310
1311 class FloatActivationsOpModel : public BaseActivationsOpModel {
1312 public:
1313 using BaseActivationsOpModel::BaseActivationsOpModel;
1314
SetInput(std::initializer_list<float> data)1315 void SetInput(std::initializer_list<float> data) {
1316 PopulateTensor(input_, data);
1317 }
GetOutput()1318 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
1319 };
1320
1321 const float kQuantizedTolerance = 2 * (1. / 256);
1322
1323 class QuantizedActivationsOpModel : public BaseActivationsOpModel {
1324 public:
1325 using BaseActivationsOpModel::BaseActivationsOpModel;
1326
1327 template <typename T>
SetInput(std::initializer_list<float> data)1328 void SetInput(std::initializer_list<float> data) {
1329 QuantizeAndPopulate<T>(input_, data);
1330 }
1331 template <typename T>
1332
GetOutput()1333 std::vector<T> GetOutput() {
1334 return ExtractVector<T>(output_);
1335 }
1336 template <typename T>
GetDequantizedOutput()1337 std::vector<float> GetDequantizedOutput() {
1338 return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
1339 GetZeroPoint(output_));
1340 }
1341 };
1342
TEST(NNAPIDelegate,Relu)1343 TEST(NNAPIDelegate, Relu) {
1344 FloatActivationsOpModel m(BuiltinOperator_RELU,
1345 /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
1346 m.SetInput({
1347 0, -6, 2, 4, //
1348 3, -2, 10, 1, //
1349 });
1350 m.Invoke();
1351 EXPECT_THAT(m.GetOutput(), ElementsAreArray({
1352 0, 0, 2, 4, //
1353 3, 0, 10, 1, //
1354 }));
1355 }
1356
TEST(NNAPIDelegate,Relu1)1357 TEST(NNAPIDelegate, Relu1) {
1358 FloatActivationsOpModel m(BuiltinOperator_RELU_N1_TO_1,
1359 /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
1360 m.SetInput({
1361 0.0, -0.6, 0.2, -0.4, //
1362 0.3, -2.0, 1.1, -0.1, //
1363 });
1364 m.Invoke();
1365 EXPECT_THAT(m.GetOutput(), ElementsAreArray({
1366 0.0, -0.6, 0.2, -0.4, //
1367 0.3, -1.0, 1.0, -0.1, //
1368 }));
1369 }
1370
TEST(NNAPIDelegate,Relu6)1371 TEST(NNAPIDelegate, Relu6) {
1372 FloatActivationsOpModel m(BuiltinOperator_RELU6,
1373 /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
1374 m.SetInput({
1375 0, -6, 2, 4, //
1376 3, -2, 10, 1, //
1377 });
1378 m.Invoke();
1379 EXPECT_THAT(m.GetOutput(), ElementsAreArray({
1380 0, 0, 2, 4, //
1381 3, 0, 6, 1, //
1382 }));
1383 }
1384
TEST(NNAPIDelegate,Tanh)1385 TEST(NNAPIDelegate, Tanh) {
1386 FloatActivationsOpModel m(BuiltinOperator_TANH,
1387 /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
1388 m.SetInput({
1389 0, -6, 2, 4, //
1390 3, -2, 10, 1, //
1391 });
1392 m.Invoke();
1393 EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
1394 0, -0.9999877, 0.9640275, 0.999329, //
1395 0.99505475, -0.9640275, 1, 0.7615941, //
1396 })));
1397 }
1398
TEST(NNAPIDelegate,LogisticFloat)1399 TEST(NNAPIDelegate, LogisticFloat) {
1400 FloatActivationsOpModel m(BuiltinOperator_LOGISTIC,
1401 /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
1402 m.SetInput({
1403 0, -6, 2, 4, //
1404 3, -2, 10, 1, //
1405 });
1406 m.Invoke();
1407 EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
1408 0.5, 0.002473, 0.880797, 0.982014, //
1409 0.952574, 0.119203, 0.999955, 0.731059, //
1410 })));
1411 }
1412
TEST(NNAPIDelegate,LogisticQuantized)1413 TEST(NNAPIDelegate, LogisticQuantized) {
1414 QuantizedActivationsOpModel m(
1415 BuiltinOperator_LOGISTIC,
1416 /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -10, 10});
1417 m.SetInput<uint8_t>({
1418 0, -6, 2, 4, //
1419 3, -2, 10, 1, //
1420 });
1421 m.Invoke();
1422 EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
1423 ElementsAreArray(ArrayFloatNear(
1424 {
1425 0.5, 0.002473, 0.880797, 0.982014, //
1426 0.952574, 0.119203, 0.999955, 0.731059, //
1427 },
1428 kQuantizedTolerance)));
1429 EXPECT_THAT(m.GetOutput<uint8_t>(),
1430 testing::Pointwise(QuantizedNear(),
1431 {128, 1, 227, 251, 244, 32, 255, 188}));
1432 }
1433
1434 #if 0
1435 class ResizeBilinearOpModel : public SingleOpModelWithNNAPI {
1436 public:
1437 ResizeBilinearOpModel(const TensorData& input,
1438 std::initializer_list<int> size_data = {}) {
1439 bool const_size = size_data.size() != 0;
1440 input_ = AddInput(input);
1441 if (const_size) {
1442 size_ = AddConstInput(TensorType_INT32, size_data, {2});
1443 } else {
1444 size_ = AddInput({TensorType_INT32, {2}});
1445 }
1446 output_ = AddOutput(input.type);
1447 SetBuiltinOp(BuiltinOperator_RESIZE_BILINEAR,
1448 BuiltinOptions_ResizeBilinearOptions,
1449 CreateResizeBilinearOptions(builder_).Union());
1450 if (const_size) {
1451 BuildInterpreter({GetShape(input_)});
1452 } else {
1453 BuildInterpreter({GetShape(input_), GetShape(size_)});
1454 }
1455 }
1456
1457 template <typename T>
1458 void SetInput(std::initializer_list<T> data) {
1459 PopulateTensor(input_, data);
1460 }
1461 void SetSize(std::initializer_list<int> data) { PopulateTensor(size_, data); }
1462
1463 template <typename T>
1464 std::vector<T> GetOutput() {
1465 return ExtractVector<T>(output_);
1466 }
1467
1468 private:
1469 int input_;
1470 int size_;
1471 int output_;
1472 };
1473
1474 TEST(NNAPIDelegate, ResizeBilinearHorizontal) {
1475 ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}});
1476 m.SetInput<float>({3, 6});
1477 m.SetSize({1, 3});
1478 m.Invoke();
1479 EXPECT_THAT(m.GetOutput<float>(),
1480 ElementsAreArray(ArrayFloatNear({3, 5, 6})));
1481
1482 ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3});
1483 const_m.SetInput<float>({3, 6});
1484 const_m.Invoke();
1485 EXPECT_THAT(const_m.GetOutput<float>(),
1486 ElementsAreArray(ArrayFloatNear({3, 5, 6})));
1487 }
1488
1489 TEST(NNAPIDelegate, ResizeBilinearVertical) {
1490 ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}});
1491 m.SetInput<float>({3, 9});
1492 m.SetSize({3, 1});
1493 m.Invoke();
1494 EXPECT_THAT(m.GetOutput<float>(),
1495 ElementsAreArray(ArrayFloatNear({3, 7, 9})));
1496
1497 ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1});
1498 const_m.SetInput<float>({3, 9});
1499 const_m.Invoke();
1500 EXPECT_THAT(const_m.GetOutput<float>(),
1501 ElementsAreArray(ArrayFloatNear({3, 7, 9})));
1502 }
1503
1504 TEST(NNAPIDelegate, ResizeBilinearTwoDimensional) {
1505 ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}});
1506 m.SetInput<float>({
1507 3, 6, //
1508 9, 12 //
1509 });
1510 m.SetSize({3, 3});
1511 m.Invoke();
1512 EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
1513 3, 5, 6, //
1514 7, 9, 10, //
1515 9, 11, 12, //
1516 })));
1517
1518 ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3});
1519 const_m.SetInput<float>({
1520 3, 6, //
1521 9, 12 //
1522 });
1523 const_m.Invoke();
1524 EXPECT_THAT(const_m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
1525 3, 5, 6, //
1526 7, 9, 10, //
1527 9, 11, 12, //
1528 })));
1529 }
1530 #endif
1531
1532 template <typename T>
1533 class PadOpModel : public SingleOpModelWithNNAPI {
1534 public:
SetInput(std::initializer_list<T> data)1535 void SetInput(std::initializer_list<T> data) {
1536 PopulateTensor<T>(input_, data);
1537 }
1538
SetQuantizedInput(std::initializer_list<float> data)1539 void SetQuantizedInput(std::initializer_list<float> data) {
1540 QuantizeAndPopulate<uint8_t>(input_, data);
1541 }
1542
SetQuantizedPadValue(float data)1543 void SetQuantizedPadValue(float data) {
1544 QuantizeAndPopulate<uint8_t>(constant_values_, {data});
1545 }
1546
SetPaddings(std::initializer_list<int> paddings)1547 void SetPaddings(std::initializer_list<int> paddings) {
1548 PopulateTensor<int>(paddings_, paddings);
1549 }
1550
GetOutput()1551 std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
GetOutputShape()1552 std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
1553
GetDequantizedOutput()1554 std::vector<float> GetDequantizedOutput() {
1555 return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
1556 GetScale(output_), GetZeroPoint(output_));
1557 }
1558
1559 protected:
1560 int input_;
1561 int output_;
1562 int paddings_;
1563 int constant_values_;
1564 };
1565
1566 class PadOpConstModel : public PadOpModel<float> {
1567 public:
PadOpConstModel(const TensorData & input,std::initializer_list<int> paddings_shape,std::initializer_list<int> paddings,const TensorData & output)1568 PadOpConstModel(const TensorData& input,
1569 std::initializer_list<int> paddings_shape,
1570 std::initializer_list<int> paddings,
1571 const TensorData& output) {
1572 input_ = AddInput(input);
1573 paddings_ = AddConstInput(TensorType_INT32, paddings, paddings_shape);
1574 output_ = AddOutput(output);
1575
1576 SetBuiltinOp(BuiltinOperator_PAD, BuiltinOptions_PadOptions,
1577 CreatePadOptions(builder_).Union());
1578 BuildInterpreter({input.shape});
1579 }
1580 };
1581
TEST(NNAPIDelegate,PadAdvancedConstTest)1582 TEST(NNAPIDelegate, PadAdvancedConstTest) {
1583 PadOpConstModel m({TensorType_FLOAT32, {1, 2, 3, 1}}, {4, 2},
1584 {0, 0, 0, 2, 1, 3, 0, 0}, {TensorType_FLOAT32});
1585 m.SetInput({1, 2, 3, 4, 5, 6});
1586 m.Invoke();
1587 EXPECT_THAT(m.GetOutput(),
1588 ElementsAreArray({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0,
1589 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
1590 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
1591 }
1592
1593 class SpaceToBatchNDOpModel : public SingleOpModelWithNNAPI {
1594 public:
SetInput(std::initializer_list<float> data)1595 void SetInput(std::initializer_list<float> data) {
1596 PopulateTensor<float>(input_, data);
1597 }
1598
SetBlockShape(std::initializer_list<int> data)1599 void SetBlockShape(std::initializer_list<int> data) {
1600 PopulateTensor<int>(block_shape_, data);
1601 }
1602
SetPaddings(std::initializer_list<int> data)1603 void SetPaddings(std::initializer_list<int> data) {
1604 PopulateTensor<int>(paddings_, data);
1605 }
1606
GetOutput()1607 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()1608 std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
1609
1610 protected:
1611 int input_;
1612 int block_shape_;
1613 int paddings_;
1614 int output_;
1615 };
1616
1617 class SpaceToBatchNDOpConstModel : public SpaceToBatchNDOpModel {
1618 public:
SpaceToBatchNDOpConstModel(std::initializer_list<int> input_shape,std::initializer_list<int> block_shape,std::initializer_list<int> paddings)1619 SpaceToBatchNDOpConstModel(std::initializer_list<int> input_shape,
1620 std::initializer_list<int> block_shape,
1621 std::initializer_list<int> paddings) {
1622 input_ = AddInput(TensorType_FLOAT32);
1623 block_shape_ = AddConstInput(TensorType_INT32, block_shape, {2});
1624 paddings_ = AddConstInput(TensorType_INT32, paddings, {2, 2});
1625 output_ = AddOutput(TensorType_FLOAT32);
1626
1627 SetBuiltinOp(BuiltinOperator_SPACE_TO_BATCH_ND,
1628 BuiltinOptions_SpaceToBatchNDOptions,
1629 CreateSpaceToBatchNDOptions(builder_).Union());
1630 BuildInterpreter({input_shape});
1631 }
1632 };
1633
TEST(NNAPIDelegate,SpaceToBatchNDSimpleConstTest)1634 TEST(NNAPIDelegate, SpaceToBatchNDSimpleConstTest) {
1635 SpaceToBatchNDOpConstModel m({1, 4, 4, 1}, {2, 2}, {0, 0, 0, 0});
1636 m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
1637 m.Invoke();
1638 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 2, 1}));
1639 EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 9, 11, 2, 4, 10, 12, 5, 7,
1640 13, 15, 6, 8, 14, 16}));
1641 }
1642
TEST(NNAPIDelegate,SpaceToBatchNDMultipleInputBatchesConstTest)1643 TEST(NNAPIDelegate, SpaceToBatchNDMultipleInputBatchesConstTest) {
1644 SpaceToBatchNDOpConstModel m({2, 2, 4, 1}, {2, 2}, {0, 0, 0, 0});
1645 m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
1646 m.Invoke();
1647 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({8, 1, 2, 1}));
1648 EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 9, 11, 2, 4, 10, 12, 5, 7,
1649 13, 15, 6, 8, 14, 16}));
1650 }
1651
TEST(NNAPIDelegate,SpaceToBatchNDSimplePaddingConstTest)1652 TEST(NNAPIDelegate, SpaceToBatchNDSimplePaddingConstTest) {
1653 SpaceToBatchNDOpConstModel m({1, 5, 2, 1}, {3, 2}, {1, 0, 2, 0});
1654 m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
1655 m.Invoke();
1656 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1}));
1657 EXPECT_THAT(m.GetOutput(), ElementsAreArray({
1658 0, 0, 0, 5, 0, 0, 0, 6, 0, 1, 0, 7,
1659 0, 2, 0, 8, 0, 3, 0, 9, 0, 4, 0, 10,
1660 }));
1661 }
1662
TEST(NNAPIDelegate,SpaceToBatchNDComplexPaddingConstTest)1663 TEST(NNAPIDelegate, SpaceToBatchNDComplexPaddingConstTest) {
1664 SpaceToBatchNDOpConstModel m({1, 4, 2, 1}, {3, 2}, {1, 1, 2, 4});
1665 m.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
1666 m.Invoke();
1667 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1}));
1668 EXPECT_THAT(m.GetOutput(), ElementsAreArray({
1669 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0,
1670 0, 1, 0, 0, 0, 7, 0, 0, 0, 2, 0, 0, 0, 8, 0, 0,
1671 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
1672 }));
1673 }
1674
1675 template <typename input_type = float,
1676 TensorType tensor_input_type = TensorType_FLOAT32>
1677 class StridedSliceOpModel : public SingleOpModelWithNNAPI {
1678 public:
StridedSliceOpModel(std::initializer_list<int> input_shape,std::initializer_list<int> begin_shape,std::initializer_list<int> begin_data,std::initializer_list<int> end_shape,std::initializer_list<int> end_data,std::initializer_list<int> strides_shape,std::initializer_list<int> strides_data,int begin_mask,int end_mask,int ellipsis_mask,int new_axis_mask,int shrink_axis_mask)1679 StridedSliceOpModel(std::initializer_list<int> input_shape,
1680 std::initializer_list<int> begin_shape,
1681 std::initializer_list<int> begin_data,
1682 std::initializer_list<int> end_shape,
1683 std::initializer_list<int> end_data,
1684 std::initializer_list<int> strides_shape,
1685 std::initializer_list<int> strides_data, int begin_mask,
1686 int end_mask, int ellipsis_mask, int new_axis_mask,
1687 int shrink_axis_mask) {
1688 input_ = AddInput(tensor_input_type);
1689 begin_ = AddConstInput(TensorType_INT32, begin_data, begin_shape);
1690 end_ = AddConstInput(TensorType_INT32, end_data, end_shape);
1691 strides_ = AddConstInput(TensorType_INT32, strides_data, strides_shape);
1692 output_ = AddOutput(tensor_input_type);
1693 SetBuiltinOp(
1694 BuiltinOperator_STRIDED_SLICE, BuiltinOptions_StridedSliceOptions,
1695 CreateStridedSliceOptions(builder_, begin_mask, end_mask, ellipsis_mask,
1696 new_axis_mask, shrink_axis_mask)
1697 .Union());
1698 BuildInterpreter({input_shape, begin_shape, end_shape, strides_shape});
1699 }
1700
SetInput(std::initializer_list<input_type> data)1701 void SetInput(std::initializer_list<input_type> data) {
1702 PopulateTensor<input_type>(input_, data);
1703 }
1704
GetOutput()1705 std::vector<input_type> GetOutput() {
1706 return ExtractVector<input_type>(output_);
1707 }
GetOutputShape()1708 std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
1709
1710 private:
1711 int input_;
1712 int begin_;
1713 int end_;
1714 int strides_;
1715 int output_;
1716 };
1717
TEST(StridedSliceOpTest,In1D)1718 TEST(StridedSliceOpTest, In1D) {
1719 StridedSliceOpModel<> m({4}, {1}, {1}, {1}, {3}, {1}, {1}, 0, 0, 0, 0, 0);
1720 m.SetInput({1, 2, 3, 4});
1721 m.Invoke();
1722 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
1723 EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3}));
1724 }
1725
TEST(StridedSliceOpTest,In1D_BeginMask)1726 TEST(StridedSliceOpTest, In1D_BeginMask) {
1727 StridedSliceOpModel<> m({4}, {1}, {1}, {1}, {3}, {1}, {1}, 1, 0, 0, 0, 0);
1728 m.SetInput({1, 2, 3, 4});
1729 m.Invoke();
1730 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
1731 EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3}));
1732 }
1733
TEST(StridedSliceOpTest,In2D_Stride2)1734 TEST(StridedSliceOpTest, In2D_Stride2) {
1735 StridedSliceOpModel<> m({2, 3}, {2}, {0, 0}, {2}, {2, 3}, {2}, {2, 2}, 0, 0,
1736 0, 0, 0);
1737 m.SetInput({1, 2, 3, 4, 5, 6});
1738 m.Invoke();
1739 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
1740 EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3}));
1741 }
1742
TEST(StridedSliceOpTest,In2D_EndMask)1743 TEST(StridedSliceOpTest, In2D_EndMask) {
1744 StridedSliceOpModel<> m({2, 3}, {2}, {1, 0}, {2}, {2, 2}, {2}, {1, 1}, 0, 2,
1745 0, 0, 0);
1746 m.SetInput({1, 2, 3, 4, 5, 6});
1747 m.Invoke();
1748 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3}));
1749 EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 5, 6}));
1750 }
1751
TEST(StridedSliceOpTest,In3D_IdentityShrinkAxis4)1752 TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis4) {
1753 StridedSliceOpModel<> m({2, 3, 2}, {3}, {0, 0, 0}, {3}, {2, 3, 1}, {3},
1754 {1, 1, 1}, 0, 0, 0, 0, 4);
1755 m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
1756 m.Invoke();
1757 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
1758 EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 5, 7, 9, 11}));
1759 }
1760
1761 static float rnn_input[] = {
1762 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133,
1763 0.43773448, 0.60379338, 0.35562468, -0.69424844, -0.93421471,
1764 -0.87287879, 0.37144363, -0.62476718, 0.23791671, 0.40060222,
1765 0.1356622, -0.99774903, -0.98858172, -0.38952237, -0.47685933,
1766 0.31073618, 0.71511042, -0.63767755, -0.31729108, 0.33468103,
1767 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043,
1768 -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007,
1769 -0.61777675, -0.21095741, 0.41213346, 0.73784804, 0.094794154,
1770 0.47791874, 0.86496925, -0.53376222, 0.85315156, 0.10288584,
1771 0.86684, -0.011186242, 0.10513687, 0.87825835, 0.59929144,
1772 0.62827742, 0.18899453, 0.31440187, 0.99059987, 0.87170351,
1773 -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719,
1774 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567,
1775 -0.66609079, 0.59098077, 0.73017097, 0.74604273, 0.32882881,
1776 -0.17503482, 0.22396147, 0.19379807, 0.29120302, 0.077113032,
1777 -0.70331609, 0.15804303, -0.93407321, 0.40182066, 0.036301374,
1778 0.66521823, 0.0300982, -0.7747041, -0.02038002, 0.020698071,
1779 -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219,
1780 -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682,
1781 0.43519354, 0.14744234, 0.62589407, 0.1653645, -0.10651493,
1782 -0.045277178, 0.99032974, -0.88255352, -0.85147917, 0.28153265,
1783 0.19455957, -0.55479527, -0.56042433, 0.26048636, 0.84702539,
1784 0.47587705, -0.074295521, -0.12287641, 0.70117295, 0.90532446,
1785 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017,
1786 -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563,
1787 0.93455386, -0.6324693, -0.083922029};
1788
1789 static float rnn_golden_output[] = {
1790 0.496726, 0, 0.965996, 0, 0.0584254, 0,
1791 0, 0.12315, 0, 0, 0.612266, 0.456601,
1792 0, 0.52286, 1.16099, 0.0291232,
1793
1794 0, 0, 0.524901, 0, 0, 0,
1795 0, 1.02116, 0, 1.35762, 0, 0.356909,
1796 0.436415, 0.0355727, 0, 0,
1797
1798 0, 0, 0, 0.262335, 0, 0,
1799 0, 1.33992, 0, 2.9739, 0, 0,
1800 1.31914, 2.66147, 0, 0,
1801
1802 0.942568, 0, 0, 0, 0.025507, 0,
1803 0, 0, 0.321429, 0.569141, 1.25274, 1.57719,
1804 0.8158, 1.21805, 0.586239, 0.25427,
1805
1806 1.04436, 0, 0.630725, 0, 0.133801, 0.210693,
1807 0.363026, 0, 0.533426, 0, 1.25926, 0.722707,
1808 0, 1.22031, 1.30117, 0.495867,
1809
1810 0.222187, 0, 0.72725, 0, 0.767003, 0,
1811 0, 0.147835, 0, 0, 0, 0.608758,
1812 0.469394, 0.00720298, 0.927537, 0,
1813
1814 0.856974, 0.424257, 0, 0, 0.937329, 0,
1815 0, 0, 0.476425, 0, 0.566017, 0.418462,
1816 0.141911, 0.996214, 1.13063, 0,
1817
1818 0.967899, 0, 0, 0, 0.0831304, 0,
1819 0, 1.00378, 0, 0, 0, 1.44818,
1820 1.01768, 0.943891, 0.502745, 0,
1821
1822 0.940135, 0, 0, 0, 0, 0,
1823 0, 2.13243, 0, 0.71208, 0.123918, 1.53907,
1824 1.30225, 1.59644, 0.70222, 0,
1825
1826 0.804329, 0, 0.430576, 0, 0.505872, 0.509603,
1827 0.343448, 0, 0.107756, 0.614544, 1.44549, 1.52311,
1828 0.0454298, 0.300267, 0.562784, 0.395095,
1829
1830 0.228154, 0, 0.675323, 0, 1.70536, 0.766217,
1831 0, 0, 0, 0.735363, 0.0759267, 1.91017,
1832 0.941888, 0, 0, 0,
1833
1834 0, 0, 1.5909, 0, 0, 0,
1835 0, 0.5755, 0, 0.184687, 0, 1.56296,
1836 0.625285, 0, 0, 0,
1837
1838 0, 0, 0.0857888, 0, 0, 0,
1839 0, 0.488383, 0.252786, 0, 0, 0,
1840 1.02817, 1.85665, 0, 0,
1841
1842 0.00981836, 0, 1.06371, 0, 0, 0,
1843 0, 0, 0, 0.290445, 0.316406, 0,
1844 0.304161, 1.25079, 0.0707152, 0,
1845
1846 0.986264, 0.309201, 0, 0, 0, 0,
1847 0, 1.64896, 0.346248, 0, 0.918175, 0.78884,
1848 0.524981, 1.92076, 2.07013, 0.333244,
1849
1850 0.415153, 0.210318, 0, 0, 0, 0,
1851 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453,
1852 0.628881, 3.58099, 1.49974, 0};
1853
1854 static std::initializer_list<float> rnn_weights = {
1855 0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346,
1856 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399,
1857 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113,
1858 -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512,
1859 -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188,
1860 -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158,
1861 -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241,
1862 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183,
1863 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303,
1864 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884,
1865 -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726,
1866 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644,
1867 -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461,
1868 -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158,
1869 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042,
1870 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012,
1871 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345,
1872 -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884,
1873 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274,
1874 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934,
1875 -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077,
1876 0.277308, 0.415818};
1877
1878 static std::initializer_list<float> rnn_recurrent_weights = {
1879 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1880 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1881 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1882 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1883 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1884 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1885 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1886 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1887 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1888 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1889 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1890 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1891 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1892 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1893 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1894 0.1};
1895
1896 static std::initializer_list<float> rnn_bias = {
1897 0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, -0.23566568,
1898 -0.389184, 0.47481549, -0.4791103, 0.29931796, 0.10463274, 0.83918178,
1899 0.37197268, 0.61957061, 0.3956964, -0.37609905};
1900
1901 class RNNOpModel : public SingleOpModelWithNNAPI {
1902 public:
RNNOpModel(int batches,int units,int size,const TensorType weights=TensorType_FLOAT32,const TensorType recurrent_weights=TensorType_FLOAT32)1903 RNNOpModel(int batches, int units, int size,
1904 const TensorType weights = TensorType_FLOAT32,
1905 const TensorType recurrent_weights = TensorType_FLOAT32)
1906 : batches_(batches), units_(units), input_size_(size) {
1907 input_ = AddInput(TensorType_FLOAT32);
1908 weights_ = AddInput(weights);
1909 recurrent_weights_ = AddInput(recurrent_weights);
1910 bias_ = AddInput(TensorType_FLOAT32);
1911 hidden_state_ = AddInput(TensorType_FLOAT32, true);
1912 output_ = AddOutput(TensorType_FLOAT32);
1913 SetBuiltinOp(
1914 BuiltinOperator_RNN, BuiltinOptions_RNNOptions,
1915 CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union());
1916 BuildInterpreter({{batches_, input_size_}, // input tensor
1917 {units_, input_size_}, // weights tensor
1918 {units_, units_}, // recurrent weights tensor
1919 {units_}, // bias tensor
1920 {batches_, units_}}); // hidden state tensor
1921 }
1922
SetBias(std::initializer_list<float> f)1923 void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
1924
SetWeights(std::initializer_list<float> f)1925 void SetWeights(std::initializer_list<float> f) {
1926 PopulateTensor(weights_, f);
1927 }
1928
SetRecurrentWeights(std::initializer_list<float> f)1929 void SetRecurrentWeights(std::initializer_list<float> f) {
1930 PopulateTensor(recurrent_weights_, f);
1931 }
1932
SetInput(std::initializer_list<float> data)1933 void SetInput(std::initializer_list<float> data) {
1934 PopulateTensor(input_, data);
1935 }
1936
SetInput(int offset,float * begin,float * end)1937 void SetInput(int offset, float* begin, float* end) {
1938 PopulateTensor(input_, offset, begin, end);
1939 }
1940
GetOutput()1941 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
1942
input_size()1943 int input_size() { return input_size_; }
num_units()1944 int num_units() { return units_; }
num_batches()1945 int num_batches() { return batches_; }
1946
1947 protected:
1948 int input_;
1949 int weights_;
1950 int recurrent_weights_;
1951 int bias_;
1952 int hidden_state_;
1953 int output_;
1954
1955 int batches_;
1956 int units_;
1957 int input_size_;
1958 };
1959
TEST(NNAPIDelegate,RnnBlackBoxTest)1960 TEST(NNAPIDelegate, RnnBlackBoxTest) {
1961 RNNOpModel rnn(2, 16, 8);
1962 rnn.SetWeights(rnn_weights);
1963 rnn.SetBias(rnn_bias);
1964 rnn.SetRecurrentWeights(rnn_recurrent_weights);
1965
1966 const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
1967 (rnn.input_size() * rnn.num_batches());
1968
1969 for (int i = 0; i < input_sequence_size; i++) {
1970 float* batch_start = rnn_input + i * rnn.input_size();
1971 float* batch_end = batch_start + rnn.input_size();
1972 rnn.SetInput(0, batch_start, batch_end);
1973 rnn.SetInput(rnn.input_size(), batch_start, batch_end);
1974
1975 rnn.Invoke();
1976
1977 float* golden_start = rnn_golden_output + i * rnn.num_units();
1978 float* golden_end = golden_start + rnn.num_units();
1979 std::vector<float> expected;
1980 expected.insert(expected.end(), golden_start, golden_end);
1981 expected.insert(expected.end(), golden_start, golden_end);
1982
1983 EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
1984 }
1985 }
1986
1987 static float svdf_input[] = {
1988 0.12609188, -0.46347019, -0.89598465,
1989 0.35867718, 0.36897406, 0.73463392,
1990
1991 0.14278367, -1.64410412, -0.75222826,
1992 -0.57290924, 0.12729003, 0.7567004,
1993
1994 0.49837467, 0.19278903, 0.26584083,
1995 0.17660543, 0.52949083, -0.77931279,
1996
1997 -0.11186574, 0.13164264, -0.05349274,
1998 -0.72674477, -0.5683046, 0.55900657,
1999
2000 -0.68892461, 0.37783599, 0.18263303,
2001 -0.63690937, 0.44483393, -0.71817774,
2002
2003 -0.81299269, -0.86831826, 1.43940818,
2004 -0.95760226, 1.82078898, 0.71135032,
2005
2006 -1.45006323, -0.82251364, -1.69082689,
2007 -1.65087092, -1.89238167, 1.54172635,
2008
2009 0.03966608, -0.24936394, -0.77526885,
2010 2.06740379, -1.51439476, 1.43768692,
2011
2012 0.11771342, -0.23761693, -0.65898693,
2013 0.31088525, -1.55601168, -0.87661445,
2014
2015 -0.89477462, 1.67204106, -0.53235275,
2016 -0.6230064, 0.29819036, 1.06939757,
2017 };
2018
2019 static float svdf_golden_output_rank_1[] = {
2020 0.014899, -0.0517661, -0.143725, -0.00271883,
2021 -0.03004015, 0.09565311, 0.1587342, 0.00784263,
2022
2023 0.068281, -0.162217, -0.152268, 0.00323521,
2024 0.01582633, 0.03858774, -0.03001583, -0.02671271,
2025
2026 -0.0317821, -0.0333089, 0.0609602, 0.0333759,
2027 -0.01432795, 0.05524484, 0.1101355, -0.02382665,
2028
2029 -0.00623099, -0.077701, -0.391193, -0.0136691,
2030 -0.02333033, 0.02293761, 0.12338032, 0.04326871,
2031
2032 0.201551, -0.164607, -0.179462, -0.0592739,
2033 0.01064911, -0.17503069, 0.07821996, -0.00224009,
2034
2035 0.0886511, -0.0875401, -0.269283, 0.0281379,
2036 -0.02282338, 0.09741908, 0.32973239, 0.12281385,
2037
2038 -0.201174, -0.586145, -0.628624, -0.0330412,
2039 0.24780814, -0.39304617, -0.22473189, 0.02589256,
2040
2041 -0.0839096, -0.299329, 0.108746, 0.109808,
2042 0.10084175, -0.06416984, 0.28936723, 0.0026358,
2043
2044 0.419114, -0.237824, -0.422627, 0.175115,
2045 -0.2314795, -0.18584411, -0.4228974, -0.12928449,
2046
2047 0.36726, -0.522303, -0.456502, -0.175475,
2048 0.17012937, -0.34447709, 0.38505614, -0.28158101,
2049 };
2050
2051 static float svdf_golden_output_rank_2[] = {
2052 -0.09623547, -0.10193135, 0.11083051, -0.0347917,
2053 0.1141196, 0.12965347, -0.12652366, 0.01007236,
2054
2055 -0.16396809, -0.21247184, 0.11259045, -0.04156673,
2056 0.10132131, -0.06143532, -0.00924693, 0.10084561,
2057
2058 0.01257364, 0.0506071, -0.19287863, -0.07162561,
2059 -0.02033747, 0.22673416, 0.15487903, 0.02525555,
2060
2061 -0.1411963, -0.37054959, 0.01774767, 0.05867489,
2062 0.09607603, -0.0141301, -0.08995658, 0.12867066,
2063
2064 -0.27142537, -0.16955489, 0.18521598, -0.12528358,
2065 0.00331409, 0.11167502, 0.02218599, -0.07309391,
2066
2067 0.09593632, -0.28361851, -0.0773851, 0.17199151,
2068 -0.00075242, 0.33691186, -0.1536046, 0.16572715,
2069
2070 -0.27916506, -0.27626723, 0.42615682, 0.3225764,
2071 -0.37472126, -0.55655634, -0.05013514, 0.289112,
2072
2073 -0.24418658, 0.07540751, -0.1940318, -0.08911639,
2074 0.00732617, 0.46737891, 0.26449674, 0.24888524,
2075
2076 -0.17225097, -0.54660404, -0.38795233, 0.08389944,
2077 0.07736043, -0.28260678, 0.15666828, 1.14949894,
2078
2079 -0.57454878, -0.64704704, 0.73235172, -0.34616736,
2080 0.21120001, -0.22927976, 0.02455296, -0.35906726,
2081 };
2082
2083 class BaseSVDFOpModel : public SingleOpModelWithNNAPI {
2084 public:
BaseSVDFOpModel(int batches,int units,int input_size,int memory_size,int rank,TensorType weights_feature_type=TensorType_FLOAT32,TensorType weights_time_type=TensorType_FLOAT32)2085 BaseSVDFOpModel(int batches, int units, int input_size, int memory_size,
2086 int rank,
2087 TensorType weights_feature_type = TensorType_FLOAT32,
2088 TensorType weights_time_type = TensorType_FLOAT32)
2089 : batches_(batches),
2090 units_(units),
2091 input_size_(input_size),
2092 memory_size_(memory_size),
2093 rank_(rank) {
2094 input_ = AddInput(TensorType_FLOAT32);
2095 weights_feature_ = AddInput(weights_feature_type);
2096 weights_time_ = AddInput(weights_time_type);
2097 // TODO(b/121383394) : figure out why optional bias causes TFLite segfault
2098 // when using NNAPI delegate.
2099 bias_ = AddInput(TensorType_FLOAT32);
2100 const int num_filters = units * rank;
2101 activation_state_ = AddInput(
2102 TensorData{TensorType_FLOAT32, {batches, memory_size * num_filters}},
2103 /*is_variable=*/true);
2104 output_ = AddOutput(TensorType_FLOAT32);
2105 SetBuiltinOp(
2106 BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions,
2107 CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE).Union());
2108 BuildInterpreter({
2109 {batches_, input_size_}, // input tensor
2110 {units_ * rank, input_size_}, // weights_feature tensor
2111 {units_ * rank, memory_size_}, // weights_time tensor
2112 {units_}, // bias tensor
2113 {batches, memory_size * num_filters} // activation_state tensor
2114 });
2115 // TODO(b/121383394) : remove once the optional bias bug is fixed.
2116 PopulateTensor(bias_, std::vector<float>(units_));
2117 }
2118
2119 // Populates the weights_feature tensor.
SetWeightsFeature(std::initializer_list<float> f)2120 void SetWeightsFeature(std::initializer_list<float> f) {
2121 PopulateTensor(weights_feature_, f);
2122 }
2123
2124 // Populates the weights_time tensor.
SetWeightsTime(std::initializer_list<float> f)2125 void SetWeightsTime(std::initializer_list<float> f) {
2126 PopulateTensor(weights_time_, f);
2127 }
2128
2129 // Populates the input tensor.
SetInput(int offset,float * begin,float * end)2130 void SetInput(int offset, float* begin, float* end) {
2131 PopulateTensor(input_, offset, begin, end);
2132 }
2133
2134 // Extracts the output tensor from the SVDF op.
GetOutput()2135 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
2136
input_size()2137 int input_size() { return input_size_; }
num_units()2138 int num_units() { return units_; }
num_batches()2139 int num_batches() { return batches_; }
2140
2141 protected:
2142 int input_;
2143 int weights_feature_;
2144 int weights_time_;
2145 int bias_;
2146 int activation_state_;
2147 int output_;
2148
2149 int batches_;
2150 int units_;
2151 int input_size_;
2152 int memory_size_;
2153 int rank_;
2154 };
2155
2156 class SVDFOpModel : public BaseSVDFOpModel {
2157 public:
2158 using BaseSVDFOpModel::BaseSVDFOpModel;
2159 };
2160
2161 class SVDFOpTest : public ::testing::Test {
2162 protected:
VerifyGoldens(float golden_input[],float golden_output[],int golden_size,BaseSVDFOpModel * svdf,float tolerance=1e-5)2163 void VerifyGoldens(float golden_input[], float golden_output[],
2164 int golden_size, BaseSVDFOpModel* svdf,
2165 float tolerance = 1e-5) {
2166 const int svdf_num_batches = svdf->num_batches();
2167 const int svdf_input_size = svdf->input_size();
2168 const int svdf_num_units = svdf->num_units();
2169 const int input_sequence_size =
2170 golden_size / sizeof(float) / (svdf_input_size * svdf_num_batches);
2171 // Going over each input batch, setting the input tensor, invoking the SVDF
2172 // op and checking the output with the expected golden values.
2173 for (int i = 0; i < input_sequence_size; i++) {
2174 float* batch_start =
2175 golden_input + i * svdf_input_size * svdf_num_batches;
2176 float* batch_end = batch_start + svdf_input_size * svdf_num_batches;
2177 svdf->SetInput(0, batch_start, batch_end);
2178
2179 svdf->Invoke();
2180
2181 const float* golden_start =
2182 golden_output + i * svdf_num_units * svdf_num_batches;
2183 const float* golden_end =
2184 golden_start + svdf_num_units * svdf_num_batches;
2185 std::vector<float> expected;
2186 expected.insert(expected.end(), golden_start, golden_end);
2187
2188 EXPECT_THAT(svdf->GetOutput(),
2189 ElementsAreArray(ArrayFloatNear(expected, tolerance)));
2190 }
2191 }
2192 };
2193
TEST_F(SVDFOpTest,BlackBoxTestRank1)2194 TEST_F(SVDFOpTest, BlackBoxTestRank1) {
2195 SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
2196 /*memory_size=*/10, /*rank=*/1);
2197 svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
2198 0.22197971, 0.12416199, 0.27901134, 0.27557442,
2199 0.3905206, -0.36137494, -0.06634006, -0.10640851});
2200
2201 svdf.SetWeightsTime(
2202 {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
2203 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
2204
2205 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
2206 -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
2207
2208 -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
2209 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
2210
2211 -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
2212 -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657});
2213
2214 VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
2215 &svdf);
2216 }
2217
TEST_F(SVDFOpTest,BlackBoxTestRank2)2218 TEST_F(SVDFOpTest, BlackBoxTestRank2) {
2219 SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
2220 /*memory_size=*/10, /*rank=*/2);
2221 svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347,
2222 0.12416199, 0.15785322, 0.27901134, 0.3905206,
2223 0.21931258, -0.36137494, -0.10640851, 0.31053296,
2224 -0.36118156, -0.0976817, -0.36916667, 0.22197971,
2225 0.15294972, 0.38031587, 0.27557442, 0.39635518,
2226 -0.21580373, -0.06634006, -0.02702999, 0.27072677});
2227
2228 svdf.SetWeightsTime(
2229 {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
2230 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
2231
2232 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
2233 -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
2234
2235 -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
2236 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
2237
2238 -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
2239 -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657,
2240
2241 -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486,
2242 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187,
2243
2244 -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589,
2245 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836,
2246
2247 -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277,
2248 -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214,
2249
2250 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326,
2251 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763});
2252
2253 VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
2254 &svdf);
2255 }
2256
2257 class LSTMOpModel : public SingleOpModelWithNNAPI {
2258 public:
LSTMOpModel(int n_batch,int n_input,int n_cell,int n_output,bool use_cifg,bool use_peephole,bool use_projection_weights,bool use_projection_bias,float cell_clip,float proj_clip,const std::vector<std::vector<int>> & input_shapes,const TensorType weight_type)2259 LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg,
2260 bool use_peephole, bool use_projection_weights,
2261 bool use_projection_bias, float cell_clip, float proj_clip,
2262 const std::vector<std::vector<int>>& input_shapes,
2263 const TensorType weight_type)
2264 : n_batch_(n_batch),
2265 n_input_(n_input),
2266 n_cell_(n_cell),
2267 n_output_(n_output),
2268 weight_type_(weight_type) {
2269 input_ = AddInput(TensorType_FLOAT32);
2270
2271 if (use_cifg) {
2272 input_to_input_weights_ = AddNullInput();
2273 } else {
2274 input_to_input_weights_ = AddInput(weight_type);
2275 }
2276
2277 input_to_forget_weights_ = AddInput(weight_type);
2278 input_to_cell_weights_ = AddInput(weight_type);
2279 input_to_output_weights_ = AddInput(weight_type);
2280
2281 if (use_cifg) {
2282 recurrent_to_input_weights_ = AddNullInput();
2283 } else {
2284 recurrent_to_input_weights_ = AddInput(weight_type);
2285 }
2286
2287 recurrent_to_forget_weights_ = AddInput(weight_type);
2288 recurrent_to_cell_weights_ = AddInput(weight_type);
2289 recurrent_to_output_weights_ = AddInput(weight_type);
2290
2291 if (use_peephole) {
2292 if (use_cifg) {
2293 cell_to_input_weights_ = AddNullInput();
2294 } else {
2295 cell_to_input_weights_ = AddInput(weight_type);
2296 }
2297 cell_to_forget_weights_ = AddInput(weight_type);
2298 cell_to_output_weights_ = AddInput(weight_type);
2299 } else {
2300 cell_to_input_weights_ = AddNullInput();
2301 cell_to_forget_weights_ = AddNullInput();
2302 cell_to_output_weights_ = AddNullInput();
2303 }
2304
2305 if (use_cifg) {
2306 input_gate_bias_ = AddNullInput();
2307 } else {
2308 input_gate_bias_ = AddInput(TensorType_FLOAT32);
2309 }
2310 forget_gate_bias_ = AddInput(TensorType_FLOAT32);
2311 cell_bias_ = AddInput(TensorType_FLOAT32);
2312 output_gate_bias_ = AddInput(TensorType_FLOAT32);
2313
2314 if (use_projection_weights) {
2315 projection_weights_ = AddInput(weight_type);
2316 if (use_projection_bias) {
2317 projection_bias_ = AddInput(TensorType_FLOAT32);
2318 } else {
2319 projection_bias_ = AddNullInput();
2320 }
2321 } else {
2322 projection_weights_ = AddNullInput();
2323 projection_bias_ = AddNullInput();
2324 }
2325
2326 // Adding the 2 input state tensors.
2327 input_activation_state_ =
2328 AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_output_}}, true);
2329 input_cell_state_ =
2330 AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_cell_}}, true);
2331
2332 output_ = AddOutput(TensorType_FLOAT32);
2333
2334 SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
2335 CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
2336 cell_clip, proj_clip)
2337 .Union());
2338 BuildInterpreter(input_shapes);
2339 }
2340
SetInputToInputWeights(const std::vector<float> & f)2341 void SetInputToInputWeights(const std::vector<float>& f) {
2342 SetData(input_to_input_weights_, weight_type_, f);
2343 }
2344
SetInputToForgetWeights(const std::vector<float> & f)2345 void SetInputToForgetWeights(const std::vector<float>& f) {
2346 SetData(input_to_forget_weights_, weight_type_, f);
2347 }
2348
SetInputToCellWeights(const std::vector<float> & f)2349 void SetInputToCellWeights(const std::vector<float>& f) {
2350 SetData(input_to_cell_weights_, weight_type_, f);
2351 }
2352
SetInputToOutputWeights(const std::vector<float> & f)2353 void SetInputToOutputWeights(const std::vector<float>& f) {
2354 SetData(input_to_output_weights_, weight_type_, f);
2355 }
2356
SetRecurrentToInputWeights(const std::vector<float> & f)2357 void SetRecurrentToInputWeights(const std::vector<float>& f) {
2358 SetData(recurrent_to_input_weights_, weight_type_, f);
2359 }
2360
SetRecurrentToForgetWeights(const std::vector<float> & f)2361 void SetRecurrentToForgetWeights(const std::vector<float>& f) {
2362 SetData(recurrent_to_forget_weights_, weight_type_, f);
2363 }
2364
SetRecurrentToCellWeights(const std::vector<float> & f)2365 void SetRecurrentToCellWeights(const std::vector<float>& f) {
2366 SetData(recurrent_to_cell_weights_, weight_type_, f);
2367 }
2368
SetRecurrentToOutputWeights(const std::vector<float> & f)2369 void SetRecurrentToOutputWeights(const std::vector<float>& f) {
2370 SetData(recurrent_to_output_weights_, weight_type_, f);
2371 }
2372
SetCellToInputWeights(const std::vector<float> & f)2373 void SetCellToInputWeights(const std::vector<float>& f) {
2374 SetData(cell_to_input_weights_, weight_type_, f);
2375 }
2376
SetCellToForgetWeights(const std::vector<float> & f)2377 void SetCellToForgetWeights(const std::vector<float>& f) {
2378 SetData(cell_to_forget_weights_, weight_type_, f);
2379 }
2380
SetCellToOutputWeights(const std::vector<float> & f)2381 void SetCellToOutputWeights(const std::vector<float>& f) {
2382 SetData(cell_to_output_weights_, weight_type_, f);
2383 }
2384
SetInputGateBias(const std::vector<float> & f)2385 void SetInputGateBias(const std::vector<float>& f) {
2386 PopulateTensor(input_gate_bias_, f);
2387 }
2388
SetForgetGateBias(const std::vector<float> & f)2389 void SetForgetGateBias(const std::vector<float>& f) {
2390 PopulateTensor(forget_gate_bias_, f);
2391 }
2392
SetCellBias(const std::vector<float> & f)2393 void SetCellBias(const std::vector<float>& f) {
2394 PopulateTensor(cell_bias_, f);
2395 }
2396
SetOutputGateBias(const std::vector<float> & f)2397 void SetOutputGateBias(const std::vector<float>& f) {
2398 PopulateTensor(output_gate_bias_, f);
2399 }
2400
SetProjectionWeights(const std::vector<float> & f)2401 void SetProjectionWeights(const std::vector<float>& f) {
2402 SetData(projection_weights_, weight_type_, f);
2403 }
2404
SetProjectionBias(const std::vector<float> & f)2405 void SetProjectionBias(const std::vector<float>& f) {
2406 PopulateTensor(projection_bias_, f);
2407 }
2408
SetInput(int offset,const float * begin,const float * end)2409 void SetInput(int offset, const float* begin, const float* end) {
2410 PopulateTensor(input_, offset, const_cast<float*>(begin),
2411 const_cast<float*>(end));
2412 }
2413
GetOutput()2414 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
2415
num_inputs()2416 int num_inputs() { return n_input_; }
num_outputs()2417 int num_outputs() { return n_output_; }
num_cells()2418 int num_cells() { return n_cell_; }
num_batches()2419 int num_batches() { return n_batch_; }
2420
2421 protected:
2422 int input_;
2423 int input_to_input_weights_;
2424 int input_to_forget_weights_;
2425 int input_to_cell_weights_;
2426 int input_to_output_weights_;
2427
2428 int recurrent_to_input_weights_;
2429 int recurrent_to_forget_weights_;
2430 int recurrent_to_cell_weights_;
2431 int recurrent_to_output_weights_;
2432
2433 int cell_to_input_weights_;
2434 int cell_to_forget_weights_;
2435 int cell_to_output_weights_;
2436
2437 int input_gate_bias_;
2438 int forget_gate_bias_;
2439 int cell_bias_;
2440 int output_gate_bias_;
2441
2442 int projection_weights_;
2443 int projection_bias_;
2444 int input_activation_state_;
2445 int input_cell_state_;
2446
2447 int output_;
2448 int output_state_;
2449 int cell_state_;
2450
2451 int n_batch_;
2452 int n_input_;
2453 int n_cell_;
2454 int n_output_;
2455
2456 private:
2457 const TensorType weight_type_;
2458 };
2459
2460 class BaseLstmTest : public ::testing::Test {
2461 protected:
2462 // Weights of the LSTM model. Some are optional.
2463 std::vector<float> input_to_input_weights_;
2464 std::vector<float> input_to_cell_weights_;
2465 std::vector<float> input_to_forget_weights_;
2466 std::vector<float> input_to_output_weights_;
2467 std::vector<float> input_gate_bias_;
2468 std::vector<float> cell_gate_bias_;
2469 std::vector<float> forget_gate_bias_;
2470 std::vector<float> output_gate_bias_;
2471 std::vector<float> recurrent_to_input_weights_;
2472 std::vector<float> recurrent_to_cell_weights_;
2473 std::vector<float> recurrent_to_forget_weights_;
2474 std::vector<float> recurrent_to_output_weights_;
2475 std::vector<float> cell_to_input_weights_;
2476 std::vector<float> cell_to_forget_weights_;
2477 std::vector<float> cell_to_output_weights_;
2478 std::vector<float> projection_weights_;
2479
2480 // LSTM input is stored as num_batch x num_inputs vector.
2481 std::vector<std::vector<float>> lstm_input_;
2482 // LSTM output is stored as num_batch x num_outputs vector.
2483 std::vector<std::vector<float>> lstm_golden_output_;
2484
2485 // Compares output up to tolerance to the result of the lstm given the input.
VerifyGoldens(const std::vector<std::vector<float>> & input,const std::vector<std::vector<float>> & output,LSTMOpModel * lstm,float tolerance=1e-5)2486 void VerifyGoldens(const std::vector<std::vector<float>>& input,
2487 const std::vector<std::vector<float>>& output,
2488 LSTMOpModel* lstm, float tolerance = 1e-5) {
2489 const int num_batches = input.size();
2490 EXPECT_GT(num_batches, 0);
2491 const int num_inputs = lstm->num_inputs();
2492 EXPECT_GT(num_inputs, 0);
2493 const int input_sequence_size = input[0].size() / num_inputs;
2494 EXPECT_GT(input_sequence_size, 0);
2495 for (int i = 0; i < input_sequence_size; ++i) {
2496 for (int b = 0; b < num_batches; ++b) {
2497 const float* batch_start = input[b].data() + i * num_inputs;
2498 const float* batch_end = batch_start + num_inputs;
2499
2500 lstm->SetInput(b * lstm->num_inputs(), batch_start, batch_end);
2501 }
2502
2503 lstm->Invoke();
2504
2505 const int num_outputs = lstm->num_outputs();
2506 std::vector<float> expected;
2507 for (int b = 0; b < num_batches; ++b) {
2508 const float* golden_start_batch = output[b].data() + i * num_outputs;
2509 const float* golden_end_batch = golden_start_batch + num_outputs;
2510 expected.insert(expected.end(), golden_start_batch, golden_end_batch);
2511 }
2512 EXPECT_THAT(lstm->GetOutput(),
2513 ElementsAreArray(ArrayFloatNear(expected, tolerance)));
2514 }
2515 }
2516 };
2517
2518 class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
SetUp()2519 void SetUp() override {
2520 input_to_input_weights_ = {-0.45018822, -0.02338299, -0.0870589,
2521 -0.34550029, 0.04266912, -0.15680569,
2522 -0.34856534, 0.43890524};
2523 input_to_cell_weights_ = {-0.50013041, 0.1370284, 0.11810488, 0.2013163,
2524 -0.20583314, 0.44344562, 0.22077113, -0.29909778};
2525 input_to_forget_weights_ = {0.09701663, 0.20334584, -0.50592935,
2526 -0.31343272, -0.40032279, 0.44781327,
2527 0.01387155, -0.35593212};
2528 input_to_output_weights_ = {-0.25065863, -0.28290087, 0.04613829,
2529 0.40525138, 0.44272184, 0.03897077,
2530 -0.1556896, 0.19487578};
2531 input_gate_bias_ = {0., 0., 0., 0.};
2532 cell_gate_bias_ = {0., 0., 0., 0.};
2533 forget_gate_bias_ = {1., 1., 1., 1.};
2534 output_gate_bias_ = {0., 0., 0., 0.};
2535
2536 recurrent_to_input_weights_ = {
2537 -0.0063535, -0.2042388, 0.31454784, -0.35746509,
2538 0.28902304, 0.08183324, -0.16555229, 0.02286911,
2539 -0.13566875, 0.03034258, 0.48091322, -0.12528998,
2540 0.24077177, -0.51332325, -0.33502164, 0.10629296};
2541
2542 recurrent_to_cell_weights_ = {
2543 -0.3407414, 0.24443203, -0.2078532, 0.26320225,
2544 0.05695659, -0.00123841, -0.4744786, -0.35869038,
2545 -0.06418842, -0.13502428, -0.501764, 0.22830659,
2546 -0.46367589, 0.26016325, -0.03894562, -0.16368064};
2547
2548 recurrent_to_forget_weights_ = {
2549 -0.48684245, -0.06655136, 0.42224967, 0.2112639,
2550 0.27654213, 0.20864892, -0.07646349, 0.45877004,
2551 0.00141793, -0.14609534, 0.36447752, 0.09196436,
2552 0.28053468, 0.01560611, -0.20127171, -0.01140004};
2553
2554 recurrent_to_output_weights_ = {
2555 0.43385774, -0.17194885, 0.2718237, 0.09215671,
2556 0.24107647, -0.39835793, 0.18212086, 0.01301402,
2557 0.48572797, -0.50656658, 0.20047462, -0.20607421,
2558 -0.51818722, -0.15390486, 0.0468148, 0.39922136};
2559
2560 lstm_input_ = {{2., 3., 3., 4., 1., 1.}};
2561 lstm_golden_output_ = {{-0.02973187, 0.1229473, 0.20885126, -0.15358765,
2562 -0.03716109, 0.12507336, 0.41193449, -0.20860538,
2563 -0.15053082, 0.09120187, 0.24278517, -0.12222792}};
2564 }
2565 };
2566
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,LstmBlackBoxTest)2567 TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
2568 const int n_batch = 1;
2569 const int n_input = 2;
2570 // n_cell and n_output have the same size when there is no projection.
2571 const int n_cell = 4;
2572 const int n_output = 4;
2573
2574 LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
2575 /*use_cifg=*/false, /*use_peephole=*/false,
2576 /*use_projection_weights=*/false,
2577 /*use_projection_bias=*/false,
2578 /*cell_clip=*/0.0, /*proj_clip=*/0.0,
2579 {
2580 {n_batch, n_input}, // input tensor
2581
2582 {n_cell, n_input}, // input_to_input_weight tensor
2583 {n_cell, n_input}, // input_to_forget_weight tensor
2584 {n_cell, n_input}, // input_to_cell_weight tensor
2585 {n_cell, n_input}, // input_to_output_weight tensor
2586
2587 {n_cell, n_output}, // recurrent_to_input_weight_tensor
2588 {n_cell, n_output}, // recurrent_to_forget_weight_tensor
2589 {n_cell, n_output}, // recurrent_to_cell_weight_tensor
2590 {n_cell, n_output}, // recurrent_to_output_weight_tensor
2591
2592 {0}, // cell_to_input_weight tensor
2593 {0}, // cell_to_forget_weight tensor
2594 {0}, // cell_to_output_weight tensor
2595
2596 {n_cell}, // input_gate_bias tensor
2597 {n_cell}, // forget_gate_bias tensor
2598 {n_cell}, // cell_bias tensor
2599 {n_cell}, // output_gate_bias tensor
2600
2601 {0, 0}, // projection_weight tensor
2602 {0}, // projection_bias tensor
2603 },
2604 /*weight_type=*/TensorType_FLOAT32);
2605
2606 lstm.SetInputToInputWeights(input_to_input_weights_);
2607 lstm.SetInputToCellWeights(input_to_cell_weights_);
2608 lstm.SetInputToForgetWeights(input_to_forget_weights_);
2609 lstm.SetInputToOutputWeights(input_to_output_weights_);
2610
2611 lstm.SetInputGateBias(input_gate_bias_);
2612 lstm.SetCellBias(cell_gate_bias_);
2613 lstm.SetForgetGateBias(forget_gate_bias_);
2614 lstm.SetOutputGateBias(output_gate_bias_);
2615
2616 lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
2617 lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
2618 lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
2619 lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
2620
2621 VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
2622 }
2623
2624 class CifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
SetUp()2625 void SetUp() override {
2626 input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726,
2627 0.05100781, 0.04717243, 0.48944736,
2628 -0.38535351, -0.17212132};
2629
2630 input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988,
2631 -0.3633365, -0.22755712, 0.28253698,
2632 0.24407166, 0.33826375};
2633
2634 input_to_output_weights_ = {0.10725588, -0.02335852, -0.55932593,
2635 -0.09426838, -0.44257352, 0.54939759,
2636 0.01533556, 0.42751634};
2637 cell_gate_bias_ = {0., 0., 0., 0.};
2638 forget_gate_bias_ = {1., 1., 1., 1.};
2639 output_gate_bias_ = {0., 0., 0., 0.};
2640
2641 recurrent_to_cell_weights_ = {
2642 0.54066205, -0.32668582, -0.43562764, -0.56094903,
2643 0.42957711, 0.01841056, -0.32764608, -0.33027974,
2644 -0.10826075, 0.20675004, 0.19069612, -0.03026325,
2645 -0.54532051, 0.33003211, 0.44901288, 0.21193194};
2646
2647 recurrent_to_forget_weights_ = {
2648 -0.13832897, -0.0515101, -0.2359007, -0.16661474,
2649 -0.14340827, 0.36986142, 0.23414481, 0.55899,
2650 0.10798943, -0.41174671, 0.17751795, -0.34484994,
2651 -0.35874045, -0.11352962, 0.27268326, 0.54058349};
2652
2653 recurrent_to_output_weights_ = {
2654 0.41613156, 0.42610586, -0.16495961, -0.5663873,
2655 0.30579174, -0.05115908, -0.33941799, 0.23364776,
2656 0.11178309, 0.09481031, -0.26424935, 0.46261835,
2657 0.50248802, 0.26114327, -0.43736315, 0.33149987};
2658
2659 cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408,
2660 0.31544167};
2661 cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703,
2662 -0.77109635};
2663
2664 lstm_input_ = {{2., 3., 3., 4., 1., 1.}};
2665 lstm_golden_output_ = {{-0.36444446, -0.00352185, 0.12886585, -0.05163646,
2666 -0.42312205, -0.01218222, 0.24201041, -0.08124574,
2667 -0.358325, -0.04621704, 0.21641694, -0.06471302}};
2668 }
2669 };
2670
TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest,LstmBlackBoxTest)2671 TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
2672 const int n_batch = 1;
2673 const int n_input = 2;
2674 // n_cell and n_output have the same size when there is no projection.
2675 const int n_cell = 4;
2676 const int n_output = 4;
2677
2678 LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
2679 /*use_cifg=*/true, /*use_peephole=*/true,
2680 /*use_projection_weights=*/false,
2681 /*use_projection_bias=*/false,
2682 /*cell_clip=*/0.0, /*proj_clip=*/0.0,
2683 {
2684 {n_batch, n_input}, // input tensor
2685
2686 {0, 0}, // input_to_input_weight tensor
2687 {n_cell, n_input}, // input_to_forget_weight tensor
2688 {n_cell, n_input}, // input_to_cell_weight tensor
2689 {n_cell, n_input}, // input_to_output_weight tensor
2690
2691 {0, 0}, // recurrent_to_input_weight tensor
2692 {n_cell, n_output}, // recurrent_to_forget_weight tensor
2693 {n_cell, n_output}, // recurrent_to_cell_weight tensor
2694 {n_cell, n_output}, // recurrent_to_output_weight tensor
2695
2696 {0}, // cell_to_input_weight tensor
2697 {n_cell}, // cell_to_forget_weight tensor
2698 {n_cell}, // cell_to_output_weight tensor
2699
2700 {0}, // input_gate_bias tensor
2701 {n_cell}, // forget_gate_bias tensor
2702 {n_cell}, // cell_bias tensor
2703 {n_cell}, // output_gate_bias tensor
2704
2705 {0, 0}, // projection_weight tensor
2706 {0}, // projection_bias tensor
2707 },
2708 /*weight_type=*/TensorType_FLOAT32);
2709
2710 lstm.SetInputToCellWeights(input_to_cell_weights_);
2711 lstm.SetInputToForgetWeights(input_to_forget_weights_);
2712 lstm.SetInputToOutputWeights(input_to_output_weights_);
2713
2714 lstm.SetCellBias(cell_gate_bias_);
2715 lstm.SetForgetGateBias(forget_gate_bias_);
2716 lstm.SetOutputGateBias(output_gate_bias_);
2717
2718 lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
2719 lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
2720 lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
2721
2722 lstm.SetCellToForgetWeights(cell_to_forget_weights_);
2723 lstm.SetCellToOutputWeights(cell_to_output_weights_);
2724
2725 VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
2726 }
2727
2728 class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest {
SetUp()2729 void SetUp() override {
2730 input_to_input_weights_ = {
2731 0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
2732 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048,
2733 -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385,
2734 -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282,
2735 -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627,
2736 -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226,
2737 -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059,
2738 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698,
2739 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206,
2740 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585,
2741 -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063,
2742 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603,
2743 -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682,
2744 -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988,
2745 -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764,
2746 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476,
2747 -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012,
2748 -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604,
2749 -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654,
2750 -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677};
2751
2752 input_to_forget_weights_ = {
2753 -0.0018401089, -0.004852237, 0.03698424, 0.014181704,
2754 0.028273236, -0.016726194, -0.05249759, -0.10204261,
2755 0.00861066, -0.040979505, -0.009899187, 0.01923892,
2756 -0.028177269, -0.08535103, -0.14585495, 0.10662567,
2757 -0.01909731, -0.017883534, -0.0047269356, -0.045103323,
2758 0.0030784295, 0.076784775, 0.07463696, 0.094531395,
2759 0.0814421, -0.12257899, -0.033945758, -0.031303465,
2760 0.045630626, 0.06843887, -0.13492945, -0.012480007,
2761 -0.0811829, -0.07224499, -0.09628791, 0.045100946,
2762 0.0012300825, 0.013964662, 0.099372394, 0.02543059,
2763 0.06958324, 0.034257296, 0.0482646, 0.06267997,
2764 0.052625068, 0.12784666, 0.07077897, 0.025725935,
2765 0.04165009, 0.07241905, 0.018668644, -0.037377294,
2766 -0.06277783, -0.08833636, -0.040120605, -0.011405586,
2767 -0.007808335, -0.010301386, -0.005102167, 0.027717464,
2768 0.05483423, 0.11449111, 0.11289652, 0.10939839,
2769 0.13396506, -0.08402166, -0.01901462, -0.044678304,
2770 -0.07720565, 0.014350063, -0.11757958, -0.0652038,
2771 -0.08185733, -0.076754324, -0.092614375, 0.10405491,
2772 0.052960336, 0.035755895, 0.035839386, -0.012540553,
2773 0.036881298, 0.02913376, 0.03420159, 0.05448447,
2774 -0.054523353, 0.02582715, 0.02327355, -0.011857179,
2775 -0.0011980024, -0.034641717, -0.026125094, -0.17582615,
2776 -0.15923657, -0.27486774, -0.0006143371, 0.0001771948,
2777 -8.470171e-05, 0.02651807, 0.045790765, 0.06956496};
2778
2779 input_to_cell_weights_ = {
2780 -0.04580283, -0.09549462, -0.032418985, -0.06454633,
2781 -0.043528453, 0.043018587, -0.049152344, -0.12418144,
2782 -0.078985475, -0.07596889, 0.019484362, -0.11434962,
2783 -0.0074034138, -0.06314844, -0.092981495, 0.0062155537,
2784 -0.025034338, -0.0028890965, 0.048929527, 0.06235075,
2785 0.10665918, -0.032036792, -0.08505916, -0.10843358,
2786 -0.13002433, -0.036816437, -0.02130134, -0.016518239,
2787 0.0047691227, -0.0025825808, 0.066017866, 0.029991534,
2788 -0.10652836, -0.1037554, -0.13056071, -0.03266643,
2789 -0.033702414, -0.006473424, -0.04611692, 0.014419339,
2790 -0.025174323, 0.0396852, 0.081777506, 0.06157468,
2791 0.10210095, -0.009658194, 0.046511717, 0.03603906,
2792 0.0069369148, 0.015960095, -0.06507666, 0.09551598,
2793 0.053568836, 0.06408714, 0.12835667, -0.008714329,
2794 -0.20211966, -0.12093674, 0.029450472, 0.2849013,
2795 -0.029227901, 0.1164364, -0.08560263, 0.09941786,
2796 -0.036999565, -0.028842626, -0.0033637602, -0.017012902,
2797 -0.09720865, -0.11193351, -0.029155117, -0.017936034,
2798 -0.009768936, -0.04223324, -0.036159635, 0.06505112,
2799 -0.021742892, -0.023377212, -0.07221364, -0.06430552,
2800 0.05453865, 0.091149814, 0.06387331, 0.007518393,
2801 0.055960953, 0.069779344, 0.046411168, 0.10509911,
2802 0.07463894, 0.0075130584, 0.012850982, 0.04555431,
2803 0.056955688, 0.06555285, 0.050801456, -0.009862683,
2804 0.00826772, -0.026555609, -0.0073611983, -0.0014897042};
2805
2806 input_to_output_weights_ = {
2807 -0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918,
2808 -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534,
2809 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722,
2810 -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761,
2811 -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394,
2812 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154,
2813 -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135,
2814 -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564,
2815 -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047,
2816 -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304,
2817 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946,
2818 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646,
2819 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813,
2820 -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403,
2821 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415,
2822 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495,
2823 -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158,
2824 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295,
2825 -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739,
2826 -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956};
2827
2828 input_gate_bias_ = {0.02234832, 0.14757581, 0.18176508, 0.10380666,
2829 0.053110216, -0.06928846, -0.13942584, -0.11816189,
2830 0.19483899, 0.03652339, -0.10250295, 0.036714908,
2831 -0.18426876, 0.036065217, 0.21810818, 0.02383196,
2832 -0.043370757, 0.08690144, -0.04444982, 0.00030581196};
2833
2834 forget_gate_bias_ = {0.035185695, -0.042891346, -0.03032477, 0.23027696,
2835 0.11098921, 0.15378423, 0.09263801, 0.09790885,
2836 0.09508917, 0.061199076, 0.07665568, -0.015443159,
2837 -0.03499149, 0.046190713, 0.08895977, 0.10899629,
2838 0.40694186, 0.06030037, 0.012413437, -0.06108739};
2839
2840 cell_gate_bias_ = {-0.024379363, 0.0055531194, 0.23377132, 0.033463873,
2841 -0.1483596, -0.10639995, -0.091433935, 0.058573797,
2842 -0.06809782, -0.07889636, -0.043246906, -0.09829136,
2843 -0.4279842, 0.034901652, 0.18797937, 0.0075234566,
2844 0.016178843, 0.1749513, 0.13975595, 0.92058027};
2845
2846 output_gate_bias_ = {0.046159424, -0.0012809046, 0.03563469, 0.12648113,
2847 0.027195795, 0.35373217, -0.018957434, 0.008907322,
2848 -0.0762701, 0.12018895, 0.04216877, 0.0022856654,
2849 0.040952638, 0.3147856, 0.08225149, -0.057416286,
2850 -0.14995944, -0.008040261, 0.13208859, 0.029760877};
2851
2852 recurrent_to_input_weights_ = {
2853 -0.001374326, -0.078856036, 0.10672688, 0.029162422,
2854 -0.11585556, 0.02557986, -0.13446963, -0.035785314,
2855 -0.01244275, 0.025961924, -0.02337298, -0.044228926,
2856 -0.055839065, -0.046598054, -0.010546039, -0.06900766,
2857 0.027239809, 0.022582639, -0.013296484, -0.05459212,
2858 0.08981, -0.045407712, 0.08682226, -0.06867011,
2859 -0.14390695, -0.02916037, 0.000996957, 0.091420636,
2860 0.14283475, -0.07390571, -0.06402044, 0.062524505,
2861 -0.093129106, 0.04860203, -0.08364217, -0.08119002,
2862 0.009352075, 0.22920375, 0.0016303885, 0.11583097,
2863 -0.13732095, 0.012405723, -0.07551853, 0.06343048,
2864 0.12162708, -0.031923793, -0.014335606, 0.01790974,
2865 -0.10650317, -0.0724401, 0.08554849, -0.05727212,
2866 0.06556731, -0.042729504, -0.043227166, 0.011683251,
2867 -0.013082158, -0.029302018, -0.010899579, -0.062036745,
2868 -0.022509435, -0.00964907, -0.01567329, 0.04260106,
2869 -0.07787477, -0.11576462, 0.017356863, 0.048673786,
2870 -0.017577527, -0.05527947, -0.082487635, -0.040137455,
2871 -0.10820036, -0.04666372, 0.022746278, -0.07851417,
2872 0.01068115, 0.032956902, 0.022433773, 0.0026891115,
2873 0.08944216, -0.0685835, 0.010513544, 0.07228705,
2874 0.02032331, -0.059686817, -0.0005566496, -0.086984694,
2875 0.040414046, -0.1380399, 0.094208956, -0.05722982,
2876 0.012092817, -0.04989123, -0.086576, -0.003399834,
2877 -0.04696032, -0.045747425, 0.10091314, 0.048676282,
2878 -0.029037097, 0.031399418, -0.0040285117, 0.047237843,
2879 0.09504992, 0.041799378, -0.049185462, -0.031518843,
2880 -0.10516937, 0.026374253, 0.10058866, -0.0033195973,
2881 -0.041975245, 0.0073591834, 0.0033782164, -0.004325073,
2882 -0.10167381, 0.042500053, -0.01447153, 0.06464186,
2883 -0.017142897, 0.03312627, 0.009205989, 0.024138335,
2884 -0.011337001, 0.035530265, -0.010912711, 0.0706555,
2885 -0.005894094, 0.051841937, -0.1401738, -0.02351249,
2886 0.0365468, 0.07590991, 0.08838724, 0.021681072,
2887 -0.10086113, 0.019608743, -0.06195883, 0.077335775,
2888 0.023646897, -0.095322326, 0.02233014, 0.09756986,
2889 -0.048691444, -0.009579111, 0.07595467, 0.11480546,
2890 -0.09801813, 0.019894179, 0.08502348, 0.004032281,
2891 0.037211012, 0.068537936, -0.048005626, -0.091520436,
2892 -0.028379958, -0.01556313, 0.06554592, -0.045599163,
2893 -0.01672207, -0.020169014, -0.011877351, -0.20212261,
2894 0.010889619, 0.0047078193, 0.038385306, 0.08540671,
2895 -0.017140968, -0.0035865551, 0.016678626, 0.005633034,
2896 0.015963363, 0.00871737, 0.060130805, 0.028611384,
2897 0.10109069, -0.015060172, -0.07894427, 0.06401885,
2898 0.011584063, -0.024466386, 0.0047652307, -0.09041358,
2899 0.030737216, -0.0046374933, 0.14215417, -0.11823516,
2900 0.019899689, 0.006106124, -0.027092824, 0.0786356,
2901 0.05052217, -0.058925, -0.011402121, -0.024987547,
2902 -0.0013661642, -0.06832946, -0.015667673, -0.1083353,
2903 -0.00096863037, -0.06988685, -0.053350925, -0.027275559,
2904 -0.033664223, -0.07978348, -0.025200296, -0.017207067,
2905 -0.058403496, -0.055697463, 0.005798788, 0.12965427,
2906 -0.062582195, 0.0013350133, -0.10482091, 0.0379771,
2907 0.072521195, -0.0029455067, -0.13797039, -0.03628521,
2908 0.013806405, -0.017858358, -0.01008298, -0.07700066,
2909 -0.017081132, 0.019358726, 0.0027079724, 0.004635139,
2910 0.062634714, -0.02338735, -0.039547626, -0.02050681,
2911 0.03385117, -0.083611414, 0.002862572, -0.09421313,
2912 0.058618143, -0.08598433, 0.00972939, 0.023867095,
2913 -0.053934585, -0.023203006, 0.07452513, -0.048767887,
2914 -0.07314807, -0.056307215, -0.10433547, -0.06440842,
2915 0.04328182, 0.04389765, -0.020006588, -0.09076438,
2916 -0.11652589, -0.021705797, 0.03345259, -0.010329105,
2917 -0.025767034, 0.013057034, -0.07316461, -0.10145612,
2918 0.06358255, 0.18531723, 0.07759293, 0.12006465,
2919 0.1305557, 0.058638252, -0.03393652, 0.09622831,
2920 -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845,
2921 -0.005644518, 0.06857898, -0.12598175, -0.035084512,
2922 0.03156317, -0.12794146, -0.031963028, 0.04692781,
2923 0.030070418, 0.0071660685, -0.095516115, -0.004643372,
2924 0.040170413, -0.062104587, -0.0037324072, 0.0554317,
2925 0.08184801, -0.019164372, 0.06791302, 0.034257166,
2926 -0.10307039, 0.021943003, 0.046745934, 0.0790918,
2927 -0.0265588, -0.007824208, 0.042546265, -0.00977924,
2928 -0.0002440307, -0.017384544, -0.017990116, 0.12252321,
2929 -0.014512694, -0.08251313, 0.08861942, 0.13589665,
2930 0.026351685, 0.012641483, 0.07466548, 0.044301085,
2931 -0.045414884, -0.051112458, 0.03444247, -0.08502782,
2932 -0.04106223, -0.028126027, 0.028473156, 0.10467447};
2933
2934 recurrent_to_cell_weights_ = {
2935 -0.037322544, 0.018592842, 0.0056175636, -0.06253426,
2936 0.055647098, -0.05713207, -0.05626563, 0.005559383,
2937 0.03375411, -0.025757805, -0.088049285, 0.06017052,
2938 -0.06570978, 0.007384076, 0.035123326, -0.07920549,
2939 0.053676967, 0.044480428, -0.07663568, 0.0071805613,
2940 0.08089997, 0.05143358, 0.038261272, 0.03339287,
2941 -0.027673481, 0.044746667, 0.028349208, 0.020090483,
2942 -0.019443132, -0.030755889, -0.0040000007, 0.04465846,
2943 -0.021585021, 0.0031670958, 0.0053199246, -0.056117613,
2944 -0.10893326, 0.076739706, -0.08509834, -0.027997585,
2945 0.037871376, 0.01449768, -0.09002357, -0.06111149,
2946 -0.046195522, 0.0422062, -0.005683705, -0.1253618,
2947 -0.012925729, -0.04890792, 0.06985068, 0.037654128,
2948 0.03398274, -0.004781977, 0.007032333, -0.031787455,
2949 0.010868644, -0.031489216, 0.09525667, 0.013939797,
2950 0.0058680447, 0.0167067, 0.02668468, -0.04797466,
2951 -0.048885044, -0.12722108, 0.035304096, 0.06554885,
2952 0.00972396, -0.039238118, -0.05159735, -0.11329045,
2953 0.1613692, -0.03750952, 0.06529313, -0.071974665,
2954 -0.11769596, 0.015524369, -0.0013754242, -0.12446318,
2955 0.02786344, -0.014179351, 0.005264273, 0.14376344,
2956 0.015983658, 0.03406988, -0.06939408, 0.040699873,
2957 0.02111075, 0.09669095, 0.041345075, -0.08316494,
2958 -0.07684199, -0.045768797, 0.032298047, -0.041805092,
2959 0.0119405, 0.0061010392, 0.12652606, 0.0064572375,
2960 -0.024950314, 0.11574242, 0.04508852, -0.04335324,
2961 0.06760663, -0.027437469, 0.07216407, 0.06977076,
2962 -0.05438599, 0.034033038, -0.028602652, 0.05346137,
2963 0.043184172, -0.037189785, 0.10420091, 0.00882477,
2964 -0.054019816, -0.074273005, -0.030617684, -0.0028467078,
2965 0.024302477, -0.0038869337, 0.005332455, 0.0013399826,
2966 0.04361412, -0.007001822, 0.09631092, -0.06702025,
2967 -0.042049985, -0.035070654, -0.04103342, -0.10273396,
2968 0.0544271, 0.037184782, -0.13150354, -0.0058036847,
2969 -0.008264958, 0.042035464, 0.05891794, 0.029673764,
2970 0.0063542654, 0.044788733, 0.054816857, 0.062257513,
2971 -0.00093483756, 0.048938446, -0.004952862, -0.007730018,
2972 -0.04043371, -0.017094059, 0.07229206, -0.023670016,
2973 -0.052195564, -0.025616996, -0.01520939, 0.045104615,
2974 -0.007376126, 0.003533447, 0.006570588, 0.056037236,
2975 0.12436656, 0.051817212, 0.028532185, -0.08686856,
2976 0.11868599, 0.07663395, -0.07323171, 0.03463402,
2977 -0.050708205, -0.04458982, -0.11590894, 0.021273347,
2978 0.1251325, -0.15313013, -0.12224372, 0.17228661,
2979 0.023029093, 0.086124025, 0.006445803, -0.03496501,
2980 0.028332196, 0.04449512, -0.042436164, -0.026587414,
2981 -0.006041347, -0.09292539, -0.05678812, 0.03897832,
2982 0.09465633, 0.008115513, -0.02171956, 0.08304309,
2983 0.071401566, 0.019622514, 0.032163795, -0.004167056,
2984 0.02295182, 0.030739572, 0.056506045, 0.004612461,
2985 0.06524936, 0.059999723, 0.046395954, -0.0045512207,
2986 -0.1335546, -0.030136576, 0.11584653, -0.014678886,
2987 0.0020118146, -0.09688814, -0.0790206, 0.039770417,
2988 -0.0329582, 0.07922767, 0.029322514, 0.026405897,
2989 0.04207835, -0.07073373, 0.063781224, 0.0859677,
2990 -0.10925287, -0.07011058, 0.048005477, 0.03438226,
2991 -0.09606514, -0.006669445, -0.043381985, 0.04240257,
2992 -0.06955775, -0.06769346, 0.043903265, -0.026784198,
2993 -0.017840602, 0.024307009, -0.040079936, -0.019946516,
2994 0.045318738, -0.12233574, 0.026170589, 0.0074471775,
2995 0.15978073, 0.10185836, 0.10298046, -0.015476589,
2996 -0.039390966, -0.072174534, 0.0739445, -0.1211869,
2997 -0.0347889, -0.07943156, 0.014809798, -0.12412325,
2998 -0.0030663363, 0.039695457, 0.0647603, -0.08291318,
2999 -0.018529687, -0.004423833, 0.0037507233, 0.084633216,
3000 -0.01514876, -0.056505352, -0.012800942, -0.06994386,
3001 0.012962922, -0.031234352, 0.07029052, 0.016418684,
3002 0.03618972, 0.055686004, -0.08663945, -0.017404709,
3003 -0.054761406, 0.029065743, 0.052404847, 0.020238016,
3004 0.0048197987, -0.0214882, 0.07078733, 0.013016777,
3005 0.06262858, 0.009184685, 0.020785125, -0.043904778,
3006 -0.0270329, -0.03299152, -0.060088247, -0.015162964,
3007 -0.001828936, 0.12642565, -0.056757294, 0.013586685,
3008 0.09232601, -0.035886683, 0.06000002, 0.05229691,
3009 -0.052580316, -0.082029596, -0.010794592, 0.012947712,
3010 -0.036429964, -0.085508935, -0.13127148, -0.017744139,
3011 0.031502828, 0.036232427, -0.031581745, 0.023051167,
3012 -0.05325106, -0.03421577, 0.028793324, -0.034633752,
3013 -0.009881397, -0.043551125, -0.018609839, 0.0019097115,
3014 -0.008799762, 0.056595087, 0.0022273948, 0.055752404};
3015
3016 recurrent_to_forget_weights_ = {
3017 -0.057784554, -0.026057621, -0.068447545, -0.022581743,
3018 0.14811787, 0.10826372, 0.09471067, 0.03987225,
3019 -0.0039523416, 0.00030638507, 0.053185795, 0.10572994,
3020 0.08414449, -0.022036452, -0.00066928595, -0.09203576,
3021 0.032950465, -0.10985798, -0.023809856, 0.0021431844,
3022 -0.02196096, -0.00326074, 0.00058621005, -0.074678116,
3023 -0.06193199, 0.055729095, 0.03736828, 0.020123724,
3024 0.061878487, -0.04729229, 0.034919553, -0.07585433,
3025 -0.04421272, -0.044019096, 0.085488975, 0.04058006,
3026 -0.06890133, -0.030951202, -0.024628663, -0.07672815,
3027 0.034293607, 0.08556707, -0.05293577, -0.033561368,
3028 -0.04899627, 0.0241671, 0.015736353, -0.095442444,
3029 -0.029564252, 0.016493602, -0.035026584, 0.022337519,
3030 -0.026871363, 0.004780428, 0.0077918363, -0.03601621,
3031 0.016435321, -0.03263031, -0.09543275, -0.047392778,
3032 0.013454138, 0.028934088, 0.01685226, -0.086110644,
3033 -0.046250615, -0.01847454, 0.047608484, 0.07339695,
3034 0.034546845, -0.04881143, 0.009128804, -0.08802852,
3035 0.03761666, 0.008096139, -0.014454086, 0.014361001,
3036 -0.023502491, -0.0011840804, -0.07607001, 0.001856849,
3037 -0.06509276, -0.006021153, -0.08570962, -0.1451793,
3038 0.060212336, 0.055259194, 0.06974018, 0.049454916,
3039 -0.027794661, -0.08077226, -0.016179763, 0.1169753,
3040 0.17213494, -0.0056326236, -0.053934924, -0.0124349,
3041 -0.11520337, 0.05409887, 0.088759385, 0.0019655675,
3042 0.0042065294, 0.03881498, 0.019844765, 0.041858196,
3043 -0.05695512, 0.047233116, 0.038937137, -0.06542224,
3044 0.014429736, -0.09719407, 0.13908425, -0.05379757,
3045 0.012321099, 0.082840554, -0.029899208, 0.044217527,
3046 0.059855383, 0.07711018, -0.045319796, 0.0948846,
3047 -0.011724666, -0.0033288454, -0.033542685, -0.04764985,
3048 -0.13873616, 0.040668588, 0.034832682, -0.015319203,
3049 -0.018715994, 0.046002675, 0.0599172, -0.043107376,
3050 0.0294216, -0.002314414, -0.022424703, 0.0030315618,
3051 0.0014641669, 0.0029166266, -0.11878115, 0.013738511,
3052 0.12375372, -0.0006038222, 0.029104086, 0.087442465,
3053 0.052958444, 0.07558703, 0.04817258, 0.044462286,
3054 -0.015213451, -0.08783778, -0.0561384, -0.003008196,
3055 0.047060397, -0.002058388, 0.03429439, -0.018839769,
3056 0.024734668, 0.024614193, -0.042046934, 0.09597743,
3057 -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786,
3058 -0.02558259, -0.022822596, -0.023273505, -0.02464396,
3059 -0.10991725, -0.006240552, 0.0074488563, 0.024044557,
3060 0.04383914, -0.046476185, 0.028658995, 0.060410924,
3061 0.050786525, 0.009452605, -0.0073054377, -0.024810238,
3062 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517,
3063 0.015898481, 0.021362653, -0.030262267, 0.016587038,
3064 -0.011442813, 0.041154444, -0.007631438, -0.03423484,
3065 -0.010977775, 0.036152758, 0.0066366293, 0.11915515,
3066 0.02318443, -0.041350313, 0.021485701, -0.10906167,
3067 -0.028218046, -0.00954771, 0.020531068, -0.11995105,
3068 -0.03672871, 0.024019798, 0.014255957, -0.05221243,
3069 -0.00661567, -0.04630967, 0.033188973, 0.10107534,
3070 -0.014027541, 0.030796422, -0.10270911, -0.035999842,
3071 0.15443139, 0.07684145, 0.036571592, -0.035900835,
3072 -0.0034699554, 0.06209149, 0.015920248, -0.031122351,
3073 -0.03858649, 0.01849943, 0.13872518, 0.01503974,
3074 0.069941424, -0.06948533, -0.0088794185, 0.061282158,
3075 -0.047401894, 0.03100163, -0.041533746, -0.10430945,
3076 0.044574402, -0.01425562, -0.024290353, 0.034563623,
3077 0.05866852, 0.023947537, -0.09445152, 0.035450947,
3078 0.02247216, -0.0042998926, 0.061146557, -0.10250651,
3079 0.020881841, -0.06747029, 0.10062043, -0.0023941975,
3080 0.03532124, -0.016341697, 0.09685456, -0.016764693,
3081 0.051808182, 0.05875331, -0.04536488, 0.001626336,
3082 -0.028892258, -0.01048663, -0.009793449, -0.017093895,
3083 0.010987891, 0.02357273, -0.00010856845, 0.0099760275,
3084 -0.001845119, -0.03551521, 0.0018358806, 0.05763657,
3085 -0.01769146, 0.040995963, 0.02235177, -0.060430344,
3086 0.11475477, -0.023854522, 0.10071741, 0.0686208,
3087 -0.014250481, 0.034261297, 0.047418304, 0.08562733,
3088 -0.030519066, 0.0060542435, 0.014653856, -0.038836084,
3089 0.04096551, 0.032249358, -0.08355519, -0.026823482,
3090 0.056386515, -0.010401743, -0.028396193, 0.08507674,
3091 0.014410365, 0.020995233, 0.17040324, 0.11511526,
3092 0.02459721, 0.0066619175, 0.025853224, -0.023133837,
3093 -0.081302024, 0.017264642, -0.009585969, 0.09491168,
3094 -0.051313367, 0.054532815, -0.014298593, 0.10657464,
3095 0.007076659, 0.10964551, 0.0409152, 0.008275321,
3096 -0.07283536, 0.07937492, 0.04192024, -0.1075027};
3097
3098 recurrent_to_output_weights_ = {
3099 0.025825322, -0.05813119, 0.09495884, -0.045984812,
3100 -0.01255415, -0.0026479573, -0.08196161, -0.054914974,
3101 -0.0046604523, -0.029587349, -0.044576716, -0.07480124,
3102 -0.082868785, 0.023254942, 0.027502948, -0.0039728214,
3103 -0.08683098, -0.08116779, -0.014675607, -0.037924774,
3104 -0.023314456, -0.007401714, -0.09255757, 0.029460307,
3105 -0.08829125, -0.005139627, -0.08989442, -0.0555066,
3106 0.13596267, -0.025062224, -0.048351806, -0.03850004,
3107 0.07266485, -0.022414139, 0.05940088, 0.075114764,
3108 0.09597592, -0.010211725, -0.0049794707, -0.011523867,
3109 -0.025980417, 0.072999895, 0.11091378, -0.081685916,
3110 0.014416728, 0.043229222, 0.034178585, -0.07530371,
3111 0.035837382, -0.085607, -0.007721233, -0.03287832,
3112 -0.043848954, -0.06404588, -0.06632928, -0.073643476,
3113 0.008214239, -0.045984086, 0.039764922, 0.03474462,
3114 0.060612556, -0.080590084, 0.049127717, 0.04151091,
3115 -0.030063879, 0.008801774, -0.023021035, -0.019558564,
3116 0.05158114, -0.010947698, -0.011825728, 0.0075720972,
3117 0.0699727, -0.0039981045, 0.069350146, 0.08799282,
3118 0.016156472, 0.035502106, 0.11695009, 0.006217345,
3119 0.13392477, -0.037875112, 0.025745004, 0.08940699,
3120 -0.00924166, 0.0046702605, -0.036598757, -0.08811812,
3121 0.10522024, -0.032441203, 0.008176899, -0.04454919,
3122 0.07058152, 0.0067963637, 0.039206743, 0.03259838,
3123 0.03725492, -0.09515802, 0.013326398, -0.052055415,
3124 -0.025676316, 0.03198509, -0.015951829, -0.058556724,
3125 0.036879618, 0.043357447, 0.028362012, -0.05908629,
3126 0.0059240665, -0.04995891, -0.019187413, 0.0276265,
3127 -0.01628143, 0.0025863599, 0.08800015, 0.035250366,
3128 -0.022165963, -0.07328642, -0.009415526, -0.07455109,
3129 0.11690406, 0.0363299, 0.07411125, 0.042103454,
3130 -0.009660886, 0.019076364, 0.018299393, -0.046004917,
3131 0.08891175, 0.0431396, -0.026327137, -0.051502608,
3132 0.08979574, -0.051670972, 0.04940282, -0.07491107,
3133 -0.021240504, 0.022596184, -0.034280192, 0.060163025,
3134 -0.058211457, -0.051837247, -0.01349775, -0.04639988,
3135 -0.035936575, -0.011681591, 0.064818054, 0.0073146066,
3136 -0.021745546, -0.043124277, -0.06471268, -0.07053354,
3137 -0.029321948, -0.05330136, 0.016933719, -0.053782392,
3138 0.13747959, -0.1361751, -0.11569455, 0.0033329215,
3139 0.05693899, -0.053219706, 0.063698, 0.07977434,
3140 -0.07924483, 0.06936997, 0.0034815092, -0.007305279,
3141 -0.037325785, -0.07251102, -0.033633437, -0.08677009,
3142 0.091591336, -0.14165086, 0.021752775, 0.019683983,
3143 0.0011612234, -0.058154266, 0.049996935, 0.0288841,
3144 -0.0024567875, -0.14345716, 0.010955264, -0.10234828,
3145 0.1183656, -0.0010731248, -0.023590032, -0.072285876,
3146 -0.0724771, -0.026382286, -0.0014920527, 0.042667855,
3147 0.0018776858, 0.02986552, 0.009814309, 0.0733756,
3148 0.12289186, 0.018043943, -0.0458958, 0.049412545,
3149 0.033632483, 0.05495232, 0.036686596, -0.013781798,
3150 -0.010036754, 0.02576849, -0.08307328, 0.010112348,
3151 0.042521734, -0.05869831, -0.071689695, 0.03876447,
3152 -0.13275425, -0.0352966, -0.023077697, 0.10285965,
3153 0.084736146, 0.15568255, -0.00040734606, 0.027835453,
3154 -0.10292561, -0.032401145, 0.10053256, -0.026142767,
3155 -0.08271222, -0.0030240538, -0.016368777, 0.1070414,
3156 0.042672627, 0.013456989, -0.0437609, -0.022309763,
3157 0.11576483, 0.04108048, 0.061026827, -0.0190714,
3158 -0.0869359, 0.037901703, 0.0610107, 0.07202949,
3159 0.01675338, 0.086139716, -0.08795751, -0.014898893,
3160 -0.023771819, -0.01965048, 0.007955471, -0.043740474,
3161 0.03346837, -0.10549954, 0.090567775, 0.042013682,
3162 -0.03176985, 0.12569028, -0.02421228, -0.029526481,
3163 0.023851605, 0.031539805, 0.05292009, -0.02344001,
3164 -0.07811758, -0.08834428, 0.10094801, 0.16594367,
3165 -0.06861939, -0.021256343, -0.041093912, -0.06669611,
3166 0.035498552, 0.021757556, -0.09302526, -0.015403468,
3167 -0.06614931, -0.051798206, -0.013874718, 0.03630673,
3168 0.010412845, -0.08077351, 0.046185967, 0.0035662893,
3169 0.03541868, -0.094149634, -0.034814864, 0.003128424,
3170 -0.020674974, -0.03944324, -0.008110165, -0.11113267,
3171 0.08484226, 0.043586485, 0.040582247, 0.0968012,
3172 -0.065249965, -0.028036479, 0.0050708856, 0.0017462453,
3173 0.0326779, 0.041296225, 0.09164146, -0.047743853,
3174 -0.015952192, -0.034451712, 0.084197424, -0.05347844,
3175 -0.11768019, 0.085926116, -0.08251791, -0.045081906,
3176 0.0948852, 0.068401024, 0.024856757, 0.06978981,
3177 -0.057309967, -0.012775832, -0.0032452994, 0.01977615,
3178 -0.041040014, -0.024264973, 0.063464895, 0.05431621,
3179 };
3180
3181 cell_to_input_weights_ = {
3182 0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458,
3183 -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174,
3184 -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047,
3185 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175};
3186
3187 cell_to_forget_weights_ = {
3188 -0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276,
3189 -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766,
3190 -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774,
3191 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355};
3192
3193 cell_to_output_weights_ = {
3194 0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764,
3195 -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544,
3196 -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817,
3197 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733};
3198
3199 projection_weights_ = {
3200 -0.009802181, 0.09401916, 0.0717386, -0.13895074,
3201 0.09641832, 0.060420845, 0.08539281, 0.054285463,
3202 0.061395317, 0.034448683, -0.042991187, 0.019801661,
3203 -0.16840284, -0.015726732, -0.23041931, -0.024478018,
3204 -0.10959692, -0.013875541, 0.18600968, -0.061274476,
3205 0.0138165, -0.08160894, -0.07661644, 0.032372914,
3206 0.16169067, 0.22465782, -0.03993472, -0.004017731,
3207 0.08633481, -0.28869787, 0.08682067, 0.17240396,
3208 0.014975425, 0.056431185, 0.031037588, 0.16702051,
3209 0.0077946745, 0.15140012, 0.29405436, 0.120285,
3210 -0.188994, -0.027265169, 0.043389652, -0.022061434,
3211 0.014777949, -0.20203483, 0.094781205, 0.19100232,
3212 0.13987629, -0.036132768, -0.06426278, -0.05108664,
3213 0.13221376, 0.009441198, -0.16715929, 0.15859416,
3214 -0.040437475, 0.050779544, -0.022187516, 0.012166504,
3215 0.027685808, -0.07675938, -0.0055694645, -0.09444123,
3216 0.0046453946, 0.050794356, 0.10770313, -0.20790008,
3217 -0.07149004, -0.11425117, 0.008225835, -0.035802525,
3218 0.14374903, 0.15262283, 0.048710253, 0.1847461,
3219 -0.007487823, 0.11000021, -0.09542012, 0.22619456,
3220 -0.029149994, 0.08527916, 0.009043713, 0.0042746216,
3221 0.016261552, 0.022461696, 0.12689082, -0.043589946,
3222 -0.12035478, -0.08361797, -0.050666027, -0.1248618,
3223 -0.1275799, -0.071875185, 0.07377272, 0.09944291,
3224 -0.18897448, -0.1593054, -0.06526116, -0.040107165,
3225 -0.004618631, -0.067624845, -0.007576253, 0.10727444,
3226 0.041546922, -0.20424393, 0.06907816, 0.050412357,
3227 0.00724631, 0.039827548, 0.12449835, 0.10747581,
3228 0.13708383, 0.09134148, -0.12617786, -0.06428341,
3229 0.09956831, 0.1208086, -0.14676677, -0.0727722,
3230 0.1126304, 0.010139365, 0.015571211, -0.038128063,
3231 0.022913318, -0.042050496, 0.16842307, -0.060597885,
3232 0.10531834, -0.06411776, -0.07451711, -0.03410368,
3233 -0.13393489, 0.06534304, 0.003620307, 0.04490757,
3234 0.05970546, 0.05197996, 0.02839995, 0.10434969,
3235 -0.013699693, -0.028353551, -0.07260381, 0.047201227,
3236 -0.024575593, -0.036445823, 0.07155557, 0.009672501,
3237 -0.02328883, 0.009533515, -0.03606021, -0.07421458,
3238 -0.028082801, -0.2678904, -0.13221288, 0.18419984,
3239 -0.13012612, -0.014588381, -0.035059117, -0.04824723,
3240 0.07830115, -0.056184657, 0.03277091, 0.025466874,
3241 0.14494097, -0.12522776, -0.098633975, -0.10766018,
3242 -0.08317623, 0.08594209, 0.07749552, 0.039474737,
3243 0.1776665, -0.07409566, -0.0477268, 0.29323658,
3244 0.10801441, 0.1154011, 0.013952499, 0.10739139,
3245 0.10708251, -0.051456142, 0.0074137426, -0.10430189,
3246 0.10034707, 0.045594677, 0.0635285, -0.0715442,
3247 -0.089667566, -0.10811871, 0.00026344223, 0.08298446,
3248 -0.009525053, 0.006585689, -0.24567553, -0.09450807,
3249 0.09648481, 0.026996298, -0.06419476, -0.04752702,
3250 -0.11063944, -0.23441927, -0.17608605, -0.052156363,
3251 0.067035615, 0.19271925, -0.0032889997, -0.043264326,
3252 0.09663576, -0.057112187, -0.10100678, 0.0628376,
3253 0.04447668, 0.017961001, -0.10094388, -0.10190601,
3254 0.18335468, 0.10494553, -0.052095775, -0.0026118709,
3255 0.10539724, -0.04383912, -0.042349473, 0.08438151,
3256 -0.1947263, 0.02251204, 0.11216432, -0.10307853,
3257 0.17351969, -0.039091777, 0.08066188, -0.00561982,
3258 0.12633002, 0.11335965, -0.0088127935, -0.019777594,
3259 0.06864014, -0.059751723, 0.016233567, -0.06894641,
3260 -0.28651384, -0.004228674, 0.019708522, -0.16305895,
3261 -0.07468996, -0.0855457, 0.099339016, -0.07580735,
3262 -0.13775392, 0.08434318, 0.08330512, -0.12131499,
3263 0.031935584, 0.09180414, -0.08876437, -0.08049874,
3264 0.008753825, 0.03498998, 0.030215185, 0.03907079,
3265 0.089751154, 0.029194152, -0.03337423, -0.019092513,
3266 0.04331237, 0.04299654, -0.036394123, -0.12915532,
3267 0.09793732, 0.07512415, -0.11319543, -0.032502122,
3268 0.15661901, 0.07671967, -0.005491124, -0.19379048,
3269 -0.218606, 0.21448623, 0.017840758, 0.1416943,
3270 -0.07051762, 0.19488361, 0.02664691, -0.18104725,
3271 -0.09334311, 0.15026465, -0.15493552, -0.057762887,
3272 -0.11604192, -0.262013, -0.01391798, 0.012185008,
3273 0.11156489, -0.07483202, 0.06693364, -0.26151478,
3274 0.046425626, 0.036540434, -0.16435726, 0.17338543,
3275 -0.21401681, -0.11385144, -0.08283257, -0.069031075,
3276 0.030635102, 0.010969227, 0.11109743, 0.010919218,
3277 0.027526086, 0.13519906, 0.01891392, -0.046839405,
3278 -0.040167913, 0.017953383, -0.09700955, 0.0061885654,
3279 -0.07000971, 0.026893595, -0.038844477, 0.14543656};
3280
3281 lstm_input_ = {
3282 {// Batch0: 4 (input_sequence_size) * 5 (n_input)
3283 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, // step 0
3284 0.596268, 0.998386, 0.568695, 0.864524, 0.571277, // step 1
3285 0.073204, 0.296072, 0.743333, 0.069199, 0.045348, // step 2
3286 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, // step 3
3287
3288 {// Batch1: 4 (input_sequence_size) * 5 (n_input)
3289 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, // step 0
3290 0.642421, 0.524260, 0.134799, 0.003639, 0.162482, // step 1
3291 0.640394, 0.930399, 0.050782, 0.432485, 0.988078, // step 2
3292 0.082922, 0.563329, 0.865614, 0.333232, 0.259916} // step 3
3293 };
3294
3295 lstm_golden_output_ = {
3296 {// Batch0: 4 (input_sequence_size) * 16 (n_output)
3297 -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576,
3298 -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004,
3299 -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147,
3300 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363,
3301 -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322,
3302 -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308,
3303 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794,
3304 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474,
3305 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827,
3306 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512,
3307 -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407,
3308 -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193,
3309 0.0286833, 0.00824207, 0.0264887, 0.0305169},
3310 {// Batch1: 4 (input_sequence_size) * 16 (n_output)
3311 -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926,
3312 -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232,
3313 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954,
3314 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507,
3315 -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039,
3316 -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233,
3317 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378,
3318 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034,
3319 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789,
3320 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855,
3321 -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679,
3322 -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181,
3323 0.0412031, 0.0118723, 0.0239643, 0.0394009}};
3324 }
3325 };
3326
TEST_F(NoCifgPeepholeProjectionClippingLstmTest,LstmBlackBoxTest)3327 TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
3328 const int n_batch = 2;
3329 const int n_input = 5;
3330 const int n_cell = 20;
3331 const int n_output = 16;
3332
3333 LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
3334 /*use_cifg=*/false, /*use_peephole=*/true,
3335 /*use_projection_weights=*/true,
3336 /*use_projection_bias=*/false,
3337 /*cell_clip=*/0.0, /*proj_clip=*/0.0,
3338 {
3339 {n_batch, n_input}, // input tensor
3340
3341 {n_cell, n_input}, // input_to_input_weight tensor
3342 {n_cell, n_input}, // input_to_forget_weight tensor
3343 {n_cell, n_input}, // input_to_cell_weight tensor
3344 {n_cell, n_input}, // input_to_output_weight tensor
3345
3346 {n_cell, n_output}, // recurrent_to_input_weight tensor
3347 {n_cell, n_output}, // recurrent_to_forget_weight tensor
3348 {n_cell, n_output}, // recurrent_to_cell_weight tensor
3349 {n_cell, n_output}, // recurrent_to_output_weight tensor
3350
3351 {n_cell}, // cell_to_input_weight tensor
3352 {n_cell}, // cell_to_forget_weight tensor
3353 {n_cell}, // cell_to_output_weight tensor
3354
3355 {n_cell}, // input_gate_bias tensor
3356 {n_cell}, // forget_gate_bias tensor
3357 {n_cell}, // cell_bias tensor
3358 {n_cell}, // output_gate_bias tensor
3359
3360 {n_output, n_cell}, // projection_weight tensor
3361 {0}, // projection_bias tensor
3362 },
3363 /*weight_type=*/TensorType_FLOAT32);
3364
3365 lstm.SetInputToInputWeights(input_to_input_weights_);
3366 lstm.SetInputToCellWeights(input_to_cell_weights_);
3367 lstm.SetInputToForgetWeights(input_to_forget_weights_);
3368 lstm.SetInputToOutputWeights(input_to_output_weights_);
3369
3370 lstm.SetInputGateBias(input_gate_bias_);
3371 lstm.SetCellBias(cell_gate_bias_);
3372 lstm.SetForgetGateBias(forget_gate_bias_);
3373 lstm.SetOutputGateBias(output_gate_bias_);
3374
3375 lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
3376 lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
3377 lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
3378 lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
3379
3380 lstm.SetCellToInputWeights(cell_to_input_weights_);
3381 lstm.SetCellToForgetWeights(cell_to_forget_weights_);
3382 lstm.SetCellToOutputWeights(cell_to_output_weights_);
3383
3384 lstm.SetProjectionWeights(projection_weights_);
3385
3386 VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
3387 }
3388
3389 class BaseReduceOpModel : public SingleOpModelWithNNAPI {
3390 public:
SetAxis(const std::vector<int> & data)3391 void SetAxis(const std::vector<int>& data) { PopulateTensor(axis_, data); }
3392
3393 template <class T>
SetInput(const std::vector<T> & data)3394 void SetInput(const std::vector<T>& data) {
3395 PopulateTensor(input_, data);
3396 }
3397
3398 template <class T>
GetOutput()3399 std::vector<T> GetOutput() {
3400 return ExtractVector<T>(output_);
3401 }
3402
GetDequantizedOutput()3403 std::vector<float> GetDequantizedOutput() {
3404 return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
3405 GetScale(output_), GetZeroPoint(output_));
3406 }
3407
GetOutputShape()3408 std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
3409
Input()3410 int Input() { return input_; }
3411
3412 protected:
3413 int input_;
3414 int axis_;
3415 int output_;
3416 };
3417
3418 // Model for the tests case where axis is a const tensor.
3419 class MeanOpConstModel : public BaseReduceOpModel {
3420 public:
MeanOpConstModel(const TensorData & input,const TensorData & output,std::initializer_list<int> axis_shape,std::initializer_list<int> axis,bool keep_dims)3421 MeanOpConstModel(const TensorData& input, const TensorData& output,
3422 std::initializer_list<int> axis_shape,
3423 std::initializer_list<int> axis, bool keep_dims) {
3424 input_ = AddInput(input);
3425 axis_ = AddConstInput(TensorType_INT32, axis, axis_shape);
3426 output_ = AddOutput(output);
3427 SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_ReducerOptions,
3428 CreateReducerOptions(builder_, keep_dims).Union());
3429 BuildInterpreter({GetShape(input_)});
3430 }
3431 };
3432
3433 // Tests for reduce_mean
TEST(NNAPIDelegate,MeanFloatNotKeepDims)3434 TEST(NNAPIDelegate, MeanFloatNotKeepDims) {
3435 std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
3436 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
3437 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
3438 MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}},
3439 {4}, {1, 0, -3, -3}, false);
3440 m.SetInput(data);
3441 m.Invoke();
3442 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
3443 EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({12, 13})));
3444 }
3445
TEST(NNAPIDelegate,MeanFloatKeepDims)3446 TEST(NNAPIDelegate, MeanFloatKeepDims) {
3447 std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
3448 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
3449 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
3450 MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}},
3451 {2}, {0, 2}, true);
3452 m.SetInput(data);
3453 m.Invoke();
3454 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
3455 EXPECT_THAT(m.GetOutput<float>(),
3456 ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5})));
3457 }
3458
3459 class BaseEmbeddingLookupOpModel : public SingleOpModelWithNNAPI {
3460 public:
BaseEmbeddingLookupOpModel(std::initializer_list<int> index_shape,std::initializer_list<int> weight_shape,TensorType weight_type=TensorType_FLOAT32)3461 BaseEmbeddingLookupOpModel(std::initializer_list<int> index_shape,
3462 std::initializer_list<int> weight_shape,
3463 TensorType weight_type = TensorType_FLOAT32) {
3464 input_ = AddInput(TensorType_INT32);
3465 weight_ = AddInput(weight_type);
3466 output_ = AddOutput(TensorType_FLOAT32);
3467 SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOptions_NONE, 0);
3468 BuildInterpreter({index_shape, weight_shape});
3469 }
3470
SetInput(std::initializer_list<int> data)3471 void SetInput(std::initializer_list<int> data) {
3472 PopulateTensor(input_, data);
3473 }
3474
GetOutput()3475 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
3476
3477 protected:
3478 int input_;
3479 int weight_;
3480 int output_;
3481 };
3482
3483 class EmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel {
3484 public:
3485 using BaseEmbeddingLookupOpModel::BaseEmbeddingLookupOpModel;
3486
Set3DWeightMatrix(const std::function<float (int,int,int)> & function)3487 void Set3DWeightMatrix(const std::function<float(int, int, int)>& function) {
3488 TfLiteTensor* tensor = interpreter_->tensor(weight_);
3489 int rows = tensor->dims->data[0];
3490 int columns = tensor->dims->data[1];
3491 int features = tensor->dims->data[2];
3492 for (int i = 0; i < rows; i++) {
3493 for (int j = 0; j < columns; j++) {
3494 for (int k = 0; k < features; k++) {
3495 tensor->data.f[(i * columns + j) * features + k] = function(i, j, k);
3496 }
3497 }
3498 }
3499 }
3500 };
3501
TEST(NNAPIDelegate,EmbeddingLookupSimpleTest)3502 TEST(NNAPIDelegate, EmbeddingLookupSimpleTest) {
3503 EmbeddingLookupOpModel m({3}, {3, 2, 4});
3504 m.SetInput({1, 0, 2});
3505 m.Set3DWeightMatrix(
3506 [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; });
3507
3508 m.Invoke();
3509
3510 EXPECT_THAT(m.GetOutput(),
3511 ElementsAreArray(ArrayFloatNear({
3512 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
3513 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
3514 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
3515 })));
3516 }
3517
3518 class HashtableLookupOpModel : public SingleOpModelWithNNAPI {
3519 public:
HashtableLookupOpModel(std::initializer_list<int> lookup_shape,std::initializer_list<int> key_shape,std::initializer_list<int> value_shape,TensorType type)3520 HashtableLookupOpModel(std::initializer_list<int> lookup_shape,
3521 std::initializer_list<int> key_shape,
3522 std::initializer_list<int> value_shape,
3523 TensorType type) {
3524 lookup_ = AddInput(TensorType_INT32);
3525 key_ = AddInput(TensorType_INT32);
3526 value_ = AddInput(type);
3527 output_ = AddOutput(type);
3528 hit_ = AddOutput(TensorType_UINT8);
3529 SetBuiltinOp(BuiltinOperator_HASHTABLE_LOOKUP, BuiltinOptions_NONE, 0);
3530 BuildInterpreter({lookup_shape, key_shape, value_shape});
3531 }
3532
SetLookup(std::initializer_list<int> data)3533 void SetLookup(std::initializer_list<int> data) {
3534 PopulateTensor<int>(lookup_, data);
3535 }
3536
SetHashtableKey(std::initializer_list<int> data)3537 void SetHashtableKey(std::initializer_list<int> data) {
3538 PopulateTensor<int>(key_, data);
3539 }
3540
SetHashtableValue(const std::vector<string> & content)3541 void SetHashtableValue(const std::vector<string>& content) {
3542 PopulateStringTensor(value_, content);
3543 }
3544
SetHashtableValue(const std::function<float (int)> & function)3545 void SetHashtableValue(const std::function<float(int)>& function) {
3546 TfLiteTensor* tensor = interpreter_->tensor(value_);
3547 int rows = tensor->dims->data[0];
3548 for (int i = 0; i < rows; i++) {
3549 tensor->data.f[i] = function(i);
3550 }
3551 }
3552
SetHashtableValue(const std::function<float (int,int)> & function)3553 void SetHashtableValue(const std::function<float(int, int)>& function) {
3554 TfLiteTensor* tensor = interpreter_->tensor(value_);
3555 int rows = tensor->dims->data[0];
3556 int features = tensor->dims->data[1];
3557 for (int i = 0; i < rows; i++) {
3558 for (int j = 0; j < features; j++) {
3559 tensor->data.f[i * features + j] = function(i, j);
3560 }
3561 }
3562 }
3563
GetStringOutput()3564 std::vector<string> GetStringOutput() {
3565 TfLiteTensor* output = interpreter_->tensor(output_);
3566 int num = GetStringCount(output);
3567 std::vector<string> result(num);
3568 for (int i = 0; i < num; i++) {
3569 auto ref = GetString(output, i);
3570 result[i] = string(ref.str, ref.len);
3571 }
3572 return result;
3573 }
3574
GetOutput()3575 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetHit()3576 std::vector<uint8_t> GetHit() { return ExtractVector<uint8_t>(hit_); }
3577
3578 private:
3579 int lookup_;
3580 int key_;
3581 int value_;
3582 int output_;
3583 int hit_;
3584 };
3585
TEST(NNAPIDelegate,HashtableLookupTest2DInput)3586 TEST(NNAPIDelegate, HashtableLookupTest2DInput) {
3587 HashtableLookupOpModel m({4}, {3}, {3, 2}, TensorType_FLOAT32);
3588
3589 m.SetLookup({1234, -292, -11, 0});
3590 m.SetHashtableKey({-11, 0, 1234});
3591 m.SetHashtableValue([](int i, int j) { return i + j / 10.0f; });
3592
3593 m.Invoke();
3594
3595 EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
3596 2.0, 2.1, // 2-nd item
3597 0, 0, // Not found
3598 0.0, 0.1, // 0-th item
3599 1.0, 1.1, // 1-st item
3600 })));
3601 EXPECT_THAT(m.GetHit(), ElementsAreArray({
3602 1,
3603 0,
3604 1,
3605 1,
3606 }));
3607 }
3608
TEST(NNAPIDelegate,HashtableLookupTest1DInput)3609 TEST(NNAPIDelegate, HashtableLookupTest1DInput) {
3610 HashtableLookupOpModel m({4}, {3}, {3}, TensorType_FLOAT32);
3611
3612 m.SetLookup({1234, -292, -11, 0});
3613 m.SetHashtableKey({-11, 0, 1234});
3614 m.SetHashtableValue([](int i) { return i * i / 10.0f; });
3615
3616 m.Invoke();
3617
3618 EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
3619 0.4, // 2-nd item
3620 0, // Not found
3621 0.0, // 0-th item
3622 0.1, // 1-st item
3623 })));
3624 EXPECT_THAT(m.GetHit(), ElementsAreArray({
3625 1,
3626 0,
3627 1,
3628 1,
3629 }));
3630 }
3631 } // namespace
3632 } // namespace tflite
3633
main(int argc,char ** argv)3634 int main(int argc, char** argv) {
3635 ::tflite::LogToStderr();
3636 ::testing::InitGoogleTest(&argc, argv);
3637 return RUN_ALL_TESTS();
3638 }
3639