1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <vector>
17
18 #include <gtest/gtest.h>
19 #include "tensorflow/lite/kernels/test_util.h"
20
21 namespace tflite {
22 namespace ops {
23 namespace experimental {
24
25 TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_GRU();
26
27 namespace {
28
29 using ::testing::ElementsAre;
30 using ::testing::ElementsAreArray;
31
32 class GRUOpModel : public SingleOpModel {
33 public:
GRUOpModel(int n_batch,int n_input,int n_output,const std::vector<std::vector<int>> & input_shapes,const TensorType & weight_type=TensorType_FLOAT32)34 explicit GRUOpModel(int n_batch, int n_input, int n_output,
35 const std::vector<std::vector<int>>& input_shapes,
36 const TensorType& weight_type = TensorType_FLOAT32)
37 : n_batch_(n_batch), n_input_(n_input), n_output_(n_output) {
38 input_ = AddInput(TensorType_FLOAT32);
39 input_state_ =
40 AddVariableInput(TensorData{TensorType_FLOAT32, {n_batch, n_output}});
41 gate_weight_ = AddInput(TensorType_FLOAT32);
42 gate_bias_ = AddInput(TensorType_FLOAT32);
43 candidate_weight_ = AddInput(TensorType_FLOAT32);
44 candidate_bias_ = AddInput(TensorType_FLOAT32);
45
46 output_ = AddOutput(TensorType_FLOAT32);
47 output_state_ = AddOutput(TensorType_FLOAT32);
48
49 SetCustomOp("UNIDIRECTIONAL_SEQUENCE_GRU", {},
50 Register_UNIDIRECTIONAL_SEQUENCE_GRU);
51 BuildInterpreter(input_shapes);
52 }
53
SetInput(const std::vector<float> & f)54 void SetInput(const std::vector<float>& f) { PopulateTensor(input_, f); }
55
SetInputState(const std::vector<float> & f)56 void SetInputState(const std::vector<float>& f) {
57 PopulateTensor(input_state_, f);
58 }
59
SetGateWeight(const std::vector<float> & f)60 void SetGateWeight(const std::vector<float>& f) {
61 PopulateTensor(gate_weight_, f);
62 }
63
SetGateBias(const std::vector<float> & f)64 void SetGateBias(const std::vector<float>& f) {
65 PopulateTensor(gate_bias_, f);
66 }
67
SetCandidateWeight(const std::vector<float> & f)68 void SetCandidateWeight(const std::vector<float>& f) {
69 PopulateTensor(candidate_weight_, f);
70 }
71
SetCandidateBias(const std::vector<float> & f)72 void SetCandidateBias(const std::vector<float>& f) {
73 PopulateTensor(candidate_bias_, f);
74 }
75
GetOutputShape()76 std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
77
GetOutput()78 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
79
num_batches()80 int num_batches() { return n_batch_; }
num_inputs()81 int num_inputs() { return n_input_; }
num_outputs()82 int num_outputs() { return n_output_; }
83
84 private:
85 int input_;
86 int input_state_;
87 int gate_weight_;
88 int gate_bias_;
89 int candidate_weight_;
90 int candidate_bias_;
91
92 int output_;
93 int output_state_;
94 int n_batch_;
95 int n_input_;
96 int n_output_;
97 };
98
TEST(GRUTest,SimpleTest)99 TEST(GRUTest, SimpleTest) {
100 const int n_time = 2;
101 const int n_batch = 2;
102 const int n_input = 2;
103 const int n_output = 3;
104
105 GRUOpModel m(n_batch, n_input, n_output,
106 {{n_time, n_batch, n_input},
107 {n_batch, n_output},
108 {2 * n_output, n_input + n_output},
109 {2 * n_output},
110 {n_output, n_input + n_output},
111 {n_output}});
112 // All data is randomly generated.
113 m.SetInput({0.89495724, 0.34482682, 0.68505806, 0.7135783, 0.3167085,
114 0.93647677, 0.47361764, 0.39643127});
115 m.SetInputState(
116 {0.09992421, 0.3028481, 0.78305984, 0.50438094, 0.11269058, 0.10244724});
117 m.SetGateWeight({0.7256918, 0.8945897, 0.03285786, 0.42637166, 0.119376324,
118 0.83035135, 0.16997327, 0.42302176, 0.77598256, 0.2660894,
119 0.9587266, 0.6218451, 0.88164485, 0.12272458, 0.2699055,
120 0.18399088, 0.21930052, 0.3374841, 0.70866305, 0.9523419,
121 0.25170696, 0.60988617, 0.79823977, 0.64477515, 0.2602957,
122 0.5053131, 0.93722224, 0.8451359, 0.97905475, 0.38669217});
123 m.SetGateBias(
124 {0.032708533, 0.018445263, 0.15320699, 0.8163046, 0.26683575, 0.1412022});
125 m.SetCandidateWeight({0.96165305, 0.95572084, 0.11534478, 0.96965164,
126 0.33562955, 0.8680755, 0.003066936, 0.057793964,
127 0.8671354, 0.33354893, 0.7313398, 0.78492093,
128 0.19530584, 0.116550304, 0.13599132});
129 m.SetCandidateBias({0.89837056, 0.54769796, 0.63364106});
130
131 m.Invoke();
132
133 EXPECT_THAT(m.GetOutputShape(), ElementsAre(n_time, n_batch, n_output));
134 EXPECT_THAT(m.GetOutput(),
135 ElementsAreArray(ArrayFloatNear(
136 {0.20112592, 0.45286041, 0.80842507, 0.59567153, 0.2619998,
137 0.22922856, 0.27715868, 0.5247152, 0.82300174, 0.65812796,
138 0.38217607, 0.3401444})));
139 }
140
141 } // namespace
142 } // namespace experimental
143 } // namespace ops
144 } // namespace tflite
145
main(int argc,char ** argv)146 int main(int argc, char** argv) {
147 ::tflite::LogToStderr();
148 ::testing::InitGoogleTest(&argc, argv);
149 return RUN_ALL_TESTS();
150 }
151