1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <functional>
17 #include <memory>
18 #include <vector>
19
20 #include <gtest/gtest.h>
21 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
22 #include "tensorflow/lite/interpreter.h"
23 #include "tensorflow/lite/kernels/register.h"
24 #include "tensorflow/lite/kernels/test_util.h"
25 #include "tensorflow/lite/model.h"
26
27 namespace tflite {
28 namespace ops {
29 namespace experimental {
30
31 using ::testing::ElementsAre;
32 using ::testing::ElementsAreArray;
33
34 TfLiteRegistration* Register_CTC_BEAM_SEARCH_DECODER();
35
36 namespace {
37
38 using ::testing::ElementsAre;
39 using ::testing::ElementsAreArray;
40
41 class CTCBeamSearchDecoderOpModel : public SingleOpModel {
42 public:
CTCBeamSearchDecoderOpModel(std::initializer_list<int> input_shape,std::initializer_list<int> sequence_length_shape,int beam_width,int top_paths,bool merge_repeated)43 CTCBeamSearchDecoderOpModel(std::initializer_list<int> input_shape,
44 std::initializer_list<int> sequence_length_shape,
45 int beam_width, int top_paths,
46 bool merge_repeated) {
47 inputs_ = AddInput(TensorType_FLOAT32);
48 sequence_length_ = AddInput(TensorType_INT32);
49
50 for (int i = 0; i < top_paths * 3; ++i) {
51 outputs_.push_back(AddOutput(TensorType_INT32));
52 }
53 outputs_.push_back(AddOutput(TensorType_FLOAT32));
54
55 flexbuffers::Builder fbb;
56 fbb.Map([&]() {
57 fbb.Int("beam_width", beam_width);
58 fbb.Int("top_paths", top_paths);
59 fbb.Bool("merge_repeated", merge_repeated);
60 });
61 fbb.Finish();
62 SetCustomOp("CTCBeamSearchDecoder", fbb.GetBuffer(),
63 Register_CTC_BEAM_SEARCH_DECODER);
64 BuildInterpreter({input_shape, sequence_length_shape});
65 }
66
inputs()67 int inputs() { return inputs_; }
68
sequence_length()69 int sequence_length() { return sequence_length_; }
70
GetDecodedOutpus()71 std::vector<std::vector<int>> GetDecodedOutpus() {
72 std::vector<std::vector<int>> outputs;
73 for (int i = 0; i < outputs_.size() - 1; ++i) {
74 outputs.push_back(ExtractVector<int>(outputs_[i]));
75 }
76 return outputs;
77 }
78
GetLogProbabilitiesOutput()79 std::vector<float> GetLogProbabilitiesOutput() {
80 return ExtractVector<float>(outputs_[outputs_.size() - 1]);
81 }
82
GetOutputShapes()83 std::vector<std::vector<int>> GetOutputShapes() {
84 std::vector<std::vector<int>> output_shapes;
85 for (const int output : outputs_) {
86 output_shapes.push_back(GetTensorShape(output));
87 }
88 return output_shapes;
89 }
90
91 private:
92 int inputs_;
93 int sequence_length_;
94 std::vector<int> outputs_;
95 };
96
TEST(CTCBeamSearchTest,SimpleTest)97 TEST(CTCBeamSearchTest, SimpleTest) {
98 CTCBeamSearchDecoderOpModel m({2, 1, 2}, {1}, 1, 1, true);
99 m.PopulateTensor<float>(m.inputs(),
100 {-0.50922557, -1.35512652, -2.55445064, -1.58419356});
101 m.PopulateTensor<int>(m.sequence_length(), {2});
102 m.Invoke();
103
104 // Make sure the output shapes are right.
105 const std::vector<std::vector<int>>& output_shapes = m.GetOutputShapes();
106 EXPECT_EQ(output_shapes.size(), 4);
107 EXPECT_THAT(output_shapes[0], ElementsAre(1, 2));
108 EXPECT_THAT(output_shapes[1], ElementsAre(1));
109 EXPECT_THAT(output_shapes[2], ElementsAre(2));
110 EXPECT_THAT(output_shapes[3], ElementsAre(1, 1));
111
112 // Check decoded outputs.
113 const std::vector<std::vector<int>>& decoded_outputs = m.GetDecodedOutpus();
114 EXPECT_EQ(decoded_outputs.size(), 3);
115 EXPECT_THAT(decoded_outputs[0], ElementsAre(0, 0));
116 EXPECT_THAT(decoded_outputs[1], ElementsAre(0));
117 EXPECT_THAT(decoded_outputs[2], ElementsAre(1, 1));
118 // Check log probabilities output.
119 EXPECT_THAT(m.GetLogProbabilitiesOutput(),
120 ElementsAreArray(ArrayFloatNear({-0.357094})));
121 }
122
TEST(CTCBeamSearchTest,MultiBatchTest)123 TEST(CTCBeamSearchTest, MultiBatchTest) {
124 CTCBeamSearchDecoderOpModel m({3, 3, 3}, {3}, 1, 1, true);
125 m.PopulateTensor<float>(
126 m.inputs(),
127 {-0.63649208, -0.00487571, -0.04249819, -0.67754697, -1.0341399,
128 -2.14717721, -0.77686821, -3.41973774, -0.05151402, -0.21482619,
129 -0.57411168, -1.45039917, -0.73769373, -2.10941739, -0.44818325,
130 -0.25287673, -2.80057302, -0.54748312, -0.73334867, -0.86537719,
131 -0.2065197, -0.18725838, -1.42770405, -0.86051965, -1.61642301,
132 -2.07275114, -0.9201845});
133 m.PopulateTensor<int>(m.sequence_length(), {3, 3, 3});
134 m.Invoke();
135
136 // Make sure the output shapes are right.
137 const std::vector<std::vector<int>>& output_shapes = m.GetOutputShapes();
138 EXPECT_EQ(output_shapes.size(), 4);
139 EXPECT_THAT(output_shapes[0], ElementsAre(4, 2));
140 EXPECT_THAT(output_shapes[1], ElementsAre(4));
141 EXPECT_THAT(output_shapes[2], ElementsAre(2));
142 EXPECT_THAT(output_shapes[3], ElementsAre(3, 1));
143
144 // Check decoded outputs.
145 const std::vector<std::vector<int>>& decoded_outputs = m.GetDecodedOutpus();
146 EXPECT_EQ(decoded_outputs.size(), 3);
147 EXPECT_THAT(decoded_outputs[0], ElementsAre(0, 0, 0, 1, 1, 0, 2, 0));
148 EXPECT_THAT(decoded_outputs[1], ElementsAre(1, 0, 0, 0));
149 EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 2));
150 // Check log probabilities output.
151 EXPECT_THAT(m.GetLogProbabilitiesOutput(),
152 ElementsAreArray(ArrayFloatNear({-1.88343, -1.41188, -1.20958})));
153 }
154
TEST(CTCBeamSearchTest,MultiPathsTest)155 TEST(CTCBeamSearchTest, MultiPathsTest) {
156 CTCBeamSearchDecoderOpModel m({3, 2, 5}, {2}, 3, 2, true);
157 m.PopulateTensor<float>(
158 m.inputs(),
159 {-2.206851, -0.09542714, -0.2393415, -3.81866197, -0.27241158,
160 -0.20371124, -0.68236623, -1.1397166, -0.17422639, -1.85224048,
161 -0.9406037, -0.32544678, -0.21846784, -0.38377237, -0.33498676,
162 -0.10139782, -0.51886883, -0.21678554, -0.15267063, -1.91164412,
163 -0.31328673, -0.27462716, -0.65975336, -1.53671973, -2.76554225,
164 -0.23920634, -1.2370502, -4.98751576, -3.12995717, -0.43129368});
165 m.PopulateTensor<int>(m.sequence_length(), {3, 3});
166 m.Invoke();
167
168 // Make sure the output shapes are right.
169 const std::vector<std::vector<int>>& output_shapes = m.GetOutputShapes();
170 EXPECT_EQ(output_shapes.size(), 7);
171 EXPECT_THAT(output_shapes[0], ElementsAre(4, 2));
172 EXPECT_THAT(output_shapes[1], ElementsAre(3, 2));
173 EXPECT_THAT(output_shapes[2], ElementsAre(4));
174 EXPECT_THAT(output_shapes[3], ElementsAre(3));
175 EXPECT_THAT(output_shapes[4], ElementsAre(2));
176 EXPECT_THAT(output_shapes[5], ElementsAre(2));
177 EXPECT_THAT(output_shapes[6], ElementsAre(2, 2));
178
179 // Check decoded outputs.
180 const std::vector<std::vector<int>>& decoded_outputs = m.GetDecodedOutpus();
181 EXPECT_EQ(decoded_outputs.size(), 6);
182 EXPECT_THAT(decoded_outputs[0], ElementsAre(0, 0, 0, 1, 1, 0, 1, 1));
183 EXPECT_THAT(decoded_outputs[1], ElementsAre(0, 0, 0, 1, 1, 0));
184 EXPECT_THAT(decoded_outputs[2], ElementsAre(1, 2, 3, 0));
185 EXPECT_THAT(decoded_outputs[3], ElementsAre(2, 1, 0));
186 EXPECT_THAT(decoded_outputs[4], ElementsAre(2, 2));
187 EXPECT_THAT(decoded_outputs[5], ElementsAre(2, 2));
188 // Check log probabilities output.
189 EXPECT_THAT(m.GetLogProbabilitiesOutput(),
190 ElementsAreArray(
191 ArrayFloatNear({-2.65148, -2.65864, -2.17914, -2.61357})));
192 }
193
TEST(CTCBeamSearchTest,NonEqualSequencesTest)194 TEST(CTCBeamSearchTest, NonEqualSequencesTest) {
195 CTCBeamSearchDecoderOpModel m({3, 3, 4}, {3}, 3, 1, true);
196 m.PopulateTensor<float>(
197 m.inputs(),
198 {-1.26658163, -0.25760023, -0.03917975, -0.63772235, -0.03794756,
199 -0.45063099, -0.27706473, -0.01569179, -0.59940385, -0.35700127,
200 -0.48920721, -1.42635476, -1.3462478, -0.02565498, -0.30179568,
201 -0.6491698, -0.55017719, -2.92291466, -0.92522973, -0.47592022,
202 -0.07099135, -0.31575624, -0.86345281, -0.36017021, -0.79208612,
203 -1.75306124, -0.65089224, -0.00912786, -0.42915003, -1.72606203,
204 -1.66337589, -0.70800793, -2.52272352, -0.67329562, -2.49145522,
205 -0.49786342});
206 m.PopulateTensor<int>(m.sequence_length(), {1, 2, 3});
207 m.Invoke();
208
209 // Make sure the output shapes are right.
210 const std::vector<std::vector<int>>& output_shapes = m.GetOutputShapes();
211 EXPECT_EQ(output_shapes.size(), 4);
212 EXPECT_THAT(output_shapes[0], ElementsAre(3, 2));
213 EXPECT_THAT(output_shapes[1], ElementsAre(3));
214 EXPECT_THAT(output_shapes[2], ElementsAre(2));
215 EXPECT_THAT(output_shapes[3], ElementsAre(3, 1));
216
217 // Check decoded outputs.
218 const std::vector<std::vector<int>>& decoded_outputs = m.GetDecodedOutpus();
219 EXPECT_EQ(decoded_outputs.size(), 3);
220 EXPECT_THAT(decoded_outputs[0], ElementsAre(0, 0, 1, 0, 2, 0));
221 EXPECT_THAT(decoded_outputs[1], ElementsAre(2, 0, 1));
222 EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 1));
223 // Check log probabilities output.
224 EXPECT_THAT(m.GetLogProbabilitiesOutput(),
225 ElementsAreArray(ArrayFloatNear({-0.97322, -1.16334, -2.15553})));
226 }
227
228 } // namespace
229 } // namespace experimental
230 } // namespace ops
231 } // namespace tflite
232
main(int argc,char ** argv)233 int main(int argc, char** argv) {
234 ::tflite::LogToStderr();
235 ::testing::InitGoogleTest(&argc, argv);
236 return RUN_ALL_TESTS();
237 }
238