1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes for different types of export output."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22
23import six
24
25
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import tensor_util
30from tensorflow.python.saved_model import signature_def_utils
31
32
33class ExportOutput(object):
34  """Represents an output of a model that can be served.
35
36  These typically correspond to model heads.
37  """
38
39  __metaclass__ = abc.ABCMeta
40
41  _SEPARATOR_CHAR = '/'
42
43  @abc.abstractmethod
44  def as_signature_def(self, receiver_tensors):
45    """Generate a SignatureDef proto for inclusion in a MetaGraphDef.
46
47    The SignatureDef will specify outputs as described in this ExportOutput,
48    and will use the provided receiver_tensors as inputs.
49
50    Args:
51      receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying
52        input nodes that will be fed.
53    """
54    pass
55
56  def _check_output_key(self, key, error_label):
57    # For multi-head models, the key can be a tuple.
58    if isinstance(key, tuple):
59      key = self._SEPARATOR_CHAR.join(key)
60
61    if not isinstance(key, six.string_types):
62      raise ValueError(
63          '{} output key must be a string; got {}.'.format(error_label, key))
64    return key
65
66  def _wrap_and_check_outputs(
67      self, outputs, single_output_default_name, error_label=None):
68    """Wraps raw tensors as dicts and checks type.
69
70    Note that we create a new dict here so that we can overwrite the keys
71    if necessary.
72
73    Args:
74      outputs: A `Tensor` or a dict of string to `Tensor`.
75      single_output_default_name: A string key for use in the output dict
76        if the provided `outputs` is a raw tensor.
77      error_label: descriptive string for use in error messages. If none,
78        single_output_default_name will be used.
79
80    Returns:
81      A dict of tensors
82
83    Raises:
84      ValueError: if the outputs dict keys are not strings or tuples of strings
85        or the values are not Tensors.
86    """
87    if not isinstance(outputs, dict):
88      outputs = {single_output_default_name: outputs}
89
90    output_dict = {}
91    for key, value in outputs.items():
92      error_name = error_label or single_output_default_name
93      key = self._check_output_key(key, error_name)
94      if not isinstance(value, ops.Tensor):
95        raise ValueError(
96            '{} output value must be a Tensor; got {}.'.format(
97                error_name, value))
98
99      output_dict[key] = value
100    return output_dict
101
102
103class ClassificationOutput(ExportOutput):
104  """Represents the output of a classification head.
105
106  Either classes or scores or both must be set.
107
108  The classes `Tensor` must provide string labels, not integer class IDs.
109
110  If only classes is set, it is interpreted as providing top-k results in
111  descending order.
112
113  If only scores is set, it is interpreted as providing a score for every class
114  in order of class ID.
115
116  If both classes and scores are set, they are interpreted as zipped, so each
117  score corresponds to the class at the same index.  Clients should not depend
118  on the order of the entries.
119  """
120
121  def __init__(self, scores=None, classes=None):
122    """Constructor for `ClassificationOutput`.
123
124    Args:
125      scores: A float `Tensor` giving scores (sometimes but not always
126          interpretable as probabilities) for each class.  May be `None`, but
127          only if `classes` is set.  Interpretation varies-- see class doc.
128      classes: A string `Tensor` giving predicted class labels.  May be `None`,
129          but only if `scores` is set.  Interpretation varies-- see class doc.
130
131    Raises:
132      ValueError: if neither classes nor scores is set, or one of them is not a
133          `Tensor` with the correct dtype.
134    """
135    if (scores is not None
136        and not (isinstance(scores, ops.Tensor)
137                 and scores.dtype.is_floating)):
138      raise ValueError('Classification scores must be a float32 Tensor; '
139                       'got {}'.format(scores))
140    if (classes is not None
141        and not (isinstance(classes, ops.Tensor)
142                 and dtypes.as_dtype(classes.dtype) == dtypes.string)):
143      raise ValueError('Classification classes must be a string Tensor; '
144                       'got {}'.format(classes))
145    if scores is None and classes is None:
146      raise ValueError('At least one of scores and classes must be set.')
147
148    self._scores = scores
149    self._classes = classes
150
151  @property
152  def scores(self):
153    return self._scores
154
155  @property
156  def classes(self):
157    return self._classes
158
159  def as_signature_def(self, receiver_tensors):
160    if len(receiver_tensors) != 1:
161      raise ValueError('Classification input must be a single string Tensor; '
162                       'got {}'.format(receiver_tensors))
163    (_, examples), = receiver_tensors.items()
164    if dtypes.as_dtype(examples.dtype) != dtypes.string:
165      raise ValueError('Classification input must be a single string Tensor; '
166                       'got {}'.format(receiver_tensors))
167    return signature_def_utils.classification_signature_def(
168        examples, self.classes, self.scores)
169
170
171class RegressionOutput(ExportOutput):
172  """Represents the output of a regression head."""
173
174  def __init__(self, value):
175    """Constructor for `RegressionOutput`.
176
177    Args:
178      value: a float `Tensor` giving the predicted values.  Required.
179
180    Raises:
181      ValueError: if the value is not a `Tensor` with dtype tf.float32.
182    """
183    if not (isinstance(value, ops.Tensor) and value.dtype.is_floating):
184      raise ValueError('Regression output value must be a float32 Tensor; '
185                       'got {}'.format(value))
186    self._value = value
187
188  @property
189  def value(self):
190    return self._value
191
192  def as_signature_def(self, receiver_tensors):
193    if len(receiver_tensors) != 1:
194      raise ValueError('Regression input must be a single string Tensor; '
195                       'got {}'.format(receiver_tensors))
196    (_, examples), = receiver_tensors.items()
197    if dtypes.as_dtype(examples.dtype) != dtypes.string:
198      raise ValueError('Regression input must be a single string Tensor; '
199                       'got {}'.format(receiver_tensors))
200    return signature_def_utils.regression_signature_def(examples, self.value)
201
202
203class PredictOutput(ExportOutput):
204  """Represents the output of a generic prediction head.
205
206  A generic prediction need not be either a classification or a regression.
207
208  Named outputs must be provided as a dict from string to `Tensor`,
209  """
210  _SINGLE_OUTPUT_DEFAULT_NAME = 'output'
211
212  def __init__(self, outputs):
213    """Constructor for PredictOutput.
214
215    Args:
216      outputs: A `Tensor` or a dict of string to `Tensor` representing the
217        predictions.
218
219    Raises:
220      ValueError: if the outputs is not dict, or any of its keys are not
221          strings, or any of its values are not `Tensor`s.
222    """
223
224    self._outputs = self._wrap_and_check_outputs(
225        outputs, self._SINGLE_OUTPUT_DEFAULT_NAME, error_label='Prediction')
226
227  @property
228  def outputs(self):
229    return self._outputs
230
231  def as_signature_def(self, receiver_tensors):
232    return signature_def_utils.predict_signature_def(receiver_tensors,
233                                                     self.outputs)
234
235
236class _SupervisedOutput(ExportOutput):
237  """Represents the output of a supervised training or eval process."""
238  __metaclass__ = abc.ABCMeta
239
240  LOSS_NAME = 'loss'
241  PREDICTIONS_NAME = 'predictions'
242  METRICS_NAME = 'metrics'
243
244  METRIC_VALUE_SUFFIX = 'value'
245  METRIC_UPDATE_SUFFIX = 'update_op'
246
247  _loss = None
248  _predictions = None
249  _metrics = None
250
251  def __init__(self, loss=None, predictions=None, metrics=None):
252    """Constructor for SupervisedOutput (ie, Train or Eval output).
253
254    Args:
255      loss: dict of Tensors or single Tensor representing calculated loss.
256      predictions: dict of Tensors or single Tensor representing model
257        predictions.
258      metrics: Dict of metric results keyed by name.
259        The values of the dict can be one of the following:
260        (1) instance of `Metric` class.
261        (2) (metric_value, update_op) tuples, or a single tuple.
262        metric_value must be a Tensor, and update_op must be a Tensor or Op.
263
264    Raises:
265      ValueError: if any of the outputs' dict keys are not strings or tuples of
266        strings or the values are not Tensors (or Operations in the case of
267        update_op).
268    """
269
270    if loss is not None:
271      loss_dict = self._wrap_and_check_outputs(loss, self.LOSS_NAME)
272      self._loss = self._prefix_output_keys(loss_dict, self.LOSS_NAME)
273    if predictions is not None:
274      pred_dict = self._wrap_and_check_outputs(
275          predictions, self.PREDICTIONS_NAME)
276      self._predictions = self._prefix_output_keys(
277          pred_dict, self.PREDICTIONS_NAME)
278    if metrics is not None:
279      self._metrics = self._wrap_and_check_metrics(metrics)
280
281  def _prefix_output_keys(self, output_dict, output_name):
282    """Prepend output_name to the output_dict keys if it doesn't exist.
283
284    This produces predictable prefixes for the pre-determined outputs
285    of SupervisedOutput.
286
287    Args:
288      output_dict: dict of string to Tensor, assumed valid.
289      output_name: prefix string to prepend to existing keys.
290
291    Returns:
292      dict with updated keys and existing values.
293    """
294
295    new_outputs = {}
296    for key, val in output_dict.items():
297      key = self._prefix_key(key, output_name)
298      new_outputs[key] = val
299    return new_outputs
300
301  def _prefix_key(self, key, output_name):
302    if key.find(output_name) != 0:
303      key = output_name + self._SEPARATOR_CHAR + key
304    return key
305
306  def _wrap_and_check_metrics(self, metrics):
307    """Handle the saving of metrics.
308
309    Metrics is either a tuple of (value, update_op), or a dict of such tuples.
310    Here, we separate out the tuples and create a dict with names to tensors.
311
312    Args:
313      metrics: Dict of metric results keyed by name.
314        The values of the dict can be one of the following:
315        (1) instance of `Metric` class.
316        (2) (metric_value, update_op) tuples, or a single tuple.
317        metric_value must be a Tensor, and update_op must be a Tensor or Op.
318
319    Returns:
320      dict of output_names to tensors
321
322    Raises:
323      ValueError: if the dict key is not a string, or the metric values or ops
324        are not tensors.
325    """
326    if not isinstance(metrics, dict):
327      metrics = {self.METRICS_NAME: metrics}
328
329    outputs = {}
330    for key, value in metrics.items():
331      if isinstance(value, tuple):
332        metric_val, metric_op = value
333      else:  # value is a keras.Metrics object
334        metric_val = value.result()
335        assert len(value.updates) == 1  # We expect only one update op.
336        metric_op = value.updates[0]
337      key = self._check_output_key(key, self.METRICS_NAME)
338      key = self._prefix_key(key, self.METRICS_NAME)
339
340      val_name = key + self._SEPARATOR_CHAR + self.METRIC_VALUE_SUFFIX
341      op_name = key + self._SEPARATOR_CHAR + self.METRIC_UPDATE_SUFFIX
342      if not isinstance(metric_val, ops.Tensor):
343        raise ValueError(
344            '{} output value must be a Tensor; got {}.'.format(
345                key, metric_val))
346      if not (tensor_util.is_tf_type(metric_op) or
347              isinstance(metric_op, ops.Operation)):
348        raise ValueError(
349            '{} update_op must be a Tensor or Operation; got {}.'.format(
350                key, metric_op))
351
352      # We must wrap any ops (or variables) in a Tensor before export, as the
353      # SignatureDef proto expects tensors only. See b/109740581
354      metric_op_tensor = metric_op
355      if not isinstance(metric_op, ops.Tensor):
356        with ops.control_dependencies([metric_op]):
357          metric_op_tensor = constant_op.constant([], name='metric_op_wrapper')
358
359      outputs[val_name] = metric_val
360      outputs[op_name] = metric_op_tensor
361
362    return outputs
363
364  @property
365  def loss(self):
366    return self._loss
367
368  @property
369  def predictions(self):
370    return self._predictions
371
372  @property
373  def metrics(self):
374    return self._metrics
375
376  @abc.abstractmethod
377  def _get_signature_def_fn(self):
378    """Returns a function that produces a SignatureDef given desired outputs."""
379    pass
380
381  def as_signature_def(self, receiver_tensors):
382    signature_def_fn = self._get_signature_def_fn()
383    return signature_def_fn(
384        receiver_tensors, self.loss, self.predictions, self.metrics)
385
386
387class TrainOutput(_SupervisedOutput):
388  """Represents the output of a supervised training process.
389
390  This class generates the appropriate signature def for exporting
391  training output by type-checking and wrapping loss, predictions, and metrics
392  values.
393  """
394
395  def _get_signature_def_fn(self):
396    return signature_def_utils.supervised_train_signature_def
397
398
399class EvalOutput(_SupervisedOutput):
400  """Represents the output of a supervised eval process.
401
402  This class generates the appropriate signature def for exporting
403  eval output by type-checking and wrapping loss, predictions, and metrics
404  values.
405  """
406
407  def _get_signature_def_fn(self):
408    return signature_def_utils.supervised_eval_signature_def
409