1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A collection of functions to be used as evaluation metrics."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.contrib import losses
23from tensorflow.contrib.learn.python.learn.estimators import prediction_key
24
25from tensorflow.python.framework import dtypes
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops import metrics
29from tensorflow.python.ops import nn
30
31INFERENCE_PROB_NAME = prediction_key.PredictionKey.PROBABILITIES
32INFERENCE_PRED_NAME = prediction_key.PredictionKey.CLASSES
33
34FEATURE_IMPORTANCE_NAME = 'global_feature_importance'
35
36
37def _top_k_generator(k):
38  def _top_k(probabilities, targets):
39    targets = math_ops.cast(targets, dtypes.int32)
40    if targets.get_shape().ndims > 1:
41      targets = array_ops.squeeze(targets, axis=[1])
42    return metrics.mean(nn.in_top_k(probabilities, targets, k))
43  return _top_k
44
45
46def _accuracy(predictions, targets, weights=None):
47  return metrics.accuracy(
48      labels=targets, predictions=predictions, weights=weights)
49
50
51def _r2(probabilities, targets, weights=None):
52  targets = math_ops.cast(targets, dtypes.float32)
53  y_mean = math_ops.reduce_mean(targets, 0)
54  squares_total = math_ops.reduce_sum(
55      math_ops.squared_difference(targets, y_mean), 0)
56  squares_residuals = math_ops.reduce_sum(
57      math_ops.squared_difference(targets, probabilities), 0)
58  score = 1 - math_ops.reduce_sum(squares_residuals / squares_total)
59  return metrics.mean(score, weights=weights)
60
61
62def _squeeze_and_onehot(targets, depth):
63  targets = array_ops.squeeze(targets, axis=[1])
64  return array_ops.one_hot(math_ops.cast(targets, dtypes.int32), depth)
65
66
67def _sigmoid_entropy(probabilities, targets, weights=None):
68  return metrics.mean(
69      losses.sigmoid_cross_entropy(probabilities,
70                                   _squeeze_and_onehot(
71                                       targets,
72                                       array_ops.shape(probabilities)[1])),
73      weights=weights)
74
75
76def _softmax_entropy(probabilities, targets, weights=None):
77  return metrics.mean(
78      losses.sparse_softmax_cross_entropy(probabilities,
79                                          math_ops.cast(targets, dtypes.int32)),
80      weights=weights)
81
82
83def _predictions(predictions, unused_targets, **unused_kwargs):
84  return predictions
85
86
87def _class_log_loss(probabilities, targets, weights=None):
88  return metrics.mean(
89      losses.log_loss(probabilities,
90                      _squeeze_and_onehot(targets,
91                                          array_ops.shape(probabilities)[1])),
92      weights=weights)
93
94
95def _precision(predictions, targets, weights=None):
96  return metrics.precision(
97      labels=targets, predictions=predictions, weights=weights)
98
99
100def _precision_at_thresholds(predictions, targets, weights=None):
101  return metrics.precision_at_thresholds(
102      labels=targets,
103      predictions=array_ops.slice(predictions, [0, 1], [-1, 1]),
104      thresholds=np.arange(0, 1, 0.01, dtype=np.float32),
105      weights=weights)
106
107
108def _recall(predictions, targets, weights=None):
109  return metrics.recall(
110      labels=targets, predictions=predictions, weights=weights)
111
112
113def _recall_at_thresholds(predictions, targets, weights=None):
114  return metrics.recall_at_thresholds(
115      labels=targets,
116      predictions=array_ops.slice(predictions, [0, 1], [-1, 1]),
117      thresholds=np.arange(0, 1, 0.01, dtype=np.float32),
118      weights=weights)
119
120
121def _auc(probs, targets, weights=None):
122  return metrics.auc(
123      labels=targets,
124      predictions=array_ops.slice(probs, [0, 1], [-1, 1]),
125      weights=weights)
126
127
128_EVAL_METRICS = {
129    'auc': _auc,
130    'sigmoid_entropy': _sigmoid_entropy,
131    'softmax_entropy': _softmax_entropy,
132    'accuracy': _accuracy,
133    'r2': _r2,
134    'predictions': _predictions,
135    'classification_log_loss': _class_log_loss,
136    'precision': _precision,
137    'precision_at_thresholds': _precision_at_thresholds,
138    'recall': _recall,
139    'recall_at_thresholds': _recall_at_thresholds,
140    'top_5': _top_k_generator(5)
141}
142
143_PREDICTION_KEYS = {
144    'auc': INFERENCE_PROB_NAME,
145    'sigmoid_entropy': INFERENCE_PROB_NAME,
146    'softmax_entropy': INFERENCE_PROB_NAME,
147    'accuracy': INFERENCE_PRED_NAME,
148    'r2': prediction_key.PredictionKey.SCORES,
149    'predictions': INFERENCE_PRED_NAME,
150    'classification_log_loss': INFERENCE_PROB_NAME,
151    'precision': INFERENCE_PRED_NAME,
152    'precision_at_thresholds': INFERENCE_PROB_NAME,
153    'recall': INFERENCE_PRED_NAME,
154    'recall_at_thresholds': INFERENCE_PROB_NAME,
155    'top_5': INFERENCE_PROB_NAME
156}
157
158
159def get_metric(metric_name):
160  """Given a metric name, return the corresponding metric function."""
161  return _EVAL_METRICS[metric_name]
162
163
164def get_prediction_key(metric_name):
165  return _PREDICTION_KEYS[metric_name]
166