1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_
17 #define TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_
18 
19 #include <vector>
20 
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/platform/types.h"
23 
24 namespace tensorflow {
25 
26 struct StreamingAccuracyStats {
StreamingAccuracyStatsStreamingAccuracyStats27   StreamingAccuracyStats()
28       : how_many_ground_truth_words(0),
29         how_many_ground_truth_matched(0),
30         how_many_false_positives(0),
31         how_many_correct_words(0),
32         how_many_wrong_words(0) {}
33   int32 how_many_ground_truth_words;
34   int32 how_many_ground_truth_matched;
35   int32 how_many_false_positives;
36   int32 how_many_correct_words;
37   int32 how_many_wrong_words;
38 };
39 
40 // Takes a file name, and loads a list of expected word labels and times from
41 // it, as comma-separated variables.
42 Status ReadGroundTruthFile(const string& file_name,
43                            std::vector<std::pair<string, int64>>* result);
44 
45 // Given ground truth labels and corresponding predictions found by a model,
46 // figure out how many were correct. Takes a time limit, so that only
47 // predictions up to a point in time are considered, in case we're evaluating
48 // accuracy when the model has only been run on part of the stream.
49 void CalculateAccuracyStats(
50     const std::vector<std::pair<string, int64>>& ground_truth_list,
51     const std::vector<std::pair<string, int64>>& found_words,
52     int64 up_to_time_ms, int64 time_tolerance_ms,
53     StreamingAccuracyStats* stats);
54 
55 // Writes a human-readable description of the statistics to stdout.
56 void PrintAccuracyStats(const StreamingAccuracyStats& stats);
57 
58 }  // namespace tensorflow
59 
60 #endif  // TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_
61