1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <vector>
17 
18 #include <gtest/gtest.h>
19 #include "tensorflow/lite/kernels/test_util.h"
20 
21 namespace tflite {
22 namespace ops {
23 namespace experimental {
24 
25 TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_GRU();
26 
27 namespace {
28 
29 using ::testing::ElementsAre;
30 using ::testing::ElementsAreArray;
31 
32 class GRUOpModel : public SingleOpModel {
33  public:
GRUOpModel(int n_batch,int n_input,int n_output,const std::vector<std::vector<int>> & input_shapes,const TensorType & weight_type=TensorType_FLOAT32)34   explicit GRUOpModel(int n_batch, int n_input, int n_output,
35                       const std::vector<std::vector<int>>& input_shapes,
36                       const TensorType& weight_type = TensorType_FLOAT32)
37       : n_batch_(n_batch), n_input_(n_input), n_output_(n_output) {
38     input_ = AddInput(TensorType_FLOAT32);
39     input_state_ =
40         AddVariableInput(TensorData{TensorType_FLOAT32, {n_batch, n_output}});
41     gate_weight_ = AddInput(TensorType_FLOAT32);
42     gate_bias_ = AddInput(TensorType_FLOAT32);
43     candidate_weight_ = AddInput(TensorType_FLOAT32);
44     candidate_bias_ = AddInput(TensorType_FLOAT32);
45 
46     output_ = AddOutput(TensorType_FLOAT32);
47     output_state_ = AddOutput(TensorType_FLOAT32);
48 
49     SetCustomOp("UNIDIRECTIONAL_SEQUENCE_GRU", {},
50                 Register_UNIDIRECTIONAL_SEQUENCE_GRU);
51     BuildInterpreter(input_shapes);
52   }
53 
SetInput(const std::vector<float> & f)54   void SetInput(const std::vector<float>& f) { PopulateTensor(input_, f); }
55 
SetInputState(const std::vector<float> & f)56   void SetInputState(const std::vector<float>& f) {
57     PopulateTensor(input_state_, f);
58   }
59 
SetGateWeight(const std::vector<float> & f)60   void SetGateWeight(const std::vector<float>& f) {
61     PopulateTensor(gate_weight_, f);
62   }
63 
SetGateBias(const std::vector<float> & f)64   void SetGateBias(const std::vector<float>& f) {
65     PopulateTensor(gate_bias_, f);
66   }
67 
SetCandidateWeight(const std::vector<float> & f)68   void SetCandidateWeight(const std::vector<float>& f) {
69     PopulateTensor(candidate_weight_, f);
70   }
71 
SetCandidateBias(const std::vector<float> & f)72   void SetCandidateBias(const std::vector<float>& f) {
73     PopulateTensor(candidate_bias_, f);
74   }
75 
GetOutputShape()76   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
77 
GetOutput()78   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
79 
num_batches()80   int num_batches() { return n_batch_; }
num_inputs()81   int num_inputs() { return n_input_; }
num_outputs()82   int num_outputs() { return n_output_; }
83 
84  private:
85   int input_;
86   int input_state_;
87   int gate_weight_;
88   int gate_bias_;
89   int candidate_weight_;
90   int candidate_bias_;
91 
92   int output_;
93   int output_state_;
94   int n_batch_;
95   int n_input_;
96   int n_output_;
97 };
98 
TEST(GRUTest,SimpleTest)99 TEST(GRUTest, SimpleTest) {
100   const int n_time = 2;
101   const int n_batch = 2;
102   const int n_input = 2;
103   const int n_output = 3;
104 
105   GRUOpModel m(n_batch, n_input, n_output,
106                {{n_time, n_batch, n_input},
107                 {n_batch, n_output},
108                 {2 * n_output, n_input + n_output},
109                 {2 * n_output},
110                 {n_output, n_input + n_output},
111                 {n_output}});
112   // All data is randomly generated.
113   m.SetInput({0.89495724, 0.34482682, 0.68505806, 0.7135783, 0.3167085,
114               0.93647677, 0.47361764, 0.39643127});
115   m.SetInputState(
116       {0.09992421, 0.3028481, 0.78305984, 0.50438094, 0.11269058, 0.10244724});
117   m.SetGateWeight({0.7256918,  0.8945897,  0.03285786, 0.42637166, 0.119376324,
118                    0.83035135, 0.16997327, 0.42302176, 0.77598256, 0.2660894,
119                    0.9587266,  0.6218451,  0.88164485, 0.12272458, 0.2699055,
120                    0.18399088, 0.21930052, 0.3374841,  0.70866305, 0.9523419,
121                    0.25170696, 0.60988617, 0.79823977, 0.64477515, 0.2602957,
122                    0.5053131,  0.93722224, 0.8451359,  0.97905475, 0.38669217});
123   m.SetGateBias(
124       {0.032708533, 0.018445263, 0.15320699, 0.8163046, 0.26683575, 0.1412022});
125   m.SetCandidateWeight({0.96165305, 0.95572084, 0.11534478, 0.96965164,
126                         0.33562955, 0.8680755, 0.003066936, 0.057793964,
127                         0.8671354, 0.33354893, 0.7313398, 0.78492093,
128                         0.19530584, 0.116550304, 0.13599132});
129   m.SetCandidateBias({0.89837056, 0.54769796, 0.63364106});
130 
131   m.Invoke();
132 
133   EXPECT_THAT(m.GetOutputShape(), ElementsAre(n_time, n_batch, n_output));
134   EXPECT_THAT(m.GetOutput(),
135               ElementsAreArray(ArrayFloatNear(
136                   {0.20112592, 0.45286041, 0.80842507, 0.59567153, 0.2619998,
137                    0.22922856, 0.27715868, 0.5247152, 0.82300174, 0.65812796,
138                    0.38217607, 0.3401444})));
139 }
140 
141 }  // namespace
142 }  // namespace experimental
143 }  // namespace ops
144 }  // namespace tflite
145 
main(int argc,char ** argv)146 int main(int argc, char** argv) {
147   ::tflite::LogToStderr();
148   ::testing::InitGoogleTest(&argc, argv);
149   return RUN_ALL_TESTS();
150 }
151