1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Some common SessionRunHook classes."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import time
23
24import numpy as np
25import six
26
27from tensorflow.core.framework.summary_pb2 import Summary
28from tensorflow.core.protobuf import config_pb2
29from tensorflow.core.util.event_pb2 import SessionLog
30from tensorflow.python.client import timeline
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import errors
33from tensorflow.python.framework import meta_graph
34from tensorflow.python.framework import ops
35from tensorflow.python.ops import init_ops
36from tensorflow.python.ops import variable_scope
37from tensorflow.python.platform import gfile
38from tensorflow.python.platform import tf_logging as logging
39from tensorflow.python.training import session_run_hook
40from tensorflow.python.training import training_util
41from tensorflow.python.training.session_run_hook import SessionRunArgs
42from tensorflow.python.training.summary_io import SummaryWriterCache
43from tensorflow.python.util.tf_export import tf_export
44
45
46_HOOKS = "hooks"
47_STEPS_PER_RUN_VAR = "steps_per_run"
48
49
50class _HookTimer(object):
51  """Base timer for determining when Hooks should trigger.
52
53  Should not be instantiated directly.
54  """
55
56  def __init__(self):
57    pass
58
59  def reset(self):
60    """Resets the timer."""
61    pass
62
63  def should_trigger_for_step(self, step):
64    """Return true if the timer should trigger for the specified step."""
65    raise NotImplementedError
66
67  def update_last_triggered_step(self, step):
68    """Update the last triggered time and step number.
69
70    Args:
71      step: The current step.
72
73    Returns:
74      A pair `(elapsed_time, elapsed_steps)`, where `elapsed_time` is the number
75      of seconds between the current trigger and the last one (a float), and
76      `elapsed_steps` is the number of steps between the current trigger and
77      the last one. Both values will be set to `None` on the first trigger.
78    """
79    raise NotImplementedError
80
81  def last_triggered_step(self):
82    """Returns the last triggered time step or None if never triggered."""
83    raise NotImplementedError
84
85
86@tf_export(v1=["train.SecondOrStepTimer"])
87class SecondOrStepTimer(_HookTimer):
88  """Timer that triggers at most once every N seconds or once every N steps.
89  """
90
91  def __init__(self, every_secs=None, every_steps=None):
92    self.reset()
93    self._every_secs = every_secs
94    self._every_steps = every_steps
95
96    if self._every_secs is None and self._every_steps is None:
97      raise ValueError("Either every_secs or every_steps should be provided.")
98    if (self._every_secs is not None) and (self._every_steps is not None):
99      raise ValueError("Can not provide both every_secs and every_steps.")
100
101    super(SecondOrStepTimer, self).__init__()
102
103  def reset(self):
104    self._last_triggered_step = None
105    self._last_triggered_time = None
106
107  def should_trigger_for_step(self, step):
108    """Return true if the timer should trigger for the specified step.
109
110    Args:
111      step: Training step to trigger on.
112
113    Returns:
114      True if the difference between the current time and the time of the last
115      trigger exceeds `every_secs`, or if the difference between the current
116      step and the last triggered step exceeds `every_steps`. False otherwise.
117    """
118    if self._last_triggered_step is None:
119      return True
120
121    if self._last_triggered_step == step:
122      return False
123
124    if self._every_secs is not None:
125      if time.time() >= self._last_triggered_time + self._every_secs:
126        return True
127
128    if self._every_steps is not None:
129      if step >= self._last_triggered_step + self._every_steps:
130        return True
131
132    return False
133
134  def update_last_triggered_step(self, step):
135    current_time = time.time()
136    if self._last_triggered_time is None:
137      elapsed_secs = None
138      elapsed_steps = None
139    else:
140      elapsed_secs = current_time - self._last_triggered_time
141      elapsed_steps = step - self._last_triggered_step
142
143    self._last_triggered_time = current_time
144    self._last_triggered_step = step
145    return (elapsed_secs, elapsed_steps)
146
147  def last_triggered_step(self):
148    return self._last_triggered_step
149
150
151class NeverTriggerTimer(_HookTimer):
152  """Timer that never triggers."""
153
154  def should_trigger_for_step(self, step):
155    _ = step
156    return False
157
158  def update_last_triggered_step(self, step):
159    _ = step
160    return (None, None)
161
162  def last_triggered_step(self):
163    return None
164
165
166@tf_export(v1=["train.LoggingTensorHook"])
167class LoggingTensorHook(session_run_hook.SessionRunHook):
168  """Prints the given tensors every N local steps, every N seconds, or at end.
169
170  The tensors will be printed to the log, with `INFO` severity. If you are not
171  seeing the logs, you might want to add the following line after your imports:
172
173  ```python
174    tf.logging.set_verbosity(tf.logging.INFO)
175  ```
176
177  Note that if `at_end` is True, `tensors` should not include any tensor
178  whose evaluation produces a side effect such as consuming additional inputs.
179  """
180
181  def __init__(self, tensors, every_n_iter=None, every_n_secs=None,
182               at_end=False, formatter=None):
183    """Initializes a `LoggingTensorHook`.
184
185    Args:
186      tensors: `dict` that maps string-valued tags to tensors/tensor names,
187          or `iterable` of tensors/tensor names.
188      every_n_iter: `int`, print the values of `tensors` once every N local
189          steps taken on the current worker.
190      every_n_secs: `int` or `float`, print the values of `tensors` once every N
191          seconds. Exactly one of `every_n_iter` and `every_n_secs` should be
192          provided.
193      at_end: `bool` specifying whether to print the values of `tensors` at the
194          end of the run.
195      formatter: function, takes dict of `tag`->`Tensor` and returns a string.
196          If `None` uses default printing all tensors.
197
198    Raises:
199      ValueError: if `every_n_iter` is non-positive.
200    """
201    only_log_at_end = (
202        at_end and (every_n_iter is None) and (every_n_secs is None))
203    if (not only_log_at_end and
204        (every_n_iter is None) == (every_n_secs is None)):
205      raise ValueError(
206          "either at_end and/or exactly one of every_n_iter and every_n_secs "
207          "must be provided.")
208    if every_n_iter is not None and every_n_iter <= 0:
209      raise ValueError("invalid every_n_iter=%s." % every_n_iter)
210    if not isinstance(tensors, dict):
211      self._tag_order = tensors
212      tensors = {item: item for item in tensors}
213    else:
214      self._tag_order = sorted(tensors.keys())
215    self._tensors = tensors
216    self._formatter = formatter
217    self._timer = (
218        NeverTriggerTimer() if only_log_at_end else
219        SecondOrStepTimer(every_secs=every_n_secs, every_steps=every_n_iter))
220    self._log_at_end = at_end
221
222  def begin(self):
223    self._timer.reset()
224    self._iter_count = 0
225    # Convert names to tensors if given
226    self._current_tensors = {tag: _as_graph_element(tensor)
227                             for (tag, tensor) in self._tensors.items()}
228
229  def before_run(self, run_context):  # pylint: disable=unused-argument
230    self._should_trigger = self._timer.should_trigger_for_step(self._iter_count)
231    if self._should_trigger:
232      return SessionRunArgs(self._current_tensors)
233    else:
234      return None
235
236  def _log_tensors(self, tensor_values):
237    original = np.get_printoptions()
238    np.set_printoptions(suppress=True)
239    elapsed_secs, _ = self._timer.update_last_triggered_step(self._iter_count)
240    if self._formatter:
241      logging.info(self._formatter(tensor_values))
242    else:
243      stats = []
244      for tag in self._tag_order:
245        stats.append("%s = %s" % (tag, tensor_values[tag]))
246      if elapsed_secs is not None:
247        logging.info("%s (%.3f sec)", ", ".join(stats), elapsed_secs)
248      else:
249        logging.info("%s", ", ".join(stats))
250    np.set_printoptions(**original)
251
252  def after_run(self, run_context, run_values):
253    _ = run_context
254    if self._should_trigger:
255      self._log_tensors(run_values.results)
256
257    self._iter_count += 1
258
259  def end(self, session):
260    if self._log_at_end:
261      values = session.run(self._current_tensors)
262      self._log_tensors(values)
263
264
265def get_or_create_steps_per_run_variable():
266  """Gets or creates the steps_per_run variable.
267
268  In Estimator, the user provided computation, the model_fn, is wrapped
269  inside a tf.while_loop for peak performance. The iterations of the loop are
270  specified by this variable, which adjusts its value on the CPU after each
271  device program execution and before the next execution.
272
273  The purpose of using a variable, rather than a constant, is to allow
274  Estimator adapt the device training iterations according to the final steps
275  specified by users. For example, if the user sets the steps_per_run as
276  4 and steps as 10 in Estimator.train(), the steps_per_run
277  variable will have the following value before each training run.
278
279      - 1-st execution: steps_per_run = 4
280      - 2-nd execution: steps_per_run = 4
281      - 3-rd execution: steps_per_run = 2
282
283  As model_fn increases the global step once per train_op invocation, the global
284  step is 10 after all executions, matching the steps=10 inputs passed in by
285  users.
286
287  Returns:
288    A TF non-trainable resource variable.
289
290  Raises:
291    RuntimeError: If multi steps_per_run variables were found.
292  """
293  graph = ops.get_default_graph()
294  collection_name = "{}_{}".format(_HOOKS, _STEPS_PER_RUN_VAR)
295  steps_per_run_vars = graph.get_collection(collection_name)
296  if len(steps_per_run_vars) == 1:
297    return steps_per_run_vars[0]
298  elif len(steps_per_run_vars) > 1:
299    raise RuntimeError("Multiple steps_per_run_var in collection.")
300
301  with variable_scope.variable_scope(_HOOKS, reuse=variable_scope.AUTO_REUSE):
302    return variable_scope.get_variable(
303        _STEPS_PER_RUN_VAR,
304        initializer=init_ops.ones_initializer(),
305        shape=[],
306        dtype=dtypes.int32,
307        trainable=False,
308        collections=[collection_name, ops.GraphKeys.LOCAL_VARIABLES],
309        use_resource=True)
310
311
312class _MultiStepStopAtStepHook(session_run_hook.SessionRunHook):
313  """Hook that requests stop at a specified step."""
314
315  def __init__(self, num_steps=None, last_step=None, steps_per_run=1):
316    """Initializes a `MultiStepStopAtStepHook`.
317
318    This hook requests stop after either a number of steps have been
319    executed or a last step has been reached. Only one of the two options can be
320    specified.
321
322    if `num_steps` is specified, it indicates the number of steps to execute
323    after `begin()` is called. If instead `last_step` is specified, it
324    indicates the last step we want to execute, as passed to the `after_run()`
325    call.
326
327    In Estimator, the user provided computation, the model_fn, is wrapped
328    inside a tf.while_loop for peak performance. The steps_per_run variable
329    determines the number of iterations of the loop before returning to the CPU.
330
331    Args:
332      num_steps: Number of steps to execute.
333      last_step: Step after which to stop.
334      steps_per_run: Number of steps executed per run call.
335
336    Raises:
337      ValueError: If one of the arguments is invalid.
338    """
339    if num_steps is None and last_step is None:
340      raise ValueError("One of num_steps or last_step must be specified.")
341    if num_steps is not None and last_step is not None:
342      raise ValueError("Only one of num_steps or last_step can be specified.")
343    if steps_per_run is None or steps_per_run < 1:
344      raise ValueError("steps_per_run should be greater than 0")
345    self._num_steps = num_steps
346    self._last_step = last_step
347    self._steps_per_run_initial_value = steps_per_run
348
349  def begin(self):
350    self._global_step_tensor = training_util.get_global_step()
351    if self._global_step_tensor is None:
352      raise RuntimeError("Global step should be created to use StopAtStepHook.")
353    self._steps_per_run_variable = get_or_create_steps_per_run_variable()
354
355  def _update_steps_per_run_variable(self, global_step, session):
356    steps = min(self._last_step - global_step,
357                self._steps_per_run_initial_value)
358    self._steps_per_run_variable.load(steps, session=session)
359
360  def after_create_session(self, session, coord):
361    global_step = session.run(self._global_step_tensor)
362    if self._last_step is None:
363      self._last_step = global_step + self._num_steps
364    self._update_steps_per_run_variable(global_step, session)
365
366  def after_run(self, run_context, run_values):
367    # Global step cannot be retrieved via SessionRunArgs and before_run due to
368    # race condition in hook execution.
369    global_step = run_context.session.run(self._global_step_tensor)
370    if global_step >= self._last_step:
371      run_context.request_stop()
372    else:
373      self._update_steps_per_run_variable(global_step, run_context.session)
374
375
376@tf_export(v1=["train.StopAtStepHook"])
377class StopAtStepHook(session_run_hook.SessionRunHook):
378  """Hook that requests stop at a specified step."""
379
380  def __init__(self, num_steps=None, last_step=None):
381    """Initializes a `StopAtStepHook`.
382
383    This hook requests stop after either a number of steps have been
384    executed or a last step has been reached. Only one of the two options can be
385    specified.
386
387    if `num_steps` is specified, it indicates the number of steps to execute
388    after `begin()` is called. If instead `last_step` is specified, it
389    indicates the last step we want to execute, as passed to the `after_run()`
390    call.
391
392    Args:
393      num_steps: Number of steps to execute.
394      last_step: Step after which to stop.
395
396    Raises:
397      ValueError: If one of the arguments is invalid.
398    """
399    if num_steps is None and last_step is None:
400      raise ValueError("One of num_steps or last_step must be specified.")
401    if num_steps is not None and last_step is not None:
402      raise ValueError("Only one of num_steps or last_step can be specified.")
403    self._num_steps = num_steps
404    self._last_step = last_step
405
406  def begin(self):
407    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
408    if self._global_step_tensor is None:
409      raise RuntimeError("Global step should be created to use StopAtStepHook.")
410
411  def after_create_session(self, session, coord):
412    if self._last_step is None:
413      global_step = session.run(self._global_step_tensor)
414      self._last_step = global_step + self._num_steps
415
416  def before_run(self, run_context):  # pylint: disable=unused-argument
417    return SessionRunArgs(self._global_step_tensor)
418
419  def after_run(self, run_context, run_values):
420    global_step = run_values.results + 1
421    if global_step >= self._last_step:
422      # Check latest global step to ensure that the targeted last step is
423      # reached. global_step read tensor is the value of global step
424      # before running the operation. We're not sure whether current session.run
425      # incremented the global_step or not. Here we're checking it.
426
427      step = run_context.session.run(self._global_step_tensor)
428      if step >= self._last_step:
429        run_context.request_stop()
430
431
432@tf_export(v1=["train.CheckpointSaverListener"])
433class CheckpointSaverListener(object):
434  """Interface for listeners that take action before or after checkpoint save.
435
436  `CheckpointSaverListener` triggers only in steps when `CheckpointSaverHook` is
437  triggered, and provides callbacks at the following points:
438   - before using the session
439   - before each call to `Saver.save()`
440   - after each call to `Saver.save()`
441   - at the end of session
442
443  To use a listener, implement a class and pass the listener to a
444  `CheckpointSaverHook`, as in this example:
445
446  ```python
447  class ExampleCheckpointSaverListener(CheckpointSaverListener):
448    def begin(self):
449      # You can add ops to the graph here.
450      print('Starting the session.')
451      self.your_tensor = ...
452
453    def before_save(self, session, global_step_value):
454      print('About to write a checkpoint')
455
456    def after_save(self, session, global_step_value):
457      print('Done writing checkpoint.')
458      if decided_to_stop_training():
459        return True
460
461    def end(self, session, global_step_value):
462      print('Done with the session.')
463
464  ...
465  listener = ExampleCheckpointSaverListener()
466  saver_hook = tf.train.CheckpointSaverHook(
467      checkpoint_dir, listeners=[listener])
468  with tf.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]):
469    ...
470  ```
471
472  A `CheckpointSaverListener` may simply take some action after every
473  checkpoint save. It is also possible for the listener to use its own schedule
474  to act less frequently, e.g. based on global_step_value. In this case,
475  implementors should implement the `end()` method to handle actions related to
476  the last checkpoint save. But the listener should not act twice if
477  `after_save()` already handled this last checkpoint save.
478
479  A `CheckpointSaverListener` can request training to be stopped, by returning
480  True in `after_save`. Please note that, in replicated distributed training
481  setting, only `chief` should use this behavior. Otherwise each worker will do
482  their own evaluation, which may be wasteful of resources.
483  """
484
485  def begin(self):
486    pass
487
488  def before_save(self, session, global_step_value):
489    pass
490
491  def after_save(self, session, global_step_value):
492    pass
493
494  def end(self, session, global_step_value):
495    pass
496
497
498@tf_export(v1=["train.CheckpointSaverHook"])
499class CheckpointSaverHook(session_run_hook.SessionRunHook):
500  """Saves checkpoints every N steps or seconds."""
501
502  def __init__(self,
503               checkpoint_dir,
504               save_secs=None,
505               save_steps=None,
506               saver=None,
507               checkpoint_basename="model.ckpt",
508               scaffold=None,
509               listeners=None):
510    """Initializes a `CheckpointSaverHook`.
511
512    Args:
513      checkpoint_dir: `str`, base directory for the checkpoint files.
514      save_secs: `int`, save every N secs.
515      save_steps: `int`, save every N steps.
516      saver: `Saver` object, used for saving.
517      checkpoint_basename: `str`, base name for the checkpoint files.
518      scaffold: `Scaffold`, use to get saver object.
519      listeners: List of `CheckpointSaverListener` subclass instances.
520        Used for callbacks that run immediately before or after this hook saves
521        the checkpoint.
522
523    Raises:
524      ValueError: One of `save_steps` or `save_secs` should be set.
525      ValueError: At most one of `saver` or `scaffold` should be set.
526    """
527    logging.info("Create CheckpointSaverHook.")
528    if saver is not None and scaffold is not None:
529      raise ValueError("You cannot provide both saver and scaffold.")
530    self._saver = saver
531    self._checkpoint_dir = checkpoint_dir
532    self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
533    self._scaffold = scaffold
534    self._timer = SecondOrStepTimer(every_secs=save_secs,
535                                    every_steps=save_steps)
536    self._listeners = listeners or []
537    self._steps_per_run = 1
538
539  def _set_steps_per_run(self, steps_per_run):
540    self._steps_per_run = steps_per_run
541
542  def begin(self):
543    self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
544    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
545    if self._global_step_tensor is None:
546      raise RuntimeError(
547          "Global step should be created to use CheckpointSaverHook.")
548    for l in self._listeners:
549      l.begin()
550
551  def after_create_session(self, session, coord):
552    global_step = session.run(self._global_step_tensor)
553    # We do write graph and saver_def at the first call of before_run.
554    # We cannot do this in begin, since we let other hooks to change graph and
555    # add variables in begin. Graph is finalized after all begin calls.
556    training_util.write_graph(
557        ops.get_default_graph().as_graph_def(add_shapes=True),
558        self._checkpoint_dir,
559        "graph.pbtxt")
560    saver_def = self._get_saver().saver_def if self._get_saver() else None
561    graph = ops.get_default_graph()
562    meta_graph_def = meta_graph.create_meta_graph_def(
563        graph_def=graph.as_graph_def(add_shapes=True),
564        saver_def=saver_def)
565    self._summary_writer.add_graph(graph)
566    self._summary_writer.add_meta_graph(meta_graph_def)
567    # The checkpoint saved here is the state at step "global_step".
568    self._save(session, global_step)
569    self._timer.update_last_triggered_step(global_step)
570
571  def before_run(self, run_context):  # pylint: disable=unused-argument
572    return SessionRunArgs(self._global_step_tensor)
573
574  def after_run(self, run_context, run_values):
575    stale_global_step = run_values.results
576    if self._timer.should_trigger_for_step(
577        stale_global_step + self._steps_per_run):
578      # get the real value after train op.
579      global_step = run_context.session.run(self._global_step_tensor)
580      if self._timer.should_trigger_for_step(global_step):
581        self._timer.update_last_triggered_step(global_step)
582        if self._save(run_context.session, global_step):
583          run_context.request_stop()
584
585  def end(self, session):
586    last_step = session.run(self._global_step_tensor)
587    if last_step != self._timer.last_triggered_step():
588      self._save(session, last_step)
589    for l in self._listeners:
590      l.end(session, last_step)
591
592  def _save(self, session, step):
593    """Saves the latest checkpoint, returns should_stop."""
594    logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
595
596    for l in self._listeners:
597      l.before_save(session, step)
598
599    self._get_saver().save(session, self._save_path, global_step=step)
600    self._summary_writer.add_session_log(
601        SessionLog(
602            status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
603        step)
604
605    should_stop = False
606    for l in self._listeners:
607      if l.after_save(session, step):
608        logging.info(
609            "A CheckpointSaverListener requested that training be stopped. "
610            "listener: {}".format(l))
611        should_stop = True
612    return should_stop
613
614  def _get_saver(self):
615    if self._saver is not None:
616      return self._saver
617    elif self._scaffold is not None:
618      return self._scaffold.saver
619
620    # Get saver from the SAVERS collection if present.
621    collection_key = ops.GraphKeys.SAVERS
622    savers = ops.get_collection(collection_key)
623    if not savers:
624      raise RuntimeError(
625          "No items in collection {}. Please add a saver to the collection "
626          "or provide a saver or scaffold.".format(collection_key))
627    elif len(savers) > 1:
628      raise RuntimeError(
629          "More than one item in collection {}. "
630          "Please indicate which one to use by passing it to the constructor.".
631          format(collection_key))
632
633    self._saver = savers[0]
634    return savers[0]
635
636
637@tf_export(v1=["train.StepCounterHook"])
638class StepCounterHook(session_run_hook.SessionRunHook):
639  """Hook that counts steps per second."""
640
641  def __init__(self,
642               every_n_steps=100,
643               every_n_secs=None,
644               output_dir=None,
645               summary_writer=None):
646
647    if (every_n_steps is None) == (every_n_secs is None):
648      raise ValueError(
649          "exactly one of every_n_steps and every_n_secs should be provided.")
650    self._timer = SecondOrStepTimer(every_steps=every_n_steps,
651                                    every_secs=every_n_secs)
652
653    self._summary_writer = summary_writer
654    self._output_dir = output_dir
655    self._last_global_step = None
656    self._global_step_check_count = 0
657    self._steps_per_run = 1
658
659  def _set_steps_per_run(self, steps_per_run):
660    self._steps_per_run = steps_per_run
661
662  def begin(self):
663    if self._summary_writer is None and self._output_dir:
664      self._summary_writer = SummaryWriterCache.get(self._output_dir)
665    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
666    if self._global_step_tensor is None:
667      raise RuntimeError(
668          "Global step should be created to use StepCounterHook.")
669    self._summary_tag = training_util.get_global_step().op.name + "/sec"
670
671  def before_run(self, run_context):  # pylint: disable=unused-argument
672    return SessionRunArgs(self._global_step_tensor)
673
674  def _log_and_record(self, elapsed_steps, elapsed_time, global_step):
675    steps_per_sec = elapsed_steps / elapsed_time
676    if self._summary_writer is not None:
677      summary = Summary(value=[Summary.Value(
678          tag=self._summary_tag, simple_value=steps_per_sec)])
679      self._summary_writer.add_summary(summary, global_step)
680    logging.info("%s: %g", self._summary_tag, steps_per_sec)
681
682  def after_run(self, run_context, run_values):
683    _ = run_context
684
685    stale_global_step = run_values.results
686    if self._timer.should_trigger_for_step(
687        stale_global_step + self._steps_per_run):
688      # get the real value after train op.
689      global_step = run_context.session.run(self._global_step_tensor)
690      if self._timer.should_trigger_for_step(global_step):
691        elapsed_time, elapsed_steps = self._timer.update_last_triggered_step(
692            global_step)
693        if elapsed_time is not None:
694          self._log_and_record(elapsed_steps, elapsed_time, global_step)
695
696    # Check whether the global step has been increased. Here, we do not use the
697    # timer.last_triggered_step as the timer might record a different global
698    # step value such that the comparison could be unreliable. For simplicity,
699    # we just compare the stale_global_step with previously recorded version.
700    if stale_global_step == self._last_global_step:
701      # Here, we use a counter to count how many times we have observed that the
702      # global step has not been increased. For some Optimizers, the global step
703      # is not increased each time by design. For example, SyncReplicaOptimizer
704      # doesn't increase the global step in worker's main train step.
705      self._global_step_check_count += 1
706      if self._global_step_check_count % 20 == 0:
707        self._global_step_check_count = 0
708        logging.warning(
709            "It seems that global step (tf.train.get_global_step) has not "
710            "been increased. Current value (could be stable): %s vs previous "
711            "value: %s. You could increase the global step by passing "
712            "tf.train.get_global_step() to Optimizer.apply_gradients or "
713            "Optimizer.minimize.", stale_global_step, self._last_global_step)
714    else:
715      # Whenever we observe the increment, reset the counter.
716      self._global_step_check_count = 0
717
718    self._last_global_step = stale_global_step
719
720
721@tf_export(v1=["train.NanLossDuringTrainingError"])
722class NanLossDuringTrainingError(RuntimeError):
723
724  def __str__(self):
725    return "NaN loss during training."
726
727
728@tf_export(v1=["train.NanTensorHook"])
729class NanTensorHook(session_run_hook.SessionRunHook):
730  """Monitors the loss tensor and stops training if loss is NaN.
731
732  Can either fail with exception or just stop training.
733  """
734
735  def __init__(self, loss_tensor, fail_on_nan_loss=True):
736    """Initializes a `NanTensorHook`.
737
738    Args:
739      loss_tensor: `Tensor`, the loss tensor.
740      fail_on_nan_loss: `bool`, whether to raise exception when loss is NaN.
741    """
742    self._loss_tensor = loss_tensor
743    self._fail_on_nan_loss = fail_on_nan_loss
744
745  def before_run(self, run_context):  # pylint: disable=unused-argument
746    return SessionRunArgs(self._loss_tensor)
747
748  def after_run(self, run_context, run_values):
749    if np.isnan(run_values.results):
750      failure_message = "Model diverged with loss = NaN."
751      if self._fail_on_nan_loss:
752        logging.error(failure_message)
753        raise NanLossDuringTrainingError
754      else:
755        logging.warning(failure_message)
756        # We don't raise an error but we request stop without an exception.
757        run_context.request_stop()
758
759
760@tf_export(v1=["train.SummarySaverHook"])
761class SummarySaverHook(session_run_hook.SessionRunHook):
762  """Saves summaries every N steps."""
763
764  def __init__(self,
765               save_steps=None,
766               save_secs=None,
767               output_dir=None,
768               summary_writer=None,
769               scaffold=None,
770               summary_op=None):
771    """Initializes a `SummarySaverHook`.
772
773    Args:
774      save_steps: `int`, save summaries every N steps. Exactly one of
775          `save_secs` and `save_steps` should be set.
776      save_secs: `int`, save summaries every N seconds.
777      output_dir: `string`, the directory to save the summaries to. Only used
778          if no `summary_writer` is supplied.
779      summary_writer: `SummaryWriter`. If `None` and an `output_dir` was passed,
780          one will be created accordingly.
781      scaffold: `Scaffold` to get summary_op if it's not provided.
782      summary_op: `Tensor` of type `string` containing the serialized `Summary`
783          protocol buffer or a list of `Tensor`. They are most likely an output
784          by TF summary methods like `tf.summary.scalar` or
785          `tf.summary.merge_all`. It can be passed in as one tensor; if more
786          than one, they must be passed in as a list.
787
788    Raises:
789      ValueError: Exactly one of scaffold or summary_op should be set.
790    """
791    if ((scaffold is None and summary_op is None) or
792        (scaffold is not None and summary_op is not None)):
793      raise ValueError(
794          "Exactly one of scaffold or summary_op must be provided.")
795    self._summary_op = summary_op
796    self._summary_writer = summary_writer
797    self._output_dir = output_dir
798    self._scaffold = scaffold
799    self._timer = SecondOrStepTimer(every_secs=save_secs,
800                                    every_steps=save_steps)
801    # TODO(mdan): Throw an error if output_dir and summary_writer are None.
802
803  def begin(self):
804    if self._summary_writer is None and self._output_dir:
805      self._summary_writer = SummaryWriterCache.get(self._output_dir)
806    self._next_step = None
807    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
808    if self._global_step_tensor is None:
809      raise RuntimeError(
810          "Global step should be created to use SummarySaverHook.")
811
812  def before_run(self, run_context):  # pylint: disable=unused-argument
813    self._request_summary = (
814        self._next_step is None or
815        self._timer.should_trigger_for_step(self._next_step))
816    requests = {"global_step": self._global_step_tensor}
817    if self._request_summary:
818      if self._get_summary_op() is not None:
819        requests["summary"] = self._get_summary_op()
820
821    return SessionRunArgs(requests)
822
823  def after_run(self, run_context, run_values):
824    _ = run_context
825    if not self._summary_writer:
826      return
827
828    stale_global_step = run_values.results["global_step"]
829    global_step = stale_global_step + 1
830    if self._next_step is None or self._request_summary:
831      global_step = run_context.session.run(self._global_step_tensor)
832
833    if self._next_step is None:
834      self._summary_writer.add_session_log(
835          SessionLog(status=SessionLog.START), global_step)
836
837    if self._request_summary:
838      self._timer.update_last_triggered_step(global_step)
839      if "summary" in run_values.results:
840        for summary in run_values.results["summary"]:
841          self._summary_writer.add_summary(summary, global_step)
842
843    self._next_step = global_step + 1
844
845  def end(self, session=None):
846    if self._summary_writer:
847      self._summary_writer.flush()
848
849  def _get_summary_op(self):
850    """Fetches the summary op either from self._summary_op or self._scaffold.
851
852    Returns:
853      Returns a list of summary `Tensor`.
854    """
855    summary_op = None
856    if self._summary_op is not None:
857      summary_op = self._summary_op
858    elif self._scaffold.summary_op is not None:
859      summary_op = self._scaffold.summary_op
860
861    if summary_op is None:
862      return None
863
864    if not isinstance(summary_op, list):
865      return [summary_op]
866    return summary_op
867
868
869@tf_export(v1=["train.GlobalStepWaiterHook"])
870class GlobalStepWaiterHook(session_run_hook.SessionRunHook):
871  """Delays execution until global step reaches `wait_until_step`.
872
873  This hook delays execution until global step reaches to `wait_until_step`. It
874  is used to gradually start workers in distributed settings. One example usage
875  would be setting `wait_until_step=int(K*log(task_id+1))` assuming that
876  task_id=0 is the chief.
877  """
878
879  def __init__(self, wait_until_step):
880    """Initializes a `GlobalStepWaiterHook`.
881
882    Args:
883      wait_until_step: an `int` shows until which global step should we wait.
884    """
885    self._wait_until_step = wait_until_step
886
887  def begin(self):
888    self._worker_is_started = False
889    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
890    if self._global_step_tensor is None:
891      raise RuntimeError(
892          "Global step should be created to use _GlobalStepWaiterHook.")
893
894  def before_run(self, run_context):
895    if self._worker_is_started:
896      return None
897
898    if self._wait_until_step <= 0:
899      self._worker_is_started = True
900      return None
901
902    logging.info("Waiting for global step %d before starting training.",
903                 self._wait_until_step)
904    last_logged_step = 0
905    while True:
906      current_step = run_context.session.run(self._global_step_tensor)
907      if current_step >= self._wait_until_step:
908        self._worker_is_started = True
909        return None
910      if current_step - last_logged_step > 1000:
911        logging.info("Waiting for global step %d before starting training. "
912                     "Current step is %d.", self._wait_until_step, current_step)
913        last_logged_step = current_step
914      time.sleep(0.5)
915
916
917@tf_export(v1=["train.FinalOpsHook"])
918class FinalOpsHook(session_run_hook.SessionRunHook):
919  """A hook which evaluates `Tensors` at the end of a session."""
920
921  def __init__(self, final_ops, final_ops_feed_dict=None):
922    """Initializes `FinalOpHook` with ops to run at the end of the session.
923
924    Args:
925      final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of
926        names to `Tensors`.
927      final_ops_feed_dict: A feed dictionary to use when running
928        `final_ops_dict`.
929    """
930    self._final_ops = final_ops
931    self._final_ops_feed_dict = final_ops_feed_dict
932    self._final_ops_values = None
933
934  @property
935  def final_ops_values(self):
936    return self._final_ops_values
937
938  def end(self, session):
939    if self._final_ops is not None:
940      try:
941        self._final_ops_values = session.run(
942            self._final_ops, feed_dict=self._final_ops_feed_dict)
943      except (errors.OutOfRangeError, StopIteration) as e:
944        logging.warning(
945            "An OutOfRangeError or StopIteration exception is raised by the "
946            "code in FinalOpsHook. This typically means the Ops running by the "
947            "FinalOpsHook have a dependency back to some input source, which "
948            "should not happen. For example, for metrics in "
949            "tf.estimator.Estimator, all metrics functions return two Ops: "
950            "`value_op` and  `update_op`. Estimator.evaluate calls the "
951            "`update_op` for each batch of the data in input source and, once "
952            "it is exhausted, it call the `value_op` to get the metric values. "
953            "The `value_op` here should have dependency back to variables "
954            "reading only, rather than reading another batch from input. "
955            "Otherwise, the `value_op`, executed by `FinalOpsHook`, triggers "
956            "another data reading, which ends OutOfRangeError/StopIteration. "
957            "Please fix that.")
958        raise e
959
960
961@tf_export(v1=["train.FeedFnHook"])
962class FeedFnHook(session_run_hook.SessionRunHook):
963  """Runs `feed_fn` and sets the `feed_dict` accordingly."""
964
965  def __init__(self, feed_fn):
966    """Initializes a `FeedFnHook`.
967
968    Args:
969      feed_fn: function that takes no arguments and returns `dict` of `Tensor`
970        to feed.
971    """
972    self.feed_fn = feed_fn
973
974  def before_run(self, run_context):  # pylint: disable=unused-argument
975    return session_run_hook.SessionRunArgs(
976        fetches=None, feed_dict=self.feed_fn())
977
978
979@tf_export(v1=["train.ProfilerHook"])
980class ProfilerHook(session_run_hook.SessionRunHook):
981  """Captures CPU/GPU profiling information every N steps or seconds.
982
983  This produces files called "timeline-<step>.json", which are in Chrome
984  Trace format.
985
986  For more information see:
987  https://github.com/catapult-project/catapult/blob/master/tracing/README.md
988  """
989
990  def __init__(self,
991               save_steps=None,
992               save_secs=None,
993               output_dir="",
994               show_dataflow=True,
995               show_memory=False):
996    """Initializes a hook that takes periodic profiling snapshots.
997
998    `options.run_metadata` argument of `tf.Session.Run` is used to collect
999    metadata about execution. This hook sets the metadata and dumps it in Chrome
1000    Trace format.
1001
1002
1003    Args:
1004      save_steps: `int`, save profile traces every N steps. Exactly one of
1005          `save_secs` and `save_steps` should be set.
1006      save_secs: `int` or `float`, save profile traces every N seconds.
1007      output_dir: `string`, the directory to save the profile traces to.
1008          Defaults to the current directory.
1009      show_dataflow: `bool`, if True, add flow events to the trace connecting
1010          producers and consumers of tensors.
1011      show_memory: `bool`, if True, add object snapshot events to the trace
1012          showing the sizes and lifetimes of tensors.
1013    """
1014    self._output_file = os.path.join(output_dir, "timeline-{}.json")
1015    self._file_writer = SummaryWriterCache.get(output_dir)
1016    self._show_dataflow = show_dataflow
1017    self._show_memory = show_memory
1018    self._timer = SecondOrStepTimer(
1019        every_secs=save_secs, every_steps=save_steps)
1020
1021  def begin(self):
1022    self._next_step = None
1023    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
1024    if self._global_step_tensor is None:
1025      raise RuntimeError("Global step should be created to use ProfilerHook.")
1026
1027  def before_run(self, run_context):
1028    self._request_summary = (
1029        self._next_step is not None and
1030        self._timer.should_trigger_for_step(self._next_step))
1031    requests = {"global_step": self._global_step_tensor}
1032    opts = (config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
1033            if self._request_summary else None)
1034
1035    return SessionRunArgs(requests, options=opts)
1036
1037  def after_run(self, run_context, run_values):
1038    stale_global_step = run_values.results["global_step"]
1039    if self._next_step is None:
1040      # Update the timer so that it does not activate until N steps or seconds
1041      # have passed.
1042      self._timer.update_last_triggered_step(stale_global_step)
1043    global_step = stale_global_step + 1
1044    if self._request_summary:
1045      global_step = run_context.session.run(self._global_step_tensor)
1046      self._timer.update_last_triggered_step(global_step)
1047      self._save(global_step,
1048                 self._output_file.format(global_step),
1049                 run_values.run_metadata.step_stats)
1050      self._file_writer.add_run_metadata(run_values.run_metadata,
1051                                         "step_%d" % global_step)
1052
1053    self._next_step = global_step + 1
1054
1055  def _save(self, step, save_path, step_stats):
1056    logging.info("Saving timeline for %d into '%s'.", step, save_path)
1057    with gfile.Open(save_path, "w") as f:
1058      trace = timeline.Timeline(step_stats)
1059      f.write(
1060          trace.generate_chrome_trace_format(
1061              show_dataflow=self._show_dataflow, show_memory=self._show_memory))
1062
1063
1064def _as_graph_element(obj):
1065  """Retrieves Graph element."""
1066  graph = ops.get_default_graph()
1067  if not isinstance(obj, six.string_types):
1068    if not hasattr(obj, "graph") or obj.graph != graph:
1069      raise ValueError("Passed %s should have graph attribute that is equal "
1070                       "to current graph %s." % (obj, graph))
1071    return obj
1072  if ":" in obj:
1073    element = graph.as_graph_element(obj)
1074  else:
1075    element = graph.as_graph_element(obj + ":0")
1076    # Check that there is no :1 (e.g. it's single output).
1077    try:
1078      graph.as_graph_element(obj + ":1")
1079    except (KeyError, ValueError):
1080      pass
1081    else:
1082      raise ValueError("Name %s is ambiguous, "
1083                       "as this `Operation` has multiple outputs "
1084                       "(at least 2)." % obj)
1085  return element
1086