1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
16 #include <gtest/gtest.h>
17 #include "tensorflow/lite/interpreter.h"
18 #include "tensorflow/lite/kernels/test_util.h"
19 #include "tensorflow/lite/model.h"
20 
21 namespace testing {
22 namespace internal {
23 
24 // The CTS test fails to compile without this.
25 // TODO(b/130342510): Find a proper solution.
FormatMatcherDescription(bool unused_negation,const char * matcher_name,const Strings & unused_param_values)26 std::string FormatMatcherDescription(bool unused_negation,
27                                      const char* matcher_name,
28                                      const Strings& unused_param_values) {
29   return matcher_name;
30 }
31 
32 }  // namespace internal
33 }  // namespace testing
34 
35 namespace tflite {
36 namespace {
37 
38 using ::testing::ElementsAre;
39 using ::testing::ElementsAreArray;
40 
41 // TODO(b/110368244): figure out how to share the existing tests in kernels/ but
42 // with the delegation on. Also, add more unit tests to improve code coverage.
43 
44 // This matcher uses 1 as maximum tolerance.
45 MATCHER(QuantizedNear, "") {
46   const int diff = abs(std::get<0>(arg) - std::get<1>(arg));
47   if (diff > 1) {
48     *result_listener << "Quantized values can be at most off by one: " << diff;
49     return false;
50   }
51   return true;
52 }
53 
54 class SingleOpModelWithNNAPI : public SingleOpModel {
55  public:
SingleOpModelWithNNAPI()56   SingleOpModelWithNNAPI() {
57     this->SetApplyDelegate([](Interpreter* interpreter) {
58       interpreter->ModifyGraphWithDelegate(NnApiDelegate());
59     });
60   }
61 
ResizeInputTensor(int tensor_index,const std::vector<int> & dims)62   TfLiteStatus ResizeInputTensor(int tensor_index,
63                                  const std::vector<int>& dims) {
64     return interpreter_->ResizeInputTensor(tensor_index, dims);
65   }
66 
67  protected:
SetData(int index,TensorType type,const std::vector<float> & data)68   void SetData(int index, TensorType type, const std::vector<float>& data) {
69     switch (type) {
70       case TensorType_FLOAT32:
71         PopulateTensor(index, data);
72         break;
73       case TensorType_INT32:
74         QuantizeAndPopulate<int32_t>(index, data);
75         break;
76       case TensorType_UINT8:
77         QuantizeAndPopulate<uint8_t>(index, data);
78         break;
79       case TensorType_INT8:
80         QuantizeAndPopulate<int8_t>(index, data);
81         break;
82       default:
83         FAIL() << "Type not supported: " << type;
84         break;
85     }
86   }
87 };
88 
89 class FloatAddOpModel : public SingleOpModelWithNNAPI {
90  public:
FloatAddOpModel(const TensorData & input1,const TensorData & input2,const TensorData & output,ActivationFunctionType activation_type,bool allow_fp32_relax_to_fp16=false)91   FloatAddOpModel(const TensorData& input1, const TensorData& input2,
92                   const TensorData& output,
93                   ActivationFunctionType activation_type,
94                   bool allow_fp32_relax_to_fp16 = false) {
95     input1_ = AddInput(input1);
96     input2_ = AddInput(input2);
97     output_ = AddOutput(output);
98     SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions,
99                  CreateAddOptions(builder_, activation_type).Union());
100     BuildInterpreter({GetShape(input1_), GetShape(input2_)},
101                      allow_fp32_relax_to_fp16);
102   }
103 
input1()104   int input1() { return input1_; }
input2()105   int input2() { return input2_; }
106 
GetOutput()107   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
108 
109  protected:
110   int input1_;
111   int input2_;
112   int output_;
113 };
114 
115 // Do a test with the NN API using no activation.
TEST(NNAPIDelegate,AddWithNoActivation)116 TEST(NNAPIDelegate, AddWithNoActivation) {
117   FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
118                     {TensorType_FLOAT32, {1, 2, 2, 1}},
119                     {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
120   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
121   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
122   m.Invoke();
123   EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3}));
124 }
125 
126 // Do a test with the NN API using no activation.
127 // The test allows computing FP32 with FP16 precision. In this particular case,
128 // calculating in FP32 or FP16 should produce the same results.
TEST(NNAPIDelegate,AddWithNoActivationRelaxed)129 TEST(NNAPIDelegate, AddWithNoActivationRelaxed) {
130   FloatAddOpModel m(
131       {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}},
132       {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE, true);
133   m.PopulateTensor<float>(m.input1(), {-2.0, -1.0, 1.0, 2.0});
134   m.PopulateTensor<float>(m.input2(), {1.0, 2.0, 3.0, 4.0});
135   m.Invoke();
136   EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.0, 1.0, 4.0, 6.0}));
137 }
138 
139 // Do a test with the NN api with relu.
TEST(NNAPIDelegate,AddWithRelu)140 TEST(NNAPIDelegate, AddWithRelu) {
141   FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
142                     {TensorType_FLOAT32, {1, 2, 2, 1}},
143                     {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU);
144   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
145   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
146   m.Invoke();
147   EXPECT_THAT(m.GetOutput(), ElementsAreArray({0.0, 0.4, 1.0, 1.3}));
148 }
149 
150 // Verify that resize attempts fail.
151 // TODO(b/113110851): Verify success after the delegate supports resizing.
TEST(NNAPIDelegate,ResizeFails)152 TEST(NNAPIDelegate, ResizeFails) {
153   FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
154                     {TensorType_FLOAT32, {1, 2, 2, 1}},
155                     {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
156   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
157   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
158   EXPECT_EQ(m.ResizeInputTensor(m.input1(), {1, 3, 3, 1}), kTfLiteError);
159 }
160 
161 class FloatMulOpModel : public SingleOpModelWithNNAPI {
162  public:
FloatMulOpModel(const TensorData & input1,const TensorData & input2,const TensorData & output,ActivationFunctionType activation_type)163   FloatMulOpModel(const TensorData& input1, const TensorData& input2,
164                   const TensorData& output,
165                   ActivationFunctionType activation_type) {
166     input1_ = AddInput(input1);
167     input2_ = AddInput(input2);
168     output_ = AddOutput(output);
169     SetBuiltinOp(BuiltinOperator_MUL, BuiltinOptions_MulOptions,
170                  CreateMulOptions(builder_, activation_type).Union());
171     BuildInterpreter({GetShape(input1_), GetShape(input2_)});
172   }
173 
input1()174   int input1() { return input1_; }
input2()175   int input2() { return input2_; }
176 
GetOutput()177   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
178 
179  protected:
180   int input1_;
181   int input2_;
182   int output_;
183 };
184 
TEST(NNAPIDelegate,MulWithNoActivation)185 TEST(NNAPIDelegate, MulWithNoActivation) {
186   FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
187                     {TensorType_FLOAT32, {1, 2, 2, 1}},
188                     {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
189   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
190   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
191   m.Invoke();
192   EXPECT_THAT(m.GetOutput(),
193               ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4})));
194 }
195 
196 class FloatPoolingOpModel : public SingleOpModelWithNNAPI {
197  public:
FloatPoolingOpModel(BuiltinOperator type,const TensorData & input,int filter_width,int filter_height,const TensorData & output)198   FloatPoolingOpModel(BuiltinOperator type, const TensorData& input,
199                       int filter_width, int filter_height,
200                       const TensorData& output) {
201     input_ = AddInput(input);
202     output_ = AddOutput(output);
203 
204     SetBuiltinOp(
205         type, BuiltinOptions_Pool2DOptions,
206         CreatePool2DOptions(builder_, Padding_VALID, 2, 2, filter_width,
207                             filter_height, ActivationFunctionType_NONE)
208             .Union());
209 
210     BuildInterpreter({GetShape(input_)});
211   }
212 
SetInput(std::initializer_list<float> data)213   void SetInput(std::initializer_list<float> data) {
214     PopulateTensor(input_, data);
215   }
216 
GetOutput()217   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
218 
219  protected:
220   int input_;
221   int output_;
222 };
223 
TEST(NNAPIDelegate,AveragePoolWithNoActivation)224 TEST(NNAPIDelegate, AveragePoolWithNoActivation) {
225   FloatPoolingOpModel m(BuiltinOperator_AVERAGE_POOL_2D,
226                         /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
227                         /*filter_width=*/2, /*filter_height=*/2,
228                         /*output=*/{TensorType_FLOAT32, {}});
229   m.SetInput({
230       0, 6, 2, 4,   //
231       3, 2, 10, 7,  //
232   });
233   m.Invoke();
234   EXPECT_THAT(m.GetOutput(), ElementsAreArray({2.75, 5.75}));
235 }
236 
TEST(NNAPIDelegate,MaxPoolWithNoActivation)237 TEST(NNAPIDelegate, MaxPoolWithNoActivation) {
238   FloatPoolingOpModel m(BuiltinOperator_MAX_POOL_2D,
239                         /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
240                         /*filter_width=*/2, /*filter_height=*/2,
241                         /*output=*/{TensorType_FLOAT32, {}});
242   m.SetInput({
243       0, 6, 2, 4,   //
244       3, 2, 10, 7,  //
245   });
246   m.Invoke();
247   EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 10}));
248 }
249 
TEST(NNAPIDelegate,L2PoolWithNoActivation)250 TEST(NNAPIDelegate, L2PoolWithNoActivation) {
251   FloatPoolingOpModel m(BuiltinOperator_L2_POOL_2D,
252                         /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
253                         /*filter_width=*/2, /*filter_height=*/2,
254                         /*output=*/{TensorType_FLOAT32, {}});
255   m.SetInput({
256       0, 6, 2, 4,   //
257       3, 2, 10, 7,  //
258   });
259   m.Invoke();
260   EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.5, 6.5}));
261 }
262 
263 class ConvolutionOpModel : public SingleOpModelWithNNAPI {
264  public:
ConvolutionOpModel(const TensorData & input,const TensorData & filter,const TensorData & output,int stride_width=2,int stride_height=2,enum Padding padding=Padding_VALID,enum ActivationFunctionType activation=ActivationFunctionType_NONE,int dilation_width_factor=1,int dilation_height_factor=1)265   ConvolutionOpModel(
266       const TensorData& input, const TensorData& filter,
267       const TensorData& output, int stride_width = 2, int stride_height = 2,
268       enum Padding padding = Padding_VALID,
269       enum ActivationFunctionType activation = ActivationFunctionType_NONE,
270       int dilation_width_factor = 1, int dilation_height_factor = 1)
271       : input_type_(input.type), filter_type_(filter.type) {
272     input_ = AddInput(input);
273     filter_ = AddInput(filter);
274 
275     int bias_size = GetShape(filter_)[0];
276     if (input.type == TensorType_FLOAT32) {
277       bias_ = AddInput({TensorType_FLOAT32, {bias_size}});
278     } else {
279       // This is a quantized version. The scale of 'bias' depends on the scales
280       // of input and filter. Supposedly this is correctly set during quantized
281       // training.
282       auto bias_scale = GetScale(input_) * GetScale(filter_);
283       TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale};
284       bias_ = AddInput(bias);
285     }
286 
287     output_ = AddOutput(output);
288 
289     if (input_type_ != TensorType_FLOAT32) {
290       // The following is required by quantized inference. It is the unittest's
291       // responsibility to make sure the output scale falls into the correct
292       // range.
293       CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_));
294     }
295 
296     SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions,
297                  CreateConv2DOptions(
298                      builder_, padding, stride_width, stride_height, activation,
299                      dilation_width_factor, dilation_height_factor)
300                      .Union());
301 
302     BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)});
303   }
304 
SetInput(std::initializer_list<float> data)305   void SetInput(std::initializer_list<float> data) {
306     SetData(input_, input_type_, data);
307   }
308 
SetFilter(std::initializer_list<float> data)309   void SetFilter(std::initializer_list<float> data) {
310     SetData(filter_, filter_type_, data);
311   }
312 
SetBias(std::initializer_list<float> data)313   void SetBias(std::initializer_list<float> data) {
314     const auto bias_type =
315         (input_type_ == TensorType_FLOAT32) ? input_type_ : TensorType_INT32;
316     SetData(bias_, bias_type, data);
317   }
318 
GetOutput()319   std::vector<float> GetOutput() {
320     if (input_type_ == TensorType_FLOAT32) {
321       return ExtractVector<float>(output_);
322     } else {
323       return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
324                                  GetScale(output_), GetZeroPoint(output_));
325     }
326   }
327 
GetQuantizedOutput()328   std::vector<uint8_t> GetQuantizedOutput() {
329     if (input_type_ == TensorType_FLOAT32) {
330       return {};  // Not supported.
331     } else {
332       return ExtractVector<uint8_t>(output_);
333     }
334   }
335 
336  protected:
337   int input_;
338   int filter_;
339   int bias_;
340   int output_;
341 
342   const TensorType input_type_;
343   const TensorType filter_type_;
344 };
345 
346 // In this tests we set the input and output scales so that the results
347 // match exactly the 'non-quantized' version.
TEST(ConvolutionOpTest,SimpleTestQuantized)348 TEST(ConvolutionOpTest, SimpleTestQuantized) {
349   ConvolutionOpModel m({TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64},
350                        {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64},
351                        {TensorType_UINT8, {}, -127, 128});
352   m.SetInput({
353       // First batch
354       1, 1, 1, 1,  // row = 1
355       2, 2, 2, 2,  // row = 2
356       // Second batch
357       1, 2, 3, 4,  // row = 1
358       1, 2, 3, 4,  // row = 2
359   });
360   m.SetFilter({
361       1, 2, 3, 4,    // first 2x2 filter
362       -1, 1, -1, 1,  // second 2x2 filter
363       -1, -1, 1, 1,  // third 2x2 filter
364   });
365   m.SetBias({1, 2, 3});
366 
367   m.Invoke();
368 
369   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
370                                  {
371                                      18, 2, 5,  // first batch, left
372                                      18, 2, 5,  // first batch, right
373                                      17, 4, 3,  // second batch, left
374                                      37, 4, 3,  // second batch, right
375                                  },
376                                  1e-5)));
377   // For good  measure, let's also verify the quantized values:
378   EXPECT_THAT(m.GetQuantizedOutput(), ElementsAreArray({
379                                           145, 129, 132,  //
380                                           145, 129, 132,  //
381                                           144, 131, 130,  //
382                                           164, 131, 130,  //
383                                       }));
384 }
385 
TEST(ConvolutionOpTest,FloatInputQuantizedWeights)386 TEST(ConvolutionOpTest, FloatInputQuantizedWeights) {
387   ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}},
388                        {TensorType_UINT8, {3, 2, 2, 1}, 0, 64},
389                        {TensorType_FLOAT32, {}});
390   m.SetInput({
391       // First batch
392       1, 1, 1, 2,  // row = 1
393       2, 2, 2, 1,  // row = 2
394       // Second batch
395       1, 2, 3, 4,  // row = 1
396       1, 2, 3, 4,  // row = 2
397   });
398   m.SetFilter({
399       1, 2, 3, 4,  // first 2x2 filter
400       0, 1, 0, 1,  // second 2x2 filter
401       0, 0, 1, 1,  // third 2x2 filter
402   });
403   m.SetBias({1, 2, 3});
404 
405   m.Invoke();
406 
407   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
408                                  {
409                                      18, 5, 7,    // first batch, left
410                                      16, 5, 6,    // first batch, right
411                                      17, 6, 6,    // second batch, left
412                                      37, 10, 10,  // second batch, right
413                                  },
414                                  0.2)));
415 }
416 
TEST(ConvolutionOpTest,NoActivation)417 TEST(ConvolutionOpTest, NoActivation) {
418   ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}},
419                        {TensorType_FLOAT32, {3, 2, 2, 1}},
420                        {TensorType_FLOAT32, {}});
421 
422   m.SetInput({
423       // First batch
424       1, 1, 1, 1,  // row = 1
425       2, 2, 2, 2,  // row = 2
426       // Second batch
427       1, 2, 3, 4,  // row = 1
428       1, 2, 3, 4,  // row = 2
429   });
430   m.SetFilter({
431       1, 2, 3, 4,    // first 2x2 filter
432       -1, 1, -1, 1,  // second 2x2 filter
433       -1, -1, 1, 1,  // third 2x2 filter
434   });
435   m.SetBias({1, 2, 3});
436 
437   m.Invoke();
438 
439   EXPECT_THAT(m.GetOutput(), ElementsAreArray({
440                                  18, 2, 5,  // first batch, left
441                                  18, 2, 5,  // first batch, right
442                                  17, 4, 3,  // second batch, left
443                                  37, 4, 3,  // second batch, right
444                              }));
445 }
446 
447 class DepthwiseConvolutionOpModel : public SingleOpModelWithNNAPI {
448  public:
DepthwiseConvolutionOpModel(const TensorData & input,const TensorData & filter,const TensorData & output)449   DepthwiseConvolutionOpModel(const TensorData& input, const TensorData& filter,
450                               const TensorData& output) {
451     input_ = AddInput(input);
452     filter_ = AddInput(filter);
453 
454     int bias_size = GetShape(filter_)[3];
455     if (input.type == TensorType_FLOAT32) {
456       bias_ = AddInput({TensorType_FLOAT32, {bias_size}});
457     } else {
458       // This is a quantized version. The scale of 'bias' depends on the scales
459       // of input and filter. Supposedly this is correctly set during quantized
460       // training.
461       auto bias_scale = GetScale(input_) * GetScale(filter_);
462       TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale};
463       bias_ = AddInput(bias);
464     }
465 
466     output_ = AddOutput(output);
467 
468     int input_depth = GetShape(input_)[3];
469     int output_depth = GetShape(filter_)[3];
470     int depth_mul = output_depth / input_depth;
471 
472     SetBuiltinOp(
473         BuiltinOperator_DEPTHWISE_CONV_2D,
474         BuiltinOptions_DepthwiseConv2DOptions,
475         CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul,
476                                      ActivationFunctionType_NONE)
477             .Union());
478 
479     BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)});
480   }
481 
SetFilter(std::initializer_list<float> f)482   void SetFilter(std::initializer_list<float> f) { PopulateTensor(filter_, f); }
483 
SetBias(std::initializer_list<float> f)484   void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
485 
SetInput(std::initializer_list<float> data)486   void SetInput(std::initializer_list<float> data) {
487     PopulateTensor(input_, data);
488   }
489 
GetOutput()490   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
491 
492  protected:
493   int input_;
494   int filter_;
495   int bias_;
496   int output_;
497 };
498 
TEST(NNAPIDelegate,DepthwiseConv2DWithNoActivation)499 TEST(NNAPIDelegate, DepthwiseConv2DWithNoActivation) {
500   DepthwiseConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 2, 2}},
501                                 {TensorType_FLOAT32, {1, 2, 2, 4}},
502                                 {TensorType_FLOAT32, {}});
503 
504   m.SetInput({
505       1, 2, 7, 8,    // column 1
506       3, 4, 9, 10,   // column 2
507       5, 6, 11, 12,  // column 3
508   });
509   m.SetFilter({
510       1, 2, 3, 4,        //
511       -9, 10, -11, 12,   //
512       5, 6, 7, 8,        //
513       13, -14, 15, -16,  //
514   });
515   m.SetBias({1, 2, 3, 4});
516 
517   m.Invoke();
518 
519   EXPECT_THAT(m.GetOutput(), ElementsAreArray({
520                                  71, -34, 99, -20,  //
521                                  91, -26, 127, -4,  //
522                              }));
523 }
524 
525 class FullyConnectedOpModel : public SingleOpModelWithNNAPI {
526  public:
FullyConnectedOpModel(const TensorData & input,const TensorData & weights,const TensorData & output,enum ActivationFunctionType activation=ActivationFunctionType_NONE)527   FullyConnectedOpModel(
528       const TensorData& input, const TensorData& weights,
529       const TensorData& output,
530       enum ActivationFunctionType activation = ActivationFunctionType_NONE)
531       : input_type_(input.type), weights_type_(weights.type) {
532     input_ = AddInput(input);
533     weights_ = AddInput(weights);
534 
535     const int units = weights.shape[0];
536     if (input.type == TensorType_FLOAT32) {
537       bias_ = AddInput({TensorType_FLOAT32, {units}});
538     } else {
539       // This is a quantized version. The scale of 'bias' depends on the scales
540       // of input and filter. Supposedly this is correctly set during quantized
541       // training.
542       auto bias_scale = GetScale(input_) * GetScale(weights_);
543       TensorData bias{TensorType_INT32, {units}, 0, 0, bias_scale};
544       bias_ = AddInput(bias);
545     }
546 
547     output_ = AddOutput(output);
548 
549     SetBuiltinOp(BuiltinOperator_FULLY_CONNECTED,
550                  BuiltinOptions_FullyConnectedOptions,
551                  CreateFullyConnectedOptions(builder_, activation).Union());
552     BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)});
553   }
554 
SetInput(std::initializer_list<float> data)555   void SetInput(std::initializer_list<float> data) {
556     SetData(input_, input_type_, data);
557   }
558 
SetWeights(std::initializer_list<float> data)559   void SetWeights(std::initializer_list<float> data) {
560     SetData(weights_, weights_type_, data);
561   }
562 
SetBias(std::initializer_list<float> data)563   void SetBias(std::initializer_list<float> data) {
564     const auto bias_type =
565         (input_type_ == TensorType_FLOAT32) ? input_type_ : TensorType_INT32;
566     SetData(bias_, bias_type, data);
567   }
568 
GetOutput()569   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
570 
571  protected:
572   int input_;
573   int weights_;
574   int bias_;
575   int output_;
576 
577   const TensorType input_type_;
578   const TensorType weights_type_;
579 };
580 
TEST(FullyConnectedOpTest,SimpleTest)581 TEST(FullyConnectedOpTest, SimpleTest) {
582   FullyConnectedOpModel m(/*input=*/{TensorType_FLOAT32, {2, 10}},
583                           /*weights=*/{TensorType_FLOAT32, {3, 10}},
584                           /*output=*/{TensorType_FLOAT32});
585   m.SetWeights({
586       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
587       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
588       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
589   });
590   m.SetBias({1, 2, 3});
591 
592   m.SetInput({
593       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
594       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
595   });
596 
597   m.Invoke();
598 
599   EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
600 }
601 
TEST(FullyConnectedOpTest,FloatInputQuantizedWeights)602 TEST(FullyConnectedOpTest, FloatInputQuantizedWeights) {
603   FullyConnectedOpModel m(/*input=*/{TensorType_FLOAT32, {2, 10}},
604                           /*weights=*/{TensorType_UINT8, {3, 10}, 0, 64},
605                           /*output=*/{TensorType_FLOAT32});
606   m.SetWeights({
607       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
608       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
609       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
610   });
611   m.SetBias({1, 2, 3});
612 
613   m.SetInput({
614       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
615       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
616   });
617 
618   m.Invoke();
619 
620   EXPECT_THAT(m.GetOutput(),
621               ElementsAreArray(ArrayFloatNear({24, 25, 26, 58, 59, 60}, 1.3)));
622 }
623 
624 class SoftmaxOpModel : public SingleOpModelWithNNAPI {
625  public:
SoftmaxOpModel(int batches,int size,float beta)626   SoftmaxOpModel(int batches, int size, float beta)
627       : batches_(batches), input_size_(size), beta_(beta) {
628     input_ = AddInput(TensorType_FLOAT32);
629     output_ = AddOutput(TensorType_FLOAT32);
630     SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions,
631                  CreateSoftmaxOptions(builder_, beta_).Union());
632     BuildInterpreter({{batches_, input_size_}});
633   }
634 
SetInput(std::initializer_list<float> data)635   void SetInput(std::initializer_list<float> data) {
636     PopulateTensor(input_, data);
637   }
638 
SetInput(int offset,float * begin,float * end)639   void SetInput(int offset, float* begin, float* end) {
640     PopulateTensor(input_, offset, begin, end);
641   }
642 
GetOutput()643   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
644 
645  private:
646   int input_;
647   int output_;
648 
649   int batches_;
650   int input_size_;
651   float beta_;
652 };
653 
TEST(NNAPIDelegate,SoftmaxSimpleTest)654 TEST(NNAPIDelegate, SoftmaxSimpleTest) {
655   SoftmaxOpModel m(/*batches=*/2, /*size=*/5, /*beta=*/1.0);
656   m.SetInput({
657       1.0, 2.0, 3.0, 4.0, 5.0,       // b = 0
658       -1.0, -2.0, -3.0, -4.0, -5.0,  // b = 0
659   });
660 
661   m.Invoke();
662 
663   EXPECT_THAT(
664       m.GetOutput(),
665       ElementsAreArray(ArrayFloatNear(
666           {0.011656231, 0.031684921, 0.086128544, 0.234121657, 0.636408647,
667            0.636408647, 0.234121657, 0.086128544, 0.031684921, 0.011656231},
668           1e-6)));
669 }
670 
671 class ReshapeOpModel : public SingleOpModelWithNNAPI {
672  public:
ReshapeOpModel(std::initializer_list<int> input_shape,std::initializer_list<int> new_shape)673   ReshapeOpModel(std::initializer_list<int> input_shape,
674                  std::initializer_list<int> new_shape) {
675     input_ = AddInput(TensorType_FLOAT32);
676     new_shape_ = AddConstInput<int>(TensorType_INT32, new_shape,
677                                     {static_cast<int>(new_shape.size())});
678     output_ = AddOutput(TensorType_FLOAT32);
679     SetBuiltinOp(
680         BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions,
681         CreateReshapeOptions(builder_, builder_.CreateVector<int>(new_shape))
682             .Union());
683     BuildInterpreter({input_shape, {static_cast<int>(new_shape.size())}});
684   }
685 
SetInput(std::initializer_list<float> data)686   void SetInput(std::initializer_list<float> data) {
687     PopulateTensor<float>(input_, data);
688   }
GetOutput()689   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()690   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
691 
692  private:
693   int input_;
694   int new_shape_;
695   int output_;
696 };
697 
TEST(NNAPIDelegate,ReshapeSimpleTest)698 TEST(NNAPIDelegate, ReshapeSimpleTest) {
699   ReshapeOpModel m({1, 2, 4, 1}, {2, 2, 2});
700   m.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
701   m.Invoke();
702   EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8}));
703   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2}));
704 }
705 
706 class SqueezeOpModel : public SingleOpModelWithNNAPI {
707  public:
SqueezeOpModel(const TensorData & input,const TensorData & output,std::initializer_list<int> axis)708   SqueezeOpModel(const TensorData& input, const TensorData& output,
709                  std::initializer_list<int> axis) {
710     input_ = AddInput(input);
711     output_ = AddOutput(output);
712     SetBuiltinOp(
713         BuiltinOperator_SQUEEZE, BuiltinOptions_SqueezeOptions,
714         CreateSqueezeOptions(builder_, builder_.CreateVector<int>(axis))
715             .Union());
716     BuildInterpreter({GetShape(input_)});
717   }
718 
SetInput(std::initializer_list<float> data)719   void SetInput(std::initializer_list<float> data) {
720     PopulateTensor<float>(input_, data);
721   }
GetOutput()722   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()723   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
724 
725  private:
726   int input_;
727   int new_shape_;
728   int output_;
729 };
730 
TEST(NNAPIDelegate,SqueezeSimpleTest)731 TEST(NNAPIDelegate, SqueezeSimpleTest) {
732   std::initializer_list<float> data = {
733       1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,  9.0,  10.0, 11.0, 12.0,
734       13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
735   SqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}}, {TensorType_FLOAT32, {24}},
736                    {});
737   m.SetInput(data);
738   m.Invoke();
739   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({24}));
740   EXPECT_THAT(
741       m.GetOutput(),
742       ElementsAreArray({1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,
743                         9.0,  10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
744                         17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}));
745 }
746 
TEST(NNAPIDelegate,SqueezeWithAxisTest)747 TEST(NNAPIDelegate, SqueezeWithAxisTest) {
748   std::initializer_list<float> data = {
749       1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,  9.0,  10.0, 11.0, 12.0,
750       13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
751   SqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}}, {TensorType_FLOAT32, {24}},
752                    {2});
753   m.SetInput(data);
754   m.Invoke();
755   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 24}));
756   EXPECT_THAT(
757       m.GetOutput(),
758       ElementsAreArray({1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,
759                         9.0,  10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
760                         17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}));
761 }
762 
763 class L2NormOpModel : public SingleOpModelWithNNAPI {
764  public:
L2NormOpModel(const TensorData & input,const TensorData & output,ActivationFunctionType activation_type)765   L2NormOpModel(const TensorData& input, const TensorData& output,
766                 ActivationFunctionType activation_type) {
767     input_ = AddInput(input);
768     output_ = AddOutput(output);
769     SetBuiltinOp(BuiltinOperator_L2_NORMALIZATION, BuiltinOptions_L2NormOptions,
770                  CreateL2NormOptions(builder_, activation_type).Union());
771     BuildInterpreter({GetShape(input_)});
772   }
773 
SetInput(std::initializer_list<float> data)774   void SetInput(std::initializer_list<float> data) {
775     PopulateTensor<float>(input_, data);
776   }
GetOutput()777   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()778   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
779 
780  private:
781   int input_;
782   int new_shape_;
783   int output_;
784 };
785 
TEST(NNAPIDelegate,L2NormSimpleTest)786 TEST(NNAPIDelegate, L2NormSimpleTest) {
787   std::initializer_list<float> data = {-1.1, 0.6, 0.7, 1.2, -0.7, 0.1};
788   L2NormOpModel m({TensorType_FLOAT32, {1, 1, 1, 6}},
789                   {TensorType_FLOAT32, {1, 1, 1, 6}},
790                   ActivationFunctionType_NONE);
791   m.SetInput(data);
792   m.Invoke();
793   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 6}));
794   EXPECT_THAT(m.GetOutput(),
795               ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}));
796 }
797 
798 class TransposeSimpleModel : public SingleOpModelWithNNAPI {
799  public:
TransposeSimpleModel(std::initializer_list<int> input_shape,std::initializer_list<int> perm_shape,std::initializer_list<int> perm)800   TransposeSimpleModel(std::initializer_list<int> input_shape,
801                        std::initializer_list<int> perm_shape,
802                        std::initializer_list<int> perm) {
803     input_ = AddInput(TensorType_FLOAT32);
804     perm_ = AddConstInput(TensorType_INT32, perm, perm_shape);
805     output_ = AddOutput(TensorType_FLOAT32);
806     SetBuiltinOp(BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions,
807                  CreateTransposeOptions(builder_).Union());
808     BuildInterpreter({input_shape, perm_shape});
809   }
810 
SetInput(std::initializer_list<float> data)811   void SetInput(std::initializer_list<float> data) {
812     PopulateTensor<float>(input_, data);
813   }
814 
GetOutput()815   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()816   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
817 
818  private:
819   int input_;
820   int perm_;
821   int output_;
822 };
823 
TEST(NNAPIDelegate,TransposeSimpleTest)824 TEST(NNAPIDelegate, TransposeSimpleTest) {
825   TransposeSimpleModel m({2, 3, 4}, {3}, {2, 0, 1});
826   m.SetInput({0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
827               12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23});
828   m.Invoke();
829   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3}));
830   EXPECT_THAT(m.GetOutput(),
831               ElementsAreArray({0, 4, 8,  12, 16, 20, 1, 5, 9,  13, 17, 21,
832                                 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23}));
833 }
834 
835 class FloatSubOpModel : public SingleOpModelWithNNAPI {
836  public:
FloatSubOpModel(const TensorData & input1,const TensorData & input2,const TensorData & output,ActivationFunctionType activation_type)837   FloatSubOpModel(const TensorData& input1, const TensorData& input2,
838                   const TensorData& output,
839                   ActivationFunctionType activation_type) {
840     input1_ = AddInput(input1);
841     input2_ = AddInput(input2);
842     output_ = AddOutput(output);
843     SetBuiltinOp(BuiltinOperator_SUB, BuiltinOptions_SubOptions,
844                  CreateMulOptions(builder_, activation_type).Union());
845     BuildInterpreter({GetShape(input1_), GetShape(input2_)});
846   }
847 
input1()848   int input1() { return input1_; }
input2()849   int input2() { return input2_; }
850 
GetOutput()851   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
852 
853  protected:
854   int input1_;
855   int input2_;
856   int output_;
857 };
858 
TEST(NNAPIDelegate,SubWithNoActivation)859 TEST(NNAPIDelegate, SubWithNoActivation) {
860   FloatSubOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
861                     {TensorType_FLOAT32, {1, 2, 2, 1}},
862                     {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
863   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
864   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
865   m.Invoke();
866   EXPECT_THAT(m.GetOutput(),
867               ElementsAreArray(ArrayFloatNear({-2.1, 0.0, 0.4, 0.3})));
868 }
869 
870 class FloatDivOpModel : public SingleOpModelWithNNAPI {
871  public:
FloatDivOpModel(const TensorData & input1,const TensorData & input2,const TensorData & output,ActivationFunctionType activation_type)872   FloatDivOpModel(const TensorData& input1, const TensorData& input2,
873                   const TensorData& output,
874                   ActivationFunctionType activation_type) {
875     input1_ = AddInput(input1);
876     input2_ = AddInput(input2);
877     output_ = AddOutput(output);
878     SetBuiltinOp(BuiltinOperator_DIV, BuiltinOptions_DivOptions,
879                  CreateMulOptions(builder_, activation_type).Union());
880     BuildInterpreter({GetShape(input1_), GetShape(input2_)});
881   }
882 
input1()883   int input1() { return input1_; }
input2()884   int input2() { return input2_; }
885 
GetOutput()886   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
887 
888  protected:
889   int input1_;
890   int input2_;
891   int output_;
892 };
893 
TEST(NNAPIDelegate,DivWithNoActivation)894 TEST(NNAPIDelegate, DivWithNoActivation) {
895   FloatDivOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
896                     {TensorType_FLOAT32, {1, 2, 2, 1}},
897                     {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
898   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.8, 0.8});
899   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.4, 0.2});
900   m.Invoke();
901   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({-20, 1, 2, 4})));
902 }
903 
904 class BaseConcatenationOpModel : public SingleOpModelWithNNAPI {
905  public:
BaseConcatenationOpModel()906   BaseConcatenationOpModel() {}
BaseConcatenationOpModel(const TensorData & input_template,int axis,int num_inputs)907   BaseConcatenationOpModel(const TensorData& input_template, int axis,
908                            int num_inputs) {
909     std::vector<std::vector<int>> all_input_shapes;
910     for (int i = 0; i < num_inputs; ++i) {
911       all_input_shapes.push_back(input_template.shape);
912       AddInput(input_template);
913     }
914     output_ = AddOutput({input_template.type, /*shape=*/{}, input_template.min,
915                          input_template.max});
916     SetBuiltinOp(
917         BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions,
918         CreateConcatenationOptions(builder_, axis, ActivationFunctionType_NONE)
919             .Union());
920     BuildInterpreter(all_input_shapes);
921   }
922 
923  protected:
924   int output_;
925 };
926 
927 class ConcatenationOpModel : public BaseConcatenationOpModel {
928  public:
929   using BaseConcatenationOpModel::BaseConcatenationOpModel;
SetInput(int index,std::initializer_list<float> data)930   void SetInput(int index, std::initializer_list<float> data) {
931     PopulateTensor(index, data);
932   }
GetOutput()933   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
934 };
935 
TEST(NNAPIDelegate,ConcatenationThreeDimensionalOneInput)936 TEST(NNAPIDelegate, ConcatenationThreeDimensionalOneInput) {
937   ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/1,
938                           /*num_inputs=*/1);
939   m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
940   m0.Invoke();
941   EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 3, 4, 7}));
942 }
943 
TEST(NNAPIDelegate,ConcatenationFourInputs)944 TEST(NNAPIDelegate, ConcatenationFourInputs) {
945   ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/2,
946                           /*num_inputs=*/4);
947   m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
948   m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f});
949   m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f});
950   m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f});
951   m0.Invoke();
952   EXPECT_THAT(m0.GetOutput(),
953               ElementsAreArray({
954                   1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f,  //
955                   4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f,  //
956               }));
957 }
958 
959 class QuantizedConcatenationOpModel : public BaseConcatenationOpModel {
960  public:
961   using BaseConcatenationOpModel::BaseConcatenationOpModel;
QuantizedConcatenationOpModel(const std::vector<TensorData> & input_template,int axis,int num_inputs,const TensorData & output_template)962   QuantizedConcatenationOpModel(const std::vector<TensorData>& input_template,
963                                 int axis, int num_inputs,
964                                 const TensorData& output_template) {
965     std::vector<std::vector<int>> all_input_shapes;
966     CHECK_EQ(input_template.size(), num_inputs);
967     for (int i = 0; i < num_inputs; ++i) {
968       all_input_shapes.push_back(input_template[i].shape);
969       AddInput(input_template[i]);
970     }
971     output_ = AddOutput({output_template.type, /*shape=*/{},
972                          output_template.min, output_template.max});
973     SetBuiltinOp(
974         BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions,
975         CreateConcatenationOptions(builder_, axis, ActivationFunctionType_NONE)
976             .Union());
977     BuildInterpreter(all_input_shapes);
978   }
SetInput(int index,std::initializer_list<float> data)979   void SetInput(int index, std::initializer_list<float> data) {
980     QuantizeAndPopulate<uint8_t>(index, data);
981   }
GetOutput()982   std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
GetDequantizedOutput()983   std::vector<float> GetDequantizedOutput() {
984     return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
985                                GetScale(output_), GetZeroPoint(output_));
986   }
987 };
988 
TEST(NNAPIDelegate,ConcatenationFourInputsQuantized)989 TEST(NNAPIDelegate, ConcatenationFourInputsQuantized) {
990   QuantizedConcatenationOpModel m0({TensorType_UINT8, {2, 1, 2}, -12.7, 12.8},
991                                    /*axis=*/2,
992                                    /*num_inputs=*/4);
993 
994   m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
995   m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f});
996   m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f});
997   m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f});
998   m0.Invoke();
999   EXPECT_THAT(m0.GetDequantizedOutput(),
1000               ElementsAreArray(ArrayFloatNear({
1001                   1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f,  //
1002                   4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f,  //
1003               })));
1004   EXPECT_THAT(m0.GetOutput(), ElementsAreArray({
1005                                   137, 157, 138, 158, 139, 159, 140, 160,  //
1006                                   167, 197, 168, 198, 169, 199, 170, 200,  //
1007                               }));
1008 }
1009 
TEST(NNAPIDelegate,ConcatenationFourInputsQuantizedMixedRange)1010 TEST(NNAPIDelegate, ConcatenationFourInputsQuantizedMixedRange) {
1011   QuantizedConcatenationOpModel m0({{TensorType_UINT8, {2, 1, 2}, -10.7, 10.8},
1012                                     {TensorType_UINT8, {2, 1, 2}, 0, 12.8},
1013                                     {TensorType_UINT8, {2, 1, 2}, -11, 11.8},
1014                                     {TensorType_UINT8, {2, 1, 2}, 0, 7.4}},
1015                                    /*axis=*/2, /*num_inputs=*/4,
1016                                    {TensorType_UINT8, {2, 1, 2}, -12.7, 12.8});
1017 
1018   m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
1019   m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f});
1020   m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f});
1021   m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f});
1022   m0.Invoke();
1023   EXPECT_THAT(m0.GetDequantizedOutput(),
1024               ElementsAreArray(ArrayFloatNear({
1025                   1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f,  //
1026                   4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f,  //
1027               })));
1028   EXPECT_THAT(m0.GetOutput(), ElementsAreArray({
1029                                   137, 157, 138, 158, 139, 159, 140, 160,  //
1030                                   167, 197, 168, 198, 169, 199, 170, 200,  //
1031                               }));
1032 }
1033 
1034 class DequantizeOpModel : public SingleOpModelWithNNAPI {
1035  public:
DequantizeOpModel(TensorType inputType,std::initializer_list<int> shape,float min,float max)1036   DequantizeOpModel(TensorType inputType, std::initializer_list<int> shape,
1037                     float min, float max) {
1038     input_ = AddInput({inputType, shape, min, max});
1039     output_ = AddOutput({TensorType_FLOAT32, shape});
1040     SetBuiltinOp(BuiltinOperator_DEQUANTIZE, BuiltinOptions_DequantizeOptions,
1041                  CreateDequantizeOptions(builder_).Union());
1042 
1043     BuildInterpreter({GetShape(input_)});
1044   }
1045 
1046   template <typename T>
SetInput(std::initializer_list<T> data)1047   void SetInput(std::initializer_list<T> data) {
1048     PopulateTensor(input_, data);
1049   }
1050 
GetOutput()1051   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
1052 
1053  private:
1054   int input_;
1055   int output_;
1056 };
1057 
TEST(NNAPIDelegate,DequantizeFourDimensionalUint8)1058 TEST(NNAPIDelegate, DequantizeFourDimensionalUint8) {
1059   DequantizeOpModel m(TensorType_UINT8, {2, 5}, -63.5, 64);
1060 
1061   m.SetInput<uint8_t>({0, 1, 2, 3, 4, 251, 252, 253, 254, 255});
1062   m.Invoke();
1063   EXPECT_THAT(m.GetOutput(),
1064               ElementsAreArray(ArrayFloatNear(
1065                   {-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64})));
1066 }
1067 
TEST(NNAPIDelegate,DequantizeFourDimensionalInt8Symm)1068 TEST(NNAPIDelegate, DequantizeFourDimensionalInt8Symm) {
1069   // [-64, 63.5] -> scale=0.5, zero_point=0 for INT8
1070   DequantizeOpModel m(TensorType_INT8, {2, 5}, -64, 63.5);
1071 
1072   m.SetInput<int8_t>({-128, -127, -126, -125, -124, 123, 124, 125, 126, 127});
1073   m.Invoke();
1074   EXPECT_THAT(m.GetOutput(),
1075               ElementsAreArray(ArrayFloatNear(
1076                   {-64, -63.5, -63, -62.5, -62, 61.5, 62, 62.5, 63, 63.5})));
1077 }
1078 
1079 class FloorOpModel : public SingleOpModelWithNNAPI {
1080  public:
FloorOpModel(std::initializer_list<int> input_shape,TensorType input_type)1081   FloorOpModel(std::initializer_list<int> input_shape, TensorType input_type) {
1082     input_ = AddInput(TensorType_FLOAT32);
1083     output_ = AddOutput(TensorType_FLOAT32);
1084     SetBuiltinOp(BuiltinOperator_FLOOR, BuiltinOptions_NONE, 0);
1085     BuildInterpreter({
1086         input_shape,
1087     });
1088   }
1089 
input()1090   int input() { return input_; }
1091 
GetOutput()1092   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()1093   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
1094 
1095  private:
1096   int input_;
1097   int output_;
1098 };
1099 
TEST(NNAPIDelegate,FloorSingleDim)1100 TEST(NNAPIDelegate, FloorSingleDim) {
1101   FloorOpModel model({2}, TensorType_FLOAT32);
1102   model.PopulateTensor<float>(model.input(), {8.5, 0.0});
1103   model.Invoke();
1104   EXPECT_THAT(model.GetOutput(), ElementsAreArray({8, 0}));
1105   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2}));
1106 }
1107 
TEST(NNAPIDelegate,FloorMultiDims)1108 TEST(NNAPIDelegate, FloorMultiDims) {
1109   FloorOpModel model({2, 1, 1, 5}, TensorType_FLOAT32);
1110   model.PopulateTensor<float>(model.input(), {
1111                                                  0.0001,
1112                                                  8.0001,
1113                                                  0.9999,
1114                                                  9.9999,
1115                                                  0.5,
1116                                                  -0.0001,
1117                                                  -8.0001,
1118                                                  -0.9999,
1119                                                  -9.9999,
1120                                                  -0.5,
1121                                              });
1122   model.Invoke();
1123   EXPECT_THAT(model.GetOutput(),
1124               ElementsAreArray({0, 8, 0, 9, 0, -1, -9, -1, -10, -1}));
1125   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 1, 5}));
1126 }
1127 
1128 class LocalResponseNormOpModel : public SingleOpModelWithNNAPI {
1129  public:
LocalResponseNormOpModel(std::initializer_list<int> input_shape,int radius,float bias,float alpha,float beta)1130   LocalResponseNormOpModel(std::initializer_list<int> input_shape, int radius,
1131                            float bias, float alpha, float beta) {
1132     input_ = AddInput(TensorType_FLOAT32);
1133     output_ = AddOutput(TensorType_FLOAT32);
1134     SetBuiltinOp(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
1135                  BuiltinOptions_LocalResponseNormalizationOptions,
1136                  CreateLocalResponseNormalizationOptions(builder_, radius, bias,
1137                                                          alpha, beta)
1138                      .Union());
1139     BuildInterpreter({input_shape});
1140   }
1141 
SetInput(std::initializer_list<float> data)1142   void SetInput(std::initializer_list<float> data) {
1143     PopulateTensor(input_, data);
1144   }
1145 
GetOutput()1146   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
1147 
1148  private:
1149   int input_;
1150   int output_;
1151 };
1152 
TEST(NNAPIDelegate,LocalResponseNormSameAsL2Norm)1153 TEST(NNAPIDelegate, LocalResponseNormSameAsL2Norm) {
1154   LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/0.0,
1155                              /*alpha=*/1.0, /*beta=*/0.5);
1156   m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
1157   m.Invoke();
1158   // The result is every input divided by 2.
1159   EXPECT_THAT(
1160       m.GetOutput(),
1161       ElementsAreArray(ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05})));
1162 }
1163 
TEST(NNAPIDelegate,LocalResponseNormWithAlpha)1164 TEST(NNAPIDelegate, LocalResponseNormWithAlpha) {
1165   LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/0.0,
1166                              /*alpha=*/4.0, /*beta=*/0.5);
1167   m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
1168   m.Invoke();
1169   // The result is every input divided by 3.
1170   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
1171                                  {-0.275, 0.15, 0.175, 0.3, -0.175, 0.025})));
1172 }
1173 
TEST(NNAPIDelegate,LocalResponseNormWithBias)1174 TEST(NNAPIDelegate, LocalResponseNormWithBias) {
1175   LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/9.0,
1176                              /*alpha=*/4.0, /*beta=*/0.5);
1177   m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
1178   m.Invoke();
1179   // The result is every input divided by 5.
1180   EXPECT_THAT(
1181       m.GetOutput(),
1182       ElementsAreArray(ArrayFloatNear({-0.22, 0.12, 0.14, 0.24, -0.14, 0.02})));
1183 }
1184 
TEST(NNAPIDelegate,LocalResponseNormSmallRadius)1185 TEST(NNAPIDelegate, LocalResponseNormSmallRadius) {
1186   LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/2, /*bias=*/9.0,
1187                              /*alpha=*/4.0, /*beta=*/0.5);
1188   m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
1189   m.Invoke();
1190   EXPECT_THAT(
1191       m.GetOutput(),
1192       ElementsAreArray(ArrayFloatNear(
1193           {-0.264926, 0.125109, 0.140112, 0.267261, -0.161788, 0.0244266})));
1194 }
1195 
1196 class LSHProjectionOpModel : public SingleOpModelWithNNAPI {
1197  public:
LSHProjectionOpModel(LSHProjectionType type,std::initializer_list<int> hash_shape,std::initializer_list<int> input_shape,std::initializer_list<int> weight_shape)1198   LSHProjectionOpModel(LSHProjectionType type,
1199                        std::initializer_list<int> hash_shape,
1200                        std::initializer_list<int> input_shape,
1201                        std::initializer_list<int> weight_shape) {
1202     hash_ = AddInput(TensorType_FLOAT32);
1203     input_ = AddInput(TensorType_INT32);
1204     if (weight_shape.size() > 0) {
1205       weight_ = AddInput(TensorType_FLOAT32);
1206     }
1207     output_ = AddOutput(TensorType_INT32);
1208 
1209     SetBuiltinOp(BuiltinOperator_LSH_PROJECTION,
1210                  BuiltinOptions_LSHProjectionOptions,
1211                  CreateLSHProjectionOptions(builder_, type).Union());
1212     if (weight_shape.size() > 0) {
1213       BuildInterpreter({hash_shape, input_shape, weight_shape});
1214     } else {
1215       BuildInterpreter({hash_shape, input_shape});
1216     }
1217 
1218     output_size_ = 1;
1219     for (int i : hash_shape) {
1220       output_size_ *= i;
1221       if (type == LSHProjectionType_SPARSE) {
1222         break;
1223       }
1224     }
1225   }
SetInput(std::initializer_list<int> data)1226   void SetInput(std::initializer_list<int> data) {
1227     PopulateTensor(input_, data);
1228   }
1229 
SetHash(std::initializer_list<float> data)1230   void SetHash(std::initializer_list<float> data) {
1231     PopulateTensor(hash_, data);
1232   }
1233 
SetWeight(std::initializer_list<float> f)1234   void SetWeight(std::initializer_list<float> f) { PopulateTensor(weight_, f); }
1235 
GetOutput()1236   std::vector<int> GetOutput() { return ExtractVector<int>(output_); }
1237 
1238  private:
1239   int input_;
1240   int hash_;
1241   int weight_;
1242   int output_;
1243 
1244   int output_size_;
1245 };
1246 
TEST(NNAPIDelegate,LSHProjectionDense1DInputs)1247 TEST(NNAPIDelegate, LSHProjectionDense1DInputs) {
1248   LSHProjectionOpModel m(LSHProjectionType_DENSE, {3, 2}, {5}, {5});
1249 
1250   m.SetInput({12345, 54321, 67890, 9876, -12345678});
1251   m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321});
1252   m.SetWeight({1.0, 1.0, 1.0, 1.0, 1.0});
1253 
1254   m.Invoke();
1255 
1256   EXPECT_THAT(m.GetOutput(), ElementsAre(0, 0, 0, 1, 0, 0));
1257 }
1258 
TEST(NNAPIDelegate,LSHProjectionSparse1DInputs)1259 TEST(NNAPIDelegate, LSHProjectionSparse1DInputs) {
1260   LSHProjectionOpModel m(LSHProjectionType_SPARSE, {3, 2}, {5}, {});
1261 
1262   m.SetInput({12345, 54321, 67890, 9876, -12345678});
1263   m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321});
1264 
1265   m.Invoke();
1266 
1267   EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 0, 4 + 1, 8 + 0));
1268 }
1269 
TEST(NNAPIDelegate,LSHProjectionSparse3DInputs)1270 TEST(NNAPIDelegate, LSHProjectionSparse3DInputs) {
1271   LSHProjectionOpModel m(LSHProjectionType_SPARSE, {3, 2}, {5, 2, 2}, {5});
1272 
1273   m.SetInput({1234, 2345, 3456, 1234, 4567, 5678, 6789, 4567, 7891, 8912,
1274               9123, 7890, -987, -876, -765, -987, -543, -432, -321, -543});
1275   m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321});
1276   m.SetWeight({0.12, 0.34, 0.56, 0.67, 0.78});
1277 
1278   m.Invoke();
1279 
1280   EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 2, 4 + 1, 8 + 1));
1281 }
1282 
1283 class BaseActivationsOpModel : public SingleOpModelWithNNAPI {
1284  public:
1285   // Most activations don't take any options, so this constructor works for
1286   // them.
BaseActivationsOpModel(BuiltinOperator type,TensorData input)1287   BaseActivationsOpModel(BuiltinOperator type, TensorData input) {
1288     input_ = AddInput(input);
1289     if (input.type == TensorType_UINT8) {
1290       output_ = AddOutput({input.type, {}, 0, 0, 1. / 256});
1291     } else {
1292       output_ = AddOutput({input.type, {}});
1293     }
1294     SetBuiltinOp(type, BuiltinOptions_NONE, 0);
1295     BuildInterpreter({GetShape(input_)});
1296   }
1297 
BaseActivationsOpModel(BuiltinOperator type,const TensorData & input,const TensorData & output)1298   BaseActivationsOpModel(BuiltinOperator type, const TensorData& input,
1299                          const TensorData& output) {
1300     input_ = AddInput(input);
1301     output_ = AddOutput(output);
1302     SetBuiltinOp(type, BuiltinOptions_NONE, 0);
1303     BuildInterpreter({GetShape(input_)});
1304   }
1305 
1306  protected:
1307   int input_;
1308   int output_;
1309 };
1310 
1311 class FloatActivationsOpModel : public BaseActivationsOpModel {
1312  public:
1313   using BaseActivationsOpModel::BaseActivationsOpModel;
1314 
SetInput(std::initializer_list<float> data)1315   void SetInput(std::initializer_list<float> data) {
1316     PopulateTensor(input_, data);
1317   }
GetOutput()1318   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
1319 };
1320 
1321 const float kQuantizedTolerance = 2 * (1. / 256);
1322 
1323 class QuantizedActivationsOpModel : public BaseActivationsOpModel {
1324  public:
1325   using BaseActivationsOpModel::BaseActivationsOpModel;
1326 
1327   template <typename T>
SetInput(std::initializer_list<float> data)1328   void SetInput(std::initializer_list<float> data) {
1329     QuantizeAndPopulate<T>(input_, data);
1330   }
1331   template <typename T>
1332 
GetOutput()1333   std::vector<T> GetOutput() {
1334     return ExtractVector<T>(output_);
1335   }
1336   template <typename T>
GetDequantizedOutput()1337   std::vector<float> GetDequantizedOutput() {
1338     return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
1339                          GetZeroPoint(output_));
1340   }
1341 };
1342 
TEST(NNAPIDelegate,Relu)1343 TEST(NNAPIDelegate, Relu) {
1344   FloatActivationsOpModel m(BuiltinOperator_RELU,
1345                             /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
1346   m.SetInput({
1347       0, -6, 2, 4,   //
1348       3, -2, 10, 1,  //
1349   });
1350   m.Invoke();
1351   EXPECT_THAT(m.GetOutput(), ElementsAreArray({
1352                                  0, 0, 2, 4,   //
1353                                  3, 0, 10, 1,  //
1354                              }));
1355 }
1356 
TEST(NNAPIDelegate,Relu1)1357 TEST(NNAPIDelegate, Relu1) {
1358   FloatActivationsOpModel m(BuiltinOperator_RELU_N1_TO_1,
1359                             /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
1360   m.SetInput({
1361       0.0, -0.6, 0.2, -0.4,  //
1362       0.3, -2.0, 1.1, -0.1,  //
1363   });
1364   m.Invoke();
1365   EXPECT_THAT(m.GetOutput(), ElementsAreArray({
1366                                  0.0, -0.6, 0.2, -0.4,  //
1367                                  0.3, -1.0, 1.0, -0.1,  //
1368                              }));
1369 }
1370 
TEST(NNAPIDelegate,Relu6)1371 TEST(NNAPIDelegate, Relu6) {
1372   FloatActivationsOpModel m(BuiltinOperator_RELU6,
1373                             /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
1374   m.SetInput({
1375       0, -6, 2, 4,   //
1376       3, -2, 10, 1,  //
1377   });
1378   m.Invoke();
1379   EXPECT_THAT(m.GetOutput(), ElementsAreArray({
1380                                  0, 0, 2, 4,  //
1381                                  3, 0, 6, 1,  //
1382                              }));
1383 }
1384 
TEST(NNAPIDelegate,Tanh)1385 TEST(NNAPIDelegate, Tanh) {
1386   FloatActivationsOpModel m(BuiltinOperator_TANH,
1387                             /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
1388   m.SetInput({
1389       0, -6, 2, 4,   //
1390       3, -2, 10, 1,  //
1391   });
1392   m.Invoke();
1393   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
1394                                  0, -0.9999877, 0.9640275, 0.999329,    //
1395                                  0.99505475, -0.9640275, 1, 0.7615941,  //
1396                              })));
1397 }
1398 
TEST(NNAPIDelegate,LogisticFloat)1399 TEST(NNAPIDelegate, LogisticFloat) {
1400   FloatActivationsOpModel m(BuiltinOperator_LOGISTIC,
1401                             /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
1402   m.SetInput({
1403       0, -6, 2, 4,   //
1404       3, -2, 10, 1,  //
1405   });
1406   m.Invoke();
1407   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
1408                                  0.5, 0.002473, 0.880797, 0.982014,       //
1409                                  0.952574, 0.119203, 0.999955, 0.731059,  //
1410                              })));
1411 }
1412 
TEST(NNAPIDelegate,LogisticQuantized)1413 TEST(NNAPIDelegate, LogisticQuantized) {
1414   QuantizedActivationsOpModel m(
1415       BuiltinOperator_LOGISTIC,
1416       /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -10, 10});
1417   m.SetInput<uint8_t>({
1418       0, -6, 2, 4,   //
1419       3, -2, 10, 1,  //
1420   });
1421   m.Invoke();
1422   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
1423               ElementsAreArray(ArrayFloatNear(
1424                   {
1425                       0.5, 0.002473, 0.880797, 0.982014,       //
1426                       0.952574, 0.119203, 0.999955, 0.731059,  //
1427                   },
1428                   kQuantizedTolerance)));
1429   EXPECT_THAT(m.GetOutput<uint8_t>(),
1430               testing::Pointwise(QuantizedNear(),
1431                                  {128, 1, 227, 251, 244, 32, 255, 188}));
1432 }
1433 
1434 #if 0
1435 class ResizeBilinearOpModel : public SingleOpModelWithNNAPI {
1436  public:
1437   ResizeBilinearOpModel(const TensorData& input,
1438                         std::initializer_list<int> size_data = {}) {
1439     bool const_size = size_data.size() != 0;
1440     input_ = AddInput(input);
1441     if (const_size) {
1442       size_ = AddConstInput(TensorType_INT32, size_data, {2});
1443     } else {
1444       size_ = AddInput({TensorType_INT32, {2}});
1445     }
1446     output_ = AddOutput(input.type);
1447     SetBuiltinOp(BuiltinOperator_RESIZE_BILINEAR,
1448                  BuiltinOptions_ResizeBilinearOptions,
1449                  CreateResizeBilinearOptions(builder_).Union());
1450     if (const_size) {
1451       BuildInterpreter({GetShape(input_)});
1452     } else {
1453       BuildInterpreter({GetShape(input_), GetShape(size_)});
1454     }
1455   }
1456 
1457   template <typename T>
1458   void SetInput(std::initializer_list<T> data) {
1459     PopulateTensor(input_, data);
1460   }
1461   void SetSize(std::initializer_list<int> data) { PopulateTensor(size_, data); }
1462 
1463   template <typename T>
1464   std::vector<T> GetOutput() {
1465     return ExtractVector<T>(output_);
1466   }
1467 
1468  private:
1469   int input_;
1470   int size_;
1471   int output_;
1472 };
1473 
1474 TEST(NNAPIDelegate, ResizeBilinearHorizontal) {
1475   ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}});
1476   m.SetInput<float>({3, 6});
1477   m.SetSize({1, 3});
1478   m.Invoke();
1479   EXPECT_THAT(m.GetOutput<float>(),
1480               ElementsAreArray(ArrayFloatNear({3, 5, 6})));
1481 
1482   ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3});
1483   const_m.SetInput<float>({3, 6});
1484   const_m.Invoke();
1485   EXPECT_THAT(const_m.GetOutput<float>(),
1486               ElementsAreArray(ArrayFloatNear({3, 5, 6})));
1487 }
1488 
1489 TEST(NNAPIDelegate, ResizeBilinearVertical) {
1490   ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}});
1491   m.SetInput<float>({3, 9});
1492   m.SetSize({3, 1});
1493   m.Invoke();
1494   EXPECT_THAT(m.GetOutput<float>(),
1495               ElementsAreArray(ArrayFloatNear({3, 7, 9})));
1496 
1497   ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1});
1498   const_m.SetInput<float>({3, 9});
1499   const_m.Invoke();
1500   EXPECT_THAT(const_m.GetOutput<float>(),
1501               ElementsAreArray(ArrayFloatNear({3, 7, 9})));
1502 }
1503 
1504 TEST(NNAPIDelegate, ResizeBilinearTwoDimensional) {
1505   ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}});
1506   m.SetInput<float>({
1507       3, 6,  //
1508       9, 12  //
1509   });
1510   m.SetSize({3, 3});
1511   m.Invoke();
1512   EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
1513                                         3, 5, 6,    //
1514                                         7, 9, 10,   //
1515                                         9, 11, 12,  //
1516                                     })));
1517 
1518   ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3});
1519   const_m.SetInput<float>({
1520       3, 6,  //
1521       9, 12  //
1522   });
1523   const_m.Invoke();
1524   EXPECT_THAT(const_m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
1525                                               3, 5, 6,    //
1526                                               7, 9, 10,   //
1527                                               9, 11, 12,  //
1528                                           })));
1529 }
1530 #endif
1531 
1532 template <typename T>
1533 class PadOpModel : public SingleOpModelWithNNAPI {
1534  public:
SetInput(std::initializer_list<T> data)1535   void SetInput(std::initializer_list<T> data) {
1536     PopulateTensor<T>(input_, data);
1537   }
1538 
SetQuantizedInput(std::initializer_list<float> data)1539   void SetQuantizedInput(std::initializer_list<float> data) {
1540     QuantizeAndPopulate<uint8_t>(input_, data);
1541   }
1542 
SetQuantizedPadValue(float data)1543   void SetQuantizedPadValue(float data) {
1544     QuantizeAndPopulate<uint8_t>(constant_values_, {data});
1545   }
1546 
SetPaddings(std::initializer_list<int> paddings)1547   void SetPaddings(std::initializer_list<int> paddings) {
1548     PopulateTensor<int>(paddings_, paddings);
1549   }
1550 
GetOutput()1551   std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
GetOutputShape()1552   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
1553 
GetDequantizedOutput()1554   std::vector<float> GetDequantizedOutput() {
1555     return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
1556                                GetScale(output_), GetZeroPoint(output_));
1557   }
1558 
1559  protected:
1560   int input_;
1561   int output_;
1562   int paddings_;
1563   int constant_values_;
1564 };
1565 
1566 class PadOpConstModel : public PadOpModel<float> {
1567  public:
PadOpConstModel(const TensorData & input,std::initializer_list<int> paddings_shape,std::initializer_list<int> paddings,const TensorData & output)1568   PadOpConstModel(const TensorData& input,
1569                   std::initializer_list<int> paddings_shape,
1570                   std::initializer_list<int> paddings,
1571                   const TensorData& output) {
1572     input_ = AddInput(input);
1573     paddings_ = AddConstInput(TensorType_INT32, paddings, paddings_shape);
1574     output_ = AddOutput(output);
1575 
1576     SetBuiltinOp(BuiltinOperator_PAD, BuiltinOptions_PadOptions,
1577                  CreatePadOptions(builder_).Union());
1578     BuildInterpreter({input.shape});
1579   }
1580 };
1581 
TEST(NNAPIDelegate,PadAdvancedConstTest)1582 TEST(NNAPIDelegate, PadAdvancedConstTest) {
1583   PadOpConstModel m({TensorType_FLOAT32, {1, 2, 3, 1}}, {4, 2},
1584                     {0, 0, 0, 2, 1, 3, 0, 0}, {TensorType_FLOAT32});
1585   m.SetInput({1, 2, 3, 4, 5, 6});
1586   m.Invoke();
1587   EXPECT_THAT(m.GetOutput(),
1588               ElementsAreArray({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0,
1589                                 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
1590   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
1591 }
1592 
1593 class SpaceToBatchNDOpModel : public SingleOpModelWithNNAPI {
1594  public:
SetInput(std::initializer_list<float> data)1595   void SetInput(std::initializer_list<float> data) {
1596     PopulateTensor<float>(input_, data);
1597   }
1598 
SetBlockShape(std::initializer_list<int> data)1599   void SetBlockShape(std::initializer_list<int> data) {
1600     PopulateTensor<int>(block_shape_, data);
1601   }
1602 
SetPaddings(std::initializer_list<int> data)1603   void SetPaddings(std::initializer_list<int> data) {
1604     PopulateTensor<int>(paddings_, data);
1605   }
1606 
GetOutput()1607   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()1608   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
1609 
1610  protected:
1611   int input_;
1612   int block_shape_;
1613   int paddings_;
1614   int output_;
1615 };
1616 
1617 class SpaceToBatchNDOpConstModel : public SpaceToBatchNDOpModel {
1618  public:
SpaceToBatchNDOpConstModel(std::initializer_list<int> input_shape,std::initializer_list<int> block_shape,std::initializer_list<int> paddings)1619   SpaceToBatchNDOpConstModel(std::initializer_list<int> input_shape,
1620                              std::initializer_list<int> block_shape,
1621                              std::initializer_list<int> paddings) {
1622     input_ = AddInput(TensorType_FLOAT32);
1623     block_shape_ = AddConstInput(TensorType_INT32, block_shape, {2});
1624     paddings_ = AddConstInput(TensorType_INT32, paddings, {2, 2});
1625     output_ = AddOutput(TensorType_FLOAT32);
1626 
1627     SetBuiltinOp(BuiltinOperator_SPACE_TO_BATCH_ND,
1628                  BuiltinOptions_SpaceToBatchNDOptions,
1629                  CreateSpaceToBatchNDOptions(builder_).Union());
1630     BuildInterpreter({input_shape});
1631   }
1632 };
1633 
TEST(NNAPIDelegate,SpaceToBatchNDSimpleConstTest)1634 TEST(NNAPIDelegate, SpaceToBatchNDSimpleConstTest) {
1635   SpaceToBatchNDOpConstModel m({1, 4, 4, 1}, {2, 2}, {0, 0, 0, 0});
1636   m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
1637   m.Invoke();
1638   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 2, 1}));
1639   EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 9, 11, 2, 4, 10, 12, 5, 7,
1640                                                13, 15, 6, 8, 14, 16}));
1641 }
1642 
TEST(NNAPIDelegate,SpaceToBatchNDMultipleInputBatchesConstTest)1643 TEST(NNAPIDelegate, SpaceToBatchNDMultipleInputBatchesConstTest) {
1644   SpaceToBatchNDOpConstModel m({2, 2, 4, 1}, {2, 2}, {0, 0, 0, 0});
1645   m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
1646   m.Invoke();
1647   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({8, 1, 2, 1}));
1648   EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 9, 11, 2, 4, 10, 12, 5, 7,
1649                                                13, 15, 6, 8, 14, 16}));
1650 }
1651 
TEST(NNAPIDelegate,SpaceToBatchNDSimplePaddingConstTest)1652 TEST(NNAPIDelegate, SpaceToBatchNDSimplePaddingConstTest) {
1653   SpaceToBatchNDOpConstModel m({1, 5, 2, 1}, {3, 2}, {1, 0, 2, 0});
1654   m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
1655   m.Invoke();
1656   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1}));
1657   EXPECT_THAT(m.GetOutput(), ElementsAreArray({
1658                                  0, 0, 0, 5, 0, 0, 0, 6, 0, 1, 0, 7,
1659                                  0, 2, 0, 8, 0, 3, 0, 9, 0, 4, 0, 10,
1660                              }));
1661 }
1662 
TEST(NNAPIDelegate,SpaceToBatchNDComplexPaddingConstTest)1663 TEST(NNAPIDelegate, SpaceToBatchNDComplexPaddingConstTest) {
1664   SpaceToBatchNDOpConstModel m({1, 4, 2, 1}, {3, 2}, {1, 1, 2, 4});
1665   m.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
1666   m.Invoke();
1667   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1}));
1668   EXPECT_THAT(m.GetOutput(), ElementsAreArray({
1669                                  0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0,
1670                                  0, 1, 0, 0, 0, 7, 0, 0, 0, 2, 0, 0, 0, 8, 0, 0,
1671                                  0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
1672                              }));
1673 }
1674 
1675 template <typename input_type = float,
1676           TensorType tensor_input_type = TensorType_FLOAT32>
1677 class StridedSliceOpModel : public SingleOpModelWithNNAPI {
1678  public:
StridedSliceOpModel(std::initializer_list<int> input_shape,std::initializer_list<int> begin_shape,std::initializer_list<int> begin_data,std::initializer_list<int> end_shape,std::initializer_list<int> end_data,std::initializer_list<int> strides_shape,std::initializer_list<int> strides_data,int begin_mask,int end_mask,int ellipsis_mask,int new_axis_mask,int shrink_axis_mask)1679   StridedSliceOpModel(std::initializer_list<int> input_shape,
1680                       std::initializer_list<int> begin_shape,
1681                       std::initializer_list<int> begin_data,
1682                       std::initializer_list<int> end_shape,
1683                       std::initializer_list<int> end_data,
1684                       std::initializer_list<int> strides_shape,
1685                       std::initializer_list<int> strides_data, int begin_mask,
1686                       int end_mask, int ellipsis_mask, int new_axis_mask,
1687                       int shrink_axis_mask) {
1688     input_ = AddInput(tensor_input_type);
1689     begin_ = AddConstInput(TensorType_INT32, begin_data, begin_shape);
1690     end_ = AddConstInput(TensorType_INT32, end_data, end_shape);
1691     strides_ = AddConstInput(TensorType_INT32, strides_data, strides_shape);
1692     output_ = AddOutput(tensor_input_type);
1693     SetBuiltinOp(
1694         BuiltinOperator_STRIDED_SLICE, BuiltinOptions_StridedSliceOptions,
1695         CreateStridedSliceOptions(builder_, begin_mask, end_mask, ellipsis_mask,
1696                                   new_axis_mask, shrink_axis_mask)
1697             .Union());
1698     BuildInterpreter({input_shape, begin_shape, end_shape, strides_shape});
1699   }
1700 
SetInput(std::initializer_list<input_type> data)1701   void SetInput(std::initializer_list<input_type> data) {
1702     PopulateTensor<input_type>(input_, data);
1703   }
1704 
GetOutput()1705   std::vector<input_type> GetOutput() {
1706     return ExtractVector<input_type>(output_);
1707   }
GetOutputShape()1708   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
1709 
1710  private:
1711   int input_;
1712   int begin_;
1713   int end_;
1714   int strides_;
1715   int output_;
1716 };
1717 
TEST(StridedSliceOpTest,In1D)1718 TEST(StridedSliceOpTest, In1D) {
1719   StridedSliceOpModel<> m({4}, {1}, {1}, {1}, {3}, {1}, {1}, 0, 0, 0, 0, 0);
1720   m.SetInput({1, 2, 3, 4});
1721   m.Invoke();
1722   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
1723   EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3}));
1724 }
1725 
TEST(StridedSliceOpTest,In1D_BeginMask)1726 TEST(StridedSliceOpTest, In1D_BeginMask) {
1727   StridedSliceOpModel<> m({4}, {1}, {1}, {1}, {3}, {1}, {1}, 1, 0, 0, 0, 0);
1728   m.SetInput({1, 2, 3, 4});
1729   m.Invoke();
1730   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
1731   EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3}));
1732 }
1733 
TEST(StridedSliceOpTest,In2D_Stride2)1734 TEST(StridedSliceOpTest, In2D_Stride2) {
1735   StridedSliceOpModel<> m({2, 3}, {2}, {0, 0}, {2}, {2, 3}, {2}, {2, 2}, 0, 0,
1736                           0, 0, 0);
1737   m.SetInput({1, 2, 3, 4, 5, 6});
1738   m.Invoke();
1739   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
1740   EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3}));
1741 }
1742 
TEST(StridedSliceOpTest,In2D_EndMask)1743 TEST(StridedSliceOpTest, In2D_EndMask) {
1744   StridedSliceOpModel<> m({2, 3}, {2}, {1, 0}, {2}, {2, 2}, {2}, {1, 1}, 0, 2,
1745                           0, 0, 0);
1746   m.SetInput({1, 2, 3, 4, 5, 6});
1747   m.Invoke();
1748   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3}));
1749   EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 5, 6}));
1750 }
1751 
TEST(StridedSliceOpTest,In3D_IdentityShrinkAxis4)1752 TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis4) {
1753   StridedSliceOpModel<> m({2, 3, 2}, {3}, {0, 0, 0}, {3}, {2, 3, 1}, {3},
1754                           {1, 1, 1}, 0, 0, 0, 0, 4);
1755   m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
1756   m.Invoke();
1757   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
1758   EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 5, 7, 9, 11}));
1759 }
1760 
1761 static float rnn_input[] = {
1762     0.23689353,   0.285385,     0.037029743, -0.19858193,  -0.27569133,
1763     0.43773448,   0.60379338,   0.35562468,  -0.69424844,  -0.93421471,
1764     -0.87287879,  0.37144363,   -0.62476718, 0.23791671,   0.40060222,
1765     0.1356622,    -0.99774903,  -0.98858172, -0.38952237,  -0.47685933,
1766     0.31073618,   0.71511042,   -0.63767755, -0.31729108,  0.33468103,
1767     0.75801885,   0.30660987,   -0.37354088, 0.77002847,   -0.62747043,
1768     -0.68572164,  0.0069220066, 0.65791464,  0.35130811,   0.80834007,
1769     -0.61777675,  -0.21095741,  0.41213346,  0.73784804,   0.094794154,
1770     0.47791874,   0.86496925,   -0.53376222, 0.85315156,   0.10288584,
1771     0.86684,      -0.011186242, 0.10513687,  0.87825835,   0.59929144,
1772     0.62827742,   0.18899453,   0.31440187,  0.99059987,   0.87170351,
1773     -0.35091716,  0.74861872,   0.17831337,  0.2755419,    0.51864719,
1774     0.55084288,   0.58982027,   -0.47443086, 0.20875752,   -0.058871567,
1775     -0.66609079,  0.59098077,   0.73017097,  0.74604273,   0.32882881,
1776     -0.17503482,  0.22396147,   0.19379807,  0.29120302,   0.077113032,
1777     -0.70331609,  0.15804303,   -0.93407321, 0.40182066,   0.036301374,
1778     0.66521823,   0.0300982,    -0.7747041,  -0.02038002,  0.020698071,
1779     -0.90300065,  0.62870288,   -0.23068321, 0.27531278,   -0.095755219,
1780     -0.712036,    -0.17384434,  -0.50593495, -0.18646687,  -0.96508682,
1781     0.43519354,   0.14744234,   0.62589407,  0.1653645,    -0.10651493,
1782     -0.045277178, 0.99032974,   -0.88255352, -0.85147917,  0.28153265,
1783     0.19455957,   -0.55479527,  -0.56042433, 0.26048636,   0.84702539,
1784     0.47587705,   -0.074295521, -0.12287641, 0.70117295,   0.90532446,
1785     0.89782166,   0.79817224,   0.53402734,  -0.33286154,  0.073485017,
1786     -0.56172788,  -0.044897556, 0.89964068,  -0.067662835, 0.76863563,
1787     0.93455386,   -0.6324693,   -0.083922029};
1788 
1789 static float rnn_golden_output[] = {
1790     0.496726,   0,          0.965996,  0,         0.0584254, 0,
1791     0,          0.12315,    0,         0,         0.612266,  0.456601,
1792     0,          0.52286,    1.16099,   0.0291232,
1793 
1794     0,          0,          0.524901,  0,         0,         0,
1795     0,          1.02116,    0,         1.35762,   0,         0.356909,
1796     0.436415,   0.0355727,  0,         0,
1797 
1798     0,          0,          0,         0.262335,  0,         0,
1799     0,          1.33992,    0,         2.9739,    0,         0,
1800     1.31914,    2.66147,    0,         0,
1801 
1802     0.942568,   0,          0,         0,         0.025507,  0,
1803     0,          0,          0.321429,  0.569141,  1.25274,   1.57719,
1804     0.8158,     1.21805,    0.586239,  0.25427,
1805 
1806     1.04436,    0,          0.630725,  0,         0.133801,  0.210693,
1807     0.363026,   0,          0.533426,  0,         1.25926,   0.722707,
1808     0,          1.22031,    1.30117,   0.495867,
1809 
1810     0.222187,   0,          0.72725,   0,         0.767003,  0,
1811     0,          0.147835,   0,         0,         0,         0.608758,
1812     0.469394,   0.00720298, 0.927537,  0,
1813 
1814     0.856974,   0.424257,   0,         0,         0.937329,  0,
1815     0,          0,          0.476425,  0,         0.566017,  0.418462,
1816     0.141911,   0.996214,   1.13063,   0,
1817 
1818     0.967899,   0,          0,         0,         0.0831304, 0,
1819     0,          1.00378,    0,         0,         0,         1.44818,
1820     1.01768,    0.943891,   0.502745,  0,
1821 
1822     0.940135,   0,          0,         0,         0,         0,
1823     0,          2.13243,    0,         0.71208,   0.123918,  1.53907,
1824     1.30225,    1.59644,    0.70222,   0,
1825 
1826     0.804329,   0,          0.430576,  0,         0.505872,  0.509603,
1827     0.343448,   0,          0.107756,  0.614544,  1.44549,   1.52311,
1828     0.0454298,  0.300267,   0.562784,  0.395095,
1829 
1830     0.228154,   0,          0.675323,  0,         1.70536,   0.766217,
1831     0,          0,          0,         0.735363,  0.0759267, 1.91017,
1832     0.941888,   0,          0,         0,
1833 
1834     0,          0,          1.5909,    0,         0,         0,
1835     0,          0.5755,     0,         0.184687,  0,         1.56296,
1836     0.625285,   0,          0,         0,
1837 
1838     0,          0,          0.0857888, 0,         0,         0,
1839     0,          0.488383,   0.252786,  0,         0,         0,
1840     1.02817,    1.85665,    0,         0,
1841 
1842     0.00981836, 0,          1.06371,   0,         0,         0,
1843     0,          0,          0,         0.290445,  0.316406,  0,
1844     0.304161,   1.25079,    0.0707152, 0,
1845 
1846     0.986264,   0.309201,   0,         0,         0,         0,
1847     0,          1.64896,    0.346248,  0,         0.918175,  0.78884,
1848     0.524981,   1.92076,    2.07013,   0.333244,
1849 
1850     0.415153,   0.210318,   0,         0,         0,         0,
1851     0,          2.02616,    0,         0.728256,  0.84183,   0.0907453,
1852     0.628881,   3.58099,    1.49974,   0};
1853 
1854 static std::initializer_list<float> rnn_weights = {
1855     0.461459,    0.153381,   0.529743,    -0.00371218, 0.676267,   -0.211346,
1856     0.317493,    0.969689,   -0.343251,   0.186423,    0.398151,   0.152399,
1857     0.448504,    0.317662,   0.523556,    -0.323514,   0.480877,   0.333113,
1858     -0.757714,   -0.674487,  -0.643585,   0.217766,    -0.0251462, 0.79512,
1859     -0.595574,   -0.422444,  0.371572,    -0.452178,   -0.556069,  -0.482188,
1860     -0.685456,   -0.727851,  0.841829,    0.551535,    -0.232336,  0.729158,
1861     -0.00294906, -0.69754,   0.766073,    -0.178424,   0.369513,   -0.423241,
1862     0.548547,    -0.0152023, -0.757482,   -0.85491,    0.251331,   -0.989183,
1863     0.306261,    -0.340716,  0.886103,    -0.0726757,  -0.723523,  -0.784303,
1864     0.0354295,   0.566564,   -0.485469,   -0.620498,   0.832546,   0.697884,
1865     -0.279115,   0.294415,   -0.584313,   0.548772,    0.0648819,  0.968726,
1866     0.723834,    -0.0080452, -0.350386,   -0.272803,   0.115121,   -0.412644,
1867     -0.824713,   -0.992843,  -0.592904,   -0.417893,   0.863791,   -0.423461,
1868     -0.147601,   -0.770664,  -0.479006,   0.654782,    0.587314,   -0.639158,
1869     0.816969,    -0.337228,  0.659878,    0.73107,     0.754768,   -0.337042,
1870     0.0960841,   0.368357,   0.244191,    -0.817703,   -0.211223,  0.442012,
1871     0.37225,     -0.623598,  -0.405423,   0.455101,    0.673656,   -0.145345,
1872     -0.511346,   -0.901675,  -0.81252,    -0.127006,   0.809865,   -0.721884,
1873     0.636255,    0.868989,   -0.347973,   -0.10179,    -0.777449,  0.917274,
1874     0.819286,    0.206218,   -0.00785118, 0.167141,    0.45872,    0.972934,
1875     -0.276798,   0.837861,   0.747958,    -0.0151566,  -0.330057,  -0.469077,
1876     0.277308,    0.415818};
1877 
1878 static std::initializer_list<float> rnn_recurrent_weights = {
1879     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1880     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1881     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1882     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1883     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1884     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1885     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1886     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1887     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1888     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1889     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1890     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1891     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1892     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1893     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1894     0.1};
1895 
1896 static std::initializer_list<float> rnn_bias = {
1897     0.065691948, -0.69055247, 0.1107955,  -0.97084129, -0.23957068, -0.23566568,
1898     -0.389184,   0.47481549,  -0.4791103, 0.29931796,  0.10463274,  0.83918178,
1899     0.37197268,  0.61957061,  0.3956964,  -0.37609905};
1900 
1901 class RNNOpModel : public SingleOpModelWithNNAPI {
1902  public:
RNNOpModel(int batches,int units,int size,const TensorType weights=TensorType_FLOAT32,const TensorType recurrent_weights=TensorType_FLOAT32)1903   RNNOpModel(int batches, int units, int size,
1904              const TensorType weights = TensorType_FLOAT32,
1905              const TensorType recurrent_weights = TensorType_FLOAT32)
1906       : batches_(batches), units_(units), input_size_(size) {
1907     input_ = AddInput(TensorType_FLOAT32);
1908     weights_ = AddInput(weights);
1909     recurrent_weights_ = AddInput(recurrent_weights);
1910     bias_ = AddInput(TensorType_FLOAT32);
1911     hidden_state_ = AddInput(TensorType_FLOAT32, true);
1912     output_ = AddOutput(TensorType_FLOAT32);
1913     SetBuiltinOp(
1914         BuiltinOperator_RNN, BuiltinOptions_RNNOptions,
1915         CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union());
1916     BuildInterpreter({{batches_, input_size_},  // input tensor
1917                       {units_, input_size_},    // weights tensor
1918                       {units_, units_},         // recurrent weights tensor
1919                       {units_},                 // bias tensor
1920                       {batches_, units_}});     // hidden state tensor
1921   }
1922 
SetBias(std::initializer_list<float> f)1923   void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
1924 
SetWeights(std::initializer_list<float> f)1925   void SetWeights(std::initializer_list<float> f) {
1926     PopulateTensor(weights_, f);
1927   }
1928 
SetRecurrentWeights(std::initializer_list<float> f)1929   void SetRecurrentWeights(std::initializer_list<float> f) {
1930     PopulateTensor(recurrent_weights_, f);
1931   }
1932 
SetInput(std::initializer_list<float> data)1933   void SetInput(std::initializer_list<float> data) {
1934     PopulateTensor(input_, data);
1935   }
1936 
SetInput(int offset,float * begin,float * end)1937   void SetInput(int offset, float* begin, float* end) {
1938     PopulateTensor(input_, offset, begin, end);
1939   }
1940 
GetOutput()1941   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
1942 
input_size()1943   int input_size() { return input_size_; }
num_units()1944   int num_units() { return units_; }
num_batches()1945   int num_batches() { return batches_; }
1946 
1947  protected:
1948   int input_;
1949   int weights_;
1950   int recurrent_weights_;
1951   int bias_;
1952   int hidden_state_;
1953   int output_;
1954 
1955   int batches_;
1956   int units_;
1957   int input_size_;
1958 };
1959 
TEST(NNAPIDelegate,RnnBlackBoxTest)1960 TEST(NNAPIDelegate, RnnBlackBoxTest) {
1961   RNNOpModel rnn(2, 16, 8);
1962   rnn.SetWeights(rnn_weights);
1963   rnn.SetBias(rnn_bias);
1964   rnn.SetRecurrentWeights(rnn_recurrent_weights);
1965 
1966   const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
1967                                   (rnn.input_size() * rnn.num_batches());
1968 
1969   for (int i = 0; i < input_sequence_size; i++) {
1970     float* batch_start = rnn_input + i * rnn.input_size();
1971     float* batch_end = batch_start + rnn.input_size();
1972     rnn.SetInput(0, batch_start, batch_end);
1973     rnn.SetInput(rnn.input_size(), batch_start, batch_end);
1974 
1975     rnn.Invoke();
1976 
1977     float* golden_start = rnn_golden_output + i * rnn.num_units();
1978     float* golden_end = golden_start + rnn.num_units();
1979     std::vector<float> expected;
1980     expected.insert(expected.end(), golden_start, golden_end);
1981     expected.insert(expected.end(), golden_start, golden_end);
1982 
1983     EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
1984   }
1985 }
1986 
1987 static float svdf_input[] = {
1988     0.12609188,  -0.46347019, -0.89598465,
1989     0.35867718,  0.36897406,  0.73463392,
1990 
1991     0.14278367,  -1.64410412, -0.75222826,
1992     -0.57290924, 0.12729003,  0.7567004,
1993 
1994     0.49837467,  0.19278903,  0.26584083,
1995     0.17660543,  0.52949083,  -0.77931279,
1996 
1997     -0.11186574, 0.13164264,  -0.05349274,
1998     -0.72674477, -0.5683046,  0.55900657,
1999 
2000     -0.68892461, 0.37783599,  0.18263303,
2001     -0.63690937, 0.44483393,  -0.71817774,
2002 
2003     -0.81299269, -0.86831826, 1.43940818,
2004     -0.95760226, 1.82078898,  0.71135032,
2005 
2006     -1.45006323, -0.82251364, -1.69082689,
2007     -1.65087092, -1.89238167, 1.54172635,
2008 
2009     0.03966608,  -0.24936394, -0.77526885,
2010     2.06740379,  -1.51439476, 1.43768692,
2011 
2012     0.11771342,  -0.23761693, -0.65898693,
2013     0.31088525,  -1.55601168, -0.87661445,
2014 
2015     -0.89477462, 1.67204106,  -0.53235275,
2016     -0.6230064,  0.29819036,  1.06939757,
2017 };
2018 
2019 static float svdf_golden_output_rank_1[] = {
2020     0.014899,    -0.0517661,  -0.143725,   -0.00271883,
2021     -0.03004015, 0.09565311,  0.1587342,   0.00784263,
2022 
2023     0.068281,    -0.162217,   -0.152268,   0.00323521,
2024     0.01582633,  0.03858774,  -0.03001583, -0.02671271,
2025 
2026     -0.0317821,  -0.0333089,  0.0609602,   0.0333759,
2027     -0.01432795, 0.05524484,  0.1101355,   -0.02382665,
2028 
2029     -0.00623099, -0.077701,   -0.391193,   -0.0136691,
2030     -0.02333033, 0.02293761,  0.12338032,  0.04326871,
2031 
2032     0.201551,    -0.164607,   -0.179462,   -0.0592739,
2033     0.01064911,  -0.17503069, 0.07821996,  -0.00224009,
2034 
2035     0.0886511,   -0.0875401,  -0.269283,   0.0281379,
2036     -0.02282338, 0.09741908,  0.32973239,  0.12281385,
2037 
2038     -0.201174,   -0.586145,   -0.628624,   -0.0330412,
2039     0.24780814,  -0.39304617, -0.22473189, 0.02589256,
2040 
2041     -0.0839096,  -0.299329,   0.108746,    0.109808,
2042     0.10084175,  -0.06416984, 0.28936723,  0.0026358,
2043 
2044     0.419114,    -0.237824,   -0.422627,   0.175115,
2045     -0.2314795,  -0.18584411, -0.4228974,  -0.12928449,
2046 
2047     0.36726,     -0.522303,   -0.456502,   -0.175475,
2048     0.17012937,  -0.34447709, 0.38505614,  -0.28158101,
2049 };
2050 
2051 static float svdf_golden_output_rank_2[] = {
2052     -0.09623547, -0.10193135, 0.11083051,  -0.0347917,
2053     0.1141196,   0.12965347,  -0.12652366, 0.01007236,
2054 
2055     -0.16396809, -0.21247184, 0.11259045,  -0.04156673,
2056     0.10132131,  -0.06143532, -0.00924693, 0.10084561,
2057 
2058     0.01257364,  0.0506071,   -0.19287863, -0.07162561,
2059     -0.02033747, 0.22673416,  0.15487903,  0.02525555,
2060 
2061     -0.1411963,  -0.37054959, 0.01774767,  0.05867489,
2062     0.09607603,  -0.0141301,  -0.08995658, 0.12867066,
2063 
2064     -0.27142537, -0.16955489, 0.18521598,  -0.12528358,
2065     0.00331409,  0.11167502,  0.02218599,  -0.07309391,
2066 
2067     0.09593632,  -0.28361851, -0.0773851,  0.17199151,
2068     -0.00075242, 0.33691186,  -0.1536046,  0.16572715,
2069 
2070     -0.27916506, -0.27626723, 0.42615682,  0.3225764,
2071     -0.37472126, -0.55655634, -0.05013514, 0.289112,
2072 
2073     -0.24418658, 0.07540751,  -0.1940318,  -0.08911639,
2074     0.00732617,  0.46737891,  0.26449674,  0.24888524,
2075 
2076     -0.17225097, -0.54660404, -0.38795233, 0.08389944,
2077     0.07736043,  -0.28260678, 0.15666828,  1.14949894,
2078 
2079     -0.57454878, -0.64704704, 0.73235172,  -0.34616736,
2080     0.21120001,  -0.22927976, 0.02455296,  -0.35906726,
2081 };
2082 
2083 class BaseSVDFOpModel : public SingleOpModelWithNNAPI {
2084  public:
BaseSVDFOpModel(int batches,int units,int input_size,int memory_size,int rank,TensorType weights_feature_type=TensorType_FLOAT32,TensorType weights_time_type=TensorType_FLOAT32)2085   BaseSVDFOpModel(int batches, int units, int input_size, int memory_size,
2086                   int rank,
2087                   TensorType weights_feature_type = TensorType_FLOAT32,
2088                   TensorType weights_time_type = TensorType_FLOAT32)
2089       : batches_(batches),
2090         units_(units),
2091         input_size_(input_size),
2092         memory_size_(memory_size),
2093         rank_(rank) {
2094     input_ = AddInput(TensorType_FLOAT32);
2095     weights_feature_ = AddInput(weights_feature_type);
2096     weights_time_ = AddInput(weights_time_type);
2097     // TODO(b/121383394) : figure out why optional bias causes TFLite segfault
2098     // when using NNAPI delegate.
2099     bias_ = AddInput(TensorType_FLOAT32);
2100     const int num_filters = units * rank;
2101     activation_state_ = AddInput(
2102         TensorData{TensorType_FLOAT32, {batches, memory_size * num_filters}},
2103         /*is_variable=*/true);
2104     output_ = AddOutput(TensorType_FLOAT32);
2105     SetBuiltinOp(
2106         BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions,
2107         CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE).Union());
2108     BuildInterpreter({
2109         {batches_, input_size_},              // input tensor
2110         {units_ * rank, input_size_},         // weights_feature tensor
2111         {units_ * rank, memory_size_},        // weights_time tensor
2112         {units_},                             // bias tensor
2113         {batches, memory_size * num_filters}  // activation_state tensor
2114     });
2115     // TODO(b/121383394) : remove once the optional bias bug is fixed.
2116     PopulateTensor(bias_, std::vector<float>(units_));
2117   }
2118 
2119   // Populates the weights_feature tensor.
SetWeightsFeature(std::initializer_list<float> f)2120   void SetWeightsFeature(std::initializer_list<float> f) {
2121     PopulateTensor(weights_feature_, f);
2122   }
2123 
2124   // Populates the weights_time tensor.
SetWeightsTime(std::initializer_list<float> f)2125   void SetWeightsTime(std::initializer_list<float> f) {
2126     PopulateTensor(weights_time_, f);
2127   }
2128 
2129   // Populates the input tensor.
SetInput(int offset,float * begin,float * end)2130   void SetInput(int offset, float* begin, float* end) {
2131     PopulateTensor(input_, offset, begin, end);
2132   }
2133 
2134   // Extracts the output tensor from the SVDF op.
GetOutput()2135   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
2136 
input_size()2137   int input_size() { return input_size_; }
num_units()2138   int num_units() { return units_; }
num_batches()2139   int num_batches() { return batches_; }
2140 
2141  protected:
2142   int input_;
2143   int weights_feature_;
2144   int weights_time_;
2145   int bias_;
2146   int activation_state_;
2147   int output_;
2148 
2149   int batches_;
2150   int units_;
2151   int input_size_;
2152   int memory_size_;
2153   int rank_;
2154 };
2155 
2156 class SVDFOpModel : public BaseSVDFOpModel {
2157  public:
2158   using BaseSVDFOpModel::BaseSVDFOpModel;
2159 };
2160 
2161 class SVDFOpTest : public ::testing::Test {
2162  protected:
VerifyGoldens(float golden_input[],float golden_output[],int golden_size,BaseSVDFOpModel * svdf,float tolerance=1e-5)2163   void VerifyGoldens(float golden_input[], float golden_output[],
2164                      int golden_size, BaseSVDFOpModel* svdf,
2165                      float tolerance = 1e-5) {
2166     const int svdf_num_batches = svdf->num_batches();
2167     const int svdf_input_size = svdf->input_size();
2168     const int svdf_num_units = svdf->num_units();
2169     const int input_sequence_size =
2170         golden_size / sizeof(float) / (svdf_input_size * svdf_num_batches);
2171     // Going over each input batch, setting the input tensor, invoking the SVDF
2172     // op and checking the output with the expected golden values.
2173     for (int i = 0; i < input_sequence_size; i++) {
2174       float* batch_start =
2175           golden_input + i * svdf_input_size * svdf_num_batches;
2176       float* batch_end = batch_start + svdf_input_size * svdf_num_batches;
2177       svdf->SetInput(0, batch_start, batch_end);
2178 
2179       svdf->Invoke();
2180 
2181       const float* golden_start =
2182           golden_output + i * svdf_num_units * svdf_num_batches;
2183       const float* golden_end =
2184           golden_start + svdf_num_units * svdf_num_batches;
2185       std::vector<float> expected;
2186       expected.insert(expected.end(), golden_start, golden_end);
2187 
2188       EXPECT_THAT(svdf->GetOutput(),
2189                   ElementsAreArray(ArrayFloatNear(expected, tolerance)));
2190     }
2191   }
2192 };
2193 
TEST_F(SVDFOpTest,BlackBoxTestRank1)2194 TEST_F(SVDFOpTest, BlackBoxTestRank1) {
2195   SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
2196                    /*memory_size=*/10, /*rank=*/1);
2197   svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
2198                           0.22197971, 0.12416199, 0.27901134, 0.27557442,
2199                           0.3905206, -0.36137494, -0.06634006, -0.10640851});
2200 
2201   svdf.SetWeightsTime(
2202       {-0.31930989, 0.37613347,  0.27901134,  -0.36137494, -0.36118156,
2203        0.22197971,  0.27557442,  -0.06634006, 0.0079667,   0.12416199,
2204 
2205        0.3905206,   -0.10640851, -0.0976817,  0.15294972,  0.39635518,
2206        -0.02702999, 0.39296314,  0.15785322,  0.21931258,  0.31053296,
2207 
2208        -0.36916667, 0.38031587,  -0.21580373, 0.27072677,  0.23622236,
2209        0.34936687,  0.18174365,  0.35907319,  -0.17493086, 0.324846,
2210 
2211        -0.10781813, 0.27201805,  0.14324132,  -0.23681851, -0.27115166,
2212        -0.01580888, -0.14943552, 0.15465137,  0.09784451,  -0.0337657});
2213 
2214   VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
2215                 &svdf);
2216 }
2217 
TEST_F(SVDFOpTest,BlackBoxTestRank2)2218 TEST_F(SVDFOpTest, BlackBoxTestRank2) {
2219   SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
2220                    /*memory_size=*/10, /*rank=*/2);
2221   svdf.SetWeightsFeature({-0.31930989, 0.0079667,   0.39296314,  0.37613347,
2222                           0.12416199,  0.15785322,  0.27901134,  0.3905206,
2223                           0.21931258,  -0.36137494, -0.10640851, 0.31053296,
2224                           -0.36118156, -0.0976817,  -0.36916667, 0.22197971,
2225                           0.15294972,  0.38031587,  0.27557442,  0.39635518,
2226                           -0.21580373, -0.06634006, -0.02702999, 0.27072677});
2227 
2228   svdf.SetWeightsTime(
2229       {-0.31930989, 0.37613347,  0.27901134,  -0.36137494, -0.36118156,
2230        0.22197971,  0.27557442,  -0.06634006, 0.0079667,   0.12416199,
2231 
2232        0.3905206,   -0.10640851, -0.0976817,  0.15294972,  0.39635518,
2233        -0.02702999, 0.39296314,  0.15785322,  0.21931258,  0.31053296,
2234 
2235        -0.36916667, 0.38031587,  -0.21580373, 0.27072677,  0.23622236,
2236        0.34936687,  0.18174365,  0.35907319,  -0.17493086, 0.324846,
2237 
2238        -0.10781813, 0.27201805,  0.14324132,  -0.23681851, -0.27115166,
2239        -0.01580888, -0.14943552, 0.15465137,  0.09784451,  -0.0337657,
2240 
2241        -0.14884081, 0.19931212,  -0.36002168, 0.34663299,  -0.11405486,
2242        0.12672701,  0.39463779,  -0.07886535, -0.06384811, 0.08249187,
2243 
2244        -0.26816407, -0.19905911, 0.29211238,  0.31264046,  -0.28664589,
2245        0.05698794,  0.11613581,  0.14078894,  0.02187902,  -0.21781836,
2246 
2247        -0.15567942, 0.08693647,  -0.38256618, 0.36580828,  -0.22922277,
2248        -0.0226903,  0.12878349,  -0.28122205, -0.10850525, -0.11955214,
2249 
2250        0.27179423,  -0.04710215, 0.31069002,  0.22672787,  0.09580326,
2251        0.08682203,  0.1258215,   0.1851041,   0.29228821,  0.12366763});
2252 
2253   VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
2254                 &svdf);
2255 }
2256 
2257 class LSTMOpModel : public SingleOpModelWithNNAPI {
2258  public:
LSTMOpModel(int n_batch,int n_input,int n_cell,int n_output,bool use_cifg,bool use_peephole,bool use_projection_weights,bool use_projection_bias,float cell_clip,float proj_clip,const std::vector<std::vector<int>> & input_shapes,const TensorType weight_type)2259   LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg,
2260               bool use_peephole, bool use_projection_weights,
2261               bool use_projection_bias, float cell_clip, float proj_clip,
2262               const std::vector<std::vector<int>>& input_shapes,
2263               const TensorType weight_type)
2264       : n_batch_(n_batch),
2265         n_input_(n_input),
2266         n_cell_(n_cell),
2267         n_output_(n_output),
2268         weight_type_(weight_type) {
2269     input_ = AddInput(TensorType_FLOAT32);
2270 
2271     if (use_cifg) {
2272       input_to_input_weights_ = AddNullInput();
2273     } else {
2274       input_to_input_weights_ = AddInput(weight_type);
2275     }
2276 
2277     input_to_forget_weights_ = AddInput(weight_type);
2278     input_to_cell_weights_ = AddInput(weight_type);
2279     input_to_output_weights_ = AddInput(weight_type);
2280 
2281     if (use_cifg) {
2282       recurrent_to_input_weights_ = AddNullInput();
2283     } else {
2284       recurrent_to_input_weights_ = AddInput(weight_type);
2285     }
2286 
2287     recurrent_to_forget_weights_ = AddInput(weight_type);
2288     recurrent_to_cell_weights_ = AddInput(weight_type);
2289     recurrent_to_output_weights_ = AddInput(weight_type);
2290 
2291     if (use_peephole) {
2292       if (use_cifg) {
2293         cell_to_input_weights_ = AddNullInput();
2294       } else {
2295         cell_to_input_weights_ = AddInput(weight_type);
2296       }
2297       cell_to_forget_weights_ = AddInput(weight_type);
2298       cell_to_output_weights_ = AddInput(weight_type);
2299     } else {
2300       cell_to_input_weights_ = AddNullInput();
2301       cell_to_forget_weights_ = AddNullInput();
2302       cell_to_output_weights_ = AddNullInput();
2303     }
2304 
2305     if (use_cifg) {
2306       input_gate_bias_ = AddNullInput();
2307     } else {
2308       input_gate_bias_ = AddInput(TensorType_FLOAT32);
2309     }
2310     forget_gate_bias_ = AddInput(TensorType_FLOAT32);
2311     cell_bias_ = AddInput(TensorType_FLOAT32);
2312     output_gate_bias_ = AddInput(TensorType_FLOAT32);
2313 
2314     if (use_projection_weights) {
2315       projection_weights_ = AddInput(weight_type);
2316       if (use_projection_bias) {
2317         projection_bias_ = AddInput(TensorType_FLOAT32);
2318       } else {
2319         projection_bias_ = AddNullInput();
2320       }
2321     } else {
2322       projection_weights_ = AddNullInput();
2323       projection_bias_ = AddNullInput();
2324     }
2325 
2326     // Adding the 2 input state tensors.
2327     input_activation_state_ =
2328         AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_output_}}, true);
2329     input_cell_state_ =
2330         AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_cell_}}, true);
2331 
2332     output_ = AddOutput(TensorType_FLOAT32);
2333 
2334     SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
2335                  CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
2336                                    cell_clip, proj_clip)
2337                      .Union());
2338     BuildInterpreter(input_shapes);
2339   }
2340 
SetInputToInputWeights(const std::vector<float> & f)2341   void SetInputToInputWeights(const std::vector<float>& f) {
2342     SetData(input_to_input_weights_, weight_type_, f);
2343   }
2344 
SetInputToForgetWeights(const std::vector<float> & f)2345   void SetInputToForgetWeights(const std::vector<float>& f) {
2346     SetData(input_to_forget_weights_, weight_type_, f);
2347   }
2348 
SetInputToCellWeights(const std::vector<float> & f)2349   void SetInputToCellWeights(const std::vector<float>& f) {
2350     SetData(input_to_cell_weights_, weight_type_, f);
2351   }
2352 
SetInputToOutputWeights(const std::vector<float> & f)2353   void SetInputToOutputWeights(const std::vector<float>& f) {
2354     SetData(input_to_output_weights_, weight_type_, f);
2355   }
2356 
SetRecurrentToInputWeights(const std::vector<float> & f)2357   void SetRecurrentToInputWeights(const std::vector<float>& f) {
2358     SetData(recurrent_to_input_weights_, weight_type_, f);
2359   }
2360 
SetRecurrentToForgetWeights(const std::vector<float> & f)2361   void SetRecurrentToForgetWeights(const std::vector<float>& f) {
2362     SetData(recurrent_to_forget_weights_, weight_type_, f);
2363   }
2364 
SetRecurrentToCellWeights(const std::vector<float> & f)2365   void SetRecurrentToCellWeights(const std::vector<float>& f) {
2366     SetData(recurrent_to_cell_weights_, weight_type_, f);
2367   }
2368 
SetRecurrentToOutputWeights(const std::vector<float> & f)2369   void SetRecurrentToOutputWeights(const std::vector<float>& f) {
2370     SetData(recurrent_to_output_weights_, weight_type_, f);
2371   }
2372 
SetCellToInputWeights(const std::vector<float> & f)2373   void SetCellToInputWeights(const std::vector<float>& f) {
2374     SetData(cell_to_input_weights_, weight_type_, f);
2375   }
2376 
SetCellToForgetWeights(const std::vector<float> & f)2377   void SetCellToForgetWeights(const std::vector<float>& f) {
2378     SetData(cell_to_forget_weights_, weight_type_, f);
2379   }
2380 
SetCellToOutputWeights(const std::vector<float> & f)2381   void SetCellToOutputWeights(const std::vector<float>& f) {
2382     SetData(cell_to_output_weights_, weight_type_, f);
2383   }
2384 
SetInputGateBias(const std::vector<float> & f)2385   void SetInputGateBias(const std::vector<float>& f) {
2386     PopulateTensor(input_gate_bias_, f);
2387   }
2388 
SetForgetGateBias(const std::vector<float> & f)2389   void SetForgetGateBias(const std::vector<float>& f) {
2390     PopulateTensor(forget_gate_bias_, f);
2391   }
2392 
SetCellBias(const std::vector<float> & f)2393   void SetCellBias(const std::vector<float>& f) {
2394     PopulateTensor(cell_bias_, f);
2395   }
2396 
SetOutputGateBias(const std::vector<float> & f)2397   void SetOutputGateBias(const std::vector<float>& f) {
2398     PopulateTensor(output_gate_bias_, f);
2399   }
2400 
SetProjectionWeights(const std::vector<float> & f)2401   void SetProjectionWeights(const std::vector<float>& f) {
2402     SetData(projection_weights_, weight_type_, f);
2403   }
2404 
SetProjectionBias(const std::vector<float> & f)2405   void SetProjectionBias(const std::vector<float>& f) {
2406     PopulateTensor(projection_bias_, f);
2407   }
2408 
SetInput(int offset,const float * begin,const float * end)2409   void SetInput(int offset, const float* begin, const float* end) {
2410     PopulateTensor(input_, offset, const_cast<float*>(begin),
2411                    const_cast<float*>(end));
2412   }
2413 
GetOutput()2414   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
2415 
num_inputs()2416   int num_inputs() { return n_input_; }
num_outputs()2417   int num_outputs() { return n_output_; }
num_cells()2418   int num_cells() { return n_cell_; }
num_batches()2419   int num_batches() { return n_batch_; }
2420 
2421  protected:
2422   int input_;
2423   int input_to_input_weights_;
2424   int input_to_forget_weights_;
2425   int input_to_cell_weights_;
2426   int input_to_output_weights_;
2427 
2428   int recurrent_to_input_weights_;
2429   int recurrent_to_forget_weights_;
2430   int recurrent_to_cell_weights_;
2431   int recurrent_to_output_weights_;
2432 
2433   int cell_to_input_weights_;
2434   int cell_to_forget_weights_;
2435   int cell_to_output_weights_;
2436 
2437   int input_gate_bias_;
2438   int forget_gate_bias_;
2439   int cell_bias_;
2440   int output_gate_bias_;
2441 
2442   int projection_weights_;
2443   int projection_bias_;
2444   int input_activation_state_;
2445   int input_cell_state_;
2446 
2447   int output_;
2448   int output_state_;
2449   int cell_state_;
2450 
2451   int n_batch_;
2452   int n_input_;
2453   int n_cell_;
2454   int n_output_;
2455 
2456  private:
2457   const TensorType weight_type_;
2458 };
2459 
2460 class BaseLstmTest : public ::testing::Test {
2461  protected:
2462   // Weights of the LSTM model. Some are optional.
2463   std::vector<float> input_to_input_weights_;
2464   std::vector<float> input_to_cell_weights_;
2465   std::vector<float> input_to_forget_weights_;
2466   std::vector<float> input_to_output_weights_;
2467   std::vector<float> input_gate_bias_;
2468   std::vector<float> cell_gate_bias_;
2469   std::vector<float> forget_gate_bias_;
2470   std::vector<float> output_gate_bias_;
2471   std::vector<float> recurrent_to_input_weights_;
2472   std::vector<float> recurrent_to_cell_weights_;
2473   std::vector<float> recurrent_to_forget_weights_;
2474   std::vector<float> recurrent_to_output_weights_;
2475   std::vector<float> cell_to_input_weights_;
2476   std::vector<float> cell_to_forget_weights_;
2477   std::vector<float> cell_to_output_weights_;
2478   std::vector<float> projection_weights_;
2479 
2480   // LSTM input is stored as num_batch x num_inputs vector.
2481   std::vector<std::vector<float>> lstm_input_;
2482   // LSTM output is stored as num_batch x num_outputs vector.
2483   std::vector<std::vector<float>> lstm_golden_output_;
2484 
2485   // Compares output up to tolerance to the result of the lstm given the input.
VerifyGoldens(const std::vector<std::vector<float>> & input,const std::vector<std::vector<float>> & output,LSTMOpModel * lstm,float tolerance=1e-5)2486   void VerifyGoldens(const std::vector<std::vector<float>>& input,
2487                      const std::vector<std::vector<float>>& output,
2488                      LSTMOpModel* lstm, float tolerance = 1e-5) {
2489     const int num_batches = input.size();
2490     EXPECT_GT(num_batches, 0);
2491     const int num_inputs = lstm->num_inputs();
2492     EXPECT_GT(num_inputs, 0);
2493     const int input_sequence_size = input[0].size() / num_inputs;
2494     EXPECT_GT(input_sequence_size, 0);
2495     for (int i = 0; i < input_sequence_size; ++i) {
2496       for (int b = 0; b < num_batches; ++b) {
2497         const float* batch_start = input[b].data() + i * num_inputs;
2498         const float* batch_end = batch_start + num_inputs;
2499 
2500         lstm->SetInput(b * lstm->num_inputs(), batch_start, batch_end);
2501       }
2502 
2503       lstm->Invoke();
2504 
2505       const int num_outputs = lstm->num_outputs();
2506       std::vector<float> expected;
2507       for (int b = 0; b < num_batches; ++b) {
2508         const float* golden_start_batch = output[b].data() + i * num_outputs;
2509         const float* golden_end_batch = golden_start_batch + num_outputs;
2510         expected.insert(expected.end(), golden_start_batch, golden_end_batch);
2511       }
2512       EXPECT_THAT(lstm->GetOutput(),
2513                   ElementsAreArray(ArrayFloatNear(expected, tolerance)));
2514     }
2515   }
2516 };
2517 
2518 class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
SetUp()2519   void SetUp() override {
2520     input_to_input_weights_ = {-0.45018822, -0.02338299, -0.0870589,
2521                                -0.34550029, 0.04266912,  -0.15680569,
2522                                -0.34856534, 0.43890524};
2523     input_to_cell_weights_ = {-0.50013041, 0.1370284,  0.11810488, 0.2013163,
2524                               -0.20583314, 0.44344562, 0.22077113, -0.29909778};
2525     input_to_forget_weights_ = {0.09701663,  0.20334584,  -0.50592935,
2526                                 -0.31343272, -0.40032279, 0.44781327,
2527                                 0.01387155,  -0.35593212};
2528     input_to_output_weights_ = {-0.25065863, -0.28290087, 0.04613829,
2529                                 0.40525138,  0.44272184,  0.03897077,
2530                                 -0.1556896,  0.19487578};
2531     input_gate_bias_ = {0., 0., 0., 0.};
2532     cell_gate_bias_ = {0., 0., 0., 0.};
2533     forget_gate_bias_ = {1., 1., 1., 1.};
2534     output_gate_bias_ = {0., 0., 0., 0.};
2535 
2536     recurrent_to_input_weights_ = {
2537         -0.0063535,  -0.2042388,  0.31454784,  -0.35746509,
2538         0.28902304,  0.08183324,  -0.16555229, 0.02286911,
2539         -0.13566875, 0.03034258,  0.48091322,  -0.12528998,
2540         0.24077177,  -0.51332325, -0.33502164, 0.10629296};
2541 
2542     recurrent_to_cell_weights_ = {
2543         -0.3407414,  0.24443203,  -0.2078532,  0.26320225,
2544         0.05695659,  -0.00123841, -0.4744786,  -0.35869038,
2545         -0.06418842, -0.13502428, -0.501764,   0.22830659,
2546         -0.46367589, 0.26016325,  -0.03894562, -0.16368064};
2547 
2548     recurrent_to_forget_weights_ = {
2549         -0.48684245, -0.06655136, 0.42224967,  0.2112639,
2550         0.27654213,  0.20864892,  -0.07646349, 0.45877004,
2551         0.00141793,  -0.14609534, 0.36447752,  0.09196436,
2552         0.28053468,  0.01560611,  -0.20127171, -0.01140004};
2553 
2554     recurrent_to_output_weights_ = {
2555         0.43385774,  -0.17194885, 0.2718237,  0.09215671,
2556         0.24107647,  -0.39835793, 0.18212086, 0.01301402,
2557         0.48572797,  -0.50656658, 0.20047462, -0.20607421,
2558         -0.51818722, -0.15390486, 0.0468148,  0.39922136};
2559 
2560     lstm_input_ = {{2., 3., 3., 4., 1., 1.}};
2561     lstm_golden_output_ = {{-0.02973187, 0.1229473, 0.20885126, -0.15358765,
2562                             -0.03716109, 0.12507336, 0.41193449, -0.20860538,
2563                             -0.15053082, 0.09120187, 0.24278517, -0.12222792}};
2564   }
2565 };
2566 
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,LstmBlackBoxTest)2567 TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
2568   const int n_batch = 1;
2569   const int n_input = 2;
2570   // n_cell and n_output have the same size when there is no projection.
2571   const int n_cell = 4;
2572   const int n_output = 4;
2573 
2574   LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
2575                    /*use_cifg=*/false, /*use_peephole=*/false,
2576                    /*use_projection_weights=*/false,
2577                    /*use_projection_bias=*/false,
2578                    /*cell_clip=*/0.0, /*proj_clip=*/0.0,
2579                    {
2580                        {n_batch, n_input},  // input tensor
2581 
2582                        {n_cell, n_input},  // input_to_input_weight tensor
2583                        {n_cell, n_input},  // input_to_forget_weight tensor
2584                        {n_cell, n_input},  // input_to_cell_weight tensor
2585                        {n_cell, n_input},  // input_to_output_weight tensor
2586 
2587                        {n_cell, n_output},  // recurrent_to_input_weight_tensor
2588                        {n_cell, n_output},  // recurrent_to_forget_weight_tensor
2589                        {n_cell, n_output},  // recurrent_to_cell_weight_tensor
2590                        {n_cell, n_output},  // recurrent_to_output_weight_tensor
2591 
2592                        {0},  // cell_to_input_weight tensor
2593                        {0},  // cell_to_forget_weight tensor
2594                        {0},  // cell_to_output_weight tensor
2595 
2596                        {n_cell},  // input_gate_bias tensor
2597                        {n_cell},  // forget_gate_bias tensor
2598                        {n_cell},  // cell_bias tensor
2599                        {n_cell},  // output_gate_bias tensor
2600 
2601                        {0, 0},  // projection_weight tensor
2602                        {0},     // projection_bias tensor
2603                    },
2604                    /*weight_type=*/TensorType_FLOAT32);
2605 
2606   lstm.SetInputToInputWeights(input_to_input_weights_);
2607   lstm.SetInputToCellWeights(input_to_cell_weights_);
2608   lstm.SetInputToForgetWeights(input_to_forget_weights_);
2609   lstm.SetInputToOutputWeights(input_to_output_weights_);
2610 
2611   lstm.SetInputGateBias(input_gate_bias_);
2612   lstm.SetCellBias(cell_gate_bias_);
2613   lstm.SetForgetGateBias(forget_gate_bias_);
2614   lstm.SetOutputGateBias(output_gate_bias_);
2615 
2616   lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
2617   lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
2618   lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
2619   lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
2620 
2621   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
2622 }
2623 
2624 class CifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
SetUp()2625   void SetUp() override {
2626     input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726,
2627                               0.05100781,  0.04717243,  0.48944736,
2628                               -0.38535351, -0.17212132};
2629 
2630     input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988,
2631                                 -0.3633365,  -0.22755712, 0.28253698,
2632                                 0.24407166,  0.33826375};
2633 
2634     input_to_output_weights_ = {0.10725588,  -0.02335852, -0.55932593,
2635                                 -0.09426838, -0.44257352, 0.54939759,
2636                                 0.01533556,  0.42751634};
2637     cell_gate_bias_ = {0., 0., 0., 0.};
2638     forget_gate_bias_ = {1., 1., 1., 1.};
2639     output_gate_bias_ = {0., 0., 0., 0.};
2640 
2641     recurrent_to_cell_weights_ = {
2642         0.54066205,  -0.32668582, -0.43562764, -0.56094903,
2643         0.42957711,  0.01841056,  -0.32764608, -0.33027974,
2644         -0.10826075, 0.20675004,  0.19069612,  -0.03026325,
2645         -0.54532051, 0.33003211,  0.44901288,  0.21193194};
2646 
2647     recurrent_to_forget_weights_ = {
2648         -0.13832897, -0.0515101,  -0.2359007, -0.16661474,
2649         -0.14340827, 0.36986142,  0.23414481, 0.55899,
2650         0.10798943,  -0.41174671, 0.17751795, -0.34484994,
2651         -0.35874045, -0.11352962, 0.27268326, 0.54058349};
2652 
2653     recurrent_to_output_weights_ = {
2654         0.41613156, 0.42610586,  -0.16495961, -0.5663873,
2655         0.30579174, -0.05115908, -0.33941799, 0.23364776,
2656         0.11178309, 0.09481031,  -0.26424935, 0.46261835,
2657         0.50248802, 0.26114327,  -0.43736315, 0.33149987};
2658 
2659     cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408,
2660                                0.31544167};
2661     cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703,
2662                                -0.77109635};
2663 
2664     lstm_input_ = {{2., 3., 3., 4., 1., 1.}};
2665     lstm_golden_output_ = {{-0.36444446, -0.00352185, 0.12886585, -0.05163646,
2666                             -0.42312205, -0.01218222, 0.24201041, -0.08124574,
2667                             -0.358325, -0.04621704, 0.21641694, -0.06471302}};
2668   }
2669 };
2670 
TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest,LstmBlackBoxTest)2671 TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
2672   const int n_batch = 1;
2673   const int n_input = 2;
2674   // n_cell and n_output have the same size when there is no projection.
2675   const int n_cell = 4;
2676   const int n_output = 4;
2677 
2678   LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
2679                    /*use_cifg=*/true, /*use_peephole=*/true,
2680                    /*use_projection_weights=*/false,
2681                    /*use_projection_bias=*/false,
2682                    /*cell_clip=*/0.0, /*proj_clip=*/0.0,
2683                    {
2684                        {n_batch, n_input},  // input tensor
2685 
2686                        {0, 0},             // input_to_input_weight tensor
2687                        {n_cell, n_input},  // input_to_forget_weight tensor
2688                        {n_cell, n_input},  // input_to_cell_weight tensor
2689                        {n_cell, n_input},  // input_to_output_weight tensor
2690 
2691                        {0, 0},              // recurrent_to_input_weight tensor
2692                        {n_cell, n_output},  // recurrent_to_forget_weight tensor
2693                        {n_cell, n_output},  // recurrent_to_cell_weight tensor
2694                        {n_cell, n_output},  // recurrent_to_output_weight tensor
2695 
2696                        {0},       // cell_to_input_weight tensor
2697                        {n_cell},  // cell_to_forget_weight tensor
2698                        {n_cell},  // cell_to_output_weight tensor
2699 
2700                        {0},       // input_gate_bias tensor
2701                        {n_cell},  // forget_gate_bias tensor
2702                        {n_cell},  // cell_bias tensor
2703                        {n_cell},  // output_gate_bias tensor
2704 
2705                        {0, 0},  // projection_weight tensor
2706                        {0},     // projection_bias tensor
2707                    },
2708                    /*weight_type=*/TensorType_FLOAT32);
2709 
2710   lstm.SetInputToCellWeights(input_to_cell_weights_);
2711   lstm.SetInputToForgetWeights(input_to_forget_weights_);
2712   lstm.SetInputToOutputWeights(input_to_output_weights_);
2713 
2714   lstm.SetCellBias(cell_gate_bias_);
2715   lstm.SetForgetGateBias(forget_gate_bias_);
2716   lstm.SetOutputGateBias(output_gate_bias_);
2717 
2718   lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
2719   lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
2720   lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
2721 
2722   lstm.SetCellToForgetWeights(cell_to_forget_weights_);
2723   lstm.SetCellToOutputWeights(cell_to_output_weights_);
2724 
2725   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
2726 }
2727 
2728 class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest {
SetUp()2729   void SetUp() override {
2730     input_to_input_weights_ = {
2731         0.021393683,  0.06124551,    0.046905167,  -0.014657677,  -0.03149463,
2732         0.09171803,   0.14647801,    0.10797193,   -0.0057968358, 0.0019193048,
2733         -0.2726754,   0.10154029,    -0.018539885, 0.080349885,   -0.10262385,
2734         -0.022599787, -0.09121155,   -0.008675967, -0.045206103,  -0.0821282,
2735         -0.008045952, 0.015478081,   0.055217247,  0.038719587,   0.044153627,
2736         -0.06453243,  0.05031825,    -0.046935108, -0.008164439,  0.014574226,
2737         -0.1671009,   -0.15519552,   -0.16819797,  -0.13971269,   -0.11953059,
2738         0.25005487,   -0.22790983,   0.009855087,  -0.028140958,  -0.11200698,
2739         0.11295408,   -0.0035217577, 0.054485075,  0.05184695,    0.064711206,
2740         0.10989193,   0.11674786,    0.03490607,   0.07727357,    0.11390585,
2741         -0.1863375,   -0.1034451,    -0.13945189,  -0.049401227,  -0.18767063,
2742         0.042483903,  0.14233552,    0.13832581,   0.18350165,    0.14545603,
2743         -0.028545704, 0.024939531,   0.050929718,  0.0076203286,  -0.0029723682,
2744         -0.042484224, -0.11827596,   -0.09171104,  -0.10808628,   -0.16327988,
2745         -0.2273378,   -0.0993647,    -0.017155107, 0.0023917493,  0.049272764,
2746         0.0038534778, 0.054764505,   0.089753784,  0.06947234,    0.08014476,
2747         -0.04544234,  -0.0497073,    -0.07135631,  -0.048929106,  -0.004042012,
2748         -0.009284026, 0.018042054,   0.0036860977, -0.07427302,   -0.11434604,
2749         -0.018995456, 0.031487543,   0.012834908,  0.019977754,   0.044256654,
2750         -0.39292613,  -0.18519334,   -0.11651281,  -0.06809892,   0.011373677};
2751 
2752     input_to_forget_weights_ = {
2753         -0.0018401089, -0.004852237, 0.03698424,    0.014181704,
2754         0.028273236,   -0.016726194, -0.05249759,   -0.10204261,
2755         0.00861066,    -0.040979505, -0.009899187,  0.01923892,
2756         -0.028177269,  -0.08535103,  -0.14585495,   0.10662567,
2757         -0.01909731,   -0.017883534, -0.0047269356, -0.045103323,
2758         0.0030784295,  0.076784775,  0.07463696,    0.094531395,
2759         0.0814421,     -0.12257899,  -0.033945758,  -0.031303465,
2760         0.045630626,   0.06843887,   -0.13492945,   -0.012480007,
2761         -0.0811829,    -0.07224499,  -0.09628791,   0.045100946,
2762         0.0012300825,  0.013964662,  0.099372394,   0.02543059,
2763         0.06958324,    0.034257296,  0.0482646,     0.06267997,
2764         0.052625068,   0.12784666,   0.07077897,    0.025725935,
2765         0.04165009,    0.07241905,   0.018668644,   -0.037377294,
2766         -0.06277783,   -0.08833636,  -0.040120605,  -0.011405586,
2767         -0.007808335,  -0.010301386, -0.005102167,  0.027717464,
2768         0.05483423,    0.11449111,   0.11289652,    0.10939839,
2769         0.13396506,    -0.08402166,  -0.01901462,   -0.044678304,
2770         -0.07720565,   0.014350063,  -0.11757958,   -0.0652038,
2771         -0.08185733,   -0.076754324, -0.092614375,  0.10405491,
2772         0.052960336,   0.035755895,  0.035839386,   -0.012540553,
2773         0.036881298,   0.02913376,   0.03420159,    0.05448447,
2774         -0.054523353,  0.02582715,   0.02327355,    -0.011857179,
2775         -0.0011980024, -0.034641717, -0.026125094,  -0.17582615,
2776         -0.15923657,   -0.27486774,  -0.0006143371, 0.0001771948,
2777         -8.470171e-05, 0.02651807,   0.045790765,   0.06956496};
2778 
2779     input_to_cell_weights_ = {
2780         -0.04580283,   -0.09549462,   -0.032418985,  -0.06454633,
2781         -0.043528453,  0.043018587,   -0.049152344,  -0.12418144,
2782         -0.078985475,  -0.07596889,   0.019484362,   -0.11434962,
2783         -0.0074034138, -0.06314844,   -0.092981495,  0.0062155537,
2784         -0.025034338,  -0.0028890965, 0.048929527,   0.06235075,
2785         0.10665918,    -0.032036792,  -0.08505916,   -0.10843358,
2786         -0.13002433,   -0.036816437,  -0.02130134,   -0.016518239,
2787         0.0047691227,  -0.0025825808, 0.066017866,   0.029991534,
2788         -0.10652836,   -0.1037554,    -0.13056071,   -0.03266643,
2789         -0.033702414,  -0.006473424,  -0.04611692,   0.014419339,
2790         -0.025174323,  0.0396852,     0.081777506,   0.06157468,
2791         0.10210095,    -0.009658194,  0.046511717,   0.03603906,
2792         0.0069369148,  0.015960095,   -0.06507666,   0.09551598,
2793         0.053568836,   0.06408714,    0.12835667,    -0.008714329,
2794         -0.20211966,   -0.12093674,   0.029450472,   0.2849013,
2795         -0.029227901,  0.1164364,     -0.08560263,   0.09941786,
2796         -0.036999565,  -0.028842626,  -0.0033637602, -0.017012902,
2797         -0.09720865,   -0.11193351,   -0.029155117,  -0.017936034,
2798         -0.009768936,  -0.04223324,   -0.036159635,  0.06505112,
2799         -0.021742892,  -0.023377212,  -0.07221364,   -0.06430552,
2800         0.05453865,    0.091149814,   0.06387331,    0.007518393,
2801         0.055960953,   0.069779344,   0.046411168,   0.10509911,
2802         0.07463894,    0.0075130584,  0.012850982,   0.04555431,
2803         0.056955688,   0.06555285,    0.050801456,   -0.009862683,
2804         0.00826772,    -0.026555609,  -0.0073611983, -0.0014897042};
2805 
2806     input_to_output_weights_ = {
2807         -0.0998932,   -0.07201956,  -0.052803773,  -0.15629593,  -0.15001918,
2808         -0.07650751,  0.02359855,   -0.075155355,  -0.08037709,  -0.15093534,
2809         0.029517552,  -0.04751393,  0.010350531,   -0.02664851,  -0.016839722,
2810         -0.023121163, 0.0077019283, 0.012851257,   -0.05040649,  -0.0129761,
2811         -0.021737747, -0.038305793, -0.06870586,   -0.01481247,  -0.001285394,
2812         0.10124236,   0.083122835,  0.053313006,   -0.062235646, -0.075637154,
2813         -0.027833903, 0.029774971,  0.1130802,     0.09218906,   0.09506135,
2814         -0.086665764, -0.037162706, -0.038880914,  -0.035832845, -0.014481564,
2815         -0.09825003,  -0.12048569,  -0.097665586,  -0.05287633,  -0.0964047,
2816         -0.11366429,  0.035777505,  0.13568819,    0.052451383,  0.050649304,
2817         0.05798951,   -0.021852335, -0.099848844,  0.014740475,  -0.078897946,
2818         0.04974699,   0.014160473,  0.06973932,    0.04964942,   0.033364646,
2819         0.08190124,   0.025535367,  0.050893165,   0.048514254,  0.06945813,
2820         -0.078907564, -0.06707616,  -0.11844508,   -0.09986688,  -0.07509403,
2821         0.06263226,   0.14925587,   0.20188436,    0.12098451,   0.14639415,
2822         0.0015017595, -0.014267382, -0.03417257,   0.012711468,  0.0028300495,
2823         -0.024758482, -0.05098548,  -0.0821182,    0.014225672,  0.021544158,
2824         0.08949725,   0.07505268,   -0.0020780868, 0.04908258,   0.06476295,
2825         -0.022907063, 0.027562456,  0.040185735,   0.019567577,  -0.015598739,
2826         -0.049097303, -0.017121866, -0.083368234,  -0.02332002,  -0.0840956};
2827 
2828     input_gate_bias_ = {0.02234832,   0.14757581,  0.18176508,  0.10380666,
2829                         0.053110216,  -0.06928846, -0.13942584, -0.11816189,
2830                         0.19483899,   0.03652339,  -0.10250295, 0.036714908,
2831                         -0.18426876,  0.036065217, 0.21810818,  0.02383196,
2832                         -0.043370757, 0.08690144,  -0.04444982, 0.00030581196};
2833 
2834     forget_gate_bias_ = {0.035185695, -0.042891346, -0.03032477, 0.23027696,
2835                          0.11098921,  0.15378423,   0.09263801,  0.09790885,
2836                          0.09508917,  0.061199076,  0.07665568,  -0.015443159,
2837                          -0.03499149, 0.046190713,  0.08895977,  0.10899629,
2838                          0.40694186,  0.06030037,   0.012413437, -0.06108739};
2839 
2840     cell_gate_bias_ = {-0.024379363, 0.0055531194, 0.23377132,   0.033463873,
2841                        -0.1483596,   -0.10639995,  -0.091433935, 0.058573797,
2842                        -0.06809782,  -0.07889636,  -0.043246906, -0.09829136,
2843                        -0.4279842,   0.034901652,  0.18797937,   0.0075234566,
2844                        0.016178843,  0.1749513,    0.13975595,   0.92058027};
2845 
2846     output_gate_bias_ = {0.046159424, -0.0012809046, 0.03563469,   0.12648113,
2847                          0.027195795, 0.35373217,    -0.018957434, 0.008907322,
2848                          -0.0762701,  0.12018895,    0.04216877,   0.0022856654,
2849                          0.040952638, 0.3147856,     0.08225149,   -0.057416286,
2850                          -0.14995944, -0.008040261,  0.13208859,   0.029760877};
2851 
2852     recurrent_to_input_weights_ = {
2853         -0.001374326,   -0.078856036,   0.10672688,    0.029162422,
2854         -0.11585556,    0.02557986,     -0.13446963,   -0.035785314,
2855         -0.01244275,    0.025961924,    -0.02337298,   -0.044228926,
2856         -0.055839065,   -0.046598054,   -0.010546039,  -0.06900766,
2857         0.027239809,    0.022582639,    -0.013296484,  -0.05459212,
2858         0.08981,        -0.045407712,   0.08682226,    -0.06867011,
2859         -0.14390695,    -0.02916037,    0.000996957,   0.091420636,
2860         0.14283475,     -0.07390571,    -0.06402044,   0.062524505,
2861         -0.093129106,   0.04860203,     -0.08364217,   -0.08119002,
2862         0.009352075,    0.22920375,     0.0016303885,  0.11583097,
2863         -0.13732095,    0.012405723,    -0.07551853,   0.06343048,
2864         0.12162708,     -0.031923793,   -0.014335606,  0.01790974,
2865         -0.10650317,    -0.0724401,     0.08554849,    -0.05727212,
2866         0.06556731,     -0.042729504,   -0.043227166,  0.011683251,
2867         -0.013082158,   -0.029302018,   -0.010899579,  -0.062036745,
2868         -0.022509435,   -0.00964907,    -0.01567329,   0.04260106,
2869         -0.07787477,    -0.11576462,    0.017356863,   0.048673786,
2870         -0.017577527,   -0.05527947,    -0.082487635,  -0.040137455,
2871         -0.10820036,    -0.04666372,    0.022746278,   -0.07851417,
2872         0.01068115,     0.032956902,    0.022433773,   0.0026891115,
2873         0.08944216,     -0.0685835,     0.010513544,   0.07228705,
2874         0.02032331,     -0.059686817,   -0.0005566496, -0.086984694,
2875         0.040414046,    -0.1380399,     0.094208956,   -0.05722982,
2876         0.012092817,    -0.04989123,    -0.086576,     -0.003399834,
2877         -0.04696032,    -0.045747425,   0.10091314,    0.048676282,
2878         -0.029037097,   0.031399418,    -0.0040285117, 0.047237843,
2879         0.09504992,     0.041799378,    -0.049185462,  -0.031518843,
2880         -0.10516937,    0.026374253,    0.10058866,    -0.0033195973,
2881         -0.041975245,   0.0073591834,   0.0033782164,  -0.004325073,
2882         -0.10167381,    0.042500053,    -0.01447153,   0.06464186,
2883         -0.017142897,   0.03312627,     0.009205989,   0.024138335,
2884         -0.011337001,   0.035530265,    -0.010912711,  0.0706555,
2885         -0.005894094,   0.051841937,    -0.1401738,    -0.02351249,
2886         0.0365468,      0.07590991,     0.08838724,    0.021681072,
2887         -0.10086113,    0.019608743,    -0.06195883,   0.077335775,
2888         0.023646897,    -0.095322326,   0.02233014,    0.09756986,
2889         -0.048691444,   -0.009579111,   0.07595467,    0.11480546,
2890         -0.09801813,    0.019894179,    0.08502348,    0.004032281,
2891         0.037211012,    0.068537936,    -0.048005626,  -0.091520436,
2892         -0.028379958,   -0.01556313,    0.06554592,    -0.045599163,
2893         -0.01672207,    -0.020169014,   -0.011877351,  -0.20212261,
2894         0.010889619,    0.0047078193,   0.038385306,   0.08540671,
2895         -0.017140968,   -0.0035865551,  0.016678626,   0.005633034,
2896         0.015963363,    0.00871737,     0.060130805,   0.028611384,
2897         0.10109069,     -0.015060172,   -0.07894427,   0.06401885,
2898         0.011584063,    -0.024466386,   0.0047652307,  -0.09041358,
2899         0.030737216,    -0.0046374933,  0.14215417,    -0.11823516,
2900         0.019899689,    0.006106124,    -0.027092824,  0.0786356,
2901         0.05052217,     -0.058925,      -0.011402121,  -0.024987547,
2902         -0.0013661642,  -0.06832946,    -0.015667673,  -0.1083353,
2903         -0.00096863037, -0.06988685,    -0.053350925,  -0.027275559,
2904         -0.033664223,   -0.07978348,    -0.025200296,  -0.017207067,
2905         -0.058403496,   -0.055697463,   0.005798788,   0.12965427,
2906         -0.062582195,   0.0013350133,   -0.10482091,   0.0379771,
2907         0.072521195,    -0.0029455067,  -0.13797039,   -0.03628521,
2908         0.013806405,    -0.017858358,   -0.01008298,   -0.07700066,
2909         -0.017081132,   0.019358726,    0.0027079724,  0.004635139,
2910         0.062634714,    -0.02338735,    -0.039547626,  -0.02050681,
2911         0.03385117,     -0.083611414,   0.002862572,   -0.09421313,
2912         0.058618143,    -0.08598433,    0.00972939,    0.023867095,
2913         -0.053934585,   -0.023203006,   0.07452513,    -0.048767887,
2914         -0.07314807,    -0.056307215,   -0.10433547,   -0.06440842,
2915         0.04328182,     0.04389765,     -0.020006588,  -0.09076438,
2916         -0.11652589,    -0.021705797,   0.03345259,    -0.010329105,
2917         -0.025767034,   0.013057034,    -0.07316461,   -0.10145612,
2918         0.06358255,     0.18531723,     0.07759293,    0.12006465,
2919         0.1305557,      0.058638252,    -0.03393652,   0.09622831,
2920         -0.16253184,    -2.4580743e-06, 0.079869635,   -0.070196845,
2921         -0.005644518,   0.06857898,     -0.12598175,   -0.035084512,
2922         0.03156317,     -0.12794146,    -0.031963028,  0.04692781,
2923         0.030070418,    0.0071660685,   -0.095516115,  -0.004643372,
2924         0.040170413,    -0.062104587,   -0.0037324072, 0.0554317,
2925         0.08184801,     -0.019164372,   0.06791302,    0.034257166,
2926         -0.10307039,    0.021943003,    0.046745934,   0.0790918,
2927         -0.0265588,     -0.007824208,   0.042546265,   -0.00977924,
2928         -0.0002440307,  -0.017384544,   -0.017990116,  0.12252321,
2929         -0.014512694,   -0.08251313,    0.08861942,    0.13589665,
2930         0.026351685,    0.012641483,    0.07466548,    0.044301085,
2931         -0.045414884,   -0.051112458,   0.03444247,    -0.08502782,
2932         -0.04106223,    -0.028126027,   0.028473156,   0.10467447};
2933 
2934     recurrent_to_cell_weights_ = {
2935         -0.037322544,   0.018592842,   0.0056175636,  -0.06253426,
2936         0.055647098,    -0.05713207,   -0.05626563,   0.005559383,
2937         0.03375411,     -0.025757805,  -0.088049285,  0.06017052,
2938         -0.06570978,    0.007384076,   0.035123326,   -0.07920549,
2939         0.053676967,    0.044480428,   -0.07663568,   0.0071805613,
2940         0.08089997,     0.05143358,    0.038261272,   0.03339287,
2941         -0.027673481,   0.044746667,   0.028349208,   0.020090483,
2942         -0.019443132,   -0.030755889,  -0.0040000007, 0.04465846,
2943         -0.021585021,   0.0031670958,  0.0053199246,  -0.056117613,
2944         -0.10893326,    0.076739706,   -0.08509834,   -0.027997585,
2945         0.037871376,    0.01449768,    -0.09002357,   -0.06111149,
2946         -0.046195522,   0.0422062,     -0.005683705,  -0.1253618,
2947         -0.012925729,   -0.04890792,   0.06985068,    0.037654128,
2948         0.03398274,     -0.004781977,  0.007032333,   -0.031787455,
2949         0.010868644,    -0.031489216,  0.09525667,    0.013939797,
2950         0.0058680447,   0.0167067,     0.02668468,    -0.04797466,
2951         -0.048885044,   -0.12722108,   0.035304096,   0.06554885,
2952         0.00972396,     -0.039238118,  -0.05159735,   -0.11329045,
2953         0.1613692,      -0.03750952,   0.06529313,    -0.071974665,
2954         -0.11769596,    0.015524369,   -0.0013754242, -0.12446318,
2955         0.02786344,     -0.014179351,  0.005264273,   0.14376344,
2956         0.015983658,    0.03406988,    -0.06939408,   0.040699873,
2957         0.02111075,     0.09669095,    0.041345075,   -0.08316494,
2958         -0.07684199,    -0.045768797,  0.032298047,   -0.041805092,
2959         0.0119405,      0.0061010392,  0.12652606,    0.0064572375,
2960         -0.024950314,   0.11574242,    0.04508852,    -0.04335324,
2961         0.06760663,     -0.027437469,  0.07216407,    0.06977076,
2962         -0.05438599,    0.034033038,   -0.028602652,  0.05346137,
2963         0.043184172,    -0.037189785,  0.10420091,    0.00882477,
2964         -0.054019816,   -0.074273005,  -0.030617684,  -0.0028467078,
2965         0.024302477,    -0.0038869337, 0.005332455,   0.0013399826,
2966         0.04361412,     -0.007001822,  0.09631092,    -0.06702025,
2967         -0.042049985,   -0.035070654,  -0.04103342,   -0.10273396,
2968         0.0544271,      0.037184782,   -0.13150354,   -0.0058036847,
2969         -0.008264958,   0.042035464,   0.05891794,    0.029673764,
2970         0.0063542654,   0.044788733,   0.054816857,   0.062257513,
2971         -0.00093483756, 0.048938446,   -0.004952862,  -0.007730018,
2972         -0.04043371,    -0.017094059,  0.07229206,    -0.023670016,
2973         -0.052195564,   -0.025616996,  -0.01520939,   0.045104615,
2974         -0.007376126,   0.003533447,   0.006570588,   0.056037236,
2975         0.12436656,     0.051817212,   0.028532185,   -0.08686856,
2976         0.11868599,     0.07663395,    -0.07323171,   0.03463402,
2977         -0.050708205,   -0.04458982,   -0.11590894,   0.021273347,
2978         0.1251325,      -0.15313013,   -0.12224372,   0.17228661,
2979         0.023029093,    0.086124025,   0.006445803,   -0.03496501,
2980         0.028332196,    0.04449512,    -0.042436164,  -0.026587414,
2981         -0.006041347,   -0.09292539,   -0.05678812,   0.03897832,
2982         0.09465633,     0.008115513,   -0.02171956,   0.08304309,
2983         0.071401566,    0.019622514,   0.032163795,   -0.004167056,
2984         0.02295182,     0.030739572,   0.056506045,   0.004612461,
2985         0.06524936,     0.059999723,   0.046395954,   -0.0045512207,
2986         -0.1335546,     -0.030136576,  0.11584653,    -0.014678886,
2987         0.0020118146,   -0.09688814,   -0.0790206,    0.039770417,
2988         -0.0329582,     0.07922767,    0.029322514,   0.026405897,
2989         0.04207835,     -0.07073373,   0.063781224,   0.0859677,
2990         -0.10925287,    -0.07011058,   0.048005477,   0.03438226,
2991         -0.09606514,    -0.006669445,  -0.043381985,  0.04240257,
2992         -0.06955775,    -0.06769346,   0.043903265,   -0.026784198,
2993         -0.017840602,   0.024307009,   -0.040079936,  -0.019946516,
2994         0.045318738,    -0.12233574,   0.026170589,   0.0074471775,
2995         0.15978073,     0.10185836,    0.10298046,    -0.015476589,
2996         -0.039390966,   -0.072174534,  0.0739445,     -0.1211869,
2997         -0.0347889,     -0.07943156,   0.014809798,   -0.12412325,
2998         -0.0030663363,  0.039695457,   0.0647603,     -0.08291318,
2999         -0.018529687,   -0.004423833,  0.0037507233,  0.084633216,
3000         -0.01514876,    -0.056505352,  -0.012800942,  -0.06994386,
3001         0.012962922,    -0.031234352,  0.07029052,    0.016418684,
3002         0.03618972,     0.055686004,   -0.08663945,   -0.017404709,
3003         -0.054761406,   0.029065743,   0.052404847,   0.020238016,
3004         0.0048197987,   -0.0214882,    0.07078733,    0.013016777,
3005         0.06262858,     0.009184685,   0.020785125,   -0.043904778,
3006         -0.0270329,     -0.03299152,   -0.060088247,  -0.015162964,
3007         -0.001828936,   0.12642565,    -0.056757294,  0.013586685,
3008         0.09232601,     -0.035886683,  0.06000002,    0.05229691,
3009         -0.052580316,   -0.082029596,  -0.010794592,  0.012947712,
3010         -0.036429964,   -0.085508935,  -0.13127148,   -0.017744139,
3011         0.031502828,    0.036232427,   -0.031581745,  0.023051167,
3012         -0.05325106,    -0.03421577,   0.028793324,   -0.034633752,
3013         -0.009881397,   -0.043551125,  -0.018609839,  0.0019097115,
3014         -0.008799762,   0.056595087,   0.0022273948,  0.055752404};
3015 
3016     recurrent_to_forget_weights_ = {
3017         -0.057784554,  -0.026057621,  -0.068447545,   -0.022581743,
3018         0.14811787,    0.10826372,    0.09471067,     0.03987225,
3019         -0.0039523416, 0.00030638507, 0.053185795,    0.10572994,
3020         0.08414449,    -0.022036452,  -0.00066928595, -0.09203576,
3021         0.032950465,   -0.10985798,   -0.023809856,   0.0021431844,
3022         -0.02196096,   -0.00326074,   0.00058621005,  -0.074678116,
3023         -0.06193199,   0.055729095,   0.03736828,     0.020123724,
3024         0.061878487,   -0.04729229,   0.034919553,    -0.07585433,
3025         -0.04421272,   -0.044019096,  0.085488975,    0.04058006,
3026         -0.06890133,   -0.030951202,  -0.024628663,   -0.07672815,
3027         0.034293607,   0.08556707,    -0.05293577,    -0.033561368,
3028         -0.04899627,   0.0241671,     0.015736353,    -0.095442444,
3029         -0.029564252,  0.016493602,   -0.035026584,   0.022337519,
3030         -0.026871363,  0.004780428,   0.0077918363,   -0.03601621,
3031         0.016435321,   -0.03263031,   -0.09543275,    -0.047392778,
3032         0.013454138,   0.028934088,   0.01685226,     -0.086110644,
3033         -0.046250615,  -0.01847454,   0.047608484,    0.07339695,
3034         0.034546845,   -0.04881143,   0.009128804,    -0.08802852,
3035         0.03761666,    0.008096139,   -0.014454086,   0.014361001,
3036         -0.023502491,  -0.0011840804, -0.07607001,    0.001856849,
3037         -0.06509276,   -0.006021153,  -0.08570962,    -0.1451793,
3038         0.060212336,   0.055259194,   0.06974018,     0.049454916,
3039         -0.027794661,  -0.08077226,   -0.016179763,   0.1169753,
3040         0.17213494,    -0.0056326236, -0.053934924,   -0.0124349,
3041         -0.11520337,   0.05409887,    0.088759385,    0.0019655675,
3042         0.0042065294,  0.03881498,    0.019844765,    0.041858196,
3043         -0.05695512,   0.047233116,   0.038937137,    -0.06542224,
3044         0.014429736,   -0.09719407,   0.13908425,     -0.05379757,
3045         0.012321099,   0.082840554,   -0.029899208,   0.044217527,
3046         0.059855383,   0.07711018,    -0.045319796,   0.0948846,
3047         -0.011724666,  -0.0033288454, -0.033542685,   -0.04764985,
3048         -0.13873616,   0.040668588,   0.034832682,    -0.015319203,
3049         -0.018715994,  0.046002675,   0.0599172,      -0.043107376,
3050         0.0294216,     -0.002314414,  -0.022424703,   0.0030315618,
3051         0.0014641669,  0.0029166266,  -0.11878115,    0.013738511,
3052         0.12375372,    -0.0006038222, 0.029104086,    0.087442465,
3053         0.052958444,   0.07558703,    0.04817258,     0.044462286,
3054         -0.015213451,  -0.08783778,   -0.0561384,     -0.003008196,
3055         0.047060397,   -0.002058388,  0.03429439,     -0.018839769,
3056         0.024734668,   0.024614193,   -0.042046934,   0.09597743,
3057         -0.0043254104, 0.04320769,    0.0064070094,   -0.0019131786,
3058         -0.02558259,   -0.022822596,  -0.023273505,   -0.02464396,
3059         -0.10991725,   -0.006240552,  0.0074488563,   0.024044557,
3060         0.04383914,    -0.046476185,  0.028658995,    0.060410924,
3061         0.050786525,   0.009452605,   -0.0073054377,  -0.024810238,
3062         0.0052906186,  0.0066939713,  -0.0020913032,  0.014515517,
3063         0.015898481,   0.021362653,   -0.030262267,   0.016587038,
3064         -0.011442813,  0.041154444,   -0.007631438,   -0.03423484,
3065         -0.010977775,  0.036152758,   0.0066366293,   0.11915515,
3066         0.02318443,    -0.041350313,  0.021485701,    -0.10906167,
3067         -0.028218046,  -0.00954771,   0.020531068,    -0.11995105,
3068         -0.03672871,   0.024019798,   0.014255957,    -0.05221243,
3069         -0.00661567,   -0.04630967,   0.033188973,    0.10107534,
3070         -0.014027541,  0.030796422,   -0.10270911,    -0.035999842,
3071         0.15443139,    0.07684145,    0.036571592,    -0.035900835,
3072         -0.0034699554, 0.06209149,    0.015920248,    -0.031122351,
3073         -0.03858649,   0.01849943,    0.13872518,     0.01503974,
3074         0.069941424,   -0.06948533,   -0.0088794185,  0.061282158,
3075         -0.047401894,  0.03100163,    -0.041533746,   -0.10430945,
3076         0.044574402,   -0.01425562,   -0.024290353,   0.034563623,
3077         0.05866852,    0.023947537,   -0.09445152,    0.035450947,
3078         0.02247216,    -0.0042998926, 0.061146557,    -0.10250651,
3079         0.020881841,   -0.06747029,   0.10062043,     -0.0023941975,
3080         0.03532124,    -0.016341697,  0.09685456,     -0.016764693,
3081         0.051808182,   0.05875331,    -0.04536488,    0.001626336,
3082         -0.028892258,  -0.01048663,   -0.009793449,   -0.017093895,
3083         0.010987891,   0.02357273,    -0.00010856845, 0.0099760275,
3084         -0.001845119,  -0.03551521,   0.0018358806,   0.05763657,
3085         -0.01769146,   0.040995963,   0.02235177,     -0.060430344,
3086         0.11475477,    -0.023854522,  0.10071741,     0.0686208,
3087         -0.014250481,  0.034261297,   0.047418304,    0.08562733,
3088         -0.030519066,  0.0060542435,  0.014653856,    -0.038836084,
3089         0.04096551,    0.032249358,   -0.08355519,    -0.026823482,
3090         0.056386515,   -0.010401743,  -0.028396193,   0.08507674,
3091         0.014410365,   0.020995233,   0.17040324,     0.11511526,
3092         0.02459721,    0.0066619175,  0.025853224,    -0.023133837,
3093         -0.081302024,  0.017264642,   -0.009585969,   0.09491168,
3094         -0.051313367,  0.054532815,   -0.014298593,   0.10657464,
3095         0.007076659,   0.10964551,    0.0409152,      0.008275321,
3096         -0.07283536,   0.07937492,    0.04192024,     -0.1075027};
3097 
3098     recurrent_to_output_weights_ = {
3099         0.025825322,   -0.05813119,   0.09495884,     -0.045984812,
3100         -0.01255415,   -0.0026479573, -0.08196161,    -0.054914974,
3101         -0.0046604523, -0.029587349,  -0.044576716,   -0.07480124,
3102         -0.082868785,  0.023254942,   0.027502948,    -0.0039728214,
3103         -0.08683098,   -0.08116779,   -0.014675607,   -0.037924774,
3104         -0.023314456,  -0.007401714,  -0.09255757,    0.029460307,
3105         -0.08829125,   -0.005139627,  -0.08989442,    -0.0555066,
3106         0.13596267,    -0.025062224,  -0.048351806,   -0.03850004,
3107         0.07266485,    -0.022414139,  0.05940088,     0.075114764,
3108         0.09597592,    -0.010211725,  -0.0049794707,  -0.011523867,
3109         -0.025980417,  0.072999895,   0.11091378,     -0.081685916,
3110         0.014416728,   0.043229222,   0.034178585,    -0.07530371,
3111         0.035837382,   -0.085607,     -0.007721233,   -0.03287832,
3112         -0.043848954,  -0.06404588,   -0.06632928,    -0.073643476,
3113         0.008214239,   -0.045984086,  0.039764922,    0.03474462,
3114         0.060612556,   -0.080590084,  0.049127717,    0.04151091,
3115         -0.030063879,  0.008801774,   -0.023021035,   -0.019558564,
3116         0.05158114,    -0.010947698,  -0.011825728,   0.0075720972,
3117         0.0699727,     -0.0039981045, 0.069350146,    0.08799282,
3118         0.016156472,   0.035502106,   0.11695009,     0.006217345,
3119         0.13392477,    -0.037875112,  0.025745004,    0.08940699,
3120         -0.00924166,   0.0046702605,  -0.036598757,   -0.08811812,
3121         0.10522024,    -0.032441203,  0.008176899,    -0.04454919,
3122         0.07058152,    0.0067963637,  0.039206743,    0.03259838,
3123         0.03725492,    -0.09515802,   0.013326398,    -0.052055415,
3124         -0.025676316,  0.03198509,    -0.015951829,   -0.058556724,
3125         0.036879618,   0.043357447,   0.028362012,    -0.05908629,
3126         0.0059240665,  -0.04995891,   -0.019187413,   0.0276265,
3127         -0.01628143,   0.0025863599,  0.08800015,     0.035250366,
3128         -0.022165963,  -0.07328642,   -0.009415526,   -0.07455109,
3129         0.11690406,    0.0363299,     0.07411125,     0.042103454,
3130         -0.009660886,  0.019076364,   0.018299393,    -0.046004917,
3131         0.08891175,    0.0431396,     -0.026327137,   -0.051502608,
3132         0.08979574,    -0.051670972,  0.04940282,     -0.07491107,
3133         -0.021240504,  0.022596184,   -0.034280192,   0.060163025,
3134         -0.058211457,  -0.051837247,  -0.01349775,    -0.04639988,
3135         -0.035936575,  -0.011681591,  0.064818054,    0.0073146066,
3136         -0.021745546,  -0.043124277,  -0.06471268,    -0.07053354,
3137         -0.029321948,  -0.05330136,   0.016933719,    -0.053782392,
3138         0.13747959,    -0.1361751,    -0.11569455,    0.0033329215,
3139         0.05693899,    -0.053219706,  0.063698,       0.07977434,
3140         -0.07924483,   0.06936997,    0.0034815092,   -0.007305279,
3141         -0.037325785,  -0.07251102,   -0.033633437,   -0.08677009,
3142         0.091591336,   -0.14165086,   0.021752775,    0.019683983,
3143         0.0011612234,  -0.058154266,  0.049996935,    0.0288841,
3144         -0.0024567875, -0.14345716,   0.010955264,    -0.10234828,
3145         0.1183656,     -0.0010731248, -0.023590032,   -0.072285876,
3146         -0.0724771,    -0.026382286,  -0.0014920527,  0.042667855,
3147         0.0018776858,  0.02986552,    0.009814309,    0.0733756,
3148         0.12289186,    0.018043943,   -0.0458958,     0.049412545,
3149         0.033632483,   0.05495232,    0.036686596,    -0.013781798,
3150         -0.010036754,  0.02576849,    -0.08307328,    0.010112348,
3151         0.042521734,   -0.05869831,   -0.071689695,   0.03876447,
3152         -0.13275425,   -0.0352966,    -0.023077697,   0.10285965,
3153         0.084736146,   0.15568255,    -0.00040734606, 0.027835453,
3154         -0.10292561,   -0.032401145,  0.10053256,     -0.026142767,
3155         -0.08271222,   -0.0030240538, -0.016368777,   0.1070414,
3156         0.042672627,   0.013456989,   -0.0437609,     -0.022309763,
3157         0.11576483,    0.04108048,    0.061026827,    -0.0190714,
3158         -0.0869359,    0.037901703,   0.0610107,      0.07202949,
3159         0.01675338,    0.086139716,   -0.08795751,    -0.014898893,
3160         -0.023771819,  -0.01965048,   0.007955471,    -0.043740474,
3161         0.03346837,    -0.10549954,   0.090567775,    0.042013682,
3162         -0.03176985,   0.12569028,    -0.02421228,    -0.029526481,
3163         0.023851605,   0.031539805,   0.05292009,     -0.02344001,
3164         -0.07811758,   -0.08834428,   0.10094801,     0.16594367,
3165         -0.06861939,   -0.021256343,  -0.041093912,   -0.06669611,
3166         0.035498552,   0.021757556,   -0.09302526,    -0.015403468,
3167         -0.06614931,   -0.051798206,  -0.013874718,   0.03630673,
3168         0.010412845,   -0.08077351,   0.046185967,    0.0035662893,
3169         0.03541868,    -0.094149634,  -0.034814864,   0.003128424,
3170         -0.020674974,  -0.03944324,   -0.008110165,   -0.11113267,
3171         0.08484226,    0.043586485,   0.040582247,    0.0968012,
3172         -0.065249965,  -0.028036479,  0.0050708856,   0.0017462453,
3173         0.0326779,     0.041296225,   0.09164146,     -0.047743853,
3174         -0.015952192,  -0.034451712,  0.084197424,    -0.05347844,
3175         -0.11768019,   0.085926116,   -0.08251791,    -0.045081906,
3176         0.0948852,     0.068401024,   0.024856757,    0.06978981,
3177         -0.057309967,  -0.012775832,  -0.0032452994,  0.01977615,
3178         -0.041040014,  -0.024264973,  0.063464895,    0.05431621,
3179     };
3180 
3181     cell_to_input_weights_ = {
3182         0.040369894, 0.030746894,  0.24704495,  0.018586371,  -0.037586458,
3183         -0.15312155, -0.11812848,  -0.11465643, 0.20259799,   0.11418174,
3184         -0.10116027, -0.011334949, 0.12411352,  -0.076769054, -0.052169047,
3185         0.21198851,  -0.38871562,  -0.09061183, -0.09683246,  -0.21929175};
3186 
3187     cell_to_forget_weights_ = {
3188         -0.01998659,  -0.15568835,  -0.24248174,   -0.012770197, 0.041331276,
3189         -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766,
3190         -0.047248036, 0.021479502,  0.033189066,   0.11952997,   -0.020432774,
3191         0.64658105,   -0.06650122,  -0.03467612,   0.095340036,  0.23647355};
3192 
3193     cell_to_output_weights_ = {
3194         0.08286371,  -0.08261836, -0.51210177, 0.002913762, 0.17764764,
3195         -0.5495371,  -0.08460716, -0.24552552, 0.030037103, 0.04123544,
3196         -0.11940523, 0.007358328, 0.1890978,   0.4833202,   -0.34441817,
3197         0.36312827,  -0.26375428, 0.1457655,   -0.19724406, 0.15548733};
3198 
3199     projection_weights_ = {
3200         -0.009802181, 0.09401916,   0.0717386,     -0.13895074,
3201         0.09641832,   0.060420845,  0.08539281,    0.054285463,
3202         0.061395317,  0.034448683,  -0.042991187,  0.019801661,
3203         -0.16840284,  -0.015726732, -0.23041931,   -0.024478018,
3204         -0.10959692,  -0.013875541, 0.18600968,    -0.061274476,
3205         0.0138165,    -0.08160894,  -0.07661644,   0.032372914,
3206         0.16169067,   0.22465782,   -0.03993472,   -0.004017731,
3207         0.08633481,   -0.28869787,  0.08682067,    0.17240396,
3208         0.014975425,  0.056431185,  0.031037588,   0.16702051,
3209         0.0077946745, 0.15140012,   0.29405436,    0.120285,
3210         -0.188994,    -0.027265169, 0.043389652,   -0.022061434,
3211         0.014777949,  -0.20203483,  0.094781205,   0.19100232,
3212         0.13987629,   -0.036132768, -0.06426278,   -0.05108664,
3213         0.13221376,   0.009441198,  -0.16715929,   0.15859416,
3214         -0.040437475, 0.050779544,  -0.022187516,  0.012166504,
3215         0.027685808,  -0.07675938,  -0.0055694645, -0.09444123,
3216         0.0046453946, 0.050794356,  0.10770313,    -0.20790008,
3217         -0.07149004,  -0.11425117,  0.008225835,   -0.035802525,
3218         0.14374903,   0.15262283,   0.048710253,   0.1847461,
3219         -0.007487823, 0.11000021,   -0.09542012,   0.22619456,
3220         -0.029149994, 0.08527916,   0.009043713,   0.0042746216,
3221         0.016261552,  0.022461696,  0.12689082,    -0.043589946,
3222         -0.12035478,  -0.08361797,  -0.050666027,  -0.1248618,
3223         -0.1275799,   -0.071875185, 0.07377272,    0.09944291,
3224         -0.18897448,  -0.1593054,   -0.06526116,   -0.040107165,
3225         -0.004618631, -0.067624845, -0.007576253,  0.10727444,
3226         0.041546922,  -0.20424393,  0.06907816,    0.050412357,
3227         0.00724631,   0.039827548,  0.12449835,    0.10747581,
3228         0.13708383,   0.09134148,   -0.12617786,   -0.06428341,
3229         0.09956831,   0.1208086,    -0.14676677,   -0.0727722,
3230         0.1126304,    0.010139365,  0.015571211,   -0.038128063,
3231         0.022913318,  -0.042050496, 0.16842307,    -0.060597885,
3232         0.10531834,   -0.06411776,  -0.07451711,   -0.03410368,
3233         -0.13393489,  0.06534304,   0.003620307,   0.04490757,
3234         0.05970546,   0.05197996,   0.02839995,    0.10434969,
3235         -0.013699693, -0.028353551, -0.07260381,   0.047201227,
3236         -0.024575593, -0.036445823, 0.07155557,    0.009672501,
3237         -0.02328883,  0.009533515,  -0.03606021,   -0.07421458,
3238         -0.028082801, -0.2678904,   -0.13221288,   0.18419984,
3239         -0.13012612,  -0.014588381, -0.035059117,  -0.04824723,
3240         0.07830115,   -0.056184657, 0.03277091,    0.025466874,
3241         0.14494097,   -0.12522776,  -0.098633975,  -0.10766018,
3242         -0.08317623,  0.08594209,   0.07749552,    0.039474737,
3243         0.1776665,    -0.07409566,  -0.0477268,    0.29323658,
3244         0.10801441,   0.1154011,    0.013952499,   0.10739139,
3245         0.10708251,   -0.051456142, 0.0074137426,  -0.10430189,
3246         0.10034707,   0.045594677,  0.0635285,     -0.0715442,
3247         -0.089667566, -0.10811871,  0.00026344223, 0.08298446,
3248         -0.009525053, 0.006585689,  -0.24567553,   -0.09450807,
3249         0.09648481,   0.026996298,  -0.06419476,   -0.04752702,
3250         -0.11063944,  -0.23441927,  -0.17608605,   -0.052156363,
3251         0.067035615,  0.19271925,   -0.0032889997, -0.043264326,
3252         0.09663576,   -0.057112187, -0.10100678,   0.0628376,
3253         0.04447668,   0.017961001,  -0.10094388,   -0.10190601,
3254         0.18335468,   0.10494553,   -0.052095775,  -0.0026118709,
3255         0.10539724,   -0.04383912,  -0.042349473,  0.08438151,
3256         -0.1947263,   0.02251204,   0.11216432,    -0.10307853,
3257         0.17351969,   -0.039091777, 0.08066188,    -0.00561982,
3258         0.12633002,   0.11335965,   -0.0088127935, -0.019777594,
3259         0.06864014,   -0.059751723, 0.016233567,   -0.06894641,
3260         -0.28651384,  -0.004228674, 0.019708522,   -0.16305895,
3261         -0.07468996,  -0.0855457,   0.099339016,   -0.07580735,
3262         -0.13775392,  0.08434318,   0.08330512,    -0.12131499,
3263         0.031935584,  0.09180414,   -0.08876437,   -0.08049874,
3264         0.008753825,  0.03498998,   0.030215185,   0.03907079,
3265         0.089751154,  0.029194152,  -0.03337423,   -0.019092513,
3266         0.04331237,   0.04299654,   -0.036394123,  -0.12915532,
3267         0.09793732,   0.07512415,   -0.11319543,   -0.032502122,
3268         0.15661901,   0.07671967,   -0.005491124,  -0.19379048,
3269         -0.218606,    0.21448623,   0.017840758,   0.1416943,
3270         -0.07051762,  0.19488361,   0.02664691,    -0.18104725,
3271         -0.09334311,  0.15026465,   -0.15493552,   -0.057762887,
3272         -0.11604192,  -0.262013,    -0.01391798,   0.012185008,
3273         0.11156489,   -0.07483202,  0.06693364,    -0.26151478,
3274         0.046425626,  0.036540434,  -0.16435726,   0.17338543,
3275         -0.21401681,  -0.11385144,  -0.08283257,   -0.069031075,
3276         0.030635102,  0.010969227,  0.11109743,    0.010919218,
3277         0.027526086,  0.13519906,   0.01891392,    -0.046839405,
3278         -0.040167913, 0.017953383,  -0.09700955,   0.0061885654,
3279         -0.07000971,  0.026893595,  -0.038844477,  0.14543656};
3280 
3281     lstm_input_ = {
3282         {// Batch0: 4 (input_sequence_size) * 5 (n_input)
3283          0.787926, 0.151646, 0.071352, 0.118426, 0.458058,   // step 0
3284          0.596268, 0.998386, 0.568695, 0.864524, 0.571277,   // step 1
3285          0.073204, 0.296072, 0.743333, 0.069199, 0.045348,   // step 2
3286          0.867394, 0.291279, 0.013714, 0.482521, 0.626339},  // step 3
3287 
3288         {// Batch1: 4 (input_sequence_size) * 5 (n_input)
3289          0.295743, 0.544053, 0.690064, 0.858138, 0.497181,  // step 0
3290          0.642421, 0.524260, 0.134799, 0.003639, 0.162482,  // step 1
3291          0.640394, 0.930399, 0.050782, 0.432485, 0.988078,  // step 2
3292          0.082922, 0.563329, 0.865614, 0.333232, 0.259916}  // step 3
3293     };
3294 
3295     lstm_golden_output_ = {
3296         {// Batch0: 4 (input_sequence_size) * 16 (n_output)
3297          -0.00396806, 0.029352,     -0.00279226, 0.0159977,   -0.00835576,
3298          -0.0211779,  0.0283512,    -0.0114597,  0.00907307,  -0.0244004,
3299          -0.0152191,  -0.0259063,   0.00914318,  0.00415118,  0.017147,
3300          0.0134203,   -0.0166936,   0.0381209,   0.000889694, 0.0143363,
3301          -0.0328911,  -0.0234288,   0.0333051,   -0.012229,   0.0110322,
3302          -0.0457725,  -0.000832209, -0.0202817,  0.0327257,   0.0121308,
3303          0.0155969,   0.0312091,    -0.0213783,  0.0350169,   0.000324794,
3304          0.0276012,   -0.0263374,   -0.0371449,  0.0446149,   -0.0205474,
3305          0.0103729,   -0.0576349,   -0.0150052,  -0.0292043,  0.0376827,
3306          0.0136115,   0.0243435,    0.0354492,   -0.0189322,  0.0464512,
3307          -0.00251373, 0.0225745,    -0.0308346,  -0.0317124,  0.0460407,
3308          -0.0189395,  0.0149363,    -0.0530162,  -0.0150767,  -0.0340193,
3309          0.0286833,   0.00824207,   0.0264887,   0.0305169},
3310         {// Batch1: 4 (input_sequence_size) * 16 (n_output)
3311          -0.013869,    0.0287268,   -0.00334693, 0.00733398,  -0.0287926,
3312          -0.0186926,   0.0193662,   -0.0115437,  0.00422612,  -0.0345232,
3313          0.00223253,   -0.00957321, 0.0210624,   0.013331,    0.0150954,
3314          0.02168,      -0.0141913,  0.0322082,   0.00227024,  0.0260507,
3315          -0.0188721,   -0.0296489,  0.0399134,   -0.0160509,  0.0116039,
3316          -0.0447318,   -0.0150515,  -0.0277406,  0.0316596,   0.0118233,
3317          0.0214762,    0.0293641,   -0.0204549,  0.0450315,   -0.00117378,
3318          0.0167673,    -0.0375007,  -0.0238314,  0.038784,    -0.0174034,
3319          0.0131743,    -0.0506589,  -0.0048447,  -0.0240239,  0.0325789,
3320          0.00790065,   0.0220157,   0.0333314,   -0.0264787,  0.0387855,
3321          -0.000764675, 0.0217599,   -0.037537,   -0.0335206,  0.0431679,
3322          -0.0211424,   0.010203,    -0.062785,   -0.00832363, -0.025181,
3323          0.0412031,    0.0118723,   0.0239643,   0.0394009}};
3324   }
3325 };
3326 
TEST_F(NoCifgPeepholeProjectionClippingLstmTest,LstmBlackBoxTest)3327 TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
3328   const int n_batch = 2;
3329   const int n_input = 5;
3330   const int n_cell = 20;
3331   const int n_output = 16;
3332 
3333   LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
3334                    /*use_cifg=*/false, /*use_peephole=*/true,
3335                    /*use_projection_weights=*/true,
3336                    /*use_projection_bias=*/false,
3337                    /*cell_clip=*/0.0, /*proj_clip=*/0.0,
3338                    {
3339                        {n_batch, n_input},  // input tensor
3340 
3341                        {n_cell, n_input},  // input_to_input_weight tensor
3342                        {n_cell, n_input},  // input_to_forget_weight tensor
3343                        {n_cell, n_input},  // input_to_cell_weight tensor
3344                        {n_cell, n_input},  // input_to_output_weight tensor
3345 
3346                        {n_cell, n_output},  // recurrent_to_input_weight tensor
3347                        {n_cell, n_output},  // recurrent_to_forget_weight tensor
3348                        {n_cell, n_output},  // recurrent_to_cell_weight tensor
3349                        {n_cell, n_output},  // recurrent_to_output_weight tensor
3350 
3351                        {n_cell},  // cell_to_input_weight tensor
3352                        {n_cell},  // cell_to_forget_weight tensor
3353                        {n_cell},  // cell_to_output_weight tensor
3354 
3355                        {n_cell},  // input_gate_bias tensor
3356                        {n_cell},  // forget_gate_bias tensor
3357                        {n_cell},  // cell_bias tensor
3358                        {n_cell},  // output_gate_bias tensor
3359 
3360                        {n_output, n_cell},  // projection_weight tensor
3361                        {0},                 // projection_bias tensor
3362                    },
3363                    /*weight_type=*/TensorType_FLOAT32);
3364 
3365   lstm.SetInputToInputWeights(input_to_input_weights_);
3366   lstm.SetInputToCellWeights(input_to_cell_weights_);
3367   lstm.SetInputToForgetWeights(input_to_forget_weights_);
3368   lstm.SetInputToOutputWeights(input_to_output_weights_);
3369 
3370   lstm.SetInputGateBias(input_gate_bias_);
3371   lstm.SetCellBias(cell_gate_bias_);
3372   lstm.SetForgetGateBias(forget_gate_bias_);
3373   lstm.SetOutputGateBias(output_gate_bias_);
3374 
3375   lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
3376   lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
3377   lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
3378   lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
3379 
3380   lstm.SetCellToInputWeights(cell_to_input_weights_);
3381   lstm.SetCellToForgetWeights(cell_to_forget_weights_);
3382   lstm.SetCellToOutputWeights(cell_to_output_weights_);
3383 
3384   lstm.SetProjectionWeights(projection_weights_);
3385 
3386   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
3387 }
3388 
3389 class BaseReduceOpModel : public SingleOpModelWithNNAPI {
3390  public:
SetAxis(const std::vector<int> & data)3391   void SetAxis(const std::vector<int>& data) { PopulateTensor(axis_, data); }
3392 
3393   template <class T>
SetInput(const std::vector<T> & data)3394   void SetInput(const std::vector<T>& data) {
3395     PopulateTensor(input_, data);
3396   }
3397 
3398   template <class T>
GetOutput()3399   std::vector<T> GetOutput() {
3400     return ExtractVector<T>(output_);
3401   }
3402 
GetDequantizedOutput()3403   std::vector<float> GetDequantizedOutput() {
3404     return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
3405                                GetScale(output_), GetZeroPoint(output_));
3406   }
3407 
GetOutputShape()3408   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
3409 
Input()3410   int Input() { return input_; }
3411 
3412  protected:
3413   int input_;
3414   int axis_;
3415   int output_;
3416 };
3417 
3418 // Model for the tests case where axis is a const tensor.
3419 class MeanOpConstModel : public BaseReduceOpModel {
3420  public:
MeanOpConstModel(const TensorData & input,const TensorData & output,std::initializer_list<int> axis_shape,std::initializer_list<int> axis,bool keep_dims)3421   MeanOpConstModel(const TensorData& input, const TensorData& output,
3422                    std::initializer_list<int> axis_shape,
3423                    std::initializer_list<int> axis, bool keep_dims) {
3424     input_ = AddInput(input);
3425     axis_ = AddConstInput(TensorType_INT32, axis, axis_shape);
3426     output_ = AddOutput(output);
3427     SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_ReducerOptions,
3428                  CreateReducerOptions(builder_, keep_dims).Union());
3429     BuildInterpreter({GetShape(input_)});
3430   }
3431 };
3432 
3433 // Tests for reduce_mean
TEST(NNAPIDelegate,MeanFloatNotKeepDims)3434 TEST(NNAPIDelegate, MeanFloatNotKeepDims) {
3435   std::vector<float> data = {1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,
3436                              9.0,  10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
3437                              17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
3438   MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}},
3439                      {4}, {1, 0, -3, -3}, false);
3440   m.SetInput(data);
3441   m.Invoke();
3442   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
3443   EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({12, 13})));
3444 }
3445 
TEST(NNAPIDelegate,MeanFloatKeepDims)3446 TEST(NNAPIDelegate, MeanFloatKeepDims) {
3447   std::vector<float> data = {1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,
3448                              9.0,  10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
3449                              17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
3450   MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}},
3451                      {2}, {0, 2}, true);
3452   m.SetInput(data);
3453   m.Invoke();
3454   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
3455   EXPECT_THAT(m.GetOutput<float>(),
3456               ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5})));
3457 }
3458 
3459 class BaseEmbeddingLookupOpModel : public SingleOpModelWithNNAPI {
3460  public:
BaseEmbeddingLookupOpModel(std::initializer_list<int> index_shape,std::initializer_list<int> weight_shape,TensorType weight_type=TensorType_FLOAT32)3461   BaseEmbeddingLookupOpModel(std::initializer_list<int> index_shape,
3462                              std::initializer_list<int> weight_shape,
3463                              TensorType weight_type = TensorType_FLOAT32) {
3464     input_ = AddInput(TensorType_INT32);
3465     weight_ = AddInput(weight_type);
3466     output_ = AddOutput(TensorType_FLOAT32);
3467     SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOptions_NONE, 0);
3468     BuildInterpreter({index_shape, weight_shape});
3469   }
3470 
SetInput(std::initializer_list<int> data)3471   void SetInput(std::initializer_list<int> data) {
3472     PopulateTensor(input_, data);
3473   }
3474 
GetOutput()3475   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
3476 
3477  protected:
3478   int input_;
3479   int weight_;
3480   int output_;
3481 };
3482 
3483 class EmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel {
3484  public:
3485   using BaseEmbeddingLookupOpModel::BaseEmbeddingLookupOpModel;
3486 
Set3DWeightMatrix(const std::function<float (int,int,int)> & function)3487   void Set3DWeightMatrix(const std::function<float(int, int, int)>& function) {
3488     TfLiteTensor* tensor = interpreter_->tensor(weight_);
3489     int rows = tensor->dims->data[0];
3490     int columns = tensor->dims->data[1];
3491     int features = tensor->dims->data[2];
3492     for (int i = 0; i < rows; i++) {
3493       for (int j = 0; j < columns; j++) {
3494         for (int k = 0; k < features; k++) {
3495           tensor->data.f[(i * columns + j) * features + k] = function(i, j, k);
3496         }
3497       }
3498     }
3499   }
3500 };
3501 
TEST(NNAPIDelegate,EmbeddingLookupSimpleTest)3502 TEST(NNAPIDelegate, EmbeddingLookupSimpleTest) {
3503   EmbeddingLookupOpModel m({3}, {3, 2, 4});
3504   m.SetInput({1, 0, 2});
3505   m.Set3DWeightMatrix(
3506       [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; });
3507 
3508   m.Invoke();
3509 
3510   EXPECT_THAT(m.GetOutput(),
3511               ElementsAreArray(ArrayFloatNear({
3512                   1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13,  // Row 1
3513                   0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13,  // Row 0
3514                   2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13,  // Row 2
3515               })));
3516 }
3517 
3518 class HashtableLookupOpModel : public SingleOpModelWithNNAPI {
3519  public:
HashtableLookupOpModel(std::initializer_list<int> lookup_shape,std::initializer_list<int> key_shape,std::initializer_list<int> value_shape,TensorType type)3520   HashtableLookupOpModel(std::initializer_list<int> lookup_shape,
3521                          std::initializer_list<int> key_shape,
3522                          std::initializer_list<int> value_shape,
3523                          TensorType type) {
3524     lookup_ = AddInput(TensorType_INT32);
3525     key_ = AddInput(TensorType_INT32);
3526     value_ = AddInput(type);
3527     output_ = AddOutput(type);
3528     hit_ = AddOutput(TensorType_UINT8);
3529     SetBuiltinOp(BuiltinOperator_HASHTABLE_LOOKUP, BuiltinOptions_NONE, 0);
3530     BuildInterpreter({lookup_shape, key_shape, value_shape});
3531   }
3532 
SetLookup(std::initializer_list<int> data)3533   void SetLookup(std::initializer_list<int> data) {
3534     PopulateTensor<int>(lookup_, data);
3535   }
3536 
SetHashtableKey(std::initializer_list<int> data)3537   void SetHashtableKey(std::initializer_list<int> data) {
3538     PopulateTensor<int>(key_, data);
3539   }
3540 
SetHashtableValue(const std::vector<string> & content)3541   void SetHashtableValue(const std::vector<string>& content) {
3542     PopulateStringTensor(value_, content);
3543   }
3544 
SetHashtableValue(const std::function<float (int)> & function)3545   void SetHashtableValue(const std::function<float(int)>& function) {
3546     TfLiteTensor* tensor = interpreter_->tensor(value_);
3547     int rows = tensor->dims->data[0];
3548     for (int i = 0; i < rows; i++) {
3549       tensor->data.f[i] = function(i);
3550     }
3551   }
3552 
SetHashtableValue(const std::function<float (int,int)> & function)3553   void SetHashtableValue(const std::function<float(int, int)>& function) {
3554     TfLiteTensor* tensor = interpreter_->tensor(value_);
3555     int rows = tensor->dims->data[0];
3556     int features = tensor->dims->data[1];
3557     for (int i = 0; i < rows; i++) {
3558       for (int j = 0; j < features; j++) {
3559         tensor->data.f[i * features + j] = function(i, j);
3560       }
3561     }
3562   }
3563 
GetStringOutput()3564   std::vector<string> GetStringOutput() {
3565     TfLiteTensor* output = interpreter_->tensor(output_);
3566     int num = GetStringCount(output);
3567     std::vector<string> result(num);
3568     for (int i = 0; i < num; i++) {
3569       auto ref = GetString(output, i);
3570       result[i] = string(ref.str, ref.len);
3571     }
3572     return result;
3573   }
3574 
GetOutput()3575   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetHit()3576   std::vector<uint8_t> GetHit() { return ExtractVector<uint8_t>(hit_); }
3577 
3578  private:
3579   int lookup_;
3580   int key_;
3581   int value_;
3582   int output_;
3583   int hit_;
3584 };
3585 
TEST(NNAPIDelegate,HashtableLookupTest2DInput)3586 TEST(NNAPIDelegate, HashtableLookupTest2DInput) {
3587   HashtableLookupOpModel m({4}, {3}, {3, 2}, TensorType_FLOAT32);
3588 
3589   m.SetLookup({1234, -292, -11, 0});
3590   m.SetHashtableKey({-11, 0, 1234});
3591   m.SetHashtableValue([](int i, int j) { return i + j / 10.0f; });
3592 
3593   m.Invoke();
3594 
3595   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
3596                                  2.0, 2.1,  // 2-nd item
3597                                  0, 0,      // Not found
3598                                  0.0, 0.1,  // 0-th item
3599                                  1.0, 1.1,  // 1-st item
3600                              })));
3601   EXPECT_THAT(m.GetHit(), ElementsAreArray({
3602                               1,
3603                               0,
3604                               1,
3605                               1,
3606                           }));
3607 }
3608 
TEST(NNAPIDelegate,HashtableLookupTest1DInput)3609 TEST(NNAPIDelegate, HashtableLookupTest1DInput) {
3610   HashtableLookupOpModel m({4}, {3}, {3}, TensorType_FLOAT32);
3611 
3612   m.SetLookup({1234, -292, -11, 0});
3613   m.SetHashtableKey({-11, 0, 1234});
3614   m.SetHashtableValue([](int i) { return i * i / 10.0f; });
3615 
3616   m.Invoke();
3617 
3618   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
3619                                  0.4,  // 2-nd item
3620                                  0,    // Not found
3621                                  0.0,  // 0-th item
3622                                  0.1,  // 1-st item
3623                              })));
3624   EXPECT_THAT(m.GetHit(), ElementsAreArray({
3625                               1,
3626                               0,
3627                               1,
3628                               1,
3629                           }));
3630 }
3631 }  // namespace
3632 }  // namespace tflite
3633 
main(int argc,char ** argv)3634 int main(int argc, char** argv) {
3635   ::tflite::LogToStderr();
3636   ::testing::InitGoogleTest(&argc, argv);
3637   return RUN_ALL_TESTS();
3638 }
3639