1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities supporting export to SavedModel (deprecated).
16
17This module and all its submodules are deprecated. See
18[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
19for migration instructions.
20
21Some contents of this file are moved to tensorflow/python/estimator/export.py:
22
23get_input_alternatives() -> obsolete
24get_output_alternatives() -> obsolete, but see _get_default_export_output()
25build_all_signature_defs() -> build_all_signature_defs()
26get_timestamped_export_directory() -> get_timestamped_export_directory()
27_get_* -> obsolete
28_is_* -> obsolete
29
30Functionality of build_standardized_signature_def() is moved to
31tensorflow/python/estimator/export_output.py as ExportOutput.as_signature_def().
32
33Anything to do with ExportStrategies or garbage collection is not moved.
34"""
35from __future__ import absolute_import
36from __future__ import division
37from __future__ import print_function
38
39import os
40import time
41
42from tensorflow.contrib.layers.python.layers import feature_column
43from tensorflow.contrib.learn.python.learn import export_strategy
44from tensorflow.contrib.learn.python.learn.estimators import constants
45from tensorflow.contrib.learn.python.learn.estimators import metric_key
46from tensorflow.contrib.learn.python.learn.estimators import prediction_key
47from tensorflow.contrib.learn.python.learn.utils import gc
48from tensorflow.contrib.learn.python.learn.utils import input_fn_utils
49from tensorflow.python.estimator import estimator as core_estimator
50from tensorflow.python.estimator.export import export as core_export
51from tensorflow.python.framework import dtypes
52from tensorflow.python.framework import errors_impl
53from tensorflow.python.platform import gfile
54from tensorflow.python.platform import tf_logging as logging
55from tensorflow.python.saved_model import signature_constants
56from tensorflow.python.saved_model import signature_def_utils
57from tensorflow.python.summary import summary_iterator
58from tensorflow.python.training import checkpoint_management
59from tensorflow.python.util import compat
60from tensorflow.python.util.deprecation import deprecated
61
62
63# A key for use in the input_alternatives dict indicating the default input.
64# This is the input that will be expected when a serving request does not
65# specify a specific signature.
66# The default input alternative specifies placeholders that the input_fn
67# requires to be fed (in the typical case, a single placeholder for a
68# serialized tf.Example).
69DEFAULT_INPUT_ALTERNATIVE_KEY = 'default_input_alternative'
70
71# A key for use in the input_alternatives dict indicating the features input.
72# The features inputs alternative specifies the feature Tensors provided as
73# input to the model_fn, i.e. the outputs of the input_fn.
74FEATURES_INPUT_ALTERNATIVE_KEY = 'features_input_alternative'
75
76# A key for use in the output_alternatives dict indicating the default output.
77# This is the output that will be provided when a serving request does not
78# specify a specific signature.
79# In a single-headed model, the single output is automatically the default.
80# In a multi-headed model, the name of the desired default head should be
81# provided to get_output_alternatives.
82_FALLBACK_DEFAULT_OUTPUT_ALTERNATIVE_KEY = 'default_output_alternative'
83
84
85@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.')
86def build_standardized_signature_def(input_tensors, output_tensors,
87                                     problem_type):
88  """Build a SignatureDef using problem type and input and output Tensors.
89
90  Note that this delegates the actual creation of the signatures to methods in
91  //third_party/tensorflow/python/saved_model/signature_def_utils.py, which may
92  assign names to the input and output tensors (depending on the problem type)
93  that are standardized in the context of SavedModel.
94
95  Args:
96    input_tensors: a dict of string key to `Tensor`
97    output_tensors: a dict of string key to `Tensor`
98    problem_type: an instance of constants.ProblemType, specifying
99      classification, regression, etc.
100
101  Returns:
102    A SignatureDef using SavedModel standard keys where possible.
103
104  Raises:
105    ValueError: if input_tensors or output_tensors is None or empty.
106  """
107
108  if not input_tensors:
109    raise ValueError('input_tensors must be provided.')
110  if not output_tensors:
111    raise ValueError('output_tensors must be provided.')
112
113  # Per-method signature_def functions will standardize the keys if possible
114  if _is_classification_problem(problem_type, input_tensors, output_tensors):
115    (_, examples), = input_tensors.items()
116    classes = _get_classification_classes(output_tensors)
117    scores = _get_classification_scores(output_tensors)
118    if classes is None and scores is None:
119      items = list(output_tensors.items())
120      if items[0][1].dtype == dtypes.string:
121        (_, classes), = items
122      else:
123        (_, scores), = items
124    return signature_def_utils.classification_signature_def(
125        examples, classes, scores)
126  elif _is_regression_problem(problem_type, input_tensors, output_tensors):
127    (_, examples), = input_tensors.items()
128    (_, predictions), = output_tensors.items()
129    return signature_def_utils.regression_signature_def(examples, predictions)
130  else:
131    return signature_def_utils.predict_signature_def(input_tensors,
132                                                     output_tensors)
133
134
135def _get_classification_scores(output_tensors):
136  scores = output_tensors.get(prediction_key.PredictionKey.SCORES)
137  if scores is None:
138    scores = output_tensors.get(prediction_key.PredictionKey.PROBABILITIES)
139  return scores
140
141
142def _get_classification_classes(output_tensors):
143  classes = output_tensors.get(prediction_key.PredictionKey.CLASSES)
144  if classes is not None and classes.dtype != dtypes.string:
145    # Servo classification can only serve string classes.
146    return None
147  return classes
148
149
150def _is_classification_problem(problem_type, input_tensors, output_tensors):
151  classes = _get_classification_classes(output_tensors)
152  scores = _get_classification_scores(output_tensors)
153  return ((problem_type == constants.ProblemType.CLASSIFICATION or
154           problem_type == constants.ProblemType.LOGISTIC_REGRESSION) and
155          len(input_tensors) == 1 and
156          (classes is not None or scores is not None or
157           len(output_tensors) == 1))
158
159
160def _is_regression_problem(problem_type, input_tensors, output_tensors):
161  return (problem_type == constants.ProblemType.LINEAR_REGRESSION and
162          len(input_tensors) == 1 and len(output_tensors) == 1)
163
164
165@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.')
166def get_input_alternatives(input_ops):
167  """Obtain all input alternatives using the input_fn output and heuristics."""
168  input_alternatives = {}
169  if isinstance(input_ops, input_fn_utils.InputFnOps):
170    features, unused_labels, default_inputs = input_ops
171    input_alternatives[DEFAULT_INPUT_ALTERNATIVE_KEY] = default_inputs
172  else:
173    features, unused_labels = input_ops
174
175  if not features:
176    raise ValueError('Features must be defined.')
177
178  # TODO(b/34253951): reinstate the "features" input_signature.
179  # The "features" input_signature, as written, does not work with
180  # SparseTensors.  It is simply commented out as a stopgap, pending discussion
181  # on the bug as to the correct solution.
182
183  # Add the "features" input_signature in any case.
184  # Note defensive copy because model_fns alter the features dict.
185  # input_alternatives[FEATURES_INPUT_ALTERNATIVE_KEY] = (
186  #    copy.copy(features))
187
188  return input_alternatives, features
189
190
191@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.')
192def get_output_alternatives(model_fn_ops, default_output_alternative_key=None):
193  """Obtain all output alternatives using the model_fn output and heuristics.
194
195  Args:
196    model_fn_ops: a `ModelFnOps` object produced by a `model_fn`.  This may or
197      may not have output_alternatives populated.
198    default_output_alternative_key: the name of the head to serve when an
199      incoming serving request does not explicitly request a specific head.
200      Not needed for single-headed models.
201
202  Returns:
203    A tuple of (output_alternatives, actual_default_output_alternative_key),
204    where the latter names the head that will actually be served by default.
205    This may differ from the requested default_output_alternative_key when
206    a) no output_alternatives are provided at all, so one must be generated, or
207    b) there is exactly one head, which is used regardless of the requested
208    default.
209
210  Raises:
211    ValueError: if the requested default_output_alternative_key is not available
212      in output_alternatives, or if there are multiple output_alternatives and
213      no default is specified.
214  """
215  output_alternatives = model_fn_ops.output_alternatives
216
217  if not output_alternatives:
218    if default_output_alternative_key:
219      raise ValueError('Requested default_output_alternative: {}, '
220                       'but available output_alternatives are: []'.format(
221                           default_output_alternative_key))
222
223    # Lacking provided output alternatives, the best we can do is to
224    # interpret the model as single-headed of unknown type.
225    default_problem_type = constants.ProblemType.UNSPECIFIED
226    default_outputs = model_fn_ops.predictions
227    if not isinstance(default_outputs, dict):
228      default_outputs = {prediction_key.PredictionKey.GENERIC: default_outputs}
229    actual_default_output_alternative_key = (
230        _FALLBACK_DEFAULT_OUTPUT_ALTERNATIVE_KEY)
231    output_alternatives = {
232        actual_default_output_alternative_key: (default_problem_type,
233                                                default_outputs)
234    }
235    return output_alternatives, actual_default_output_alternative_key
236
237  if default_output_alternative_key:
238    # If a default head is provided, use it.
239    if default_output_alternative_key in output_alternatives:
240      return output_alternatives, default_output_alternative_key
241
242    raise ValueError('Requested default_output_alternative: {}, '
243                     'but available output_alternatives are: {}'.format(
244                         default_output_alternative_key,
245                         sorted(output_alternatives.keys())))
246
247  if len(output_alternatives) == 1:
248    # If there is only one head, use it as the default regardless of its name.
249    (actual_default_output_alternative_key, _), = output_alternatives.items()
250    return output_alternatives, actual_default_output_alternative_key
251
252  raise ValueError('Please specify a default_output_alternative.  '
253                   'Available output_alternatives are: {}'.format(
254                       sorted(output_alternatives.keys())))
255
256
257@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.')
258def build_all_signature_defs(input_alternatives, output_alternatives,
259                             actual_default_output_alternative_key):
260  """Build `SignatureDef`s from all pairs of input and output alternatives."""
261
262  signature_def_map = {('%s:%s' % (input_key, output_key or 'None')):
263                       build_standardized_signature_def(inputs, outputs,
264                                                        problem_type)
265                       for input_key, inputs in input_alternatives.items()
266                       for output_key, (problem_type,
267                                        outputs) in output_alternatives.items()}
268
269  # Add the default SignatureDef
270  default_inputs = input_alternatives.get(DEFAULT_INPUT_ALTERNATIVE_KEY)
271  if not default_inputs:
272    raise ValueError('A default input_alternative must be provided.')
273    # default_inputs = input_alternatives[FEATURES_INPUT_ALTERNATIVE_KEY]
274  # default outputs are guaranteed to exist above
275  (default_problem_type, default_outputs) = (
276      output_alternatives[actual_default_output_alternative_key])
277  signature_def_map[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = (
278      build_standardized_signature_def(default_inputs, default_outputs,
279                                       default_problem_type))
280
281  return signature_def_map
282
283
284# When we create a timestamped directory, there is a small chance that the
285# directory already exists because another worker is also writing exports.
286# In this case we just wait one second to get a new timestamp and try again.
287# If this fails several times in a row, then something is seriously wrong.
288MAX_DIRECTORY_CREATION_ATTEMPTS = 10
289
290
291@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.')
292def get_timestamped_export_dir(export_dir_base):
293  """Builds a path to a new subdirectory within the base directory.
294
295  Each export is written into a new subdirectory named using the
296  current time.  This guarantees monotonically increasing version
297  numbers even across multiple runs of the pipeline.
298  The timestamp used is the number of seconds since epoch UTC.
299
300  Args:
301    export_dir_base: A string containing a directory to write the exported
302        graph and checkpoints.
303  Returns:
304    The full path of the new subdirectory (which is not actually created yet).
305
306  Raises:
307    RuntimeError: if repeated attempts fail to obtain a unique timestamped
308      directory name.
309  """
310  attempts = 0
311  while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS:
312    export_timestamp = int(time.time())
313
314    export_dir = os.path.join(
315        compat.as_bytes(export_dir_base),
316        compat.as_bytes(str(export_timestamp)))
317    if not gfile.Exists(export_dir):
318      # Collisions are still possible (though extremely unlikely): this
319      # directory is not actually created yet, but it will be almost
320      # instantly on return from this function.
321      return export_dir
322    time.sleep(1)
323    attempts += 1
324    logging.warn('Export directory {} already exists; retrying (attempt {}/{})'.
325                 format(export_dir, attempts, MAX_DIRECTORY_CREATION_ATTEMPTS))
326  raise RuntimeError('Failed to obtain a unique export directory name after '
327                     '{} attempts.'.format(MAX_DIRECTORY_CREATION_ATTEMPTS))
328
329
330@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.')
331def get_temp_export_dir(timestamped_export_dir):
332  """Builds a directory name based on the argument but starting with 'temp-'.
333
334  This relies on the fact that TensorFlow Serving ignores subdirectories of
335  the base directory that can't be parsed as integers.
336
337  Args:
338    timestamped_export_dir: the name of the eventual export directory, e.g.
339      /foo/bar/<timestamp>
340
341  Returns:
342    A sister directory prefixed with 'temp-', e.g. /foo/bar/temp-<timestamp>.
343  """
344  (dirname, basename) = os.path.split(timestamped_export_dir)
345  temp_export_dir = os.path.join(
346      compat.as_bytes(dirname),
347      compat.as_bytes('temp-{}'.format(compat.as_text(basename))))
348  return temp_export_dir
349
350
351# create a simple parser that pulls the export_version from the directory.
352def _export_version_parser(path):
353  filename = os.path.basename(path.path)
354  if not (len(filename) == 10 and filename.isdigit()):
355    return None
356  return path._replace(export_version=int(filename))
357
358
359@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.')
360def get_most_recent_export(export_dir_base):
361  """Locate the most recent SavedModel export in a directory of many exports.
362
363  This method assumes that SavedModel subdirectories are named as a timestamp
364  (seconds from epoch), as produced by get_timestamped_export_dir().
365
366  Args:
367    export_dir_base: A base directory containing multiple timestamped
368                     directories.
369
370  Returns:
371    A gc.Path, with is just a namedtuple of (path, export_version).
372  """
373  select_filter = gc.largest_export_versions(1)
374  results = select_filter(
375      gc.get_paths(export_dir_base, parser=_export_version_parser))
376  return next(iter(results or []), None)
377
378
379@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.')
380def garbage_collect_exports(export_dir_base, exports_to_keep):
381  """Deletes older exports, retaining only a given number of the most recent.
382
383  Export subdirectories are assumed to be named with monotonically increasing
384  integers; the most recent are taken to be those with the largest values.
385
386  Args:
387    export_dir_base: the base directory under which each export is in a
388      versioned subdirectory.
389    exports_to_keep: the number of recent exports to retain.
390  """
391  if exports_to_keep is None:
392    return
393
394  keep_filter = gc.largest_export_versions(exports_to_keep)
395  delete_filter = gc.negation(keep_filter)
396  for p in delete_filter(
397      gc.get_paths(export_dir_base, parser=_export_version_parser)):
398    try:
399      gfile.DeleteRecursively(p.path)
400    except errors_impl.NotFoundError as e:
401      logging.warn('Can not delete %s recursively: %s', p.path, e)
402
403
404@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.')
405def make_export_strategy(serving_input_fn,
406                         default_output_alternative_key=None,
407                         assets_extra=None,
408                         as_text=False,
409                         exports_to_keep=5,
410                         strip_default_attrs=None):
411  """Create an ExportStrategy for use with Experiment.
412
413  Args:
414    serving_input_fn: A function that takes no arguments and returns an
415      `InputFnOps`.
416    default_output_alternative_key: the name of the head to serve when an
417      incoming serving request does not explicitly request a specific head.
418      Must be `None` if the estimator inherits from `tf.estimator.Estimator`
419      or for single-headed models.
420    assets_extra: A dict specifying how to populate the assets.extra directory
421      within the exported SavedModel.  Each key should give the destination
422      path (including the filename) relative to the assets.extra directory.
423      The corresponding value gives the full path of the source file to be
424      copied.  For example, the simple case of copying a single file without
425      renaming it is specified as
426      `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.
427    as_text: whether to write the SavedModel proto in text format.
428    exports_to_keep: Number of exports to keep.  Older exports will be
429      garbage-collected.  Defaults to 5.  Set to None to disable garbage
430      collection.
431    strip_default_attrs: Boolean. If True, default attrs in the
432      `GraphDef` will be stripped on write. This is recommended for better
433      forward compatibility of the resulting `SavedModel`.
434
435  Returns:
436    An ExportStrategy that can be passed to the Experiment constructor.
437  """
438
439  def export_fn(estimator, export_dir_base, checkpoint_path=None,
440                strip_default_attrs=False):
441    """Exports the given Estimator as a SavedModel.
442
443    Args:
444      estimator: the Estimator to export.
445      export_dir_base: A string containing a directory to write the exported
446        graph and checkpoints.
447      checkpoint_path: The checkpoint path to export.  If None (the default),
448        the most recent checkpoint found within the model directory is chosen.
449      strip_default_attrs: Boolean. If `True`, default-valued attributes will
450        be removed from the NodeDefs.
451
452    Returns:
453      The string path to the exported directory.
454
455    Raises:
456      ValueError: If `estimator` is a `tf.estimator.Estimator` instance
457        and `default_output_alternative_key` was specified.
458    """
459    if isinstance(estimator, core_estimator.Estimator):
460      if default_output_alternative_key is not None:
461        raise ValueError(
462            'default_output_alternative_key is not supported in core '
463            'Estimator. Given: {}'.format(default_output_alternative_key))
464      export_result = estimator.export_savedmodel(
465          export_dir_base,
466          serving_input_fn,
467          assets_extra=assets_extra,
468          as_text=as_text,
469          checkpoint_path=checkpoint_path,
470          strip_default_attrs=strip_default_attrs)
471    else:
472      export_result = estimator.export_savedmodel(
473          export_dir_base,
474          serving_input_fn,
475          default_output_alternative_key=default_output_alternative_key,
476          assets_extra=assets_extra,
477          as_text=as_text,
478          checkpoint_path=checkpoint_path,
479          strip_default_attrs=strip_default_attrs)
480
481    garbage_collect_exports(export_dir_base, exports_to_keep)
482    return export_result
483
484  return export_strategy.ExportStrategy('Servo', export_fn, strip_default_attrs)
485
486
487@deprecated(None,
488            'Use tf.estimator.export.build_parsing_serving_input_receiver_fn')
489def make_parsing_export_strategy(feature_columns,
490                                 default_output_alternative_key=None,
491                                 assets_extra=None,
492                                 as_text=False,
493                                 exports_to_keep=5,
494                                 target_core=False,
495                                 strip_default_attrs=None):
496  """Create an ExportStrategy for use with Experiment, using `FeatureColumn`s.
497
498  Creates a SavedModel export that expects to be fed with a single string
499  Tensor containing serialized tf.Examples.  At serving time, incoming
500  tf.Examples will be parsed according to the provided `FeatureColumn`s.
501
502  Args:
503    feature_columns: An iterable of `FeatureColumn`s representing the features
504      that must be provided at serving time (excluding labels!).
505    default_output_alternative_key: the name of the head to serve when an
506      incoming serving request does not explicitly request a specific head.
507      Must be `None` if the estimator inherits from `tf.estimator.Estimator`
508      or for single-headed models.
509    assets_extra: A dict specifying how to populate the assets.extra directory
510      within the exported SavedModel.  Each key should give the destination
511      path (including the filename) relative to the assets.extra directory.
512      The corresponding value gives the full path of the source file to be
513      copied.  For example, the simple case of copying a single file without
514      renaming it is specified as
515      `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.
516    as_text: whether to write the SavedModel proto in text format.
517    exports_to_keep: Number of exports to keep.  Older exports will be
518      garbage-collected.  Defaults to 5.  Set to None to disable garbage
519      collection.
520    target_core: If True, prepare an ExportStrategy for use with
521      tensorflow.python.estimator.*.  If False (default), prepare an
522      ExportStrategy for use with tensorflow.contrib.learn.python.learn.*.
523    strip_default_attrs: Boolean. If True, default attrs in the
524      `GraphDef` will be stripped on write. This is recommended for better
525      forward compatibility of the resulting `SavedModel`.
526
527  Returns:
528    An ExportStrategy that can be passed to the Experiment constructor.
529  """
530  feature_spec = feature_column.create_feature_spec_for_parsing(feature_columns)
531  if target_core:
532    serving_input_fn = (
533        core_export.build_parsing_serving_input_receiver_fn(feature_spec))
534  else:
535    serving_input_fn = (
536        input_fn_utils.build_parsing_serving_input_fn(feature_spec))
537  return make_export_strategy(
538      serving_input_fn,
539      default_output_alternative_key=default_output_alternative_key,
540      assets_extra=assets_extra,
541      as_text=as_text,
542      exports_to_keep=exports_to_keep,
543      strip_default_attrs=strip_default_attrs)
544
545
546def _default_compare_fn(curr_best_eval_result, cand_eval_result):
547  """Compares two evaluation results and returns true if the 2nd one is better.
548
549  Both evaluation results should have the values for MetricKey.LOSS, which are
550  used for comparison.
551
552  Args:
553    curr_best_eval_result: current best eval metrics.
554    cand_eval_result: candidate eval metrics.
555
556  Returns:
557    True if cand_eval_result is better.
558
559  Raises:
560    ValueError: If input eval result is None or no loss is available.
561  """
562  default_key = metric_key.MetricKey.LOSS
563  if not curr_best_eval_result or default_key not in curr_best_eval_result:
564    raise ValueError(
565        'curr_best_eval_result cannot be empty or no loss is found in it.')
566
567  if not cand_eval_result or default_key not in cand_eval_result:
568    raise ValueError(
569        'cand_eval_result cannot be empty or no loss is found in it.')
570
571  return curr_best_eval_result[default_key] > cand_eval_result[default_key]
572
573
574class BestModelSelector(object):
575  """A helper that keeps track of export selection candidates.
576
577  THIS CLASS IS DEPRECATED. See
578  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
579  for general migration instructions.
580  """
581
582  @deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.')
583  def __init__(self, event_file_pattern=None, compare_fn=None):
584    """Constructor of this class.
585
586    Args:
587      event_file_pattern: absolute event file name pattern.
588      compare_fn: a function that returns true if the candidate is better than
589        the current best model.
590    """
591    self._compare_fn = compare_fn or _default_compare_fn
592    self._best_eval_result = self._get_best_eval_result(event_file_pattern)
593
594  def update(self, checkpoint_path, eval_result):
595    """Records a given checkpoint and exports if this is the best model.
596
597    Args:
598      checkpoint_path: the checkpoint path to export.
599      eval_result: a dictionary which is usually generated in evaluation runs.
600        By default, eval_results contains 'loss' field.
601
602    Returns:
603      A string representing the path to the checkpoint to be exported.
604      A dictionary of the same type of eval_result.
605
606    Raises:
607      ValueError: if checkpoint path is empty.
608      ValueError: if eval_results is None object.
609    """
610    if not checkpoint_path:
611      raise ValueError('Checkpoint path is empty.')
612    if eval_result is None:
613      raise ValueError('%s has empty evaluation results.', checkpoint_path)
614
615    if (self._best_eval_result is None or
616        self._compare_fn(self._best_eval_result, eval_result)):
617      self._best_eval_result = eval_result
618      return checkpoint_path, eval_result
619    else:
620      return '', None
621
622  def _get_best_eval_result(self, event_files):
623    """Get the best eval result from event files.
624
625    Args:
626      event_files: Absolute pattern of event files.
627
628    Returns:
629      The best eval result.
630    """
631    if not event_files:
632      return None
633
634    best_eval_result = None
635    for event_file in gfile.Glob(os.path.join(event_files)):
636      for event in summary_iterator.summary_iterator(event_file):
637        if event.HasField('summary'):
638          event_eval_result = {}
639          for value in event.summary.value:
640            if value.HasField('simple_value'):
641              event_eval_result[value.tag] = value.simple_value
642          if best_eval_result is None or self._compare_fn(
643              best_eval_result, event_eval_result):
644            best_eval_result = event_eval_result
645    return best_eval_result
646
647
648@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.')
649def make_best_model_export_strategy(
650    serving_input_fn,
651    exports_to_keep=1,
652    model_dir=None,
653    event_file_pattern=None,
654    compare_fn=None,
655    default_output_alternative_key=None,
656    strip_default_attrs=None):
657  """Creates an custom ExportStrategy for use with tf.contrib.learn.Experiment.
658
659  Args:
660    serving_input_fn: a function that takes no arguments and returns an
661      `InputFnOps`.
662    exports_to_keep: an integer indicating how many historical best models need
663      to be preserved.
664    model_dir: Directory where model parameters, graph etc. are saved. This will
665        be used to load eval metrics from the directory when the export strategy
666        is created. So the best metrics would not be lost even if the export
667        strategy got preempted, which guarantees that only the best model would
668        be exported regardless of preemption. If None, however, the export
669        strategy would not be preemption-safe. To be preemption-safe, both
670        model_dir and event_file_pattern would be needed.
671    event_file_pattern: event file name pattern relative to model_dir, e.g.
672        "eval_continuous/*.tfevents.*". If None, however, the export strategy
673        would not be preemption-safe. To be preemption-safe, both
674        model_dir and event_file_pattern would be needed.
675    compare_fn: a function that select the 'best' candidate from a dictionary
676        of evaluation result keyed by corresponding checkpoint path.
677    default_output_alternative_key: the key for default serving signature for
678        multi-headed inference graphs.
679    strip_default_attrs: Boolean. If True, default attrs in the
680      `GraphDef` will be stripped on write. This is recommended for better
681      forward compatibility of the resulting `SavedModel`.
682
683  Returns:
684    An ExportStrategy that can be passed to the Experiment constructor.
685  """
686  best_model_export_strategy = make_export_strategy(
687      serving_input_fn,
688      exports_to_keep=exports_to_keep,
689      default_output_alternative_key=default_output_alternative_key,
690      strip_default_attrs=strip_default_attrs)
691
692  full_event_file_pattern = os.path.join(
693      model_dir,
694      event_file_pattern) if model_dir and event_file_pattern else None
695  best_model_selector = BestModelSelector(full_event_file_pattern, compare_fn)
696
697  def export_fn(estimator, export_dir_base, checkpoint_path, eval_result=None):
698    """Exports the given Estimator as a SavedModel.
699
700    Args:
701      estimator: the Estimator to export.
702      export_dir_base: A string containing a directory to write the exported
703        graph and checkpoints.
704      checkpoint_path: The checkpoint path to export.  If None (the default),
705        the most recent checkpoint found within the model directory is chosen.
706      eval_result: placehold args matching the call signature of ExportStrategy.
707
708    Returns:
709      The string path to the exported directory.
710    """
711    if not checkpoint_path:
712      # TODO(b/67425018): switch to
713      #    checkpoint_path = estimator.latest_checkpoint()
714      #  as soon as contrib is cleaned up and we can thus be sure that
715      #  estimator is a tf.estimator.Estimator and not a
716      #  tf.contrib.learn.Estimator
717      checkpoint_path = checkpoint_management.latest_checkpoint(
718          estimator.model_dir)
719    export_checkpoint_path, export_eval_result = best_model_selector.update(
720        checkpoint_path, eval_result)
721
722    if export_checkpoint_path and export_eval_result is not None:
723      checkpoint_base = os.path.basename(export_checkpoint_path)
724      export_dir = os.path.join(export_dir_base, checkpoint_base)
725      return best_model_export_strategy.export(
726          estimator, export_dir, export_checkpoint_path, export_eval_result)
727    else:
728      return ''
729
730  return export_strategy.ExportStrategy('best_model', export_fn)
731
732
733# TODO(b/67013778): Revisit this approach when corresponding changes to
734# TF Core are finalized.
735@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.')
736def extend_export_strategy(base_export_strategy,
737                           post_export_fn,
738                           post_export_name=None):
739  """Extend ExportStrategy, calling post_export_fn after export.
740
741  Args:
742    base_export_strategy: An ExportStrategy that can be passed to the Experiment
743      constructor.
744    post_export_fn: A user-specified function to call after exporting the
745      SavedModel. Takes two arguments - the path to the SavedModel exported by
746      base_export_strategy and the directory where to export the SavedModel
747      modified by the post_export_fn. Returns the path to the exported
748      SavedModel.
749    post_export_name: The directory name under the export base directory where
750      SavedModels generated by the post_export_fn will be written. If None, the
751      directory name of base_export_strategy is used.
752
753  Returns:
754    An ExportStrategy that can be passed to the Experiment constructor.
755  """
756  def export_fn(estimator, export_dir_base, checkpoint_path=None):
757    """Exports the given Estimator as a SavedModel and invokes post_export_fn.
758
759    Args:
760      estimator: the Estimator to export.
761      export_dir_base: A string containing a directory to write the exported
762        graphs and checkpoint.
763      checkpoint_path: The checkpoint path to export. If None (the default),
764        the most recent checkpoint found within the model directory is chosen.
765
766    Returns:
767      The string path to the SavedModel indicated by post_export_fn.
768
769    Raises:
770      ValueError: If `estimator` is a `tf.estimator.Estimator` instance
771        and `default_output_alternative_key` was specified or if post_export_fn
772        does not return a valid directory.
773      RuntimeError: If unable to create temporary or final export directory.
774    """
775    tmp_base_export_folder = 'temp-base-export-' + str(int(time.time()))
776    tmp_base_export_dir = os.path.join(export_dir_base, tmp_base_export_folder)
777    if gfile.Exists(tmp_base_export_dir):
778      raise RuntimeError('Failed to obtain base export directory')
779    gfile.MakeDirs(tmp_base_export_dir)
780    tmp_base_export = base_export_strategy.export(
781        estimator, tmp_base_export_dir, checkpoint_path)
782
783    tmp_post_export_folder = 'temp-post-export-' + str(int(time.time()))
784    tmp_post_export_dir = os.path.join(export_dir_base, tmp_post_export_folder)
785    if gfile.Exists(tmp_post_export_dir):
786      raise RuntimeError('Failed to obtain temp export directory')
787
788    gfile.MakeDirs(tmp_post_export_dir)
789    tmp_post_export = post_export_fn(tmp_base_export, tmp_post_export_dir)
790
791    if not tmp_post_export.startswith(tmp_post_export_dir):
792      raise ValueError('post_export_fn must return a sub-directory of {}'
793                       .format(tmp_post_export_dir))
794    post_export_relpath = os.path.relpath(tmp_post_export, tmp_post_export_dir)
795    post_export = os.path.join(export_dir_base, post_export_relpath)
796    if gfile.Exists(post_export):
797      raise RuntimeError('Failed to obtain final export directory')
798    gfile.Rename(tmp_post_export, post_export)
799
800    gfile.DeleteRecursively(tmp_base_export_dir)
801    gfile.DeleteRecursively(tmp_post_export_dir)
802    return post_export
803
804  name = post_export_name if post_export_name else base_export_strategy.name
805  return export_strategy.ExportStrategy(name, export_fn)
806