1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <functional>
17 #include <memory>
18 #include <vector>
19 
20 #include <gtest/gtest.h>
21 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
22 #include "tensorflow/lite/interpreter.h"
23 #include "tensorflow/lite/kernels/register.h"
24 #include "tensorflow/lite/kernels/test_util.h"
25 #include "tensorflow/lite/model.h"
26 
27 namespace tflite {
28 namespace ops {
29 namespace experimental {
30 
31 using ::testing::ElementsAre;
32 using ::testing::ElementsAreArray;
33 
34 TfLiteRegistration* Register_CTC_BEAM_SEARCH_DECODER();
35 
36 namespace {
37 
38 using ::testing::ElementsAre;
39 using ::testing::ElementsAreArray;
40 
41 class CTCBeamSearchDecoderOpModel : public SingleOpModel {
42  public:
CTCBeamSearchDecoderOpModel(std::initializer_list<int> input_shape,std::initializer_list<int> sequence_length_shape,int beam_width,int top_paths,bool merge_repeated)43   CTCBeamSearchDecoderOpModel(std::initializer_list<int> input_shape,
44                               std::initializer_list<int> sequence_length_shape,
45                               int beam_width, int top_paths,
46                               bool merge_repeated) {
47     inputs_ = AddInput(TensorType_FLOAT32);
48     sequence_length_ = AddInput(TensorType_INT32);
49 
50     for (int i = 0; i < top_paths * 3; ++i) {
51       outputs_.push_back(AddOutput(TensorType_INT32));
52     }
53     outputs_.push_back(AddOutput(TensorType_FLOAT32));
54 
55     flexbuffers::Builder fbb;
56     fbb.Map([&]() {
57       fbb.Int("beam_width", beam_width);
58       fbb.Int("top_paths", top_paths);
59       fbb.Bool("merge_repeated", merge_repeated);
60     });
61     fbb.Finish();
62     SetCustomOp("CTCBeamSearchDecoder", fbb.GetBuffer(),
63                 Register_CTC_BEAM_SEARCH_DECODER);
64     BuildInterpreter({input_shape, sequence_length_shape});
65   }
66 
inputs()67   int inputs() { return inputs_; }
68 
sequence_length()69   int sequence_length() { return sequence_length_; }
70 
GetDecodedOutpus()71   std::vector<std::vector<int>> GetDecodedOutpus() {
72     std::vector<std::vector<int>> outputs;
73     for (int i = 0; i < outputs_.size() - 1; ++i) {
74       outputs.push_back(ExtractVector<int>(outputs_[i]));
75     }
76     return outputs;
77   }
78 
GetLogProbabilitiesOutput()79   std::vector<float> GetLogProbabilitiesOutput() {
80     return ExtractVector<float>(outputs_[outputs_.size() - 1]);
81   }
82 
GetOutputShapes()83   std::vector<std::vector<int>> GetOutputShapes() {
84     std::vector<std::vector<int>> output_shapes;
85     for (const int output : outputs_) {
86       output_shapes.push_back(GetTensorShape(output));
87     }
88     return output_shapes;
89   }
90 
91  private:
92   int inputs_;
93   int sequence_length_;
94   std::vector<int> outputs_;
95 };
96 
TEST(CTCBeamSearchTest,SimpleTest)97 TEST(CTCBeamSearchTest, SimpleTest) {
98   CTCBeamSearchDecoderOpModel m({2, 1, 2}, {1}, 1, 1, true);
99   m.PopulateTensor<float>(m.inputs(),
100                           {-0.50922557, -1.35512652, -2.55445064, -1.58419356});
101   m.PopulateTensor<int>(m.sequence_length(), {2});
102   m.Invoke();
103 
104   // Make sure the output shapes are right.
105   const std::vector<std::vector<int>>& output_shapes = m.GetOutputShapes();
106   EXPECT_EQ(output_shapes.size(), 4);
107   EXPECT_THAT(output_shapes[0], ElementsAre(1, 2));
108   EXPECT_THAT(output_shapes[1], ElementsAre(1));
109   EXPECT_THAT(output_shapes[2], ElementsAre(2));
110   EXPECT_THAT(output_shapes[3], ElementsAre(1, 1));
111 
112   // Check decoded outputs.
113   const std::vector<std::vector<int>>& decoded_outputs = m.GetDecodedOutpus();
114   EXPECT_EQ(decoded_outputs.size(), 3);
115   EXPECT_THAT(decoded_outputs[0], ElementsAre(0, 0));
116   EXPECT_THAT(decoded_outputs[1], ElementsAre(0));
117   EXPECT_THAT(decoded_outputs[2], ElementsAre(1, 1));
118   // Check log probabilities output.
119   EXPECT_THAT(m.GetLogProbabilitiesOutput(),
120               ElementsAreArray(ArrayFloatNear({-0.357094})));
121 }
122 
TEST(CTCBeamSearchTest,MultiBatchTest)123 TEST(CTCBeamSearchTest, MultiBatchTest) {
124   CTCBeamSearchDecoderOpModel m({3, 3, 3}, {3}, 1, 1, true);
125   m.PopulateTensor<float>(
126       m.inputs(),
127       {-0.63649208, -0.00487571, -0.04249819, -0.67754697, -1.0341399,
128        -2.14717721, -0.77686821, -3.41973774, -0.05151402, -0.21482619,
129        -0.57411168, -1.45039917, -0.73769373, -2.10941739, -0.44818325,
130        -0.25287673, -2.80057302, -0.54748312, -0.73334867, -0.86537719,
131        -0.2065197,  -0.18725838, -1.42770405, -0.86051965, -1.61642301,
132        -2.07275114, -0.9201845});
133   m.PopulateTensor<int>(m.sequence_length(), {3, 3, 3});
134   m.Invoke();
135 
136   // Make sure the output shapes are right.
137   const std::vector<std::vector<int>>& output_shapes = m.GetOutputShapes();
138   EXPECT_EQ(output_shapes.size(), 4);
139   EXPECT_THAT(output_shapes[0], ElementsAre(4, 2));
140   EXPECT_THAT(output_shapes[1], ElementsAre(4));
141   EXPECT_THAT(output_shapes[2], ElementsAre(2));
142   EXPECT_THAT(output_shapes[3], ElementsAre(3, 1));
143 
144   // Check decoded outputs.
145   const std::vector<std::vector<int>>& decoded_outputs = m.GetDecodedOutpus();
146   EXPECT_EQ(decoded_outputs.size(), 3);
147   EXPECT_THAT(decoded_outputs[0], ElementsAre(0, 0, 0, 1, 1, 0, 2, 0));
148   EXPECT_THAT(decoded_outputs[1], ElementsAre(1, 0, 0, 0));
149   EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 2));
150   // Check log probabilities output.
151   EXPECT_THAT(m.GetLogProbabilitiesOutput(),
152               ElementsAreArray(ArrayFloatNear({-1.88343, -1.41188, -1.20958})));
153 }
154 
TEST(CTCBeamSearchTest,MultiPathsTest)155 TEST(CTCBeamSearchTest, MultiPathsTest) {
156   CTCBeamSearchDecoderOpModel m({3, 2, 5}, {2}, 3, 2, true);
157   m.PopulateTensor<float>(
158       m.inputs(),
159       {-2.206851,   -0.09542714, -0.2393415,  -3.81866197, -0.27241158,
160        -0.20371124, -0.68236623, -1.1397166,  -0.17422639, -1.85224048,
161        -0.9406037,  -0.32544678, -0.21846784, -0.38377237, -0.33498676,
162        -0.10139782, -0.51886883, -0.21678554, -0.15267063, -1.91164412,
163        -0.31328673, -0.27462716, -0.65975336, -1.53671973, -2.76554225,
164        -0.23920634, -1.2370502,  -4.98751576, -3.12995717, -0.43129368});
165   m.PopulateTensor<int>(m.sequence_length(), {3, 3});
166   m.Invoke();
167 
168   // Make sure the output shapes are right.
169   const std::vector<std::vector<int>>& output_shapes = m.GetOutputShapes();
170   EXPECT_EQ(output_shapes.size(), 7);
171   EXPECT_THAT(output_shapes[0], ElementsAre(4, 2));
172   EXPECT_THAT(output_shapes[1], ElementsAre(3, 2));
173   EXPECT_THAT(output_shapes[2], ElementsAre(4));
174   EXPECT_THAT(output_shapes[3], ElementsAre(3));
175   EXPECT_THAT(output_shapes[4], ElementsAre(2));
176   EXPECT_THAT(output_shapes[5], ElementsAre(2));
177   EXPECT_THAT(output_shapes[6], ElementsAre(2, 2));
178 
179   // Check decoded outputs.
180   const std::vector<std::vector<int>>& decoded_outputs = m.GetDecodedOutpus();
181   EXPECT_EQ(decoded_outputs.size(), 6);
182   EXPECT_THAT(decoded_outputs[0], ElementsAre(0, 0, 0, 1, 1, 0, 1, 1));
183   EXPECT_THAT(decoded_outputs[1], ElementsAre(0, 0, 0, 1, 1, 0));
184   EXPECT_THAT(decoded_outputs[2], ElementsAre(1, 2, 3, 0));
185   EXPECT_THAT(decoded_outputs[3], ElementsAre(2, 1, 0));
186   EXPECT_THAT(decoded_outputs[4], ElementsAre(2, 2));
187   EXPECT_THAT(decoded_outputs[5], ElementsAre(2, 2));
188   // Check log probabilities output.
189   EXPECT_THAT(m.GetLogProbabilitiesOutput(),
190               ElementsAreArray(
191                   ArrayFloatNear({-2.65148, -2.65864, -2.17914, -2.61357})));
192 }
193 
TEST(CTCBeamSearchTest,NonEqualSequencesTest)194 TEST(CTCBeamSearchTest, NonEqualSequencesTest) {
195   CTCBeamSearchDecoderOpModel m({3, 3, 4}, {3}, 3, 1, true);
196   m.PopulateTensor<float>(
197       m.inputs(),
198       {-1.26658163, -0.25760023, -0.03917975, -0.63772235, -0.03794756,
199        -0.45063099, -0.27706473, -0.01569179, -0.59940385, -0.35700127,
200        -0.48920721, -1.42635476, -1.3462478,  -0.02565498, -0.30179568,
201        -0.6491698,  -0.55017719, -2.92291466, -0.92522973, -0.47592022,
202        -0.07099135, -0.31575624, -0.86345281, -0.36017021, -0.79208612,
203        -1.75306124, -0.65089224, -0.00912786, -0.42915003, -1.72606203,
204        -1.66337589, -0.70800793, -2.52272352, -0.67329562, -2.49145522,
205        -0.49786342});
206   m.PopulateTensor<int>(m.sequence_length(), {1, 2, 3});
207   m.Invoke();
208 
209   // Make sure the output shapes are right.
210   const std::vector<std::vector<int>>& output_shapes = m.GetOutputShapes();
211   EXPECT_EQ(output_shapes.size(), 4);
212   EXPECT_THAT(output_shapes[0], ElementsAre(3, 2));
213   EXPECT_THAT(output_shapes[1], ElementsAre(3));
214   EXPECT_THAT(output_shapes[2], ElementsAre(2));
215   EXPECT_THAT(output_shapes[3], ElementsAre(3, 1));
216 
217   // Check decoded outputs.
218   const std::vector<std::vector<int>>& decoded_outputs = m.GetDecodedOutpus();
219   EXPECT_EQ(decoded_outputs.size(), 3);
220   EXPECT_THAT(decoded_outputs[0], ElementsAre(0, 0, 1, 0, 2, 0));
221   EXPECT_THAT(decoded_outputs[1], ElementsAre(2, 0, 1));
222   EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 1));
223   // Check log probabilities output.
224   EXPECT_THAT(m.GetLogProbabilitiesOutput(),
225               ElementsAreArray(ArrayFloatNear({-0.97322, -1.16334, -2.15553})));
226 }
227 
228 }  // namespace
229 }  // namespace experimental
230 }  // namespace ops
231 }  // namespace tflite
232 
main(int argc,char ** argv)233 int main(int argc, char** argv) {
234   ::tflite::LogToStderr();
235   ::testing::InitGoogleTest(&argc, argv);
236   return RUN_ALL_TESTS();
237 }
238