1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Unit test for speech models (Hotword, SpeakerId) using TFLite Ops.
16 
17 #include <memory>
18 #include <string>
19 
20 #include <fstream>
21 
22 #include "testing/base/public/googletest.h"
23 #include <gtest/gtest.h>
24 #include "tensorflow/lite/testing/parse_testdata.h"
25 #include "tensorflow/lite/testing/split.h"
26 #include "tensorflow/lite/testing/tflite_driver.h"
27 
28 namespace tflite {
29 namespace {
30 
31 const char kDataPath[] = "third_party/tensorflow/lite/models/testdata/";
32 
Init(const string & in_file_name,testing::TfLiteDriver * driver,std::ifstream * in_file)33 bool Init(const string& in_file_name, testing::TfLiteDriver* driver,
34           std::ifstream* in_file) {
35   driver->SetModelBaseDir(kDataPath);
36   in_file->open(string(kDataPath) + in_file_name, std::ifstream::in);
37   return in_file->is_open();
38 }
39 
40 // Converts a set of test files provided by the speech team into a single
41 // test_spec. Input CSV files are supposed to contain a number of sequences per
42 // line. Each sequence maps to a single invocation of the interpreter and the
43 // output tensor after all sequences have run is compared to the corresponding
44 // line in the output CSV file.
ConvertCsvData(const string & model_name,const string & in_name,const string & out_name,const string & input_tensor,const string & output_tensor,const string & persistent_tensors,int sequence_size,std::ostream * out)45 bool ConvertCsvData(const string& model_name, const string& in_name,
46                     const string& out_name, const string& input_tensor,
47                     const string& output_tensor,
48                     const string& persistent_tensors, int sequence_size,
49                     std::ostream* out) {
50   auto data_path = [](const string& s) { return string(kDataPath) + s; };
51 
52   *out << "load_model: \"" << data_path(model_name) << "\"" << std::endl;
53 
54   *out << "init_state: \"" << persistent_tensors << "\"" << std::endl;
55 
56   string in_file_name = data_path(in_name);
57   std::ifstream in_file(in_file_name);
58   if (!in_file.is_open()) {
59     std::cerr << "Failed to open " << in_file_name << std::endl;
60     return false;
61   }
62   string out_file_name = data_path(out_name);
63   std::ifstream out_file(out_file_name);
64   if (!out_file.is_open()) {
65     std::cerr << "Failed to open " << out_file_name << std::endl;
66     return false;
67   }
68 
69   int invocation_count = 0;
70   string in_values;
71   while (std::getline(in_file, in_values, '\n')) {
72     std::vector<string> input = testing::Split<string>(in_values, ",");
73     int num_sequences = input.size() / sequence_size;
74 
75     for (int j = 0; j < num_sequences; ++j) {
76       *out << "invoke {" << std::endl;
77       *out << "  id: " << invocation_count << std::endl;
78       *out << "  input: \"";
79       for (int k = 0; k < sequence_size; ++k) {
80         *out << input[k + j * sequence_size] << ",";
81       }
82       *out << "\"" << std::endl;
83 
84       if (j == num_sequences - 1) {
85         string out_values;
86         if (!std::getline(out_file, out_values, '\n')) {
87           std::cerr << "Not enough lines in " << out_file_name << std::endl;
88           return false;
89         }
90         *out << "  output: \"" << out_values << "\"" << std::endl;
91       }
92 
93       *out << "}" << std::endl;
94       ++invocation_count;
95     }
96   }
97   return true;
98 }
99 
100 class SpeechTest : public ::testing::TestWithParam<int> {
101  protected:
GetMaxInvocations()102   int GetMaxInvocations() { return GetParam(); }
103 };
104 
TEST_P(SpeechTest,DISABLED_HotwordOkGoogleRank1Test)105 TEST_P(SpeechTest, DISABLED_HotwordOkGoogleRank1Test) {
106   std::stringstream os;
107   ASSERT_TRUE(ConvertCsvData(
108       "speech_hotword_model_rank1.tflite", "speech_hotword_model_in.csv",
109       "speech_hotword_model_out_rank1.csv", /*input_tensor=*/"0",
110       /*output_tensor=*/"18", /*persistent_tensors=*/"4",
111       /*sequence_size=*/40, &os));
112   testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
113   ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
114       << test_driver.GetErrorMessage();
115 }
116 
TEST_P(SpeechTest,DISABLED_HotwordOkGoogleRank2Test)117 TEST_P(SpeechTest, DISABLED_HotwordOkGoogleRank2Test) {
118   std::stringstream os;
119   ASSERT_TRUE(ConvertCsvData(
120       "speech_hotword_model_rank2.tflite", "speech_hotword_model_in.csv",
121       "speech_hotword_model_out_rank2.csv", /*input_tensor=*/"17",
122       /*output_tensor=*/"18", /*persistent_tensors=*/"1",
123       /*sequence_size=*/40, &os));
124   testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
125   ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
126       << test_driver.GetErrorMessage();
127 }
128 
TEST_P(SpeechTest,DISABLED_SpeakerIdOkGoogleTest)129 TEST_P(SpeechTest, DISABLED_SpeakerIdOkGoogleTest) {
130   std::stringstream os;
131   ASSERT_TRUE(ConvertCsvData(
132       "speech_speakerid_model.tflite", "speech_speakerid_model_in.csv",
133       "speech_speakerid_model_out.csv", /*input_tensor=*/"0",
134       /*output_tensor=*/"63",
135       /*persistent_tensors=*/"18,19,38,39,58,59",
136       /*sequence_size=*/80, &os));
137   testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
138   ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
139       << test_driver.GetErrorMessage();
140 }
141 
TEST_P(SpeechTest,AsrAmTest)142 TEST_P(SpeechTest, AsrAmTest) {
143   std::stringstream os;
144   ASSERT_TRUE(
145       ConvertCsvData("speech_asr_am_model.tflite", "speech_asr_am_model_in.csv",
146                      "speech_asr_am_model_out.csv", /*input_tensor=*/"0",
147                      /*output_tensor=*/"104",
148                      /*persistent_tensors=*/"18,19,38,39,58,59,78,79,98,99",
149                      /*sequence_size=*/320, &os));
150   testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
151   ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
152       << test_driver.GetErrorMessage();
153 }
154 
TEST_P(SpeechTest,AsrAmQuantizedTest)155 TEST_P(SpeechTest, AsrAmQuantizedTest) {
156   std::stringstream os;
157   ASSERT_TRUE(ConvertCsvData(
158       "speech_asr_am_model_int8.tflite", "speech_asr_am_model_in.csv",
159       "speech_asr_am_model_int8_out.csv", /*input_tensor=*/"0",
160       /*output_tensor=*/"104",
161       /*persistent_tensors=*/"18,19,38,39,58,59,78,79,98,99",
162       /*sequence_size=*/320, &os));
163   testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
164   ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
165       << test_driver.GetErrorMessage();
166 }
167 
168 // The original version of speech_asr_lm_model_test.cc ran a few sequences
169 // through the interpreter and stored the sum of all the output, which was them
170 // compared for correctness. In this test we are comparing all the intermediate
171 // results.
TEST_P(SpeechTest,DISABLED_AsrLmTest)172 TEST_P(SpeechTest, DISABLED_AsrLmTest) {
173   std::ifstream in_file;
174   testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
175   ASSERT_TRUE(Init("speech_asr_lm_model.test_spec", &test_driver, &in_file));
176   ASSERT_TRUE(
177       testing::ParseAndRunTests(&in_file, &test_driver, GetMaxInvocations()))
178       << test_driver.GetErrorMessage();
179 }
180 
TEST_P(SpeechTest,DISABLED_EndpointerTest)181 TEST_P(SpeechTest, DISABLED_EndpointerTest) {
182   std::stringstream os;
183   ASSERT_TRUE(ConvertCsvData(
184       "speech_endpointer_model.tflite", "speech_endpointer_model_in.csv",
185       "speech_endpointer_model_out.csv", /*input_tensor=*/"0",
186       /*output_tensor=*/"56",
187       /*persistent_tensors=*/"27,28,47,48",
188       /*sequence_size=*/320, &os));
189   testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
190   ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
191       << test_driver.GetErrorMessage();
192 }
193 
TEST_P(SpeechTest,DISABLED_TtsTest)194 TEST_P(SpeechTest, DISABLED_TtsTest) {
195   std::stringstream os;
196   ASSERT_TRUE(ConvertCsvData("speech_tts_model.tflite",
197                              "speech_tts_model_in.csv",
198                              "speech_tts_model_out.csv", /*input_tensor=*/"0",
199                              /*output_tensor=*/"71",
200                              /*persistent_tensors=*/"24,25,44,45,64,65,70",
201                              /*sequence_size=*/334, &os));
202   testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
203   ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
204       << test_driver.GetErrorMessage();
205 }
206 
207 // Define two instantiations. The "ShortTests" instantiations is used when
208 // running the tests on Android, in order to prevent timeouts (It takes about
209 // 200s just to bring up the Android emulator.)
210 static const int kAllInvocations = -1;
211 static const int kFirstFewInvocations = 10;
212 INSTANTIATE_TEST_SUITE_P(LongTests, SpeechTest,
213                          ::testing::Values(kAllInvocations));
214 INSTANTIATE_TEST_SUITE_P(ShortTests, SpeechTest,
215                          ::testing::Values(kFirstFewInvocations));
216 
217 }  // namespace
218 }  // namespace tflite
219