1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A collection of functions to be used as evaluation metrics.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21 22from tensorflow.contrib import losses 23from tensorflow.contrib.learn.python.learn.estimators import prediction_key 24 25from tensorflow.python.framework import dtypes 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.ops import metrics 29from tensorflow.python.ops import nn 30 31INFERENCE_PROB_NAME = prediction_key.PredictionKey.PROBABILITIES 32INFERENCE_PRED_NAME = prediction_key.PredictionKey.CLASSES 33 34FEATURE_IMPORTANCE_NAME = 'global_feature_importance' 35 36 37def _top_k_generator(k): 38 def _top_k(probabilities, targets): 39 targets = math_ops.cast(targets, dtypes.int32) 40 if targets.get_shape().ndims > 1: 41 targets = array_ops.squeeze(targets, axis=[1]) 42 return metrics.mean(nn.in_top_k(probabilities, targets, k)) 43 return _top_k 44 45 46def _accuracy(predictions, targets, weights=None): 47 return metrics.accuracy( 48 labels=targets, predictions=predictions, weights=weights) 49 50 51def _r2(probabilities, targets, weights=None): 52 targets = math_ops.cast(targets, dtypes.float32) 53 y_mean = math_ops.reduce_mean(targets, 0) 54 squares_total = math_ops.reduce_sum( 55 math_ops.squared_difference(targets, y_mean), 0) 56 squares_residuals = math_ops.reduce_sum( 57 math_ops.squared_difference(targets, probabilities), 0) 58 score = 1 - math_ops.reduce_sum(squares_residuals / squares_total) 59 return metrics.mean(score, weights=weights) 60 61 62def _squeeze_and_onehot(targets, depth): 63 targets = array_ops.squeeze(targets, axis=[1]) 64 return array_ops.one_hot(math_ops.cast(targets, dtypes.int32), depth) 65 66 67def _sigmoid_entropy(probabilities, targets, weights=None): 68 return metrics.mean( 69 losses.sigmoid_cross_entropy(probabilities, 70 _squeeze_and_onehot( 71 targets, 72 array_ops.shape(probabilities)[1])), 73 weights=weights) 74 75 76def _softmax_entropy(probabilities, targets, weights=None): 77 return metrics.mean( 78 losses.sparse_softmax_cross_entropy(probabilities, 79 math_ops.cast(targets, dtypes.int32)), 80 weights=weights) 81 82 83def _predictions(predictions, unused_targets, **unused_kwargs): 84 return predictions 85 86 87def _class_log_loss(probabilities, targets, weights=None): 88 return metrics.mean( 89 losses.log_loss(probabilities, 90 _squeeze_and_onehot(targets, 91 array_ops.shape(probabilities)[1])), 92 weights=weights) 93 94 95def _precision(predictions, targets, weights=None): 96 return metrics.precision( 97 labels=targets, predictions=predictions, weights=weights) 98 99 100def _precision_at_thresholds(predictions, targets, weights=None): 101 return metrics.precision_at_thresholds( 102 labels=targets, 103 predictions=array_ops.slice(predictions, [0, 1], [-1, 1]), 104 thresholds=np.arange(0, 1, 0.01, dtype=np.float32), 105 weights=weights) 106 107 108def _recall(predictions, targets, weights=None): 109 return metrics.recall( 110 labels=targets, predictions=predictions, weights=weights) 111 112 113def _recall_at_thresholds(predictions, targets, weights=None): 114 return metrics.recall_at_thresholds( 115 labels=targets, 116 predictions=array_ops.slice(predictions, [0, 1], [-1, 1]), 117 thresholds=np.arange(0, 1, 0.01, dtype=np.float32), 118 weights=weights) 119 120 121def _auc(probs, targets, weights=None): 122 return metrics.auc( 123 labels=targets, 124 predictions=array_ops.slice(probs, [0, 1], [-1, 1]), 125 weights=weights) 126 127 128_EVAL_METRICS = { 129 'auc': _auc, 130 'sigmoid_entropy': _sigmoid_entropy, 131 'softmax_entropy': _softmax_entropy, 132 'accuracy': _accuracy, 133 'r2': _r2, 134 'predictions': _predictions, 135 'classification_log_loss': _class_log_loss, 136 'precision': _precision, 137 'precision_at_thresholds': _precision_at_thresholds, 138 'recall': _recall, 139 'recall_at_thresholds': _recall_at_thresholds, 140 'top_5': _top_k_generator(5) 141} 142 143_PREDICTION_KEYS = { 144 'auc': INFERENCE_PROB_NAME, 145 'sigmoid_entropy': INFERENCE_PROB_NAME, 146 'softmax_entropy': INFERENCE_PROB_NAME, 147 'accuracy': INFERENCE_PRED_NAME, 148 'r2': prediction_key.PredictionKey.SCORES, 149 'predictions': INFERENCE_PRED_NAME, 150 'classification_log_loss': INFERENCE_PROB_NAME, 151 'precision': INFERENCE_PRED_NAME, 152 'precision_at_thresholds': INFERENCE_PROB_NAME, 153 'recall': INFERENCE_PRED_NAME, 154 'recall_at_thresholds': INFERENCE_PROB_NAME, 155 'top_5': INFERENCE_PROB_NAME 156} 157 158 159def get_metric(metric_name): 160 """Given a metric name, return the corresponding metric function.""" 161 return _EVAL_METRICS[metric_name] 162 163 164def get_prediction_key(metric_name): 165 return _PREDICTION_KEYS[metric_name] 166