1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2# Licensed under the Apache License, Version 2.0 (the "License");
3# you may not use this file except in compliance with the License.
4# You may obtain a copy of the License at
5#
6#     http://www.apache.org/licenses/LICENSE-2.0
7#
8# Unless required by applicable law or agreed to in writing, software
9# distributed under the License is distributed on an "AS IS" BASIS,
10# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11# See the License for the specific language governing permissions and
12# limitations under the License.
13# ==============================================================================
14"""Implementation of tf.metrics module."""
15
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.distribute import distribution_strategy_context
21from tensorflow.python.eager import context
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import sparse_tensor
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import check_ops
27from tensorflow.python.ops import confusion_matrix
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops import nn
31from tensorflow.python.ops import sets
32from tensorflow.python.ops import sparse_ops
33from tensorflow.python.ops import state_ops
34from tensorflow.python.ops import variable_scope
35from tensorflow.python.ops import weights_broadcast_ops
36from tensorflow.python.platform import tf_logging as logging
37from tensorflow.python.util.deprecation import deprecated
38from tensorflow.python.util.tf_export import tf_export
39
40
41def metric_variable(shape, dtype, validate_shape=True, name=None):
42  """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES)` collections.
43
44  If running in a `DistributionStrategy` context, the variable will be
45  "sync on read". This means:
46
47  *   The returned object will be a container with separate variables
48      per replica of the model.
49
50  *   When writing to the variable, e.g. using `assign_add` in a metric
51      update, the update will be applied to the variable local to the
52      replica.
53
54  *   To get a metric's result value, we need to sum the variable values
55      across the replicas before computing the final answer. Furthermore,
56      the final answer should be computed once instead of in every
57      replica. Both of these are accomplished by running the computation
58      of the final result value inside
59      `distribution_strategy_context.get_replica_context().merge_call(fn)`.
60      Inside the `merge_call()`, ops are only added to the graph once
61      and access to a sync on read variable in a computation returns
62      the sum across all replicas.
63
64  Args:
65    shape: Shape of the created variable.
66    dtype: Type of the created variable.
67    validate_shape: (Optional) Whether shape validation is enabled for
68      the created variable.
69    name: (Optional) String name of the created variable.
70
71  Returns:
72    A (non-trainable) variable initialized to zero, or if inside a
73    `DistributionStrategy` scope a sync on read variable container.
74  """
75  # Note that synchronization "ON_READ" implies trainable=False.
76  return variable_scope.variable(
77      lambda: array_ops.zeros(shape, dtype),
78      trainable=False,
79      collections=[
80          ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES
81      ],
82      validate_shape=validate_shape,
83      synchronization=variable_scope.VariableSynchronization.ON_READ,
84      aggregation=variable_scope.VariableAggregation.SUM,
85      name=name)
86
87
88def _remove_squeezable_dimensions(predictions, labels, weights):
89  """Squeeze or expand last dim if needed.
90
91  Squeezes last dim of `predictions` or `labels` if their rank differs by 1
92  (using confusion_matrix.remove_squeezable_dimensions).
93  Squeezes or expands last dim of `weights` if its rank differs by 1 from the
94  new rank of `predictions`.
95
96  If `weights` is scalar, it is kept scalar.
97
98  This will use static shape if available. Otherwise, it will add graph
99  operations, which could result in a performance hit.
100
101  Args:
102    predictions: Predicted values, a `Tensor` of arbitrary dimensions.
103    labels: Optional label `Tensor` whose dimensions match `predictions`.
104    weights: Optional weight scalar or `Tensor` whose dimensions match
105      `predictions`.
106
107  Returns:
108    Tuple of `predictions`, `labels` and `weights`. Each of them possibly has
109    the last dimension squeezed, `weights` could be extended by one dimension.
110  """
111  predictions = ops.convert_to_tensor(predictions)
112  if labels is not None:
113    labels, predictions = confusion_matrix.remove_squeezable_dimensions(
114        labels, predictions)
115    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
116
117  if weights is None:
118    return predictions, labels, None
119
120  weights = ops.convert_to_tensor(weights)
121  weights_shape = weights.get_shape()
122  weights_rank = weights_shape.ndims
123  if weights_rank == 0:
124    return predictions, labels, weights
125
126  predictions_shape = predictions.get_shape()
127  predictions_rank = predictions_shape.ndims
128  if (predictions_rank is not None) and (weights_rank is not None):
129    # Use static rank.
130    if weights_rank - predictions_rank == 1:
131      weights = array_ops.squeeze(weights, [-1])
132    elif predictions_rank - weights_rank == 1:
133      weights = array_ops.expand_dims(weights, [-1])
134  else:
135    # Use dynamic rank.
136    weights_rank_tensor = array_ops.rank(weights)
137    rank_diff = weights_rank_tensor - array_ops.rank(predictions)
138
139    def _maybe_expand_weights():
140      return control_flow_ops.cond(
141          math_ops.equal(rank_diff, -1),
142          lambda: array_ops.expand_dims(weights, [-1]), lambda: weights)
143
144    # Don't attempt squeeze if it will fail based on static check.
145    if ((weights_rank is not None) and
146        (not weights_shape.dims[-1].is_compatible_with(1))):
147      maybe_squeeze_weights = lambda: weights
148    else:
149      maybe_squeeze_weights = lambda: array_ops.squeeze(weights, [-1])
150
151    def _maybe_adjust_weights():
152      return control_flow_ops.cond(
153          math_ops.equal(rank_diff, 1), maybe_squeeze_weights,
154          _maybe_expand_weights)
155
156    # If weights are scalar, do nothing. Otherwise, try to add or remove a
157    # dimension to match predictions.
158    weights = control_flow_ops.cond(
159        math_ops.equal(weights_rank_tensor, 0), lambda: weights,
160        _maybe_adjust_weights)
161  return predictions, labels, weights
162
163
164def _maybe_expand_labels(labels, predictions):
165  """If necessary, expand `labels` along last dimension to match `predictions`.
166
167  Args:
168    labels: `Tensor` or `SparseTensor` with shape
169      [D1, ... DN, num_labels] or [D1, ... DN]. The latter implies
170      num_labels=1, in which case the result is an expanded `labels` with shape
171      [D1, ... DN, 1].
172    predictions: `Tensor` with shape [D1, ... DN, num_classes].
173
174  Returns:
175    `labels` with the same rank as `predictions`.
176
177  Raises:
178    ValueError: if `labels` has invalid shape.
179  """
180  with ops.name_scope(None, 'expand_labels', (labels, predictions)) as scope:
181    labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
182
183    # If sparse, expand sparse shape.
184    if isinstance(labels, sparse_tensor.SparseTensor):
185      return control_flow_ops.cond(
186          math_ops.equal(
187              array_ops.rank(predictions),
188              array_ops.size(labels.dense_shape) + 1),
189          lambda: sparse_ops.sparse_reshape(  # pylint: disable=g-long-lambda
190              labels,
191              shape=array_ops.concat((labels.dense_shape, (1,)), 0),
192              name=scope),
193          lambda: labels)
194
195    # Otherwise, try to use static shape.
196    labels_rank = labels.get_shape().ndims
197    if labels_rank is not None:
198      predictions_rank = predictions.get_shape().ndims
199      if predictions_rank is not None:
200        if predictions_rank == labels_rank:
201          return labels
202        if predictions_rank == labels_rank + 1:
203          return array_ops.expand_dims(labels, -1, name=scope)
204        raise ValueError(
205            'Unexpected labels shape %s for predictions shape %s.' %
206            (labels.get_shape(), predictions.get_shape()))
207
208    # Otherwise, use dynamic shape.
209    return control_flow_ops.cond(
210        math_ops.equal(array_ops.rank(predictions),
211                       array_ops.rank(labels) + 1),
212        lambda: array_ops.expand_dims(labels, -1, name=scope), lambda: labels)
213
214
215def _safe_scalar_div(numerator, denominator, name):
216  """Divides two values, returning 0 if the denominator is 0.
217
218  Args:
219    numerator: A scalar `float64` `Tensor`.
220    denominator: A scalar `float64` `Tensor`.
221    name: Name for the returned op.
222
223  Returns:
224    0 if `denominator` == 0, else `numerator` / `denominator`
225  """
226  numerator.get_shape().with_rank_at_most(1)
227  denominator.get_shape().with_rank_at_most(1)
228  return math_ops.div_no_nan(numerator, denominator, name=name)
229
230
231def _streaming_confusion_matrix(labels, predictions, num_classes, weights=None):
232  """Calculate a streaming confusion matrix.
233
234  Calculates a confusion matrix. For estimation over a stream of data,
235  the function creates an  `update_op` operation.
236
237  Args:
238    labels: A `Tensor` of ground truth labels with shape [batch size] and of
239      type `int32` or `int64`. The tensor will be flattened if its rank > 1.
240    predictions: A `Tensor` of prediction results for semantic labels, whose
241      shape is [batch size] and type `int32` or `int64`. The tensor will be
242      flattened if its rank > 1.
243    num_classes: The possible number of labels the prediction task can
244      have. This value must be provided, since a confusion matrix of
245      dimension = [num_classes, num_classes] will be allocated.
246    weights: Optional `Tensor` whose rank is either 0, or the same rank as
247      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
248      be either `1`, or the same as the corresponding `labels` dimension).
249
250  Returns:
251    total_cm: A `Tensor` representing the confusion matrix.
252    update_op: An operation that increments the confusion matrix.
253  """
254  # Local variable to accumulate the predictions in the confusion matrix.
255  total_cm = metric_variable(
256      [num_classes, num_classes], dtypes.float64, name='total_confusion_matrix')
257
258  # Cast the type to int64 required by confusion_matrix_ops.
259  predictions = math_ops.cast(predictions, dtypes.int64)
260  labels = math_ops.cast(labels, dtypes.int64)
261  num_classes = math_ops.cast(num_classes, dtypes.int64)
262
263  # Flatten the input if its rank > 1.
264  if predictions.get_shape().ndims > 1:
265    predictions = array_ops.reshape(predictions, [-1])
266
267  if labels.get_shape().ndims > 1:
268    labels = array_ops.reshape(labels, [-1])
269
270  if (weights is not None) and (weights.get_shape().ndims > 1):
271    weights = array_ops.reshape(weights, [-1])
272
273  # Accumulate the prediction to current confusion matrix.
274  current_cm = confusion_matrix.confusion_matrix(
275      labels, predictions, num_classes, weights=weights, dtype=dtypes.float64)
276  update_op = state_ops.assign_add(total_cm, current_cm)
277  return total_cm, update_op
278
279
280def _aggregate_across_replicas(metrics_collections, metric_value_fn, *args):
281  """Aggregate metric value across replicas."""
282  def fn(distribution, *a):
283    """Call `metric_value_fn` in the correct control flow context."""
284    if hasattr(distribution.extended, '_outer_control_flow_context'):
285      # If there was an outer context captured before this method was called,
286      # then we enter that context to create the metric value op. If the
287      # captured context is `None`, ops.control_dependencies(None) gives the
288      # desired behavior. Else we use `Enter` and `Exit` to enter and exit the
289      # captured context.
290      # This special handling is needed because sometimes the metric is created
291      # inside a while_loop (and perhaps a TPU rewrite context). But we don't
292      # want the value op to be evaluated every step or on the TPU. So we
293      # create it outside so that it can be evaluated at the end on the host,
294      # once the update ops have been evaluated.
295
296      # pylint: disable=protected-access
297      if distribution.extended._outer_control_flow_context is None:
298        with ops.control_dependencies(None):
299          metric_value = metric_value_fn(distribution, *a)
300      else:
301        distribution.extended._outer_control_flow_context.Enter()
302        metric_value = metric_value_fn(distribution, *a)
303        distribution.extended._outer_control_flow_context.Exit()
304        # pylint: enable=protected-access
305    else:
306      metric_value = metric_value_fn(distribution, *a)
307    if metrics_collections:
308      ops.add_to_collections(metrics_collections, metric_value)
309    return metric_value
310
311  return distribution_strategy_context.get_replica_context().merge_call(
312      fn, args=args)
313
314
315@tf_export(v1=['metrics.mean'])
316def mean(values,
317         weights=None,
318         metrics_collections=None,
319         updates_collections=None,
320         name=None):
321  """Computes the (weighted) mean of the given values.
322
323  The `mean` function creates two local variables, `total` and `count`
324  that are used to compute the average of `values`. This average is ultimately
325  returned as `mean` which is an idempotent operation that simply divides
326  `total` by `count`.
327
328  For estimation of the metric over a stream of data, the function creates an
329  `update_op` operation that updates these variables and returns the `mean`.
330  `update_op` increments `total` with the reduced sum of the product of `values`
331  and `weights`, and it increments `count` with the reduced sum of `weights`.
332
333  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
334
335  Args:
336    values: A `Tensor` of arbitrary dimensions.
337    weights: Optional `Tensor` whose rank is either 0, or the same rank as
338      `values`, and must be broadcastable to `values` (i.e., all dimensions must
339      be either `1`, or the same as the corresponding `values` dimension).
340    metrics_collections: An optional list of collections that `mean`
341      should be added to.
342    updates_collections: An optional list of collections that `update_op`
343      should be added to.
344    name: An optional variable_scope name.
345
346  Returns:
347    mean: A `Tensor` representing the current mean, the value of `total` divided
348      by `count`.
349    update_op: An operation that increments the `total` and `count` variables
350      appropriately and whose value matches `mean_value`.
351
352  Raises:
353    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
354      or if either `metrics_collections` or `updates_collections` are not a list
355      or tuple.
356    RuntimeError: If eager execution is enabled.
357  """
358  if context.executing_eagerly():
359    raise RuntimeError('tf.metrics.mean is not supported when eager execution '
360                       'is enabled.')
361
362  with variable_scope.variable_scope(name, 'mean', (values, weights)):
363    values = math_ops.cast(values, dtypes.float32)
364
365    total = metric_variable([], dtypes.float32, name='total')
366    count = metric_variable([], dtypes.float32, name='count')
367
368    if weights is None:
369      num_values = math_ops.cast(array_ops.size(values), dtypes.float32)
370    else:
371      values, _, weights = _remove_squeezable_dimensions(
372          predictions=values, labels=None, weights=weights)
373      weights = weights_broadcast_ops.broadcast_weights(
374          math_ops.cast(weights, dtypes.float32), values)
375      values = math_ops.multiply(values, weights)
376      num_values = math_ops.reduce_sum(weights)
377
378    update_total_op = state_ops.assign_add(total, math_ops.reduce_sum(values))
379    with ops.control_dependencies([values]):
380      update_count_op = state_ops.assign_add(count, num_values)
381
382    def compute_mean(_, t, c):
383      return math_ops.div_no_nan(t, math_ops.maximum(c, 0), name='value')
384
385    mean_t = _aggregate_across_replicas(
386        metrics_collections, compute_mean, total, count)
387    update_op = math_ops.div_no_nan(
388        update_total_op, math_ops.maximum(update_count_op, 0), name='update_op')
389
390    if updates_collections:
391      ops.add_to_collections(updates_collections, update_op)
392
393    return mean_t, update_op
394
395
396@tf_export(v1=['metrics.accuracy'])
397def accuracy(labels,
398             predictions,
399             weights=None,
400             metrics_collections=None,
401             updates_collections=None,
402             name=None):
403  """Calculates how often `predictions` matches `labels`.
404
405  The `accuracy` function creates two local variables, `total` and
406  `count` that are used to compute the frequency with which `predictions`
407  matches `labels`. This frequency is ultimately returned as `accuracy`: an
408  idempotent operation that simply divides `total` by `count`.
409
410  For estimation of the metric over a stream of data, the function creates an
411  `update_op` operation that updates these variables and returns the `accuracy`.
412  Internally, an `is_correct` operation computes a `Tensor` with elements 1.0
413  where the corresponding elements of `predictions` and `labels` match and 0.0
414  otherwise. Then `update_op` increments `total` with the reduced sum of the
415  product of `weights` and `is_correct`, and it increments `count` with the
416  reduced sum of `weights`.
417
418  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
419
420  Args:
421    labels: The ground truth values, a `Tensor` whose shape matches
422      `predictions`.
423    predictions: The predicted values, a `Tensor` of any shape.
424    weights: Optional `Tensor` whose rank is either 0, or the same rank as
425      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
426      be either `1`, or the same as the corresponding `labels` dimension).
427    metrics_collections: An optional list of collections that `accuracy` should
428      be added to.
429    updates_collections: An optional list of collections that `update_op` should
430      be added to.
431    name: An optional variable_scope name.
432
433  Returns:
434    accuracy: A `Tensor` representing the accuracy, the value of `total` divided
435      by `count`.
436    update_op: An operation that increments the `total` and `count` variables
437      appropriately and whose value matches `accuracy`.
438
439  Raises:
440    ValueError: If `predictions` and `labels` have mismatched shapes, or if
441      `weights` is not `None` and its shape doesn't match `predictions`, or if
442      either `metrics_collections` or `updates_collections` are not a list or
443      tuple.
444    RuntimeError: If eager execution is enabled.
445  """
446  if context.executing_eagerly():
447    raise RuntimeError('tf.metrics.accuracy is not supported when eager '
448                       'execution is enabled.')
449
450  predictions, labels, weights = _remove_squeezable_dimensions(
451      predictions=predictions, labels=labels, weights=weights)
452  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
453  if labels.dtype != predictions.dtype:
454    predictions = math_ops.cast(predictions, labels.dtype)
455  is_correct = math_ops.cast(
456      math_ops.equal(predictions, labels), dtypes.float32)
457  return mean(is_correct, weights, metrics_collections, updates_collections,
458              name or 'accuracy')
459
460
461def _confusion_matrix_at_thresholds(labels,
462                                    predictions,
463                                    thresholds,
464                                    weights=None,
465                                    includes=None):
466  """Computes true_positives, false_negatives, true_negatives, false_positives.
467
468  This function creates up to four local variables, `true_positives`,
469  `true_negatives`, `false_positives` and `false_negatives`.
470  `true_positive[i]` is defined as the total weight of values in `predictions`
471  above `thresholds[i]` whose corresponding entry in `labels` is `True`.
472  `false_negatives[i]` is defined as the total weight of values in `predictions`
473  at most `thresholds[i]` whose corresponding entry in `labels` is `True`.
474  `true_negatives[i]` is defined as the total weight of values in `predictions`
475  at most `thresholds[i]` whose corresponding entry in `labels` is `False`.
476  `false_positives[i]` is defined as the total weight of values in `predictions`
477  above `thresholds[i]` whose corresponding entry in `labels` is `False`.
478
479  For estimation of these metrics over a stream of data, for each metric the
480  function respectively creates an `update_op` operation that updates the
481  variable and returns its value.
482
483  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
484
485  Args:
486    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
487      `bool`.
488    predictions: A floating point `Tensor` of arbitrary shape and whose values
489      are in the range `[0, 1]`.
490    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
491    weights: Optional `Tensor` whose rank is either 0, or the same rank as
492      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
493      be either `1`, or the same as the corresponding `labels` dimension).
494    includes: Tuple of keys to return, from 'tp', 'fn', 'tn', fp'. If `None`,
495        default to all four.
496
497  Returns:
498    values: Dict of variables of shape `[len(thresholds)]`. Keys are from
499        `includes`.
500    update_ops: Dict of operations that increments the `values`. Keys are from
501        `includes`.
502
503  Raises:
504    ValueError: If `predictions` and `labels` have mismatched shapes, or if
505      `weights` is not `None` and its shape doesn't match `predictions`, or if
506      `includes` contains invalid keys.
507  """
508  all_includes = ('tp', 'fn', 'tn', 'fp')
509  if includes is None:
510    includes = all_includes
511  else:
512    for include in includes:
513      if include not in all_includes:
514        raise ValueError('Invalid key: %s.' % include)
515
516  with ops.control_dependencies([
517      check_ops.assert_greater_equal(
518          predictions,
519          math_ops.cast(0.0, dtype=predictions.dtype),
520          message='predictions must be in [0, 1]'),
521      check_ops.assert_less_equal(
522          predictions,
523          math_ops.cast(1.0, dtype=predictions.dtype),
524          message='predictions must be in [0, 1]')
525  ]):
526    predictions, labels, weights = _remove_squeezable_dimensions(
527        predictions=math_ops.cast(predictions, dtypes.float32),
528        labels=math_ops.cast(labels, dtype=dtypes.bool),
529        weights=weights)
530
531  num_thresholds = len(thresholds)
532
533  # Reshape predictions and labels.
534  predictions_2d = array_ops.reshape(predictions, [-1, 1])
535  labels_2d = array_ops.reshape(
536      math_ops.cast(labels, dtype=dtypes.bool), [1, -1])
537
538  # Use static shape if known.
539  num_predictions = predictions_2d.get_shape().as_list()[0]
540
541  # Otherwise use dynamic shape.
542  if num_predictions is None:
543    num_predictions = array_ops.shape(predictions_2d)[0]
544  thresh_tiled = array_ops.tile(
545      array_ops.expand_dims(array_ops.constant(thresholds), [1]),
546      array_ops.stack([1, num_predictions]))
547
548  # Tile the predictions after thresholding them across different thresholds.
549  pred_is_pos = math_ops.greater(
550      array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]),
551      thresh_tiled)
552  if ('fn' in includes) or ('tn' in includes):
553    pred_is_neg = math_ops.logical_not(pred_is_pos)
554
555  # Tile labels by number of thresholds
556  label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1])
557  if ('fp' in includes) or ('tn' in includes):
558    label_is_neg = math_ops.logical_not(label_is_pos)
559
560  if weights is not None:
561    weights = weights_broadcast_ops.broadcast_weights(
562        math_ops.cast(weights, dtypes.float32), predictions)
563    weights_tiled = array_ops.tile(
564        array_ops.reshape(weights, [1, -1]), [num_thresholds, 1])
565    thresh_tiled.get_shape().assert_is_compatible_with(
566        weights_tiled.get_shape())
567  else:
568    weights_tiled = None
569
570  values = {}
571  update_ops = {}
572
573  if 'tp' in includes:
574    true_p = metric_variable(
575        [num_thresholds], dtypes.float32, name='true_positives')
576    is_true_positive = math_ops.cast(
577        math_ops.logical_and(label_is_pos, pred_is_pos), dtypes.float32)
578    if weights_tiled is not None:
579      is_true_positive *= weights_tiled
580    update_ops['tp'] = state_ops.assign_add(true_p,
581                                            math_ops.reduce_sum(
582                                                is_true_positive, 1))
583    values['tp'] = true_p
584
585  if 'fn' in includes:
586    false_n = metric_variable(
587        [num_thresholds], dtypes.float32, name='false_negatives')
588    is_false_negative = math_ops.cast(
589        math_ops.logical_and(label_is_pos, pred_is_neg), dtypes.float32)
590    if weights_tiled is not None:
591      is_false_negative *= weights_tiled
592    update_ops['fn'] = state_ops.assign_add(false_n,
593                                            math_ops.reduce_sum(
594                                                is_false_negative, 1))
595    values['fn'] = false_n
596
597  if 'tn' in includes:
598    true_n = metric_variable(
599        [num_thresholds], dtypes.float32, name='true_negatives')
600    is_true_negative = math_ops.cast(
601        math_ops.logical_and(label_is_neg, pred_is_neg), dtypes.float32)
602    if weights_tiled is not None:
603      is_true_negative *= weights_tiled
604    update_ops['tn'] = state_ops.assign_add(true_n,
605                                            math_ops.reduce_sum(
606                                                is_true_negative, 1))
607    values['tn'] = true_n
608
609  if 'fp' in includes:
610    false_p = metric_variable(
611        [num_thresholds], dtypes.float32, name='false_positives')
612    is_false_positive = math_ops.cast(
613        math_ops.logical_and(label_is_neg, pred_is_pos), dtypes.float32)
614    if weights_tiled is not None:
615      is_false_positive *= weights_tiled
616    update_ops['fp'] = state_ops.assign_add(false_p,
617                                            math_ops.reduce_sum(
618                                                is_false_positive, 1))
619    values['fp'] = false_p
620
621  return values, update_ops
622
623
624def _aggregate_variable(v, collections):
625  f = lambda distribution, value: distribution.extended.read_var(value)
626  return _aggregate_across_replicas(collections, f, v)
627
628
629@tf_export(v1=['metrics.auc'])
630@deprecated(None,
631            'The value of AUC returned by this may race with the update so '
632            'this is deprecated. Please use tf.keras.metrics.AUC instead.')
633def auc(labels,
634        predictions,
635        weights=None,
636        num_thresholds=200,
637        metrics_collections=None,
638        updates_collections=None,
639        curve='ROC',
640        name=None,
641        summation_method='trapezoidal',
642        thresholds=None):
643  """Computes the approximate AUC via a Riemann sum.
644
645  The `auc` function creates four local variables, `true_positives`,
646  `true_negatives`, `false_positives` and `false_negatives` that are used to
647  compute the AUC. To discretize the AUC curve, a linearly spaced set of
648  thresholds is used to compute pairs of recall and precision values. The area
649  under the ROC-curve is therefore computed using the height of the recall
650  values by the false positive rate, while the area under the PR-curve is the
651  computed using the height of the precision values by the recall.
652
653  This value is ultimately returned as `auc`, an idempotent operation that
654  computes the area under a discretized curve of precision versus recall values
655  (computed using the aforementioned variables). The `num_thresholds` variable
656  controls the degree of discretization with larger numbers of thresholds more
657  closely approximating the true AUC. The quality of the approximation may vary
658  dramatically depending on `num_thresholds`.
659
660  For best results, `predictions` should be distributed approximately uniformly
661  in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
662  approximation may be poor if this is not the case. Setting `summation_method`
663  to 'minoring' or 'majoring' can help quantify the error in the approximation
664  by providing lower or upper bound estimate of the AUC. The `thresholds`
665  parameter can be used to manually specify thresholds which split the
666  predictions more evenly.
667
668  For estimation of the metric over a stream of data, the function creates an
669  `update_op` operation that updates these variables and returns the `auc`.
670
671  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
672
673  Args:
674    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
675      `bool`.
676    predictions: A floating point `Tensor` of arbitrary shape and whose values
677      are in the range `[0, 1]`.
678    weights: Optional `Tensor` whose rank is either 0, or the same rank as
679      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
680      be either `1`, or the same as the corresponding `labels` dimension).
681    num_thresholds: The number of thresholds to use when discretizing the roc
682      curve.
683    metrics_collections: An optional list of collections that `auc` should be
684      added to.
685    updates_collections: An optional list of collections that `update_op` should
686      be added to.
687    curve: Specifies the name of the curve to be computed, 'ROC' [default] or
688      'PR' for the Precision-Recall-curve.
689    name: An optional variable_scope name.
690    summation_method: Specifies the Riemann summation method used
691      (https://en.wikipedia.org/wiki/Riemann_sum): 'trapezoidal' [default] that
692      applies the trapezoidal rule; 'careful_interpolation', a variant of it
693      differing only by a more correct interpolation scheme for PR-AUC -
694      interpolating (true/false) positives but not the ratio that is precision;
695      'minoring' that applies left summation for increasing intervals and right
696      summation for decreasing intervals; 'majoring' that does the opposite.
697      Note that 'careful_interpolation' is strictly preferred to 'trapezoidal'
698      (to be deprecated soon) as it applies the same method for ROC, and a
699      better one (see Davis & Goadrich 2006 for details) for the PR curve.
700    thresholds: An optional list of floating point values to use as the
701      thresholds for discretizing the curve. If set, the `num_thresholds`
702      parameter is ignored. Values should be in [0, 1]. Endpoint thresholds
703      equal to {-epsilon, 1+epsilon} for a small positive epsilon value will be
704      automatically included with these to correctly handle predictions equal to
705       exactly 0 or 1.
706
707  Returns:
708    auc: A scalar `Tensor` representing the current area-under-curve.
709    update_op: An operation that increments the `true_positives`,
710      `true_negatives`, `false_positives` and `false_negatives` variables
711      appropriately and whose value matches `auc`.
712
713  Raises:
714    ValueError: If `predictions` and `labels` have mismatched shapes, or if
715      `weights` is not `None` and its shape doesn't match `predictions`, or if
716      either `metrics_collections` or `updates_collections` are not a list or
717      tuple.
718    RuntimeError: If eager execution is enabled.
719  """
720  if context.executing_eagerly():
721    raise RuntimeError('tf.metrics.auc is not supported when eager execution '
722                       'is enabled.')
723
724  with variable_scope.variable_scope(name, 'auc',
725                                     (labels, predictions, weights)):
726    if curve != 'ROC' and curve != 'PR':
727      raise ValueError('curve must be either ROC or PR, %s unknown' % (curve))
728
729    kepsilon = 1e-7  # To account for floating point imprecisions.
730    if thresholds is not None:
731      # If specified, use the supplied thresholds.
732      thresholds = sorted(thresholds)
733      num_thresholds = len(thresholds) + 2
734    else:
735      # Otherwise, linearly interpolate (num_thresholds - 2) thresholds in
736      # (0, 1).
737      thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
738                    for i in range(num_thresholds - 2)]
739
740    # Add an endpoint "threshold" below zero and above one for either threshold
741    # method.
742    thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
743
744    values, update_ops = _confusion_matrix_at_thresholds(
745        labels, predictions, thresholds, weights)
746
747    # Add epsilons to avoid dividing by 0.
748    epsilon = 1.0e-6
749
750    def interpolate_pr_auc(tp, fp, fn):
751      """Interpolation formula inspired by section 4 of (Davis et al., 2006).
752
753      Note here we derive & use a closed formula not present in the paper
754      - as follows:
755      Modeling all of TP (true positive weight),
756      FP (false positive weight) and their sum P = TP + FP (positive weight)
757      as varying linearly within each interval [A, B] between successive
758      thresholds, we get
759        Precision = (TP_A + slope * (P - P_A)) / P
760      with slope = dTP / dP = (TP_B - TP_A) / (P_B - P_A).
761      The area within the interval is thus (slope / total_pos_weight) times
762        int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}
763        int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}
764      where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in
765        int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)
766      Bringing back the factor (slope / total_pos_weight) we'd put aside, we get
767         slope * [dTP + intercept *  log(P_B / P_A)] / total_pos_weight
768      where dTP == TP_B - TP_A.
769      Note that when P_A == 0 the above calculation simplifies into
770        int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A)
771      which is really equivalent to imputing constant precision throughout the
772      first bucket having >0 true positives.
773
774      Args:
775        tp: true positive counts
776        fp: false positive counts
777        fn: false negative counts
778
779      Returns:
780        pr_auc: an approximation of the area under the P-R curve.
781
782      References:
783        The Relationship Between Precision-Recall and ROC Curves:
784          [Davis et al., 2006](https://dl.acm.org/citation.cfm?id=1143874)
785          ([pdf](https://www.biostat.wisc.edu/~page/rocpr.pdf))
786      """
787      dtp = tp[:num_thresholds - 1] - tp[1:]
788      p = tp + fp
789      prec_slope = math_ops.div_no_nan(
790          dtp,
791          math_ops.maximum(p[:num_thresholds - 1] - p[1:], 0),
792          name='prec_slope')
793      intercept = tp[1:] - math_ops.multiply(prec_slope, p[1:])
794      safe_p_ratio = array_ops.where(
795          math_ops.logical_and(p[:num_thresholds - 1] > 0, p[1:] > 0),
796          math_ops.div_no_nan(
797              p[:num_thresholds - 1],
798              math_ops.maximum(p[1:], 0),
799              name='recall_relative_ratio'), array_ops.ones_like(p[1:]))
800      return math_ops.reduce_sum(
801          math_ops.div_no_nan(
802              prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)),
803              math_ops.maximum(tp[1:] + fn[1:], 0),
804              name='pr_auc_increment'),
805          name='interpolate_pr_auc')
806
807    def compute_auc(tp, fn, tn, fp, name):
808      """Computes the roc-auc or pr-auc based on confusion counts."""
809      if curve == 'PR':
810        if summation_method == 'trapezoidal':
811          logging.warning(
812              'Trapezoidal rule is known to produce incorrect PR-AUCs; '
813              'please switch to "careful_interpolation" instead.')
814        elif summation_method == 'careful_interpolation':
815          # This one is a bit tricky and is handled separately.
816          return interpolate_pr_auc(tp, fp, fn)
817      rec = math_ops.divide(tp + epsilon, tp + fn + epsilon)
818      if curve == 'ROC':
819        fp_rate = math_ops.divide(fp, fp + tn + epsilon)
820        x = fp_rate
821        y = rec
822      else:  # curve == 'PR'.
823        prec = math_ops.divide(tp + epsilon, tp + fp + epsilon)
824        x = rec
825        y = prec
826      if summation_method in ('trapezoidal', 'careful_interpolation'):
827        # Note that the case ('PR', 'careful_interpolation') has been handled
828        # above.
829        return math_ops.reduce_sum(
830            math_ops.multiply(x[:num_thresholds - 1] - x[1:],
831                              (y[:num_thresholds - 1] + y[1:]) / 2.),
832            name=name)
833      elif summation_method == 'minoring':
834        return math_ops.reduce_sum(
835            math_ops.multiply(x[:num_thresholds - 1] - x[1:],
836                              math_ops.minimum(y[:num_thresholds - 1], y[1:])),
837            name=name)
838      elif summation_method == 'majoring':
839        return math_ops.reduce_sum(
840            math_ops.multiply(x[:num_thresholds - 1] - x[1:],
841                              math_ops.maximum(y[:num_thresholds - 1], y[1:])),
842            name=name)
843      else:
844        raise ValueError('Invalid summation_method: %s' % summation_method)
845
846    # sum up the areas of all the trapeziums
847    def compute_auc_value(_, values):
848      return compute_auc(values['tp'], values['fn'], values['tn'], values['fp'],
849                         'value')
850
851    auc_value = _aggregate_across_replicas(
852        metrics_collections, compute_auc_value, values)
853    update_op = compute_auc(update_ops['tp'], update_ops['fn'],
854                            update_ops['tn'], update_ops['fp'], 'update_op')
855
856    if updates_collections:
857      ops.add_to_collections(updates_collections, update_op)
858
859    return auc_value, update_op
860
861
862@tf_export(v1=['metrics.mean_absolute_error'])
863def mean_absolute_error(labels,
864                        predictions,
865                        weights=None,
866                        metrics_collections=None,
867                        updates_collections=None,
868                        name=None):
869  """Computes the mean absolute error between the labels and predictions.
870
871  The `mean_absolute_error` function creates two local variables,
872  `total` and `count` that are used to compute the mean absolute error. This
873  average is weighted by `weights`, and it is ultimately returned as
874  `mean_absolute_error`: an idempotent operation that simply divides `total` by
875  `count`.
876
877  For estimation of the metric over a stream of data, the function creates an
878  `update_op` operation that updates these variables and returns the
879  `mean_absolute_error`. Internally, an `absolute_errors` operation computes the
880  absolute value of the differences between `predictions` and `labels`. Then
881  `update_op` increments `total` with the reduced sum of the product of
882  `weights` and `absolute_errors`, and it increments `count` with the reduced
883  sum of `weights`
884
885  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
886
887  Args:
888    labels: A `Tensor` of the same shape as `predictions`.
889    predictions: A `Tensor` of arbitrary shape.
890    weights: Optional `Tensor` whose rank is either 0, or the same rank as
891      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
892      be either `1`, or the same as the corresponding `labels` dimension).
893    metrics_collections: An optional list of collections that
894      `mean_absolute_error` should be added to.
895    updates_collections: An optional list of collections that `update_op` should
896      be added to.
897    name: An optional variable_scope name.
898
899  Returns:
900    mean_absolute_error: A `Tensor` representing the current mean, the value of
901      `total` divided by `count`.
902    update_op: An operation that increments the `total` and `count` variables
903      appropriately and whose value matches `mean_absolute_error`.
904
905  Raises:
906    ValueError: If `predictions` and `labels` have mismatched shapes, or if
907      `weights` is not `None` and its shape doesn't match `predictions`, or if
908      either `metrics_collections` or `updates_collections` are not a list or
909      tuple.
910    RuntimeError: If eager execution is enabled.
911  """
912  if context.executing_eagerly():
913    raise RuntimeError('tf.metrics.mean_absolute_error is not supported '
914                       'when eager execution is enabled.')
915
916  predictions, labels, weights = _remove_squeezable_dimensions(
917      predictions=predictions, labels=labels, weights=weights)
918  absolute_errors = math_ops.abs(predictions - labels)
919  return mean(absolute_errors, weights, metrics_collections,
920              updates_collections, name or 'mean_absolute_error')
921
922
923@tf_export(v1=['metrics.mean_cosine_distance'])
924def mean_cosine_distance(labels,
925                         predictions,
926                         dim,
927                         weights=None,
928                         metrics_collections=None,
929                         updates_collections=None,
930                         name=None):
931  """Computes the cosine distance between the labels and predictions.
932
933  The `mean_cosine_distance` function creates two local variables,
934  `total` and `count` that are used to compute the average cosine distance
935  between `predictions` and `labels`. This average is weighted by `weights`,
936  and it is ultimately returned as `mean_distance`, which is an idempotent
937  operation that simply divides `total` by `count`.
938
939  For estimation of the metric over a stream of data, the function creates an
940  `update_op` operation that updates these variables and returns the
941  `mean_distance`.
942
943  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
944
945  Args:
946    labels: A `Tensor` of arbitrary shape.
947    predictions: A `Tensor` of the same shape as `labels`.
948    dim: The dimension along which the cosine distance is computed.
949    weights: Optional `Tensor` whose rank is either 0, or the same rank as
950      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
951      be either `1`, or the same as the corresponding `labels` dimension). Also,
952      dimension `dim` must be `1`.
953    metrics_collections: An optional list of collections that the metric
954      value variable should be added to.
955    updates_collections: An optional list of collections that the metric update
956      ops should be added to.
957    name: An optional variable_scope name.
958
959  Returns:
960    mean_distance: A `Tensor` representing the current mean, the value of
961      `total` divided by `count`.
962    update_op: An operation that increments the `total` and `count` variables
963      appropriately.
964
965  Raises:
966    ValueError: If `predictions` and `labels` have mismatched shapes, or if
967      `weights` is not `None` and its shape doesn't match `predictions`, or if
968      either `metrics_collections` or `updates_collections` are not a list or
969      tuple.
970    RuntimeError: If eager execution is enabled.
971  """
972  if context.executing_eagerly():
973    raise RuntimeError('tf.metrics.mean_cosine_distance is not supported when '
974                       'eager execution is enabled.')
975
976  predictions, labels, weights = _remove_squeezable_dimensions(
977      predictions=predictions, labels=labels, weights=weights)
978  radial_diffs = math_ops.multiply(predictions, labels)
979  radial_diffs = math_ops.reduce_sum(
980      radial_diffs, axis=[
981          dim,
982      ], keepdims=True)
983  mean_distance, update_op = mean(radial_diffs, weights, None, None, name or
984                                  'mean_cosine_distance')
985  mean_distance = math_ops.subtract(1.0, mean_distance)
986  update_op = math_ops.subtract(1.0, update_op)
987
988  if metrics_collections:
989    ops.add_to_collections(metrics_collections, mean_distance)
990
991  if updates_collections:
992    ops.add_to_collections(updates_collections, update_op)
993
994  return mean_distance, update_op
995
996
997@tf_export(v1=['metrics.mean_per_class_accuracy'])
998def mean_per_class_accuracy(labels,
999                            predictions,
1000                            num_classes,
1001                            weights=None,
1002                            metrics_collections=None,
1003                            updates_collections=None,
1004                            name=None):
1005  """Calculates the mean of the per-class accuracies.
1006
1007  Calculates the accuracy for each class, then takes the mean of that.
1008
1009  For estimation of the metric over a stream of data, the function creates an
1010  `update_op` operation that updates the accuracy of each class and returns
1011  them.
1012
1013  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1014
1015  Args:
1016    labels: A `Tensor` of ground truth labels with shape [batch size] and of
1017      type `int32` or `int64`. The tensor will be flattened if its rank > 1.
1018    predictions: A `Tensor` of prediction results for semantic labels, whose
1019      shape is [batch size] and type `int32` or `int64`. The tensor will be
1020      flattened if its rank > 1.
1021    num_classes: The possible number of labels the prediction task can
1022      have. This value must be provided, since two variables with shape =
1023      [num_classes] will be allocated.
1024    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1025      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1026      be either `1`, or the same as the corresponding `labels` dimension).
1027    metrics_collections: An optional list of collections that
1028      `mean_per_class_accuracy'
1029      should be added to.
1030    updates_collections: An optional list of collections `update_op` should be
1031      added to.
1032    name: An optional variable_scope name.
1033
1034  Returns:
1035    mean_accuracy: A `Tensor` representing the mean per class accuracy.
1036    update_op: An operation that updates the accuracy tensor.
1037
1038  Raises:
1039    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1040      `weights` is not `None` and its shape doesn't match `predictions`, or if
1041      either `metrics_collections` or `updates_collections` are not a list or
1042      tuple.
1043    RuntimeError: If eager execution is enabled.
1044  """
1045  if context.executing_eagerly():
1046    raise RuntimeError('tf.metrics.mean_per_class_accuracy is not supported '
1047                       'when eager execution is enabled.')
1048
1049  with variable_scope.variable_scope(name, 'mean_accuracy',
1050                                     (predictions, labels, weights)):
1051    labels = math_ops.cast(labels, dtypes.int64)
1052
1053    # Flatten the input if its rank > 1.
1054    if labels.get_shape().ndims > 1:
1055      labels = array_ops.reshape(labels, [-1])
1056
1057    if predictions.get_shape().ndims > 1:
1058      predictions = array_ops.reshape(predictions, [-1])
1059
1060    # Check if shape is compatible.
1061    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1062
1063    total = metric_variable([num_classes], dtypes.float32, name='total')
1064    count = metric_variable([num_classes], dtypes.float32, name='count')
1065
1066    ones = array_ops.ones([array_ops.size(labels)], dtypes.float32)
1067
1068    if labels.dtype != predictions.dtype:
1069      predictions = math_ops.cast(predictions, labels.dtype)
1070    is_correct = math_ops.cast(
1071        math_ops.equal(predictions, labels), dtypes.float32)
1072
1073    if weights is not None:
1074      if weights.get_shape().ndims > 1:
1075        weights = array_ops.reshape(weights, [-1])
1076      weights = math_ops.cast(weights, dtypes.float32)
1077
1078      is_correct *= weights
1079      ones *= weights
1080
1081    update_total_op = state_ops.scatter_add(total, labels, ones)
1082    update_count_op = state_ops.scatter_add(count, labels, is_correct)
1083
1084    def compute_mean_accuracy(_, count, total):
1085      per_class_accuracy = math_ops.div_no_nan(
1086          count, math_ops.maximum(total, 0), name=None)
1087      mean_accuracy_v = math_ops.reduce_mean(
1088          per_class_accuracy, name='mean_accuracy')
1089      return mean_accuracy_v
1090
1091    mean_accuracy_v = _aggregate_across_replicas(
1092        metrics_collections, compute_mean_accuracy, count, total)
1093
1094    update_op = math_ops.div_no_nan(
1095        update_count_op, math_ops.maximum(update_total_op, 0), name='update_op')
1096    if updates_collections:
1097      ops.add_to_collections(updates_collections, update_op)
1098
1099    return mean_accuracy_v, update_op
1100
1101
1102@tf_export(v1=['metrics.mean_iou'])
1103def mean_iou(labels,
1104             predictions,
1105             num_classes,
1106             weights=None,
1107             metrics_collections=None,
1108             updates_collections=None,
1109             name=None):
1110  """Calculate per-step mean Intersection-Over-Union (mIOU).
1111
1112  Mean Intersection-Over-Union is a common evaluation metric for
1113  semantic image segmentation, which first computes the IOU for each
1114  semantic class and then computes the average over classes.
1115  IOU is defined as follows:
1116    IOU = true_positive / (true_positive + false_positive + false_negative).
1117  The predictions are accumulated in a confusion matrix, weighted by `weights`,
1118  and mIOU is then calculated from it.
1119
1120  For estimation of the metric over a stream of data, the function creates an
1121  `update_op` operation that updates these variables and returns the `mean_iou`.
1122
1123  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1124
1125  Args:
1126    labels: A `Tensor` of ground truth labels with shape [batch size] and of
1127      type `int32` or `int64`. The tensor will be flattened if its rank > 1.
1128    predictions: A `Tensor` of prediction results for semantic labels, whose
1129      shape is [batch size] and type `int32` or `int64`. The tensor will be
1130      flattened if its rank > 1.
1131    num_classes: The possible number of labels the prediction task can
1132      have. This value must be provided, since a confusion matrix of
1133      dimension = [num_classes, num_classes] will be allocated.
1134    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1135      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1136      be either `1`, or the same as the corresponding `labels` dimension).
1137    metrics_collections: An optional list of collections that `mean_iou`
1138      should be added to.
1139    updates_collections: An optional list of collections `update_op` should be
1140      added to.
1141    name: An optional variable_scope name.
1142
1143  Returns:
1144    mean_iou: A `Tensor` representing the mean intersection-over-union.
1145    update_op: An operation that increments the confusion matrix.
1146
1147  Raises:
1148    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1149      `weights` is not `None` and its shape doesn't match `predictions`, or if
1150      either `metrics_collections` or `updates_collections` are not a list or
1151      tuple.
1152    RuntimeError: If eager execution is enabled.
1153  """
1154  if context.executing_eagerly():
1155    raise RuntimeError('tf.metrics.mean_iou is not supported when '
1156                       'eager execution is enabled.')
1157
1158  with variable_scope.variable_scope(name, 'mean_iou',
1159                                     (predictions, labels, weights)):
1160    # Check if shape is compatible.
1161    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1162
1163    total_cm, update_op = _streaming_confusion_matrix(labels, predictions,
1164                                                      num_classes, weights)
1165
1166    def compute_mean_iou(_, total_cm):
1167      """Compute the mean intersection-over-union via the confusion matrix."""
1168      sum_over_row = math_ops.cast(
1169          math_ops.reduce_sum(total_cm, 0), dtypes.float32)
1170      sum_over_col = math_ops.cast(
1171          math_ops.reduce_sum(total_cm, 1), dtypes.float32)
1172      cm_diag = math_ops.cast(array_ops.diag_part(total_cm), dtypes.float32)
1173      denominator = sum_over_row + sum_over_col - cm_diag
1174
1175      # The mean is only computed over classes that appear in the
1176      # label or prediction tensor. If the denominator is 0, we need to
1177      # ignore the class.
1178      num_valid_entries = math_ops.reduce_sum(
1179          math_ops.cast(
1180              math_ops.not_equal(denominator, 0), dtype=dtypes.float32))
1181
1182      # If the value of the denominator is 0, set it to 1 to avoid
1183      # zero division.
1184      denominator = array_ops.where(
1185          math_ops.greater(denominator, 0), denominator,
1186          array_ops.ones_like(denominator))
1187      iou = math_ops.divide(cm_diag, denominator)
1188
1189      # If the number of valid entries is 0 (no classes) we return 0.
1190      result = array_ops.where(
1191          math_ops.greater(num_valid_entries, 0),
1192          math_ops.reduce_sum(iou, name='mean_iou') / num_valid_entries, 0)
1193      return result
1194
1195    # TODO(priyag): Use outside_compilation if in TPU context.
1196    mean_iou_v = _aggregate_across_replicas(
1197        metrics_collections, compute_mean_iou, total_cm)
1198
1199    if updates_collections:
1200      ops.add_to_collections(updates_collections, update_op)
1201
1202    return mean_iou_v, update_op
1203
1204
1205@tf_export(v1=['metrics.mean_relative_error'])
1206def mean_relative_error(labels,
1207                        predictions,
1208                        normalizer,
1209                        weights=None,
1210                        metrics_collections=None,
1211                        updates_collections=None,
1212                        name=None):
1213  """Computes the mean relative error by normalizing with the given values.
1214
1215  The `mean_relative_error` function creates two local variables,
1216  `total` and `count` that are used to compute the mean relative absolute error.
1217  This average is weighted by `weights`, and it is ultimately returned as
1218  `mean_relative_error`: an idempotent operation that simply divides `total` by
1219  `count`.
1220
1221  For estimation of the metric over a stream of data, the function creates an
1222  `update_op` operation that updates these variables and returns the
1223  `mean_reative_error`. Internally, a `relative_errors` operation divides the
1224  absolute value of the differences between `predictions` and `labels` by the
1225  `normalizer`. Then `update_op` increments `total` with the reduced sum of the
1226  product of `weights` and `relative_errors`, and it increments `count` with the
1227  reduced sum of `weights`.
1228
1229  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1230
1231  Args:
1232    labels: A `Tensor` of the same shape as `predictions`.
1233    predictions: A `Tensor` of arbitrary shape.
1234    normalizer: A `Tensor` of the same shape as `predictions`.
1235    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1236      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1237      be either `1`, or the same as the corresponding `labels` dimension).
1238    metrics_collections: An optional list of collections that
1239      `mean_relative_error` should be added to.
1240    updates_collections: An optional list of collections that `update_op` should
1241      be added to.
1242    name: An optional variable_scope name.
1243
1244  Returns:
1245    mean_relative_error: A `Tensor` representing the current mean, the value of
1246      `total` divided by `count`.
1247    update_op: An operation that increments the `total` and `count` variables
1248      appropriately and whose value matches `mean_relative_error`.
1249
1250  Raises:
1251    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1252      `weights` is not `None` and its shape doesn't match `predictions`, or if
1253      either `metrics_collections` or `updates_collections` are not a list or
1254      tuple.
1255    RuntimeError: If eager execution is enabled.
1256  """
1257  if context.executing_eagerly():
1258    raise RuntimeError('tf.metrics.mean_relative_error is not supported when '
1259                       'eager execution is enabled.')
1260
1261  predictions, labels, weights = _remove_squeezable_dimensions(
1262      predictions=predictions, labels=labels, weights=weights)
1263
1264  predictions, normalizer = confusion_matrix.remove_squeezable_dimensions(
1265      predictions, normalizer)
1266  predictions.get_shape().assert_is_compatible_with(normalizer.get_shape())
1267  relative_errors = array_ops.where(
1268      math_ops.equal(normalizer, 0.0), array_ops.zeros_like(labels),
1269      math_ops.divide(math_ops.abs(labels - predictions), normalizer))
1270  return mean(relative_errors, weights, metrics_collections,
1271              updates_collections, name or 'mean_relative_error')
1272
1273
1274@tf_export(v1=['metrics.mean_squared_error'])
1275def mean_squared_error(labels,
1276                       predictions,
1277                       weights=None,
1278                       metrics_collections=None,
1279                       updates_collections=None,
1280                       name=None):
1281  """Computes the mean squared error between the labels and predictions.
1282
1283  The `mean_squared_error` function creates two local variables,
1284  `total` and `count` that are used to compute the mean squared error.
1285  This average is weighted by `weights`, and it is ultimately returned as
1286  `mean_squared_error`: an idempotent operation that simply divides `total` by
1287  `count`.
1288
1289  For estimation of the metric over a stream of data, the function creates an
1290  `update_op` operation that updates these variables and returns the
1291  `mean_squared_error`. Internally, a `squared_error` operation computes the
1292  element-wise square of the difference between `predictions` and `labels`. Then
1293  `update_op` increments `total` with the reduced sum of the product of
1294  `weights` and `squared_error`, and it increments `count` with the reduced sum
1295  of `weights`.
1296
1297  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1298
1299  Args:
1300    labels: A `Tensor` of the same shape as `predictions`.
1301    predictions: A `Tensor` of arbitrary shape.
1302    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1303      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1304      be either `1`, or the same as the corresponding `labels` dimension).
1305    metrics_collections: An optional list of collections that
1306      `mean_squared_error` should be added to.
1307    updates_collections: An optional list of collections that `update_op` should
1308      be added to.
1309    name: An optional variable_scope name.
1310
1311  Returns:
1312    mean_squared_error: A `Tensor` representing the current mean, the value of
1313      `total` divided by `count`.
1314    update_op: An operation that increments the `total` and `count` variables
1315      appropriately and whose value matches `mean_squared_error`.
1316
1317  Raises:
1318    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1319      `weights` is not `None` and its shape doesn't match `predictions`, or if
1320      either `metrics_collections` or `updates_collections` are not a list or
1321      tuple.
1322    RuntimeError: If eager execution is enabled.
1323  """
1324  if context.executing_eagerly():
1325    raise RuntimeError('tf.metrics.mean_squared_error is not supported when '
1326                       'eager execution is enabled.')
1327
1328  predictions, labels, weights = _remove_squeezable_dimensions(
1329      predictions=predictions, labels=labels, weights=weights)
1330  squared_error = math_ops.squared_difference(labels, predictions)
1331  return mean(squared_error, weights, metrics_collections, updates_collections,
1332              name or 'mean_squared_error')
1333
1334
1335@tf_export(v1=['metrics.mean_tensor'])
1336def mean_tensor(values,
1337                weights=None,
1338                metrics_collections=None,
1339                updates_collections=None,
1340                name=None):
1341  """Computes the element-wise (weighted) mean of the given tensors.
1342
1343  In contrast to the `mean` function which returns a scalar with the
1344  mean,  this function returns an average tensor with the same shape as the
1345  input tensors.
1346
1347  The `mean_tensor` function creates two local variables,
1348  `total_tensor` and `count_tensor` that are used to compute the average of
1349  `values`. This average is ultimately returned as `mean` which is an idempotent
1350  operation that simply divides `total` by `count`.
1351
1352  For estimation of the metric over a stream of data, the function creates an
1353  `update_op` operation that updates these variables and returns the `mean`.
1354  `update_op` increments `total` with the reduced sum of the product of `values`
1355  and `weights`, and it increments `count` with the reduced sum of `weights`.
1356
1357  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1358
1359  Args:
1360    values: A `Tensor` of arbitrary dimensions.
1361    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1362      `values`, and must be broadcastable to `values` (i.e., all dimensions must
1363      be either `1`, or the same as the corresponding `values` dimension).
1364    metrics_collections: An optional list of collections that `mean`
1365      should be added to.
1366    updates_collections: An optional list of collections that `update_op`
1367      should be added to.
1368    name: An optional variable_scope name.
1369
1370  Returns:
1371    mean: A float `Tensor` representing the current mean, the value of `total`
1372      divided by `count`.
1373    update_op: An operation that increments the `total` and `count` variables
1374      appropriately and whose value matches `mean_value`.
1375
1376  Raises:
1377    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1378      or if either `metrics_collections` or `updates_collections` are not a list
1379      or tuple.
1380    RuntimeError: If eager execution is enabled.
1381  """
1382  if context.executing_eagerly():
1383    raise RuntimeError('tf.metrics.mean_tensor is not supported when '
1384                       'eager execution is enabled.')
1385
1386  with variable_scope.variable_scope(name, 'mean', (values, weights)):
1387    values = math_ops.cast(values, dtypes.float32)
1388    total = metric_variable(
1389        values.get_shape(), dtypes.float32, name='total_tensor')
1390    count = metric_variable(
1391        values.get_shape(), dtypes.float32, name='count_tensor')
1392
1393    num_values = array_ops.ones_like(values)
1394    if weights is not None:
1395      values, _, weights = _remove_squeezable_dimensions(
1396          predictions=values, labels=None, weights=weights)
1397      weights = weights_broadcast_ops.broadcast_weights(
1398          math_ops.cast(weights, dtypes.float32), values)
1399      values = math_ops.multiply(values, weights)
1400      num_values = math_ops.multiply(num_values, weights)
1401
1402    update_total_op = state_ops.assign_add(total, values)
1403    with ops.control_dependencies([values]):
1404      update_count_op = state_ops.assign_add(count, num_values)
1405
1406    compute_mean = lambda _, t, c: math_ops.div_no_nan(  # pylint: disable=g-long-lambda
1407        t, math_ops.maximum(c, 0), name='value')
1408
1409    mean_t = _aggregate_across_replicas(
1410        metrics_collections, compute_mean, total, count)
1411
1412    update_op = math_ops.div_no_nan(
1413        update_total_op, math_ops.maximum(update_count_op, 0), name='update_op')
1414    if updates_collections:
1415      ops.add_to_collections(updates_collections, update_op)
1416
1417    return mean_t, update_op
1418
1419
1420@tf_export(v1=['metrics.percentage_below'])
1421def percentage_below(values,
1422                     threshold,
1423                     weights=None,
1424                     metrics_collections=None,
1425                     updates_collections=None,
1426                     name=None):
1427  """Computes the percentage of values less than the given threshold.
1428
1429  The `percentage_below` function creates two local variables,
1430  `total` and `count` that are used to compute the percentage of `values` that
1431  fall below `threshold`. This rate is weighted by `weights`, and it is
1432  ultimately returned as `percentage` which is an idempotent operation that
1433  simply divides `total` by `count`.
1434
1435  For estimation of the metric over a stream of data, the function creates an
1436  `update_op` operation that updates these variables and returns the
1437  `percentage`.
1438
1439  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1440
1441  Args:
1442    values: A numeric `Tensor` of arbitrary size.
1443    threshold: A scalar threshold.
1444    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1445      `values`, and must be broadcastable to `values` (i.e., all dimensions must
1446      be either `1`, or the same as the corresponding `values` dimension).
1447    metrics_collections: An optional list of collections that the metric
1448      value variable should be added to.
1449    updates_collections: An optional list of collections that the metric update
1450      ops should be added to.
1451    name: An optional variable_scope name.
1452
1453  Returns:
1454    percentage: A `Tensor` representing the current mean, the value of `total`
1455      divided by `count`.
1456    update_op: An operation that increments the `total` and `count` variables
1457      appropriately.
1458
1459  Raises:
1460    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1461      or if either `metrics_collections` or `updates_collections` are not a list
1462      or tuple.
1463    RuntimeError: If eager execution is enabled.
1464  """
1465  if context.executing_eagerly():
1466    raise RuntimeError('tf.metrics.percentage_below is not supported when '
1467                       'eager execution is enabled.')
1468
1469  is_below_threshold = math_ops.cast(
1470      math_ops.less(values, threshold), dtypes.float32)
1471  return mean(is_below_threshold, weights, metrics_collections,
1472              updates_collections, name or 'percentage_below_threshold')
1473
1474
1475def _count_condition(values,
1476                     weights=None,
1477                     metrics_collections=None,
1478                     updates_collections=None):
1479  """Sums the weights of cases where the given values are True.
1480
1481  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1482
1483  Args:
1484    values: A `bool` `Tensor` of arbitrary size.
1485    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1486      `values`, and must be broadcastable to `values` (i.e., all dimensions must
1487      be either `1`, or the same as the corresponding `values` dimension).
1488    metrics_collections: An optional list of collections that the metric
1489      value variable should be added to.
1490    updates_collections: An optional list of collections that the metric update
1491      ops should be added to.
1492
1493  Returns:
1494    value_tensor: A `Tensor` representing the current value of the metric.
1495    update_op: An operation that accumulates the error from a batch of data.
1496
1497  Raises:
1498    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1499      or if either `metrics_collections` or `updates_collections` are not a list
1500      or tuple.
1501  """
1502  check_ops.assert_type(values, dtypes.bool)
1503  count = metric_variable([], dtypes.float32, name='count')
1504
1505  values = math_ops.cast(values, dtypes.float32)
1506  if weights is not None:
1507    with ops.control_dependencies((check_ops.assert_rank_in(
1508        weights, (0, array_ops.rank(values))),)):
1509      weights = math_ops.cast(weights, dtypes.float32)
1510      values = math_ops.multiply(values, weights)
1511
1512  value_tensor = _aggregate_variable(count, metrics_collections)
1513
1514  update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
1515  if updates_collections:
1516    ops.add_to_collections(updates_collections, update_op)
1517
1518  return value_tensor, update_op
1519
1520
1521@tf_export(v1=['metrics.false_negatives'])
1522def false_negatives(labels,
1523                    predictions,
1524                    weights=None,
1525                    metrics_collections=None,
1526                    updates_collections=None,
1527                    name=None):
1528  """Computes the total number of false negatives.
1529
1530  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1531
1532  Args:
1533    labels: The ground truth values, a `Tensor` whose dimensions must match
1534      `predictions`. Will be cast to `bool`.
1535    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1536      be cast to `bool`.
1537    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1538      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1539      be either `1`, or the same as the corresponding `labels` dimension).
1540    metrics_collections: An optional list of collections that the metric
1541      value variable should be added to.
1542    updates_collections: An optional list of collections that the metric update
1543      ops should be added to.
1544    name: An optional variable_scope name.
1545
1546  Returns:
1547    value_tensor: A `Tensor` representing the current value of the metric.
1548    update_op: An operation that accumulates the error from a batch of data.
1549
1550  Raises:
1551    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1552      or if either `metrics_collections` or `updates_collections` are not a list
1553      or tuple.
1554    RuntimeError: If eager execution is enabled.
1555  """
1556  if context.executing_eagerly():
1557    raise RuntimeError('tf.metrics.false_negatives is not supported when '
1558                       'eager execution is enabled.')
1559
1560  with variable_scope.variable_scope(name, 'false_negatives',
1561                                     (predictions, labels, weights)):
1562
1563    predictions, labels, weights = _remove_squeezable_dimensions(
1564        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
1565        labels=math_ops.cast(labels, dtype=dtypes.bool),
1566        weights=weights)
1567    is_false_negative = math_ops.logical_and(
1568        math_ops.equal(labels, True), math_ops.equal(predictions, False))
1569    return _count_condition(is_false_negative, weights, metrics_collections,
1570                            updates_collections)
1571
1572
1573@tf_export(v1=['metrics.false_negatives_at_thresholds'])
1574def false_negatives_at_thresholds(labels,
1575                                  predictions,
1576                                  thresholds,
1577                                  weights=None,
1578                                  metrics_collections=None,
1579                                  updates_collections=None,
1580                                  name=None):
1581  """Computes false negatives at provided threshold values.
1582
1583  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1584
1585  Args:
1586    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
1587      `bool`.
1588    predictions: A floating point `Tensor` of arbitrary shape and whose values
1589      are in the range `[0, 1]`.
1590    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1591    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1592      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1593      be either `1`, or the same as the corresponding `labels` dimension).
1594    metrics_collections: An optional list of collections that `false_negatives`
1595      should be added to.
1596    updates_collections: An optional list of collections that `update_op` should
1597      be added to.
1598    name: An optional variable_scope name.
1599
1600  Returns:
1601    false_negatives:  A float `Tensor` of shape `[len(thresholds)]`.
1602    update_op: An operation that updates the `false_negatives` variable and
1603      returns its current value.
1604
1605  Raises:
1606    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1607      `weights` is not `None` and its shape doesn't match `predictions`, or if
1608      either `metrics_collections` or `updates_collections` are not a list or
1609      tuple.
1610    RuntimeError: If eager execution is enabled.
1611  """
1612  if context.executing_eagerly():
1613    raise RuntimeError('tf.metrics.false_negatives_at_thresholds is not '
1614                       'supported when eager execution is enabled.')
1615
1616  with variable_scope.variable_scope(name, 'false_negatives',
1617                                     (predictions, labels, weights)):
1618    values, update_ops = _confusion_matrix_at_thresholds(
1619        labels, predictions, thresholds, weights=weights, includes=('fn',))
1620
1621    fn_value = _aggregate_variable(values['fn'], metrics_collections)
1622
1623    if updates_collections:
1624      ops.add_to_collections(updates_collections, update_ops['fn'])
1625
1626    return fn_value, update_ops['fn']
1627
1628
1629@tf_export(v1=['metrics.false_positives'])
1630def false_positives(labels,
1631                    predictions,
1632                    weights=None,
1633                    metrics_collections=None,
1634                    updates_collections=None,
1635                    name=None):
1636  """Sum the weights of false positives.
1637
1638  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1639
1640  Args:
1641    labels: The ground truth values, a `Tensor` whose dimensions must match
1642      `predictions`. Will be cast to `bool`.
1643    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1644      be cast to `bool`.
1645    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1646      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1647      be either `1`, or the same as the corresponding `labels` dimension).
1648    metrics_collections: An optional list of collections that the metric
1649      value variable should be added to.
1650    updates_collections: An optional list of collections that the metric update
1651      ops should be added to.
1652    name: An optional variable_scope name.
1653
1654  Returns:
1655    value_tensor: A `Tensor` representing the current value of the metric.
1656    update_op: An operation that accumulates the error from a batch of data.
1657
1658  Raises:
1659    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1660      `weights` is not `None` and its shape doesn't match `predictions`, or if
1661      either `metrics_collections` or `updates_collections` are not a list or
1662      tuple.
1663    RuntimeError: If eager execution is enabled.
1664  """
1665  if context.executing_eagerly():
1666    raise RuntimeError('tf.metrics.false_positives is not supported when '
1667                       'eager execution is enabled.')
1668
1669  with variable_scope.variable_scope(name, 'false_positives',
1670                                     (predictions, labels, weights)):
1671
1672    predictions, labels, weights = _remove_squeezable_dimensions(
1673        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
1674        labels=math_ops.cast(labels, dtype=dtypes.bool),
1675        weights=weights)
1676    is_false_positive = math_ops.logical_and(
1677        math_ops.equal(labels, False), math_ops.equal(predictions, True))
1678    return _count_condition(is_false_positive, weights, metrics_collections,
1679                            updates_collections)
1680
1681
1682@tf_export(v1=['metrics.false_positives_at_thresholds'])
1683def false_positives_at_thresholds(labels,
1684                                  predictions,
1685                                  thresholds,
1686                                  weights=None,
1687                                  metrics_collections=None,
1688                                  updates_collections=None,
1689                                  name=None):
1690  """Computes false positives at provided threshold values.
1691
1692  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1693
1694  Args:
1695    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
1696      `bool`.
1697    predictions: A floating point `Tensor` of arbitrary shape and whose values
1698      are in the range `[0, 1]`.
1699    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1700    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1701      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1702      be either `1`, or the same as the corresponding `labels` dimension).
1703    metrics_collections: An optional list of collections that `false_positives`
1704      should be added to.
1705    updates_collections: An optional list of collections that `update_op` should
1706      be added to.
1707    name: An optional variable_scope name.
1708
1709  Returns:
1710    false_positives:  A float `Tensor` of shape `[len(thresholds)]`.
1711    update_op: An operation that updates the `false_positives` variable and
1712      returns its current value.
1713
1714  Raises:
1715    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1716      `weights` is not `None` and its shape doesn't match `predictions`, or if
1717      either `metrics_collections` or `updates_collections` are not a list or
1718      tuple.
1719    RuntimeError: If eager execution is enabled.
1720  """
1721  if context.executing_eagerly():
1722    raise RuntimeError('tf.metrics.false_positives_at_thresholds is not '
1723                       'supported when eager execution is enabled.')
1724
1725  with variable_scope.variable_scope(name, 'false_positives',
1726                                     (predictions, labels, weights)):
1727    values, update_ops = _confusion_matrix_at_thresholds(
1728        labels, predictions, thresholds, weights=weights, includes=('fp',))
1729
1730    fp_value = _aggregate_variable(values['fp'], metrics_collections)
1731
1732    if updates_collections:
1733      ops.add_to_collections(updates_collections, update_ops['fp'])
1734
1735    return fp_value, update_ops['fp']
1736
1737
1738@tf_export(v1=['metrics.true_negatives'])
1739def true_negatives(labels,
1740                   predictions,
1741                   weights=None,
1742                   metrics_collections=None,
1743                   updates_collections=None,
1744                   name=None):
1745  """Sum the weights of true_negatives.
1746
1747  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1748
1749  Args:
1750    labels: The ground truth values, a `Tensor` whose dimensions must match
1751      `predictions`. Will be cast to `bool`.
1752    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1753      be cast to `bool`.
1754    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1755      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1756      be either `1`, or the same as the corresponding `labels` dimension).
1757    metrics_collections: An optional list of collections that the metric
1758      value variable should be added to.
1759    updates_collections: An optional list of collections that the metric update
1760      ops should be added to.
1761    name: An optional variable_scope name.
1762
1763  Returns:
1764    value_tensor: A `Tensor` representing the current value of the metric.
1765    update_op: An operation that accumulates the error from a batch of data.
1766
1767  Raises:
1768    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1769      `weights` is not `None` and its shape doesn't match `predictions`, or if
1770      either `metrics_collections` or `updates_collections` are not a list or
1771      tuple.
1772    RuntimeError: If eager execution is enabled.
1773  """
1774  if context.executing_eagerly():
1775    raise RuntimeError('tf.metrics.true_negatives is not '
1776                       'supported when eager execution is enabled.')
1777
1778  with variable_scope.variable_scope(name, 'true_negatives',
1779                                     (predictions, labels, weights)):
1780
1781    predictions, labels, weights = _remove_squeezable_dimensions(
1782        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
1783        labels=math_ops.cast(labels, dtype=dtypes.bool),
1784        weights=weights)
1785    is_true_negative = math_ops.logical_and(
1786        math_ops.equal(labels, False), math_ops.equal(predictions, False))
1787    return _count_condition(is_true_negative, weights, metrics_collections,
1788                            updates_collections)
1789
1790
1791@tf_export(v1=['metrics.true_negatives_at_thresholds'])
1792def true_negatives_at_thresholds(labels,
1793                                 predictions,
1794                                 thresholds,
1795                                 weights=None,
1796                                 metrics_collections=None,
1797                                 updates_collections=None,
1798                                 name=None):
1799  """Computes true negatives at provided threshold values.
1800
1801  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1802
1803  Args:
1804    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
1805      `bool`.
1806    predictions: A floating point `Tensor` of arbitrary shape and whose values
1807      are in the range `[0, 1]`.
1808    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1809    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1810      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1811      be either `1`, or the same as the corresponding `labels` dimension).
1812    metrics_collections: An optional list of collections that `true_negatives`
1813      should be added to.
1814    updates_collections: An optional list of collections that `update_op` should
1815      be added to.
1816    name: An optional variable_scope name.
1817
1818  Returns:
1819    true_negatives:  A float `Tensor` of shape `[len(thresholds)]`.
1820    update_op: An operation that updates the `true_negatives` variable and
1821      returns its current value.
1822
1823  Raises:
1824    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1825      `weights` is not `None` and its shape doesn't match `predictions`, or if
1826      either `metrics_collections` or `updates_collections` are not a list or
1827      tuple.
1828    RuntimeError: If eager execution is enabled.
1829  """
1830  if context.executing_eagerly():
1831    raise RuntimeError('tf.metrics.true_negatives_at_thresholds is not '
1832                       'supported when eager execution is enabled.')
1833
1834  with variable_scope.variable_scope(name, 'true_negatives',
1835                                     (predictions, labels, weights)):
1836    values, update_ops = _confusion_matrix_at_thresholds(
1837        labels, predictions, thresholds, weights=weights, includes=('tn',))
1838
1839    tn_value = _aggregate_variable(values['tn'], metrics_collections)
1840
1841    if updates_collections:
1842      ops.add_to_collections(updates_collections, update_ops['tn'])
1843
1844    return tn_value, update_ops['tn']
1845
1846
1847@tf_export(v1=['metrics.true_positives'])
1848def true_positives(labels,
1849                   predictions,
1850                   weights=None,
1851                   metrics_collections=None,
1852                   updates_collections=None,
1853                   name=None):
1854  """Sum the weights of true_positives.
1855
1856  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1857
1858  Args:
1859    labels: The ground truth values, a `Tensor` whose dimensions must match
1860      `predictions`. Will be cast to `bool`.
1861    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1862      be cast to `bool`.
1863    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1864      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1865      be either `1`, or the same as the corresponding `labels` dimension).
1866    metrics_collections: An optional list of collections that the metric
1867      value variable should be added to.
1868    updates_collections: An optional list of collections that the metric update
1869      ops should be added to.
1870    name: An optional variable_scope name.
1871
1872  Returns:
1873    value_tensor: A `Tensor` representing the current value of the metric.
1874    update_op: An operation that accumulates the error from a batch of data.
1875
1876  Raises:
1877    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1878      `weights` is not `None` and its shape doesn't match `predictions`, or if
1879      either `metrics_collections` or `updates_collections` are not a list or
1880      tuple.
1881    RuntimeError: If eager execution is enabled.
1882  """
1883  if context.executing_eagerly():
1884    raise RuntimeError('tf.metrics.true_positives is not '
1885                       'supported when eager execution is enabled.')
1886
1887  with variable_scope.variable_scope(name, 'true_positives',
1888                                     (predictions, labels, weights)):
1889
1890    predictions, labels, weights = _remove_squeezable_dimensions(
1891        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
1892        labels=math_ops.cast(labels, dtype=dtypes.bool),
1893        weights=weights)
1894    is_true_positive = math_ops.logical_and(
1895        math_ops.equal(labels, True), math_ops.equal(predictions, True))
1896    return _count_condition(is_true_positive, weights, metrics_collections,
1897                            updates_collections)
1898
1899
1900@tf_export(v1=['metrics.true_positives_at_thresholds'])
1901def true_positives_at_thresholds(labels,
1902                                 predictions,
1903                                 thresholds,
1904                                 weights=None,
1905                                 metrics_collections=None,
1906                                 updates_collections=None,
1907                                 name=None):
1908  """Computes true positives at provided threshold values.
1909
1910  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1911
1912  Args:
1913    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
1914      `bool`.
1915    predictions: A floating point `Tensor` of arbitrary shape and whose values
1916      are in the range `[0, 1]`.
1917    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1918    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1919      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1920      be either `1`, or the same as the corresponding `labels` dimension).
1921    metrics_collections: An optional list of collections that `true_positives`
1922      should be added to.
1923    updates_collections: An optional list of collections that `update_op` should
1924      be added to.
1925    name: An optional variable_scope name.
1926
1927  Returns:
1928    true_positives:  A float `Tensor` of shape `[len(thresholds)]`.
1929    update_op: An operation that updates the `true_positives` variable and
1930      returns its current value.
1931
1932  Raises:
1933    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1934      `weights` is not `None` and its shape doesn't match `predictions`, or if
1935      either `metrics_collections` or `updates_collections` are not a list or
1936      tuple.
1937    RuntimeError: If eager execution is enabled.
1938  """
1939  if context.executing_eagerly():
1940    raise RuntimeError('tf.metrics.true_positives_at_thresholds is not '
1941                       'supported when eager execution is enabled.')
1942
1943  with variable_scope.variable_scope(name, 'true_positives',
1944                                     (predictions, labels, weights)):
1945    values, update_ops = _confusion_matrix_at_thresholds(
1946        labels, predictions, thresholds, weights=weights, includes=('tp',))
1947
1948    tp_value = _aggregate_variable(values['tp'], metrics_collections)
1949
1950    if updates_collections:
1951      ops.add_to_collections(updates_collections, update_ops['tp'])
1952
1953    return tp_value, update_ops['tp']
1954
1955
1956@tf_export(v1=['metrics.precision'])
1957def precision(labels,
1958              predictions,
1959              weights=None,
1960              metrics_collections=None,
1961              updates_collections=None,
1962              name=None):
1963  """Computes the precision of the predictions with respect to the labels.
1964
1965  The `precision` function creates two local variables,
1966  `true_positives` and `false_positives`, that are used to compute the
1967  precision. This value is ultimately returned as `precision`, an idempotent
1968  operation that simply divides `true_positives` by the sum of `true_positives`
1969  and `false_positives`.
1970
1971  For estimation of the metric over a stream of data, the function creates an
1972  `update_op` operation that updates these variables and returns the
1973  `precision`. `update_op` weights each prediction by the corresponding value in
1974  `weights`.
1975
1976  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1977
1978  Args:
1979    labels: The ground truth values, a `Tensor` whose dimensions must match
1980      `predictions`. Will be cast to `bool`.
1981    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1982      be cast to `bool`.
1983    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1984      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1985      be either `1`, or the same as the corresponding `labels` dimension).
1986    metrics_collections: An optional list of collections that `precision` should
1987      be added to.
1988    updates_collections: An optional list of collections that `update_op` should
1989      be added to.
1990    name: An optional variable_scope name.
1991
1992  Returns:
1993    precision: Scalar float `Tensor` with the value of `true_positives`
1994      divided by the sum of `true_positives` and `false_positives`.
1995    update_op: `Operation` that increments `true_positives` and
1996      `false_positives` variables appropriately and whose value matches
1997      `precision`.
1998
1999  Raises:
2000    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2001      `weights` is not `None` and its shape doesn't match `predictions`, or if
2002      either `metrics_collections` or `updates_collections` are not a list or
2003      tuple.
2004    RuntimeError: If eager execution is enabled.
2005  """
2006  if context.executing_eagerly():
2007    raise RuntimeError('tf.metrics.precision is not '
2008                       'supported when eager execution is enabled.')
2009
2010  with variable_scope.variable_scope(name, 'precision',
2011                                     (predictions, labels, weights)):
2012
2013    predictions, labels, weights = _remove_squeezable_dimensions(
2014        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
2015        labels=math_ops.cast(labels, dtype=dtypes.bool),
2016        weights=weights)
2017
2018    true_p, true_positives_update_op = true_positives(
2019        labels,
2020        predictions,
2021        weights,
2022        metrics_collections=None,
2023        updates_collections=None,
2024        name=None)
2025    false_p, false_positives_update_op = false_positives(
2026        labels,
2027        predictions,
2028        weights,
2029        metrics_collections=None,
2030        updates_collections=None,
2031        name=None)
2032
2033    def compute_precision(tp, fp, name):
2034      return array_ops.where(
2035          math_ops.greater(tp + fp, 0), math_ops.divide(tp, tp + fp), 0, name)
2036
2037    def once_across_replicas(_, true_p, false_p):
2038      return compute_precision(true_p, false_p, 'value')
2039
2040    p = _aggregate_across_replicas(metrics_collections, once_across_replicas,
2041                                   true_p, false_p)
2042
2043    update_op = compute_precision(true_positives_update_op,
2044                                  false_positives_update_op, 'update_op')
2045    if updates_collections:
2046      ops.add_to_collections(updates_collections, update_op)
2047
2048    return p, update_op
2049
2050
2051@tf_export(v1=['metrics.precision_at_thresholds'])
2052def precision_at_thresholds(labels,
2053                            predictions,
2054                            thresholds,
2055                            weights=None,
2056                            metrics_collections=None,
2057                            updates_collections=None,
2058                            name=None):
2059  """Computes precision values for different `thresholds` on `predictions`.
2060
2061  The `precision_at_thresholds` function creates four local variables,
2062  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
2063  for various values of thresholds. `precision[i]` is defined as the total
2064  weight of values in `predictions` above `thresholds[i]` whose corresponding
2065  entry in `labels` is `True`, divided by the total weight of values in
2066  `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] +
2067  false_positives[i])`).
2068
2069  For estimation of the metric over a stream of data, the function creates an
2070  `update_op` operation that updates these variables and returns the
2071  `precision`.
2072
2073  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2074
2075  Args:
2076    labels: The ground truth values, a `Tensor` whose dimensions must match
2077      `predictions`. Will be cast to `bool`.
2078    predictions: A floating point `Tensor` of arbitrary shape and whose values
2079      are in the range `[0, 1]`.
2080    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
2081    weights: Optional `Tensor` whose rank is either 0, or the same rank as
2082      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2083      be either `1`, or the same as the corresponding `labels` dimension).
2084    metrics_collections: An optional list of collections that `auc` should be
2085      added to.
2086    updates_collections: An optional list of collections that `update_op` should
2087      be added to.
2088    name: An optional variable_scope name.
2089
2090  Returns:
2091    precision: A float `Tensor` of shape `[len(thresholds)]`.
2092    update_op: An operation that increments the `true_positives`,
2093      `true_negatives`, `false_positives` and `false_negatives` variables that
2094      are used in the computation of `precision`.
2095
2096  Raises:
2097    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2098      `weights` is not `None` and its shape doesn't match `predictions`, or if
2099      either `metrics_collections` or `updates_collections` are not a list or
2100      tuple.
2101    RuntimeError: If eager execution is enabled.
2102  """
2103  if context.executing_eagerly():
2104    raise RuntimeError('tf.metrics.precision_at_thresholds is not '
2105                       'supported when eager execution is enabled.')
2106
2107  with variable_scope.variable_scope(name, 'precision_at_thresholds',
2108                                     (predictions, labels, weights)):
2109    values, update_ops = _confusion_matrix_at_thresholds(
2110        labels, predictions, thresholds, weights, includes=('tp', 'fp'))
2111
2112    # Avoid division by zero.
2113    epsilon = 1e-7
2114
2115    def compute_precision(tp, fp, name):
2116      return math_ops.divide(tp, epsilon + tp + fp, name='precision_' + name)
2117
2118    def precision_across_replicas(_, values):
2119      return compute_precision(values['tp'], values['fp'], 'value')
2120
2121    prec = _aggregate_across_replicas(
2122        metrics_collections, precision_across_replicas, values)
2123
2124    update_op = compute_precision(update_ops['tp'], update_ops['fp'],
2125                                  'update_op')
2126    if updates_collections:
2127      ops.add_to_collections(updates_collections, update_op)
2128
2129    return prec, update_op
2130
2131
2132@tf_export(v1=['metrics.recall'])
2133def recall(labels,
2134           predictions,
2135           weights=None,
2136           metrics_collections=None,
2137           updates_collections=None,
2138           name=None):
2139  """Computes the recall of the predictions with respect to the labels.
2140
2141  The `recall` function creates two local variables, `true_positives`
2142  and `false_negatives`, that are used to compute the recall. This value is
2143  ultimately returned as `recall`, an idempotent operation that simply divides
2144  `true_positives` by the sum of `true_positives` and `false_negatives`.
2145
2146  For estimation of the metric over a stream of data, the function creates an
2147  `update_op` that updates these variables and returns the `recall`. `update_op`
2148  weights each prediction by the corresponding value in `weights`.
2149
2150  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2151
2152  Args:
2153    labels: The ground truth values, a `Tensor` whose dimensions must match
2154      `predictions`. Will be cast to `bool`.
2155    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
2156      be cast to `bool`.
2157    weights: Optional `Tensor` whose rank is either 0, or the same rank as
2158      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2159      be either `1`, or the same as the corresponding `labels` dimension).
2160    metrics_collections: An optional list of collections that `recall` should
2161      be added to.
2162    updates_collections: An optional list of collections that `update_op` should
2163      be added to.
2164    name: An optional variable_scope name.
2165
2166  Returns:
2167    recall: Scalar float `Tensor` with the value of `true_positives` divided
2168      by the sum of `true_positives` and `false_negatives`.
2169    update_op: `Operation` that increments `true_positives` and
2170      `false_negatives` variables appropriately and whose value matches
2171      `recall`.
2172
2173  Raises:
2174    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2175      `weights` is not `None` and its shape doesn't match `predictions`, or if
2176      either `metrics_collections` or `updates_collections` are not a list or
2177      tuple.
2178    RuntimeError: If eager execution is enabled.
2179  """
2180  if context.executing_eagerly():
2181    raise RuntimeError('tf.metrics.recall is not supported is not '
2182                       'supported when eager execution is enabled.')
2183
2184  with variable_scope.variable_scope(name, 'recall',
2185                                     (predictions, labels, weights)):
2186    predictions, labels, weights = _remove_squeezable_dimensions(
2187        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
2188        labels=math_ops.cast(labels, dtype=dtypes.bool),
2189        weights=weights)
2190
2191    true_p, true_positives_update_op = true_positives(
2192        labels,
2193        predictions,
2194        weights,
2195        metrics_collections=None,
2196        updates_collections=None,
2197        name=None)
2198    false_n, false_negatives_update_op = false_negatives(
2199        labels,
2200        predictions,
2201        weights,
2202        metrics_collections=None,
2203        updates_collections=None,
2204        name=None)
2205
2206    def compute_recall(true_p, false_n, name):
2207      return array_ops.where(
2208          math_ops.greater(true_p + false_n, 0),
2209          math_ops.divide(true_p, true_p + false_n), 0, name)
2210
2211    def once_across_replicas(_, true_p, false_n):
2212      return compute_recall(true_p, false_n, 'value')
2213
2214    rec = _aggregate_across_replicas(
2215        metrics_collections, once_across_replicas, true_p, false_n)
2216
2217    update_op = compute_recall(true_positives_update_op,
2218                               false_negatives_update_op, 'update_op')
2219    if updates_collections:
2220      ops.add_to_collections(updates_collections, update_op)
2221
2222    return rec, update_op
2223
2224
2225def _at_k_name(name, k=None, class_id=None):
2226  if k is not None:
2227    name = '%s_at_%d' % (name, k)
2228  else:
2229    name = '%s_at_k' % (name)
2230  if class_id is not None:
2231    name = '%s_class%d' % (name, class_id)
2232  return name
2233
2234
2235def _select_class_id(ids, selected_id):
2236  """Filter all but `selected_id` out of `ids`.
2237
2238  Args:
2239    ids: `int64` `Tensor` or `SparseTensor` of IDs.
2240    selected_id: Int id to select.
2241
2242  Returns:
2243    `SparseTensor` of same dimensions as `ids`. This contains only the entries
2244    equal to `selected_id`.
2245  """
2246  ids = sparse_tensor.convert_to_tensor_or_sparse_tensor(ids)
2247  if isinstance(ids, sparse_tensor.SparseTensor):
2248    return sparse_ops.sparse_retain(ids, math_ops.equal(ids.values,
2249                                                        selected_id))
2250
2251  # TODO(ptucker): Make this more efficient, maybe add a sparse version of
2252  # tf.equal and tf.reduce_any?
2253
2254  # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1.
2255  ids_shape = array_ops.shape(ids, out_type=dtypes.int64)
2256  ids_last_dim = array_ops.size(ids_shape) - 1
2257  filled_selected_id_shape = math_ops.reduced_shape(ids_shape,
2258                                                    array_ops.reshape(
2259                                                        ids_last_dim, [1]))
2260
2261  # Intersect `ids` with the selected ID.
2262  filled_selected_id = array_ops.fill(filled_selected_id_shape,
2263                                      math_ops.cast(selected_id, dtypes.int64))
2264  result = sets.set_intersection(filled_selected_id, ids)
2265  return sparse_tensor.SparseTensor(
2266      indices=result.indices, values=result.values, dense_shape=ids_shape)
2267
2268
2269def _maybe_select_class_id(labels, predictions_idx, selected_id=None):
2270  """If class ID is specified, filter all other classes.
2271
2272  Args:
2273    labels: `int64` `Tensor` or `SparseTensor` with shape
2274      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2275      target classes for the associated prediction. Commonly, N=1 and `labels`
2276      has shape [batch_size, num_labels]. [D1, ... DN] must match
2277      `predictions_idx`.
2278    predictions_idx: `int64` `Tensor` of class IDs, with shape [D1, ... DN, k]
2279      where N >= 1. Commonly, N=1 and `predictions_idx` has shape
2280      [batch size, k].
2281    selected_id: Int id to select.
2282
2283  Returns:
2284    Tuple of `labels` and `predictions_idx`, possibly with classes removed.
2285  """
2286  if selected_id is None:
2287    return labels, predictions_idx
2288  return (_select_class_id(labels, selected_id),
2289          _select_class_id(predictions_idx, selected_id))
2290
2291
2292def _sparse_true_positive_at_k(labels,
2293                               predictions_idx,
2294                               class_id=None,
2295                               weights=None,
2296                               name=None):
2297  """Calculates true positives for recall@k and precision@k.
2298
2299  If `class_id` is specified, calculate binary true positives for `class_id`
2300      only.
2301  If `class_id` is not specified, calculate metrics for `k` predicted vs
2302      `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
2303
2304  Args:
2305    labels: `int64` `Tensor` or `SparseTensor` with shape
2306      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2307      target classes for the associated prediction. Commonly, N=1 and `labels`
2308      has shape [batch_size, num_labels]. [D1, ... DN] must match
2309      `predictions_idx`.
2310    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2311      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2312      match `labels`.
2313    class_id: Class for which we want binary metrics.
2314    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2315      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2316      dimensions must be either `1`, or the same as the corresponding `labels`
2317      dimension).
2318    name: Name of operation.
2319
2320  Returns:
2321    A [D1, ... DN] `Tensor` of true positive counts.
2322  """
2323  with ops.name_scope(name, 'true_positives',
2324                      (predictions_idx, labels, weights)):
2325    labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
2326                                                     class_id)
2327    tp = sets.set_size(sets.set_intersection(predictions_idx, labels))
2328    tp = math_ops.cast(tp, dtypes.float64)
2329    if weights is not None:
2330      with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable(
2331          weights, tp),)):
2332        weights = math_ops.cast(weights, dtypes.float64)
2333        tp = math_ops.multiply(tp, weights)
2334    return tp
2335
2336
2337def _streaming_sparse_true_positive_at_k(labels,
2338                                         predictions_idx,
2339                                         k=None,
2340                                         class_id=None,
2341                                         weights=None,
2342                                         name=None):
2343  """Calculates weighted per step true positives for recall@k and precision@k.
2344
2345  If `class_id` is specified, calculate binary true positives for `class_id`
2346      only.
2347  If `class_id` is not specified, calculate metrics for `k` predicted vs
2348      `n` label classes, where `n` is the 2nd dimension of `labels`.
2349
2350  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2351
2352  Args:
2353    labels: `int64` `Tensor` or `SparseTensor` with shape
2354      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2355      target classes for the associated prediction. Commonly, N=1 and `labels`
2356      has shape [batch_size, num_labels]. [D1, ... DN] must match
2357      `predictions_idx`.
2358    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2359      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2360      match `labels`.
2361    k: Integer, k for @k metric. This is only used for default op name.
2362    class_id: Class for which we want binary metrics.
2363    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2364      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2365      dimensions must be either `1`, or the same as the corresponding `labels`
2366      dimension).
2367    name: Name of new variable, and namespace for other dependent ops.
2368
2369  Returns:
2370    A tuple of `Variable` and update `Operation`.
2371
2372  Raises:
2373    ValueError: If `weights` is not `None` and has an incompatible shape.
2374  """
2375  with ops.name_scope(name, _at_k_name('true_positive', k, class_id=class_id),
2376                      (predictions_idx, labels, weights)) as scope:
2377    tp = _sparse_true_positive_at_k(
2378        predictions_idx=predictions_idx,
2379        labels=labels,
2380        class_id=class_id,
2381        weights=weights)
2382    batch_total_tp = math_ops.cast(math_ops.reduce_sum(tp), dtypes.float64)
2383
2384    var = metric_variable([], dtypes.float64, name=scope)
2385    return var, state_ops.assign_add(var, batch_total_tp, name='update')
2386
2387
2388def _sparse_false_negative_at_k(labels,
2389                                predictions_idx,
2390                                class_id=None,
2391                                weights=None):
2392  """Calculates false negatives for recall@k.
2393
2394  If `class_id` is specified, calculate binary true positives for `class_id`
2395      only.
2396  If `class_id` is not specified, calculate metrics for `k` predicted vs
2397      `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
2398
2399  Args:
2400    labels: `int64` `Tensor` or `SparseTensor` with shape
2401      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2402      target classes for the associated prediction. Commonly, N=1 and `labels`
2403      has shape [batch_size, num_labels]. [D1, ... DN] must match
2404      `predictions_idx`.
2405    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2406      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2407      match `labels`.
2408    class_id: Class for which we want binary metrics.
2409    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2410      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2411      dimensions must be either `1`, or the same as the corresponding `labels`
2412      dimension).
2413
2414  Returns:
2415    A [D1, ... DN] `Tensor` of false negative counts.
2416  """
2417  with ops.name_scope(None, 'false_negatives',
2418                      (predictions_idx, labels, weights)):
2419    labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
2420                                                     class_id)
2421    fn = sets.set_size(
2422        sets.set_difference(predictions_idx, labels, aminusb=False))
2423    fn = math_ops.cast(fn, dtypes.float64)
2424    if weights is not None:
2425      with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable(
2426          weights, fn),)):
2427        weights = math_ops.cast(weights, dtypes.float64)
2428        fn = math_ops.multiply(fn, weights)
2429    return fn
2430
2431
2432def _streaming_sparse_false_negative_at_k(labels,
2433                                          predictions_idx,
2434                                          k,
2435                                          class_id=None,
2436                                          weights=None,
2437                                          name=None):
2438  """Calculates weighted per step false negatives for recall@k.
2439
2440  If `class_id` is specified, calculate binary true positives for `class_id`
2441      only.
2442  If `class_id` is not specified, calculate metrics for `k` predicted vs
2443      `n` label classes, where `n` is the 2nd dimension of `labels`.
2444
2445  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2446
2447  Args:
2448    labels: `int64` `Tensor` or `SparseTensor` with shape
2449      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2450      target classes for the associated prediction. Commonly, N=1 and `labels`
2451      has shape [batch_size, num_labels]. [D1, ... DN] must match
2452      `predictions_idx`.
2453    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2454      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2455      match `labels`.
2456    k: Integer, k for @k metric. This is only used for default op name.
2457    class_id: Class for which we want binary metrics.
2458    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2459      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2460      dimensions must be either `1`, or the same as the corresponding `labels`
2461      dimension).
2462    name: Name of new variable, and namespace for other dependent ops.
2463
2464  Returns:
2465    A tuple of `Variable` and update `Operation`.
2466
2467  Raises:
2468    ValueError: If `weights` is not `None` and has an incompatible shape.
2469  """
2470  with ops.name_scope(name, _at_k_name('false_negative', k, class_id=class_id),
2471                      (predictions_idx, labels, weights)) as scope:
2472    fn = _sparse_false_negative_at_k(
2473        predictions_idx=predictions_idx,
2474        labels=labels,
2475        class_id=class_id,
2476        weights=weights)
2477    batch_total_fn = math_ops.cast(math_ops.reduce_sum(fn), dtypes.float64)
2478
2479    var = metric_variable([], dtypes.float64, name=scope)
2480    return var, state_ops.assign_add(var, batch_total_fn, name='update')
2481
2482
2483@tf_export(v1=['metrics.recall_at_k'])
2484def recall_at_k(labels,
2485                predictions,
2486                k,
2487                class_id=None,
2488                weights=None,
2489                metrics_collections=None,
2490                updates_collections=None,
2491                name=None):
2492  """Computes recall@k of the predictions with respect to sparse labels.
2493
2494  If `class_id` is specified, we calculate recall by considering only the
2495      entries in the batch for which `class_id` is in the label, and computing
2496      the fraction of them for which `class_id` is in the top-k `predictions`.
2497  If `class_id` is not specified, we'll calculate recall as how often on
2498      average a class among the labels of a batch entry is in the top-k
2499      `predictions`.
2500
2501  `sparse_recall_at_k` creates two local variables,
2502  `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute
2503  the recall_at_k frequency. This frequency is ultimately returned as
2504  `recall_at_<k>`: an idempotent operation that simply divides
2505  `true_positive_at_<k>` by total (`true_positive_at_<k>` +
2506  `false_negative_at_<k>`).
2507
2508  For estimation of the metric over a stream of data, the function creates an
2509  `update_op` operation that updates these variables and returns the
2510  `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
2511  indicating the top `k` `predictions`. Set operations applied to `top_k` and
2512  `labels` calculate the true positives and false negatives weighted by
2513  `weights`. Then `update_op` increments `true_positive_at_<k>` and
2514  `false_negative_at_<k>` using these values.
2515
2516  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2517
2518  Args:
2519    labels: `int64` `Tensor` or `SparseTensor` with shape
2520      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
2521      num_labels=1. N >= 1 and num_labels is the number of target classes for
2522      the associated prediction. Commonly, N=1 and `labels` has shape
2523      [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
2524      should be in range [0, num_classes), where num_classes is the last
2525      dimension of `predictions`. Values outside this range always count
2526      towards `false_negative_at_<k>`.
2527    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
2528      N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
2529      The final dimension contains the logit values for each class. [D1, ... DN]
2530      must match `labels`.
2531    k: Integer, k for @k metric.
2532    class_id: Integer class ID for which we want binary metrics. This should be
2533      in range [0, num_classes), where num_classes is the last dimension of
2534      `predictions`. If class_id is outside this range, the method returns NAN.
2535    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2536      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2537      dimensions must be either `1`, or the same as the corresponding `labels`
2538      dimension).
2539    metrics_collections: An optional list of collections that values should
2540      be added to.
2541    updates_collections: An optional list of collections that updates should
2542      be added to.
2543    name: Name of new update operation, and namespace for other dependent ops.
2544
2545  Returns:
2546    recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
2547      by the sum of `true_positives` and `false_negatives`.
2548    update_op: `Operation` that increments `true_positives` and
2549      `false_negatives` variables appropriately, and whose value matches
2550      `recall`.
2551
2552  Raises:
2553    ValueError: If `weights` is not `None` and its shape doesn't match
2554    `predictions`, or if either `metrics_collections` or `updates_collections`
2555    are not a list or tuple.
2556    RuntimeError: If eager execution is enabled.
2557  """
2558  if context.executing_eagerly():
2559    raise RuntimeError('tf.metrics.recall_at_k is not '
2560                       'supported when eager execution is enabled.')
2561
2562  with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id),
2563                      (predictions, labels, weights)) as scope:
2564    _, top_k_idx = nn.top_k(predictions, k)
2565    return recall_at_top_k(
2566        labels=labels,
2567        predictions_idx=top_k_idx,
2568        k=k,
2569        class_id=class_id,
2570        weights=weights,
2571        metrics_collections=metrics_collections,
2572        updates_collections=updates_collections,
2573        name=scope)
2574
2575
2576@tf_export(v1=['metrics.recall_at_top_k'])
2577def recall_at_top_k(labels,
2578                    predictions_idx,
2579                    k=None,
2580                    class_id=None,
2581                    weights=None,
2582                    metrics_collections=None,
2583                    updates_collections=None,
2584                    name=None):
2585  """Computes recall@k of top-k predictions with respect to sparse labels.
2586
2587  Differs from `recall_at_k` in that predictions must be in the form of top `k`
2588  class indices, whereas `recall_at_k` expects logits. Refer to `recall_at_k`
2589  for more details.
2590
2591  Args:
2592    labels: `int64` `Tensor` or `SparseTensor` with shape
2593      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
2594      num_labels=1. N >= 1 and num_labels is the number of target classes for
2595      the associated prediction. Commonly, N=1 and `labels` has shape
2596      [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
2597      should be in range [0, num_classes), where num_classes is the last
2598      dimension of `predictions`. Values outside this range always count
2599      towards `false_negative_at_<k>`.
2600    predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
2601      Commonly, N=1 and predictions has shape [batch size, k]. The final
2602      dimension contains the top `k` predicted class indices. [D1, ... DN] must
2603      match `labels`.
2604    k: Integer, k for @k metric. Only used for the default op name.
2605    class_id: Integer class ID for which we want binary metrics. This should be
2606      in range [0, num_classes), where num_classes is the last dimension of
2607      `predictions`. If class_id is outside this range, the method returns NAN.
2608    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2609      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2610      dimensions must be either `1`, or the same as the corresponding `labels`
2611      dimension).
2612    metrics_collections: An optional list of collections that values should
2613      be added to.
2614    updates_collections: An optional list of collections that updates should
2615      be added to.
2616    name: Name of new update operation, and namespace for other dependent ops.
2617
2618  Returns:
2619    recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
2620      by the sum of `true_positives` and `false_negatives`.
2621    update_op: `Operation` that increments `true_positives` and
2622      `false_negatives` variables appropriately, and whose value matches
2623      `recall`.
2624
2625  Raises:
2626    ValueError: If `weights` is not `None` and its shape doesn't match
2627    `predictions`, or if either `metrics_collections` or `updates_collections`
2628    are not a list or tuple.
2629  """
2630  with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id),
2631                      (predictions_idx, labels, weights)) as scope:
2632    labels = _maybe_expand_labels(labels, predictions_idx)
2633    top_k_idx = math_ops.cast(predictions_idx, dtypes.int64)
2634    tp, tp_update = _streaming_sparse_true_positive_at_k(
2635        predictions_idx=top_k_idx,
2636        labels=labels,
2637        k=k,
2638        class_id=class_id,
2639        weights=weights)
2640    fn, fn_update = _streaming_sparse_false_negative_at_k(
2641        predictions_idx=top_k_idx,
2642        labels=labels,
2643        k=k,
2644        class_id=class_id,
2645        weights=weights)
2646
2647    def compute_recall(_, tp, fn):
2648      return math_ops.divide(tp, math_ops.add(tp, fn), name=scope)
2649
2650    metric = _aggregate_across_replicas(
2651        metrics_collections, compute_recall, tp, fn)
2652
2653    update = math_ops.divide(
2654        tp_update, math_ops.add(tp_update, fn_update), name='update')
2655    if updates_collections:
2656      ops.add_to_collections(updates_collections, update)
2657    return metric, update
2658
2659
2660@tf_export(v1=['metrics.recall_at_thresholds'])
2661def recall_at_thresholds(labels,
2662                         predictions,
2663                         thresholds,
2664                         weights=None,
2665                         metrics_collections=None,
2666                         updates_collections=None,
2667                         name=None):
2668  """Computes various recall values for different `thresholds` on `predictions`.
2669
2670  The `recall_at_thresholds` function creates four local variables,
2671  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
2672  for various values of thresholds. `recall[i]` is defined as the total weight
2673  of values in `predictions` above `thresholds[i]` whose corresponding entry in
2674  `labels` is `True`, divided by the total weight of `True` values in `labels`
2675  (`true_positives[i] / (true_positives[i] + false_negatives[i])`).
2676
2677  For estimation of the metric over a stream of data, the function creates an
2678  `update_op` operation that updates these variables and returns the `recall`.
2679
2680  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2681
2682  Args:
2683    labels: The ground truth values, a `Tensor` whose dimensions must match
2684      `predictions`. Will be cast to `bool`.
2685    predictions: A floating point `Tensor` of arbitrary shape and whose values
2686      are in the range `[0, 1]`.
2687    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
2688    weights: Optional `Tensor` whose rank is either 0, or the same rank as
2689      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2690      be either `1`, or the same as the corresponding `labels` dimension).
2691    metrics_collections: An optional list of collections that `recall` should be
2692      added to.
2693    updates_collections: An optional list of collections that `update_op` should
2694      be added to.
2695    name: An optional variable_scope name.
2696
2697  Returns:
2698    recall: A float `Tensor` of shape `[len(thresholds)]`.
2699    update_op: An operation that increments the `true_positives`,
2700      `true_negatives`, `false_positives` and `false_negatives` variables that
2701      are used in the computation of `recall`.
2702
2703  Raises:
2704    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2705      `weights` is not `None` and its shape doesn't match `predictions`, or if
2706      either `metrics_collections` or `updates_collections` are not a list or
2707      tuple.
2708    RuntimeError: If eager execution is enabled.
2709  """
2710  if context.executing_eagerly():
2711    raise RuntimeError('tf.metrics.recall_at_thresholds is not '
2712                       'supported when eager execution is enabled.')
2713
2714  with variable_scope.variable_scope(name, 'recall_at_thresholds',
2715                                     (predictions, labels, weights)):
2716    values, update_ops = _confusion_matrix_at_thresholds(
2717        labels, predictions, thresholds, weights, includes=('tp', 'fn'))
2718
2719    # Avoid division by zero.
2720    epsilon = 1e-7
2721
2722    def compute_recall(tp, fn, name):
2723      return math_ops.divide(tp, epsilon + tp + fn, name='recall_' + name)
2724
2725    def recall_across_replicas(_, values):
2726      return compute_recall(values['tp'], values['fn'], 'value')
2727
2728    rec = _aggregate_across_replicas(
2729        metrics_collections, recall_across_replicas, values)
2730
2731    update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op')
2732    if updates_collections:
2733      ops.add_to_collections(updates_collections, update_op)
2734
2735    return rec, update_op
2736
2737
2738@tf_export(v1=['metrics.root_mean_squared_error'])
2739def root_mean_squared_error(labels,
2740                            predictions,
2741                            weights=None,
2742                            metrics_collections=None,
2743                            updates_collections=None,
2744                            name=None):
2745  """Computes the root mean squared error between the labels and predictions.
2746
2747  The `root_mean_squared_error` function creates two local variables,
2748  `total` and `count` that are used to compute the root mean squared error.
2749  This average is weighted by `weights`, and it is ultimately returned as
2750  `root_mean_squared_error`: an idempotent operation that takes the square root
2751  of the division of `total` by `count`.
2752
2753  For estimation of the metric over a stream of data, the function creates an
2754  `update_op` operation that updates these variables and returns the
2755  `root_mean_squared_error`. Internally, a `squared_error` operation computes
2756  the element-wise square of the difference between `predictions` and `labels`.
2757  Then `update_op` increments `total` with the reduced sum of the product of
2758  `weights` and `squared_error`, and it increments `count` with the reduced sum
2759  of `weights`.
2760
2761  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2762
2763  Args:
2764    labels: A `Tensor` of the same shape as `predictions`.
2765    predictions: A `Tensor` of arbitrary shape.
2766    weights: Optional `Tensor` whose rank is either 0, or the same rank as
2767      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2768      be either `1`, or the same as the corresponding `labels` dimension).
2769    metrics_collections: An optional list of collections that
2770      `root_mean_squared_error` should be added to.
2771    updates_collections: An optional list of collections that `update_op` should
2772      be added to.
2773    name: An optional variable_scope name.
2774
2775  Returns:
2776    root_mean_squared_error: A `Tensor` representing the current mean, the value
2777      of `total` divided by `count`.
2778    update_op: An operation that increments the `total` and `count` variables
2779      appropriately and whose value matches `root_mean_squared_error`.
2780
2781  Raises:
2782    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2783      `weights` is not `None` and its shape doesn't match `predictions`, or if
2784      either `metrics_collections` or `updates_collections` are not a list or
2785      tuple.
2786    RuntimeError: If eager execution is enabled.
2787  """
2788  if context.executing_eagerly():
2789    raise RuntimeError('tf.metrics.root_mean_squared_error is not '
2790                       'supported when eager execution is enabled.')
2791
2792  predictions, labels, weights = _remove_squeezable_dimensions(
2793      predictions=predictions, labels=labels, weights=weights)
2794  mse, update_mse_op = mean_squared_error(labels, predictions, weights, None,
2795                                          None, name or
2796                                          'root_mean_squared_error')
2797
2798  once_across_replicas = lambda _, mse: math_ops.sqrt(mse)
2799  rmse = _aggregate_across_replicas(
2800      metrics_collections, once_across_replicas, mse)
2801
2802  update_rmse_op = math_ops.sqrt(update_mse_op)
2803  if updates_collections:
2804    ops.add_to_collections(updates_collections, update_rmse_op)
2805
2806  return rmse, update_rmse_op
2807
2808
2809@tf_export(v1=['metrics.sensitivity_at_specificity'])
2810def sensitivity_at_specificity(labels,
2811                               predictions,
2812                               specificity,
2813                               weights=None,
2814                               num_thresholds=200,
2815                               metrics_collections=None,
2816                               updates_collections=None,
2817                               name=None):
2818  """Computes the specificity at a given sensitivity.
2819
2820  The `sensitivity_at_specificity` function creates four local
2821  variables, `true_positives`, `true_negatives`, `false_positives` and
2822  `false_negatives` that are used to compute the sensitivity at the given
2823  specificity value. The threshold for the given specificity value is computed
2824  and used to evaluate the corresponding sensitivity.
2825
2826  For estimation of the metric over a stream of data, the function creates an
2827  `update_op` operation that updates these variables and returns the
2828  `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`,
2829  `false_positives` and `false_negatives` counts with the weight of each case
2830  found in the `predictions` and `labels`.
2831
2832  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2833
2834  For additional information about specificity and sensitivity, see the
2835  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
2836
2837  Args:
2838    labels: The ground truth values, a `Tensor` whose dimensions must match
2839      `predictions`. Will be cast to `bool`.
2840    predictions: A floating point `Tensor` of arbitrary shape and whose values
2841      are in the range `[0, 1]`.
2842    specificity: A scalar value in range `[0, 1]`.
2843    weights: Optional `Tensor` whose rank is either 0, or the same rank as
2844      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2845      be either `1`, or the same as the corresponding `labels` dimension).
2846    num_thresholds: The number of thresholds to use for matching the given
2847      specificity.
2848    metrics_collections: An optional list of collections that `sensitivity`
2849      should be added to.
2850    updates_collections: An optional list of collections that `update_op` should
2851      be added to.
2852    name: An optional variable_scope name.
2853
2854  Returns:
2855    sensitivity: A scalar `Tensor` representing the sensitivity at the given
2856      `specificity` value.
2857    update_op: An operation that increments the `true_positives`,
2858      `true_negatives`, `false_positives` and `false_negatives` variables
2859      appropriately and whose value matches `sensitivity`.
2860
2861  Raises:
2862    ValueError: If `predictions` and `labels` have mismatched shapes, if
2863      `weights` is not `None` and its shape doesn't match `predictions`, or if
2864      `specificity` is not between 0 and 1, or if either `metrics_collections`
2865      or `updates_collections` are not a list or tuple.
2866    RuntimeError: If eager execution is enabled.
2867  """
2868  if context.executing_eagerly():
2869    raise RuntimeError('tf.metrics.sensitivity_at_specificity is not '
2870                       'supported when eager execution is enabled.')
2871
2872  if specificity < 0 or specificity > 1:
2873    raise ValueError('`specificity` must be in the range [0, 1].')
2874
2875  with variable_scope.variable_scope(name, 'sensitivity_at_specificity',
2876                                     (predictions, labels, weights)):
2877    kepsilon = 1e-7  # to account for floating point imprecisions
2878    thresholds = [
2879        (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
2880    ]
2881    thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
2882
2883    values, update_ops = _confusion_matrix_at_thresholds(
2884        labels, predictions, thresholds, weights)
2885
2886    def compute_sensitivity_at_specificity(tp, tn, fp, fn, name):
2887      specificities = math_ops.divide(tn, tn + fp + kepsilon)
2888      tf_index = math_ops.argmin(math_ops.abs(specificities - specificity), 0)
2889      tf_index = math_ops.cast(tf_index, dtypes.int32)
2890
2891      # Now, we have the implicit threshold, so compute the sensitivity:
2892      return math_ops.divide(tp[tf_index],
2893                             tp[tf_index] + fn[tf_index] + kepsilon, name)
2894
2895    def sensitivity_across_replicas(_, values):
2896      return compute_sensitivity_at_specificity(
2897          values['tp'], values['tn'], values['fp'], values['fn'], 'value')
2898
2899    sensitivity = _aggregate_across_replicas(
2900        metrics_collections, sensitivity_across_replicas, values)
2901
2902    update_op = compute_sensitivity_at_specificity(
2903        update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
2904        'update_op')
2905    if updates_collections:
2906      ops.add_to_collections(updates_collections, update_op)
2907
2908    return sensitivity, update_op
2909
2910
2911def _expand_and_tile(tensor, multiple, dim=0, name=None):
2912  """Slice `tensor` shape in 2, then tile along the sliced dimension.
2913
2914  A new dimension is inserted in shape of `tensor` before `dim`, then values are
2915  tiled `multiple` times along the new dimension.
2916
2917  Args:
2918    tensor: Input `Tensor` or `SparseTensor`.
2919    multiple: Integer, number of times to tile.
2920    dim: Integer, dimension along which to tile.
2921    name: Name of operation.
2922
2923  Returns:
2924    `Tensor` result of expanding and tiling `tensor`.
2925
2926  Raises:
2927    ValueError: if `multiple` is less than 1, or `dim` is not in
2928    `[-rank(tensor), rank(tensor)]`.
2929  """
2930  if multiple < 1:
2931    raise ValueError('Invalid multiple %s, must be > 0.' % multiple)
2932  with ops.name_scope(name, 'expand_and_tile',
2933                      (tensor, multiple, dim)) as scope:
2934    # Sparse.
2935    tensor = sparse_tensor.convert_to_tensor_or_sparse_tensor(tensor)
2936    if isinstance(tensor, sparse_tensor.SparseTensor):
2937      if dim < 0:
2938        expand_dims = array_ops.reshape(
2939            array_ops.size(tensor.dense_shape) + dim, [1])
2940      else:
2941        expand_dims = [dim]
2942      expanded_shape = array_ops.concat(
2943          (array_ops.slice(tensor.dense_shape, [0], expand_dims), [1],
2944           array_ops.slice(tensor.dense_shape, expand_dims, [-1])),
2945          0,
2946          name='expanded_shape')
2947      expanded = sparse_ops.sparse_reshape(
2948          tensor, shape=expanded_shape, name='expand')
2949      if multiple == 1:
2950        return expanded
2951      return sparse_ops.sparse_concat(
2952          dim - 1 if dim < 0 else dim, [expanded] * multiple, name=scope)
2953
2954    # Dense.
2955    expanded = array_ops.expand_dims(
2956        tensor, dim if (dim >= 0) else (dim - 1), name='expand')
2957    if multiple == 1:
2958      return expanded
2959    ones = array_ops.ones_like(array_ops.shape(tensor))
2960    tile_multiples = array_ops.concat(
2961        (ones[:dim], (multiple,), ones[dim:]), 0, name='multiples')
2962    return array_ops.tile(expanded, tile_multiples, name=scope)
2963
2964
2965def _num_relevant(labels, k):
2966  """Computes number of relevant values for each row in labels.
2967
2968  For labels with shape [D1, ... DN, num_labels], this is the minimum of
2969  `num_labels` and `k`.
2970
2971  Args:
2972    labels: `int64` `Tensor` or `SparseTensor` with shape
2973      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2974      target classes for the associated prediction. Commonly, N=1 and `labels`
2975      has shape [batch_size, num_labels].
2976    k: Integer, k for @k metric.
2977
2978  Returns:
2979    Integer `Tensor` of shape [D1, ... DN], where each value is the number of
2980    relevant values for that row.
2981
2982  Raises:
2983    ValueError: if inputs have invalid dtypes or values.
2984  """
2985  if k < 1:
2986    raise ValueError('Invalid k=%s.' % k)
2987  with ops.name_scope(None, 'num_relevant', (labels,)) as scope:
2988    # For SparseTensor, calculate separate count for each row.
2989    labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
2990    if isinstance(labels, sparse_tensor.SparseTensor):
2991      return math_ops.minimum(sets.set_size(labels), k, name=scope)
2992
2993    # The relevant values for each (d1, ... dN) is the minimum of k and the
2994    # number of labels along the last dimension that are non-negative.
2995    num_labels = math_ops.reduce_sum(
2996        array_ops.where_v2(math_ops.greater_equal(labels, 0),
2997                           array_ops.ones_like(labels),
2998                           array_ops.zeros_like(labels)),
2999        axis=-1)
3000    return math_ops.minimum(num_labels, k, name=scope)
3001
3002
3003def _sparse_average_precision_at_top_k(labels, predictions_idx):
3004  """Computes average precision@k of predictions with respect to sparse labels.
3005
3006  From en.wikipedia.org/wiki/Information_retrieval#Average_precision, formula
3007  for each row is:
3008
3009    AveP = sum_{i=1...k} P_{i} * rel_{i} / num_relevant_items
3010
3011  A "row" is the elements in dimension [D1, ... DN] of `predictions_idx`,
3012  `labels`, and the result `Tensors`. In the common case, this is [batch_size].
3013  Each row of the results contains the average precision for that row.
3014
3015  Args:
3016    labels: `int64` `Tensor` or `SparseTensor` with shape
3017      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3018      num_labels=1. N >= 1 and num_labels is the number of target classes for
3019      the associated prediction. Commonly, N=1 and `labels` has shape
3020      [batch_size, num_labels]. [D1, ... DN] must match `predictions_idx`.
3021      Values should be non-negative. Negative values are ignored.
3022    predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
3023      Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final
3024      dimension must be set and contains the top `k` predicted class indices.
3025      [D1, ... DN] must match `labels`. Values should be in range
3026      [0, num_classes).
3027
3028  Returns:
3029    `float64` `Tensor` of shape [D1, ... DN], where each value is the average
3030    precision for that row.
3031
3032  Raises:
3033    ValueError: if the last dimension of predictions_idx is not set.
3034  """
3035  with ops.name_scope(None, 'average_precision',
3036                      (predictions_idx, labels)) as scope:
3037    predictions_idx = math_ops.cast(
3038        predictions_idx, dtypes.int64, name='predictions_idx')
3039    if predictions_idx.get_shape().ndims == 0:
3040      raise ValueError('The rank of predictions_idx must be at least 1.')
3041    k = predictions_idx.get_shape().as_list()[-1]
3042    if k is None:
3043      raise ValueError('The last dimension of predictions_idx must be set.')
3044    labels = _maybe_expand_labels(labels, predictions_idx)
3045
3046    # Expand dims to produce [D1, ... DN, k, 1] tensor. This gives us a separate
3047    # prediction for each k, so we can calculate separate true positive values
3048    # for each k.
3049    predictions_idx_per_k = array_ops.expand_dims(
3050        predictions_idx, -1, name='predictions_idx_per_k')
3051
3052    # Replicate labels k times to produce [D1, ... DN, k, num_labels] tensor.
3053    labels_per_k = _expand_and_tile(
3054        labels, multiple=k, dim=-1, name='labels_per_k')
3055
3056    # The following tensors are all of shape [D1, ... DN, k], containing values
3057    # per row, per k value.
3058    # `relevant_per_k` (int32) - Relevance indicator, 1 if the prediction at
3059    #     that k value is correct, 0 otherwise. This is the "rel_{i}" term from
3060    #     the formula above.
3061    # `tp_per_k` (int32) - True positive counts.
3062    # `retrieved_per_k` (int32) - Number of predicted values at each k. This is
3063    #     the precision denominator.
3064    # `precision_per_k` (float64) - Precision at each k. This is the "P_{i}"
3065    #     term from the formula above.
3066    # `relevant_precision_per_k` (float64) - Relevant precisions; i.e.,
3067    #     precisions at all k for which relevance indicator is true.
3068    relevant_per_k = _sparse_true_positive_at_k(
3069        labels_per_k, predictions_idx_per_k, name='relevant_per_k')
3070    tp_per_k = math_ops.cumsum(relevant_per_k, axis=-1, name='tp_per_k')
3071    retrieved_per_k = math_ops.cumsum(
3072        array_ops.ones_like(relevant_per_k), axis=-1, name='retrieved_per_k')
3073    precision_per_k = math_ops.divide(
3074        math_ops.cast(tp_per_k, dtypes.float64),
3075        math_ops.cast(retrieved_per_k, dtypes.float64),
3076        name='precision_per_k')
3077    relevant_precision_per_k = math_ops.multiply(
3078        precision_per_k,
3079        math_ops.cast(relevant_per_k, dtypes.float64),
3080        name='relevant_precision_per_k')
3081
3082    # Reduce along k dimension to get the sum, yielding a [D1, ... DN] tensor.
3083    precision_sum = math_ops.reduce_sum(
3084        relevant_precision_per_k, axis=(-1,), name='precision_sum')
3085
3086    # Divide by number of relevant items to get average precision. These are
3087    # the "num_relevant_items" and "AveP" terms from the formula above.
3088    num_relevant_items = math_ops.cast(_num_relevant(labels, k), dtypes.float64)
3089    return math_ops.divide(precision_sum, num_relevant_items, name=scope)
3090
3091
3092def _streaming_sparse_average_precision_at_top_k(labels,
3093                                                 predictions_idx,
3094                                                 weights=None,
3095                                                 metrics_collections=None,
3096                                                 updates_collections=None,
3097                                                 name=None):
3098  """Computes average precision@k of predictions with respect to sparse labels.
3099
3100  `sparse_average_precision_at_top_k` creates two local variables,
3101  `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that
3102  are used to compute the frequency. This frequency is ultimately returned as
3103  `average_precision_at_<k>`: an idempotent operation that simply divides
3104  `average_precision_at_<k>/total` by `average_precision_at_<k>/max`.
3105
3106  For estimation of the metric over a stream of data, the function creates an
3107  `update_op` operation that updates these variables and returns the
3108  `precision_at_<k>`. Set operations applied to `top_k` and `labels` calculate
3109  the true positives and false positives weighted by `weights`. Then `update_op`
3110  increments `true_positive_at_<k>` and `false_positive_at_<k>` using these
3111  values.
3112
3113  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3114
3115  Args:
3116    labels: `int64` `Tensor` or `SparseTensor` with shape
3117      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3118      num_labels=1. N >= 1 and num_labels is the number of target classes for
3119      the associated prediction. Commonly, N=1 and `labels` has shape
3120      [batch_size, num_labels]. [D1, ... DN] must match `predictions_idx`.
3121      Values should be non-negative. Negative values are ignored.
3122    predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
3123      Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final
3124      dimension contains the top `k` predicted class indices. [D1, ... DN] must
3125      match `labels`. Values should be in range [0, num_classes).
3126    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3127      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3128      dimensions must be either `1`, or the same as the corresponding `labels`
3129      dimension).
3130    metrics_collections: An optional list of collections that values should
3131      be added to.
3132    updates_collections: An optional list of collections that updates should
3133      be added to.
3134    name: Name of new update operation, and namespace for other dependent ops.
3135
3136  Returns:
3137    mean_average_precision: Scalar `float64` `Tensor` with the mean average
3138      precision values.
3139    update: `Operation` that increments variables appropriately, and whose
3140      value matches `metric`.
3141  """
3142  with ops.name_scope(name, 'average_precision_at_top_k',
3143                      (predictions_idx, labels, weights)) as scope:
3144    # Calculate per-example average precision, and apply weights.
3145    average_precision = _sparse_average_precision_at_top_k(
3146        predictions_idx=predictions_idx, labels=labels)
3147    if weights is not None:
3148      weights = weights_broadcast_ops.broadcast_weights(
3149          math_ops.cast(weights, dtypes.float64), average_precision)
3150      average_precision = math_ops.multiply(average_precision, weights)
3151
3152    # Create accumulation variables and update ops for max average precision and
3153    # total average precision.
3154    with ops.name_scope(None, 'max', (average_precision,)) as max_scope:
3155      # `max` is the max possible precision. Since max for any row is 1.0:
3156      # - For the unweighted case, this is just the number of rows.
3157      # - For the weighted case, it's the sum of the weights broadcast across
3158      #   `average_precision` rows.
3159      max_var = metric_variable([], dtypes.float64, name=max_scope)
3160      if weights is None:
3161        batch_max = math_ops.cast(
3162            array_ops.size(average_precision, name='batch_max'), dtypes.float64)
3163      else:
3164        batch_max = math_ops.reduce_sum(weights, name='batch_max')
3165      max_update = state_ops.assign_add(max_var, batch_max, name='update')
3166    with ops.name_scope(None, 'total', (average_precision,)) as total_scope:
3167      total_var = metric_variable([], dtypes.float64, name=total_scope)
3168      batch_total = math_ops.reduce_sum(average_precision, name='batch_total')
3169      total_update = state_ops.assign_add(total_var, batch_total, name='update')
3170
3171    # Divide total by max to get mean, for both vars and the update ops.
3172    def precision_across_replicas(_, total_var, max_var):
3173      return _safe_scalar_div(total_var, max_var, name='mean')
3174
3175    mean_average_precision = _aggregate_across_replicas(
3176        metrics_collections, precision_across_replicas, total_var, max_var)
3177
3178    update = _safe_scalar_div(total_update, max_update, name=scope)
3179    if updates_collections:
3180      ops.add_to_collections(updates_collections, update)
3181
3182    return mean_average_precision, update
3183
3184
3185def _clean_out_of_range_indices(labels, num_classes):
3186  """Replaces large out-of-range labels by small out-of-range labels.
3187
3188  Replaces any value in `labels` that is greater or equal to `num_classes` by
3189  -1. Do this conditionally for efficiency in case there are no such values.
3190
3191  Args:
3192    labels: `int64` `Tensor` or `SparseTensor`.
3193    num_classes: `int64` scalar `Tensor`.
3194  Returns:
3195    An `int64` `Tensor` or `SparseTensor` as `labels` with indices greater
3196    or equal to num_classes replaced by -1.
3197  """
3198
3199  def _labels_is_sparse():
3200    """Returns true is `labels` is a sparse tensor."""
3201    return isinstance(labels, (sparse_tensor.SparseTensor,
3202                               sparse_tensor.SparseTensorValue))
3203
3204  def _clean_out_of_range(values):
3205    """Replaces by -1 any large out-of-range `values`."""
3206    return array_ops.where_v2(math_ops.greater_equal(values, num_classes),
3207                              -1 * array_ops.ones_like(values), values)
3208
3209  def _clean_labels_out_of_range():
3210    """Replaces by -1 ane large out-of-range values in `labels`."""
3211    if _labels_is_sparse():
3212      return type(labels)(indices=labels.indices,
3213                          values=_clean_out_of_range(labels.values),
3214                          dense_shape=labels.dense_shape)
3215    else:
3216      return _clean_out_of_range(labels)
3217
3218  max_labels = math_ops.reduce_max(
3219      labels.values if _labels_is_sparse() else labels)
3220  return control_flow_ops.cond(
3221      math_ops.greater_equal(max_labels, num_classes),
3222      _clean_labels_out_of_range,
3223      lambda: labels)
3224
3225
3226@tf_export(v1=['metrics.sparse_average_precision_at_k'])
3227@deprecated(None, 'Use average_precision_at_k instead')
3228def sparse_average_precision_at_k(labels,
3229                                  predictions,
3230                                  k,
3231                                  weights=None,
3232                                  metrics_collections=None,
3233                                  updates_collections=None,
3234                                  name=None):
3235  """Renamed to `average_precision_at_k`, please use that method instead."""
3236  return average_precision_at_k(
3237      labels=labels,
3238      predictions=predictions,
3239      k=k,
3240      weights=weights,
3241      metrics_collections=metrics_collections,
3242      updates_collections=updates_collections,
3243      name=name)
3244
3245
3246@tf_export(v1=['metrics.average_precision_at_k'])
3247def average_precision_at_k(labels,
3248                           predictions,
3249                           k,
3250                           weights=None,
3251                           metrics_collections=None,
3252                           updates_collections=None,
3253                           name=None):
3254  """Computes average precision@k of predictions with respect to sparse labels.
3255
3256  `average_precision_at_k` creates two local variables,
3257  `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that
3258  are used to compute the frequency. This frequency is ultimately returned as
3259  `average_precision_at_<k>`: an idempotent operation that simply divides
3260  `average_precision_at_<k>/total` by `average_precision_at_<k>/max`.
3261
3262  For estimation of the metric over a stream of data, the function creates an
3263  `update_op` operation that updates these variables and returns the
3264  `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
3265  indicating the top `k` `predictions`. Set operations applied to `top_k` and
3266  `labels` calculate the true positives and false positives weighted by
3267  `weights`. Then `update_op` increments `true_positive_at_<k>` and
3268  `false_positive_at_<k>` using these values.
3269
3270  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3271
3272  Args:
3273    labels: `int64` `Tensor` or `SparseTensor` with shape
3274      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3275      num_labels=1. N >= 1 and num_labels is the number of target classes for
3276      the associated prediction. Commonly, N=1 and `labels` has shape
3277      [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
3278      should be in range [0, num_classes), where num_classes is the last
3279      dimension of `predictions`. Values outside this range are ignored.
3280    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
3281      N >= 1. Commonly, N=1 and `predictions` has shape
3282      [batch size, num_classes]. The final dimension contains the logit values
3283      for each class. [D1, ... DN] must match `labels`.
3284    k: Integer, k for @k metric. This will calculate an average precision for
3285      range `[1,k]`, as documented above.
3286    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3287      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3288      dimensions must be either `1`, or the same as the corresponding `labels`
3289      dimension).
3290    metrics_collections: An optional list of collections that values should
3291      be added to.
3292    updates_collections: An optional list of collections that updates should
3293      be added to.
3294    name: Name of new update operation, and namespace for other dependent ops.
3295
3296  Returns:
3297    mean_average_precision: Scalar `float64` `Tensor` with the mean average
3298      precision values.
3299    update: `Operation` that increments variables appropriately, and whose
3300      value matches `metric`.
3301
3302  Raises:
3303    ValueError: if k is invalid.
3304    RuntimeError: If eager execution is enabled.
3305  """
3306  if context.executing_eagerly():
3307    raise RuntimeError('tf.metrics.sparse_average_precision_at_k is not '
3308                       'supported when eager execution is enabled.')
3309
3310  if k < 1:
3311    raise ValueError('Invalid k=%s.' % k)
3312  with ops.name_scope(name, _at_k_name('average_precision', k),
3313                      (predictions, labels, weights)) as scope:
3314    # Calculate top k indices to produce [D1, ... DN, k] tensor.
3315    _, predictions_idx = nn.top_k(predictions, k)
3316    # The documentation states that labels should be in [0, ..., num_classes),
3317    # but num_classes is lost when predictions_idx replaces predictions.
3318    # For conformity with the documentation, any label >= num_classes, which is
3319    # ignored, is replaced by -1.
3320    labels = _clean_out_of_range_indices(
3321        labels, math_ops.cast(array_ops.shape(predictions)[-1], dtypes.int64))
3322    return _streaming_sparse_average_precision_at_top_k(
3323        labels=labels,
3324        predictions_idx=predictions_idx,
3325        weights=weights,
3326        metrics_collections=metrics_collections,
3327        updates_collections=updates_collections,
3328        name=scope)
3329
3330
3331def _sparse_false_positive_at_k(labels,
3332                                predictions_idx,
3333                                class_id=None,
3334                                weights=None):
3335  """Calculates false positives for precision@k.
3336
3337  If `class_id` is specified, calculate binary true positives for `class_id`
3338      only.
3339  If `class_id` is not specified, calculate metrics for `k` predicted vs
3340      `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
3341
3342  Args:
3343    labels: `int64` `Tensor` or `SparseTensor` with shape
3344      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
3345      target classes for the associated prediction. Commonly, N=1 and `labels`
3346      has shape [batch_size, num_labels]. [D1, ... DN] must match
3347      `predictions_idx`.
3348    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
3349      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
3350      match `labels`.
3351    class_id: Class for which we want binary metrics.
3352    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3353      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3354      dimensions must be either `1`, or the same as the corresponding `labels`
3355      dimension).
3356
3357  Returns:
3358    A [D1, ... DN] `Tensor` of false positive counts.
3359  """
3360  with ops.name_scope(None, 'false_positives',
3361                      (predictions_idx, labels, weights)):
3362    labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
3363                                                     class_id)
3364    fp = sets.set_size(
3365        sets.set_difference(predictions_idx, labels, aminusb=True))
3366    fp = math_ops.cast(fp, dtypes.float64)
3367    if weights is not None:
3368      with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable(
3369          weights, fp),)):
3370        weights = math_ops.cast(weights, dtypes.float64)
3371        fp = math_ops.multiply(fp, weights)
3372    return fp
3373
3374
3375def _streaming_sparse_false_positive_at_k(labels,
3376                                          predictions_idx,
3377                                          k=None,
3378                                          class_id=None,
3379                                          weights=None,
3380                                          name=None):
3381  """Calculates weighted per step false positives for precision@k.
3382
3383  If `class_id` is specified, calculate binary true positives for `class_id`
3384      only.
3385  If `class_id` is not specified, calculate metrics for `k` predicted vs
3386      `n` label classes, where `n` is the 2nd dimension of `labels`.
3387
3388  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3389
3390  Args:
3391    labels: `int64` `Tensor` or `SparseTensor` with shape
3392      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
3393      target classes for the associated prediction. Commonly, N=1 and `labels`
3394      has shape [batch_size, num_labels]. [D1, ... DN] must match
3395      `predictions_idx`.
3396    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
3397      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
3398      match `labels`.
3399    k: Integer, k for @k metric. This is only used for default op name.
3400    class_id: Class for which we want binary metrics.
3401    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3402      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3403      dimensions must be either `1`, or the same as the corresponding `labels`
3404      dimension).
3405    name: Name of new variable, and namespace for other dependent ops.
3406
3407  Returns:
3408    A tuple of `Variable` and update `Operation`.
3409
3410  Raises:
3411    ValueError: If `weights` is not `None` and has an incompatible shape.
3412  """
3413  with ops.name_scope(name, _at_k_name('false_positive', k, class_id=class_id),
3414                      (predictions_idx, labels, weights)) as scope:
3415    fp = _sparse_false_positive_at_k(
3416        predictions_idx=predictions_idx,
3417        labels=labels,
3418        class_id=class_id,
3419        weights=weights)
3420    batch_total_fp = math_ops.cast(math_ops.reduce_sum(fp), dtypes.float64)
3421
3422    var = metric_variable([], dtypes.float64, name=scope)
3423    return var, state_ops.assign_add(var, batch_total_fp, name='update')
3424
3425
3426@tf_export(v1=['metrics.precision_at_top_k'])
3427def precision_at_top_k(labels,
3428                       predictions_idx,
3429                       k=None,
3430                       class_id=None,
3431                       weights=None,
3432                       metrics_collections=None,
3433                       updates_collections=None,
3434                       name=None):
3435  """Computes precision@k of the predictions with respect to sparse labels.
3436
3437  Differs from `sparse_precision_at_k` in that predictions must be in the form
3438  of top `k` class indices, whereas `sparse_precision_at_k` expects logits.
3439  Refer to `sparse_precision_at_k` for more details.
3440
3441  Args:
3442    labels: `int64` `Tensor` or `SparseTensor` with shape
3443      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3444      num_labels=1. N >= 1 and num_labels is the number of target classes for
3445      the associated prediction. Commonly, N=1 and `labels` has shape
3446      [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
3447      should be in range [0, num_classes), where num_classes is the last
3448      dimension of `predictions`. Values outside this range are ignored.
3449    predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where
3450      N >= 1. Commonly, N=1 and predictions has shape [batch size, k].
3451      The final dimension contains the top `k` predicted class indices.
3452      [D1, ... DN] must match `labels`.
3453    k: Integer, k for @k metric. Only used for the default op name.
3454    class_id: Integer class ID for which we want binary metrics. This should be
3455      in range [0, num_classes], where num_classes is the last dimension of
3456      `predictions`. If `class_id` is outside this range, the method returns
3457      NAN.
3458    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3459      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3460      dimensions must be either `1`, or the same as the corresponding `labels`
3461      dimension).
3462    metrics_collections: An optional list of collections that values should
3463      be added to.
3464    updates_collections: An optional list of collections that updates should
3465      be added to.
3466    name: Name of new update operation, and namespace for other dependent ops.
3467
3468  Returns:
3469    precision: Scalar `float64` `Tensor` with the value of `true_positives`
3470      divided by the sum of `true_positives` and `false_positives`.
3471    update_op: `Operation` that increments `true_positives` and
3472      `false_positives` variables appropriately, and whose value matches
3473      `precision`.
3474
3475  Raises:
3476    ValueError: If `weights` is not `None` and its shape doesn't match
3477      `predictions`, or if either `metrics_collections` or `updates_collections`
3478      are not a list or tuple.
3479    RuntimeError: If eager execution is enabled.
3480  """
3481  if context.executing_eagerly():
3482    raise RuntimeError('tf.metrics.precision_at_top_k is not '
3483                       'supported when eager execution is enabled.')
3484
3485  with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id),
3486                      (predictions_idx, labels, weights)) as scope:
3487    labels = _maybe_expand_labels(labels, predictions_idx)
3488    top_k_idx = math_ops.cast(predictions_idx, dtypes.int64)
3489    tp, tp_update = _streaming_sparse_true_positive_at_k(
3490        predictions_idx=top_k_idx,
3491        labels=labels,
3492        k=k,
3493        class_id=class_id,
3494        weights=weights)
3495    fp, fp_update = _streaming_sparse_false_positive_at_k(
3496        predictions_idx=top_k_idx,
3497        labels=labels,
3498        k=k,
3499        class_id=class_id,
3500        weights=weights)
3501
3502    def precision_across_replicas(_, tp, fp):
3503      return math_ops.divide(tp, math_ops.add(tp, fp), name=scope)
3504
3505    metric = _aggregate_across_replicas(
3506        metrics_collections, precision_across_replicas, tp, fp)
3507
3508    update = math_ops.divide(
3509        tp_update, math_ops.add(tp_update, fp_update), name='update')
3510    if updates_collections:
3511      ops.add_to_collections(updates_collections, update)
3512    return metric, update
3513
3514
3515@tf_export(v1=['metrics.sparse_precision_at_k'])
3516@deprecated(None, 'Use precision_at_k instead')
3517def sparse_precision_at_k(labels,
3518                          predictions,
3519                          k,
3520                          class_id=None,
3521                          weights=None,
3522                          metrics_collections=None,
3523                          updates_collections=None,
3524                          name=None):
3525  """Renamed to `precision_at_k`, please use that method instead."""
3526  return precision_at_k(
3527      labels=labels,
3528      predictions=predictions,
3529      k=k,
3530      class_id=class_id,
3531      weights=weights,
3532      metrics_collections=metrics_collections,
3533      updates_collections=updates_collections,
3534      name=name)
3535
3536
3537@tf_export(v1=['metrics.precision_at_k'])
3538def precision_at_k(labels,
3539                   predictions,
3540                   k,
3541                   class_id=None,
3542                   weights=None,
3543                   metrics_collections=None,
3544                   updates_collections=None,
3545                   name=None):
3546  """Computes precision@k of the predictions with respect to sparse labels.
3547
3548  If `class_id` is specified, we calculate precision by considering only the
3549      entries in the batch for which `class_id` is in the top-k highest
3550      `predictions`, and computing the fraction of them for which `class_id` is
3551      indeed a correct label.
3552  If `class_id` is not specified, we'll calculate precision as how often on
3553      average a class among the top-k classes with the highest predicted values
3554      of a batch entry is correct and can be found in the label for that entry.
3555
3556  `precision_at_k` creates two local variables,
3557  `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute
3558  the precision@k frequency. This frequency is ultimately returned as
3559  `precision_at_<k>`: an idempotent operation that simply divides
3560  `true_positive_at_<k>` by total (`true_positive_at_<k>` +
3561  `false_positive_at_<k>`).
3562
3563  For estimation of the metric over a stream of data, the function creates an
3564  `update_op` operation that updates these variables and returns the
3565  `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
3566  indicating the top `k` `predictions`. Set operations applied to `top_k` and
3567  `labels` calculate the true positives and false positives weighted by
3568  `weights`. Then `update_op` increments `true_positive_at_<k>` and
3569  `false_positive_at_<k>` using these values.
3570
3571  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3572
3573  Args:
3574    labels: `int64` `Tensor` or `SparseTensor` with shape
3575      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3576      num_labels=1. N >= 1 and num_labels is the number of target classes for
3577      the associated prediction. Commonly, N=1 and `labels` has shape
3578      [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
3579      should be in range [0, num_classes), where num_classes is the last
3580      dimension of `predictions`. Values outside this range are ignored.
3581    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
3582      N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
3583      The final dimension contains the logit values for each class. [D1, ... DN]
3584      must match `labels`.
3585    k: Integer, k for @k metric.
3586    class_id: Integer class ID for which we want binary metrics. This should be
3587      in range [0, num_classes], where num_classes is the last dimension of
3588      `predictions`. If `class_id` is outside this range, the method returns
3589      NAN.
3590    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3591      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3592      dimensions must be either `1`, or the same as the corresponding `labels`
3593      dimension).
3594    metrics_collections: An optional list of collections that values should
3595      be added to.
3596    updates_collections: An optional list of collections that updates should
3597      be added to.
3598    name: Name of new update operation, and namespace for other dependent ops.
3599
3600  Returns:
3601    precision: Scalar `float64` `Tensor` with the value of `true_positives`
3602      divided by the sum of `true_positives` and `false_positives`.
3603    update_op: `Operation` that increments `true_positives` and
3604      `false_positives` variables appropriately, and whose value matches
3605      `precision`.
3606
3607  Raises:
3608    ValueError: If `weights` is not `None` and its shape doesn't match
3609      `predictions`, or if either `metrics_collections` or `updates_collections`
3610      are not a list or tuple.
3611    RuntimeError: If eager execution is enabled.
3612  """
3613  if context.executing_eagerly():
3614    raise RuntimeError('tf.metrics.sparse_precision_at_k is not '
3615                       'supported when eager execution is enabled.')
3616
3617  with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id),
3618                      (predictions, labels, weights)) as scope:
3619    _, top_k_idx = nn.top_k(predictions, k)
3620    return precision_at_top_k(
3621        labels=labels,
3622        predictions_idx=top_k_idx,
3623        k=k,
3624        class_id=class_id,
3625        weights=weights,
3626        metrics_collections=metrics_collections,
3627        updates_collections=updates_collections,
3628        name=scope)
3629
3630
3631@tf_export(v1=['metrics.specificity_at_sensitivity'])
3632def specificity_at_sensitivity(labels,
3633                               predictions,
3634                               sensitivity,
3635                               weights=None,
3636                               num_thresholds=200,
3637                               metrics_collections=None,
3638                               updates_collections=None,
3639                               name=None):
3640  """Computes the specificity at a given sensitivity.
3641
3642  The `specificity_at_sensitivity` function creates four local
3643  variables, `true_positives`, `true_negatives`, `false_positives` and
3644  `false_negatives` that are used to compute the specificity at the given
3645  sensitivity value. The threshold for the given sensitivity value is computed
3646  and used to evaluate the corresponding specificity.
3647
3648  For estimation of the metric over a stream of data, the function creates an
3649  `update_op` operation that updates these variables and returns the
3650  `specificity`. `update_op` increments the `true_positives`, `true_negatives`,
3651  `false_positives` and `false_negatives` counts with the weight of each case
3652  found in the `predictions` and `labels`.
3653
3654  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3655
3656  For additional information about specificity and sensitivity, see the
3657  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
3658
3659  Args:
3660    labels: The ground truth values, a `Tensor` whose dimensions must match
3661      `predictions`. Will be cast to `bool`.
3662    predictions: A floating point `Tensor` of arbitrary shape and whose values
3663      are in the range `[0, 1]`.
3664    sensitivity: A scalar value in range `[0, 1]`.
3665    weights: Optional `Tensor` whose rank is either 0, or the same rank as
3666      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
3667      be either `1`, or the same as the corresponding `labels` dimension).
3668    num_thresholds: The number of thresholds to use for matching the given
3669      sensitivity.
3670    metrics_collections: An optional list of collections that `specificity`
3671      should be added to.
3672    updates_collections: An optional list of collections that `update_op` should
3673      be added to.
3674    name: An optional variable_scope name.
3675
3676  Returns:
3677    specificity: A scalar `Tensor` representing the specificity at the given
3678      `sensitivity` value.
3679    update_op: An operation that increments the `true_positives`,
3680      `true_negatives`, `false_positives` and `false_negatives` variables
3681      appropriately and whose value matches `specificity`.
3682
3683  Raises:
3684    ValueError: If `predictions` and `labels` have mismatched shapes, if
3685      `weights` is not `None` and its shape doesn't match `predictions`, or if
3686      `sensitivity` is not between 0 and 1, or if either `metrics_collections`
3687      or `updates_collections` are not a list or tuple.
3688    RuntimeError: If eager execution is enabled.
3689  """
3690  if context.executing_eagerly():
3691    raise RuntimeError('tf.metrics.specificity_at_sensitivity is not '
3692                       'supported when eager execution is enabled.')
3693
3694  if sensitivity < 0 or sensitivity > 1:
3695    raise ValueError('`sensitivity` must be in the range [0, 1].')
3696
3697  with variable_scope.variable_scope(name, 'specificity_at_sensitivity',
3698                                     (predictions, labels, weights)):
3699    kepsilon = 1e-7  # to account for floating point imprecisions
3700    thresholds = [
3701        (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
3702    ]
3703    thresholds = [0.0 - kepsilon] + thresholds + [1.0 - kepsilon]
3704
3705    values, update_ops = _confusion_matrix_at_thresholds(
3706        labels, predictions, thresholds, weights)
3707
3708    def compute_specificity_at_sensitivity(tp, tn, fp, fn, name):
3709      """Computes the specificity at the given sensitivity.
3710
3711      Args:
3712        tp: True positives.
3713        tn: True negatives.
3714        fp: False positives.
3715        fn: False negatives.
3716        name: The name of the operation.
3717
3718      Returns:
3719        The specificity using the aggregated values.
3720      """
3721      sensitivities = math_ops.divide(tp, tp + fn + kepsilon)
3722
3723      # We'll need to use this trick until tf.argmax allows us to specify
3724      # whether we should use the first or last index in case of ties.
3725      min_val = math_ops.reduce_min(math_ops.abs(sensitivities - sensitivity))
3726      indices_at_minval = math_ops.equal(
3727          math_ops.abs(sensitivities - sensitivity), min_val)
3728      indices_at_minval = math_ops.cast(indices_at_minval, dtypes.int64)
3729      indices_at_minval = math_ops.cumsum(indices_at_minval)
3730      tf_index = math_ops.argmax(indices_at_minval, 0)
3731      tf_index = math_ops.cast(tf_index, dtypes.int32)
3732
3733      # Now, we have the implicit threshold, so compute the specificity:
3734      return math_ops.divide(tn[tf_index],
3735                             tn[tf_index] + fp[tf_index] + kepsilon, name)
3736
3737    def specificity_across_replicas(_, values):
3738      return compute_specificity_at_sensitivity(
3739          values['tp'], values['tn'], values['fp'], values['fn'], 'value')
3740
3741    specificity = _aggregate_across_replicas(
3742        metrics_collections, specificity_across_replicas, values)
3743
3744    update_op = compute_specificity_at_sensitivity(
3745        update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
3746        'update_op')
3747    if updates_collections:
3748      ops.add_to_collections(updates_collections, update_op)
3749
3750    return specificity, update_op
3751