1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Training helper that checkpoints models and creates session."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import time
21
22import numpy as np
23
24from tensorflow.python.client import session
25from tensorflow.python.distribute import distribution_strategy_context
26from tensorflow.python.framework import errors
27from tensorflow.python.framework import ops
28from tensorflow.python.platform import tf_logging as logging
29from tensorflow.python.training import checkpoint_management
30from tensorflow.python.util.tf_export import tf_export
31
32
33def _maybe_name(obj):
34  """Returns object name if it has one, or a message otherwise.
35
36  This is useful for names that apper in error messages.
37  Args:
38    obj: Object to get the name of.
39  Returns:
40    name, "None", or a "no name" message.
41  """
42  if obj is None:
43    return "None"
44  elif hasattr(obj, "name"):
45    return obj.name
46  else:
47    return "<no name for %s>" % type(obj)
48
49
50@tf_export(v1=["train.SessionManager"])
51class SessionManager(object):
52  """Training helper that restores from checkpoint and creates session.
53
54  This class is a small wrapper that takes care of session creation and
55  checkpoint recovery. It also provides functions that to facilitate
56  coordination among multiple training threads or processes.
57
58  * Checkpointing trained variables as the training progresses.
59  * Initializing variables on startup, restoring them from the most recent
60    checkpoint after a crash, or wait for checkpoints to become available.
61
62  ### Usage:
63
64  ```python
65  with tf.Graph().as_default():
66     ...add operations to the graph...
67    # Create a SessionManager that will checkpoint the model in '/tmp/mydir'.
68    sm = SessionManager()
69    sess = sm.prepare_session(master, init_op, saver, checkpoint_dir)
70    # Use the session to train the graph.
71    while True:
72      sess.run(<my_train_op>)
73  ```
74
75  `prepare_session()` initializes or restores a model. It requires `init_op`
76  and `saver` as an argument.
77
78  A second process could wait for the model to be ready by doing the following:
79
80  ```python
81  with tf.Graph().as_default():
82     ...add operations to the graph...
83    # Create a SessionManager that will wait for the model to become ready.
84    sm = SessionManager()
85    sess = sm.wait_for_session(master)
86    # Use the session to train the graph.
87    while True:
88      sess.run(<my_train_op>)
89  ```
90
91  `wait_for_session()` waits for a model to be initialized by other processes.
92
93  """
94
95  def __init__(self,
96               local_init_op=None,
97               ready_op=None,
98               ready_for_local_init_op=None,
99               graph=None,
100               recovery_wait_secs=30,
101               local_init_run_options=None,
102               local_init_feed_dict=None):
103    """Creates a SessionManager.
104
105    The `local_init_op` is an `Operation` that is run always after a new session
106    was created. If `None`, this step is skipped.
107
108    The `ready_op` is an `Operation` used to check if the model is ready.  The
109    model is considered ready if that operation returns an empty 1D string
110    tensor. If the operation returns a non empty 1D string tensor, the elements
111    are concatenated and used to indicate to the user why the model is not
112    ready.
113
114    The `ready_for_local_init_op` is an `Operation` used to check if the model
115    is ready to run local_init_op.  The model is considered ready if that
116    operation returns an empty 1D string tensor. If the operation returns a non
117    empty 1D string tensor, the elements are concatenated and used to indicate
118    to the user why the model is not ready.
119
120    If `ready_op` is `None`, the model is not checked for readiness.
121
122    `recovery_wait_secs` is the number of seconds between checks that
123    the model is ready.  It is used by processes to wait for a model to
124    be initialized or restored.  Defaults to 30 seconds.
125
126    Args:
127      local_init_op: An `Operation` run immediately after session creation.
128         Usually used to initialize tables and local variables.
129      ready_op: An `Operation` to check if the model is initialized.
130      ready_for_local_init_op: An `Operation` to check if the model is ready
131         to run local_init_op.
132      graph: The `Graph` that the model will use.
133      recovery_wait_secs: Seconds between checks for the model to be ready.
134      local_init_run_options: RunOptions to be passed to session.run when
135        executing the local_init_op.
136      local_init_feed_dict: Optional session feed dictionary to use when running
137        the local_init_op.
138
139    Raises:
140      ValueError: If ready_for_local_init_op is not None but local_init_op is
141        None
142    """
143    # Sets default values of arguments.
144    if graph is None:
145      graph = ops.get_default_graph()
146    self._local_init_op = local_init_op
147    self._ready_op = ready_op
148    self._ready_for_local_init_op = ready_for_local_init_op
149    self._graph = graph
150    self._recovery_wait_secs = recovery_wait_secs
151    self._target = None
152    self._local_init_run_options = local_init_run_options
153    self._local_init_feed_dict = local_init_feed_dict
154    if ready_for_local_init_op is not None and local_init_op is None:
155      raise ValueError("If you pass a ready_for_local_init_op "
156                       "you must also pass a local_init_op "
157                       ", ready_for_local_init_op [%s]" %
158                       ready_for_local_init_op)
159
160  def _restore_checkpoint(self,
161                          master,
162                          saver=None,
163                          checkpoint_dir=None,
164                          checkpoint_filename_with_path=None,
165                          wait_for_checkpoint=False,
166                          max_wait_secs=7200,
167                          config=None):
168    """Creates a `Session`, and tries to restore a checkpoint.
169
170
171    Args:
172      master: `String` representation of the TensorFlow master to use.
173      saver: A `Saver` object used to restore a model.
174      checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
175        dir will be used to restore.
176      checkpoint_filename_with_path: Full file name path to the checkpoint file.
177      wait_for_checkpoint: Whether to wait for checkpoint to become available.
178      max_wait_secs: Maximum time to wait for checkpoints to become available.
179      config: Optional `ConfigProto` proto used to configure the session.
180
181    Returns:
182      A pair (sess, is_restored) where 'is_restored' is `True` if
183      the session could be restored, `False` otherwise.
184
185    Raises:
186      ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
187        set.
188    """
189    self._target = master
190
191    # This is required to so that we initialize the TPU device before
192    # restoring from checkpoint since we'll be placing variables on the device
193    # and TPUInitialize wipes out the memory of the device.
194    strategy = distribution_strategy_context.get_strategy()
195    if strategy and hasattr(strategy.extended,
196                            "_experimental_initialize_system"):
197      strategy.extended._experimental_initialize_system()  # pylint: disable=protected-access
198
199    sess = session.Session(self._target, graph=self._graph, config=config)
200    if checkpoint_dir and checkpoint_filename_with_path:
201      raise ValueError("Can not provide both checkpoint_dir and "
202                       "checkpoint_filename_with_path.")
203    # If either saver or checkpoint_* is not specified, cannot restore. Just
204    # return.
205    if not saver or not (checkpoint_dir or checkpoint_filename_with_path):
206      return sess, False
207
208    if checkpoint_filename_with_path:
209      saver.restore(sess, checkpoint_filename_with_path)
210      return sess, True
211
212    # Waits up until max_wait_secs for checkpoint to become available.
213    wait_time = 0
214    ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir)
215    while not ckpt or not ckpt.model_checkpoint_path:
216      if wait_for_checkpoint and wait_time < max_wait_secs:
217        logging.info("Waiting for checkpoint to be available.")
218        time.sleep(self._recovery_wait_secs)
219        wait_time += self._recovery_wait_secs
220        ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir)
221      else:
222        return sess, False
223
224    # Loads the checkpoint.
225    saver.restore(sess, ckpt.model_checkpoint_path)
226    saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)
227    return sess, True
228
229  def prepare_session(self,
230                      master,
231                      init_op=None,
232                      saver=None,
233                      checkpoint_dir=None,
234                      checkpoint_filename_with_path=None,
235                      wait_for_checkpoint=False,
236                      max_wait_secs=7200,
237                      config=None,
238                      init_feed_dict=None,
239                      init_fn=None):
240    """Creates a `Session`. Makes sure the model is ready to be used.
241
242    Creates a `Session` on 'master'. If a `saver` object is passed in, and
243    `checkpoint_dir` points to a directory containing valid checkpoint
244    files, then it will try to recover the model from checkpoint. If
245    no checkpoint files are available, and `wait_for_checkpoint` is
246    `True`, then the process would check every `recovery_wait_secs`,
247    up to `max_wait_secs`, for recovery to succeed.
248
249    If the model cannot be recovered successfully then it is initialized by
250    running the `init_op` and calling `init_fn` if they are provided.
251    The `local_init_op` is also run after init_op and init_fn, regardless of
252    whether the model was recovered successfully, but only if
253    `ready_for_local_init_op` passes.
254
255    If the model is recovered from a checkpoint it is assumed that all
256    global variables have been initialized, in particular neither `init_op`
257    nor `init_fn` will be executed.
258
259    It is an error if the model cannot be recovered and no `init_op`
260    or `init_fn` or `local_init_op` are passed.
261
262    Args:
263      master: `String` representation of the TensorFlow master to use.
264      init_op: Optional `Operation` used to initialize the model.
265      saver: A `Saver` object used to restore a model.
266      checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
267        dir will be used to restore.
268      checkpoint_filename_with_path: Full file name path to the checkpoint file.
269      wait_for_checkpoint: Whether to wait for checkpoint to become available.
270      max_wait_secs: Maximum time to wait for checkpoints to become available.
271      config: Optional `ConfigProto` proto used to configure the session.
272      init_feed_dict: Optional dictionary that maps `Tensor` objects to feed
273        values.  This feed dictionary is passed to the session `run()` call when
274        running the init op.
275      init_fn: Optional callable used to initialize the model. Called after the
276        optional `init_op` is called.  The callable must accept one argument,
277        the session being initialized.
278
279    Returns:
280      A `Session` object that can be used to drive the model.
281
282    Raises:
283      RuntimeError: If the model cannot be initialized or recovered.
284      ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
285        set.
286    """
287
288    sess, is_loaded_from_checkpoint = self._restore_checkpoint(
289        master,
290        saver,
291        checkpoint_dir=checkpoint_dir,
292        checkpoint_filename_with_path=checkpoint_filename_with_path,
293        wait_for_checkpoint=wait_for_checkpoint,
294        max_wait_secs=max_wait_secs,
295        config=config)
296    if not is_loaded_from_checkpoint:
297      if init_op is None and not init_fn and self._local_init_op is None:
298        raise RuntimeError("Model is not initialized and no init_op or "
299                           "init_fn or local_init_op was given")
300      if init_op is not None:
301        sess.run(init_op, feed_dict=init_feed_dict)
302      if init_fn:
303        init_fn(sess)
304
305    local_init_success, msg = self._try_run_local_init_op(sess)
306    if not local_init_success:
307      raise RuntimeError(
308          "Init operations did not make model ready for local_init.  "
309          "Init op: %s, init fn: %s, error: %s" % (_maybe_name(init_op),
310                                                   init_fn,
311                                                   msg))
312
313    is_ready, msg = self._model_ready(sess)
314    if not is_ready:
315      raise RuntimeError(
316          "Init operations did not make model ready.  "
317          "Init op: %s, init fn: %s, local_init_op: %s, error: %s" %
318          (_maybe_name(init_op), init_fn, self._local_init_op, msg))
319    return sess
320
321  def recover_session(self,
322                      master,
323                      saver=None,
324                      checkpoint_dir=None,
325                      checkpoint_filename_with_path=None,
326                      wait_for_checkpoint=False,
327                      max_wait_secs=7200,
328                      config=None):
329    """Creates a `Session`, recovering if possible.
330
331    Creates a new session on 'master'.  If the session is not initialized
332    and can be recovered from a checkpoint, recover it.
333
334    Args:
335      master: `String` representation of the TensorFlow master to use.
336      saver: A `Saver` object used to restore a model.
337      checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
338        dir will be used to restore.
339      checkpoint_filename_with_path: Full file name path to the checkpoint file.
340      wait_for_checkpoint: Whether to wait for checkpoint to become available.
341      max_wait_secs: Maximum time to wait for checkpoints to become available.
342      config: Optional `ConfigProto` proto used to configure the session.
343
344    Returns:
345      A pair (sess, initialized) where 'initialized' is `True` if
346      the session could be recovered and initialized, `False` otherwise.
347
348    Raises:
349      ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
350        set.
351    """
352
353    sess, is_loaded_from_checkpoint = self._restore_checkpoint(
354        master,
355        saver,
356        checkpoint_dir=checkpoint_dir,
357        checkpoint_filename_with_path=checkpoint_filename_with_path,
358        wait_for_checkpoint=wait_for_checkpoint,
359        max_wait_secs=max_wait_secs,
360        config=config)
361
362    # Always try to run local_init_op
363    local_init_success, msg = self._try_run_local_init_op(sess)
364
365    if not is_loaded_from_checkpoint:
366      # Do not need to run checks for readiness
367      return sess, False
368
369    restoring_file = checkpoint_dir or checkpoint_filename_with_path
370    if not local_init_success:
371      logging.info(
372          "Restoring model from %s did not make model ready for local init:"
373          " %s", restoring_file, msg)
374      return sess, False
375
376    is_ready, msg = self._model_ready(sess)
377    if not is_ready:
378      logging.info("Restoring model from %s did not make model ready: %s",
379                   restoring_file, msg)
380      return sess, False
381
382    logging.info("Restored model from %s", restoring_file)
383    return sess, is_loaded_from_checkpoint
384
385  def wait_for_session(self, master, config=None, max_wait_secs=float("Inf")):
386    """Creates a new `Session` and waits for model to be ready.
387
388    Creates a new `Session` on 'master'.  Waits for the model to be
389    initialized or recovered from a checkpoint.  It's expected that
390    another thread or process will make the model ready, and that this
391    is intended to be used by threads/processes that participate in a
392    distributed training configuration where a different thread/process
393    is responsible for initializing or recovering the model being trained.
394
395    NB: The amount of time this method waits for the session is bounded
396    by max_wait_secs. By default, this function will wait indefinitely.
397
398    Args:
399      master: `String` representation of the TensorFlow master to use.
400      config: Optional ConfigProto proto used to configure the session.
401      max_wait_secs: Maximum time to wait for the session to become available.
402
403    Returns:
404      A `Session`. May be None if the operation exceeds the timeout
405      specified by config.operation_timeout_in_ms.
406
407    Raises:
408      tf.DeadlineExceededError: if the session is not available after
409        max_wait_secs.
410    """
411    self._target = master
412
413    if max_wait_secs is None:
414      max_wait_secs = float("Inf")
415    timer = _CountDownTimer(max_wait_secs)
416
417    while True:
418      sess = session.Session(self._target, graph=self._graph, config=config)
419      not_ready_msg = None
420      not_ready_local_msg = None
421      local_init_success, not_ready_local_msg = self._try_run_local_init_op(
422          sess)
423      if local_init_success:
424        # Successful if local_init_op is None, or ready_for_local_init_op passes
425        is_ready, not_ready_msg = self._model_ready(sess)
426        if is_ready:
427          return sess
428
429      self._safe_close(sess)
430
431      # Do we have enough time left to try again?
432      remaining_ms_after_wait = (
433          timer.secs_remaining() - self._recovery_wait_secs)
434      if remaining_ms_after_wait < 0:
435        raise errors.DeadlineExceededError(
436            None, None,
437            "Session was not ready after waiting %d secs." % (max_wait_secs,))
438
439      logging.info("Waiting for model to be ready.  "
440                   "Ready_for_local_init_op:  %s, ready: %s",
441                   not_ready_local_msg, not_ready_msg)
442      time.sleep(self._recovery_wait_secs)
443
444  def _safe_close(self, sess):
445    """Closes a session without raising an exception.
446
447    Just like sess.close() but ignores exceptions.
448
449    Args:
450      sess: A `Session`.
451    """
452    # pylint: disable=broad-except
453    try:
454      sess.close()
455    except Exception:
456      # Intentionally not logging to avoid user complaints that
457      # they get cryptic errors.  We really do not care that Close
458      # fails.
459      pass
460    # pylint: enable=broad-except
461
462  def _model_ready(self, sess):
463    """Checks if the model is ready or not.
464
465    Args:
466      sess: A `Session`.
467
468    Returns:
469      A tuple (is_ready, msg), where is_ready is True if ready and False
470      otherwise, and msg is `None` if the model is ready, a `String` with the
471      reason why it is not ready otherwise.
472    """
473    return _ready(self._ready_op, sess, "Model not ready")
474
475  def _model_ready_for_local_init(self, sess):
476    """Checks if the model is ready to run local_init_op.
477
478    Args:
479      sess: A `Session`.
480
481    Returns:
482      A tuple (is_ready, msg), where is_ready is True if ready to run
483      local_init_op and False otherwise, and msg is `None` if the model is
484      ready to run local_init_op, a `String` with the reason why it is not ready
485      otherwise.
486    """
487    return _ready(self._ready_for_local_init_op, sess,
488                  "Model not ready for local init")
489
490  def _try_run_local_init_op(self, sess):
491    """Tries to run _local_init_op, if not None, and is ready for local init.
492
493    Args:
494      sess: A `Session`.
495
496    Returns:
497      A tuple (is_successful, msg), where is_successful is True if
498      _local_init_op is None, or we ran _local_init_op, and False otherwise;
499      and msg is a `String` with the reason why the model was not ready to run
500      local init.
501    """
502    if self._local_init_op is not None:
503      is_ready_for_local_init, msg = self._model_ready_for_local_init(sess)
504      if is_ready_for_local_init:
505        logging.info("Running local_init_op.")
506        sess.run(self._local_init_op, feed_dict=self._local_init_feed_dict,
507                 options=self._local_init_run_options)
508        logging.info("Done running local_init_op.")
509        return True, None
510      else:
511        return False, msg
512    return True, None
513
514
515def _ready(op, sess, msg):
516  """Checks if the model is ready or not, as determined by op.
517
518  Args:
519    op: An op, either _ready_op or _ready_for_local_init_op, which defines the
520      readiness of the model.
521    sess: A `Session`.
522    msg: A message to log to warning if not ready
523
524  Returns:
525    A tuple (is_ready, msg), where is_ready is True if ready and False
526    otherwise, and msg is `None` if the model is ready, a `String` with the
527    reason why it is not ready otherwise.
528  """
529  if op is None:
530    return True, None
531  else:
532    try:
533      ready_value = sess.run(op)
534      # The model is considered ready if ready_op returns an empty 1-D tensor.
535      # Also compare to `None` and dtype being int32 for backward
536      # compatibility.
537      if (ready_value is None or ready_value.dtype == np.int32 or
538          ready_value.size == 0):
539        return True, None
540      else:
541        # TODO(sherrym): If a custom ready_op returns other types of tensor,
542        # or strings other than variable names, this message could be
543        # confusing.
544        non_initialized_varnames = ", ".join(
545            [i.decode("utf-8") for i in ready_value])
546        return False, "Variables not initialized: " + non_initialized_varnames
547    except errors.FailedPreconditionError as e:
548      if "uninitialized" not in str(e):
549        logging.warning("%s : error [%s]", msg, str(e))
550        raise e
551      return False, str(e)
552
553
554class _CountDownTimer(object):
555
556  __slots__ = ["_start_time_secs", "_duration_secs"]
557
558  def __init__(self, duration_secs):
559    self._start_time_secs = time.time()
560    self._duration_secs = duration_secs
561
562  def secs_remaining(self):
563    diff = self._duration_secs - (time.time() - self._start_time_secs)
564    return max(0, diff)
565