1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The metric spec class to flexibly connect models and metrics (deprecated).
16
17This module and all its submodules are deprecated. See
18[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
19for migration instructions.
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import six
27
28from tensorflow.python.platform import tf_logging as logging
29from tensorflow.python.util import tf_inspect
30from tensorflow.python.util.deprecation import deprecated
31
32
33def _assert_named_args(sentinel):
34  if sentinel is not None:
35    raise ValueError(
36        '`metric_fn` requires named args: '
37        '`labels`, `predictions`, and optionally `weights`.')
38
39
40def _args(fn):
41  """Get argument names for function-like object.
42
43  Args:
44    fn: Function, or function-like object (e.g., result of `functools.partial`).
45
46  Returns:
47    `tuple` of string argument names.
48  """
49  if hasattr(fn, 'func') and hasattr(fn, 'keywords'):
50    # Handle functools.partial and similar objects.
51    return tuple(
52        [arg for arg in _args(fn.func) if arg not in set(fn.keywords.keys())])
53  # Handle function.
54  return tuple(tf_inspect.getargspec(fn).args)
55
56
57_CANONICAL_LABELS_ARG = 'labels'
58_LABELS_ARGS = set((_CANONICAL_LABELS_ARG, 'label', 'targets', 'target'))
59_CANONICAL_PREDICTIONS_ARG = 'predictions'
60_PREDICTIONS_ARGS = set((_CANONICAL_PREDICTIONS_ARG, 'prediction',
61                         'logits', 'logit'))
62_CANONICAL_WEIGHTS_ARG = 'weights'
63_WEIGHTS_ARGS = set((_CANONICAL_WEIGHTS_ARG, 'weight'))
64
65
66def _matching_arg(
67    fn_name, fn_args, candidate_args, canonical_arg, is_required=False):
68  """Find single argument in `args` from `candidate_args`.
69
70  Args:
71    fn_name: Function name, only used for error string.
72    fn_args: String argument names to `fn_name` function.
73    candidate_args: Candidate argument names to find in `args`.
74    canonical_arg: Canonical argument name in `candidate_args`. This is only
75      used to log a warning if a non-canonical match is found.
76    is_required: Whether function is required to have an arg in
77      `candidate_args`.
78
79  Returns:
80    String argument name if found, or `None` if not found.
81
82  Raises:
83    ValueError: if 2 candidates are found, or 0 are found and `is_required` is
84      set.
85  """
86  assert canonical_arg in candidate_args  # Sanity check.
87  matching_args = candidate_args.intersection(fn_args)
88  if len(matching_args) > 1:
89    raise ValueError(
90        'Ambiguous arguments %s, must provide only one of %s.' % (
91            matching_args, candidate_args))
92  matching_arg = matching_args.pop() if matching_args else None
93  if matching_arg:
94    if matching_arg != canonical_arg:
95      logging.warning(
96          'Canonical arg %s missing from %s(%s), using %s.',
97          canonical_arg, fn_name, fn_args, matching_arg)
98  elif is_required:
99    raise ValueError(
100        '%s missing from %s(%s).' % (candidate_args, fn_name, fn_args))
101  return matching_arg
102
103
104def _fn_name(fn):
105  if hasattr(fn, '__name__'):
106    return fn.__name__
107  if hasattr(fn, 'func') and hasattr(fn.func, '__name__'):
108    return fn.func.__name__  # If it's a functools.partial.
109  return str(fn)
110
111
112def _adapt_metric_fn(
113    metric_fn, metric_fn_name, is_labels_required, is_weights_required):
114  """Adapt `metric_fn` to take only named args.
115
116  This returns a function that takes only named args `labels`, `predictions`,
117  and `weights`, and invokes `metric_fn` according to the following rules:
118  - If `metric_fn` args include exactly one of `_LABELS_ARGS`, that arg is
119    passed (usually by name, but positionally if both it and `predictions` need
120    to be passed positionally). Otherwise, `labels` are omitted.
121  - If `metric_fn` args include exactly one of `_PREDICTIONS_ARGS`, that arg is
122    passed by name. Otherwise, `predictions` are passed positionally as the
123    first non-label argument.
124  - If exactly one of `_WEIGHTS_ARGS` is provided, that arg is passed by
125    name.
126
127  Args:
128    metric_fn: Metric function to be wrapped.
129    metric_fn_name: `metric_fn` name, only used for logging.
130    is_labels_required: Whether `labels` is a required arg.
131    is_weights_required: Whether `weights` is a required arg.
132
133  Returns:
134    Function accepting only named args `labels, `predictions`, and `weights`,
135    and passing those to `metric_fn`.
136
137  Raises:
138    ValueError: if one of the following is true:
139    - `metric_fn` has more than one arg of `_LABELS_ARGS`, `_PREDICTIONS_ARGS`,
140      or `_WEIGHTS_ARGS`
141    - `is_labels_required` is true, and `metric_fn` has no arg from
142      `_LABELS_ARGS`.
143    - `is_weights_required` is true, and `metric_fn` has no arg from
144      `_WEIGHTS_ARGS`.
145  """
146  args = _args(metric_fn)
147
148  labels_arg = _matching_arg(
149      metric_fn_name, args, _LABELS_ARGS, _CANONICAL_LABELS_ARG,
150      is_labels_required)
151  predictions_arg = _matching_arg(
152      metric_fn_name, args, _PREDICTIONS_ARGS, _CANONICAL_PREDICTIONS_ARG)
153  weights_arg = _matching_arg(
154      metric_fn_name, args, _WEIGHTS_ARGS, _CANONICAL_WEIGHTS_ARG,
155      is_weights_required)
156
157  # pylint: disable=invalid-name
158  if labels_arg:
159    if predictions_arg:
160      # Both labels and predictions are named args.
161      def _named_metric_fn(
162          _sentinel=None, labels=None, predictions=None, weights=None):
163        _assert_named_args(_sentinel)
164        kwargs = {
165            labels_arg: labels,
166            predictions_arg: predictions,
167        }
168        if weights is not None:
169          kwargs[weights_arg] = weights
170        return metric_fn(**kwargs)
171      return _named_metric_fn
172
173    if labels_arg == args[0]:
174      # labels is a named arg, and first. predictions is not a named arg, so we
175      # want to pass it as the 2nd positional arg (i.e., the first non-labels
176      # position), which means passing both positionally.
177      def _positional_metric_fn(
178          _sentinel=None, labels=None, predictions=None, weights=None):
179        _assert_named_args(_sentinel)
180        # TODO(ptucker): Should we support metrics that take only labels?
181        # Currently, if you want streaming mean of a label, you have to wrap it
182        # in a fn that takes discards predictions.
183        if weights is None:
184          return metric_fn(labels, predictions)
185        return metric_fn(labels, predictions, **{weights_arg: weights})
186      return _positional_metric_fn
187
188    # labels is a named arg, and not first, so we pass predictions positionally
189    # and labels by name.
190    def _positional_predictions_metric_fn(
191        _sentinel=None, labels=None, predictions=None, weights=None):
192      _assert_named_args(_sentinel)
193      kwargs = {
194          labels_arg: labels,
195      }
196      if weights is not None:
197        kwargs[weights_arg] = weights
198      return metric_fn(predictions, **kwargs)
199    return _positional_predictions_metric_fn
200
201  if predictions_arg:
202    # No labels, and predictions is named, so we pass the latter as a named arg.
203    def _named_no_labels_metric_fn(
204        _sentinel=None, labels=None, predictions=None, weights=None):
205      del labels
206      _assert_named_args(_sentinel)
207      kwargs = {
208          predictions_arg: predictions,
209      }
210      # TODO(ptucker): Should we allow weights with no labels?
211      if weights is not None:
212        kwargs[weights_arg] = weights
213      return metric_fn(**kwargs)
214    return _named_no_labels_metric_fn
215
216  # Neither labels nor predictions are named, so we just pass predictions as the
217  # first arg.
218  def _positional_no_labels_metric_fn(
219      _sentinel=None, labels=None, predictions=None, weights=None):
220    del labels
221    _assert_named_args(_sentinel)
222    if weights is None:
223      return metric_fn(predictions)
224    # TODO(ptucker): Should we allow weights with no labels?
225    return metric_fn(predictions, **{weights_arg: weights})
226  return _positional_no_labels_metric_fn
227
228
229class MetricSpec(object):
230  """MetricSpec connects a model to metric functions.
231
232  THIS CLASS IS DEPRECATED. See
233  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
234  for general migration instructions.
235
236  The MetricSpec class contains all information necessary to connect the
237  output of a `model_fn` to the metrics (usually, streaming metrics) that are
238  used in evaluation.
239
240  It is passed in the `metrics` argument of `Estimator.evaluate`. The
241  `Estimator` then knows which predictions, labels, and weight to use to call a
242  given metric function.
243
244  When building the ops to run in evaluation, an `Estimator` will call
245  `create_metric_ops`, which will connect the given `metric_fn` to the model
246  as detailed in the docstring for `create_metric_ops`, and return the metric.
247
248  Example:
249
250  Assuming a model has an input function which returns inputs containing
251  (among other things) a tensor with key "input_key", and a labels dictionary
252  containing "label_key". Let's assume that the `model_fn` for this model
253  returns a prediction with key "prediction_key".
254
255  In order to compute the accuracy of the "prediction_key" prediction, we
256  would add
257
258  ```
259  "prediction accuracy": MetricSpec(metric_fn=prediction_accuracy_fn,
260                                    prediction_key="prediction_key",
261                                    label_key="label_key")
262  ```
263
264  to the metrics argument to `evaluate`. `prediction_accuracy_fn` can be either
265  a predefined function in metric_ops (e.g., `streaming_accuracy`) or a custom
266  function you define.
267
268  If we would like the accuracy to be weighted by "input_key", we can add that
269  as the `weight_key` argument.
270
271  ```
272  "prediction accuracy": MetricSpec(metric_fn=prediction_accuracy_fn,
273                                    prediction_key="prediction_key",
274                                    label_key="label_key",
275                                    weight_key="input_key")
276  ```
277
278  An end-to-end example is as follows:
279
280  ```
281  estimator = tf.contrib.learn.Estimator(...)
282  estimator.fit(...)
283  _ = estimator.evaluate(
284      input_fn=input_fn,
285      steps=1,
286      metrics={
287          'prediction accuracy':
288              metric_spec.MetricSpec(
289                  metric_fn=prediction_accuracy_fn,
290                  prediction_key="prediction_key",
291                  label_key="label_key")
292      })
293  ```
294
295  """
296
297  @deprecated(None, 'Use tf.estimator.EstimatorSpec.eval_metric_ops.')
298  def __init__(self,
299               metric_fn,
300               prediction_key=None,
301               label_key=None,
302               weight_key=None):
303    """Constructor.
304
305    Creates a MetricSpec.
306
307    Args:
308      metric_fn: A function to use as a metric. See `_adapt_metric_fn` for
309        rules on how `predictions`, `labels`, and `weights` are passed to this
310        function. This must return either a single `Tensor`, which is
311        interpreted as a value of this metric, or a pair
312        `(value_op, update_op)`, where `value_op` is the op to call to
313        obtain the value of the metric, and `update_op` should be run for
314        each batch to update internal state.
315      prediction_key: The key for a tensor in the `predictions` dict (output
316        from the `model_fn`) to use as the `predictions` input to the
317        `metric_fn`. Optional. If `None`, the `model_fn` must return a single
318        tensor or a dict with only a single entry as `predictions`.
319      label_key: The key for a tensor in the `labels` dict (output from the
320        `input_fn`) to use as the `labels` input to the `metric_fn`.
321        Optional. If `None`, the `input_fn` must return a single tensor or a
322        dict with only a single entry as `labels`.
323      weight_key: The key for a tensor in the `inputs` dict (output from the
324        `input_fn`) to use as the `weights` input to the `metric_fn`.
325        Optional. If `None`, no weights will be passed to the `metric_fn`.
326    """
327    self._metric_fn_name = _fn_name(metric_fn)
328    self._metric_fn = _adapt_metric_fn(
329        metric_fn=metric_fn,
330        metric_fn_name=self._metric_fn_name,
331        is_labels_required=label_key is not None,
332        is_weights_required=weight_key is not None)
333    self._prediction_key = prediction_key
334    self._label_key = label_key
335    self._weight_key = weight_key
336
337  @property
338  def prediction_key(self):
339    return self._prediction_key
340
341  @property
342  def label_key(self):
343    return self._label_key
344
345  @property
346  def weight_key(self):
347    return self._weight_key
348
349  @property
350  def metric_fn(self):
351    """Metric function.
352
353    This function accepts named args: `predictions`, `labels`, `weights`. It
354    returns a single `Tensor` or `(value_op, update_op)` pair. See `metric_fn`
355    constructor argument for more details.
356
357    Returns:
358      Function, see `metric_fn` constructor argument for more details.
359    """
360    return self._metric_fn
361
362  def __str__(self):
363    return ('MetricSpec(metric_fn=%s, ' % self._metric_fn_name +
364            'prediction_key=%s, ' % self.prediction_key +
365            'label_key=%s, ' % self.label_key +
366            'weight_key=%s)' % self.weight_key
367           )
368
369  def create_metric_ops(self, inputs, labels, predictions):
370    """Connect our `metric_fn` to the specified members of the given dicts.
371
372    This function will call the `metric_fn` given in our constructor as follows:
373
374    ```
375      metric_fn(predictions[self.prediction_key],
376                labels[self.label_key],
377                weights=weights[self.weight_key])
378    ```
379
380    And returns the result. The `weights` argument is only passed if
381    `self.weight_key` is not `None`.
382
383    `predictions` and `labels` may be single tensors as well as dicts. If
384    `predictions` is a single tensor, `self.prediction_key` must be `None`. If
385    `predictions` is a single element dict, `self.prediction_key` is allowed to
386    be `None`. Conversely, if `labels` is a single tensor, `self.label_key` must
387    be `None`. If `labels` is a single element dict, `self.label_key` is allowed
388    to be `None`.
389
390    Args:
391      inputs: A dict of inputs produced by the `input_fn`
392      labels: A dict of labels or a single label tensor produced by the
393        `input_fn`.
394      predictions: A dict of predictions or a single tensor produced by the
395        `model_fn`.
396
397    Returns:
398      The result of calling `metric_fn`.
399
400    Raises:
401      ValueError: If `predictions` or `labels` is a single `Tensor` and
402        `self.prediction_key` or `self.label_key` is not `None`; or if
403        `self.label_key` is `None` but `labels` is a dict with more than one
404        element, or if `self.prediction_key` is `None` but `predictions` is a
405        dict with more than one element.
406    """
407    def _get_dict(name, dict_or_tensor, key):
408      """Get a single tensor or an element of a dict or raise ValueError."""
409      if key:
410        if not isinstance(dict_or_tensor, dict):
411          raise ValueError('MetricSpec with ' + name + '_key specified'
412                           ' requires ' +
413                           name + 's dict, got %s.\n' % dict_or_tensor +
414                           'You must not provide a %s_key if you ' % name +
415                           'only have a single Tensor as %ss.' % name)
416        if key not in dict_or_tensor:
417          raise KeyError(
418              'Key \'%s\' missing from %s.' % (key, dict_or_tensor.keys()))
419        return dict_or_tensor[key]
420      else:
421        if isinstance(dict_or_tensor, dict):
422          if len(dict_or_tensor) != 1:
423            raise ValueError('MetricSpec without specified ' + name + '_key'
424                             ' requires ' + name + 's tensor or single element'
425                             ' dict, got %s' % dict_or_tensor)
426          return six.next(six.itervalues(dict_or_tensor))
427        return dict_or_tensor
428
429    # Get the predictions.
430    prediction = _get_dict('prediction', predictions, self.prediction_key)
431
432    # Get the labels.
433    label = _get_dict('label', labels, self.label_key)
434
435    try:
436      return self.metric_fn(
437          labels=label,
438          predictions=prediction,
439          weights=inputs[self.weight_key] if self.weight_key else None)
440    except Exception as ex:
441      logging.error('Could not create metric ops for %s, %s.' % (self, ex))
442      raise
443