1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Confusion matrix related utilities."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import check_ops
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.util import deprecation
28from tensorflow.python.util import dispatch
29from tensorflow.python.util.tf_export import tf_export
30
31
32def remove_squeezable_dimensions(
33    labels, predictions, expected_rank_diff=0, name=None):
34  """Squeeze last dim if ranks differ from expected by exactly 1.
35
36  In the common case where we expect shapes to match, `expected_rank_diff`
37  defaults to 0, and we squeeze the last dimension of the larger rank if they
38  differ by 1.
39
40  But, for example, if `labels` contains class IDs and `predictions` contains 1
41  probability per class, we expect `predictions` to have 1 more dimension than
42  `labels`, so `expected_rank_diff` would be 1. In this case, we'd squeeze
43  `labels` if `rank(predictions) - rank(labels) == 0`, and
44  `predictions` if `rank(predictions) - rank(labels) == 2`.
45
46  This will use static shape if available. Otherwise, it will add graph
47  operations, which could result in a performance hit.
48
49  Args:
50    labels: Label values, a `Tensor` whose dimensions match `predictions`.
51    predictions: Predicted values, a `Tensor` of arbitrary dimensions.
52    expected_rank_diff: Expected result of `rank(predictions) - rank(labels)`.
53    name: Name of the op.
54
55  Returns:
56    Tuple of `labels` and `predictions`, possibly with last dim squeezed.
57  """
58  with ops.name_scope(name, 'remove_squeezable_dimensions',
59                      [labels, predictions]):
60    predictions = ops.convert_to_tensor(predictions)
61    labels = ops.convert_to_tensor(labels)
62    predictions_shape = predictions.get_shape()
63    predictions_rank = predictions_shape.ndims
64    labels_shape = labels.get_shape()
65    labels_rank = labels_shape.ndims
66    if (labels_rank is not None) and (predictions_rank is not None):
67      # Use static rank.
68      rank_diff = predictions_rank - labels_rank
69      if (rank_diff == expected_rank_diff + 1 and
70          predictions_shape.dims[-1].is_compatible_with(1)):
71        predictions = array_ops.squeeze(predictions, [-1])
72      elif (rank_diff == expected_rank_diff - 1 and
73            labels_shape.dims[-1].is_compatible_with(1)):
74        labels = array_ops.squeeze(labels, [-1])
75      return labels, predictions
76
77    # Use dynamic rank.
78    rank_diff = array_ops.rank(predictions) - array_ops.rank(labels)
79    if (predictions_rank is None) or (
80        predictions_shape.dims[-1].is_compatible_with(1)):
81      predictions = control_flow_ops.cond(
82          math_ops.equal(expected_rank_diff + 1, rank_diff),
83          lambda: array_ops.squeeze(predictions, [-1]),
84          lambda: predictions)
85    if (labels_rank is None) or (
86        labels_shape.dims[-1].is_compatible_with(1)):
87      labels = control_flow_ops.cond(
88          math_ops.equal(expected_rank_diff - 1, rank_diff),
89          lambda: array_ops.squeeze(labels, [-1]),
90          lambda: labels)
91    return labels, predictions
92
93
94@tf_export('math.confusion_matrix', v1=[])
95@dispatch.add_dispatch_support
96def confusion_matrix(labels,
97                     predictions,
98                     num_classes=None,
99                     weights=None,
100                     dtype=dtypes.int32,
101                     name=None):
102  """Computes the confusion matrix from predictions and labels.
103
104  The matrix columns represent the prediction labels and the rows represent the
105  real labels. The confusion matrix is always a 2-D array of shape `[n, n]`,
106  where `n` is the number of valid labels for a given classification task. Both
107  prediction and labels must be 1-D arrays of the same shape in order for this
108  function to work.
109
110  If `num_classes` is `None`, then `num_classes` will be set to one plus the
111  maximum value in either predictions or labels. Class labels are expected to
112  start at 0. For example, if `num_classes` is 3, then the possible labels
113  would be `[0, 1, 2]`.
114
115  If `weights` is not `None`, then each prediction contributes its
116  corresponding weight to the total value of the confusion matrix cell.
117
118  For example:
119
120  ```python
121    tf.math.confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
122        [[0 0 0 0 0]
123         [0 0 1 0 0]
124         [0 0 1 0 0]
125         [0 0 0 0 0]
126         [0 0 0 0 1]]
127  ```
128
129  Note that the possible labels are assumed to be `[0, 1, 2, 3, 4]`,
130  resulting in a 5x5 confusion matrix.
131
132  Args:
133    labels: 1-D `Tensor` of real labels for the classification task.
134    predictions: 1-D `Tensor` of predictions for a given classification.
135    num_classes: The possible number of labels the classification task can
136                 have. If this value is not provided, it will be calculated
137                 using both predictions and labels array.
138    weights: An optional `Tensor` whose shape matches `predictions`.
139    dtype: Data type of the confusion matrix.
140    name: Scope name.
141
142  Returns:
143    A `Tensor` of type `dtype` with shape `[n, n]` representing the confusion
144    matrix, where `n` is the number of possible labels in the classification
145    task.
146
147  Raises:
148    ValueError: If both predictions and labels are not 1-D vectors and have
149      mismatched shapes, or if `weights` is not `None` and its shape doesn't
150      match `predictions`.
151  """
152  with ops.name_scope(name, 'confusion_matrix',
153                      (predictions, labels, num_classes, weights)) as name:
154    labels, predictions = remove_squeezable_dimensions(
155        ops.convert_to_tensor(labels, name='labels'),
156        ops.convert_to_tensor(
157            predictions, name='predictions'))
158    predictions = math_ops.cast(predictions, dtypes.int64)
159    labels = math_ops.cast(labels, dtypes.int64)
160
161    # Sanity checks - underflow or overflow can cause memory corruption.
162    labels = control_flow_ops.with_dependencies(
163        [check_ops.assert_non_negative(
164            labels, message='`labels` contains negative values')],
165        labels)
166    predictions = control_flow_ops.with_dependencies(
167        [check_ops.assert_non_negative(
168            predictions, message='`predictions` contains negative values')],
169        predictions)
170
171    if num_classes is None:
172      num_classes = math_ops.maximum(math_ops.reduce_max(predictions),
173                                     math_ops.reduce_max(labels)) + 1
174    else:
175      num_classes_int64 = math_ops.cast(num_classes, dtypes.int64)
176      labels = control_flow_ops.with_dependencies(
177          [check_ops.assert_less(
178              labels, num_classes_int64, message='`labels` out of bound')],
179          labels)
180      predictions = control_flow_ops.with_dependencies(
181          [check_ops.assert_less(
182              predictions, num_classes_int64,
183              message='`predictions` out of bound')],
184          predictions)
185
186    if weights is not None:
187      weights = ops.convert_to_tensor(weights, name='weights')
188      predictions.get_shape().assert_is_compatible_with(weights.get_shape())
189      weights = math_ops.cast(weights, dtype)
190
191    shape = array_ops.stack([num_classes, num_classes])
192    indices = array_ops.stack([labels, predictions], axis=1)
193    values = (array_ops.ones_like(predictions, dtype)
194              if weights is None else weights)
195    return array_ops.scatter_nd(
196        indices=indices,
197        updates=values,
198        shape=math_ops.cast(shape, dtypes.int64))
199
200
201@tf_export(v1=['math.confusion_matrix', 'confusion_matrix'])
202@dispatch.add_dispatch_support
203@deprecation.deprecated_endpoints('confusion_matrix', 'train.confusion_matrix')
204def confusion_matrix_v1(labels,
205                        predictions,
206                        num_classes=None,
207                        dtype=dtypes.int32,
208                        name=None,
209                        weights=None):
210  """Computes the confusion matrix from predictions and labels.
211
212  The matrix columns represent the prediction labels and the rows represent the
213  real labels. The confusion matrix is always a 2-D array of shape `[n, n]`,
214  where `n` is the number of valid labels for a given classification task. Both
215  prediction and labels must be 1-D arrays of the same shape in order for this
216  function to work.
217
218  If `num_classes` is `None`, then `num_classes` will be set to one plus the
219  maximum value in either predictions or labels. Class labels are expected to
220  start at 0. For example, if `num_classes` is 3, then the possible labels
221  would be `[0, 1, 2]`.
222
223  If `weights` is not `None`, then each prediction contributes its
224  corresponding weight to the total value of the confusion matrix cell.
225
226  For example:
227
228  ```python
229    tf.math.confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
230        [[0 0 0 0 0]
231         [0 0 1 0 0]
232         [0 0 1 0 0]
233         [0 0 0 0 0]
234         [0 0 0 0 1]]
235  ```
236
237  Note that the possible labels are assumed to be `[0, 1, 2, 3, 4]`,
238  resulting in a 5x5 confusion matrix.
239
240  Args:
241    labels: 1-D `Tensor` of real labels for the classification task.
242    predictions: 1-D `Tensor` of predictions for a given classification.
243    num_classes: The possible number of labels the classification task can have.
244      If this value is not provided, it will be calculated using both
245      predictions and labels array.
246    dtype: Data type of the confusion matrix.
247    name: Scope name.
248    weights: An optional `Tensor` whose shape matches `predictions`.
249
250  Returns:
251    A `Tensor` of type `dtype` with shape `[n, n]` representing the confusion
252    matrix, where `n` is the number of possible labels in the classification
253    task.
254
255  Raises:
256    ValueError: If both predictions and labels are not 1-D vectors and have
257      mismatched shapes, or if `weights` is not `None` and its shape doesn't
258      match `predictions`.
259  """
260  return confusion_matrix(labels, predictions, num_classes, weights, dtype,
261                          name)
262