1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Some common SessionRunHook classes.
16
17Note that the symbols that are exported to v1 tf.train namespace are also
18exported to v2 in tf.estimator namespace. See
19https://github.com/tensorflow/estimator/blob/master/tensorflow_estimator/python/estimator/hooks/basic_session_run_hooks.py
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import os
27import time
28
29import numpy as np
30import six
31
32from tensorflow.core.framework.summary_pb2 import Summary
33from tensorflow.core.protobuf import config_pb2
34from tensorflow.core.util.event_pb2 import SessionLog
35from tensorflow.python.client import timeline
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import errors
38from tensorflow.python.framework import meta_graph
39from tensorflow.python.framework import ops
40from tensorflow.python.ops import init_ops
41from tensorflow.python.ops import variable_scope
42from tensorflow.python.platform import gfile
43from tensorflow.python.platform import tf_logging as logging
44from tensorflow.python.training import session_run_hook
45from tensorflow.python.training import training_util
46from tensorflow.python.training.session_run_hook import SessionRunArgs
47from tensorflow.python.training.summary_io import SummaryWriterCache
48from tensorflow.python.util.tf_export import tf_export
49
50_HOOKS = "hooks"
51_STEPS_PER_RUN_VAR = "steps_per_run"
52
53
54class _HookTimer(object):
55  """Base timer for determining when Hooks should trigger.
56
57  Should not be instantiated directly.
58  """
59
60  def __init__(self):
61    pass
62
63  def reset(self):
64    """Resets the timer."""
65    pass
66
67  def should_trigger_for_step(self, step):
68    """Return true if the timer should trigger for the specified step."""
69    raise NotImplementedError
70
71  def update_last_triggered_step(self, step):
72    """Update the last triggered time and step number.
73
74    Args:
75      step: The current step.
76
77    Returns:
78      A pair `(elapsed_time, elapsed_steps)`, where `elapsed_time` is the number
79      of seconds between the current trigger and the last one (a float), and
80      `elapsed_steps` is the number of steps between the current trigger and
81      the last one. Both values will be set to `None` on the first trigger.
82    """
83    raise NotImplementedError
84
85  def last_triggered_step(self):
86    """Returns the last triggered time step or None if never triggered."""
87    raise NotImplementedError
88
89
90@tf_export(v1=["train.SecondOrStepTimer"])
91class SecondOrStepTimer(_HookTimer):
92  """Timer that triggers at most once every N seconds or once every N steps.
93
94  This symbol is also exported to v2 in tf.estimator namespace. See
95  https://github.com/tensorflow/estimator/blob/master/tensorflow_estimator/python/estimator/hooks/basic_session_run_hooks.py
96  """
97
98  def __init__(self, every_secs=None, every_steps=None):
99    self.reset()
100    self._every_secs = every_secs
101    self._every_steps = every_steps
102
103    if self._every_secs is None and self._every_steps is None:
104      raise ValueError("Either every_secs or every_steps should be provided.")
105    if (self._every_secs is not None) and (self._every_steps is not None):
106      raise ValueError("Can not provide both every_secs and every_steps.")
107
108    super(SecondOrStepTimer, self).__init__()
109
110  def reset(self):
111    self._last_triggered_step = None
112    self._last_triggered_time = None
113
114  def should_trigger_for_step(self, step):
115    """Return true if the timer should trigger for the specified step.
116
117    Args:
118      step: Training step to trigger on.
119
120    Returns:
121      True if the difference between the current time and the time of the last
122      trigger exceeds `every_secs`, or if the difference between the current
123      step and the last triggered step exceeds `every_steps`. False otherwise.
124    """
125    if self._last_triggered_step is None:
126      return True
127
128    if self._last_triggered_step == step:
129      return False
130
131    if self._every_secs is not None:
132      if time.time() >= self._last_triggered_time + self._every_secs:
133        return True
134
135    if self._every_steps is not None:
136      if step >= self._last_triggered_step + self._every_steps:
137        return True
138
139    return False
140
141  def update_last_triggered_step(self, step):
142    current_time = time.time()
143    if self._last_triggered_time is None:
144      elapsed_secs = None
145      elapsed_steps = None
146    else:
147      elapsed_secs = current_time - self._last_triggered_time
148      elapsed_steps = step - self._last_triggered_step
149
150    self._last_triggered_time = current_time
151    self._last_triggered_step = step
152    return (elapsed_secs, elapsed_steps)
153
154  def last_triggered_step(self):
155    return self._last_triggered_step
156
157
158class NeverTriggerTimer(_HookTimer):
159  """Timer that never triggers."""
160
161  def should_trigger_for_step(self, step):
162    _ = step
163    return False
164
165  def update_last_triggered_step(self, step):
166    _ = step
167    return (None, None)
168
169  def last_triggered_step(self):
170    return None
171
172
173@tf_export(v1=["train.LoggingTensorHook"])
174class LoggingTensorHook(session_run_hook.SessionRunHook):
175  """Prints the given tensors every N local steps, every N seconds, or at end.
176
177  The tensors will be printed to the log, with `INFO` severity. If you are not
178  seeing the logs, you might want to add the following line after your imports:
179
180  ```python
181    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
182  ```
183
184  Note that if `at_end` is True, `tensors` should not include any tensor
185  whose evaluation produces a side effect such as consuming additional inputs.
186  """
187
188  def __init__(self,
189               tensors,
190               every_n_iter=None,
191               every_n_secs=None,
192               at_end=False,
193               formatter=None):
194    """Initializes a `LoggingTensorHook`.
195
196    Args:
197      tensors: `dict` that maps string-valued tags to tensors/tensor names, or
198        `iterable` of tensors/tensor names.
199      every_n_iter: `int`, print the values of `tensors` once every N local
200        steps taken on the current worker.
201      every_n_secs: `int` or `float`, print the values of `tensors` once every N
202        seconds. Exactly one of `every_n_iter` and `every_n_secs` should be
203        provided.
204      at_end: `bool` specifying whether to print the values of `tensors` at the
205        end of the run.
206      formatter: function, takes dict of `tag`->`Tensor` and returns a string.
207        If `None` uses default printing all tensors.
208
209    Raises:
210      ValueError: if `every_n_iter` is non-positive.
211    """
212    only_log_at_end = (
213        at_end and (every_n_iter is None) and (every_n_secs is None))
214    if (not only_log_at_end and
215        (every_n_iter is None) == (every_n_secs is None)):
216      raise ValueError(
217          "either at_end and/or exactly one of every_n_iter and every_n_secs "
218          "must be provided.")
219    if every_n_iter is not None and every_n_iter <= 0:
220      raise ValueError("invalid every_n_iter=%s." % every_n_iter)
221    if not isinstance(tensors, dict):
222      self._tag_order = tensors
223      tensors = {item: item for item in tensors}
224    else:
225      self._tag_order = sorted(tensors.keys())
226    self._tensors = tensors
227    self._formatter = formatter
228    self._timer = (
229        NeverTriggerTimer() if only_log_at_end else SecondOrStepTimer(
230            every_secs=every_n_secs, every_steps=every_n_iter))
231    self._log_at_end = at_end
232
233  def begin(self):
234    self._timer.reset()
235    self._iter_count = 0
236    # Convert names to tensors if given
237    self._current_tensors = {
238        tag: _as_graph_element(tensor)
239        for (tag, tensor) in self._tensors.items()
240    }
241
242  def before_run(self, run_context):  # pylint: disable=unused-argument
243    self._should_trigger = self._timer.should_trigger_for_step(self._iter_count)
244    if self._should_trigger:
245      return SessionRunArgs(self._current_tensors)
246    else:
247      return None
248
249  def _log_tensors(self, tensor_values):
250    original = np.get_printoptions()
251    np.set_printoptions(suppress=True)
252    elapsed_secs, _ = self._timer.update_last_triggered_step(self._iter_count)
253    if self._formatter:
254      logging.info(self._formatter(tensor_values))
255    else:
256      stats = []
257      for tag in self._tag_order:
258        stats.append("%s = %s" % (tag, tensor_values[tag]))
259      if elapsed_secs is not None:
260        logging.info("%s (%.3f sec)", ", ".join(stats), elapsed_secs)
261      else:
262        logging.info("%s", ", ".join(stats))
263    np.set_printoptions(**original)
264
265  def after_run(self, run_context, run_values):
266    _ = run_context
267    if self._should_trigger:
268      self._log_tensors(run_values.results)
269
270    self._iter_count += 1
271
272  def end(self, session):
273    if self._log_at_end:
274      values = session.run(self._current_tensors)
275      self._log_tensors(values)
276
277
278def get_or_create_steps_per_run_variable():
279  """Gets or creates the steps_per_run variable.
280
281  In Estimator, the user provided computation, the model_fn, is wrapped
282  inside a tf.while_loop for peak performance. The iterations of the loop are
283  specified by this variable, which adjusts its value on the CPU after each
284  device program execution and before the next execution.
285
286  The purpose of using a variable, rather than a constant, is to allow
287  Estimator adapt the device training iterations according to the final steps
288  specified by users. For example, if the user sets the steps_per_run as
289  4 and steps as 10 in Estimator.train(), the steps_per_run
290  variable will have the following value before each training run.
291
292      - 1-st execution: steps_per_run = 4
293      - 2-nd execution: steps_per_run = 4
294      - 3-rd execution: steps_per_run = 2
295
296  As model_fn increases the global step once per train_op invocation, the global
297  step is 10 after all executions, matching the steps=10 inputs passed in by
298  users.
299
300  Returns:
301    A TF non-trainable resource variable.
302
303  Raises:
304    RuntimeError: If multi steps_per_run variables were found.
305  """
306  graph = ops.get_default_graph()
307  collection_name = "{}_{}".format(_HOOKS, _STEPS_PER_RUN_VAR)
308  steps_per_run_vars = graph.get_collection(collection_name)
309  if len(steps_per_run_vars) == 1:
310    return steps_per_run_vars[0]
311  elif len(steps_per_run_vars) > 1:
312    raise RuntimeError("Multiple steps_per_run_var in collection.")
313
314  with variable_scope.variable_scope(_HOOKS, reuse=variable_scope.AUTO_REUSE):
315    return variable_scope.get_variable(
316        _STEPS_PER_RUN_VAR,
317        initializer=init_ops.ones_initializer(),
318        shape=[],
319        dtype=dtypes.int32,
320        trainable=False,
321        collections=[collection_name, ops.GraphKeys.LOCAL_VARIABLES],
322        use_resource=True)
323
324
325class _MultiStepStopAtStepHook(session_run_hook.SessionRunHook):
326  """Hook that requests stop at a specified step."""
327
328  def __init__(self, num_steps=None, last_step=None, steps_per_run=1):
329    """Initializes a `MultiStepStopAtStepHook`.
330
331    This hook requests stop after either a number of steps have been
332    executed or a last step has been reached. Only one of the two options can be
333    specified.
334
335    if `num_steps` is specified, it indicates the number of steps to execute
336    after `begin()` is called. If instead `last_step` is specified, it
337    indicates the last step we want to execute, as passed to the `after_run()`
338    call.
339
340    In Estimator, the user provided computation, the model_fn, is wrapped
341    inside a tf.while_loop for peak performance. The steps_per_run variable
342    determines the number of iterations of the loop before returning to the CPU.
343
344    Args:
345      num_steps: Number of steps to execute.
346      last_step: Step after which to stop.
347      steps_per_run: Number of steps executed per run call.
348
349    Raises:
350      ValueError: If one of the arguments is invalid.
351    """
352    if num_steps is None and last_step is None:
353      raise ValueError("One of num_steps or last_step must be specified.")
354    if num_steps is not None and last_step is not None:
355      raise ValueError("Only one of num_steps or last_step can be specified.")
356    if steps_per_run is None or steps_per_run < 1:
357      raise ValueError("steps_per_run should be greater than 0")
358    self._num_steps = num_steps
359    self._last_step = last_step
360    self._steps_per_run_initial_value = steps_per_run
361
362  def begin(self):
363    self._global_step_tensor = training_util.get_global_step()
364    if self._global_step_tensor is None:
365      raise RuntimeError("Global step should be created to use StopAtStepHook.")
366    self._steps_per_run_variable = get_or_create_steps_per_run_variable()
367
368  def _update_steps_per_run_variable(self, global_step, session):
369    steps = min(self._last_step - global_step,
370                self._steps_per_run_initial_value)
371    self._steps_per_run_variable.load(steps, session=session)
372
373  def after_create_session(self, session, coord):
374    global_step = session.run(self._global_step_tensor)
375    if self._last_step is None:
376      self._last_step = global_step + self._num_steps
377    self._update_steps_per_run_variable(global_step, session)
378
379  def after_run(self, run_context, run_values):
380    # Global step cannot be retrieved via SessionRunArgs and before_run due to
381    # race condition in hook execution.
382    global_step = run_context.session.run(self._global_step_tensor)
383    if global_step >= self._last_step:
384      run_context.request_stop()
385    else:
386      self._update_steps_per_run_variable(global_step, run_context.session)
387
388
389@tf_export(v1=["train.StopAtStepHook"])
390class StopAtStepHook(session_run_hook.SessionRunHook):
391  """Hook that requests stop at a specified step."""
392
393  def __init__(self, num_steps=None, last_step=None):
394    """Initializes a `StopAtStepHook`.
395
396    This hook requests stop after either a number of steps have been
397    executed or a last step has been reached. Only one of the two options can be
398    specified.
399
400    if `num_steps` is specified, it indicates the number of steps to execute
401    after `begin()` is called. If instead `last_step` is specified, it
402    indicates the last step we want to execute, as passed to the `after_run()`
403    call.
404
405    Args:
406      num_steps: Number of steps to execute.
407      last_step: Step after which to stop.
408
409    Raises:
410      ValueError: If one of the arguments is invalid.
411    """
412    if num_steps is None and last_step is None:
413      raise ValueError("One of num_steps or last_step must be specified.")
414    if num_steps is not None and last_step is not None:
415      raise ValueError("Only one of num_steps or last_step can be specified.")
416    self._num_steps = num_steps
417    self._last_step = last_step
418
419  def begin(self):
420    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
421    if self._global_step_tensor is None:
422      raise RuntimeError("Global step should be created to use StopAtStepHook.")
423
424  def after_create_session(self, session, coord):
425    if self._last_step is None:
426      global_step = session.run(self._global_step_tensor)
427      self._last_step = global_step + self._num_steps
428
429  def before_run(self, run_context):  # pylint: disable=unused-argument
430    return SessionRunArgs(self._global_step_tensor)
431
432  def after_run(self, run_context, run_values):
433    global_step = run_values.results + 1
434    if global_step >= self._last_step:
435      # Check latest global step to ensure that the targeted last step is
436      # reached. global_step read tensor is the value of global step
437      # before running the operation. We're not sure whether current session.run
438      # incremented the global_step or not. Here we're checking it.
439
440      step = run_context.session.run(self._global_step_tensor)
441      if step >= self._last_step:
442        run_context.request_stop()
443
444
445@tf_export(v1=["train.CheckpointSaverListener"])
446class CheckpointSaverListener(object):
447  """Interface for listeners that take action before or after checkpoint save.
448
449  `CheckpointSaverListener` triggers only in steps when `CheckpointSaverHook` is
450  triggered, and provides callbacks at the following points:
451   - before using the session
452   - before each call to `Saver.save()`
453   - after each call to `Saver.save()`
454   - at the end of session
455
456  To use a listener, implement a class and pass the listener to a
457  `CheckpointSaverHook`, as in this example:
458
459  ```python
460  class ExampleCheckpointSaverListener(CheckpointSaverListener):
461    def begin(self):
462      # You can add ops to the graph here.
463      print('Starting the session.')
464      self.your_tensor = ...
465
466    def before_save(self, session, global_step_value):
467      print('About to write a checkpoint')
468
469    def after_save(self, session, global_step_value):
470      print('Done writing checkpoint.')
471      if decided_to_stop_training():
472        return True
473
474    def end(self, session, global_step_value):
475      print('Done with the session.')
476
477  ...
478  listener = ExampleCheckpointSaverListener()
479  saver_hook = tf.estimator.CheckpointSaverHook(
480      checkpoint_dir, listeners=[listener])
481  with
482  tf.compat.v1.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]):
483    ...
484  ```
485
486  A `CheckpointSaverListener` may simply take some action after every
487  checkpoint save. It is also possible for the listener to use its own schedule
488  to act less frequently, e.g. based on global_step_value. In this case,
489  implementors should implement the `end()` method to handle actions related to
490  the last checkpoint save. But the listener should not act twice if
491  `after_save()` already handled this last checkpoint save.
492
493  A `CheckpointSaverListener` can request training to be stopped, by returning
494  True in `after_save`. Please note that, in replicated distributed training
495  setting, only `chief` should use this behavior. Otherwise each worker will do
496  their own evaluation, which may be wasteful of resources.
497  """
498
499  def begin(self):
500    pass
501
502  def before_save(self, session, global_step_value):
503    pass
504
505  def after_save(self, session, global_step_value):
506    pass
507
508  def end(self, session, global_step_value):
509    pass
510
511
512@tf_export(v1=["train.CheckpointSaverHook"])
513class CheckpointSaverHook(session_run_hook.SessionRunHook):
514  """Saves checkpoints every N steps or seconds."""
515
516  def __init__(self,
517               checkpoint_dir,
518               save_secs=None,
519               save_steps=None,
520               saver=None,
521               checkpoint_basename="model.ckpt",
522               scaffold=None,
523               listeners=None,
524               save_graph_def=True):
525    """Initializes a `CheckpointSaverHook`.
526
527    Args:
528      checkpoint_dir: `str`, base directory for the checkpoint files.
529      save_secs: `int`, save every N secs.
530      save_steps: `int`, save every N steps.
531      saver: `Saver` object, used for saving.
532      checkpoint_basename: `str`, base name for the checkpoint files.
533      scaffold: `Scaffold`, use to get saver object.
534      listeners: List of `CheckpointSaverListener` subclass instances. Used for
535        callbacks that run immediately before or after this hook saves the
536        checkpoint.
537      save_graph_def: Whether to save the GraphDef and MetaGraphDef to
538        `checkpoint_dir`. The GraphDef is saved after the session is created as
539        `graph.pbtxt`. MetaGraphDefs are saved out for every checkpoint as
540        `model.ckpt-*.meta`.
541
542    Raises:
543      ValueError: One of `save_steps` or `save_secs` should be set.
544      ValueError: At most one of `saver` or `scaffold` should be set.
545    """
546    logging.info("Create CheckpointSaverHook.")
547    if saver is not None and scaffold is not None:
548      raise ValueError("You cannot provide both saver and scaffold.")
549    self._saver = saver
550    self._checkpoint_dir = checkpoint_dir
551    self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
552    self._scaffold = scaffold
553    self._timer = SecondOrStepTimer(
554        every_secs=save_secs, every_steps=save_steps)
555    self._listeners = listeners or []
556    self._steps_per_run = 1
557    self._save_graph_def = save_graph_def
558
559  def _set_steps_per_run(self, steps_per_run):
560    self._steps_per_run = steps_per_run
561
562  def begin(self):
563    self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
564    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
565    if self._global_step_tensor is None:
566      raise RuntimeError(
567          "Global step should be created to use CheckpointSaverHook.")
568    for l in self._listeners:
569      l.begin()
570
571  def after_create_session(self, session, coord):
572    global_step = session.run(self._global_step_tensor)
573    if self._save_graph_def:
574      # We do write graph and saver_def at the first call of before_run.
575      # We cannot do this in begin, since we let other hooks to change graph and
576      # add variables in begin. Graph is finalized after all begin calls.
577      training_util.write_graph(
578          ops.get_default_graph().as_graph_def(add_shapes=True),
579          self._checkpoint_dir, "graph.pbtxt")
580    saver_def = self._get_saver().saver_def if self._get_saver() else None
581    graph = ops.get_default_graph()
582    meta_graph_def = meta_graph.create_meta_graph_def(
583        graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def)
584    self._summary_writer.add_graph(graph)
585    self._summary_writer.add_meta_graph(meta_graph_def)
586    # The checkpoint saved here is the state at step "global_step".
587    self._save(session, global_step)
588    self._timer.update_last_triggered_step(global_step)
589
590  def before_run(self, run_context):  # pylint: disable=unused-argument
591    return SessionRunArgs(self._global_step_tensor)
592
593  def after_run(self, run_context, run_values):
594    stale_global_step = run_values.results
595    if self._timer.should_trigger_for_step(stale_global_step +
596                                           self._steps_per_run):
597      # get the real value after train op.
598      global_step = run_context.session.run(self._global_step_tensor)
599      if self._timer.should_trigger_for_step(global_step):
600        self._timer.update_last_triggered_step(global_step)
601        if self._save(run_context.session, global_step):
602          run_context.request_stop()
603
604  def end(self, session):
605    last_step = session.run(self._global_step_tensor)
606    if last_step != self._timer.last_triggered_step():
607      self._save(session, last_step)
608    for l in self._listeners:
609      l.end(session, last_step)
610
611  def _save(self, session, step):
612    """Saves the latest checkpoint, returns should_stop."""
613    logging.info("Calling checkpoint listeners before saving checkpoint %d...",
614                 step)
615    for l in self._listeners:
616      l.before_save(session, step)
617
618    logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
619    self._get_saver().save(session, self._save_path, global_step=step,
620                           write_meta_graph=self._save_graph_def)
621    self._summary_writer.add_session_log(
622        SessionLog(
623            status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
624        step)
625    logging.info("Calling checkpoint listeners after saving checkpoint %d...",
626                 step)
627    should_stop = False
628    for l in self._listeners:
629      if l.after_save(session, step):
630        logging.info(
631            "A CheckpointSaverListener requested that training be stopped. "
632            "listener: {}".format(l))
633        should_stop = True
634    return should_stop
635
636  def _get_saver(self):
637    if self._saver is not None:
638      return self._saver
639    elif self._scaffold is not None:
640      return self._scaffold.saver
641
642    # Get saver from the SAVERS collection if present.
643    collection_key = ops.GraphKeys.SAVERS
644    savers = ops.get_collection(collection_key)
645    if not savers:
646      raise RuntimeError(
647          "No items in collection {}. Please add a saver to the collection "
648          "or provide a saver or scaffold.".format(collection_key))
649    elif len(savers) > 1:
650      raise RuntimeError(
651          "More than one item in collection {}. "
652          "Please indicate which one to use by passing it to the constructor."
653          .format(collection_key))
654
655    self._saver = savers[0]
656    return savers[0]
657
658
659@tf_export(v1=["train.StepCounterHook"])
660class StepCounterHook(session_run_hook.SessionRunHook):
661  """Hook that counts steps per second."""
662
663  def __init__(self,
664               every_n_steps=100,
665               every_n_secs=None,
666               output_dir=None,
667               summary_writer=None):
668
669    if (every_n_steps is None) == (every_n_secs is None):
670      raise ValueError(
671          "exactly one of every_n_steps and every_n_secs should be provided.")
672    self._timer = SecondOrStepTimer(
673        every_steps=every_n_steps, every_secs=every_n_secs)
674
675    self._summary_writer = summary_writer
676    self._output_dir = output_dir
677    self._last_global_step = None
678    self._steps_per_run = 1
679
680  def _set_steps_per_run(self, steps_per_run):
681    self._steps_per_run = steps_per_run
682
683  def begin(self):
684    if self._summary_writer is None and self._output_dir:
685      self._summary_writer = SummaryWriterCache.get(self._output_dir)
686    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
687    if self._global_step_tensor is None:
688      raise RuntimeError(
689          "Global step should be created to use StepCounterHook.")
690    self._summary_tag = training_util.get_global_step().op.name + "/sec"
691
692  def before_run(self, run_context):  # pylint: disable=unused-argument
693    return SessionRunArgs(self._global_step_tensor)
694
695  def _log_and_record(self, elapsed_steps, elapsed_time, global_step):
696    steps_per_sec = elapsed_steps / elapsed_time
697    if self._summary_writer is not None:
698      summary = Summary(value=[
699          Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec)
700      ])
701      self._summary_writer.add_summary(summary, global_step)
702    logging.info("%s: %g", self._summary_tag, steps_per_sec)
703
704  def after_run(self, run_context, run_values):
705    _ = run_context
706
707    stale_global_step = run_values.results
708    if self._timer.should_trigger_for_step(stale_global_step +
709                                           self._steps_per_run):
710      # get the real value after train op.
711      global_step = run_context.session.run(self._global_step_tensor)
712      if self._timer.should_trigger_for_step(global_step):
713        elapsed_time, elapsed_steps = self._timer.update_last_triggered_step(
714            global_step)
715        if elapsed_time is not None:
716          self._log_and_record(elapsed_steps, elapsed_time, global_step)
717
718    # Check whether the global step has been increased. Here, we do not use the
719    # timer.last_triggered_step as the timer might record a different global
720    # step value such that the comparison could be unreliable. For simplicity,
721    # we just compare the stale_global_step with previously recorded version.
722    if stale_global_step == self._last_global_step:
723      # Here, we give a warning in the first 5 times if we have observed that
724      # the global step has not been increased. For some Optimizers, the global
725      # step is not increased each time by design. For example,
726      # SyncReplicaOptimizer doesn't increase the global step in worker's main
727      # train step.
728      logging.log_first_n(
729          logging.WARN,
730          "It seems that global step (tf.train.get_global_step) has not "
731          "been increased. Current value (could be stable): %s vs previous "
732          "value: %s. You could increase the global step by passing "
733          "tf.train.get_global_step() to Optimizer.apply_gradients or "
734          "Optimizer.minimize.", 5, stale_global_step, self._last_global_step)
735
736    self._last_global_step = stale_global_step
737
738
739@tf_export(v1=["train.NanLossDuringTrainingError"])
740class NanLossDuringTrainingError(RuntimeError):
741
742  def __str__(self):
743    return "NaN loss during training."
744
745
746@tf_export(v1=["train.NanTensorHook"])
747class NanTensorHook(session_run_hook.SessionRunHook):
748  """Monitors the loss tensor and stops training if loss is NaN.
749
750  Can either fail with exception or just stop training.
751  """
752
753  def __init__(self, loss_tensor, fail_on_nan_loss=True):
754    """Initializes a `NanTensorHook`.
755
756    Args:
757      loss_tensor: `Tensor`, the loss tensor.
758      fail_on_nan_loss: `bool`, whether to raise exception when loss is NaN.
759    """
760    self._loss_tensor = loss_tensor
761    self._fail_on_nan_loss = fail_on_nan_loss
762
763  def before_run(self, run_context):  # pylint: disable=unused-argument
764    return SessionRunArgs(self._loss_tensor)
765
766  def after_run(self, run_context, run_values):
767    if np.isnan(run_values.results):
768      failure_message = "Model diverged with loss = NaN."
769      if self._fail_on_nan_loss:
770        logging.error(failure_message)
771        raise NanLossDuringTrainingError
772      else:
773        logging.warning(failure_message)
774        # We don't raise an error but we request stop without an exception.
775        run_context.request_stop()
776
777
778@tf_export(v1=["train.SummarySaverHook"])
779class SummarySaverHook(session_run_hook.SessionRunHook):
780  """Saves summaries every N steps."""
781
782  def __init__(self,
783               save_steps=None,
784               save_secs=None,
785               output_dir=None,
786               summary_writer=None,
787               scaffold=None,
788               summary_op=None):
789    """Initializes a `SummarySaverHook`.
790
791    Args:
792      save_steps: `int`, save summaries every N steps. Exactly one of
793        `save_secs` and `save_steps` should be set.
794      save_secs: `int`, save summaries every N seconds.
795      output_dir: `string`, the directory to save the summaries to. Only used if
796        no `summary_writer` is supplied.
797      summary_writer: `SummaryWriter`. If `None` and an `output_dir` was passed,
798        one will be created accordingly.
799      scaffold: `Scaffold` to get summary_op if it's not provided.
800      summary_op: `Tensor` of type `string` containing the serialized `Summary`
801        protocol buffer or a list of `Tensor`. They are most likely an output by
802        TF summary methods like `tf.compat.v1.summary.scalar` or
803        `tf.compat.v1.summary.merge_all`. It can be passed in as one tensor; if
804        more than one, they must be passed in as a list.
805
806    Raises:
807      ValueError: Exactly one of scaffold or summary_op should be set.
808    """
809    if ((scaffold is None and summary_op is None) or
810        (scaffold is not None and summary_op is not None)):
811      raise ValueError(
812          "Exactly one of scaffold or summary_op must be provided.")
813    self._summary_op = summary_op
814    self._summary_writer = summary_writer
815    self._output_dir = output_dir
816    self._scaffold = scaffold
817    self._timer = SecondOrStepTimer(
818        every_secs=save_secs, every_steps=save_steps)
819    # TODO(mdan): Throw an error if output_dir and summary_writer are None.
820
821  def begin(self):
822    if self._summary_writer is None and self._output_dir:
823      self._summary_writer = SummaryWriterCache.get(self._output_dir)
824    self._next_step = None
825    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
826    if self._global_step_tensor is None:
827      raise RuntimeError(
828          "Global step should be created to use SummarySaverHook.")
829
830  def before_run(self, run_context):  # pylint: disable=unused-argument
831    self._request_summary = (
832        self._next_step is None or
833        self._timer.should_trigger_for_step(self._next_step))
834    requests = {"global_step": self._global_step_tensor}
835    if self._request_summary:
836      if self._get_summary_op() is not None:
837        requests["summary"] = self._get_summary_op()
838
839    return SessionRunArgs(requests)
840
841  def after_run(self, run_context, run_values):
842    _ = run_context
843    if not self._summary_writer:
844      return
845
846    stale_global_step = run_values.results["global_step"]
847    global_step = stale_global_step + 1
848    if self._next_step is None or self._request_summary:
849      global_step = run_context.session.run(self._global_step_tensor)
850
851    if self._next_step is None:
852      self._summary_writer.add_session_log(
853          SessionLog(status=SessionLog.START), global_step)
854
855    if self._request_summary:
856      self._timer.update_last_triggered_step(global_step)
857      if "summary" in run_values.results:
858        for summary in run_values.results["summary"]:
859          self._summary_writer.add_summary(summary, global_step)
860
861    self._next_step = global_step + 1
862
863  def end(self, session=None):
864    if self._summary_writer:
865      self._summary_writer.flush()
866
867  def _get_summary_op(self):
868    """Fetches the summary op either from self._summary_op or self._scaffold.
869
870    Returns:
871      Returns a list of summary `Tensor`.
872    """
873    summary_op = None
874    if self._summary_op is not None:
875      summary_op = self._summary_op
876    elif self._scaffold.summary_op is not None:
877      summary_op = self._scaffold.summary_op
878
879    if summary_op is None:
880      return None
881
882    if not isinstance(summary_op, list):
883      return [summary_op]
884    return summary_op
885
886
887@tf_export(v1=["train.GlobalStepWaiterHook"])
888class GlobalStepWaiterHook(session_run_hook.SessionRunHook):
889  """Delays execution until global step reaches `wait_until_step`.
890
891  This hook delays execution until global step reaches to `wait_until_step`. It
892  is used to gradually start workers in distributed settings. One example usage
893  would be setting `wait_until_step=int(K*log(task_id+1))` assuming that
894  task_id=0 is the chief.
895  """
896
897  def __init__(self, wait_until_step):
898    """Initializes a `GlobalStepWaiterHook`.
899
900    Args:
901      wait_until_step: an `int` shows until which global step should we wait.
902    """
903    self._wait_until_step = wait_until_step
904
905  def begin(self):
906    self._worker_is_started = False
907    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
908    if self._global_step_tensor is None:
909      raise RuntimeError(
910          "Global step should be created to use _GlobalStepWaiterHook.")
911
912  def before_run(self, run_context):
913    if self._worker_is_started:
914      return None
915
916    if self._wait_until_step <= 0:
917      self._worker_is_started = True
918      return None
919
920    logging.info("Waiting for global step %d before starting training.",
921                 self._wait_until_step)
922    last_logged_step = 0
923    while True:
924      current_step = run_context.session.run(self._global_step_tensor)
925      if current_step >= self._wait_until_step:
926        self._worker_is_started = True
927        return None
928      if current_step - last_logged_step > 1000:
929        logging.info(
930            "Waiting for global step %d before starting training. "
931            "Current step is %d.", self._wait_until_step, current_step)
932        last_logged_step = current_step
933      time.sleep(0.5)
934
935
936@tf_export(v1=["train.FinalOpsHook"])
937class FinalOpsHook(session_run_hook.SessionRunHook):
938  """A hook which evaluates `Tensors` at the end of a session."""
939
940  def __init__(self, final_ops, final_ops_feed_dict=None):
941    """Initializes `FinalOpHook` with ops to run at the end of the session.
942
943    Args:
944      final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
945        to `Tensors`.
946      final_ops_feed_dict: A feed dictionary to use when running
947        `final_ops_dict`.
948    """
949    self._final_ops = final_ops
950    self._final_ops_feed_dict = final_ops_feed_dict
951    self._final_ops_values = None
952
953  @property
954  def final_ops_values(self):
955    return self._final_ops_values
956
957  def end(self, session):
958    if self._final_ops is not None:
959      try:
960        self._final_ops_values = session.run(
961            self._final_ops, feed_dict=self._final_ops_feed_dict)
962      except (errors.OutOfRangeError, StopIteration) as e:
963        logging.warning(
964            "An OutOfRangeError or StopIteration exception is raised by the "
965            "code in FinalOpsHook. This typically means the Ops running by the "
966            "FinalOpsHook have a dependency back to some input source, which "
967            "should not happen. For example, for metrics in "
968            "tf.estimator.Estimator, all metrics functions return two Ops: "
969            "`value_op` and  `update_op`. Estimator.evaluate calls the "
970            "`update_op` for each batch of the data in input source and, once "
971            "it is exhausted, it call the `value_op` to get the metric values. "
972            "The `value_op` here should have dependency back to variables "
973            "reading only, rather than reading another batch from input. "
974            "Otherwise, the `value_op`, executed by `FinalOpsHook`, triggers "
975            "another data reading, which ends OutOfRangeError/StopIteration. "
976            "Please fix that.")
977        raise e
978
979
980@tf_export(v1=["train.FeedFnHook"])
981class FeedFnHook(session_run_hook.SessionRunHook):
982  """Runs `feed_fn` and sets the `feed_dict` accordingly."""
983
984  def __init__(self, feed_fn):
985    """Initializes a `FeedFnHook`.
986
987    Args:
988      feed_fn: function that takes no arguments and returns `dict` of `Tensor`
989        to feed.
990    """
991    self.feed_fn = feed_fn
992
993  def before_run(self, run_context):  # pylint: disable=unused-argument
994    return session_run_hook.SessionRunArgs(
995        fetches=None, feed_dict=self.feed_fn())
996
997
998@tf_export(v1=["train.ProfilerHook"])
999class ProfilerHook(session_run_hook.SessionRunHook):
1000  """Captures CPU/GPU profiling information every N steps or seconds.
1001
1002  This produces files called "timeline-<step>.json", which are in Chrome
1003  Trace format.
1004
1005  For more information see:
1006  https://github.com/catapult-project/catapult/blob/master/tracing/README.md
1007  """
1008
1009  def __init__(self,
1010               save_steps=None,
1011               save_secs=None,
1012               output_dir="",
1013               show_dataflow=True,
1014               show_memory=False):
1015    """Initializes a hook that takes periodic profiling snapshots.
1016
1017    `options.run_metadata` argument of `tf.Session.Run` is used to collect
1018    metadata about execution. This hook sets the metadata and dumps it in Chrome
1019    Trace format.
1020
1021
1022    Args:
1023      save_steps: `int`, save profile traces every N steps. Exactly one of
1024        `save_secs` and `save_steps` should be set.
1025      save_secs: `int` or `float`, save profile traces every N seconds.
1026      output_dir: `string`, the directory to save the profile traces to.
1027        Defaults to the current directory.
1028      show_dataflow: `bool`, if True, add flow events to the trace connecting
1029        producers and consumers of tensors.
1030      show_memory: `bool`, if True, add object snapshot events to the trace
1031        showing the sizes and lifetimes of tensors.
1032    """
1033    self._output_file = os.path.join(output_dir, "timeline-{}.json")
1034    self._file_writer = SummaryWriterCache.get(output_dir)
1035    self._show_dataflow = show_dataflow
1036    self._show_memory = show_memory
1037    self._timer = SecondOrStepTimer(
1038        every_secs=save_secs, every_steps=save_steps)
1039
1040  def begin(self):
1041    self._next_step = None
1042    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
1043    if self._global_step_tensor is None:
1044      raise RuntimeError("Global step should be created to use ProfilerHook.")
1045
1046  def before_run(self, run_context):
1047    self._request_summary = (
1048        self._next_step is not None and
1049        self._timer.should_trigger_for_step(self._next_step))
1050    requests = {"global_step": self._global_step_tensor}
1051    opts = (
1052        config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
1053        if self._request_summary else None)
1054
1055    return SessionRunArgs(requests, options=opts)
1056
1057  def after_run(self, run_context, run_values):
1058    stale_global_step = run_values.results["global_step"]
1059    if self._next_step is None:
1060      # Update the timer so that it does not activate until N steps or seconds
1061      # have passed.
1062      self._timer.update_last_triggered_step(stale_global_step)
1063    global_step = stale_global_step + 1
1064    if self._request_summary:
1065      global_step = run_context.session.run(self._global_step_tensor)
1066      self._timer.update_last_triggered_step(global_step)
1067      self._save(global_step, self._output_file.format(global_step),
1068                 run_values.run_metadata.step_stats)
1069      self._file_writer.add_run_metadata(run_values.run_metadata,
1070                                         "step_%d" % global_step)
1071
1072    self._next_step = global_step + 1
1073
1074  def _save(self, step, save_path, step_stats):
1075    logging.info("Saving timeline for %d into '%s'.", step, save_path)
1076    with gfile.Open(save_path, "w") as f:
1077      trace = timeline.Timeline(step_stats)
1078      f.write(
1079          trace.generate_chrome_trace_format(
1080              show_dataflow=self._show_dataflow, show_memory=self._show_memory))
1081
1082
1083def _as_graph_element(obj):
1084  """Retrieves Graph element."""
1085  graph = ops.get_default_graph()
1086  if not isinstance(obj, six.string_types):
1087    if not hasattr(obj, "graph") or obj.graph != graph:
1088      raise ValueError("Passed %s should have graph attribute that is equal "
1089                       "to current graph %s." % (obj, graph))
1090    return obj
1091  if ":" in obj:
1092    element = graph.as_graph_element(obj)
1093  else:
1094    element = graph.as_graph_element(obj + ":0")
1095    # Check that there is no :1 (e.g. it's single output).
1096    try:
1097      graph.as_graph_element(obj + ":1")
1098    except (KeyError, ValueError):
1099      pass
1100    else:
1101      raise ValueError("Name %s is ambiguous, "
1102                       "as this `Operation` has multiple outputs "
1103                       "(at least 2)." % obj)
1104  return element
1105