1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
7#     http://www.apache.org/licenses/LICENSE-2.0
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Utils related to keras metrics.
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
22import functools
23import weakref
25from enum import Enum
27from tensorflow.python.distribute import distribution_strategy_context
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.keras import backend
31from tensorflow.python.keras.utils import losses_utils
32from tensorflow.python.keras.utils import tf_utils
33from tensorflow.python.keras.utils.generic_utils import to_list
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import check_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.ops import gen_math_ops
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import nn_ops
40from tensorflow.python.ops import weights_broadcast_ops
41from tensorflow.python.ops.ragged import ragged_tensor
42from tensorflow.python.util import tf_decorator
44NEG_INF = -1e10
47class Reduction(Enum):
48  """Types of metrics reduction.
50  Contains the following values:
52  * `SUM`: Scalar sum of weighted values.
53  * `SUM_OVER_BATCH_SIZE`: Scalar sum of weighted values divided by
54        number of elements.
55  * `WEIGHTED_MEAN`: Scalar sum of weighted values divided by sum of weights.
56  """
57  SUM = 'sum'
58  SUM_OVER_BATCH_SIZE = 'sum_over_batch_size'
59  WEIGHTED_MEAN = 'weighted_mean'
62def update_state_wrapper(update_state_fn):
63  """Decorator to wrap metric `update_state()` with `add_update()`.
65  Args:
66    update_state_fn: function that accumulates metric statistics.
68  Returns:
69    Decorated function that wraps `update_state_fn()` with `add_update()`.
70  """
72  def decorated(metric_obj, *args, **kwargs):
73    """Decorated function with `add_update()`."""
74    strategy = distribution_strategy_context.get_strategy()
75    # TODO(b/142574744): Remove this check if a better solution is found for
76    # declaring keras Metric outside of TPUStrategy and then updating it per
77    # replica.
79    for weight in metric_obj.weights:
80      if (backend.is_tpu_strategy(strategy) and
81          not strategy.extended.variable_created_in_scope(weight)
82          and not distribution_strategy_context.in_cross_replica_context()):
83        raise ValueError(
84            'Trying to run metric.update_state in replica context when '
85            'the metric was not created in TPUStrategy scope. '
86            'Make sure the keras Metric is created in TPUstrategy scope. ')
88    with tf_utils.graph_context_for_symbolic_tensors(*args, **kwargs):
89      update_op = update_state_fn(*args, **kwargs)
90    if update_op is not None:  # update_op will be None in eager execution.
91      metric_obj.add_update(update_op)
92    return update_op
94  return tf_decorator.make_decorator(update_state_fn, decorated)
97def result_wrapper(result_fn):
98  """Decorator to wrap metric `result()` function in `merge_call()`.
100  Result computation is an idempotent operation that simply calculates the
101  metric value using the state variables.
103  If metric state variables are distributed across replicas/devices and
104  `result()` is requested from the context of one device - This function wraps
105  `result()` in a distribution strategy `merge_call()`. With this,
106  the metric state variables will be aggregated across devices.
108  Args:
109    result_fn: function that computes the metric result.
111  Returns:
112    Decorated function that wraps `result_fn()` in distribution strategy
113    `merge_call()`.
114  """
116  def decorated(metric_obj, *args):
117    """Decorated function with merge_call."""
118    has_strategy = distribution_strategy_context.has_strategy()
119    replica_context = distribution_strategy_context.get_replica_context()
120    if not has_strategy or replica_context is None:
121      result_t = array_ops.identity(result_fn(*args))
122    else:
123      # TODO(psv): Test distribution of metrics using different distribution
124      # strategies.
126      # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
127      # with distribution object as the first parameter. We create a wrapper
128      # here so that the result function need not have that parameter.
129      def merge_fn_wrapper(distribution, merge_fn, *args):
130        # We will get `PerReplica` merge function. Taking the first one as all
131        # are identical copies of the function that we had passed below.
132        result = distribution.experimental_local_results(merge_fn)[0](*args)
134        # Wrapping result in identity so that control dependency between
135        # update_op from `update_state` and result works in case result returns
136        # a tensor.
137        return array_ops.identity(result)
139      # Wrapping result in merge_call. merge_call is used when we want to leave
140      # replica mode and compute a value in cross replica mode.
141      result_t = replica_context.merge_call(
142          merge_fn_wrapper, args=(result_fn,) + args)
144    # We are saving the result op here to be used in train/test execution
145    # functions. This basically gives the result op that was generated with a
146    # control dep to the updates for these workflows.
147    metric_obj._call_result = result_t
148    return result_t
150  return tf_decorator.make_decorator(result_fn, decorated)
153def weakmethod(method):
154  """Creates a weak reference to the bound method."""
156  cls = method.im_class
157  func = method.im_func
158  instance_ref = weakref.ref(method.im_self)
160  @functools.wraps(method)
161  def inner(*args, **kwargs):
162    return func.__get__(instance_ref(), cls)(*args, **kwargs)
164  del method
165  return inner
168def assert_thresholds_range(thresholds):
169  if thresholds is not None:
170    invalid_thresholds = [t for t in thresholds if t is None or t < 0 or t > 1]
171    if invalid_thresholds:
172      raise ValueError(
173          'Threshold values must be in [0, 1]. Invalid values: {}'.format(
174              invalid_thresholds))
177def parse_init_thresholds(thresholds, default_threshold=0.5):
178  if thresholds is not None:
179    assert_thresholds_range(to_list(thresholds))
180  thresholds = to_list(default_threshold if thresholds is None else thresholds)
181  return thresholds
184class ConfusionMatrix(Enum):
185  TRUE_POSITIVES = 'tp'
187  TRUE_NEGATIVES = 'tn'
191class AUCCurve(Enum):
192  """Type of AUC Curve (ROC or PR)."""
193  ROC = 'ROC'
194  PR = 'PR'
196  @staticmethod
197  def from_str(key):
198    if key in ('pr', 'PR'):
199      return AUCCurve.PR
200    elif key in ('roc', 'ROC'):
201      return AUCCurve.ROC
202    else:
203      raise ValueError('Invalid AUC curve value "%s".' % key)
206class AUCSummationMethod(Enum):
207  """Type of AUC summation method.
209  https://en.wikipedia.org/wiki/Riemann_sum)
211  Contains the following values:
212  * 'interpolation': Applies mid-point summation scheme for `ROC` curve. For
213    `PR` curve, interpolates (true/false) positives but not the ratio that is
214    precision (see Davis & Goadrich 2006 for details).
215  * 'minoring': Applies left summation for increasing intervals and right
216    summation for decreasing intervals.
217  * 'majoring': Applies right summation for increasing intervals and left
218    summation for decreasing intervals.
219  """
220  INTERPOLATION = 'interpolation'
221  MAJORING = 'majoring'
222  MINORING = 'minoring'
224  @staticmethod
225  def from_str(key):
226    if key in ('interpolation', 'Interpolation'):
227      return AUCSummationMethod.INTERPOLATION
228    elif key in ('majoring', 'Majoring'):
229      return AUCSummationMethod.MAJORING
230    elif key in ('minoring', 'Minoring'):
231      return AUCSummationMethod.MINORING
232    else:
233      raise ValueError('Invalid AUC summation method value "%s".' % key)
236def update_confusion_matrix_variables(variables_to_update,
237                                      y_true,
238                                      y_pred,
239                                      thresholds,
240                                      top_k=None,
241                                      class_id=None,
242                                      sample_weight=None,
243                                      multi_label=False,
244                                      label_weights=None):
245  """Returns op to update the given confusion matrix variables.
247  For every pair of values in y_true and y_pred:
249  true_positive: y_true == True and y_pred > thresholds
250  false_negatives: y_true == True and y_pred <= thresholds
251  true_negatives: y_true == False and y_pred <= thresholds
252  false_positive: y_true == False and y_pred > thresholds
254  The results will be weighted and added together. When multiple thresholds are
255  provided, we will repeat the same for every threshold.
257  For estimation of these metrics over a stream of data, the function creates an
258  `update_op` operation that updates the given variables.
260  If `sample_weight` is `None`, weights default to 1.
261  Use weights of 0 to mask values.
263  Args:
264    variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys
265      and corresponding variables to update as values.
266    y_true: A `Tensor` whose shape matches `y_pred`. Will be cast to `bool`.
267    y_pred: A floating point `Tensor` of arbitrary shape and whose values are in
268      the range `[0, 1]`.
269    thresholds: A float value, float tensor, python list, or tuple of float
270      thresholds in `[0, 1]`, or NEG_INF (used when top_k is set).
271    top_k: Optional int, indicates that the positive labels should be limited to
272      the top k predictions.
273    class_id: Optional int, limits the prediction and labels to the class
274      specified by this argument.
275    sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as
276      `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions must
277      be either `1`, or the same as the corresponding `y_true` dimension).
278    multi_label: Optional boolean indicating whether multidimensional
279      prediction/labels should be treated as multilabel responses, or flattened
280      into a single label. When True, the valus of `variables_to_update` must
281      have a second dimension equal to the number of labels in y_true and
282      y_pred, and those tensors must not be RaggedTensors.
283    label_weights: (optional) tensor of non-negative weights for multilabel
284      data. The weights are applied when calculating TP, FP, FN, and TN without
285      explicit multilabel handling (i.e. when the data is to be flattened).
287  Returns:
288    Update op.
290  Raises:
291    ValueError: If `y_pred` and `y_true` have mismatched shapes, or if
292      `sample_weight` is not `None` and its shape doesn't match `y_pred`, or if
293      `variables_to_update` contains invalid keys.
294  """
295  if multi_label and label_weights is not None:
296    raise ValueError('`label_weights` for multilabel data should be handled '
297                     'outside of `update_confusion_matrix_variables` when '
298                     '`multi_label` is True.')
299  if variables_to_update is None:
300    return
301  if not any(
302      key for key in variables_to_update if key in list(ConfusionMatrix)):
303    raise ValueError(
304        'Please provide at least one valid confusion matrix '
305        'variable to update. Valid variable key options are: "{}". '
306        'Received: "{}"'.format(
307            list(ConfusionMatrix), variables_to_update.keys()))
309  variable_dtype = list(variables_to_update.values())[0].dtype
311  y_true = math_ops.cast(y_true, dtype=variable_dtype)
312  y_pred = math_ops.cast(y_pred, dtype=variable_dtype)
313  thresholds = ops.convert_to_tensor_v2_with_dispatch(
314      thresholds, dtype=variable_dtype)
315  num_thresholds = thresholds.shape[0]
316  if multi_label:
317    one_thresh = math_ops.equal(
318        math_ops.cast(1, dtype=dtypes.int32),
319        array_ops.rank(thresholds),
320        name='one_set_of_thresholds_cond')
321  else:
322    [y_pred,
323     y_true], _ = ragged_assert_compatible_and_get_flat_values([y_pred, y_true],
324                                                               sample_weight)
325    one_thresh = math_ops.cast(True, dtype=dtypes.bool)
327  invalid_keys = [
328      key for key in variables_to_update if key not in list(ConfusionMatrix)
329  ]
330  if invalid_keys:
331    raise ValueError(
332        'Invalid keys: {}. Valid variable key options are: "{}"'.format(
333            invalid_keys, list(ConfusionMatrix)))
335  with ops.control_dependencies([
336      check_ops.assert_greater_equal(
337          y_pred,
338          math_ops.cast(0.0, dtype=y_pred.dtype),
339          message='predictions must be >= 0'),
340      check_ops.assert_less_equal(
341          y_pred,
342          math_ops.cast(1.0, dtype=y_pred.dtype),
343          message='predictions must be <= 1')
344  ]):
345    if sample_weight is None:
346      y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(
347          y_pred, y_true)
348    else:
349      sample_weight = math_ops.cast(sample_weight, dtype=variable_dtype)
350      y_pred, y_true, sample_weight = (
351          losses_utils.squeeze_or_expand_dimensions(
352              y_pred, y_true, sample_weight=sample_weight))
353  y_pred.shape.assert_is_compatible_with(y_true.shape)
355  if top_k is not None:
356    y_pred = _filter_top_k(y_pred, top_k)
357  if class_id is not None:
358    y_true = y_true[..., class_id]
359    y_pred = y_pred[..., class_id]
361  pred_shape = array_ops.shape(y_pred)
362  num_predictions = pred_shape[0]
363  if y_pred.shape.ndims == 1:
364    num_labels = 1
365  else:
366    num_labels = gen_math_ops.Prod(input=pred_shape[1:], axis=0)
367  thresh_label_tile = control_flow_ops.cond(
368      one_thresh, lambda: num_labels,
369      lambda: math_ops.cast(1, dtype=dtypes.int32))
371  # Reshape predictions and labels, adding a dim for thresholding.
372  if multi_label:
373    predictions_extra_dim = array_ops.expand_dims(y_pred, 0)
374    labels_extra_dim = array_ops.expand_dims(
375        math_ops.cast(y_true, dtype=dtypes.bool), 0)
376  else:
377    # Flatten predictions and labels when not multilabel.
378    predictions_extra_dim = array_ops.reshape(y_pred, [1, -1])
379    labels_extra_dim = array_ops.reshape(
380        math_ops.cast(y_true, dtype=dtypes.bool), [1, -1])
382  # Tile the thresholds for every prediction.
383  if multi_label:
384    thresh_pretile_shape = [num_thresholds, 1, -1]
385    thresh_tiles = [1, num_predictions, thresh_label_tile]
386    data_tiles = [num_thresholds, 1, 1]
387  else:
388    thresh_pretile_shape = [num_thresholds, -1]
389    thresh_tiles = [1, num_predictions * num_labels]
390    data_tiles = [num_thresholds, 1]
392  thresh_tiled = array_ops.tile(
393      array_ops.reshape(thresholds, thresh_pretile_shape),
394      array_ops.stack(thresh_tiles))
396  # Tile the predictions for every threshold.
397  preds_tiled = array_ops.tile(predictions_extra_dim, data_tiles)
399  # Compare predictions and threshold.
400  pred_is_pos = math_ops.greater(preds_tiled, thresh_tiled)
402  # Tile labels by number of thresholds
403  label_is_pos = array_ops.tile(labels_extra_dim, data_tiles)
405  if sample_weight is not None:
406    sample_weight = weights_broadcast_ops.broadcast_weights(
407        math_ops.cast(sample_weight, dtype=variable_dtype), y_pred)
408    weights_tiled = array_ops.tile(
409        array_ops.reshape(sample_weight, thresh_tiles), data_tiles)
410  else:
411    weights_tiled = None
413  if label_weights is not None and not multi_label:
414    label_weights = array_ops.expand_dims(label_weights, 0)
415    label_weights = weights_broadcast_ops.broadcast_weights(label_weights,
416                                                            y_pred)
417    label_weights_tiled = array_ops.tile(
418        array_ops.reshape(label_weights, thresh_tiles), data_tiles)
419    if weights_tiled is None:
420      weights_tiled = label_weights_tiled
421    else:
422      weights_tiled = math_ops.multiply(weights_tiled, label_weights_tiled)
424  update_ops = []
426  def weighted_assign_add(label, pred, weights, var):
427    label_and_pred = math_ops.cast(
428        math_ops.logical_and(label, pred), dtype=var.dtype)
429    if weights is not None:
430      label_and_pred *= math_ops.cast(weights, dtype=var.dtype)
431    return var.assign_add(math_ops.reduce_sum(label_and_pred, 1))
433  loop_vars = {
434      ConfusionMatrix.TRUE_POSITIVES: (label_is_pos, pred_is_pos),
435  }
436  update_tn = ConfusionMatrix.TRUE_NEGATIVES in variables_to_update
437  update_fp = ConfusionMatrix.FALSE_POSITIVES in variables_to_update
438  update_fn = ConfusionMatrix.FALSE_NEGATIVES in variables_to_update
440  if update_fn or update_tn:
441    pred_is_neg = math_ops.logical_not(pred_is_pos)
442    loop_vars[ConfusionMatrix.FALSE_NEGATIVES] = (label_is_pos, pred_is_neg)
444  if update_fp or update_tn:
445    label_is_neg = math_ops.logical_not(label_is_pos)
446    loop_vars[ConfusionMatrix.FALSE_POSITIVES] = (label_is_neg, pred_is_pos)
447    if update_tn:
448      loop_vars[ConfusionMatrix.TRUE_NEGATIVES] = (label_is_neg, pred_is_neg)
450  for matrix_cond, (label, pred) in loop_vars.items():
452    if matrix_cond in variables_to_update:
453      update_ops.append(
454          weighted_assign_add(label, pred, weights_tiled,
455                              variables_to_update[matrix_cond]))
457  return control_flow_ops.group(update_ops)
460def _filter_top_k(x, k):
461  """Filters top-k values in the last dim of x and set the rest to NEG_INF.
463  Used for computing top-k prediction values in dense labels (which has the same
464  shape as predictions) for recall and precision top-k metrics.
466  Args:
467    x: tensor with any dimensions.
468    k: the number of values to keep.
470  Returns:
471    tensor with same shape and dtype as x.
472  """
473  _, top_k_idx = nn_ops.top_k(x, k, sorted=False)
474  top_k_mask = math_ops.reduce_sum(
475      array_ops.one_hot(top_k_idx, array_ops.shape(x)[-1], axis=-1), axis=-2)
476  return x * top_k_mask + NEG_INF * (1 - top_k_mask)
479def ragged_assert_compatible_and_get_flat_values(values, mask=None):
480  """If ragged, it checks the compatibility and then returns the flat_values.
482     Note: If two tensors are dense, it does not check their compatibility.
483     Note: Although two ragged tensors with different ragged ranks could have
484           identical overall rank and dimension sizes and hence be compatible,
485           we do not support those cases.
486  Args:
487     values: A list of potentially ragged tensor of the same ragged_rank.
488     mask: A potentially ragged tensor of the same ragged_rank as elements in
489       Values.
491  Returns:
492     A tuple in which the first element is the list of tensors and the second
493     is the mask tensor. ([Values], mask). Mask and the element in Values
494     are equal to the flat_values of the input arguments (if they were ragged).
495  """
496  if isinstance(values, list):
497    is_all_ragged = \
498        all(isinstance(rt, ragged_tensor.RaggedTensor) for rt in values)
499    is_any_ragged = \
500        any(isinstance(rt, ragged_tensor.RaggedTensor) for rt in values)
501  else:
502    is_all_ragged = isinstance(values, ragged_tensor.RaggedTensor)
503    is_any_ragged = is_all_ragged
504  if (is_all_ragged and
505      ((mask is None) or isinstance(mask, ragged_tensor.RaggedTensor))):
506    to_be_stripped = False
507    if not isinstance(values, list):
508      values = [values]
509      to_be_stripped = True
511    # NOTE: we leave the flat_values compatibility to
512    # tf.TensorShape `assert_is_compatible_with`
513    # check if both dynamic dimensions are equal and then use the flat_values.
514    nested_row_split_list = [rt.nested_row_splits for rt in values]
515    assertion_list = _assert_splits_match(nested_row_split_list)
517    # if both are ragged sample_weights also should be ragged with same dims.
518    if isinstance(mask, ragged_tensor.RaggedTensor):
519      assertion_list_for_mask = _assert_splits_match(
520          [nested_row_split_list[0], mask.nested_row_splits])
521      with ops.control_dependencies(assertion_list_for_mask):
522        mask = array_ops.expand_dims(mask.flat_values, -1)
524    # values has at least 1 element.
525    flat_values = []
526    for value in values:
527      with ops.control_dependencies(assertion_list):
528        flat_values.append(array_ops.expand_dims(value.flat_values, -1))
530    values = flat_values[0] if to_be_stripped else flat_values
532  elif is_any_ragged:
533    raise TypeError('One of the inputs does not have acceptable types.')
534  # values are empty or value are not ragged and mask is ragged.
535  elif isinstance(mask, ragged_tensor.RaggedTensor):
536    raise TypeError('Ragged mask is not allowed with non-ragged inputs.')
538  return values, mask
541def _assert_splits_match(nested_splits_lists):
542  """Checks that the given splits lists are identical.
544  Performs static tests to ensure that the given splits lists are identical,
545  and returns a list of control dependency op tensors that check that they are
546  fully identical.
548  Args:
549    nested_splits_lists: A list of nested_splits_lists, where each split_list is
550      a list of `splits` tensors from a `RaggedTensor`, ordered from outermost
551      ragged dimension to innermost ragged dimension.
553  Returns:
554    A list of control dependency op tensors.
555  Raises:
556    ValueError: If the splits are not identical.
557  """
558  error_msg = 'Inputs must have identical ragged splits'
559  for splits_list in nested_splits_lists:
560    if len(splits_list) != len(nested_splits_lists[0]):
561      raise ValueError(error_msg)
562  return [
563      check_ops.assert_equal(s1, s2, message=error_msg)  # pylint: disable=g-complex-comprehension
564      for splits_list in nested_splits_lists[1:]
565      for (s1, s2) in zip(nested_splits_lists[0], splits_list)
566  ]