1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Training helper that checkpoints models and creates session."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import time
21import numpy as np
22
23from tensorflow.python.client import session
24from tensorflow.python.distribute import distribution_strategy_context
25from tensorflow.python.framework import errors
26from tensorflow.python.framework import ops
27from tensorflow.python.platform import tf_logging as logging
28from tensorflow.python.training import checkpoint_management
29from tensorflow.python.util.tf_export import tf_export
30
31
32def _maybe_name(obj):
33  """Returns object name if it has one, or a message otherwise.
34
35  This is useful for names that apper in error messages.
36  Args:
37    obj: Object to get the name of.
38  Returns:
39    name, "None", or a "no name" message.
40  """
41  if obj is None:
42    return "None"
43  elif hasattr(obj, "name"):
44    return obj.name
45  else:
46    return "<no name for %s>" % type(obj)
47
48
49@tf_export(v1=["train.SessionManager"])
50class SessionManager(object):
51  """Training helper that restores from checkpoint and creates session.
52
53  This class is a small wrapper that takes care of session creation and
54  checkpoint recovery. It also provides functions that to facilitate
55  coordination among multiple training threads or processes.
56
57  * Checkpointing trained variables as the training progresses.
58  * Initializing variables on startup, restoring them from the most recent
59    checkpoint after a crash, or wait for checkpoints to become available.
60
61  ### Usage:
62
63  ```python
64  with tf.Graph().as_default():
65     ...add operations to the graph...
66    # Create a SessionManager that will checkpoint the model in '/tmp/mydir'.
67    sm = SessionManager()
68    sess = sm.prepare_session(master, init_op, saver, checkpoint_dir)
69    # Use the session to train the graph.
70    while True:
71      sess.run(<my_train_op>)
72  ```
73
74  `prepare_session()` initializes or restores a model. It requires `init_op`
75  and `saver` as an argument.
76
77  A second process could wait for the model to be ready by doing the following:
78
79  ```python
80  with tf.Graph().as_default():
81     ...add operations to the graph...
82    # Create a SessionManager that will wait for the model to become ready.
83    sm = SessionManager()
84    sess = sm.wait_for_session(master)
85    # Use the session to train the graph.
86    while True:
87      sess.run(<my_train_op>)
88  ```
89
90  `wait_for_session()` waits for a model to be initialized by other processes.
91
92  """
93
94  def __init__(self,
95               local_init_op=None,
96               ready_op=None,
97               ready_for_local_init_op=None,
98               graph=None,
99               recovery_wait_secs=30,
100               local_init_run_options=None):
101    """Creates a SessionManager.
102
103    The `local_init_op` is an `Operation` that is run always after a new session
104    was created. If `None`, this step is skipped.
105
106    The `ready_op` is an `Operation` used to check if the model is ready.  The
107    model is considered ready if that operation returns an empty 1D string
108    tensor. If the operation returns a non empty 1D string tensor, the elements
109    are concatenated and used to indicate to the user why the model is not
110    ready.
111
112    The `ready_for_local_init_op` is an `Operation` used to check if the model
113    is ready to run local_init_op.  The model is considered ready if that
114    operation returns an empty 1D string tensor. If the operation returns a non
115    empty 1D string tensor, the elements are concatenated and used to indicate
116    to the user why the model is not ready.
117
118    If `ready_op` is `None`, the model is not checked for readiness.
119
120    `recovery_wait_secs` is the number of seconds between checks that
121    the model is ready.  It is used by processes to wait for a model to
122    be initialized or restored.  Defaults to 30 seconds.
123
124    Args:
125      local_init_op: An `Operation` run immediately after session creation.
126         Usually used to initialize tables and local variables.
127      ready_op: An `Operation` to check if the model is initialized.
128      ready_for_local_init_op: An `Operation` to check if the model is ready
129         to run local_init_op.
130      graph: The `Graph` that the model will use.
131      recovery_wait_secs: Seconds between checks for the model to be ready.
132      local_init_run_options: RunOptions to be passed to session.run when
133        executing the local_init_op.
134
135    Raises:
136      ValueError: If ready_for_local_init_op is not None but local_init_op is
137        None
138    """
139    # Sets default values of arguments.
140    if graph is None:
141      graph = ops.get_default_graph()
142    self._local_init_op = local_init_op
143    self._ready_op = ready_op
144    self._ready_for_local_init_op = ready_for_local_init_op
145    self._graph = graph
146    self._recovery_wait_secs = recovery_wait_secs
147    self._target = None
148    self._local_init_run_options = local_init_run_options
149    if ready_for_local_init_op is not None and local_init_op is None:
150      raise ValueError("If you pass a ready_for_local_init_op "
151                       "you must also pass a local_init_op "
152                       ", ready_for_local_init_op [%s]" %
153                       ready_for_local_init_op)
154
155  def _restore_checkpoint(self,
156                          master,
157                          saver=None,
158                          checkpoint_dir=None,
159                          checkpoint_filename_with_path=None,
160                          wait_for_checkpoint=False,
161                          max_wait_secs=7200,
162                          config=None):
163    """Creates a `Session`, and tries to restore a checkpoint.
164
165
166    Args:
167      master: `String` representation of the TensorFlow master to use.
168      saver: A `Saver` object used to restore a model.
169      checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
170        dir will be used to restore.
171      checkpoint_filename_with_path: Full file name path to the checkpoint file.
172      wait_for_checkpoint: Whether to wait for checkpoint to become available.
173      max_wait_secs: Maximum time to wait for checkpoints to become available.
174      config: Optional `ConfigProto` proto used to configure the session.
175
176    Returns:
177      A pair (sess, is_restored) where 'is_restored' is `True` if
178      the session could be restored, `False` otherwise.
179
180    Raises:
181      ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
182        set.
183    """
184    self._target = master
185
186    # This is required to so that we initialize the TPU device before
187    # restoring from checkpoint since we'll be placing variables on the device
188    # and TPUInitialize wipes out the memory of the device.
189    strategy = distribution_strategy_context.get_strategy()
190    if strategy and hasattr(strategy.extended,
191                            "_experimental_initialize_system"):
192      strategy.extended._experimental_initialize_system()  # pylint: disable=protected-access
193
194    sess = session.Session(self._target, graph=self._graph, config=config)
195    if checkpoint_dir and checkpoint_filename_with_path:
196      raise ValueError("Can not provide both checkpoint_dir and "
197                       "checkpoint_filename_with_path.")
198    # If either saver or checkpoint_* is not specified, cannot restore. Just
199    # return.
200    if not saver or not (checkpoint_dir or checkpoint_filename_with_path):
201      return sess, False
202
203    if checkpoint_filename_with_path:
204      saver.restore(sess, checkpoint_filename_with_path)
205      return sess, True
206
207    # Waits up until max_wait_secs for checkpoint to become available.
208    wait_time = 0
209    ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir)
210    while not ckpt or not ckpt.model_checkpoint_path:
211      if wait_for_checkpoint and wait_time < max_wait_secs:
212        logging.info("Waiting for checkpoint to be available.")
213        time.sleep(self._recovery_wait_secs)
214        wait_time += self._recovery_wait_secs
215        ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir)
216      else:
217        return sess, False
218
219    # Loads the checkpoint.
220    saver.restore(sess, ckpt.model_checkpoint_path)
221    saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)
222    return sess, True
223
224  def prepare_session(self,
225                      master,
226                      init_op=None,
227                      saver=None,
228                      checkpoint_dir=None,
229                      checkpoint_filename_with_path=None,
230                      wait_for_checkpoint=False,
231                      max_wait_secs=7200,
232                      config=None,
233                      init_feed_dict=None,
234                      init_fn=None):
235    """Creates a `Session`. Makes sure the model is ready to be used.
236
237    Creates a `Session` on 'master'. If a `saver` object is passed in, and
238    `checkpoint_dir` points to a directory containing valid checkpoint
239    files, then it will try to recover the model from checkpoint. If
240    no checkpoint files are available, and `wait_for_checkpoint` is
241    `True`, then the process would check every `recovery_wait_secs`,
242    up to `max_wait_secs`, for recovery to succeed.
243
244    If the model cannot be recovered successfully then it is initialized by
245    running the `init_op` and calling `init_fn` if they are provided.
246    The `local_init_op` is also run after init_op and init_fn, regardless of
247    whether the model was recovered successfully, but only if
248    `ready_for_local_init_op` passes.
249
250    If the model is recovered from a checkpoint it is assumed that all
251    global variables have been initialized, in particular neither `init_op`
252    nor `init_fn` will be executed.
253
254    It is an error if the model cannot be recovered and no `init_op`
255    or `init_fn` or `local_init_op` are passed.
256
257    Args:
258      master: `String` representation of the TensorFlow master to use.
259      init_op: Optional `Operation` used to initialize the model.
260      saver: A `Saver` object used to restore a model.
261      checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
262        dir will be used to restore.
263      checkpoint_filename_with_path: Full file name path to the checkpoint file.
264      wait_for_checkpoint: Whether to wait for checkpoint to become available.
265      max_wait_secs: Maximum time to wait for checkpoints to become available.
266      config: Optional `ConfigProto` proto used to configure the session.
267      init_feed_dict: Optional dictionary that maps `Tensor` objects to feed
268        values.  This feed dictionary is passed to the session `run()` call when
269        running the init op.
270      init_fn: Optional callable used to initialize the model. Called after the
271        optional `init_op` is called.  The callable must accept one argument,
272        the session being initialized.
273
274    Returns:
275      A `Session` object that can be used to drive the model.
276
277    Raises:
278      RuntimeError: If the model cannot be initialized or recovered.
279      ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
280        set.
281    """
282
283    sess, is_loaded_from_checkpoint = self._restore_checkpoint(
284        master,
285        saver,
286        checkpoint_dir=checkpoint_dir,
287        checkpoint_filename_with_path=checkpoint_filename_with_path,
288        wait_for_checkpoint=wait_for_checkpoint,
289        max_wait_secs=max_wait_secs,
290        config=config)
291    if not is_loaded_from_checkpoint:
292      if init_op is None and not init_fn and self._local_init_op is None:
293        raise RuntimeError("Model is not initialized and no init_op or "
294                           "init_fn or local_init_op was given")
295      if init_op is not None:
296        sess.run(init_op, feed_dict=init_feed_dict)
297      if init_fn:
298        init_fn(sess)
299
300    local_init_success, msg = self._try_run_local_init_op(sess)
301    if not local_init_success:
302      raise RuntimeError(
303          "Init operations did not make model ready for local_init.  "
304          "Init op: %s, init fn: %s, error: %s" % (_maybe_name(init_op),
305                                                   init_fn,
306                                                   msg))
307
308    is_ready, msg = self._model_ready(sess)
309    if not is_ready:
310      raise RuntimeError(
311          "Init operations did not make model ready.  "
312          "Init op: %s, init fn: %s, local_init_op: %s, error: %s" %
313          (_maybe_name(init_op), init_fn, self._local_init_op, msg))
314    return sess
315
316  def recover_session(self,
317                      master,
318                      saver=None,
319                      checkpoint_dir=None,
320                      checkpoint_filename_with_path=None,
321                      wait_for_checkpoint=False,
322                      max_wait_secs=7200,
323                      config=None):
324    """Creates a `Session`, recovering if possible.
325
326    Creates a new session on 'master'.  If the session is not initialized
327    and can be recovered from a checkpoint, recover it.
328
329    Args:
330      master: `String` representation of the TensorFlow master to use.
331      saver: A `Saver` object used to restore a model.
332      checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
333        dir will be used to restore.
334      checkpoint_filename_with_path: Full file name path to the checkpoint file.
335      wait_for_checkpoint: Whether to wait for checkpoint to become available.
336      max_wait_secs: Maximum time to wait for checkpoints to become available.
337      config: Optional `ConfigProto` proto used to configure the session.
338
339    Returns:
340      A pair (sess, initialized) where 'initialized' is `True` if
341      the session could be recovered and initialized, `False` otherwise.
342
343    Raises:
344      ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
345        set.
346    """
347
348    sess, is_loaded_from_checkpoint = self._restore_checkpoint(
349        master,
350        saver,
351        checkpoint_dir=checkpoint_dir,
352        checkpoint_filename_with_path=checkpoint_filename_with_path,
353        wait_for_checkpoint=wait_for_checkpoint,
354        max_wait_secs=max_wait_secs,
355        config=config)
356
357    # Always try to run local_init_op
358    local_init_success, msg = self._try_run_local_init_op(sess)
359
360    if not is_loaded_from_checkpoint:
361      # Do not need to run checks for readiness
362      return sess, False
363
364    restoring_file = checkpoint_dir or checkpoint_filename_with_path
365    if not local_init_success:
366      logging.info(
367          "Restoring model from %s did not make model ready for local init:"
368          " %s", restoring_file, msg)
369      return sess, False
370
371    is_ready, msg = self._model_ready(sess)
372    if not is_ready:
373      logging.info("Restoring model from %s did not make model ready: %s",
374                   restoring_file, msg)
375      return sess, False
376
377    logging.info("Restored model from %s", restoring_file)
378    return sess, is_loaded_from_checkpoint
379
380  def wait_for_session(self, master, config=None, max_wait_secs=float("Inf")):
381    """Creates a new `Session` and waits for model to be ready.
382
383    Creates a new `Session` on 'master'.  Waits for the model to be
384    initialized or recovered from a checkpoint.  It's expected that
385    another thread or process will make the model ready, and that this
386    is intended to be used by threads/processes that participate in a
387    distributed training configuration where a different thread/process
388    is responsible for initializing or recovering the model being trained.
389
390    NB: The amount of time this method waits for the session is bounded
391    by max_wait_secs. By default, this function will wait indefinitely.
392
393    Args:
394      master: `String` representation of the TensorFlow master to use.
395      config: Optional ConfigProto proto used to configure the session.
396      max_wait_secs: Maximum time to wait for the session to become available.
397
398    Returns:
399      A `Session`. May be None if the operation exceeds the timeout
400      specified by config.operation_timeout_in_ms.
401
402    Raises:
403      tf.DeadlineExceededError: if the session is not available after
404        max_wait_secs.
405    """
406    self._target = master
407
408    if max_wait_secs is None:
409      max_wait_secs = float("Inf")
410    timer = _CountDownTimer(max_wait_secs)
411
412    while True:
413      sess = session.Session(self._target, graph=self._graph, config=config)
414      not_ready_msg = None
415      not_ready_local_msg = None
416      local_init_success, not_ready_local_msg = self._try_run_local_init_op(
417          sess)
418      if local_init_success:
419        # Successful if local_init_op is None, or ready_for_local_init_op passes
420        is_ready, not_ready_msg = self._model_ready(sess)
421        if is_ready:
422          return sess
423
424      self._safe_close(sess)
425
426      # Do we have enough time left to try again?
427      remaining_ms_after_wait = (
428          timer.secs_remaining() - self._recovery_wait_secs)
429      if remaining_ms_after_wait < 0:
430        raise errors.DeadlineExceededError(
431            None, None,
432            "Session was not ready after waiting %d secs." % (max_wait_secs,))
433
434      logging.info("Waiting for model to be ready.  "
435                   "Ready_for_local_init_op:  %s, ready: %s",
436                   not_ready_local_msg, not_ready_msg)
437      time.sleep(self._recovery_wait_secs)
438
439  def _safe_close(self, sess):
440    """Closes a session without raising an exception.
441
442    Just like sess.close() but ignores exceptions.
443
444    Args:
445      sess: A `Session`.
446    """
447    # pylint: disable=broad-except
448    try:
449      sess.close()
450    except Exception:
451      # Intentionally not logging to avoid user complaints that
452      # they get cryptic errors.  We really do not care that Close
453      # fails.
454      pass
455    # pylint: enable=broad-except
456
457  def _model_ready(self, sess):
458    """Checks if the model is ready or not.
459
460    Args:
461      sess: A `Session`.
462
463    Returns:
464      A tuple (is_ready, msg), where is_ready is True if ready and False
465      otherwise, and msg is `None` if the model is ready, a `String` with the
466      reason why it is not ready otherwise.
467    """
468    return _ready(self._ready_op, sess, "Model not ready")
469
470  def _model_ready_for_local_init(self, sess):
471    """Checks if the model is ready to run local_init_op.
472
473    Args:
474      sess: A `Session`.
475
476    Returns:
477      A tuple (is_ready, msg), where is_ready is True if ready to run
478      local_init_op and False otherwise, and msg is `None` if the model is
479      ready to run local_init_op, a `String` with the reason why it is not ready
480      otherwise.
481    """
482    return _ready(self._ready_for_local_init_op, sess,
483                  "Model not ready for local init")
484
485  def _try_run_local_init_op(self, sess):
486    """Tries to run _local_init_op, if not None, and is ready for local init.
487
488    Args:
489      sess: A `Session`.
490
491    Returns:
492      A tuple (is_successful, msg), where is_successful is True if
493      _local_init_op is None, or we ran _local_init_op, and False otherwise;
494      and msg is a `String` with the reason why the model was not ready to run
495      local init.
496    """
497    if self._local_init_op is not None:
498      is_ready_for_local_init, msg = self._model_ready_for_local_init(sess)
499      if is_ready_for_local_init:
500        logging.info("Running local_init_op.")
501        sess.run(self._local_init_op, options=self._local_init_run_options)
502        logging.info("Done running local_init_op.")
503        return True, None
504      else:
505        return False, msg
506    return True, None
507
508
509def _ready(op, sess, msg):
510  """Checks if the model is ready or not, as determined by op.
511
512  Args:
513    op: An op, either _ready_op or _ready_for_local_init_op, which defines the
514      readiness of the model.
515    sess: A `Session`.
516    msg: A message to log to warning if not ready
517
518  Returns:
519    A tuple (is_ready, msg), where is_ready is True if ready and False
520    otherwise, and msg is `None` if the model is ready, a `String` with the
521    reason why it is not ready otherwise.
522  """
523  if op is None:
524    return True, None
525  else:
526    try:
527      ready_value = sess.run(op)
528      # The model is considered ready if ready_op returns an empty 1-D tensor.
529      # Also compare to `None` and dtype being int32 for backward
530      # compatibility.
531      if (ready_value is None or ready_value.dtype == np.int32 or
532          ready_value.size == 0):
533        return True, None
534      else:
535        # TODO(sherrym): If a custom ready_op returns other types of tensor,
536        # or strings other than variable names, this message could be
537        # confusing.
538        non_initialized_varnames = ", ".join(
539            [i.decode("utf-8") for i in ready_value])
540        return False, "Variables not initialized: " + non_initialized_varnames
541    except errors.FailedPreconditionError as e:
542      if "uninitialized" not in str(e):
543        logging.warning("%s : error [%s]", msg, str(e))
544        raise e
545      return False, str(e)
546
547
548class _CountDownTimer(object):
549
550  def __init__(self, duration_secs):
551    self._start_time_secs = time.time()
552    self._duration_secs = duration_secs
553
554  def secs_remaining(self):
555    diff = self._duration_secs - (time.time() - self._start_time_secs)
556    return max(0, diff)
557