1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes for different types of export output."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22
23import six
24
25
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.saved_model import signature_def_utils
30
31
32class ExportOutput(object):
33  """Represents an output of a model that can be served.
34
35  These typically correspond to model heads.
36  """
37
38  __metaclass__ = abc.ABCMeta
39
40  _SEPARATOR_CHAR = '/'
41
42  @abc.abstractmethod
43  def as_signature_def(self, receiver_tensors):
44    """Generate a SignatureDef proto for inclusion in a MetaGraphDef.
45
46    The SignatureDef will specify outputs as described in this ExportOutput,
47    and will use the provided receiver_tensors as inputs.
48
49    Args:
50      receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying
51        input nodes that will be fed.
52    """
53    pass
54
55  def _check_output_key(self, key, error_label):
56    # For multi-head models, the key can be a tuple.
57    if isinstance(key, tuple):
58      key = self._SEPARATOR_CHAR.join(key)
59
60    if not isinstance(key, six.string_types):
61      raise ValueError(
62          '{} output key must be a string; got {}.'.format(error_label, key))
63    return key
64
65  def _wrap_and_check_outputs(
66      self, outputs, single_output_default_name, error_label=None):
67    """Wraps raw tensors as dicts and checks type.
68
69    Note that we create a new dict here so that we can overwrite the keys
70    if necessary.
71
72    Args:
73      outputs: A `Tensor` or a dict of string to `Tensor`.
74      single_output_default_name: A string key for use in the output dict
75        if the provided `outputs` is a raw tensor.
76      error_label: descriptive string for use in error messages. If none,
77        single_output_default_name will be used.
78
79    Returns:
80      A dict of tensors
81
82    Raises:
83      ValueError: if the outputs dict keys are not strings or tuples of strings
84        or the values are not Tensors.
85    """
86    if not isinstance(outputs, dict):
87      outputs = {single_output_default_name: outputs}
88
89    output_dict = {}
90    for key, value in outputs.items():
91      error_name = error_label or single_output_default_name
92      key = self._check_output_key(key, error_name)
93      if not isinstance(value, ops.Tensor):
94        raise ValueError(
95            '{} output value must be a Tensor; got {}.'.format(
96                error_name, value))
97
98      output_dict[key] = value
99    return output_dict
100
101
102class ClassificationOutput(ExportOutput):
103  """Represents the output of a classification head.
104
105  Either classes or scores or both must be set.
106
107  The classes `Tensor` must provide string labels, not integer class IDs.
108
109  If only classes is set, it is interpreted as providing top-k results in
110  descending order.
111
112  If only scores is set, it is interpreted as providing a score for every class
113  in order of class ID.
114
115  If both classes and scores are set, they are interpreted as zipped, so each
116  score corresponds to the class at the same index.  Clients should not depend
117  on the order of the entries.
118  """
119
120  def __init__(self, scores=None, classes=None):
121    """Constructor for `ClassificationOutput`.
122
123    Args:
124      scores: A float `Tensor` giving scores (sometimes but not always
125          interpretable as probabilities) for each class.  May be `None`, but
126          only if `classes` is set.  Interpretation varies-- see class doc.
127      classes: A string `Tensor` giving predicted class labels.  May be `None`,
128          but only if `scores` is set.  Interpretation varies-- see class doc.
129
130    Raises:
131      ValueError: if neither classes nor scores is set, or one of them is not a
132          `Tensor` with the correct dtype.
133    """
134    if (scores is not None
135        and not (isinstance(scores, ops.Tensor)
136                 and scores.dtype.is_floating)):
137      raise ValueError('Classification scores must be a float32 Tensor; '
138                       'got {}'.format(scores))
139    if (classes is not None
140        and not (isinstance(classes, ops.Tensor)
141                 and dtypes.as_dtype(classes.dtype) == dtypes.string)):
142      raise ValueError('Classification classes must be a string Tensor; '
143                       'got {}'.format(classes))
144    if scores is None and classes is None:
145      raise ValueError('At least one of scores and classes must be set.')
146
147    self._scores = scores
148    self._classes = classes
149
150  @property
151  def scores(self):
152    return self._scores
153
154  @property
155  def classes(self):
156    return self._classes
157
158  def as_signature_def(self, receiver_tensors):
159    if len(receiver_tensors) != 1:
160      raise ValueError('Classification input must be a single string Tensor; '
161                       'got {}'.format(receiver_tensors))
162    (_, examples), = receiver_tensors.items()
163    if dtypes.as_dtype(examples.dtype) != dtypes.string:
164      raise ValueError('Classification input must be a single string Tensor; '
165                       'got {}'.format(receiver_tensors))
166    return signature_def_utils.classification_signature_def(
167        examples, self.classes, self.scores)
168
169
170class RegressionOutput(ExportOutput):
171  """Represents the output of a regression head."""
172
173  def __init__(self, value):
174    """Constructor for `RegressionOutput`.
175
176    Args:
177      value: a float `Tensor` giving the predicted values.  Required.
178
179    Raises:
180      ValueError: if the value is not a `Tensor` with dtype tf.float32.
181    """
182    if not (isinstance(value, ops.Tensor) and value.dtype.is_floating):
183      raise ValueError('Regression output value must be a float32 Tensor; '
184                       'got {}'.format(value))
185    self._value = value
186
187  @property
188  def value(self):
189    return self._value
190
191  def as_signature_def(self, receiver_tensors):
192    if len(receiver_tensors) != 1:
193      raise ValueError('Regression input must be a single string Tensor; '
194                       'got {}'.format(receiver_tensors))
195    (_, examples), = receiver_tensors.items()
196    if dtypes.as_dtype(examples.dtype) != dtypes.string:
197      raise ValueError('Regression input must be a single string Tensor; '
198                       'got {}'.format(receiver_tensors))
199    return signature_def_utils.regression_signature_def(examples, self.value)
200
201
202class PredictOutput(ExportOutput):
203  """Represents the output of a generic prediction head.
204
205  A generic prediction need not be either a classification or a regression.
206
207  Named outputs must be provided as a dict from string to `Tensor`,
208  """
209  _SINGLE_OUTPUT_DEFAULT_NAME = 'output'
210
211  def __init__(self, outputs):
212    """Constructor for PredictOutput.
213
214    Args:
215      outputs: A `Tensor` or a dict of string to `Tensor` representing the
216        predictions.
217
218    Raises:
219      ValueError: if the outputs is not dict, or any of its keys are not
220          strings, or any of its values are not `Tensor`s.
221    """
222
223    self._outputs = self._wrap_and_check_outputs(
224        outputs, self._SINGLE_OUTPUT_DEFAULT_NAME, error_label='Prediction')
225
226  @property
227  def outputs(self):
228    return self._outputs
229
230  def as_signature_def(self, receiver_tensors):
231    return signature_def_utils.predict_signature_def(receiver_tensors,
232                                                     self.outputs)
233
234
235class _SupervisedOutput(ExportOutput):
236  """Represents the output of a supervised training or eval process."""
237  __metaclass__ = abc.ABCMeta
238
239  LOSS_NAME = 'loss'
240  PREDICTIONS_NAME = 'predictions'
241  METRICS_NAME = 'metrics'
242
243  METRIC_VALUE_SUFFIX = 'value'
244  METRIC_UPDATE_SUFFIX = 'update_op'
245
246  _loss = None
247  _predictions = None
248  _metrics = None
249
250  def __init__(self, loss=None, predictions=None, metrics=None):
251    """Constructor for SupervisedOutput (ie, Train or Eval output).
252
253    Args:
254      loss: dict of Tensors or single Tensor representing calculated loss.
255      predictions: dict of Tensors or single Tensor representing model
256        predictions.
257      metrics: Dict of metric results keyed by name.
258        The values of the dict can be one of the following:
259        (1) instance of `Metric` class.
260        (2) (metric_value, update_op) tuples, or a single tuple.
261        metric_value must be a Tensor, and update_op must be a Tensor or Op.
262
263    Raises:
264      ValueError: if any of the outputs' dict keys are not strings or tuples of
265        strings or the values are not Tensors (or Operations in the case of
266        update_op).
267    """
268
269    if loss is not None:
270      loss_dict = self._wrap_and_check_outputs(loss, self.LOSS_NAME)
271      self._loss = self._prefix_output_keys(loss_dict, self.LOSS_NAME)
272    if predictions is not None:
273      pred_dict = self._wrap_and_check_outputs(
274          predictions, self.PREDICTIONS_NAME)
275      self._predictions = self._prefix_output_keys(
276          pred_dict, self.PREDICTIONS_NAME)
277    if metrics is not None:
278      self._metrics = self._wrap_and_check_metrics(metrics)
279
280  def _prefix_output_keys(self, output_dict, output_name):
281    """Prepend output_name to the output_dict keys if it doesn't exist.
282
283    This produces predictable prefixes for the pre-determined outputs
284    of SupervisedOutput.
285
286    Args:
287      output_dict: dict of string to Tensor, assumed valid.
288      output_name: prefix string to prepend to existing keys.
289
290    Returns:
291      dict with updated keys and existing values.
292    """
293
294    new_outputs = {}
295    for key, val in output_dict.items():
296      key = self._prefix_key(key, output_name)
297      new_outputs[key] = val
298    return new_outputs
299
300  def _prefix_key(self, key, output_name):
301    if key.find(output_name) != 0:
302      key = output_name + self._SEPARATOR_CHAR + key
303    return key
304
305  def _wrap_and_check_metrics(self, metrics):
306    """Handle the saving of metrics.
307
308    Metrics is either a tuple of (value, update_op), or a dict of such tuples.
309    Here, we separate out the tuples and create a dict with names to tensors.
310
311    Args:
312      metrics: Dict of metric results keyed by name.
313        The values of the dict can be one of the following:
314        (1) instance of `Metric` class.
315        (2) (metric_value, update_op) tuples, or a single tuple.
316        metric_value must be a Tensor, and update_op must be a Tensor or Op.
317
318    Returns:
319      dict of output_names to tensors
320
321    Raises:
322      ValueError: if the dict key is not a string, or the metric values or ops
323        are not tensors.
324    """
325    if not isinstance(metrics, dict):
326      metrics = {self.METRICS_NAME: metrics}
327
328    outputs = {}
329    for key, value in metrics.items():
330      if isinstance(value, tuple):
331        metric_val, metric_op = value
332      else:  # value is a keras.Metrics object
333        metric_val = value.result()
334        assert len(value.updates) == 1  # We expect only one update op.
335        metric_op = value.updates[0]
336      key = self._check_output_key(key, self.METRICS_NAME)
337      key = self._prefix_key(key, self.METRICS_NAME)
338
339      val_name = key + self._SEPARATOR_CHAR + self.METRIC_VALUE_SUFFIX
340      op_name = key + self._SEPARATOR_CHAR + self.METRIC_UPDATE_SUFFIX
341      if not isinstance(metric_val, ops.Tensor):
342        raise ValueError(
343            '{} output value must be a Tensor; got {}.'.format(
344                key, metric_val))
345      if (not isinstance(metric_op, ops.Tensor) and
346          not isinstance(metric_op, ops.Operation)):
347        raise ValueError(
348            '{} update_op must be a Tensor or Operation; got {}.'.format(
349                key, metric_op))
350
351      # We must wrap any ops in a Tensor before export, as the SignatureDef
352      # proto expects tensors only. See b/109740581
353      metric_op_tensor = metric_op
354      if isinstance(metric_op, ops.Operation):
355        with ops.control_dependencies([metric_op]):
356          metric_op_tensor = constant_op.constant([], name='metric_op_wrapper')
357
358      outputs[val_name] = metric_val
359      outputs[op_name] = metric_op_tensor
360
361    return outputs
362
363  @property
364  def loss(self):
365    return self._loss
366
367  @property
368  def predictions(self):
369    return self._predictions
370
371  @property
372  def metrics(self):
373    return self._metrics
374
375  @abc.abstractmethod
376  def _get_signature_def_fn(self):
377    """Returns a function that produces a SignatureDef given desired outputs."""
378    pass
379
380  def as_signature_def(self, receiver_tensors):
381    signature_def_fn = self._get_signature_def_fn()
382    return signature_def_fn(
383        receiver_tensors, self.loss, self.predictions, self.metrics)
384
385
386class TrainOutput(_SupervisedOutput):
387  """Represents the output of a supervised training process.
388
389  This class generates the appropriate signature def for exporting
390  training output by type-checking and wrapping loss, predictions, and metrics
391  values.
392  """
393
394  def _get_signature_def_fn(self):
395    return signature_def_utils.supervised_train_signature_def
396
397
398class EvalOutput(_SupervisedOutput):
399  """Represents the output of a supervised eval process.
400
401  This class generates the appropriate signature def for exporting
402  eval output by type-checking and wrapping loss, predictions, and metrics
403  values.
404  """
405
406  def _get_signature_def_fn(self):
407    return signature_def_utils.supervised_eval_signature_def
408