1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Classes and methods related to model_fn (deprecated).
17
18This module and all its submodules are deprecated. See
19[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
20for migration instructions.
21"""
22
23from __future__ import absolute_import
24from __future__ import division
25from __future__ import print_function
26
27import collections
28
29import six
30
31from tensorflow.contrib.framework import get_graph_from_inputs
32from tensorflow.contrib.learn.python.learn.estimators import constants
33from tensorflow.contrib.learn.python.learn.estimators import metric_key
34from tensorflow.contrib.learn.python.learn.estimators import prediction_key
35from tensorflow.python.estimator import model_fn as core_model_fn_lib
36from tensorflow.python.estimator.export import export_output as core_export_lib
37from tensorflow.python.framework import dtypes
38from tensorflow.python.framework import ops
39from tensorflow.python.framework import sparse_tensor
40from tensorflow.python.framework import tensor_shape
41from tensorflow.python.ops import array_ops
42from tensorflow.python.platform import tf_logging as logging
43from tensorflow.python.saved_model import signature_constants
44from tensorflow.python.training import session_run_hook
45from tensorflow.python.util.deprecation import deprecated
46
47
48class ModeKeys(object):
49  """Standard names for model modes (deprecated).
50
51  THIS CLASS IS DEPRECATED.
52
53  The following standard keys are defined:
54
55  * `TRAIN`: training mode.
56  * `EVAL`: evaluation mode.
57  * `INFER`: inference mode.
58  """
59
60  TRAIN = 'train'
61  EVAL = 'eval'
62  INFER = 'infer'
63
64  @classmethod
65  def validate(cls, key):
66    if key not in (cls.TRAIN, cls.EVAL, cls.INFER):
67      raise ValueError('Invalid mode %s.' % key)
68
69
70class ModelFnOps(
71    collections.namedtuple('ModelFnOps', [
72        'predictions', 'loss', 'train_op', 'eval_metric_ops',
73        'output_alternatives', 'training_chief_hooks', 'training_hooks',
74        'scaffold', 'mode'
75    ])):
76  """Ops returned from a model_fn.
77
78  THIS CLASS IS DEPRECATED. See
79  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
80  for general migration instructions.
81  """
82
83  @deprecated(None, 'When switching to tf.estimator.Estimator, use '
84              'tf.estimator.EstimatorSpec. You can use the `estimator_spec`'
85              ' method to create an equivalent one.')
86  def __new__(cls,
87              mode,
88              predictions=None,
89              loss=None,
90              train_op=None,
91              eval_metric_ops=None,
92              output_alternatives=None,
93              training_chief_hooks=None,
94              training_hooks=None,
95              scaffold=None):
96    """Creates a validated `ModelFnOps` instance.
97
98    For a multi-headed model, the predictions dict here will contain the outputs
99    of all of the heads.  However: at serving time, requests will be made
100    specifically for one or more heads, and the RPCs used for these requests may
101    differ by problem type (i.e., regression, classification, other).  The
102    purpose of the output_alternatives dict is to aid in exporting a SavedModel
103    from which such head-specific queries can be served.  These
104    output_alternatives will be combined with input_alternatives (see
105    `saved_model_export_utils`) to produce a set of `SignatureDef`s specifying
106    the valid requests that can be served from this model.
107
108    For a single-headed model, it is still adviseable to provide
109    output_alternatives with a single entry, because this is how the problem
110    type is communicated for export and serving.  If output_alternatives is not
111    given, the resulting SavedModel will support only one head of unspecified
112    type.
113
114    Args:
115      mode: One of `ModeKeys`. Specifies if this training, evaluation or
116        prediction.
117      predictions: Predictions `Tensor` or dict of `Tensor`.
118      loss: Training loss `Tensor`.
119      train_op: Op for the training step.
120      eval_metric_ops: Dict of metric results keyed by name. The values of the
121        dict are the results of calling a metric function, such as `Tensor`.
122      output_alternatives: a dict of
123        `{submodel_name: (problem_type, {tensor_name: Tensor})}`, where
124        `submodel_name` is a submodel identifier that should be consistent
125        across the pipeline (here likely taken from the name of each `Head`,
126        for models that use them), `problem_type` is a `ProblemType`,
127        `tensor_name` is a symbolic name for an output Tensor possibly but not
128        necessarily taken from `PredictionKey`, and `Tensor` is the
129        corresponding output Tensor itself.
130      training_chief_hooks: A list of `SessionRunHook` objects that will be
131        run on the chief worker during training.
132      training_hooks: A list of `SessionRunHook` objects that will be run on
133        all workers during training.
134      scaffold: A `tf.train.Scaffold` object that can be used to set
135        initialization, saver, and more to be used in training.
136
137    Returns:
138      A validated `ModelFnOps` object.
139
140    Raises:
141      ValueError: If validation fails.
142    """
143    ModeKeys.validate(mode)
144
145    # Assert all ops are from the same graph.
146    get_graph_from_inputs((predictions, loss, train_op))
147
148    # Validate train_op.
149    if train_op is None:
150      if mode == ModeKeys.TRAIN:
151        raise ValueError('Missing train_op.')
152    elif not isinstance(train_op, ops.Operation):
153      # TODO(ptucker): Should this be allowed? Consider raising error.
154      train_op = ops.convert_to_tensor(train_op).op
155
156    # Validate loss.
157    if loss is None:
158      if mode in (ModeKeys.TRAIN, ModeKeys.EVAL):
159        raise ValueError('Missing loss.')
160    else:
161      loss = ops.convert_to_tensor(loss)
162      loss_shape = loss.get_shape()
163      if loss_shape.num_elements() not in (None, 1):
164        raise ValueError('Loss must be scalar: %s.' % loss)
165      if not loss_shape.is_compatible_with(tensor_shape.scalar()):
166        loss = array_ops.reshape(loss, [])
167
168    # Validate predictions.
169    if predictions is None:
170      if mode == ModeKeys.INFER or mode == ModeKeys.EVAL:
171        raise ValueError('Missing predictions.')
172    else:
173      if isinstance(predictions, dict):
174        predictions = {
175            k: sparse_tensor.convert_to_tensor_or_sparse_tensor(v)
176            for k, v in six.iteritems(predictions)
177        }
178      else:
179        predictions = sparse_tensor.convert_to_tensor_or_sparse_tensor(
180            predictions)
181
182    # Validate eval_metric_ops
183    if eval_metric_ops is None:
184      eval_metric_ops = {}
185    else:
186      if not isinstance(eval_metric_ops, dict):
187        raise ValueError('eval_metric_ops must be a dict.')
188
189    # Validate hooks
190    if training_chief_hooks is None:
191      training_chief_hooks = []
192    if training_hooks is None:
193      training_hooks = []
194    for hook in training_hooks + training_chief_hooks:
195      if not isinstance(hook, session_run_hook.SessionRunHook):
196        raise TypeError('All hooks returned from model_fn must be '
197                        'SessionRunHook instances, got instance of %s: %s' %
198                        (type(hook), hook))
199
200    return super(ModelFnOps, cls).__new__(
201        cls,
202        predictions=predictions,
203        loss=loss,
204        train_op=train_op,
205        eval_metric_ops=eval_metric_ops,
206        output_alternatives=output_alternatives,
207        training_chief_hooks=training_chief_hooks,
208        training_hooks=training_hooks,
209        scaffold=scaffold,
210        mode=mode)
211
212  def estimator_spec(self, default_serving_output_alternative_key=None):
213    """Creates an equivalent `EstimatorSpec`.
214
215    Args:
216      default_serving_output_alternative_key: Required for multiple heads. If
217        you have multiple entries in `output_alternatives` dict (comparable to
218        multiple heads), `EstimatorSpec` requires a default head that will be
219        used if a Servo request does not explicitly mention which head to infer
220        on. Pass the key of the output alternative here that you want to
221        designate as default. A separate ExportOutpout for this default head
222        will be added to the export_outputs dict with the special key
223        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, unless there is
224        already an enry in output_alternatives with this special key.
225
226    Returns:
227      Instance of `EstimatorSpec` that is equivalent to this `ModelFnOps`
228
229    Raises:
230      ValueError: If problem type is unknown.
231    """
232    def _scores(output_tensors):
233      scores = output_tensors.get(prediction_key.PredictionKey.SCORES)
234      if scores is None:
235        scores = output_tensors.get(prediction_key.PredictionKey.PROBABILITIES)
236      return scores
237
238    def _classes(output_tensors):  # pylint: disable=missing-docstring
239      classes = output_tensors.get(prediction_key.PredictionKey.CLASSES)
240      if classes is None:
241        logging.warning(
242            'classes is None, Servo inference will not have class ids.')
243        return None
244      elif classes.dtype != dtypes.string:
245        # Servo classification can only serve string classes
246        logging.warning(
247            'classes is not string, Servo inference will not have class ids.')
248        return None
249
250      return classes
251
252    def _export_output(problem_type, predictions):  # pylint: disable=missing-docstring
253      if problem_type == constants.ProblemType.LINEAR_REGRESSION:
254        return core_export_lib.RegressionOutput(_scores(predictions))
255
256      if (problem_type == constants.ProblemType.CLASSIFICATION or
257          problem_type == constants.ProblemType.LOGISTIC_REGRESSION):
258        return core_export_lib.ClassificationOutput(
259            scores=_scores(predictions), classes=_classes(predictions))
260
261      if problem_type == constants.ProblemType.UNSPECIFIED:
262        return core_export_lib.PredictOutput(predictions)
263
264      raise ValueError('Unknown problem_type=%s' % problem_type)
265
266    # Converts output_alternatives
267    export_outputs_dict = None
268    if self.output_alternatives:
269      output_alternatives = self.output_alternatives
270      # Adds default output_alternative if needed.
271      if (len(output_alternatives) > 1 and
272          signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY not in
273          output_alternatives):
274        output_alternatives = output_alternatives.copy()
275        output_alternatives[
276            signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = (
277                output_alternatives[default_serving_output_alternative_key])
278      export_outputs_dict = {key: _export_output(*val) for key, val in
279                             output_alternatives.items()}
280
281    def _get_eval_metric_ops():
282      """Returns self.eval_metric_ops without loss metric."""
283      result = {}
284      for key, value in six.iteritems(self.eval_metric_ops):
285        if key != metric_key.MetricKey.LOSS:
286          result[key] = value
287      return result
288
289    # Convert the contrib mode enum to the core mode enum.
290    # Note: mode already validated in __new__().
291    if self.mode == ModeKeys.TRAIN:
292      core_mode = core_model_fn_lib.ModeKeys.TRAIN
293    elif self.mode == ModeKeys.EVAL:
294      core_mode = core_model_fn_lib.ModeKeys.EVAL
295    elif self.mode == ModeKeys.INFER:
296      core_mode = core_model_fn_lib.ModeKeys.PREDICT
297
298    return core_model_fn_lib.EstimatorSpec(
299        mode=core_mode,
300        predictions=self.predictions,
301        loss=self.loss,
302        train_op=self.train_op,
303        eval_metric_ops=_get_eval_metric_ops(),
304        export_outputs=export_outputs_dict,
305        training_chief_hooks=self.training_chief_hooks,
306        training_hooks=self.training_hooks,
307        scaffold=self.scaffold)
308