1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utils for getting accuracy statistics."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21import tensorflow as tf
22
23
24class StreamingAccuracyStats(object):
25  """Get streaming accuracy statistics every time a new command is founded.
26
27    Attributes:
28      _how_many_gt: How many ground truths.
29      _how_many_gt_matched: How many ground truths have been matched.
30      _how_many_fp: How many commands have been fired as false positive.
31      _how_many_c: How many commands have been fired correctly.
32      _how_many_w: How many commands have been fired wrongly.
33      _gt_occurrence: A list to record which commands and when it occurs in the
34        input audio stream.
35      _previous_c: A variable to record the last status of _how_many_c.
36      _previous_w: A variable to record the last status of _how_many_w.
37      _previous_fp: A variable to record the last status of _how_many_fp.
38  """
39
40  def __init__(self):
41    """Init StreamingAccuracyStats with void or zero values."""
42    self._how_many_gt = 0
43    self._how_many_gt_matched = 0
44    self._how_many_fp = 0
45    self._how_many_c = 0
46    self._how_many_w = 0
47    self._gt_occurrence = []
48    self._previous_c = 0
49    self._previous_w = 0
50    self._previous_fp = 0
51
52  def read_ground_truth_file(self, file_name):
53    """Load ground truth and timestamp pairs and store it in time order."""
54    with open(file_name, 'r') as f:
55      for line in f:
56        line_split = line.strip().split(',')
57        if len(line_split) != 2:
58          continue
59        timestamp = round(float(line_split[1]))
60        label = line_split[0]
61        self._gt_occurrence.append([label, timestamp])
62    self._gt_occurrence = sorted(self._gt_occurrence, key=lambda item: item[1])
63
64  def delta(self):
65    """Compute delta of StreamingAccuracyStats against last status."""
66    fp_delta = self._how_many_fp - self._previous_fp
67    w_delta = self._how_many_w - self._previous_w
68    c_delta = self._how_many_c - self._previous_c
69    if fp_delta == 1:
70      recognition_state = '(False Positive)'
71    elif c_delta == 1:
72      recognition_state = '(Correct)'
73    elif w_delta == 1:
74      recognition_state = '(Wrong)'
75    else:
76      raise ValueError('Unexpected state in statistics')
77    # Update the previous status
78    self._previous_c = self._how_many_c
79    self._previous_w = self._how_many_w
80    self._previous_fp = self._how_many_fp
81    return recognition_state
82
83  def calculate_accuracy_stats(self, found_words, up_to_time_ms,
84                               time_tolerance_ms):
85    """Calculate accuracy statistics when a new commands is founded.
86
87    Given ground truth and corresponding predictions founded by
88    model, figure out how many were correct. Take a tolerance time, so that only
89    predictions up to a point in time are considered.
90
91    Args:
92        found_words: A list of all founded commands up to now.
93        up_to_time_ms: End timestamp of this audio piece.
94        time_tolerance_ms: The tolerance milliseconds before and after
95          up_to_time_ms to match a ground truth.
96    """
97    if up_to_time_ms == -1:
98      latest_possible_time = np.inf
99    else:
100      latest_possible_time = up_to_time_ms + time_tolerance_ms
101    self._how_many_gt = 0
102    for ground_truth in self._gt_occurrence:
103      ground_truth_time = ground_truth[1]
104      if ground_truth_time > latest_possible_time:
105        break
106      self._how_many_gt += 1
107    self._how_many_fp = 0
108    self._how_many_c = 0
109    self._how_many_w = 0
110    has_gt_matched = []
111    for found_word in found_words:
112      found_label = found_word[0]
113      found_time = found_word[1]
114      earliest_time = found_time - time_tolerance_ms
115      latest_time = found_time + time_tolerance_ms
116      has_matched_been_found = False
117      for ground_truth in self._gt_occurrence:
118        ground_truth_time = ground_truth[1]
119        if (ground_truth_time > latest_time or
120            ground_truth_time > latest_possible_time):
121          break
122        if ground_truth_time < earliest_time:
123          continue
124        ground_truth_label = ground_truth[0]
125        if (ground_truth_label == found_label and
126            has_gt_matched.count(ground_truth_time) == 0):
127          self._how_many_c += 1
128        else:
129          self._how_many_w += 1
130        has_gt_matched.append(ground_truth_time)
131        has_matched_been_found = True
132        break
133      if not has_matched_been_found:
134        self._how_many_fp += 1
135    self._how_many_gt_matched = len(has_gt_matched)
136
137  def print_accuracy_stats(self):
138    """Write a human-readable description of the statistics to stdout."""
139    if self._how_many_gt == 0:
140      tf.compat.v1.logging.info('No ground truth yet, {}false positives'.format(
141          self._how_many_fp))
142    else:
143      any_match_percentage = self._how_many_gt_matched / self._how_many_gt * 100
144      correct_match_percentage = self._how_many_c / self._how_many_gt * 100
145      wrong_match_percentage = self._how_many_w / self._how_many_gt * 100
146      false_positive_percentage = self._how_many_fp / self._how_many_gt * 100
147      tf.compat.v1.logging.info(
148          '{:.1f}% matched, {:.1f}% correct, {:.1f}% wrong, '
149          '{:.1f}% false positive'.format(any_match_percentage,
150                                          correct_match_percentage,
151                                          wrong_match_percentage,
152                                          false_positive_percentage))
153