1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utils for getting accuracy statistics.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21import tensorflow as tf 22 23 24class StreamingAccuracyStats(object): 25 """Get streaming accuracy statistics every time a new command is founded. 26 27 Attributes: 28 _how_many_gt: How many ground truths. 29 _how_many_gt_matched: How many ground truths have been matched. 30 _how_many_fp: How many commands have been fired as false positive. 31 _how_many_c: How many commands have been fired correctly. 32 _how_many_w: How many commands have been fired wrongly. 33 _gt_occurrence: A list to record which commands and when it occurs in the 34 input audio stream. 35 _previous_c: A variable to record the last status of _how_many_c. 36 _previous_w: A variable to record the last status of _how_many_w. 37 _previous_fp: A variable to record the last status of _how_many_fp. 38 """ 39 40 def __init__(self): 41 """Init StreamingAccuracyStats with void or zero values.""" 42 self._how_many_gt = 0 43 self._how_many_gt_matched = 0 44 self._how_many_fp = 0 45 self._how_many_c = 0 46 self._how_many_w = 0 47 self._gt_occurrence = [] 48 self._previous_c = 0 49 self._previous_w = 0 50 self._previous_fp = 0 51 52 def read_ground_truth_file(self, file_name): 53 """Load ground truth and timestamp pairs and store it in time order.""" 54 with open(file_name, 'r') as f: 55 for line in f: 56 line_split = line.strip().split(',') 57 if len(line_split) != 2: 58 continue 59 timestamp = round(float(line_split[1])) 60 label = line_split[0] 61 self._gt_occurrence.append([label, timestamp]) 62 self._gt_occurrence = sorted(self._gt_occurrence, key=lambda item: item[1]) 63 64 def delta(self): 65 """Compute delta of StreamingAccuracyStats against last status.""" 66 fp_delta = self._how_many_fp - self._previous_fp 67 w_delta = self._how_many_w - self._previous_w 68 c_delta = self._how_many_c - self._previous_c 69 if fp_delta == 1: 70 recognition_state = '(False Positive)' 71 elif c_delta == 1: 72 recognition_state = '(Correct)' 73 elif w_delta == 1: 74 recognition_state = '(Wrong)' 75 else: 76 raise ValueError('Unexpected state in statistics') 77 # Update the previous status 78 self._previous_c = self._how_many_c 79 self._previous_w = self._how_many_w 80 self._previous_fp = self._how_many_fp 81 return recognition_state 82 83 def calculate_accuracy_stats(self, found_words, up_to_time_ms, 84 time_tolerance_ms): 85 """Calculate accuracy statistics when a new commands is founded. 86 87 Given ground truth and corresponding predictions founded by 88 model, figure out how many were correct. Take a tolerance time, so that only 89 predictions up to a point in time are considered. 90 91 Args: 92 found_words: A list of all founded commands up to now. 93 up_to_time_ms: End timestamp of this audio piece. 94 time_tolerance_ms: The tolerance milliseconds before and after 95 up_to_time_ms to match a ground truth. 96 """ 97 if up_to_time_ms == -1: 98 latest_possible_time = np.inf 99 else: 100 latest_possible_time = up_to_time_ms + time_tolerance_ms 101 self._how_many_gt = 0 102 for ground_truth in self._gt_occurrence: 103 ground_truth_time = ground_truth[1] 104 if ground_truth_time > latest_possible_time: 105 break 106 self._how_many_gt += 1 107 self._how_many_fp = 0 108 self._how_many_c = 0 109 self._how_many_w = 0 110 has_gt_matched = [] 111 for found_word in found_words: 112 found_label = found_word[0] 113 found_time = found_word[1] 114 earliest_time = found_time - time_tolerance_ms 115 latest_time = found_time + time_tolerance_ms 116 has_matched_been_found = False 117 for ground_truth in self._gt_occurrence: 118 ground_truth_time = ground_truth[1] 119 if (ground_truth_time > latest_time or 120 ground_truth_time > latest_possible_time): 121 break 122 if ground_truth_time < earliest_time: 123 continue 124 ground_truth_label = ground_truth[0] 125 if (ground_truth_label == found_label and 126 has_gt_matched.count(ground_truth_time) == 0): 127 self._how_many_c += 1 128 else: 129 self._how_many_w += 1 130 has_gt_matched.append(ground_truth_time) 131 has_matched_been_found = True 132 break 133 if not has_matched_been_found: 134 self._how_many_fp += 1 135 self._how_many_gt_matched = len(has_gt_matched) 136 137 def print_accuracy_stats(self): 138 """Write a human-readable description of the statistics to stdout.""" 139 if self._how_many_gt == 0: 140 tf.compat.v1.logging.info('No ground truth yet, {}false positives'.format( 141 self._how_many_fp)) 142 else: 143 any_match_percentage = self._how_many_gt_matched / self._how_many_gt * 100 144 correct_match_percentage = self._how_many_c / self._how_many_gt * 100 145 wrong_match_percentage = self._how_many_w / self._how_many_gt * 100 146 false_positive_percentage = self._how_many_fp / self._how_many_gt * 100 147 tf.compat.v1.logging.info( 148 '{:.1f}% matched, {:.1f}% correct, {:.1f}% wrong, ' 149 '{:.1f}% false positive'.format(any_match_percentage, 150 correct_match_percentage, 151 wrong_match_percentage, 152 false_positive_percentage)) 153