1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for creating SavedModels."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import os
23import time
24
25import six
26
27from tensorflow.python.platform import gfile
28from tensorflow.python.platform import tf_logging as logging
29from tensorflow.python.saved_model import signature_constants
30from tensorflow.python.saved_model import signature_def_utils
31from tensorflow.python.saved_model import tag_constants
32from tensorflow.python.saved_model.model_utils import export_output as export_output_lib
33from tensorflow.python.saved_model.model_utils import mode_keys
34from tensorflow.python.saved_model.model_utils.mode_keys import KerasModeKeys as ModeKeys
35from tensorflow.python.util import compat
36
37
38# Mapping of the modes to appropriate MetaGraph tags in the SavedModel.
39EXPORT_TAG_MAP = mode_keys.ModeKeyMap(**{
40    ModeKeys.PREDICT: [tag_constants.SERVING],
41    ModeKeys.TRAIN: [tag_constants.TRAINING],
42    ModeKeys.TEST: [tag_constants.EVAL]})
43
44# For every exported mode, a SignatureDef map should be created using the
45# functions `export_outputs_for_mode` and `build_all_signature_defs`. By
46# default, this map will contain a single Signature that defines the input
47# tensors and output predictions, losses, and/or metrics (depending on the mode)
48# The default keys used in the SignatureDef map are defined below.
49SIGNATURE_KEY_MAP = mode_keys.ModeKeyMap(**{
50    ModeKeys.PREDICT: signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
51    ModeKeys.TRAIN: signature_constants.DEFAULT_TRAIN_SIGNATURE_DEF_KEY,
52    ModeKeys.TEST: signature_constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY})
53
54# Default names used in the SignatureDef input map, which maps strings to
55# TensorInfo protos.
56SINGLE_FEATURE_DEFAULT_NAME = 'feature'
57SINGLE_RECEIVER_DEFAULT_NAME = 'input'
58SINGLE_LABEL_DEFAULT_NAME = 'label'
59
60### Below utilities are specific to SavedModel exports.
61
62
63def build_all_signature_defs(receiver_tensors,
64                             export_outputs,
65                             receiver_tensors_alternatives=None,
66                             serving_only=True):
67  """Build `SignatureDef`s for all export outputs.
68
69  Args:
70    receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying
71      input nodes where this receiver expects to be fed by default.  Typically,
72      this is a single placeholder expecting serialized `tf.Example` protos.
73    export_outputs: a dict of ExportOutput instances, each of which has
74      an as_signature_def instance method that will be called to retrieve
75      the signature_def for all export output tensors.
76    receiver_tensors_alternatives: a dict of string to additional
77      groups of receiver tensors, each of which may be a `Tensor` or a dict of
78      string to `Tensor`.  These named receiver tensor alternatives generate
79      additional serving signatures, which may be used to feed inputs at
80      different points within the input receiver subgraph.  A typical usage is
81      to allow feeding raw feature `Tensor`s *downstream* of the
82      tf.io.parse_example() op.  Defaults to None.
83    serving_only: boolean; if true, resulting signature defs will only include
84      valid serving signatures. If false, all requested signatures will be
85      returned.
86
87  Returns:
88    signature_def representing all passed args.
89
90  Raises:
91    ValueError: if export_outputs is not a dict
92  """
93  if not isinstance(receiver_tensors, dict):
94    receiver_tensors = {SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors}
95  if export_outputs is None or not isinstance(export_outputs, dict):
96    raise ValueError('export_outputs must be a dict and not'
97                     '{}'.format(type(export_outputs)))
98
99  signature_def_map = {}
100  excluded_signatures = {}
101  for output_key, export_output in export_outputs.items():
102    signature_name = '{}'.format(output_key or 'None')
103    try:
104      signature = export_output.as_signature_def(receiver_tensors)
105      signature_def_map[signature_name] = signature
106    except ValueError as e:
107      excluded_signatures[signature_name] = str(e)
108
109  if receiver_tensors_alternatives:
110    for receiver_name, receiver_tensors_alt in (
111        six.iteritems(receiver_tensors_alternatives)):
112      if not isinstance(receiver_tensors_alt, dict):
113        receiver_tensors_alt = {
114            SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors_alt
115        }
116      for output_key, export_output in export_outputs.items():
117        signature_name = '{}:{}'.format(receiver_name or 'None', output_key or
118                                        'None')
119        try:
120          signature = export_output.as_signature_def(receiver_tensors_alt)
121          signature_def_map[signature_name] = signature
122        except ValueError as e:
123          excluded_signatures[signature_name] = str(e)
124
125  _log_signature_report(signature_def_map, excluded_signatures)
126
127  # The above calls to export_output_lib.as_signature_def should return only
128  # valid signatures; if there is a validity problem, they raise a ValueError,
129  # in which case we exclude that signature from signature_def_map above.
130  # The is_valid_signature check ensures that the signatures produced are
131  # valid for serving, and acts as an additional sanity check for export
132  # signatures produced for serving. We skip this check for training and eval
133  # signatures, which are not intended for serving.
134  if serving_only:
135    signature_def_map = {
136        k: v
137        for k, v in signature_def_map.items()
138        if signature_def_utils.is_valid_signature(v)
139    }
140  return signature_def_map
141
142
143_FRIENDLY_METHOD_NAMES = {
144    signature_constants.CLASSIFY_METHOD_NAME: 'Classify',
145    signature_constants.REGRESS_METHOD_NAME: 'Regress',
146    signature_constants.PREDICT_METHOD_NAME: 'Predict',
147    signature_constants.SUPERVISED_TRAIN_METHOD_NAME: 'Train',
148    signature_constants.SUPERVISED_EVAL_METHOD_NAME: 'Eval',
149}
150
151
152def _log_signature_report(signature_def_map, excluded_signatures):
153  """Log a report of which signatures were produced."""
154  sig_names_by_method_name = collections.defaultdict(list)
155
156  # We'll collect whatever method_names are present, but also we want to make
157  # sure to output a line for each of the three standard methods even if they
158  # have no signatures.
159  for method_name in _FRIENDLY_METHOD_NAMES:
160    sig_names_by_method_name[method_name] = []
161
162  for signature_name, sig in signature_def_map.items():
163    sig_names_by_method_name[sig.method_name].append(signature_name)
164
165  # TODO(b/67733540): consider printing the full signatures, not just names
166  for method_name, sig_names in sig_names_by_method_name.items():
167    if method_name in _FRIENDLY_METHOD_NAMES:
168      method_name = _FRIENDLY_METHOD_NAMES[method_name]
169    logging.info('Signatures INCLUDED in export for {}: {}'.format(
170        method_name, sig_names if sig_names else 'None'))
171
172  if excluded_signatures:
173    logging.info('Signatures EXCLUDED from export because they cannot be '
174                 'be served via TensorFlow Serving APIs:')
175    for signature_name, message in excluded_signatures.items():
176      logging.info('\'{}\' : {}'.format(signature_name, message))
177
178  if not signature_def_map:
179    logging.warn('Export includes no signatures!')
180  elif (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY not in
181        signature_def_map):
182    logging.warn('Export includes no default signature!')
183
184
185# When we create a timestamped directory, there is a small chance that the
186# directory already exists because another process is also creating these
187# directories. In this case we just wait one second to get a new timestamp and
188# try again. If this fails several times in a row, then something is seriously
189# wrong.
190MAX_DIRECTORY_CREATION_ATTEMPTS = 10
191
192
193def get_timestamped_export_dir(export_dir_base):
194  """Builds a path to a new subdirectory within the base directory.
195
196  Each export is written into a new subdirectory named using the
197  current time.  This guarantees monotonically increasing version
198  numbers even across multiple runs of the pipeline.
199  The timestamp used is the number of seconds since epoch UTC.
200
201  Args:
202    export_dir_base: A string containing a directory to write the exported
203        graph and checkpoints.
204  Returns:
205    The full path of the new subdirectory (which is not actually created yet).
206
207  Raises:
208    RuntimeError: if repeated attempts fail to obtain a unique timestamped
209      directory name.
210  """
211  attempts = 0
212  while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS:
213    timestamp = int(time.time())
214
215    result_dir = os.path.join(
216        compat.as_bytes(export_dir_base), compat.as_bytes(str(timestamp)))
217    if not gfile.Exists(result_dir):
218      # Collisions are still possible (though extremely unlikely): this
219      # directory is not actually created yet, but it will be almost
220      # instantly on return from this function.
221      return result_dir
222    time.sleep(1)
223    attempts += 1
224    logging.warn('Directory {} already exists; retrying (attempt {}/{})'.format(
225        compat.as_str(result_dir), attempts, MAX_DIRECTORY_CREATION_ATTEMPTS))
226  raise RuntimeError('Failed to obtain a unique export directory name after '
227                     '{} attempts.'.format(MAX_DIRECTORY_CREATION_ATTEMPTS))
228
229
230def get_temp_export_dir(timestamped_export_dir):
231  """Builds a directory name based on the argument but starting with 'temp-'.
232
233  This relies on the fact that TensorFlow Serving ignores subdirectories of
234  the base directory that can't be parsed as integers.
235
236  Args:
237    timestamped_export_dir: the name of the eventual export directory, e.g.
238      /foo/bar/<timestamp>
239
240  Returns:
241    A sister directory prefixed with 'temp-', e.g. /foo/bar/temp-<timestamp>.
242  """
243  (dirname, basename) = os.path.split(timestamped_export_dir)
244  temp_export_dir = os.path.join(
245      compat.as_bytes(dirname),
246      compat.as_bytes('temp-{}'.format(six.ensure_text(basename))))
247  return temp_export_dir
248
249
250def export_outputs_for_mode(
251    mode, serving_export_outputs=None, predictions=None, loss=None,
252    metrics=None):
253  """Util function for constructing a `ExportOutput` dict given a mode.
254
255  The returned dict can be directly passed to `build_all_signature_defs` helper
256  function as the `export_outputs` argument, used for generating a SignatureDef
257  map.
258
259  Args:
260    mode: A `ModeKeys` specifying the mode.
261    serving_export_outputs: Describes the output signatures to be exported to
262      `SavedModel` and used during serving. Should be a dict or None.
263    predictions: A dict of Tensors or single Tensor representing model
264        predictions. This argument is only used if serving_export_outputs is not
265        set.
266    loss: A dict of Tensors or single Tensor representing calculated loss.
267    metrics: A dict of (metric_value, update_op) tuples, or a single tuple.
268      metric_value must be a Tensor, and update_op must be a Tensor or Op
269
270  Returns:
271    Dictionary mapping the a key to an `tf.estimator.export.ExportOutput` object
272    The key is the expected SignatureDef key for the mode.
273
274  Raises:
275    ValueError: if an appropriate ExportOutput cannot be found for the mode.
276  """
277  if mode not in SIGNATURE_KEY_MAP:
278    raise ValueError(
279        'Export output type not found for mode: {}. Expected one of: {}.\n'
280        'One likely error is that V1 Estimator Modekeys were somehow passed to '
281        'this function. Please ensure that you are using the new ModeKeys.'
282        .format(mode, SIGNATURE_KEY_MAP.keys()))
283  signature_key = SIGNATURE_KEY_MAP[mode]
284  if mode_keys.is_predict(mode):
285    return get_export_outputs(serving_export_outputs, predictions)
286  elif mode_keys.is_train(mode):
287    return {signature_key: export_output_lib.TrainOutput(
288        loss=loss, predictions=predictions, metrics=metrics)}
289  else:
290    return {signature_key: export_output_lib.EvalOutput(
291        loss=loss, predictions=predictions, metrics=metrics)}
292
293
294def get_export_outputs(export_outputs, predictions):
295  """Validate export_outputs or create default export_outputs.
296
297  Args:
298    export_outputs: Describes the output signatures to be exported to
299      `SavedModel` and used during serving. Should be a dict or None.
300    predictions:  Predictions `Tensor` or dict of `Tensor`.
301
302  Returns:
303    Valid export_outputs dict
304
305  Raises:
306    TypeError: if export_outputs is not a dict or its values are not
307      ExportOutput instances.
308  """
309  if export_outputs is None:
310    default_output = export_output_lib.PredictOutput(predictions)
311    export_outputs = {
312        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: default_output}
313
314  if not isinstance(export_outputs, dict):
315    raise TypeError('export_outputs must be dict, given: {}'.format(
316        export_outputs))
317  for v in six.itervalues(export_outputs):
318    if not isinstance(v, export_output_lib.ExportOutput):
319      raise TypeError(
320          'Values in export_outputs must be ExportOutput objects. '
321          'Given: {}'.format(export_outputs))
322
323  _maybe_add_default_serving_output(export_outputs)
324
325  return export_outputs
326
327
328def _maybe_add_default_serving_output(export_outputs):
329  """Add a default serving output to the export_outputs if not present.
330
331  Args:
332    export_outputs: Describes the output signatures to be exported to
333      `SavedModel` and used during serving. Should be a dict.
334
335  Returns:
336    export_outputs dict with default serving signature added if necessary
337
338  Raises:
339    ValueError: if multiple export_outputs were provided without a default
340      serving key.
341  """
342  if len(export_outputs) == 1:
343    (key, value), = export_outputs.items()
344    if key != signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
345      export_outputs[
346          signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = value
347  if len(export_outputs) > 1:
348    if (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
349        not in export_outputs):
350      raise ValueError(
351          'Multiple export_outputs were provided, but none of them is '
352          'specified as the default.  Do this by naming one of them with '
353          'signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.')
354
355  return export_outputs
356