1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 /*
16
17 Tool to create accuracy statistics from running an audio recognition model on a
18 continuous stream of samples.
19
20 This is designed to be an environment for running experiments on new models and
21 settings to understand the effects they will have in a real application. You
22 need to supply it with a long audio file containing sounds you want to recognize
23 and a text file listing the labels of each sound along with the time they occur.
24 With this information, and a frozen model, the tool will process the audio
25 stream, apply the model, and keep track of how many mistakes and successes the
26 model achieved.
27
28 The matched percentage is the number of sounds that were correctly classified,
29 as a percentage of the total number of sounds listed in the ground truth file.
30 A correct classification is when the right label is chosen within a short time
31 of the expected ground truth, where the time tolerance is controlled by the
32 'time_tolerance_ms' command line flag.
33
34 The wrong percentage is how many sounds triggered a detection (the classifier
35 figured out it wasn't silence or background noise), but the detected class was
36 wrong. This is also a percentage of the total number of ground truth sounds.
37
38 The false positive percentage is how many sounds were detected when there was
39 only silence or background noise. This is also expressed as a percentage of the
40 total number of ground truth sounds, though since it can be large it may go
41 above 100%.
42
43 The easiest way to get an audio file and labels to test with is by using the
44 'generate_streaming_test_wav' script. This will synthesize a test file with
45 randomly placed sounds and background noise, and output a text file with the
46 ground truth.
47
48 If you want to test natural data, you need to use a .wav with the same sample
49 rate as your model (often 16,000 samples per second), and note down where the
50 sounds occur in time. Save this information out as a comma-separated text file,
51 where the first column is the label and the second is the time in seconds from
52 the start of the file that it occurs.
53
54 Here's an example of how to run the tool:
55
56 bazel run tensorflow/examples/speech_commands:test_streaming_accuracy -- \
57 --wav=/tmp/streaming_test_bg.wav \
58 --graph=/tmp/conv_frozen.pb \
59 --labels=/tmp/speech_commands_train/conv_labels.txt \
60 --ground_truth=/tmp/streaming_test_labels.txt --verbose \
61 --clip_duration_ms=1000 --detection_threshold=0.70 --average_window_ms=500 \
62 --suppression_ms=500 --time_tolerance_ms=1500
63
64 */
65
66 #include <fstream>
67 #include <iomanip>
68 #include <unordered_set>
69 #include <vector>
70
71 #include "tensorflow/core/framework/tensor.h"
72 #include "tensorflow/core/lib/io/path.h"
73 #include "tensorflow/core/lib/strings/numbers.h"
74 #include "tensorflow/core/lib/strings/str_util.h"
75 #include "tensorflow/core/lib/wav/wav_io.h"
76 #include "tensorflow/core/platform/init_main.h"
77 #include "tensorflow/core/platform/logging.h"
78 #include "tensorflow/core/platform/types.h"
79 #include "tensorflow/core/public/session.h"
80 #include "tensorflow/core/util/command_line_flags.h"
81 #include "tensorflow/examples/speech_commands/accuracy_utils.h"
82 #include "tensorflow/examples/speech_commands/recognize_commands.h"
83
84 // These are all common classes it's handy to reference with no namespace.
85 using tensorflow::Flag;
86 using tensorflow::Status;
87 using tensorflow::Tensor;
88 using tensorflow::int32;
89 using tensorflow::int64;
90 using tensorflow::string;
91 using tensorflow::uint16;
92 using tensorflow::uint32;
93
94 namespace {
95
96 // Reads a model graph definition from disk, and creates a session object you
97 // can use to run it.
LoadGraph(const string & graph_file_name,std::unique_ptr<tensorflow::Session> * session)98 Status LoadGraph(const string& graph_file_name,
99 std::unique_ptr<tensorflow::Session>* session) {
100 tensorflow::GraphDef graph_def;
101 Status load_graph_status =
102 ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
103 if (!load_graph_status.ok()) {
104 return tensorflow::errors::NotFound("Failed to load compute graph at '",
105 graph_file_name, "'");
106 }
107 session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
108 Status session_create_status = (*session)->Create(graph_def);
109 if (!session_create_status.ok()) {
110 return session_create_status;
111 }
112 return Status::OK();
113 }
114
115 // Takes a file name, and loads a list of labels from it, one per line, and
116 // returns a vector of the strings.
ReadLabelsFile(const string & file_name,std::vector<string> * result)117 Status ReadLabelsFile(const string& file_name, std::vector<string>* result) {
118 std::ifstream file(file_name);
119 if (!file) {
120 return tensorflow::errors::NotFound("Labels file '", file_name,
121 "' not found.");
122 }
123 result->clear();
124 string line;
125 while (std::getline(file, line)) {
126 result->push_back(line);
127 }
128 return Status::OK();
129 }
130
131 } // namespace
132
main(int argc,char * argv[])133 int main(int argc, char* argv[]) {
134 string wav = "";
135 string graph = "";
136 string labels = "";
137 string ground_truth = "";
138 string input_data_name = "decoded_sample_data:0";
139 string input_rate_name = "decoded_sample_data:1";
140 string output_name = "labels_softmax";
141 int32 clip_duration_ms = 1000;
142 int32 clip_stride_ms = 30;
143 int32 average_window_ms = 500;
144 int32 time_tolerance_ms = 750;
145 int32 suppression_ms = 1500;
146 float detection_threshold = 0.7f;
147 bool verbose = false;
148 std::vector<Flag> flag_list = {
149 Flag("wav", &wav, "audio file to be identified"),
150 Flag("graph", &graph, "model to be executed"),
151 Flag("labels", &labels, "path to file containing labels"),
152 Flag("ground_truth", &ground_truth,
153 "path to file containing correct times and labels of words in the "
154 "audio as <word>,<timestamp in ms> lines"),
155 Flag("input_data_name", &input_data_name,
156 "name of input data node in model"),
157 Flag("input_rate_name", &input_rate_name,
158 "name of input sample rate node in model"),
159 Flag("output_name", &output_name, "name of output node in model"),
160 Flag("clip_duration_ms", &clip_duration_ms,
161 "length of recognition window"),
162 Flag("average_window_ms", &average_window_ms,
163 "length of window to smooth results over"),
164 Flag("time_tolerance_ms", &time_tolerance_ms,
165 "maximum gap allowed between a recognition and ground truth"),
166 Flag("suppression_ms", &suppression_ms,
167 "how long to ignore others for after a recognition"),
168 Flag("clip_stride_ms", &clip_stride_ms, "how often to run recognition"),
169 Flag("detection_threshold", &detection_threshold,
170 "what score is required to trigger detection of a word"),
171 Flag("verbose", &verbose, "whether to log extra debugging information"),
172 };
173 string usage = tensorflow::Flags::Usage(argv[0], flag_list);
174 const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
175 if (!parse_result) {
176 LOG(ERROR) << usage;
177 return -1;
178 }
179
180 // We need to call this to set up global state for TensorFlow.
181 tensorflow::port::InitMain(argv[0], &argc, &argv);
182 if (argc > 1) {
183 LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
184 return -1;
185 }
186
187 // First we load and initialize the model.
188 std::unique_ptr<tensorflow::Session> session;
189 Status load_graph_status = LoadGraph(graph, &session);
190 if (!load_graph_status.ok()) {
191 LOG(ERROR) << load_graph_status;
192 return -1;
193 }
194
195 std::vector<string> labels_list;
196 Status read_labels_status = ReadLabelsFile(labels, &labels_list);
197 if (!read_labels_status.ok()) {
198 LOG(ERROR) << read_labels_status;
199 return -1;
200 }
201
202 std::vector<std::pair<string, tensorflow::int64>> ground_truth_list;
203 Status read_ground_truth_status =
204 tensorflow::ReadGroundTruthFile(ground_truth, &ground_truth_list);
205 if (!read_ground_truth_status.ok()) {
206 LOG(ERROR) << read_ground_truth_status;
207 return -1;
208 }
209
210 string wav_string;
211 Status read_wav_status = tensorflow::ReadFileToString(
212 tensorflow::Env::Default(), wav, &wav_string);
213 if (!read_wav_status.ok()) {
214 LOG(ERROR) << read_wav_status;
215 return -1;
216 }
217 std::vector<float> audio_data;
218 uint32 sample_count;
219 uint16 channel_count;
220 uint32 sample_rate;
221 Status decode_wav_status = tensorflow::wav::DecodeLin16WaveAsFloatVector(
222 wav_string, &audio_data, &sample_count, &channel_count, &sample_rate);
223 if (!decode_wav_status.ok()) {
224 LOG(ERROR) << decode_wav_status;
225 return -1;
226 }
227 if (channel_count != 1) {
228 LOG(ERROR) << "Only mono .wav files can be used, but input has "
229 << channel_count << " channels.";
230 return -1;
231 }
232
233 const int64 clip_duration_samples = (clip_duration_ms * sample_rate) / 1000;
234 const int64 clip_stride_samples = (clip_stride_ms * sample_rate) / 1000;
235 Tensor audio_data_tensor(tensorflow::DT_FLOAT,
236 tensorflow::TensorShape({clip_duration_samples, 1}));
237
238 Tensor sample_rate_tensor(tensorflow::DT_INT32, tensorflow::TensorShape({}));
239 sample_rate_tensor.scalar<int32>()() = sample_rate;
240
241 tensorflow::RecognizeCommands recognize_commands(
242 labels_list, average_window_ms, detection_threshold, suppression_ms);
243
244 std::vector<std::pair<string, int64>> all_found_words;
245 tensorflow::StreamingAccuracyStats previous_stats;
246
247 const int64 audio_data_end = (sample_count - clip_duration_samples);
248 for (int64 audio_data_offset = 0; audio_data_offset < audio_data_end;
249 audio_data_offset += clip_stride_samples) {
250 const float* input_start = &(audio_data[audio_data_offset]);
251 const float* input_end = input_start + clip_duration_samples;
252 std::copy(input_start, input_end, audio_data_tensor.flat<float>().data());
253
254 // Actually run the audio through the model.
255 std::vector<Tensor> outputs;
256 Status run_status = session->Run({{input_data_name, audio_data_tensor},
257 {input_rate_name, sample_rate_tensor}},
258 {output_name}, {}, &outputs);
259 if (!run_status.ok()) {
260 LOG(ERROR) << "Running model failed: " << run_status;
261 return -1;
262 }
263
264 const int64 current_time_ms = (audio_data_offset * 1000) / sample_rate;
265 string found_command;
266 float score;
267 bool is_new_command;
268 Status recognize_status = recognize_commands.ProcessLatestResults(
269 outputs[0], current_time_ms, &found_command, &score, &is_new_command);
270 if (!recognize_status.ok()) {
271 LOG(ERROR) << "Recognition processing failed: " << recognize_status;
272 return -1;
273 }
274
275 if (is_new_command && (found_command != "_silence_")) {
276 all_found_words.push_back({found_command, current_time_ms});
277 if (verbose) {
278 tensorflow::StreamingAccuracyStats stats;
279 tensorflow::CalculateAccuracyStats(ground_truth_list, all_found_words,
280 current_time_ms, time_tolerance_ms,
281 &stats);
282 int32 false_positive_delta = stats.how_many_false_positives -
283 previous_stats.how_many_false_positives;
284 int32 correct_delta = stats.how_many_correct_words -
285 previous_stats.how_many_correct_words;
286 int32 wrong_delta =
287 stats.how_many_wrong_words - previous_stats.how_many_wrong_words;
288 string recognition_state;
289 if (false_positive_delta == 1) {
290 recognition_state = " (False Positive)";
291 } else if (correct_delta == 1) {
292 recognition_state = " (Correct)";
293 } else if (wrong_delta == 1) {
294 recognition_state = " (Wrong)";
295 } else {
296 LOG(ERROR) << "Unexpected state in statistics";
297 }
298 LOG(INFO) << current_time_ms << "ms: " << found_command << ": " << score
299 << recognition_state;
300 previous_stats = stats;
301 tensorflow::PrintAccuracyStats(stats);
302 }
303 }
304 }
305
306 tensorflow::StreamingAccuracyStats stats;
307 tensorflow::CalculateAccuracyStats(ground_truth_list, all_found_words, -1,
308 time_tolerance_ms, &stats);
309 tensorflow::PrintAccuracyStats(stats);
310
311 return 0;
312 }
313