1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Utils related to keras metrics.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import functools
23import weakref
24
25from enum import Enum
26
27from tensorflow.python.distribute import distribution_strategy_context
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.keras.utils.generic_utils import to_list
31from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import check_ops
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops import nn_ops
37from tensorflow.python.ops import weights_broadcast_ops
38from tensorflow.python.util import tf_decorator
39
40NEG_INF = -1e10
41
42
43class Reduction(Enum):
44  """Types of metrics reduction.
45
46  Contains the following values:
47
48  * `SUM`: Scalar sum of weighted values.
49  * `SUM_OVER_BATCH_SIZE`: Scalar sum of weighted values divided by
50        number of elements.
51  * `WEIGHTED_MEAN`: Scalar sum of weighted values divided by sum of weights.
52  """
53  SUM = 'sum'
54  SUM_OVER_BATCH_SIZE = 'sum_over_batch_size'
55  WEIGHTED_MEAN = 'weighted_mean'
56
57
58def update_state_wrapper(update_state_fn):
59  """Decorator to wrap metric `update_state()` with `add_update()`.
60
61  Args:
62    update_state_fn: function that accumulates metric statistics.
63
64  Returns:
65    Decorated function that wraps `update_state_fn()` with `add_update()`.
66  """
67
68  def decorated(metric_obj, *args, **kwargs):
69    """Decorated function with `add_update()`."""
70
71    update_op = update_state_fn(*args, **kwargs)
72    if update_op is not None:  # update_op will be None in eager execution.
73      metric_obj.add_update(update_op, inputs=True)
74    return update_op
75
76  return tf_decorator.make_decorator(update_state_fn, decorated)
77
78
79def result_wrapper(result_fn):
80  """Decorator to wrap metric `result()` function in `merge_call()`.
81
82  Result computation is an idempotent operation that simply calculates the
83  metric value using the state variables.
84
85  If metric state variables are distributed across replicas/devices and
86  `result()` is requested from the context of one device - This function wraps
87  `result()` in a distribution strategy `merge_call()`. With this,
88  the metric state variables will be aggregated across devices.
89
90  Args:
91    result_fn: function that computes the metric result.
92
93  Returns:
94    Decorated function that wraps `result_fn()` in distribution strategy
95    `merge_call()`.
96  """
97
98  def decorated(_, *args):
99    """Decorated function with merge_call."""
100    replica_context = distribution_strategy_context.get_replica_context()
101    if replica_context is None:  # if in cross replica context already
102      result_t = array_ops.identity(result_fn(*args))
103    else:
104      # TODO(psv): Test distribution of metrics using different distribution
105      # strategies.
106
107      # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
108      # with distribution object as the first parameter. We create a wrapper
109      # here so that the result function need not have that parameter.
110      def merge_fn_wrapper(distribution, merge_fn, *args):
111        # We will get `PerDevice` merge function. Taking the first one as all
112        # are identical copies of the function that we had passed below.
113        merged_result_fn = distribution.unwrap(merge_fn)[0](*args)
114
115        # Wrapping result in identity so that control dependency between
116        # update_op from `update_state` and result works in case result returns
117        # a tensor.
118        return array_ops.identity(merged_result_fn)
119
120      # Wrapping result in merge_call. merge_call is used when we want to leave
121      # replica mode and compute a value in cross replica mode.
122      result_t = replica_context.merge_call(
123          merge_fn_wrapper, args=(result_fn,) + args)
124    return result_t
125
126  return tf_decorator.make_decorator(result_fn, decorated)
127
128
129def weakmethod(method):
130  """Creates a weak reference to the bound method."""
131
132  cls = method.im_class
133  func = method.im_func
134  instance_ref = weakref.ref(method.im_self)
135
136  @functools.wraps(method)
137  def inner(*args, **kwargs):
138    return func.__get__(instance_ref(), cls)(*args, **kwargs)
139
140  del method
141  return inner
142
143
144def assert_thresholds_range(thresholds):
145  if thresholds is not None:
146    invalid_thresholds = [t for t in thresholds if t is None or t < 0 or t > 1]
147    if invalid_thresholds:
148      raise ValueError(
149          'Threshold values must be in [0, 1]. Invalid values: {}'.format(
150              invalid_thresholds))
151
152
153def parse_init_thresholds(thresholds, default_threshold=0.5):
154  if thresholds is not None:
155    assert_thresholds_range(to_list(thresholds))
156  thresholds = to_list(default_threshold if thresholds is None else thresholds)
157  return thresholds
158
159
160class ConfusionMatrix(Enum):
161  TRUE_POSITIVES = 'tp'
162  FALSE_POSITIVES = 'fp'
163  TRUE_NEGATIVES = 'tn'
164  FALSE_NEGATIVES = 'fn'
165
166
167class AUCCurve(Enum):
168  """Type of AUC Curve (ROC or PR)."""
169  ROC = 'ROC'
170  PR = 'PR'
171
172  @staticmethod
173  def from_str(key):
174    if key in ('pr', 'PR'):
175      return AUCCurve.PR
176    elif key in ('roc', 'ROC'):
177      return AUCCurve.ROC
178    else:
179      raise ValueError('Invalid AUC curve value "%s".' % key)
180
181
182class AUCSummationMethod(Enum):
183  """Type of AUC summation method.
184
185  https://en.wikipedia.org/wiki/Riemann_sum)
186
187  Contains the following values:
188  * 'interpolation': Applies mid-point summation scheme for `ROC` curve. For
189    `PR` curve, interpolates (true/false) positives but not the ratio that is
190    precision (see Davis & Goadrich 2006 for details).
191  * 'minoring': Applies left summation for increasing intervals and right
192    summation for decreasing intervals.
193  * 'majoring': Applies right summation for increasing intervals and left
194    summation for decreasing intervals.
195  """
196  INTERPOLATION = 'interpolation'
197  MAJORING = 'majoring'
198  MINORING = 'minoring'
199
200  @staticmethod
201  def from_str(key):
202    if key in ('interpolation', 'Interpolation'):
203      return AUCSummationMethod.INTERPOLATION
204    elif key in ('majoring', 'Majoring'):
205      return AUCSummationMethod.MAJORING
206    elif key in ('minoring', 'Minoring'):
207      return AUCSummationMethod.MINORING
208    else:
209      raise ValueError('Invalid AUC summation method value "%s".' % key)
210
211
212def update_confusion_matrix_variables(variables_to_update,
213                                      y_true,
214                                      y_pred,
215                                      thresholds,
216                                      top_k=None,
217                                      class_id=None,
218                                      sample_weight=None):
219  """Returns op to update the given confusion matrix variables.
220
221  For every pair of values in y_true and y_pred:
222
223  true_positive: y_true == True and y_pred > thresholds
224  false_negatives: y_true == True and y_pred <= thresholds
225  true_negatives: y_true == False and y_pred <= thresholds
226  false_positive: y_true == False and y_pred > thresholds
227
228  The results will be weighted and added together. When multiple thresholds are
229  provided, we will repeat the same for every threshold.
230
231  For estimation of these metrics over a stream of data, the function creates an
232  `update_op` operation that updates the given variables.
233
234  If `sample_weight` is `None`, weights default to 1.
235  Use weights of 0 to mask values.
236
237  Args:
238    variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys
239      and corresponding variables to update as values.
240    y_true: A `Tensor` whose shape matches `y_pred`. Will be cast to `bool`.
241    y_pred: A floating point `Tensor` of arbitrary shape and whose values are in
242      the range `[0, 1]`.
243    thresholds: A float value or a python list or tuple of float thresholds in
244      `[0, 1]`, or NEG_INF (used when top_k is set).
245    top_k: Optional int, indicates that the positive labels should be limited to
246      the top k predictions.
247    class_id: Optional int, limits the prediction and labels to the class
248      specified by this argument.
249    sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as
250      `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions must
251      be either `1`, or the same as the corresponding `y_true` dimension).
252
253  Returns:
254    Update op.
255
256  Raises:
257    ValueError: If `y_pred` and `y_true` have mismatched shapes, or if
258      `sample_weight` is not `None` and its shape doesn't match `y_pred`, or if
259      `variables_to_update` contains invalid keys.
260  """
261  if variables_to_update is None:
262    return
263  y_true = math_ops.cast(y_true, dtype=dtypes.float32)
264  y_pred = math_ops.cast(y_pred, dtype=dtypes.float32)
265  y_pred.shape.assert_is_compatible_with(y_true.shape)
266
267  if not any(
268      key for key in variables_to_update if key in list(ConfusionMatrix)):
269    raise ValueError(
270        'Please provide at least one valid confusion matrix '
271        'variable to update. Valid variable key options are: "{}". '
272        'Received: "{}"'.format(
273            list(ConfusionMatrix), variables_to_update.keys()))
274
275  invalid_keys = [
276      key for key in variables_to_update if key not in list(ConfusionMatrix)
277  ]
278  if invalid_keys:
279    raise ValueError(
280        'Invalid keys: {}. Valid variable key options are: "{}"'.format(
281            invalid_keys, list(ConfusionMatrix)))
282
283  with ops.control_dependencies([
284      check_ops.assert_greater_equal(
285          y_pred,
286          math_ops.cast(0.0, dtype=y_pred.dtype),
287          message='predictions must be >= 0'),
288      check_ops.assert_less_equal(
289          y_pred,
290          math_ops.cast(1.0, dtype=y_pred.dtype),
291          message='predictions must be <= 1')
292  ]):
293    y_pred, y_true, sample_weight = squeeze_or_expand_dimensions(
294        y_pred, y_true, sample_weight)
295
296  if top_k is not None:
297    y_pred = _filter_top_k(y_pred, top_k)
298  if class_id is not None:
299    y_true = y_true[..., class_id]
300    y_pred = y_pred[..., class_id]
301
302  thresholds = to_list(thresholds)
303  num_thresholds = len(thresholds)
304  num_predictions = array_ops.size(y_pred)
305
306  # Reshape predictions and labels.
307  predictions_2d = array_ops.reshape(y_pred, [1, -1])
308  labels_2d = array_ops.reshape(
309      math_ops.cast(y_true, dtype=dtypes.bool), [1, -1])
310
311  # Tile the thresholds for every prediction.
312  thresh_tiled = array_ops.tile(
313      array_ops.expand_dims(array_ops.constant(thresholds), 1),
314      array_ops.stack([1, num_predictions]))
315
316  # Tile the predictions for every threshold.
317  preds_tiled = array_ops.tile(predictions_2d, [num_thresholds, 1])
318
319  # Compare predictions and threshold.
320  pred_is_pos = math_ops.greater(preds_tiled, thresh_tiled)
321
322  # Tile labels by number of thresholds
323  label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1])
324
325  if sample_weight is not None:
326    weights = weights_broadcast_ops.broadcast_weights(
327        math_ops.cast(sample_weight, dtype=dtypes.float32), y_pred)
328    weights_tiled = array_ops.tile(
329        array_ops.reshape(weights, [1, -1]), [num_thresholds, 1])
330  else:
331    weights_tiled = None
332
333  update_ops = []
334
335  def weighted_assign_add(label, pred, weights, var):
336    label_and_pred = math_ops.cast(
337        math_ops.logical_and(label, pred), dtype=dtypes.float32)
338    if weights is not None:
339      label_and_pred *= weights
340    return var.assign_add(math_ops.reduce_sum(label_and_pred, 1))
341
342  loop_vars = {
343      ConfusionMatrix.TRUE_POSITIVES: (label_is_pos, pred_is_pos),
344  }
345  update_tn = ConfusionMatrix.TRUE_NEGATIVES in variables_to_update
346  update_fp = ConfusionMatrix.FALSE_POSITIVES in variables_to_update
347  update_fn = ConfusionMatrix.FALSE_NEGATIVES in variables_to_update
348
349  if update_fn or update_tn:
350    pred_is_neg = math_ops.logical_not(pred_is_pos)
351    loop_vars[ConfusionMatrix.FALSE_NEGATIVES] = (label_is_pos, pred_is_neg)
352
353  if update_fp or update_tn:
354    label_is_neg = math_ops.logical_not(label_is_pos)
355    loop_vars[ConfusionMatrix.FALSE_POSITIVES] = (label_is_neg, pred_is_pos)
356    if update_tn:
357      loop_vars[ConfusionMatrix.TRUE_NEGATIVES] = (label_is_neg, pred_is_neg)
358
359  for matrix_cond, (label, pred) in loop_vars.items():
360    if matrix_cond in variables_to_update:
361      update_ops.append(
362          weighted_assign_add(label, pred, weights_tiled,
363                              variables_to_update[matrix_cond]))
364  return control_flow_ops.group(update_ops)
365
366
367def _filter_top_k(x, k):
368  """Filters top-k values in the last dim of x and set the rest to NEG_INF.
369
370  Used for computing top-k prediction values in dense labels (which has the same
371  shape as predictions) for recall and precision top-k metrics.
372
373  Args:
374    x: tensor with any dimensions.
375    k: the number of values to keep.
376
377  Returns:
378    tensor with same shape and dtype as x.
379  """
380  _, top_k_idx = nn_ops.top_k(x, k, sorted=False)
381  top_k_mask = math_ops.reduce_sum(
382      array_ops.one_hot(top_k_idx, x.shape[-1], axis=-1), axis=-2)
383  return x * top_k_mask + NEG_INF * (1 - top_k_mask)
384