1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes and functions related to train_and_evaluate."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import json
23import os
24import time
25
26import six
27
28from tensorflow.core.protobuf import config_pb2
29from tensorflow.python.estimator import estimator as estimator_lib
30from tensorflow.python.estimator import exporter as exporter_lib
31from tensorflow.python.estimator import run_config as run_config_lib
32from tensorflow.python.framework import ops
33from tensorflow.python.platform import tf_logging as logging
34from tensorflow.python.training import basic_session_run_hooks
35from tensorflow.python.training import server_lib
36from tensorflow.python.training import session_run_hook
37from tensorflow.python.util import compat
38from tensorflow.python.util.tf_export import tf_export
39
40_MAX_DELAY_SECS = 60
41_DELAY_SECS_PER_WORKER = 5
42_TF_CONFIG_ENV = 'TF_CONFIG'
43_ENVIRONMENT_KEY = 'environment'
44_ENVIRONMENT_GOOGLE_VALUE = 'google'
45_TRAINER_JOBS = (run_config_lib.TaskType.CHIEF, run_config_lib.TaskType.MASTER,
46                 run_config_lib.TaskType.WORKER)
47
48
49def _validate_input_fn(input_fn):
50  """Validates the `input_fn`."""
51  if not callable(input_fn):
52    raise TypeError('`input_fn` must be callable, given: {}'.format(input_fn))
53
54
55def _validate_hooks(hooks):
56  """Validates the `hooks`."""
57  hooks = tuple(hooks or [])
58  for hook in hooks:
59    if not isinstance(hook, session_run_hook.SessionRunHook):
60      raise TypeError(
61          'All hooks must be `SessionRunHook` instances, given: {}'.format(
62              hook))
63  return hooks
64
65
66def _validate_exporters(exporters):
67  """Validates `exporters` and returns them as a tuple."""
68  if not exporters:
69    return ()
70
71  if isinstance(exporters, exporter_lib.Exporter):
72    exporters = [exporters]
73
74  unique_names = []  # `Exporter`s should have unique names.
75  try:
76    for exporter in exporters:
77      if not isinstance(exporter, exporter_lib.Exporter):
78        # Error message will be printed out by the outer try/except.
79        raise TypeError
80
81      if not exporter.name:
82        full_list_of_names = [e.name for e in exporters]
83        raise ValueError('An Exporter cannot have a name that is `None` or'
84                         ' empty. All exporter names:'
85                         ' {}'.format(full_list_of_names))
86
87      if not isinstance(exporter.name, six.string_types):
88        raise ValueError('An Exporter must have a string name. Given: '
89                         '{}'.format(type(exporter.name)))
90
91      if exporter.name in unique_names:
92        full_list_of_names = [e.name for e in exporters]
93        raise ValueError(
94            '`exporters` must have unique names. Such a name cannot be `None`.'
95            ' All exporter names: {}'.format(full_list_of_names))
96      unique_names.append(exporter.name)
97  except TypeError:
98    # Two possibilities:
99    # - `exporters` is neither `Exporter` nor iterable.  Python has
100    #   raised a `TypeError` when iterating over `exporters`.
101    # - an `exporter` was None or not of type `Exporter`, so we raised a
102    #   `TypeError`.
103    raise TypeError('`exporters` must be an Exporter,'
104                    ' an iterable of Exporter, or `None`,'
105                    ' found %s.' % exporters)
106
107  return tuple(exporters)
108
109
110def _is_google_env():
111  """Detects whether current environment is google."""
112  tf_config = json.loads(os.environ.get(_TF_CONFIG_ENV) or '{}')
113  if not tf_config:
114    logging.warn('TF_CONFIG should not be empty in distributed environment.')
115  return tf_config.get(_ENVIRONMENT_KEY) == _ENVIRONMENT_GOOGLE_VALUE
116
117
118@tf_export('estimator.TrainSpec')
119class TrainSpec(
120    collections.namedtuple('TrainSpec', ['input_fn', 'max_steps', 'hooks'])):
121  """Configuration for the "train" part for the `train_and_evaluate` call.
122
123  `TrainSpec` determines the input data for the training, as well as the
124  duration. Optional hooks run at various stages of training.
125  """
126
127  def __new__(cls, input_fn, max_steps=None, hooks=None):
128    """Creates a validated `TrainSpec` instance.
129
130    Args:
131      input_fn: Training input function returning a tuple of:
132          features - `Tensor` or dictionary of string feature name to `Tensor`.
133          labels - `Tensor` or dictionary of `Tensor` with labels.
134      max_steps: Int. Positive number of total steps for which to train model.
135        If `None`, train forever. The training `input_fn` is not expected to
136        generate `OutOfRangeError` or `StopIteration` exceptions. See the
137        `train_and_evaluate` stop condition section for details.
138      hooks: Iterable of `tf.train.SessionRunHook` objects to run
139        on all workers (including chief) during training.
140
141    Returns:
142      A validated `TrainSpec` object.
143
144    Raises:
145      ValueError: If any of the input arguments is invalid.
146      TypeError: If any of the arguments is not of the expected type.
147    """
148    # Validate input_fn.
149    _validate_input_fn(input_fn)
150
151    # Validate max_steps.
152    if max_steps is not None and max_steps <= 0:
153      raise ValueError(
154          'Must specify max_steps > 0, given: {}'.format(max_steps))
155
156    # Validate hooks.
157    hooks = _validate_hooks(hooks)
158
159    return super(TrainSpec, cls).__new__(
160        cls, input_fn=input_fn, max_steps=max_steps, hooks=hooks)
161
162
163@tf_export('estimator.EvalSpec')
164class EvalSpec(
165    collections.namedtuple('EvalSpec', [
166        'input_fn', 'steps', 'name', 'hooks', 'exporters', 'start_delay_secs',
167        'throttle_secs'
168    ])):
169  """Configuration for the "eval" part for the `train_and_evaluate` call.
170
171  `EvalSpec` combines details of evaluation of the trained model as well as its
172  export. Evaluation consists of computing metrics to judge the performance of
173  the trained model.  Export writes out the trained model on to external
174  storage.
175  """
176
177  def __new__(cls,
178              input_fn,
179              steps=100,
180              name=None,
181              hooks=None,
182              exporters=None,
183              start_delay_secs=120,
184              throttle_secs=600):
185    """Creates a validated `EvalSpec` instance.
186
187    Args:
188      input_fn: Evaluation input function returning a tuple of:
189          features - `Tensor` or dictionary of string feature name to `Tensor`.
190          labels - `Tensor` or dictionary of `Tensor` with labels.
191      steps: Int. Positive number of steps for which to evaluate model. If
192        `None`, evaluates until `input_fn` raises an end-of-input exception.
193        See `Estimator.evaluate` for details.
194      name: String. Name of the evaluation if user needs to run multiple
195        evaluations on different data sets. Metrics for different evaluations
196        are saved in separate folders, and appear separately in tensorboard.
197      hooks: Iterable of `tf.train.SessionRunHook` objects to run
198        during evaluation.
199      exporters: Iterable of `Exporter`s, or a single one, or `None`.
200        `exporters` will be invoked after each evaluation.
201      start_delay_secs: Int. Start evaluating after waiting for this many
202        seconds.
203      throttle_secs: Int. Do not re-evaluate unless the last evaluation was
204        started at least this many seconds ago. Of course, evaluation does not
205        occur if no new checkpoints are available, hence, this is the minimum.
206
207    Returns:
208      A validated `EvalSpec` object.
209
210    Raises:
211      ValueError: If any of the input arguments is invalid.
212      TypeError: If any of the arguments is not of the expected type.
213    """
214    # Validate input_fn.
215    _validate_input_fn(input_fn)
216
217    # Validate steps.
218    if steps is not None and steps <= 0:
219      raise ValueError('Must specify steps > 0, given: {}'.format(steps))
220
221    # Validate name.
222    if name is not None and not isinstance(name, six.string_types):
223      raise TypeError('`name` must be string, given: {}'.format(name))
224
225    # Validate hooks.
226    hooks = _validate_hooks(hooks)
227
228    # Validate exporters.
229    exporters = _validate_exporters(exporters)
230
231    # Validate start_delay_secs.
232    if start_delay_secs < 0:
233      raise ValueError('Must specify start_delay_secs >= 0, given: {}'.format(
234          start_delay_secs))
235
236    # Validate throttle_secs.
237    if throttle_secs < 0:
238      raise ValueError(
239          'Must specify throttle_secs >= 0, given: {}'.format(throttle_secs))
240
241    return super(EvalSpec, cls).__new__(
242        cls,
243        input_fn=input_fn,
244        steps=steps,
245        name=name,
246        hooks=hooks,
247        exporters=exporters,
248        start_delay_secs=start_delay_secs,
249        throttle_secs=throttle_secs)
250
251
252@tf_export('estimator.train_and_evaluate')
253def train_and_evaluate(estimator, train_spec, eval_spec):
254  """Train and evaluate the `estimator`.
255
256  This utility function trains, evaluates, and (optionally) exports the model by
257  using the given `estimator`. All training related specification is held in
258  `train_spec`, including training `input_fn` and training max steps, etc. All
259  evaluation and export related specification is held in `eval_spec`, including
260  evaluation `input_fn`, steps, etc.
261
262  This utility function provides consistent behavior for both local
263  (non-distributed) and distributed configurations. Currently, the only
264  supported distributed training configuration is between-graph replication.
265
266  Overfitting: In order to avoid overfitting, it is recommended to set up the
267  training `input_fn` to shuffle the training data properly. It is also
268  recommended to train the model a little longer, say multiple epochs, before
269  performing evaluation, as the input pipeline starts from scratch for each
270  training. It is particularly important for local training and evaluation.
271
272  Stop condition: In order to support both distributed and non-distributed
273  configuration reliably, the only supported stop condition for model
274  training is `train_spec.max_steps`. If `train_spec.max_steps` is `None`, the
275  model is trained forever. *Use with care* if model stop condition is
276  different. For example, assume that the model is expected to be trained with
277  one epoch of training data, and the training `input_fn` is configured to throw
278  `OutOfRangeError` after going through one epoch, which stops the
279  `Estimator.train`. For a three-training-worker distributed configuration, each
280  training worker is likely to go through the whole epoch independently. So, the
281  model will be trained with three epochs of training data instead of one epoch.
282
283  Example of local (non-distributed) training:
284  ```python
285  # Set up feature columns.
286  categorial_feature_a = categorial_column_with_hash_bucket(...)
287  categorial_feature_a_emb = embedding_column(
288      categorical_column=categorial_feature_a, ...)
289  ...  # other feature columns
290
291  estimator = DNNClassifier(
292      feature_columns=[categorial_feature_a_emb, ...],
293      hidden_units=[1024, 512, 256])
294
295  # Or set up the model directory
296  #   estimator = DNNClassifier(
297  #       config=tf.estimator.RunConfig(
298  #           model_dir='/my_model', save_summary_steps=100),
299  #       feature_columns=[categorial_feature_a_emb, ...],
300  #       hidden_units=[1024, 512, 256])
301
302  # Input pipeline for train and evaluate.
303  def train_input_fn: # returns x, y
304    # please shuffle the data.
305    pass
306  def eval_input_fn_eval: # returns x, y
307    pass
308
309  train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=1000)
310  eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
311
312  tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
313  ```
314
315  Example of distributed training:
316
317  Regarding the example of distributed training, the code above can be used
318  without a change (Please do make sure that the `RunConfig.model_dir` for all
319  workers is set to the same directory, i.e., a shared file system all workers
320  can read and write). The only extra work to do is setting the environment
321  variable `TF_CONFIG` properly for each worker correspondingly.
322
323  Also see: https://www.tensorflow.org/deploy/distributed
324
325  Setting environment variable depends on the platform. For example, on Linux,
326  it can be done as follows (`$` is the shell prompt):
327  ```
328  $ TF_CONFIG='<replace_with_real_content>' python train_model.py
329  ```
330
331  For the content in `TF_CONFIG`, assume that the training cluster spec looks
332  like:
333  ```
334  cluster = {"chief": ["host0:2222"],
335             "worker": ["host1:2222", "host2:2222", "host3:2222"],
336             "ps": ["host4:2222", "host5:2222"]}
337  ```
338
339  Example of `TF_CONFIG` for chief training worker (must have one and only one):
340  ```
341  # This should be a JSON string, which is set as environment variable. Usually
342  # the cluster manager handles that.
343  TF_CONFIG='{
344      "cluster": {
345          "chief": ["host0:2222"],
346          "worker": ["host1:2222", "host2:2222", "host3:2222"],
347          "ps": ["host4:2222", "host5:2222"]
348      },
349      "task": {"type": "chief", "index": 0}
350  }'
351  ```
352  Note that the chief worker also does the model training job, similar to other
353  non-chief training workers (see next paragraph). In addition to the model
354  training, it manages some extra work, e.g., checkpoint saving and restoring,
355  writing summaries, etc.
356
357  Example of `TF_CONFIG` for non-chief training worker (optional, could be
358  multiple):
359  ```
360  # This should be a JSON string, which is set as environment variable. Usually
361  # the cluster manager handles that.
362  TF_CONFIG='{
363      "cluster": {
364          "chief": ["host0:2222"],
365          "worker": ["host1:2222", "host2:2222", "host3:2222"],
366          "ps": ["host4:2222", "host5:2222"]
367      },
368      "task": {"type": "worker", "index": 0}
369  }'
370  ```
371  where the `task.index` should be set as 0, 1, 2, in this example, respectively
372  for non-chief training workers.
373
374  Example of `TF_CONFIG` for parameter server, aka ps (could be multiple):
375  ```
376  # This should be a JSON string, which is set as environment variable. Usually
377  # the cluster manager handles that.
378  TF_CONFIG='{
379      "cluster": {
380          "chief": ["host0:2222"],
381          "worker": ["host1:2222", "host2:2222", "host3:2222"],
382          "ps": ["host4:2222", "host5:2222"]
383      },
384      "task": {"type": "ps", "index": 0}
385  }'
386  ```
387  where the `task.index` should be set as 0 and 1, in this example, respectively
388  for parameter servers.
389
390  Example of `TF_CONFIG` for evaluator task. Evaluator is a special task that is
391  not part of the training cluster. There could be only one. It is used for
392  model evaluation.
393  ```
394  # This should be a JSON string, which is set as environment variable. Usually
395  # the cluster manager handles that.
396  TF_CONFIG='{
397      "cluster": {
398          "chief": ["host0:2222"],
399          "worker": ["host1:2222", "host2:2222", "host3:2222"],
400          "ps": ["host4:2222", "host5:2222"]
401      },
402      "task": {"type": "evaluator", "index": 0}
403  }'
404  ```
405
406  Args:
407    estimator: An `Estimator` instance to train and evaluate.
408    train_spec: A `TrainSpec` instance to specify the training specification.
409    eval_spec: A `EvalSpec` instance to specify the evaluation and export
410      specification.
411
412  Raises:
413    ValueError: if environment variable `TF_CONFIG` is incorrectly set.
414  """
415  executor = _TrainingExecutor(
416      estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)
417
418  config = estimator.config
419  if (config.task_type == run_config_lib.TaskType.EVALUATOR and
420      config.task_id > 0):
421    raise ValueError(
422        'For distributed training, there can only be one `evaluator` task '
423        '(with task id 0).  Given task id {}'.format(config.task_id))
424
425  executor.run()
426
427
428class _StopAtSecsHook(session_run_hook.SessionRunHook):
429  """Stops given secs after begin is called."""
430
431  def __init__(self, stop_after_secs):
432    self._stop_after_secs = stop_after_secs
433    self._start_time = None
434
435  def begin(self):
436    self._start_time = time.time()
437
438  def after_run(self, run_context, run_values):
439    del run_values
440    if time.time() - self._start_time >= self._stop_after_secs:
441      run_context.request_stop()
442
443
444class _TrainingExecutor(object):
445  """The executor to run `Estimator` training and evaluation.
446
447  This implementation supports both distributed and non-distributed (aka local)
448  training and evaluation based on the setting in `tf.estimator.RunConfig`.
449  """
450
451  def __init__(self,
452               estimator,
453               train_spec,
454               eval_spec,
455               train_hooks=None,
456               continuous_eval_listener=None):
457    if not isinstance(estimator, estimator_lib.Estimator):
458      raise TypeError('`estimator` must have type `tf.estimator.Estimator`.')
459    self._estimator = estimator
460
461    if not isinstance(train_spec, TrainSpec):
462      raise TypeError('`train_spec` must have type `tf.estimator.TrainSpec`.')
463    self._train_spec = train_spec
464
465    if not isinstance(eval_spec, EvalSpec):
466      raise TypeError('`eval_spec` must have type `tf.estimator.EvalSpec`.')
467    self._eval_spec = eval_spec
468
469    self._train_hooks = _validate_hooks(train_hooks)
470
471    if (continuous_eval_listener and
472        not isinstance(continuous_eval_listener, _ContinuousEvalListener)):
473      raise TypeError('`continuous_eval_listener` must have type '
474                      '`_ContinuousEvalListener`.')
475    self._continuous_eval_listener = (
476        continuous_eval_listener or _ContinuousEvalListener())
477
478  @property
479  def estimator(self):
480    return self._estimator
481
482  def run(self):
483    """Executes the run_foo for task type `foo`.
484
485    `_TrainingExecutor` predefines the procedure for task type 'chief',
486    'worker', 'ps', and 'evaluator'. For task type `foo`, the corresponding
487    procedure is `run_foo'. This `run` method invoke the procedure base on the
488    `RunConfig.task_type`.
489
490    Raises:
491      ValueError: if the estimator.config is mis-configured.
492    """
493    config = self._estimator.config
494
495    if (not config.cluster_spec and
496        config.task_type != run_config_lib.TaskType.EVALUATOR):
497      logging.info('Running training and evaluation locally (non-distributed).')
498      self.run_local()
499      return
500
501    # Distributed case.
502    if not config.task_type:
503      # TODO(xiejw): Improve the error message about how to set the TF_CONFIG
504      # correctly.
505      raise ValueError(
506          '`estimator.config` must have task_type set. This usually means '
507          'TF_CONFIG environment is not set correctly.')
508
509    if config.task_type == 'local':
510      raise ValueError(
511          '`task.type` in TF_CONFIG cannot be `local`. Leaving `cluster` and '
512          '`task` properties in TF_CONFIG absent triggers train and evaluate '
513          '`Estimator` locally (non-distributed).')
514
515    # For task type foo, call executor.run_foo.
516    available_tasks = [
517        x for x in dir(self)
518        if x.startswith('run_') and x != 'run_local' and
519        callable(getattr(self, x))
520    ]
521    task_to_run = 'run_' + config.task_type
522    if task_to_run not in available_tasks:
523      raise ValueError(
524          'Task type {} is not supported. Supported task types are {}'.format(
525              config.task_type, [x[len('run_'):] for x in available_tasks]))
526    getattr(self, task_to_run)()
527
528  def run_chief(self):
529    """Runs task chief."""
530    # TODO(xiejw): To allow execution framework to add train hooks.
531    return self._start_distributed_training()
532
533  def run_worker(self):
534    """Runs task (training) worker."""
535    # TODO(xiejw): To allow execution framework to add train hooks.
536    return self._start_distributed_training()
537
538  def run_master(self):
539    """Runs task master."""
540
541    class NewCheckpointListener(
542        basic_session_run_hooks.CheckpointSaverListener):
543
544      def __init__(self, evaluator, eval_throttle_secs):
545        self._evaluator = evaluator
546        self._eval_throttle_secs = eval_throttle_secs
547
548      def begin(self):
549        self._timer = basic_session_run_hooks.SecondOrStepTimer(
550            every_secs=self._eval_throttle_secs)
551
552      def after_save(self, session, global_step_value):
553        del session  # unused; required by signature.
554
555        if self._timer.should_trigger_for_step(global_step_value):
556          self._timer.update_last_triggered_step(global_step_value)
557          self._evaluator.evaluate_and_export()
558        else:
559          logging.info('Skip the current checkpoint eval due to throttle secs '
560                       '({} secs).'.format(self._eval_throttle_secs))
561
562    # Final export signal: For any eval result with global_step >= train
563    # max_steps, the evaluator will send the final export signal. There is a
564    # small chance that the Estimator.train stopping logic sees a different
565    # global_step value (due to global step race condition and the fact the
566    # saver sees a larger value for checkpoing saving), which does not end
567    # the training. When the training ends, a new checkpoint is generated, which
568    # triggers the listener again. So, it could be the case the final export is
569    # triggered twice.
570    #
571    # But here, throttle_secs will skip the next intermediate checkpoint and,
572    # so, the double final export chance is very small.
573    evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec,
574                                             self._train_spec.max_steps)
575
576    # When the underlying `Estimator` object saves a new checkpoint, we would
577    # like this callback to be called so that evaluation and export can trigger.
578    saving_listeners = [
579        NewCheckpointListener(evaluator, self._eval_spec.throttle_secs)
580    ]
581    self._start_distributed_training(saving_listeners=saving_listeners)
582
583    if not evaluator.is_final_export_triggered:
584      logging.info('Training has already ended. But the last eval is skipped '
585                   'due to eval throttle_secs. Now evaluating the final '
586                   'checkpoint.')
587      evaluator.evaluate_and_export()
588
589  def run_evaluator(self):
590    """Runs task evaluator."""
591    # TODO(xiejw): To allow execution framework to add continuous eval listener.
592    return self._start_continuous_evaluation()
593
594  def run_ps(self):
595    """Runs task parameter server (in training cluster spec)."""
596    config = self._estimator.config
597    server = self._start_std_server(config)
598    server.join()
599
600  def run_local(self):
601    """Runs training and evaluation locally (non-distributed)."""
602
603    def _should_stop_local_train(global_step):
604      if self._train_spec.max_steps is None:
605        return False
606      if global_step >= self._train_spec.max_steps:
607        return True
608      return False
609
610    if self._eval_spec.throttle_secs <= 0:
611      raise ValueError('eval_spec.throttle_secs should be positive, given: {}.'
612                       'It is used do determine how long each training '
613                       'iteration should go when train and evaluate '
614                       'locally.'.format(self._eval_spec.throttle_secs))
615
616    stop_hook = _StopAtSecsHook(self._eval_spec.throttle_secs)
617    train_hooks = (
618        list(self._train_spec.hooks) + [stop_hook] + list(self._train_hooks))
619    logging.info('Start train and evaluate loop. The evaluate will happen '
620                 'after {} secs (eval_spec.throttle_secs) or training is '
621                 'finished.'.format(self._eval_spec.throttle_secs))
622
623    evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec,
624                                             self._train_spec.max_steps)
625
626    while True:
627      self._estimator.train(
628          input_fn=self._train_spec.input_fn,
629          max_steps=self._train_spec.max_steps,
630          hooks=train_hooks)
631
632      # Final export signal: For any eval result with global_step >= train
633      # max_steps, the evaluator will send the final export signal. The
634      # _should_stop_local_train will then end the while True as the stopping
635      # condition is satisfied (both checks use the same global_step value,
636      # i.e., no race condition)
637      eval_result = evaluator.evaluate_and_export()
638
639      if eval_result.status != _EvalStatus.EVALUATED:
640        #  This is unexpected; should never happen.
641        #  Training should always end with a new checkpoint.
642        raise RuntimeError('There was no new checkpoint after the training. '
643                           'Eval status: {}'.format(eval_result.status))
644
645      if _should_stop_local_train(
646          eval_result.metrics[ops.GraphKeys.GLOBAL_STEP]):
647        break
648
649  def _start_std_server(self, config):
650    """Creates, starts, and returns a server_lib.Server."""
651    if (not config.cluster_spec or not config.task_type or
652        config.task_id is None):
653      raise RuntimeError('Could not start server; be sure to specify '
654                         'cluster_spec, task_type, and task in '
655                         'RunConfig or set the TF_CONFIG environment variable.')
656
657    if not config.master:
658      jobs = config.cluster_spec.jobs
659      if (len(jobs) == 1 and
660          len(config.cluster_spec.job_tasks(jobs[0])) == 1 and
661          config.task_type in _TRAINER_JOBS):
662        # For distributed training, config.master is empty if and only if it has
663        # a single node in the cluster spec. In this case, we should not start
664        # the server.
665        logging.info('Skip starting Tensorflow server as there is only one '
666                     'node in the cluster.')
667        return
668      else:
669        raise RuntimeError(
670            'Could not start server; be sure to specify master in '
671            'RunConfig or set the TF_CONFIG environment variable.')
672
673    logging.info('Start Tensorflow server.')
674
675    if config.session_config is None:
676      session_config = config_pb2.ConfigProto(log_device_placement=False)
677    else:
678      session_config = config_pb2.ConfigProto(
679          log_device_placement=False,
680          gpu_options=config.session_config.gpu_options)
681
682    server = server_lib.Server(
683        config.cluster_spec,
684        job_name=config.task_type,
685        task_index=config.task_id,
686        config=session_config,
687        start=False)
688    server.start()
689    return server
690
691  def _start_distributed_training(self, saving_listeners=None):
692    """Calls `Estimator` train in a distributed setting."""
693    config = self._estimator.config
694
695    # Start in-process TensorFlow server if needed. It's important to start the
696    # server before we (optionally) sleep. Otherwise, the servers will wait to
697    # connect to each other before starting to train.
698    if not _is_google_env():
699      self._start_std_server(config)
700
701    # Delay worker to start. For asynchronous training, this usually helps model
702    # to converge faster.  Chief starts the training immediately, so, worker
703    # with task id x (0-based) should wait (x+1) * _DELAY_SECS_PER_WORKER.
704    start_delay_secs = 0
705    if config.task_type == run_config_lib.TaskType.WORKER:
706      # TODO(xiejw): Replace the hard code logic (task_id + 1) with unique id in
707      # training cluster.
708      start_delay_secs = min(_MAX_DELAY_SECS,
709                             (config.task_id + 1) * _DELAY_SECS_PER_WORKER)
710    if start_delay_secs > 0:
711      logging.info('Waiting %d secs before starting training.',
712                   start_delay_secs)
713      time.sleep(start_delay_secs)
714
715    self._estimator.train(
716        input_fn=self._train_spec.input_fn,
717        max_steps=self._train_spec.max_steps,
718        hooks=list(self._train_spec.hooks) + list(self._train_hooks),
719        saving_listeners=saving_listeners)
720
721  def _start_continuous_evaluation(self):
722    """Repeatedly calls `Estimator` evaluate and export until training ends."""
723    start_delay_secs = self._eval_spec.start_delay_secs
724    if start_delay_secs:
725      logging.info('Waiting %f secs before starting eval.', start_delay_secs)
726      time.sleep(start_delay_secs)
727
728    latest_eval_result = None
729    evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec,
730                                             self._train_spec.max_steps)
731
732    should_early_stop = False
733    while not should_early_stop:
734      if (latest_eval_result and
735          latest_eval_result.status == _EvalStatus.EVALUATED):
736        global_step = latest_eval_result.metrics.get(ops.GraphKeys.GLOBAL_STEP)
737        if (global_step and self._train_spec.max_steps and
738            global_step >= self._train_spec.max_steps):
739          logging.info(
740              'Exiting evaluation, global_step=%s >= train max_steps=%s',
741              global_step, self._train_spec.max_steps)
742          return
743
744      latest_eval_result, should_early_stop = self._execute_evaluator_once(
745          evaluator, self._continuous_eval_listener,
746          self._eval_spec.throttle_secs)
747
748  def _execute_evaluator_once(self, evaluator, continuous_eval_listener,
749                              throttle_secs):
750    """Executes the `evaluator`."""
751    start = time.time()
752
753    eval_result = None
754    should_early_stop = False
755
756    if not continuous_eval_listener.before_eval():
757      logging.info('Exiting evaluation, as requested by '
758                   '_ContinuousEvalListener.before_eval.')
759      should_early_stop = True
760      return (eval_result, should_early_stop)
761
762    # Final export signal: For any eval result with global_step >= train
763    # max_steps, the evaluator will send the final export signal. The next
764    # iteration of while loop will end the continuous eval as the stopping
765    # condition is satisfied (both checks use the same global_step value,
766    # i.e., no race condition)
767    eval_result = evaluator.evaluate_and_export()
768
769    if not self._continuous_eval_listener.after_eval(eval_result):
770      logging.info('Exiting evaluation, as requested by '
771                   '_ContinuousEvalListener.after_eval.')
772      should_early_stop = True
773      return (eval_result, should_early_stop)
774
775    # Throttle if necessary.
776    elapsed_time = time.time() - start
777    difference = throttle_secs - elapsed_time
778    if difference > 0:
779      logging.info('Waiting %f secs before starting next eval run.', difference)
780      time.sleep(difference)
781
782    return (eval_result, should_early_stop)
783
784  class _Evaluator(object):
785    """A helper class to call `Estimator.evaluate` and export model."""
786
787    def __init__(self, estimator, eval_spec, max_training_steps):
788      self._estimator = estimator
789      self._eval_spec = eval_spec
790      self._is_final_export_triggered = False
791      self._previous_ckpt_path = None
792      self._last_warning_time = 0
793      self._max_training_steps = max_training_steps
794
795    @property
796    def is_final_export_triggered(self):
797      return self._is_final_export_triggered
798
799    def evaluate_and_export(self):
800      """Evaluate and (maybe) export the current model.
801
802      Returns:
803        An `EvalResult` instance.
804
805      Raises:
806        RuntimeError: for any unexpected internal error.
807        TypeError: if evaluation result has wrong type.
808      """
809      latest_ckpt_path = self._estimator.latest_checkpoint()
810      if not latest_ckpt_path:
811        self._log_err_msg('Estimator is not trained yet. Will start an '
812                          'evaluation when a checkpoint is ready.')
813        return _EvalResult(status=_EvalStatus.MISSING_CHECKPOINT)
814
815      if latest_ckpt_path == self._previous_ckpt_path:
816        self._log_err_msg(
817            'No new checkpoint ready for evaluation. Skip the current '
818            'evaluation pass as evaluation results are expected to be same '
819            'for the same checkpoint.')
820        return _EvalResult(status=_EvalStatus.NO_NEW_CHECKPOINT)
821
822      metrics = self._estimator.evaluate(
823          input_fn=self._eval_spec.input_fn,
824          steps=self._eval_spec.steps,
825          name=self._eval_spec.name,
826          checkpoint_path=latest_ckpt_path,
827          hooks=self._eval_spec.hooks)
828
829      # _EvalResult validates the metrics.
830      eval_result = _EvalResult(
831          status=_EvalStatus.EVALUATED,
832          metrics=metrics,
833          checkpoint_path=latest_ckpt_path)
834
835      is_the_final_export = (
836          eval_result.metrics[ops.GraphKeys.GLOBAL_STEP] >=
837          self._max_training_steps if self._max_training_steps else False)
838      self._export_eval_result(eval_result, is_the_final_export)
839
840      if is_the_final_export:
841        logging.debug('Calling exporter with the `is_the_final_export=True`.')
842        self._is_final_export_triggered = True
843
844      self._last_warning_time = 0
845      self._previous_ckpt_path = latest_ckpt_path
846      return eval_result
847
848    def _log_err_msg(self, message):
849      """Prints warning `message` every 10 mins."""
850      current_time = time.time()
851      if current_time - self._last_warning_time > 600:
852        logging.warning(message)
853        self._last_warning_time = current_time
854
855    def _export_eval_result(self, eval_result, is_the_final_export):
856      """Export `eval_result` according to exporters in `EvalSpec`."""
857      export_dir_base = os.path.join(
858          compat.as_str_any(self._estimator.model_dir),
859          compat.as_str_any('export'))
860
861      for exporter in self._eval_spec.exporters:
862        exporter.export(
863            estimator=self._estimator,
864            export_path=os.path.join(
865                compat.as_str_any(export_dir_base),
866                compat.as_str_any(exporter.name)),
867            checkpoint_path=eval_result.checkpoint_path,
868            eval_result=eval_result.metrics,
869            is_the_final_export=is_the_final_export)
870
871
872class _EvalStatus(object):
873  """The status of an evaluation event.
874
875  For local training and evaluation, the status can only be `EVALUATED` as
876  `Estimator.train` always generates a new checkpoint.
877
878  For distributed training and evaluation, a separated evaluator keeps looking
879  for new checkpoint. So, multiple situations might occur:
880
881  - EVALUATED: A new checkpoint is found since last evaluation.
882      `Estimator.evaluate` will be invoked.
883  - MISSING_CHECKPOINT: No checkpoint can be found. Typically, this means
884      the trainer has not yet produced any checkpoint.
885  - NO_NEW_CHECKPOINT: No new checkpoint can be found since last evaluation.
886      Typically, this means the trainer has not yet produced any new checkpoint.
887  """
888
889  EVALUATED = 'evaluated'
890  MISSING_CHECKPOINT = 'missing checkpoint'
891  NO_NEW_CHECKPOINT = 'no new checkpoint'
892
893
894class _EvalResult(
895    collections.namedtuple('EvalResult',
896                           ['status', 'metrics', 'checkpoint_path'])):
897  """_EvalResult holds the result of an evaluation event."""
898
899  def __new__(cls, status, metrics=None, checkpoint_path=None):
900    """Creates a validated `_EvalResult`.
901
902    Args:
903      status: See `_EvalStatus`.
904      metrics: The evaluation results returned by `Estimator.evaluate`. Only set
905          if status is `EVALUATED`.
906      checkpoint_path: The corresponding checkpoint path for the `metrics`. Only
907          set if status is `EVALUATED`.
908    Returns:
909      A validated `_EvalResult` object.
910
911    Raises:
912      ValueError: If validation fails.
913      TypeError: If any of the arguments is not the expected type.
914    """
915
916    if status != _EvalStatus.EVALUATED:
917      if metrics:
918        raise ValueError(
919            'metrics must be `None` if status is not {}; got status {},'
920            ' metrics {}'.format(_EvalStatus.EVALUATED, status, metrics))
921      if checkpoint_path:
922        raise ValueError(
923            'checkpoint must be `None` if status is not {}; got status {}, '
924            'checkpoint_path {}'.format(_EvalStatus.EVALUATED, status,
925                                        checkpoint_path))
926      return super(_EvalResult, cls).__new__(cls, status, metrics,
927                                             checkpoint_path)
928
929    # Now, evaluated case.
930    assert status == _EvalStatus.EVALUATED
931
932    # Validates metrics.
933    if not metrics:
934      raise ValueError(
935          'Internal error: `Estimator.evaluate` should never return empty '
936          'metrics.')
937    if not isinstance(metrics, dict):
938      raise TypeError(
939          '`Estimator.evaluate` should return dict. Given {}.'.format(
940              type(metrics)))
941    if ops.GraphKeys.GLOBAL_STEP not in metrics:
942      raise ValueError(
943          'Internal error: `Estimator.evaluate` result should have '
944          '`global_step` in result. Given {}'.format(metrics))
945
946    # Validates checkpoint_path.
947    if not checkpoint_path:
948      raise ValueError(
949          'Internal error: `checkpoint_path` should never be empty.')
950
951    return super(_EvalResult, cls).__new__(cls, status, metrics,
952                                           checkpoint_path)
953
954
955class _ContinuousEvalListener(object):
956  """Interface for listeners that take action before or after evaluation."""
957
958  def before_eval(self):
959    """Called before evaluation.
960
961    Returns:
962      `False` if you want to skip the current evaluation and early stop the
963      continuous evaluation; `True` otherwise.
964    """
965    return True
966
967  def after_eval(self, eval_result):
968    """Called after the evaluation is executed.
969
970    Args:
971      eval_result: An `_EvalResult` instance.
972
973    Returns:
974      False if you want to early stop continuous evaluation; `True` otherwise.
975    """
976    del eval_result
977    return True
978