1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Unit test for speech models (Hotword, SpeakerId) using TFLite Ops.
16
17 #include <memory>
18 #include <string>
19
20 #include <fstream>
21
22 #include "testing/base/public/googletest.h"
23 #include <gtest/gtest.h>
24 #include "tensorflow/lite/testing/parse_testdata.h"
25 #include "tensorflow/lite/testing/split.h"
26 #include "tensorflow/lite/testing/tflite_driver.h"
27
28 namespace tflite {
29 namespace {
30
31 const char kDataPath[] = "third_party/tensorflow/lite/models/testdata/";
32
Init(const string & in_file_name,testing::TfLiteDriver * driver,std::ifstream * in_file)33 bool Init(const string& in_file_name, testing::TfLiteDriver* driver,
34 std::ifstream* in_file) {
35 driver->SetModelBaseDir(kDataPath);
36 in_file->open(string(kDataPath) + in_file_name, std::ifstream::in);
37 return in_file->is_open();
38 }
39
40 // Converts a set of test files provided by the speech team into a single
41 // test_spec. Input CSV files are supposed to contain a number of sequences per
42 // line. Each sequence maps to a single invocation of the interpreter and the
43 // output tensor after all sequences have run is compared to the corresponding
44 // line in the output CSV file.
ConvertCsvData(const string & model_name,const string & in_name,const string & out_name,const string & input_tensor,const string & output_tensor,const string & persistent_tensors,int sequence_size,std::ostream * out)45 bool ConvertCsvData(const string& model_name, const string& in_name,
46 const string& out_name, const string& input_tensor,
47 const string& output_tensor,
48 const string& persistent_tensors, int sequence_size,
49 std::ostream* out) {
50 auto data_path = [](const string& s) { return string(kDataPath) + s; };
51
52 *out << "load_model: \"" << data_path(model_name) << "\"" << std::endl;
53
54 *out << "init_state: \"" << persistent_tensors << "\"" << std::endl;
55
56 string in_file_name = data_path(in_name);
57 std::ifstream in_file(in_file_name);
58 if (!in_file.is_open()) {
59 std::cerr << "Failed to open " << in_file_name << std::endl;
60 return false;
61 }
62 string out_file_name = data_path(out_name);
63 std::ifstream out_file(out_file_name);
64 if (!out_file.is_open()) {
65 std::cerr << "Failed to open " << out_file_name << std::endl;
66 return false;
67 }
68
69 int invocation_count = 0;
70 string in_values;
71 while (std::getline(in_file, in_values, '\n')) {
72 std::vector<string> input = testing::Split<string>(in_values, ",");
73 int num_sequences = input.size() / sequence_size;
74
75 for (int j = 0; j < num_sequences; ++j) {
76 *out << "invoke {" << std::endl;
77 *out << " id: " << invocation_count << std::endl;
78 *out << " input: \"";
79 for (int k = 0; k < sequence_size; ++k) {
80 *out << input[k + j * sequence_size] << ",";
81 }
82 *out << "\"" << std::endl;
83
84 if (j == num_sequences - 1) {
85 string out_values;
86 if (!std::getline(out_file, out_values, '\n')) {
87 std::cerr << "Not enough lines in " << out_file_name << std::endl;
88 return false;
89 }
90 *out << " output: \"" << out_values << "\"" << std::endl;
91 }
92
93 *out << "}" << std::endl;
94 ++invocation_count;
95 }
96 }
97 return true;
98 }
99
100 class SpeechTest : public ::testing::TestWithParam<int> {
101 protected:
GetMaxInvocations()102 int GetMaxInvocations() { return GetParam(); }
103 };
104
TEST_P(SpeechTest,DISABLED_HotwordOkGoogleRank1Test)105 TEST_P(SpeechTest, DISABLED_HotwordOkGoogleRank1Test) {
106 std::stringstream os;
107 ASSERT_TRUE(ConvertCsvData(
108 "speech_hotword_model_rank1.tflite", "speech_hotword_model_in.csv",
109 "speech_hotword_model_out_rank1.csv", /*input_tensor=*/"0",
110 /*output_tensor=*/"18", /*persistent_tensors=*/"4",
111 /*sequence_size=*/40, &os));
112 testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
113 ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
114 << test_driver.GetErrorMessage();
115 }
116
TEST_P(SpeechTest,DISABLED_HotwordOkGoogleRank2Test)117 TEST_P(SpeechTest, DISABLED_HotwordOkGoogleRank2Test) {
118 std::stringstream os;
119 ASSERT_TRUE(ConvertCsvData(
120 "speech_hotword_model_rank2.tflite", "speech_hotword_model_in.csv",
121 "speech_hotword_model_out_rank2.csv", /*input_tensor=*/"17",
122 /*output_tensor=*/"18", /*persistent_tensors=*/"1",
123 /*sequence_size=*/40, &os));
124 testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
125 ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
126 << test_driver.GetErrorMessage();
127 }
128
TEST_P(SpeechTest,DISABLED_SpeakerIdOkGoogleTest)129 TEST_P(SpeechTest, DISABLED_SpeakerIdOkGoogleTest) {
130 std::stringstream os;
131 ASSERT_TRUE(ConvertCsvData(
132 "speech_speakerid_model.tflite", "speech_speakerid_model_in.csv",
133 "speech_speakerid_model_out.csv", /*input_tensor=*/"0",
134 /*output_tensor=*/"63",
135 /*persistent_tensors=*/"18,19,38,39,58,59",
136 /*sequence_size=*/80, &os));
137 testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
138 ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
139 << test_driver.GetErrorMessage();
140 }
141
TEST_P(SpeechTest,AsrAmTest)142 TEST_P(SpeechTest, AsrAmTest) {
143 std::stringstream os;
144 ASSERT_TRUE(
145 ConvertCsvData("speech_asr_am_model.tflite", "speech_asr_am_model_in.csv",
146 "speech_asr_am_model_out.csv", /*input_tensor=*/"0",
147 /*output_tensor=*/"104",
148 /*persistent_tensors=*/"18,19,38,39,58,59,78,79,98,99",
149 /*sequence_size=*/320, &os));
150 testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
151 ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
152 << test_driver.GetErrorMessage();
153 }
154
TEST_P(SpeechTest,AsrAmQuantizedTest)155 TEST_P(SpeechTest, AsrAmQuantizedTest) {
156 std::stringstream os;
157 ASSERT_TRUE(ConvertCsvData(
158 "speech_asr_am_model_int8.tflite", "speech_asr_am_model_in.csv",
159 "speech_asr_am_model_int8_out.csv", /*input_tensor=*/"0",
160 /*output_tensor=*/"104",
161 /*persistent_tensors=*/"18,19,38,39,58,59,78,79,98,99",
162 /*sequence_size=*/320, &os));
163 testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
164 ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
165 << test_driver.GetErrorMessage();
166 }
167
168 // The original version of speech_asr_lm_model_test.cc ran a few sequences
169 // through the interpreter and stored the sum of all the output, which was them
170 // compared for correctness. In this test we are comparing all the intermediate
171 // results.
TEST_P(SpeechTest,DISABLED_AsrLmTest)172 TEST_P(SpeechTest, DISABLED_AsrLmTest) {
173 std::ifstream in_file;
174 testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
175 ASSERT_TRUE(Init("speech_asr_lm_model.test_spec", &test_driver, &in_file));
176 ASSERT_TRUE(
177 testing::ParseAndRunTests(&in_file, &test_driver, GetMaxInvocations()))
178 << test_driver.GetErrorMessage();
179 }
180
TEST_P(SpeechTest,DISABLED_EndpointerTest)181 TEST_P(SpeechTest, DISABLED_EndpointerTest) {
182 std::stringstream os;
183 ASSERT_TRUE(ConvertCsvData(
184 "speech_endpointer_model.tflite", "speech_endpointer_model_in.csv",
185 "speech_endpointer_model_out.csv", /*input_tensor=*/"0",
186 /*output_tensor=*/"56",
187 /*persistent_tensors=*/"27,28,47,48",
188 /*sequence_size=*/320, &os));
189 testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
190 ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
191 << test_driver.GetErrorMessage();
192 }
193
TEST_P(SpeechTest,DISABLED_TtsTest)194 TEST_P(SpeechTest, DISABLED_TtsTest) {
195 std::stringstream os;
196 ASSERT_TRUE(ConvertCsvData("speech_tts_model.tflite",
197 "speech_tts_model_in.csv",
198 "speech_tts_model_out.csv", /*input_tensor=*/"0",
199 /*output_tensor=*/"71",
200 /*persistent_tensors=*/"24,25,44,45,64,65,70",
201 /*sequence_size=*/334, &os));
202 testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
203 ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
204 << test_driver.GetErrorMessage();
205 }
206
207 // Define two instantiations. The "ShortTests" instantiations is used when
208 // running the tests on Android, in order to prevent timeouts (It takes about
209 // 200s just to bring up the Android emulator.)
210 static const int kAllInvocations = -1;
211 static const int kFirstFewInvocations = 10;
212 INSTANTIATE_TEST_SUITE_P(LongTests, SpeechTest,
213 ::testing::Values(kAllInvocations));
214 INSTANTIATE_TEST_SUITE_P(ShortTests, SpeechTest,
215 ::testing::Values(kFirstFewInvocations));
216
217 } // namespace
218 } // namespace tflite
219