1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Confusion matrix related utilities.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import check_ops 25from tensorflow.python.ops import control_flow_ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.util import deprecation 28from tensorflow.python.util import dispatch 29from tensorflow.python.util.tf_export import tf_export 30 31 32def remove_squeezable_dimensions( 33 labels, predictions, expected_rank_diff=0, name=None): 34 """Squeeze last dim if ranks differ from expected by exactly 1. 35 36 In the common case where we expect shapes to match, `expected_rank_diff` 37 defaults to 0, and we squeeze the last dimension of the larger rank if they 38 differ by 1. 39 40 But, for example, if `labels` contains class IDs and `predictions` contains 1 41 probability per class, we expect `predictions` to have 1 more dimension than 42 `labels`, so `expected_rank_diff` would be 1. In this case, we'd squeeze 43 `labels` if `rank(predictions) - rank(labels) == 0`, and 44 `predictions` if `rank(predictions) - rank(labels) == 2`. 45 46 This will use static shape if available. Otherwise, it will add graph 47 operations, which could result in a performance hit. 48 49 Args: 50 labels: Label values, a `Tensor` whose dimensions match `predictions`. 51 predictions: Predicted values, a `Tensor` of arbitrary dimensions. 52 expected_rank_diff: Expected result of `rank(predictions) - rank(labels)`. 53 name: Name of the op. 54 55 Returns: 56 Tuple of `labels` and `predictions`, possibly with last dim squeezed. 57 """ 58 with ops.name_scope(name, 'remove_squeezable_dimensions', 59 [labels, predictions]): 60 predictions = ops.convert_to_tensor(predictions) 61 labels = ops.convert_to_tensor(labels) 62 predictions_shape = predictions.get_shape() 63 predictions_rank = predictions_shape.ndims 64 labels_shape = labels.get_shape() 65 labels_rank = labels_shape.ndims 66 if (labels_rank is not None) and (predictions_rank is not None): 67 # Use static rank. 68 rank_diff = predictions_rank - labels_rank 69 if (rank_diff == expected_rank_diff + 1 and 70 predictions_shape.dims[-1].is_compatible_with(1)): 71 predictions = array_ops.squeeze(predictions, [-1]) 72 elif (rank_diff == expected_rank_diff - 1 and 73 labels_shape.dims[-1].is_compatible_with(1)): 74 labels = array_ops.squeeze(labels, [-1]) 75 return labels, predictions 76 77 # Use dynamic rank. 78 rank_diff = array_ops.rank(predictions) - array_ops.rank(labels) 79 if (predictions_rank is None) or ( 80 predictions_shape.dims[-1].is_compatible_with(1)): 81 predictions = control_flow_ops.cond( 82 math_ops.equal(expected_rank_diff + 1, rank_diff), 83 lambda: array_ops.squeeze(predictions, [-1]), 84 lambda: predictions) 85 if (labels_rank is None) or ( 86 labels_shape.dims[-1].is_compatible_with(1)): 87 labels = control_flow_ops.cond( 88 math_ops.equal(expected_rank_diff - 1, rank_diff), 89 lambda: array_ops.squeeze(labels, [-1]), 90 lambda: labels) 91 return labels, predictions 92 93 94@tf_export('math.confusion_matrix', v1=[]) 95@dispatch.add_dispatch_support 96def confusion_matrix(labels, 97 predictions, 98 num_classes=None, 99 weights=None, 100 dtype=dtypes.int32, 101 name=None): 102 """Computes the confusion matrix from predictions and labels. 103 104 The matrix columns represent the prediction labels and the rows represent the 105 real labels. The confusion matrix is always a 2-D array of shape `[n, n]`, 106 where `n` is the number of valid labels for a given classification task. Both 107 prediction and labels must be 1-D arrays of the same shape in order for this 108 function to work. 109 110 If `num_classes` is `None`, then `num_classes` will be set to one plus the 111 maximum value in either predictions or labels. Class labels are expected to 112 start at 0. For example, if `num_classes` is 3, then the possible labels 113 would be `[0, 1, 2]`. 114 115 If `weights` is not `None`, then each prediction contributes its 116 corresponding weight to the total value of the confusion matrix cell. 117 118 For example: 119 120 ```python 121 tf.math.confusion_matrix([1, 2, 4], [2, 2, 4]) ==> 122 [[0 0 0 0 0] 123 [0 0 1 0 0] 124 [0 0 1 0 0] 125 [0 0 0 0 0] 126 [0 0 0 0 1]] 127 ``` 128 129 Note that the possible labels are assumed to be `[0, 1, 2, 3, 4]`, 130 resulting in a 5x5 confusion matrix. 131 132 Args: 133 labels: 1-D `Tensor` of real labels for the classification task. 134 predictions: 1-D `Tensor` of predictions for a given classification. 135 num_classes: The possible number of labels the classification task can 136 have. If this value is not provided, it will be calculated 137 using both predictions and labels array. 138 weights: An optional `Tensor` whose shape matches `predictions`. 139 dtype: Data type of the confusion matrix. 140 name: Scope name. 141 142 Returns: 143 A `Tensor` of type `dtype` with shape `[n, n]` representing the confusion 144 matrix, where `n` is the number of possible labels in the classification 145 task. 146 147 Raises: 148 ValueError: If both predictions and labels are not 1-D vectors and have 149 mismatched shapes, or if `weights` is not `None` and its shape doesn't 150 match `predictions`. 151 """ 152 with ops.name_scope(name, 'confusion_matrix', 153 (predictions, labels, num_classes, weights)) as name: 154 labels, predictions = remove_squeezable_dimensions( 155 ops.convert_to_tensor(labels, name='labels'), 156 ops.convert_to_tensor( 157 predictions, name='predictions')) 158 predictions = math_ops.cast(predictions, dtypes.int64) 159 labels = math_ops.cast(labels, dtypes.int64) 160 161 # Sanity checks - underflow or overflow can cause memory corruption. 162 labels = control_flow_ops.with_dependencies( 163 [check_ops.assert_non_negative( 164 labels, message='`labels` contains negative values')], 165 labels) 166 predictions = control_flow_ops.with_dependencies( 167 [check_ops.assert_non_negative( 168 predictions, message='`predictions` contains negative values')], 169 predictions) 170 171 if num_classes is None: 172 num_classes = math_ops.maximum(math_ops.reduce_max(predictions), 173 math_ops.reduce_max(labels)) + 1 174 else: 175 num_classes_int64 = math_ops.cast(num_classes, dtypes.int64) 176 labels = control_flow_ops.with_dependencies( 177 [check_ops.assert_less( 178 labels, num_classes_int64, message='`labels` out of bound')], 179 labels) 180 predictions = control_flow_ops.with_dependencies( 181 [check_ops.assert_less( 182 predictions, num_classes_int64, 183 message='`predictions` out of bound')], 184 predictions) 185 186 if weights is not None: 187 weights = ops.convert_to_tensor(weights, name='weights') 188 predictions.get_shape().assert_is_compatible_with(weights.get_shape()) 189 weights = math_ops.cast(weights, dtype) 190 191 shape = array_ops.stack([num_classes, num_classes]) 192 indices = array_ops.stack([labels, predictions], axis=1) 193 values = (array_ops.ones_like(predictions, dtype) 194 if weights is None else weights) 195 return array_ops.scatter_nd( 196 indices=indices, 197 updates=values, 198 shape=math_ops.cast(shape, dtypes.int64)) 199 200 201@tf_export(v1=['math.confusion_matrix', 'confusion_matrix']) 202@dispatch.add_dispatch_support 203@deprecation.deprecated_endpoints('confusion_matrix', 'train.confusion_matrix') 204def confusion_matrix_v1(labels, 205 predictions, 206 num_classes=None, 207 dtype=dtypes.int32, 208 name=None, 209 weights=None): 210 """Computes the confusion matrix from predictions and labels. 211 212 The matrix columns represent the prediction labels and the rows represent the 213 real labels. The confusion matrix is always a 2-D array of shape `[n, n]`, 214 where `n` is the number of valid labels for a given classification task. Both 215 prediction and labels must be 1-D arrays of the same shape in order for this 216 function to work. 217 218 If `num_classes` is `None`, then `num_classes` will be set to one plus the 219 maximum value in either predictions or labels. Class labels are expected to 220 start at 0. For example, if `num_classes` is 3, then the possible labels 221 would be `[0, 1, 2]`. 222 223 If `weights` is not `None`, then each prediction contributes its 224 corresponding weight to the total value of the confusion matrix cell. 225 226 For example: 227 228 ```python 229 tf.math.confusion_matrix([1, 2, 4], [2, 2, 4]) ==> 230 [[0 0 0 0 0] 231 [0 0 1 0 0] 232 [0 0 1 0 0] 233 [0 0 0 0 0] 234 [0 0 0 0 1]] 235 ``` 236 237 Note that the possible labels are assumed to be `[0, 1, 2, 3, 4]`, 238 resulting in a 5x5 confusion matrix. 239 240 Args: 241 labels: 1-D `Tensor` of real labels for the classification task. 242 predictions: 1-D `Tensor` of predictions for a given classification. 243 num_classes: The possible number of labels the classification task can have. 244 If this value is not provided, it will be calculated using both 245 predictions and labels array. 246 dtype: Data type of the confusion matrix. 247 name: Scope name. 248 weights: An optional `Tensor` whose shape matches `predictions`. 249 250 Returns: 251 A `Tensor` of type `dtype` with shape `[n, n]` representing the confusion 252 matrix, where `n` is the number of possible labels in the classification 253 task. 254 255 Raises: 256 ValueError: If both predictions and labels are not 1-D vectors and have 257 mismatched shapes, or if `weights` is not `None` and its shape doesn't 258 match `predictions`. 259 """ 260 return confusion_matrix(labels, predictions, num_classes, weights, dtype, 261 name) 262