1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Abstractions for the head(s) of a model (deprecated).
16
17This module and all its submodules are deprecated. See
18[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
19for migration instructions.
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import abc
27
28import six
29
30from tensorflow.contrib import framework as framework_lib
31from tensorflow.contrib import layers as layers_lib
32from tensorflow.contrib.learn.python.learn.estimators import constants
33from tensorflow.contrib.learn.python.learn.estimators import model_fn
34from tensorflow.contrib.learn.python.learn.estimators import prediction_key
35from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey as mkey
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import sparse_tensor
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import control_flow_ops
41from tensorflow.python.ops import lookup_ops
42from tensorflow.python.ops import math_ops
43from tensorflow.python.ops import metrics as metrics_lib
44from tensorflow.python.ops import nn
45from tensorflow.python.ops import sparse_ops
46from tensorflow.python.ops import string_ops
47from tensorflow.python.ops import variable_scope
48from tensorflow.python.ops import weights_broadcast_ops
49from tensorflow.python.ops.losses import losses as losses_lib
50from tensorflow.python.platform import tf_logging as logging
51from tensorflow.python.summary import summary
52from tensorflow.python.training import training
53from tensorflow.python.util import tf_decorator
54from tensorflow.python.util import tf_inspect
55from tensorflow.python.util.deprecation import deprecated
56
57
58@six.add_metaclass(abc.ABCMeta)
59class Head(object):
60  """Interface for the head/top of a model.
61
62  THIS CLASS IS DEPRECATED. See
63  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
64  for general migration instructions.
65
66  Given logits (or output of a hidden layer), a Head knows how to compute
67  predictions, loss, default metric and export signature. It is meant to,
68
69  1) Simplify writing model_fn and to make model_fn more configurable
70  2) Support wide range of machine learning models. Since most heads can work
71      with logits, they can support DNN, RNN, Wide, Wide&Deep,
72      Global objectives, Gradient boosted trees and many other types
73      of machine learning models.
74  2) To allow users to seamlessly switch between 1 to n heads for multi
75  objective learning (See _MultiHead implementation for more details)
76
77  Common usage:
78  Here is simplified model_fn to build a multiclass DNN model.
79    ```python
80    def _my_dnn_model_fn(features, labels, mode, params, config=None):
81      # Optionally your callers can pass head to model_fn as a param.
82      head = tf.contrib.learn.multi_class_head(...)
83      input = tf.contrib.layers.input_from_feature_columns(features, ...)
84      last_hidden_layer_out = tf.contrib.layers.stack(
85          input, tf.contrib.layers.fully_connected, [1000, 500])
86      logits = tf.contrib.layers.fully_connected(
87          last_hidden_layer_out, head.logits_dimension, activation_fn=None)
88
89      def _train_op_fn(loss):
90        return optimizer.minimize(loss)
91
92      return head.create_model_fn_ops(
93          features=features,
94          labels=labels,
95          mode=mode,
96          train_op_fn=_train_op_fn,
97          logits=logits,
98          scope=...)
99    ```
100
101  Most heads also support logits_input which is typically the output of the last
102  hidden layer. Some heads (like heads responsible for candidate sampling or
103  hierarchical softmax) intrinsically will not support logits and you have
104  to pass logits_input. Here is a common usage,
105    ```python
106    return head.create_model_fn_ops(
107        features=features,
108        labels=labels,
109        mode=mode,
110        train_op_fn=_train_op_fn,
111        logits_input=last_hidden_layer_out,
112        scope=...)
113    ```python
114
115  There are cases where computing and applying gradients can not be meaningfully
116  captured with train_op_fn we support (for example, with sync optimizer). In
117  such case, you can take the responsibility on your own. Here is a common
118  use case,
119    ```python
120    model_fn_ops = head.create_model_fn_ops(
121        features=features,
122        labels=labels,
123        mode=mode,
124        train_op_fn=tf.contrib.learn.no_op_train_fn,
125        logits=logits,
126        scope=...)
127    if mode == tf.contrib.learn.ModeKeys.TRAIN:
128      optimizer = ...
129      sync = tf.train.SyncReplicasOptimizer(opt=optimizer, ...)
130      update_op = tf.contrib.layers.optimize_loss(optimizer=sync,
131                                                  loss=model_fn_ops.loss, ...)
132      hooks = [sync.make_session_run_hook(is_chief)]
133      ... update train_op and hooks in ModelFnOps and return
134    ```
135  """
136
137  @abc.abstractproperty
138  def logits_dimension(self):
139    """Size of the last dimension of the logits `Tensor`.
140
141    Typically, logits is of shape `[batch_size, logits_dimension]`.
142
143    Returns:
144      The expected size of the `logits` tensor.
145    """
146    raise NotImplementedError("Calling an abstract method.")
147
148  @abc.abstractmethod
149  def create_model_fn_ops(self,
150                          features,
151                          mode,
152                          labels=None,
153                          train_op_fn=None,
154                          logits=None,
155                          logits_input=None,
156                          scope=None):
157    """Returns `ModelFnOps` that a model_fn can return.
158
159    Please note that,
160    + Exactly one of `logits` and `logits_input` must be provided.
161    + All args must be passed via name.
162
163    Args:
164      features: Input `dict` of `Tensor` objects.
165      mode: Estimator's `ModeKeys`.
166      labels: Labels `Tensor`, or `dict` of same.
167      train_op_fn: Function that takes a scalar loss `Tensor` and returns an op
168          to optimize the model with the loss. This is used in TRAIN mode and
169          must not be None. None is allowed in other modes. If you want to
170          optimize loss yourself you can pass `no_op_train_fn` and then use
171          ModeFnOps.loss to compute and apply gradients.
172      logits: logits `Tensor` to be used by the head.
173      logits_input: `Tensor` from which to build logits, often needed when you
174        don't want to compute the logits. Typically this is the activation of
175        the last hidden layer in a DNN. Some heads (like the ones responsible
176        for candidate sampling) intrinsically avoid computing full logits and
177        only accepts logits_input.
178      scope: Optional scope for `variable_scope`.
179
180    Returns:
181      An instance of `ModelFnOps`.
182
183    Raises:
184      ValueError: If `mode` is not recognized.
185      ValueError: If neither or both of `logits` and `logits_input` is provided.
186    """
187    raise NotImplementedError("Calling an abstract method.")
188
189
190@deprecated(None, "Please switch to tf.contrib.estimator.*_head.")
191def regression_head(label_name=None,
192                    weight_column_name=None,
193                    label_dimension=1,
194                    enable_centered_bias=False,
195                    head_name=None,
196                    link_fn=None):
197  """Creates a `Head` for linear regression.
198
199  Args:
200    label_name: String, name of the key in label dict. Can be null if label
201        is a tensor (single headed models).
202    weight_column_name: A string defining feature column name representing
203      weights. It is used to down weight or boost examples during training. It
204      will be multiplied by the loss of the example.
205    label_dimension: Number of regression labels per example. This is the size
206      of the last dimension of the labels `Tensor` (typically, this has shape
207      `[batch_size, label_dimension]`).
208    enable_centered_bias: A bool. If True, estimator will learn a centered
209      bias variable for each class. Rest of the model structure learns the
210      residual after centered bias.
211    head_name: name of the head. If provided, predictions, summary and metrics
212      keys will be suffixed by `"/" + head_name` and the default variable scope
213      will be `head_name`.
214    link_fn: link function to convert logits to predictions. If provided,
215      this link function will be used instead of identity.
216
217  Returns:
218    An instance of `Head` for linear regression.
219  """
220  return _RegressionHead(
221      label_name=label_name,
222      weight_column_name=weight_column_name,
223      label_dimension=label_dimension,
224      enable_centered_bias=enable_centered_bias,
225      head_name=head_name,
226      loss_fn=_mean_squared_loss,
227      link_fn=(link_fn if link_fn is not None else array_ops.identity))
228
229
230@deprecated(None, "Please switch to tf.contrib.estimator.*_head.")
231def poisson_regression_head(label_name=None,
232                            weight_column_name=None,
233                            label_dimension=1,
234                            enable_centered_bias=False,
235                            head_name=None):
236  """Creates a `Head` for poisson regression.
237
238  Args:
239    label_name: String, name of the key in label dict. Can be null if label
240        is a tensor (single headed models).
241    weight_column_name: A string defining feature column name representing
242      weights. It is used to down weight or boost examples during training. It
243      will be multiplied by the loss of the example.
244    label_dimension: Number of regression labels per example. This is the size
245      of the last dimension of the labels `Tensor` (typically, this has shape
246      `[batch_size, label_dimension]`).
247    enable_centered_bias: A bool. If True, estimator will learn a centered
248      bias variable for each class. Rest of the model structure learns the
249      residual after centered bias.
250    head_name: name of the head. If provided, predictions, summary and metrics
251      keys will be suffixed by `"/" + head_name` and the default variable scope
252      will be `head_name`.
253
254  Returns:
255    An instance of `Head` for poisson regression.
256  """
257  return _RegressionHead(
258      label_name=label_name,
259      weight_column_name=weight_column_name,
260      label_dimension=label_dimension,
261      enable_centered_bias=enable_centered_bias,
262      head_name=head_name,
263      loss_fn=_poisson_loss,
264      link_fn=math_ops.exp)
265
266# TODO(zakaria): Consider adding a _RegressionHead for logistic_regression
267
268
269@deprecated(None, "Please switch to tf.contrib.estimator.*_head.")
270def multi_class_head(n_classes,
271                     label_name=None,
272                     weight_column_name=None,
273                     enable_centered_bias=False,
274                     head_name=None,
275                     thresholds=None,
276                     metric_class_ids=None,
277                     loss_fn=None,
278                     label_keys=None):
279  """Creates a `Head` for multi class single label classification.
280
281  The Head uses softmax cross entropy loss.
282
283  This head expects to be fed integer labels specifying the class index. But
284  if `label_keys` is specified, then labels must be strings from this
285  vocabulary, and the predicted classes will be strings from the same
286  vocabulary.
287
288  Args:
289    n_classes: Integer, number of classes, must be >= 2
290    label_name: String, name of the key in label dict. Can be null if label
291        is a tensor (single headed models).
292    weight_column_name: A string defining feature column name representing
293      weights. It is used to down weight or boost examples during training. It
294      will be multiplied by the loss of the example.
295    enable_centered_bias: A bool. If True, estimator will learn a centered
296      bias variable for each class. Rest of the model structure learns the
297      residual after centered bias.
298    head_name: name of the head. If provided, predictions, summary and metrics
299      keys will be suffixed by `"/" + head_name` and the default variable scope
300      will be `head_name`.
301    thresholds: thresholds for eval metrics, defaults to [.5]
302    metric_class_ids: List of class IDs for which we should report per-class
303      metrics. Must all be in the range `[0, n_classes)`. Invalid if
304      `n_classes` is 2.
305    loss_fn: Optional function that takes (`labels`, `logits`, `weights`) as
306      parameter and returns a weighted scalar loss. `weights` should be
307      optional. See `tf.losses`
308    label_keys: Optional list of strings with size `[n_classes]` defining the
309      label vocabulary. Only supported for `n_classes` > 2.
310
311  Returns:
312    An instance of `Head` for multi class classification.
313
314  Raises:
315    ValueError: if `n_classes` is < 2.
316    ValueError: If `metric_class_ids` is provided when `n_classes` is 2.
317    ValueError: If `len(label_keys) != n_classes`.
318  """
319  if (n_classes is None) or (n_classes < 2):
320    raise ValueError("n_classes must be > 1 for classification: %s." %
321                     n_classes)
322  if loss_fn:
323    _verify_loss_fn_args(loss_fn)
324
325  loss_fn = _wrap_custom_loss_fn(loss_fn) if loss_fn else None
326  if n_classes == 2:
327    if metric_class_ids:
328      raise ValueError("metric_class_ids invalid for n_classes==2.")
329    if label_keys:
330      raise ValueError("label_keys is not supported for n_classes=2.")
331    return _BinaryLogisticHead(
332        label_name=label_name,
333        weight_column_name=weight_column_name,
334        enable_centered_bias=enable_centered_bias,
335        head_name=head_name,
336        thresholds=thresholds,
337        loss_fn=loss_fn)
338
339  return _MultiClassHead(
340      n_classes=n_classes,
341      label_name=label_name,
342      weight_column_name=weight_column_name,
343      enable_centered_bias=enable_centered_bias,
344      head_name=head_name,
345      thresholds=thresholds,
346      metric_class_ids=metric_class_ids,
347      loss_fn=loss_fn,
348      label_keys=label_keys)
349
350
351@deprecated(None, "Please switch to tf.contrib.estimator.*_head.")
352def binary_svm_head(
353    label_name=None,
354    weight_column_name=None,
355    enable_centered_bias=False,
356    head_name=None,
357    thresholds=None,):
358  """Creates a `Head` for binary classification with SVMs.
359
360  The head uses binary hinge loss.
361
362  Args:
363    label_name: String, name of the key in label dict. Can be null if label
364      is a tensor (single headed models).
365    weight_column_name: A string defining feature column name representing
366      weights. It is used to down weight or boost examples during training. It
367      will be multiplied by the loss of the example.
368    enable_centered_bias: A bool. If True, estimator will learn a centered
369      bias variable for each class. Rest of the model structure learns the
370      residual after centered bias.
371    head_name: name of the head. If provided, predictions, summary and metrics
372      keys will be suffixed by `"/" + head_name` and the default variable scope
373      will be `head_name`.
374    thresholds: thresholds for eval metrics, defaults to [.5]
375
376  Returns:
377    An instance of `Head` for binary classification with SVM.
378  """
379  return _BinarySvmHead(
380      label_name=label_name,
381      weight_column_name=weight_column_name,
382      enable_centered_bias=enable_centered_bias,
383      head_name=head_name,
384      thresholds=thresholds)
385
386
387@deprecated(None, "Please switch to tf.contrib.estimator.*_head.")
388def multi_label_head(n_classes,
389                     label_name=None,
390                     weight_column_name=None,
391                     enable_centered_bias=False,
392                     head_name=None,
393                     thresholds=None,
394                     metric_class_ids=None,
395                     loss_fn=None):
396  """Creates a Head for multi label classification.
397
398  Multi-label classification handles the case where each example may have zero
399  or more associated labels, from a discrete set.  This is distinct from
400  `multi_class_head` which has exactly one label from a discrete set.
401
402  This head by default uses sigmoid cross entropy loss, which expects as input
403  a multi-hot tensor of shape `(batch_size, num_classes)`.
404
405  Args:
406    n_classes: Integer, number of classes, must be >= 2
407    label_name: String, name of the key in label dict. Can be null if label
408        is a tensor (single headed models).
409    weight_column_name: A string defining feature column name representing
410      weights. It is used to down weight or boost examples during training. It
411      will be multiplied by the loss of the example.
412    enable_centered_bias: A bool. If True, estimator will learn a centered
413      bias variable for each class. Rest of the model structure learns the
414      residual after centered bias.
415    head_name: name of the head. If provided, predictions, summary and metrics
416      keys will be suffixed by `"/" + head_name` and the default variable scope
417      will be `head_name`.
418    thresholds: thresholds for eval metrics, defaults to [.5]
419    metric_class_ids: List of class IDs for which we should report per-class
420      metrics. Must all be in the range `[0, n_classes)`.
421    loss_fn: Optional function that takes (`labels`, `logits`, `weights`) as
422      parameter and returns a weighted scalar loss. `weights` should be
423      optional. See `tf.losses`
424
425  Returns:
426    An instance of `Head` for multi label classification.
427
428  Raises:
429    ValueError: If n_classes is < 2
430    ValueError: If loss_fn does not have expected signature.
431  """
432  if n_classes < 2:
433    raise ValueError("n_classes must be > 1 for classification.")
434  if loss_fn:
435    _verify_loss_fn_args(loss_fn)
436
437  return _MultiLabelHead(
438      n_classes=n_classes,
439      label_name=label_name,
440      weight_column_name=weight_column_name,
441      enable_centered_bias=enable_centered_bias,
442      head_name=head_name,
443      thresholds=thresholds,
444      metric_class_ids=metric_class_ids,
445      loss_fn=_wrap_custom_loss_fn(loss_fn) if loss_fn else None)
446
447
448@deprecated(None, "Please switch to tf.contrib.estimator.*_head.")
449def loss_only_head(loss_fn, head_name=None):
450  """Creates a Head that contains only loss terms.
451
452  Loss only head holds additional loss terms to be added to other heads and
453  usually represents additional regularization terms in the objective function.
454
455  Args:
456    loss_fn: a function that takes no argument and returns a list of
457        scalar tensors.
458    head_name: a name for the head.
459
460  Returns:
461    An instance of `Head` to hold the additional losses.
462  """
463  return _LossOnlyHead(loss_fn, head_name=head_name)
464
465
466@deprecated(None, "Please switch to tf.contrib.estimator.*_head.")
467def multi_head(heads, loss_weights=None):
468  """Creates a MultiHead stemming from same logits/hidden layer.
469
470  Args:
471    heads: list of Head objects.
472    loss_weights: optional list of weights to be used to merge losses from
473        each head. All losses are weighted equally if not provided.
474
475  Returns:
476    A instance of `Head` that merges multiple heads.
477
478  Raises:
479    ValueError: if heads and loss_weights have different size.
480  """
481  if loss_weights:
482    if len(loss_weights) != len(heads):
483      raise ValueError("heads and loss_weights must have same size")
484
485  def _weighted_loss_merger(losses):
486    if loss_weights:
487      if len(losses) != len(loss_weights):
488        raise ValueError("losses and loss_weights must have same size")
489      weighted_losses = []
490      for loss, weight in zip(losses, loss_weights):
491        weighted_losses.append(math_ops.multiply(loss, weight))
492      return math_ops.add_n(weighted_losses)
493    else:
494      return math_ops.add_n(losses)
495
496  return _MultiHead(heads, loss_merger=_weighted_loss_merger)
497
498
499@deprecated(None, "Use 'lambda _: tf.no_op()'.")
500def no_op_train_fn(loss):
501  del loss
502  return control_flow_ops.no_op()
503
504
505class _SingleHead(Head):
506  """Interface for a single head/top of a model."""
507
508  def __init__(
509      self, problem_type, logits_dimension, label_name=None,
510      weight_column_name=None, head_name=None):
511    if problem_type is None:
512      raise ValueError("Invalid problem_type %s." % problem_type)
513    if logits_dimension is None or logits_dimension < 1:
514      raise ValueError("Invalid logits_dimension %s." % logits_dimension)
515    self._problem_type = problem_type
516    self._logits_dimension = logits_dimension
517    self._label_name = label_name
518    self._weight_column_name = weight_column_name
519    self._head_name = head_name
520
521  @property
522  def logits_dimension(self):
523    return self._logits_dimension
524
525  @property
526  def label_name(self):
527    return self._label_name
528
529  @property
530  def weight_column_name(self):
531    return self._weight_column_name
532
533  @property
534  def head_name(self):
535    return self._head_name
536
537  def _create_output_alternatives(self, predictions):
538    """Creates output alternative for the Head.
539
540    Args:
541      predictions: a dict of {tensor_name: Tensor}, where 'tensor_name' is a
542        symbolic name for an output Tensor possibly but not necessarily taken
543        from `PredictionKey`, and 'Tensor' is the corresponding output Tensor
544        itself.
545
546    Returns:
547      `dict` of {submodel_name: (problem_type, {tensor_name: Tensor})}, where
548      'submodel_name' is a submodel identifier that should be consistent across
549      the pipeline (here likely taken from the head_name),
550      'problem_type' is a `ProblemType`,
551      'tensor_name' is a symbolic name for an output Tensor possibly but not
552       necessarily taken from `PredictionKey`, and
553      'Tensor' is the corresponding output Tensor itself.
554    """
555    return {self._head_name: (self._problem_type, predictions)}
556
557
558# TODO(zakaria): use contrib losses.
559def _mean_squared_loss(labels, logits, weights=None):
560  with ops.name_scope(None, "mean_squared_loss", (logits, labels)) as name:
561    logits = ops.convert_to_tensor(logits)
562    labels = ops.convert_to_tensor(labels)
563    # To prevent broadcasting inside "-".
564    if len(labels.get_shape()) == 1:
565      labels = array_ops.expand_dims(labels, axis=1)
566    # TODO(zakaria): make sure it does not recreate the broadcast bug.
567    if len(logits.get_shape()) == 1:
568      logits = array_ops.expand_dims(logits, axis=1)
569    logits.get_shape().assert_is_compatible_with(labels.get_shape())
570    loss = math_ops.squared_difference(
571        logits, math_ops.cast(labels, dtypes.float32), name=name)
572    return _compute_weighted_loss(loss, weights)
573
574
575def _poisson_loss(labels, logits, weights=None):
576  """Computes poisson loss from logits."""
577  with ops.name_scope(None, "_poisson_loss", (logits, labels)) as name:
578    logits = ops.convert_to_tensor(logits)
579    labels = ops.convert_to_tensor(labels)
580    # To prevent broadcasting inside "-".
581    if len(labels.get_shape()) == 1:
582      labels = array_ops.expand_dims(labels, axis=1)
583    # TODO(zakaria): make sure it does not recreate the broadcast bug.
584    if len(logits.get_shape()) == 1:
585      logits = array_ops.expand_dims(logits, axis=1)
586    logits.get_shape().assert_is_compatible_with(labels.get_shape())
587    loss = nn.log_poisson_loss(labels, logits, compute_full_loss=True,
588                               name=name)
589    return _compute_weighted_loss(loss, weights)
590
591
592def _logits(logits_input, logits, logits_dimension):
593  """Validate logits args, and create `logits` if necessary.
594
595  Exactly one of `logits_input` and `logits` must be provided.
596
597  Args:
598    logits_input: `Tensor` input to `logits`.
599    logits: `Tensor` output.
600    logits_dimension: Integer, last dimension of `logits`. This is used to
601      create `logits` from `logits_input` if `logits` is `None`; otherwise, it's
602      used to validate `logits`.
603
604  Returns:
605    `logits` `Tensor`.
606
607  Raises:
608    ValueError: if neither or both of `logits` and `logits_input` are supplied.
609  """
610  if (logits_dimension is None) or (logits_dimension < 1):
611    raise ValueError("Invalid logits_dimension %s." % logits_dimension)
612
613  # If not provided, create logits.
614  if logits is None:
615    if logits_input is None:
616      raise ValueError("Neither logits nor logits_input supplied.")
617    return layers_lib.linear(logits_input, logits_dimension, scope="logits")
618
619  if logits_input is not None:
620    raise ValueError("Both logits and logits_input supplied.")
621
622  logits = ops.convert_to_tensor(logits, name="logits")
623  logits_dims = logits.get_shape().dims
624  if logits_dims is not None:
625    logits_dims[-1].assert_is_compatible_with(logits_dimension)
626
627  return logits
628
629
630def _create_model_fn_ops(features,
631                         mode,
632                         loss_fn,
633                         logits_to_predictions_fn,
634                         metrics_fn,
635                         create_output_alternatives_fn,
636                         labels=None,
637                         train_op_fn=None,
638                         logits=None,
639                         logits_dimension=None,
640                         head_name=None,
641                         weight_column_name=None,
642                         enable_centered_bias=False):
643  """Returns a `ModelFnOps` object."""
644  _check_mode_valid(mode)
645
646  centered_bias = None
647  if enable_centered_bias:
648    centered_bias = _centered_bias(logits_dimension, head_name)
649    logits = nn.bias_add(logits, centered_bias)
650
651  predictions = logits_to_predictions_fn(logits)
652  loss = None
653  train_op = None
654  eval_metric_ops = None
655  if (mode != model_fn.ModeKeys.INFER) and (labels is not None):
656    weight_tensor = _weight_tensor(features, weight_column_name)
657    loss, weighted_average_loss = loss_fn(labels, logits, weight_tensor)
658    # The name_scope escapism is needed to maintain the same summary tag
659    # after switching away from the now unsupported API.
660    with ops.name_scope(""):
661      summary_loss = array_ops.identity(weighted_average_loss)
662      summary.scalar(_summary_key(head_name, mkey.LOSS), summary_loss)
663
664    if mode == model_fn.ModeKeys.TRAIN:
665      if train_op_fn is None:
666        raise ValueError("train_op_fn can not be None in TRAIN mode")
667      batch_size = array_ops.shape(logits)[0]
668      train_op = _train_op(loss, labels, train_op_fn, centered_bias,
669                           batch_size, loss_fn, weight_tensor)
670    eval_metric_ops = metrics_fn(
671        weighted_average_loss, predictions, labels, weight_tensor)
672  return model_fn.ModelFnOps(
673      mode=mode,
674      predictions=predictions,
675      loss=loss,
676      train_op=train_op,
677      eval_metric_ops=eval_metric_ops,
678      output_alternatives=create_output_alternatives_fn(predictions))
679
680
681class _RegressionHead(_SingleHead):
682  """`Head` for regression with a generalized linear model."""
683
684  def __init__(self,
685               label_dimension,
686               loss_fn,
687               link_fn,
688               logits_dimension=None,
689               label_name=None,
690               weight_column_name=None,
691               enable_centered_bias=False,
692               head_name=None):
693    """`Head` for regression.
694
695    Args:
696      label_dimension: Number of regression labels per example. This is the
697        size of the last dimension of the labels `Tensor` (typically, this has
698        shape `[batch_size, label_dimension]`).
699      loss_fn: Loss function, takes logits and labels and returns loss.
700      link_fn: Link function, takes a logits tensor and returns the output.
701      logits_dimension: Number of logits per example. This is the
702        size of the last dimension of the logits `Tensor` (typically, this has
703        shape `[batch_size, label_dimension]`).
704        Default value: `label_dimension`.
705      label_name: String, name of the key in label dict. Can be null if label
706          is a tensor (single headed models).
707      weight_column_name: A string defining feature column name representing
708        weights. It is used to down weight or boost examples during training. It
709        will be multiplied by the loss of the example.
710      enable_centered_bias: A bool. If True, estimator will learn a centered
711        bias variable for each class. Rest of the model structure learns the
712        residual after centered bias.
713      head_name: name of the head. Predictions, summary and metrics keys are
714        suffixed by `"/" + head_name` and the default variable scope is
715        `head_name`.
716    """
717    super(_RegressionHead, self).__init__(
718        problem_type=constants.ProblemType.LINEAR_REGRESSION,
719        logits_dimension=(logits_dimension if logits_dimension is not None
720                          else label_dimension),
721        label_name=label_name,
722        weight_column_name=weight_column_name,
723        head_name=head_name)
724
725    self._loss_fn = loss_fn
726    self._link_fn = link_fn
727    self._enable_centered_bias = enable_centered_bias
728
729  def create_model_fn_ops(self,
730                          features,
731                          mode,
732                          labels=None,
733                          train_op_fn=None,
734                          logits=None,
735                          logits_input=None,
736                          scope=None):
737    """See `Head`."""
738    with variable_scope.variable_scope(
739        scope,
740        default_name=self.head_name or "regression_head",
741        values=(tuple(six.itervalues(features)) +
742                (labels, logits, logits_input))):
743      labels = self._transform_labels(mode=mode, labels=labels)
744      logits = _logits(logits_input, logits, self.logits_dimension)
745      return _create_model_fn_ops(
746          features=features,
747          mode=mode,
748          loss_fn=self._loss_fn,
749          logits_to_predictions_fn=self._logits_to_predictions,
750          metrics_fn=self._metrics,
751          create_output_alternatives_fn=self._create_output_alternatives,
752          labels=labels,
753          train_op_fn=train_op_fn,
754          logits=logits,
755          logits_dimension=self.logits_dimension,
756          head_name=self.head_name,
757          weight_column_name=self.weight_column_name,
758          enable_centered_bias=self._enable_centered_bias)
759
760  def _transform_labels(self, mode, labels):
761    """Applies transformations to labels tensor."""
762    if (mode == model_fn.ModeKeys.INFER) or (labels is None):
763      return None
764    labels_tensor = _to_labels_tensor(labels, self._label_name)
765    _check_no_sparse_tensor(labels_tensor)
766    return labels_tensor
767
768  def _logits_to_predictions(self, logits):
769    """Returns a dict of predictions.
770
771    Args:
772      logits: logits `Tensor` after applying possible centered bias.
773
774    Returns:
775      Dict of prediction `Tensor` keyed by `PredictionKey`.
776    """
777    key = prediction_key.PredictionKey.SCORES
778    with ops.name_scope(None, "predictions", (logits,)):
779      if self.logits_dimension == 1:
780        logits = array_ops.squeeze(logits, axis=(1,), name=key)
781      return {key: self._link_fn(logits)}
782
783  def _metrics(self, eval_loss, predictions, labels, weights):
784    """Returns a dict of metrics keyed by name."""
785    del predictions, labels, weights  # Unused by this head.
786    with ops.name_scope("metrics", values=[eval_loss]):
787      return {
788          _summary_key(self.head_name, mkey.LOSS):
789              metrics_lib.mean(eval_loss)}
790
791
792def _log_loss_with_two_classes(labels, logits, weights=None):
793  with ops.name_scope(None, "log_loss_with_two_classes",
794                      (logits, labels)) as name:
795    logits = ops.convert_to_tensor(logits)
796    labels = math_ops.cast(labels, dtypes.float32)
797    # TODO(ptucker): This will break for dynamic shapes.
798    # sigmoid_cross_entropy_with_logits requires [batch_size, 1] labels.
799    if len(labels.get_shape()) == 1:
800      labels = array_ops.expand_dims(labels, axis=1)
801    loss = nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits,
802                                                name=name)
803    return _compute_weighted_loss(loss, weights)
804
805
806def _one_class_to_two_class_logits(logits):
807  return array_ops.concat((array_ops.zeros_like(logits), logits), 1)
808
809
810class _BinaryLogisticHead(_SingleHead):
811  """`Head` for binary classification with logistic regression."""
812
813  def __init__(self,
814               label_name=None,
815               weight_column_name=None,
816               enable_centered_bias=False,
817               head_name=None,
818               loss_fn=None,
819               thresholds=None):
820    """`Head` for binary classification with logistic regression.
821
822    Args:
823      label_name: String, name of the key in label dict. Can be `None` if label
824          is a tensor (single headed models).
825      weight_column_name: A string defining feature column name representing
826        weights. It is used to down weight or boost examples during training. It
827        will be multiplied by the loss of the example.
828      enable_centered_bias: A bool. If True, estimator will learn a centered
829        bias variable for each class. Rest of the model structure learns the
830        residual after centered bias.
831      head_name: name of the head. Predictions, summary, metrics keys are
832        suffixed by `"/" + head_name` and the default variable scope is
833        `head_name`.
834      loss_fn: Loss function.
835      thresholds: thresholds for eval.
836
837    Raises:
838      ValueError: if n_classes is invalid.
839    """
840    super(_BinaryLogisticHead, self).__init__(
841        problem_type=constants.ProblemType.LOGISTIC_REGRESSION,
842        logits_dimension=1,
843        label_name=label_name,
844        weight_column_name=weight_column_name,
845        head_name=head_name)
846    self._thresholds = thresholds if thresholds else (.5,)
847    self._loss_fn = loss_fn if loss_fn else _log_loss_with_two_classes
848    self._enable_centered_bias = enable_centered_bias
849
850  def create_model_fn_ops(self,
851                          features,
852                          mode,
853                          labels=None,
854                          train_op_fn=None,
855                          logits=None,
856                          logits_input=None,
857                          scope=None):
858    """See `Head`."""
859    with variable_scope.variable_scope(
860        scope,
861        default_name=self.head_name or "binary_logistic_head",
862        values=(tuple(six.itervalues(features)) +
863                (labels, logits, logits_input))):
864      labels = self._transform_labels(mode=mode, labels=labels)
865      logits = _logits(logits_input, logits, self.logits_dimension)
866      return _create_model_fn_ops(
867          features=features,
868          mode=mode,
869          loss_fn=self._loss_fn,
870          logits_to_predictions_fn=self._logits_to_predictions,
871          metrics_fn=self._metrics,
872          create_output_alternatives_fn=_classification_output_alternatives(
873              self.head_name, self._problem_type),
874          labels=labels,
875          train_op_fn=train_op_fn,
876          logits=logits,
877          logits_dimension=self.logits_dimension,
878          head_name=self.head_name,
879          weight_column_name=self.weight_column_name,
880          enable_centered_bias=self._enable_centered_bias)
881
882  def _transform_labels(self, mode, labels):
883    """Applies transformations to labels tensor."""
884    if (mode == model_fn.ModeKeys.INFER) or (labels is None):
885      return None
886    labels_tensor = _to_labels_tensor(labels, self._label_name)
887    _check_no_sparse_tensor(labels_tensor)
888    return labels_tensor
889
890  def _logits_to_predictions(self, logits):
891    """Returns a dict of predictions.
892
893    Args:
894      logits: logits `Output` after applying possible centered bias.
895
896    Returns:
897      Dict of prediction `Output` keyed by `PredictionKey`.
898    """
899    with ops.name_scope(None, "predictions", (logits,)):
900      two_class_logits = _one_class_to_two_class_logits(logits)
901      return {
902          prediction_key.PredictionKey.LOGITS:
903              logits,
904          prediction_key.PredictionKey.LOGISTIC:
905              math_ops.sigmoid(
906                  logits, name=prediction_key.PredictionKey.LOGISTIC),
907          prediction_key.PredictionKey.PROBABILITIES:
908              nn.softmax(
909                  two_class_logits,
910                  name=prediction_key.PredictionKey.PROBABILITIES),
911          prediction_key.PredictionKey.CLASSES:
912              math_ops.argmax(
913                  two_class_logits,
914                  1,
915                  name=prediction_key.PredictionKey.CLASSES)
916      }
917
918  def _metrics(self, eval_loss, predictions, labels, weights):
919    """Returns a dict of metrics keyed by name."""
920    with ops.name_scope("metrics", values=(
921        [eval_loss, labels, weights] + list(six.itervalues(predictions)))):
922      classes = predictions[prediction_key.PredictionKey.CLASSES]
923      logistic = predictions[prediction_key.PredictionKey.LOGISTIC]
924
925      metrics = {_summary_key(self.head_name, mkey.LOSS):
926                 metrics_lib.mean(eval_loss)}
927      # TODO(b/29366811): This currently results in both an "accuracy" and an
928      # "accuracy/threshold_0.500000_mean" metric for binary classification.
929      metrics[_summary_key(self.head_name, mkey.ACCURACY)] = (
930          metrics_lib.accuracy(labels, classes, weights))
931      metrics[_summary_key(self.head_name, mkey.PREDICTION_MEAN)] = (
932          _predictions_streaming_mean(logistic, weights))
933      metrics[_summary_key(self.head_name, mkey.LABEL_MEAN)] = (
934          _indicator_labels_streaming_mean(labels, weights))
935
936      # Also include the streaming mean of the label as an accuracy baseline, as
937      # a reminder to users.
938      metrics[_summary_key(self.head_name, mkey.ACCURACY_BASELINE)] = (
939          _indicator_labels_streaming_mean(labels, weights))
940      metrics[_summary_key(self.head_name, mkey.AUC)] = (
941          _streaming_auc(logistic, labels, weights))
942      metrics[_summary_key(self.head_name, mkey.AUC_PR)] = (
943          _streaming_auc(logistic, labels, weights, curve="PR"))
944
945      for threshold in self._thresholds:
946        metrics[_summary_key(
947            self.head_name, mkey.ACCURACY_MEAN % threshold)] = (
948                _streaming_accuracy_at_threshold(logistic, labels, weights,
949                                                 threshold))
950        # Precision for positive examples.
951        metrics[_summary_key(
952            self.head_name, mkey.PRECISION_MEAN % threshold)] = (
953                _streaming_precision_at_threshold(logistic, labels, weights,
954                                                  threshold))
955        # Recall for positive examples.
956        metrics[_summary_key(
957            self.head_name, mkey.RECALL_MEAN % threshold)] = (
958                _streaming_recall_at_threshold(logistic, labels, weights,
959                                               threshold))
960
961    return metrics
962
963
964def _softmax_cross_entropy_loss(labels, logits, weights=None):
965  with ops.name_scope(
966      None, "softmax_cross_entropy_loss", (logits, labels,)) as name:
967    labels = ops.convert_to_tensor(labels)
968    # Check that we got integer for classification.
969    if not labels.dtype.is_integer:
970      raise ValueError("Labels dtype should be integer "
971                       "Instead got %s." % labels.dtype)
972
973    # sparse_softmax_cross_entropy_with_logits requires [batch_size] labels.
974    is_squeezed_labels = False
975    # TODO(ptucker): This will break for dynamic shapes.
976    if len(labels.get_shape()) == 2:
977      labels = array_ops.squeeze(labels, axis=(1,))
978      is_squeezed_labels = True
979
980    loss = nn.sparse_softmax_cross_entropy_with_logits(
981        labels=labels, logits=logits, name=name)
982
983    # Restore squeezed dimension, if necessary, so loss matches weights shape.
984    if is_squeezed_labels:
985      loss = array_ops.expand_dims(loss, axis=(1,))
986
987    return _compute_weighted_loss(loss, weights)
988
989
990class _MultiClassHead(_SingleHead):
991  """'Head' for multi class classification."""
992
993  def __init__(self,
994               n_classes,
995               label_name=None,
996               weight_column_name=None,
997               enable_centered_bias=False,
998               head_name=None,
999               loss_fn=None,
1000               thresholds=None,
1001               metric_class_ids=None,
1002               label_keys=None):
1003    """'Head' for multi class classification.
1004
1005    This head expects to be fed integer labels specifying the class index. But
1006    if `label_keys` is specified, then labels must be strings from this
1007    vocabulary, and the predicted classes will be strings from the same
1008    vocabulary.
1009
1010    Args:
1011      n_classes: Number of classes, must be greater than 2 (for 2 classes, use
1012        `_BinaryLogisticHead`).
1013      label_name: String, name of the key in label dict. Can be null if label
1014        is a tensor (single headed models).
1015      weight_column_name: A string defining feature column name representing
1016        weights. It is used to down weight or boost examples during training. It
1017        will be multiplied by the loss of the example.
1018      enable_centered_bias: A bool. If True, estimator will learn a centered
1019        bias variable for each class. Rest of the model structure learns the
1020        residual after centered bias.
1021      head_name: name of the head. If provided, predictions, summary, metrics
1022        keys will be suffixed by `"/" + head_name` and the default variable
1023        scope will be `head_name`.
1024      loss_fn: Loss function. Defaults to softmax cross entropy loss.
1025      thresholds: thresholds for eval.
1026      metric_class_ids: List of class IDs for which we should report per-class
1027        metrics. Must all be in the range `[0, n_classes)`.
1028      label_keys: Optional list of strings with size `[n_classes]` defining the
1029        label vocabulary.
1030
1031    Raises:
1032      ValueError: if `n_classes`, `metric_class_ids` or `label_keys` is invalid.
1033    """
1034    super(_MultiClassHead, self).__init__(
1035        problem_type=constants.ProblemType.CLASSIFICATION,
1036        logits_dimension=n_classes,
1037        label_name=label_name,
1038        weight_column_name=weight_column_name,
1039        head_name=head_name)
1040
1041    if (n_classes is None) or (n_classes <= 2):
1042      raise ValueError("n_classes must be > 2: %s." % n_classes)
1043    self._thresholds = thresholds if thresholds else (.5,)
1044    self._loss_fn = loss_fn if loss_fn else _softmax_cross_entropy_loss
1045    self._enable_centered_bias = enable_centered_bias
1046    self._metric_class_ids = tuple([] if metric_class_ids is None else
1047                                   metric_class_ids)
1048    for class_id in self._metric_class_ids:
1049      if (class_id < 0) or (class_id >= n_classes):
1050        raise ValueError("Class ID %s not in [0, %s)." % (class_id, n_classes))
1051    if label_keys and len(label_keys) != n_classes:
1052      raise ValueError("Length of label_keys must equal n_classes.")
1053    self._label_keys = label_keys
1054
1055  def create_model_fn_ops(self,
1056                          features,
1057                          mode,
1058                          labels=None,
1059                          train_op_fn=None,
1060                          logits=None,
1061                          logits_input=None,
1062                          scope=None):
1063    """See `Head`."""
1064    with variable_scope.variable_scope(
1065        scope,
1066        default_name=self.head_name or "multi_class_head",
1067        values=(tuple(six.itervalues(features)) +
1068                (labels, logits, logits_input))):
1069      labels = self._transform_labels(mode=mode, labels=labels)
1070      logits = _logits(logits_input, logits, self.logits_dimension)
1071      return _create_model_fn_ops(
1072          features=features,
1073          mode=mode,
1074          loss_fn=self._wrapped_loss_fn,
1075          logits_to_predictions_fn=self._logits_to_predictions,
1076          metrics_fn=self._metrics,
1077          create_output_alternatives_fn=_classification_output_alternatives(
1078              self.head_name, self._problem_type, self._label_keys),
1079          labels=labels,
1080          train_op_fn=train_op_fn,
1081          logits=logits,
1082          logits_dimension=self.logits_dimension,
1083          head_name=self.head_name,
1084          weight_column_name=self.weight_column_name,
1085          enable_centered_bias=self._enable_centered_bias)
1086
1087  def _transform_labels(self, mode, labels):
1088    """Returns a dict that contains both the original labels and label IDs."""
1089    if (mode == model_fn.ModeKeys.INFER) or (labels is None):
1090      return None
1091    labels_tensor = _to_labels_tensor(labels, self._label_name)
1092    _check_no_sparse_tensor(labels_tensor)
1093    if self._label_keys:
1094      table = lookup_ops.index_table_from_tensor(
1095          self._label_keys, name="label_id_lookup")
1096      return {
1097          "labels": labels_tensor,
1098          "label_ids": table.lookup(labels_tensor),
1099      }
1100    return {
1101        "labels": labels_tensor,
1102        "label_ids": labels_tensor,
1103    }
1104
1105  def _labels(self, labels_dict):
1106    """Returns labels `Tensor` of the same type as classes."""
1107    return labels_dict["labels"]
1108
1109  def _label_ids(self, labels_dict):
1110    """Returns integer label ID `Tensor`."""
1111    return labels_dict["label_ids"]
1112
1113  def _wrapped_loss_fn(self, labels, logits, weights=None):
1114    return self._loss_fn(self._label_ids(labels), logits, weights=weights)
1115
1116  def _logits_to_predictions(self, logits):
1117    """Returns a dict of predictions.
1118
1119    Args:
1120      logits: logits `Tensor` after applying possible centered bias.
1121
1122    Returns:
1123      Dict of prediction `Tensor` keyed by `PredictionKey`.
1124    """
1125    with ops.name_scope(None, "predictions", (logits,)):
1126      class_ids = math_ops.argmax(
1127          logits, 1, name=prediction_key.PredictionKey.CLASSES)
1128      if self._label_keys:
1129        table = lookup_ops.index_to_string_table_from_tensor(
1130            self._label_keys, name="class_string_lookup")
1131        classes = table.lookup(class_ids)
1132      else:
1133        classes = class_ids
1134      return {
1135          prediction_key.PredictionKey.LOGITS: logits,
1136          prediction_key.PredictionKey.PROBABILITIES:
1137              nn.softmax(
1138                  logits, name=prediction_key.PredictionKey.PROBABILITIES),
1139          prediction_key.PredictionKey.CLASSES: classes
1140      }
1141
1142  def _metrics(self, eval_loss, predictions, labels, weights):
1143    """Returns a dict of metrics keyed by name."""
1144    with ops.name_scope(
1145        "metrics",
1146        values=((eval_loss, self._labels(labels), self._label_ids(labels),
1147                 weights) + tuple(six.itervalues(predictions)))):
1148      logits = predictions[prediction_key.PredictionKey.LOGITS]
1149      probabilities = predictions[prediction_key.PredictionKey.PROBABILITIES]
1150      classes = predictions[prediction_key.PredictionKey.CLASSES]
1151
1152      metrics = {_summary_key(self.head_name, mkey.LOSS):
1153                 metrics_lib.mean(eval_loss)}
1154      # TODO(b/29366811): This currently results in both an "accuracy" and an
1155      # "accuracy/threshold_0.500000_mean" metric for binary classification.
1156      metrics[_summary_key(self.head_name, mkey.ACCURACY)] = (
1157          metrics_lib.accuracy(self._labels(labels), classes, weights))
1158
1159      if not self._label_keys:
1160        # Classes are IDs. Add some metrics.
1161        for class_id in self._metric_class_ids:
1162          metrics[_summary_key(
1163              self.head_name, mkey.CLASS_PREDICTION_MEAN % class_id)] = (
1164                  _class_predictions_streaming_mean(classes, weights, class_id))
1165          # TODO(ptucker): Add per-class accuracy, precision, recall.
1166          metrics[_summary_key(
1167              self.head_name, mkey.CLASS_LABEL_MEAN % class_id)] = (
1168                  _class_labels_streaming_mean(
1169                      self._label_ids(labels), weights, class_id))
1170          metrics[_summary_key(
1171              self.head_name, mkey.CLASS_PROBABILITY_MEAN % class_id)] = (
1172                  _predictions_streaming_mean(probabilities, weights, class_id))
1173          metrics[_summary_key(
1174              self.head_name, mkey.CLASS_LOGITS_MEAN % class_id)] = (
1175                  _predictions_streaming_mean(logits, weights, class_id))
1176
1177    return metrics
1178
1179
1180def _to_labels_tensor(labels, label_name):
1181  """Returns label as a tensor.
1182
1183  Args:
1184    labels: Label `Tensor` or `SparseTensor` or a dict containing labels.
1185    label_name: Label name if labels is a dict.
1186
1187  Returns:
1188    Label `Tensor` or `SparseTensor`.
1189  """
1190  labels = labels[label_name] if isinstance(labels, dict) else labels
1191  return framework_lib.convert_to_tensor_or_sparse_tensor(labels)
1192
1193
1194def _check_no_sparse_tensor(x):
1195  """Raises ValueError if the given tensor is `SparseTensor`."""
1196  if isinstance(x, sparse_tensor.SparseTensor):
1197    raise ValueError("SparseTensor is not supported.")
1198
1199
1200def _sparse_labels_to_indicator(labels, num_classes):
1201  """If labels is `SparseTensor`, converts it to indicator `Tensor`.
1202
1203  Args:
1204    labels: Label `Tensor` or `SparseTensor`.
1205    num_classes: Number of classes.
1206
1207  Returns:
1208    Dense label `Tensor`.
1209
1210  Raises:
1211    ValueError: If labels is `SparseTensor` and `num_classes` < 2.
1212  """
1213  if isinstance(labels, sparse_tensor.SparseTensor):
1214    if num_classes < 2:
1215      raise ValueError("Must set num_classes >= 2 when passing labels as a "
1216                       "SparseTensor.")
1217    return math_ops.cast(
1218        sparse_ops.sparse_to_indicator(labels, num_classes), dtypes.int64)
1219  return labels
1220
1221
1222def _assert_labels_rank(labels):
1223  return control_flow_ops.Assert(
1224      math_ops.less_equal(array_ops.rank(labels), 2),
1225      ("labels shape should be either [batch_size, 1] or [batch_size]",))
1226
1227
1228class _BinarySvmHead(_SingleHead):
1229  """`Head` for binary classification using SVM."""
1230
1231  def __init__(self, label_name, weight_column_name, enable_centered_bias,
1232               head_name, thresholds):
1233
1234    def _loss_fn(labels, logits, weights=None):
1235      with ops.name_scope(None, "hinge_loss", (logits, labels)) as name:
1236        with ops.control_dependencies((_assert_labels_rank(labels),)):
1237          labels = array_ops.reshape(labels, shape=(-1, 1))
1238        loss = losses_lib.hinge_loss(labels=labels, logits=logits, scope=name,
1239                                     reduction=losses_lib.Reduction.NONE)
1240        return _compute_weighted_loss(loss, weights)
1241
1242    super(_BinarySvmHead, self).__init__(
1243        problem_type=constants.ProblemType.LOGISTIC_REGRESSION,
1244        logits_dimension=1,
1245        label_name=label_name,
1246        weight_column_name=weight_column_name,
1247        head_name=head_name)
1248    self._thresholds = thresholds if thresholds else (.5,)
1249    self._loss_fn = _loss_fn
1250    self._enable_centered_bias = enable_centered_bias
1251
1252  def create_model_fn_ops(self,
1253                          features,
1254                          mode,
1255                          labels=None,
1256                          train_op_fn=None,
1257                          logits=None,
1258                          logits_input=None,
1259                          scope=None):
1260    """See `Head`."""
1261    with variable_scope.variable_scope(
1262        scope,
1263        default_name=self.head_name or "binary_svm_head",
1264        values=(tuple(six.itervalues(features)) +
1265                (labels, logits, logits_input))):
1266      labels = self._transform_labels(mode=mode, labels=labels)
1267      logits = _logits(logits_input, logits, self.logits_dimension)
1268      return _create_model_fn_ops(
1269          features=features,
1270          mode=mode,
1271          loss_fn=self._loss_fn,
1272          logits_to_predictions_fn=self._logits_to_predictions,
1273          metrics_fn=self._metrics,
1274          # TODO(zakaria): Handle labels for export.
1275          create_output_alternatives_fn=self._create_output_alternatives,
1276          labels=labels,
1277          train_op_fn=train_op_fn,
1278          logits=logits,
1279          logits_dimension=self.logits_dimension,
1280          head_name=self.head_name,
1281          weight_column_name=self.weight_column_name,
1282          enable_centered_bias=self._enable_centered_bias)
1283
1284  def _transform_labels(self, mode, labels):
1285    """Applies transformations to labels tensor."""
1286    if (mode == model_fn.ModeKeys.INFER) or (labels is None):
1287      return None
1288    labels_tensor = _to_labels_tensor(labels, self._label_name)
1289    _check_no_sparse_tensor(labels_tensor)
1290    return labels_tensor
1291
1292  def _logits_to_predictions(self, logits):
1293    """See `_MultiClassHead`."""
1294    with ops.name_scope(None, "predictions", (logits,)):
1295      return {
1296          prediction_key.PredictionKey.LOGITS:
1297              logits,
1298          prediction_key.PredictionKey.CLASSES:
1299              math_ops.argmax(
1300                  _one_class_to_two_class_logits(logits),
1301                  1,
1302                  name=prediction_key.PredictionKey.CLASSES)
1303      }
1304
1305  def _metrics(self, eval_loss, predictions, labels, weights):
1306    """See `_MultiClassHead`."""
1307    with ops.name_scope("metrics", values=(
1308        [eval_loss, labels, weights] + list(six.itervalues(predictions)))):
1309      metrics = {_summary_key(self.head_name, mkey.LOSS):
1310                 metrics_lib.mean(eval_loss)}
1311
1312      # TODO(b/29366811): This currently results in both an "accuracy" and an
1313      # "accuracy/threshold_0.500000_mean" metric for binary classification.
1314      classes = predictions[prediction_key.PredictionKey.CLASSES]
1315      metrics[_summary_key(self.head_name, mkey.ACCURACY)] = (
1316          metrics_lib.accuracy(labels, classes, weights))
1317      # TODO(sibyl-vie3Poto): add more metrics relevant for svms.
1318
1319    return metrics
1320
1321
1322class _MultiLabelHead(_SingleHead):
1323  """`Head` for multi-label classification."""
1324
1325  # TODO(zakaria): add signature and metric for multilabel.
1326  def __init__(self,
1327               n_classes,
1328               label_name,
1329               weight_column_name,
1330               enable_centered_bias,
1331               head_name,
1332               thresholds,
1333               metric_class_ids=None,
1334               loss_fn=None):
1335
1336    super(_MultiLabelHead, self).__init__(
1337        problem_type=constants.ProblemType.CLASSIFICATION,
1338        logits_dimension=n_classes,
1339        label_name=label_name,
1340        weight_column_name=weight_column_name,
1341        head_name=head_name)
1342
1343    self._thresholds = thresholds if thresholds else (.5,)
1344    self._loss_fn = loss_fn if loss_fn else _sigmoid_cross_entropy_loss
1345    self._enable_centered_bias = enable_centered_bias
1346    self._metric_class_ids = tuple([] if metric_class_ids is None else
1347                                   metric_class_ids)
1348    for class_id in self._metric_class_ids:
1349      if (class_id < 0) or (class_id >= n_classes):
1350        raise ValueError("Class ID %s not in [0, %s)." % (class_id, n_classes))
1351
1352  def create_model_fn_ops(self,
1353                          features,
1354                          mode,
1355                          labels=None,
1356                          train_op_fn=None,
1357                          logits=None,
1358                          logits_input=None,
1359                          scope=None):
1360    """See `Head`."""
1361    with variable_scope.variable_scope(
1362        scope,
1363        default_name=self.head_name or "multi_label_head",
1364        values=(tuple(six.itervalues(features)) +
1365                (labels, logits, logits_input))):
1366      labels = self._transform_labels(mode=mode, labels=labels)
1367      logits = _logits(logits_input, logits, self.logits_dimension)
1368      return _create_model_fn_ops(
1369          features=features,
1370          mode=mode,
1371          loss_fn=self._loss_fn,
1372          logits_to_predictions_fn=self._logits_to_predictions,
1373          metrics_fn=self._metrics,
1374          create_output_alternatives_fn=_classification_output_alternatives(
1375              self.head_name, self._problem_type),
1376          labels=labels,
1377          train_op_fn=train_op_fn,
1378          logits=logits,
1379          logits_dimension=self.logits_dimension,
1380          head_name=self.head_name,
1381          weight_column_name=self.weight_column_name,
1382          enable_centered_bias=self._enable_centered_bias)
1383
1384  def _transform_labels(self, mode, labels):
1385    """Applies transformations to labels tensor."""
1386    if (mode == model_fn.ModeKeys.INFER) or (labels is None):
1387      return None
1388    labels_tensor = _to_labels_tensor(labels, self._label_name)
1389    labels_tensor = _sparse_labels_to_indicator(labels_tensor,
1390                                                self._logits_dimension)
1391    return labels_tensor
1392
1393  def _logits_to_predictions(self, logits):
1394    """See `_MultiClassHead`."""
1395    with ops.name_scope(None, "predictions", (logits,)):
1396      return {
1397          prediction_key.PredictionKey.LOGITS:
1398              logits,
1399          prediction_key.PredictionKey.PROBABILITIES:
1400              math_ops.sigmoid(
1401                  logits, name=prediction_key.PredictionKey.PROBABILITIES),
1402          prediction_key.PredictionKey.CLASSES:
1403              math_ops.cast(
1404                  math_ops.greater(logits, 0),
1405                  dtypes.int64,
1406                  name=prediction_key.PredictionKey.CLASSES)
1407      }
1408
1409  def _metrics(self, eval_loss, predictions, labels, weights):
1410    """Returns a dict of metrics keyed by name."""
1411    with ops.name_scope("metrics", values=(
1412        [eval_loss, labels, weights] + list(six.itervalues(predictions)))):
1413      classes = predictions[prediction_key.PredictionKey.CLASSES]
1414      probabilities = predictions[prediction_key.PredictionKey.PROBABILITIES]
1415      logits = predictions[prediction_key.PredictionKey.LOGITS]
1416
1417      metrics = {_summary_key(self.head_name, mkey.LOSS):
1418                 metrics_lib.mean(eval_loss)}
1419      # TODO(b/29366811): This currently results in both an "accuracy" and an
1420      # "accuracy/threshold_0.500000_mean" metric for binary classification.
1421      metrics[_summary_key(self.head_name, mkey.ACCURACY)] = (
1422          metrics_lib.accuracy(labels, classes, weights))
1423      metrics[_summary_key(self.head_name, mkey.AUC)] = _streaming_auc(
1424          probabilities, labels, weights)
1425      metrics[_summary_key(self.head_name, mkey.AUC_PR)] = _streaming_auc(
1426          probabilities, labels, weights, curve="PR")
1427
1428      for class_id in self._metric_class_ids:
1429        # TODO(ptucker): Add per-class accuracy, precision, recall.
1430        metrics[_summary_key(
1431            self.head_name, mkey.CLASS_PREDICTION_MEAN % class_id)] = (
1432                _predictions_streaming_mean(classes, weights, class_id))
1433        metrics[_summary_key(
1434            self.head_name, mkey.CLASS_LABEL_MEAN % class_id)] = (
1435                _indicator_labels_streaming_mean(labels, weights, class_id))
1436        metrics[_summary_key(
1437            self.head_name, mkey.CLASS_PROBABILITY_MEAN % class_id)] = (
1438                _predictions_streaming_mean(probabilities, weights, class_id))
1439        metrics[_summary_key(
1440            self.head_name, mkey.CLASS_LOGITS_MEAN % class_id)] = (
1441                _predictions_streaming_mean(logits, weights, class_id))
1442        metrics[_summary_key(self.head_name, mkey.CLASS_AUC % class_id)] = (
1443            _streaming_auc(probabilities, labels, weights, class_id))
1444        metrics[_summary_key(self.head_name, mkey.CLASS_AUC_PR % class_id)] = (
1445            _streaming_auc(probabilities, labels, weights, class_id,
1446                           curve="PR"))
1447
1448    return metrics
1449
1450
1451class _LossOnlyHead(Head):
1452  """`Head` implementation for additional loss terms.
1453
1454  This class only holds loss terms unrelated to any other heads (labels),
1455  e.g. regularization.
1456
1457  Common usage:
1458  This is oftem combine with other heads in a multi head setup.
1459    ```python
1460    head = multi_head([
1461        head1, head2, loss_only_head('regularizer', regularizer)])
1462    ```
1463  """
1464
1465  def __init__(self, loss_fn, head_name=None):
1466    self._loss_fn = loss_fn
1467    self.head_name = head_name or "loss_only_head"
1468
1469  @property
1470  def logits_dimension(self):
1471    return 0
1472
1473  def create_model_fn_ops(self,
1474                          features,
1475                          mode,
1476                          labels=None,
1477                          train_op_fn=None,
1478                          logits=None,
1479                          logits_input=None,
1480                          scope=None):
1481    """See `_Head.create_model_fn_ops`.
1482
1483    Args:
1484      features: Not been used.
1485      mode: Estimator's `ModeKeys`.
1486      labels: Labels `Tensor`, or `dict` of same.
1487      train_op_fn: Function that takes a scalar loss and returns an op to
1488          optimize with the loss.
1489      logits: Not been used.
1490      logits_input: Not been used.
1491      scope: Optional scope for variable_scope. If provided, will be passed to
1492          all heads. Most users will want to set this to `None`, so each head
1493          constructs a separate variable_scope according to its `head_name`.
1494
1495    Returns:
1496      A `ModelFnOps` object.
1497
1498    Raises:
1499      ValueError: if `mode` is not recognition.
1500    """
1501    _check_mode_valid(mode)
1502    loss = None
1503    train_op = None
1504    if mode != model_fn.ModeKeys.INFER:
1505      with variable_scope.variable_scope(scope, default_name=self.head_name):
1506        loss = self._loss_fn()
1507        if isinstance(loss, list):
1508          loss = math_ops.add_n(loss)
1509        # The name_scope escapism is needed to maintain the same summary tag
1510        # after switching away from the now unsupported API.
1511        with ops.name_scope(""):
1512          summary_loss = array_ops.identity(loss)
1513          summary.scalar(_summary_key(self.head_name, mkey.LOSS),
1514                         summary_loss)
1515        if mode == model_fn.ModeKeys.TRAIN:
1516          if train_op_fn is None:
1517            raise ValueError("train_op_fn can not be None in TRAIN mode")
1518          with ops.name_scope(None, "train_op", (loss,)):
1519            train_op = train_op_fn(loss)
1520
1521    return model_fn.ModelFnOps(
1522        mode=mode,
1523        loss=loss,
1524        train_op=train_op,
1525        predictions={},
1526        eval_metric_ops={})
1527
1528
1529class _MultiHead(Head):
1530  """`Head` implementation for multi objective learning.
1531
1532  This class is responsible for using and merging the output of multiple
1533  `Head` objects.
1534
1535  All heads stem from the same logits/logit_input tensor.
1536
1537  Common usage:
1538  For simple use cases you can pass the activation of hidden layer like
1539  this from your model_fn,
1540    ```python
1541    last_hidden_layer_activation = ... Build your model.
1542    multi_head = ...
1543    return multi_head.create_model_fn_ops(
1544        ..., logits_input=last_hidden_layer_activation, ...)
1545    ```
1546
1547  Or you can create a logits tensor of
1548  [batch_size, multi_head.logits_dimension] shape. _MultiHead will split the
1549  logits for you.
1550    return multi_head.create_model_fn_ops(..., logits=logits, ...)
1551
1552  For more complex use cases like a multi-task/multi-tower model or when logits
1553  for each head has to be created separately, you can pass a dict of logits
1554  where the keys match the name of the single heads.
1555    ```python
1556    logits = {"head1": logits1, "head2": logits2}
1557    return multi_head.create_model_fn_ops(..., logits=logits, ...)
1558    ```
1559
1560  Here is what this class does,
1561  + For training, merges losses of each heads according a function provided by
1562      user, calls user provided train_op_fn with this final loss.
1563  + For eval, merges metrics by adding head_name suffix to the keys in eval
1564      metrics.
1565  + For inference, updates keys in prediction dict to a 2-tuple,
1566      (head_name, prediction_key)
1567  """
1568
1569  def __init__(self, heads, loss_merger):
1570    """_Head to merges multiple _Head objects.
1571
1572    Args:
1573      heads: list of _Head objects.
1574      loss_merger: function that takes a list of loss tensors for the heads
1575        and returns the final loss tensor for the multi head.
1576
1577    Raises:
1578      ValueError: if any head does not have a name.
1579    """
1580    self._logits_dimension = 0
1581    for head in heads:
1582      if not head.head_name:
1583        raise ValueError("Members of MultiHead must have names.")
1584      self._logits_dimension += head.logits_dimension
1585
1586    self._heads = heads
1587    self._loss_merger = loss_merger
1588
1589  @property
1590  def logits_dimension(self):
1591    return self._logits_dimension
1592
1593  def create_model_fn_ops(self,
1594                          features,
1595                          mode,
1596                          labels=None,
1597                          train_op_fn=None,
1598                          logits=None,
1599                          logits_input=None,
1600                          scope=None):
1601    """See `_Head.create_model_fn_ops`.
1602
1603    Args:
1604      features: Input `dict` of `Tensor` objects.
1605      mode: Estimator's `ModeKeys`.
1606      labels: Labels `Tensor`, or `dict` of same.
1607      train_op_fn: Function that takes a scalar loss and returns an op to
1608          optimize with the loss.
1609      logits: Concatenated logits for all heads or a dict of head name to logits
1610          tensor. If concatenated logits, it should have (batchsize, x) shape
1611          where x is the sum of `logits_dimension` of all the heads,
1612          i.e., same as `logits_dimension` of this class. create_model_fn_ops
1613          will split the logits tensor and pass logits of proper size to each
1614          head. This is useful if we want to be agnostic about whether you
1615          creating a single versus multihead. logits can also be a dict for
1616          convenience where you are creating the head specific logits explicitly
1617          and don't want to concatenate them yourself.
1618      logits_input: tensor to build logits from.
1619      scope: Optional scope for variable_scope. If provided, will be passed to
1620        all heads. Most users will want to set this to `None`, so each head
1621        constructs a separate variable_scope according to its `head_name`.
1622
1623    Returns:
1624      `ModelFnOps`.
1625
1626    Raises:
1627      ValueError: if `mode` is not recognized, or neither or both of `logits`
1628          and `logits_input` is provided.
1629    """
1630    _check_mode_valid(mode)
1631    all_model_fn_ops = []
1632    if logits is None:
1633      # Use logits_input.
1634      for head in self._heads:
1635        all_model_fn_ops.append(
1636            head.create_model_fn_ops(
1637                features=features,
1638                mode=mode,
1639                labels=labels,
1640                train_op_fn=no_op_train_fn,
1641                logits_input=logits_input,
1642                scope=scope))
1643    else:
1644      head_logits_pairs = []
1645      if isinstance(logits, dict):
1646        head_logits_pairs = []
1647        for head in self._heads:
1648          if isinstance(head, _LossOnlyHead):
1649            head_logits_pairs.append((head, None))
1650          else:
1651            head_logits_pairs.append((head, logits[head.head_name]))
1652      else:
1653        # Split logits for each head.
1654        head_logits_pairs = zip(self._heads, self._split_logits(logits))
1655
1656      for head, head_logits in head_logits_pairs:
1657        all_model_fn_ops.append(
1658            head.create_model_fn_ops(
1659                features=features,
1660                mode=mode,
1661                labels=labels,
1662                train_op_fn=no_op_train_fn,
1663                logits=head_logits,
1664                scope=scope))
1665
1666    if mode == model_fn.ModeKeys.TRAIN:
1667      if train_op_fn is None:
1668        raise ValueError("train_op_fn can not be None in TRAIN mode.")
1669      return self._merge_train(all_model_fn_ops, train_op_fn)
1670    if mode == model_fn.ModeKeys.INFER:
1671      return self._merge_infer(all_model_fn_ops)
1672    if mode == model_fn.ModeKeys.EVAL:
1673      return self._merge_eval(all_model_fn_ops)
1674    raise ValueError("mode=%s unrecognized" % str(mode))
1675
1676  def _split_logits(self, logits):
1677    """Splits logits for heads.
1678
1679    Args:
1680      logits: the logits tensor.
1681
1682    Returns:
1683      A list of logits for the individual heads.
1684    """
1685    all_logits = []
1686    begin = 0
1687    for head in self._heads:
1688      current_logits_size = head.logits_dimension
1689      current_logits = array_ops.slice(logits, [0, begin],
1690                                       [-1, current_logits_size])
1691      all_logits.append(current_logits)
1692      begin += current_logits_size
1693    return all_logits
1694
1695  def _merge_train(self, all_model_fn_ops, train_op_fn):
1696    """Merges list of ModelFnOps for training.
1697
1698    Args:
1699      all_model_fn_ops: list of ModelFnOps for the individual heads.
1700      train_op_fn: Function to create train op. See `create_model_fn_ops`
1701          documentation for more details.
1702
1703    Returns:
1704      ModelFnOps that merges all heads for TRAIN.
1705    """
1706    losses = []
1707    metrics = {}
1708    additional_train_ops = []
1709    for m in all_model_fn_ops:
1710      losses.append(m.loss)
1711      if m.eval_metric_ops is not None:
1712        for k, v in six.iteritems(m.eval_metric_ops):
1713          # metrics["%s/%s" % (k, head_name)] = v
1714          metrics[k] = v
1715      additional_train_ops.append(m.train_op)
1716    loss = self._loss_merger(losses)
1717
1718    train_op = train_op_fn(loss)
1719    train_op = control_flow_ops.group(train_op, *additional_train_ops)
1720    return model_fn.ModelFnOps(
1721        mode=model_fn.ModeKeys.TRAIN,
1722        loss=loss,
1723        train_op=train_op,
1724        eval_metric_ops=metrics)
1725
1726  def _merge_infer(self, all_model_fn_ops):
1727    """Merges list of ModelFnOps for inference.
1728
1729    Args:
1730      all_model_fn_ops: list of ModelFnOps for the individual heads.
1731
1732    Returns:
1733      ModelFnOps that Merges all the heads for INFER.
1734    """
1735    predictions = {}
1736    output_alternatives = {}
1737    for head, m in zip(self._heads, all_model_fn_ops):
1738      if isinstance(head, _LossOnlyHead):
1739        continue
1740      head_name = head.head_name
1741      output_alternatives[head_name] = m.output_alternatives[head_name]
1742      for k, v in m.predictions.items():
1743        predictions[(head_name, k)] = v
1744
1745    return model_fn.ModelFnOps(
1746        mode=model_fn.ModeKeys.INFER,
1747        predictions=predictions,
1748        output_alternatives=output_alternatives)
1749
1750  def _merge_eval(self, all_model_fn_ops):
1751    """Merges list of ModelFnOps for eval.
1752
1753    Args:
1754      all_model_fn_ops: list of ModelFnOps for the individual heads.
1755
1756    Returns:
1757      ModelFnOps that merges all the heads for EVAL.
1758    """
1759    predictions = {}
1760    metrics = {}
1761    losses = []
1762    for head, m in zip(self._heads, all_model_fn_ops):
1763      losses.append(m.loss)
1764      head_name = head.head_name
1765      for k, v in m.predictions.items():
1766        predictions[(head_name, k)] = v
1767      for k, v in m.eval_metric_ops.items():
1768        # metrics["%s/%s" % (k, head_name)] = v
1769        metrics[k] = v
1770    loss = self._loss_merger(losses)
1771
1772    return model_fn.ModelFnOps(
1773        mode=model_fn.ModeKeys.EVAL,
1774        predictions=predictions,
1775        loss=loss,
1776        eval_metric_ops=metrics)
1777
1778
1779def _weight_tensor(features, weight_column_name):
1780  """Returns weights as `Tensor` of rank 0, or at least 2."""
1781  if not weight_column_name:
1782    return None
1783  if weight_column_name not in features:
1784    raise ValueError("Weights {} missing from features.".format(
1785        weight_column_name))
1786  with ops.name_scope(None, "weight_tensor", tuple(six.itervalues(features))):
1787    weight_tensor = math_ops.cast(features[weight_column_name], dtypes.float32)
1788    shape = weight_tensor.get_shape()
1789    rank = shape.ndims
1790    # We don't bother with expanding dims of non-staticly shaped tensors or
1791    # scalars, and >1d is already in a good format.
1792    if rank == 1:
1793      logging.warning("Weights {} has shape {}, expanding to make it 2d.".
1794                      format(weight_column_name, shape))
1795      return (
1796          sparse_ops.sparse_reshape(weight_tensor, (-1, 1))
1797          if isinstance(weight_tensor, sparse_tensor.SparseTensor) else
1798          array_ops.reshape(weight_tensor, (-1, 1)))
1799    return weight_tensor
1800
1801
1802# TODO(zakaria): This function is needed for backward compatibility and should
1803#   be removed when we migrate to core.
1804def _compute_weighted_loss(loss_unweighted, weight, name="loss"):
1805  """Returns a tuple of (loss_train, loss_report).
1806
1807  loss is used for gradient descent while weighted_average_loss is used for
1808  summaries to be backward compatible.
1809
1810  loss is different from the loss reported on the tensorboard as we
1811  should respect the example weights when computing the gradient.
1812
1813    L = sum_{i} w_{i} * l_{i} / B
1814
1815  where B is the number of examples in the batch, l_{i}, w_{i} are individual
1816  losses, and example weight.
1817
1818  Args:
1819    loss_unweighted: Unweighted loss
1820    weight: Weight tensor
1821    name: Optional name
1822
1823  Returns:
1824    A tuple of losses. First one for training and the second one for reporting.
1825  """
1826  with ops.name_scope(name, values=(loss_unweighted, weight)) as name_scope:
1827    if weight is None:
1828      loss = math_ops.reduce_mean(loss_unweighted, name=name_scope)
1829      return loss, loss
1830    weight = weights_broadcast_ops.broadcast_weights(weight, loss_unweighted)
1831    with ops.name_scope(None, "weighted_loss",
1832                        (loss_unweighted, weight)) as name:
1833      weighted_loss = math_ops.multiply(loss_unweighted, weight, name=name)
1834    weighted_loss_mean = math_ops.reduce_mean(weighted_loss, name=name_scope)
1835    weighted_loss_normalized = math_ops.div(
1836        math_ops.reduce_sum(weighted_loss),
1837        math_ops.cast(math_ops.reduce_sum(weight), dtypes.float32),
1838        name="weighted_average_loss")
1839
1840    return weighted_loss_mean, weighted_loss_normalized
1841
1842
1843def _wrap_custom_loss_fn(loss_fn):
1844  def _wrapper(labels, logits, weights=None):
1845    if weights is None:
1846      loss = loss_fn(labels, logits)
1847    else:
1848      loss = loss_fn(labels, logits, weights)
1849    return loss, loss
1850  return _wrapper
1851
1852
1853def _check_mode_valid(mode):
1854  """Raises ValueError if the given mode is invalid."""
1855  if (mode != model_fn.ModeKeys.TRAIN and mode != model_fn.ModeKeys.INFER and
1856      mode != model_fn.ModeKeys.EVAL):
1857    raise ValueError("mode=%s unrecognized." % str(mode))
1858
1859
1860def _get_arguments(func):
1861  """Returns a spec of given func."""
1862  _, func = tf_decorator.unwrap(func)
1863  if hasattr(func, "__code__"):
1864    # Regular function.
1865    return tf_inspect.getargspec(func)
1866  elif hasattr(func, "func"):
1867    # Partial function.
1868    return _get_arguments(func.func)
1869  elif hasattr(func, "__call__"):
1870    # Callable object.
1871    return _get_arguments(func.__call__)
1872
1873
1874def _verify_loss_fn_args(loss_fn):
1875  args = _get_arguments(loss_fn).args
1876  for arg_name in ["labels", "logits", "weights"]:
1877    if arg_name not in args:
1878      raise ValueError("Argument %s not found in loss_fn." % arg_name)
1879
1880
1881def _centered_bias(logits_dimension, head_name=None):
1882  """Returns centered_bias `Variable`.
1883
1884  Args:
1885    logits_dimension: Last dimension of `logits`. Must be >= 1.
1886    head_name: Optional name of the head.
1887
1888  Returns:
1889    `Variable` with shape `[logits_dimension]`.
1890
1891  Raises:
1892    ValueError: if `logits_dimension` is invalid.
1893  """
1894  if (logits_dimension is None) or (logits_dimension < 1):
1895    raise ValueError("Invalid logits_dimension %s." % logits_dimension)
1896  # Do not create a variable with variable_scope.get_variable, because that may
1897  # create a PartitionedVariable, which does not support indexing, so
1898  # summary.scalar will not work.
1899  centered_bias = variable_scope.variable(
1900      name="centered_bias_weight",
1901      initial_value=array_ops.zeros(shape=(logits_dimension,)),
1902      trainable=True)
1903  for dim in range(logits_dimension):
1904    if head_name:
1905      summary.scalar("centered_bias/bias_%d/%s" % (dim, head_name),
1906                     centered_bias[dim])
1907    else:
1908      summary.scalar("centered_bias/bias_%d" % dim, centered_bias[dim])
1909  return centered_bias
1910
1911
1912def _centered_bias_step(centered_bias, batch_size, labels, loss_fn, weights):
1913  """Creates and returns training op for centered bias."""
1914  with ops.name_scope(None, "centered_bias_step", (labels,)) as name:
1915    logits_dimension = array_ops.shape(centered_bias)[0]
1916    logits = array_ops.reshape(
1917        array_ops.tile(centered_bias, (batch_size,)),
1918        (batch_size, logits_dimension))
1919    with ops.name_scope(None, "centered_bias", (labels, logits)):
1920      centered_bias_loss = math_ops.reduce_mean(
1921          loss_fn(labels, logits, weights), name="training_loss")
1922  # Learn central bias by an optimizer. 0.1 is a convervative lr for a
1923  # single variable.
1924  return training.AdagradOptimizer(0.1).minimize(
1925      centered_bias_loss, var_list=(centered_bias,), name=name)
1926
1927
1928def _summary_key(head_name, val):
1929  return "%s/%s" % (val, head_name) if head_name else val
1930
1931
1932def _train_op(loss, labels, train_op_fn, centered_bias, batch_size, loss_fn,
1933              weights):
1934  """Returns op for the training step."""
1935  if centered_bias is not None:
1936    centered_bias_step = _centered_bias_step(
1937        centered_bias=centered_bias,
1938        batch_size=batch_size,
1939        labels=labels,
1940        loss_fn=loss_fn,
1941        weights=weights)
1942  else:
1943    centered_bias_step = None
1944  with ops.name_scope(None, "train_op", (loss, labels)):
1945    train_op = train_op_fn(loss)
1946    if centered_bias_step is not None:
1947      train_op = control_flow_ops.group(train_op, centered_bias_step)
1948    return train_op
1949
1950
1951def _sigmoid_cross_entropy_loss(labels, logits, weights=None):
1952  with ops.name_scope(None, "sigmoid_cross_entropy_loss",
1953                      (logits, labels)) as name:
1954    # sigmoid_cross_entropy_with_logits requires [batch_size, n_classes] labels.
1955    loss = nn.sigmoid_cross_entropy_with_logits(
1956        labels=math_ops.cast(labels, dtypes.float32), logits=logits, name=name)
1957    return _compute_weighted_loss(loss, weights)
1958
1959
1960def _float_weights_or_none(weights):
1961  if weights is None:
1962    return None
1963  with ops.name_scope(None, "float_weights", (weights,)) as name:
1964    return math_ops.cast(weights, dtypes.float32, name=name)
1965
1966
1967def _indicator_labels_streaming_mean(labels, weights=None, class_id=None):
1968  labels = math_ops.cast(labels, dtypes.float32)
1969  weights = _float_weights_or_none(weights)
1970  if weights is not None:
1971    weights = weights_broadcast_ops.broadcast_weights(weights, labels)
1972  if class_id is not None:
1973    if weights is not None:
1974      weights = weights[:, class_id]
1975    labels = labels[:, class_id]
1976  return metrics_lib.mean(labels, weights)
1977
1978
1979def _predictions_streaming_mean(predictions,
1980                                weights=None,
1981                                class_id=None):
1982  predictions = math_ops.cast(predictions, dtypes.float32)
1983  weights = _float_weights_or_none(weights)
1984  if weights is not None:
1985    weights = weights_broadcast_ops.broadcast_weights(weights, predictions)
1986  if class_id is not None:
1987    if weights is not None:
1988      weights = weights[:, class_id]
1989    predictions = predictions[:, class_id]
1990  return metrics_lib.mean(predictions, weights)
1991
1992
1993# TODO(ptucker): Add support for SparseTensor labels.
1994def _class_id_labels_to_indicator(labels, num_classes):
1995  if (num_classes is None) or (num_classes < 2):
1996    raise ValueError("Invalid num_classes %s." % num_classes)
1997  with ops.control_dependencies((_assert_labels_rank(labels),)):
1998    labels = array_ops.reshape(labels, (-1,))
1999  return array_ops.one_hot(labels, depth=num_classes, axis=-1)
2000
2001
2002def _class_predictions_streaming_mean(predictions, weights, class_id):
2003  return metrics_lib.mean(
2004      array_ops.where(
2005          math_ops.equal(
2006              math_ops.cast(class_id, dtypes.int32),
2007              math_ops.cast(predictions, dtypes.int32)),
2008          array_ops.ones_like(predictions), array_ops.zeros_like(predictions)),
2009      weights=weights)
2010
2011
2012def _class_labels_streaming_mean(labels, weights, class_id):
2013  return metrics_lib.mean(
2014      array_ops.where(
2015          math_ops.equal(
2016              math_ops.cast(class_id, dtypes.int32),
2017              math_ops.cast(labels, dtypes.int32)), array_ops.ones_like(labels),
2018          array_ops.zeros_like(labels)),
2019      weights=weights)
2020
2021
2022def _streaming_auc(predictions, labels, weights=None, class_id=None,
2023                   curve="ROC"):
2024  # pylint: disable=missing-docstring
2025  predictions = math_ops.cast(predictions, dtypes.float32)
2026  if labels.dtype.base_dtype != dtypes.bool:
2027    logging.warning("Casting %s labels to bool.", labels.dtype)
2028    labels = math_ops.cast(labels, dtypes.bool)
2029  weights = _float_weights_or_none(weights)
2030  if weights is not None:
2031    weights = weights_broadcast_ops.broadcast_weights(weights, predictions)
2032  if class_id is not None:
2033    if weights is not None:
2034      weights = weights[:, class_id]
2035    predictions = predictions[:, class_id]
2036    labels = labels[:, class_id]
2037  return metrics_lib.auc(labels, predictions, weights, curve=curve)
2038
2039
2040def _assert_class_id(class_id, num_classes=None):
2041  """Average label value for class `class_id`."""
2042  if (class_id is None) or (class_id < 0):
2043    raise ValueError("Invalid class_id %s." % class_id)
2044  if num_classes is not None:
2045    if num_classes < 2:
2046      raise ValueError("Invalid num_classes %s." % num_classes)
2047    if class_id >= num_classes:
2048      raise ValueError("Invalid class_id %s." % class_id)
2049
2050
2051def _streaming_accuracy_at_threshold(predictions, labels, weights, threshold):
2052  threshold_predictions = math_ops.cast(
2053      math_ops.greater_equal(predictions, threshold), dtypes.float32)
2054  return metrics_lib.accuracy(labels, threshold_predictions, weights)
2055
2056
2057def _streaming_precision_at_threshold(predictions, labels, weights, threshold):
2058  precision_tensor, update_op = metrics_lib.precision_at_thresholds(
2059      labels, predictions, (threshold,), _float_weights_or_none(weights))
2060  return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op)
2061
2062
2063def _streaming_recall_at_threshold(predictions, labels, weights, threshold):
2064  precision_tensor, update_op = metrics_lib.recall_at_thresholds(
2065      labels, predictions, (threshold,), _float_weights_or_none(weights))
2066  return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op)
2067
2068
2069def _classification_output_alternatives(head_name, problem_type,
2070                                        label_keys=None):
2071  """Creates a func to generate output alternatives for classification.
2072
2073  Servo expects classes to be a string tensor, and have the same dimensions
2074  as the probabilities tensor. It should contain the labels of the corresponding
2075  entries in probabilities. This function creates a new classes tensor that
2076  satisfies these conditions and can be exported.
2077
2078  Args:
2079    head_name: Name of the head.
2080    problem_type: `ProblemType`
2081    label_keys: Optional label keys
2082
2083  Returns:
2084    A function to generate output alternatives.
2085  """
2086  def _create_output_alternatives(predictions):
2087    """Creates output alternative for the Head.
2088
2089    Args:
2090      predictions: a dict of {tensor_name: Tensor}, where 'tensor_name' is a
2091        symbolic name for an output Tensor possibly but not necessarily taken
2092        from `PredictionKey`, and 'Tensor' is the corresponding output Tensor
2093        itself.
2094
2095    Returns:
2096      `dict` of {submodel_name: (problem_type, {tensor_name: Tensor})}, where
2097      'submodel_name' is a submodel identifier that should be consistent across
2098      the pipeline (here likely taken from the head_name),
2099      'problem_type' is a `ProblemType`,
2100      'tensor_name' is a symbolic name for an output Tensor possibly but not
2101       necessarily taken from `PredictionKey`, and
2102      'Tensor' is the corresponding output Tensor itself.
2103
2104    Raises:
2105      ValueError: if predictions does not have PredictionKey.PROBABILITIES key.
2106    """
2107    probabilities = predictions.get(prediction_key.PredictionKey.PROBABILITIES)
2108    if probabilities is None:
2109      raise ValueError("%s missing in predictions" %
2110                       prediction_key.PredictionKey.PROBABILITIES)
2111
2112    with ops.name_scope(None, "_classification_output_alternatives",
2113                        (probabilities,)):
2114      batch_size = array_ops.shape(probabilities)[0]
2115      if label_keys:
2116        classes = array_ops.tile(
2117            input=array_ops.expand_dims(input=label_keys, axis=0),
2118            multiples=[batch_size, 1],
2119            name="classes_tensor")
2120      else:
2121        n = array_ops.shape(probabilities)[1]
2122        classes = array_ops.tile(
2123            input=array_ops.expand_dims(input=math_ops.range(n), axis=0),
2124            multiples=[batch_size, 1])
2125        classes = string_ops.as_string(classes, name="classes_tensor")
2126
2127    exported_predictions = {
2128        prediction_key.PredictionKey.PROBABILITIES: probabilities,
2129        prediction_key.PredictionKey.CLASSES: classes}
2130    return {head_name: (problem_type, exported_predictions)}
2131
2132  return _create_output_alternatives
2133
2134# Aliases
2135# TODO(zakaria): Remove these aliases, See b/34751732
2136_regression_head = regression_head
2137_poisson_regression_head = poisson_regression_head
2138_multi_class_head = multi_class_head
2139_binary_svm_head = binary_svm_head
2140_multi_label_head = multi_label_head
2141_multi_head = multi_head
2142_Head = Head
2143