1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tools to work with checkpoints."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import time
22
23import six
24
25from tensorflow.python.distribute import distribution_strategy_context
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import io_ops
28from tensorflow.python.ops import resource_variable_ops
29from tensorflow.python.ops import variable_scope as vs
30from tensorflow.python.ops import variables
31from tensorflow.python.platform import gfile
32from tensorflow.python.platform import tf_logging as logging
33from tensorflow.python.training import checkpoint_management
34from tensorflow.python.training import py_checkpoint_reader
35from tensorflow.python.training.saving import saveable_object_util
36from tensorflow.python.util.tf_export import tf_export
37
38
39__all__ = [
40    "load_checkpoint", "load_variable", "list_variables",
41    "checkpoints_iterator", "init_from_checkpoint"
42]
43
44
45@tf_export("train.load_checkpoint")
46def load_checkpoint(ckpt_dir_or_file):
47  """Returns `CheckpointReader` for checkpoint found in `ckpt_dir_or_file`.
48
49  If `ckpt_dir_or_file` resolves to a directory with multiple checkpoints,
50  reader for the latest checkpoint is returned.
51
52  Args:
53    ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint
54      file.
55
56  Returns:
57    `CheckpointReader` object.
58
59  Raises:
60    ValueError: If `ckpt_dir_or_file` resolves to a directory with no
61      checkpoints.
62  """
63  filename = _get_checkpoint_filename(ckpt_dir_or_file)
64  if filename is None:
65    raise ValueError("Couldn't find 'checkpoint' file or checkpoints in "
66                     "given directory %s" % ckpt_dir_or_file)
67  return py_checkpoint_reader.NewCheckpointReader(filename)
68
69
70@tf_export("train.load_variable")
71def load_variable(ckpt_dir_or_file, name):
72  """Returns the tensor value of the given variable in the checkpoint.
73
74  Args:
75    ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
76    name: Name of the variable to return.
77
78  Returns:
79    A numpy `ndarray` with a copy of the value of this variable.
80  """
81  # TODO(b/29227106): Fix this in the right place and remove this.
82  if name.endswith(":0"):
83    name = name[:-2]
84  reader = load_checkpoint(ckpt_dir_or_file)
85  return reader.get_tensor(name)
86
87
88@tf_export("train.list_variables")
89def list_variables(ckpt_dir_or_file):
90  """Lists the checkpoint keys and shapes of variables in a checkpoint.
91
92  Checkpoint keys are paths in a checkpoint graph.
93
94  Example usage:
95
96    ```python
97  import tensorflow as tf
98  import os
99  ckpt_directory = "/tmp/training_checkpoints/ckpt"
100  ckpt = tf.train.Checkpoint(optimizer=optimizer, model=model)
101  manager = tf.train.CheckpointManager(ckpt, ckpt_directory, max_to_keep=3)
102  train_and_checkpoint(model, manager)
103  tf.train.list_variables(manager.latest_checkpoint)
104  ```
105
106  Args:
107    ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
108
109  Returns:
110    List of tuples `(key, shape)`.
111  """
112  reader = load_checkpoint(ckpt_dir_or_file)
113  variable_map = reader.get_variable_to_shape_map()
114  names = sorted(variable_map.keys())
115  result = []
116  for name in names:
117    result.append((name, variable_map[name]))
118  return result
119
120
121def wait_for_new_checkpoint(checkpoint_dir,
122                            last_checkpoint=None,
123                            seconds_to_sleep=1,
124                            timeout=None):
125  """Waits until a new checkpoint file is found.
126
127  Args:
128    checkpoint_dir: The directory in which checkpoints are saved.
129    last_checkpoint: The last checkpoint path used or `None` if we're expecting
130      a checkpoint for the first time.
131    seconds_to_sleep: The number of seconds to sleep for before looking for a
132      new checkpoint.
133    timeout: The maximum number of seconds to wait. If left as `None`, then the
134      process will wait indefinitely.
135
136  Returns:
137    a new checkpoint path, or None if the timeout was reached.
138  """
139  logging.info("Waiting for new checkpoint at %s", checkpoint_dir)
140  stop_time = time.time() + timeout if timeout is not None else None
141  while True:
142    checkpoint_path = checkpoint_management.latest_checkpoint(checkpoint_dir)
143    if checkpoint_path is None or checkpoint_path == last_checkpoint:
144      if stop_time is not None and time.time() + seconds_to_sleep > stop_time:
145        return None
146      time.sleep(seconds_to_sleep)
147    else:
148      logging.info("Found new checkpoint at %s", checkpoint_path)
149      return checkpoint_path
150
151
152@tf_export("train.checkpoints_iterator")
153def checkpoints_iterator(checkpoint_dir,
154                         min_interval_secs=0,
155                         timeout=None,
156                         timeout_fn=None):
157  """Continuously yield new checkpoint files as they appear.
158
159  The iterator only checks for new checkpoints when control flow has been
160  reverted to it. This means it can miss checkpoints if your code takes longer
161  to run between iterations than `min_interval_secs` or the interval at which
162  new checkpoints are written.
163
164  The `timeout` argument is the maximum number of seconds to block waiting for
165  a new checkpoint.  It is used in combination with the `timeout_fn` as
166  follows:
167
168  * If the timeout expires and no `timeout_fn` was specified, the iterator
169    stops yielding.
170  * If a `timeout_fn` was specified, that function is called and if it returns
171    a true boolean value the iterator stops yielding.
172  * If the function returns a false boolean value then the iterator resumes the
173    wait for new checkpoints.  At this point the timeout logic applies again.
174
175  This behavior gives control to callers on what to do if checkpoints do not
176  come fast enough or stop being generated.  For example, if callers have a way
177  to detect that the training has stopped and know that no new checkpoints
178  will be generated, they can provide a `timeout_fn` that returns `True` when
179  the training has stopped.  If they know that the training is still going on
180  they return `False` instead.
181
182  Args:
183    checkpoint_dir: The directory in which checkpoints are saved.
184    min_interval_secs: The minimum number of seconds between yielding
185      checkpoints.
186    timeout: The maximum number of seconds to wait between checkpoints. If left
187      as `None`, then the process will wait indefinitely.
188    timeout_fn: Optional function to call after a timeout.  If the function
189      returns True, then it means that no new checkpoints will be generated and
190      the iterator will exit.  The function is called with no arguments.
191
192  Yields:
193    String paths to latest checkpoint files as they arrive.
194  """
195  checkpoint_path = None
196  while True:
197    new_checkpoint_path = wait_for_new_checkpoint(
198        checkpoint_dir, checkpoint_path, timeout=timeout)
199    if new_checkpoint_path is None:
200      if not timeout_fn:
201        # timed out
202        logging.info("Timed-out waiting for a checkpoint.")
203        return
204      if timeout_fn():
205        # The timeout_fn indicated that we are truly done.
206        return
207      else:
208        # The timeout_fn indicated that more checkpoints may come.
209        continue
210    start = time.time()
211    checkpoint_path = new_checkpoint_path
212    yield checkpoint_path
213    time_to_next_eval = start + min_interval_secs - time.time()
214    if time_to_next_eval > 0:
215      time.sleep(time_to_next_eval)
216
217
218@tf_export(v1=["train.init_from_checkpoint"])
219def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
220  """Replaces `tf.Variable` initializers so they load from a checkpoint file.
221
222  Values are not loaded immediately, but when the initializer is run
223  (typically by running a `tf.compat.v1.global_variables_initializer` op).
224
225  Note: This overrides default initialization ops of specified variables and
226  redefines dtype.
227
228  Assignment map supports following syntax:
229
230  * `'checkpoint_scope_name/': 'scope_name/'` - will load all variables in
231    current `scope_name` from `checkpoint_scope_name` with matching tensor
232    names.
233  * `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` -
234    will initialize `scope_name/variable_name` variable
235    from `checkpoint_scope_name/some_other_variable`.
236  * `'scope_variable_name': variable` - will initialize given `tf.Variable`
237    object with tensor 'scope_variable_name' from the checkpoint.
238  * `'scope_variable_name': list(variable)` - will initialize list of
239    partitioned variables with tensor 'scope_variable_name' from the checkpoint.
240  * `'/': 'scope_name/'` - will load all variables in current `scope_name` from
241    checkpoint's root (e.g. no scope).
242
243  Supports loading into partitioned variables, which are represented as
244  `'<variable>/part_<part #>'`.
245
246  Example:
247
248  ```python
249
250  # Say, '/tmp/model.ckpt' has the following tensors:
251  #  -- name='old_scope_1/var1', shape=[20, 2]
252  #  -- name='old_scope_1/var2', shape=[50, 4]
253  #  -- name='old_scope_2/var3', shape=[100, 100]
254
255  # Create new model's variables
256  with tf.compat.v1.variable_scope('new_scope_1'):
257    var1 = tf.compat.v1.get_variable('var1', shape=[20, 2],
258                           initializer=tf.compat.v1.zeros_initializer())
259  with tf.compat.v1.variable_scope('new_scope_2'):
260    var2 = tf.compat.v1.get_variable('var2', shape=[50, 4],
261                           initializer=tf.compat.v1.zeros_initializer())
262    # Partition into 5 variables along the first axis.
263    var3 = tf.compat.v1.get_variable(name='var3', shape=[100, 100],
264                           initializer=tf.compat.v1.zeros_initializer(),
265                           partitioner=lambda shape, dtype: [5, 1])
266
267  # Initialize all variables in `new_scope_1` from `old_scope_1`.
268  init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/': 'new_scope_1'})
269
270  # Use names to specify which variables to initialize from checkpoint.
271  init_from_checkpoint('/tmp/model.ckpt',
272                       {'old_scope_1/var1': 'new_scope_1/var1',
273                        'old_scope_1/var2': 'new_scope_2/var2'})
274
275  # Or use tf.Variable objects to identify what to initialize.
276  init_from_checkpoint('/tmp/model.ckpt',
277                       {'old_scope_1/var1': var1,
278                        'old_scope_1/var2': var2})
279
280  # Initialize partitioned variables using variable's name
281  init_from_checkpoint('/tmp/model.ckpt',
282                       {'old_scope_2/var3': 'new_scope_2/var3'})
283
284  # Or specify the list of tf.Variable objects.
285  init_from_checkpoint('/tmp/model.ckpt',
286                       {'old_scope_2/var3': var3._get_variable_list()})
287
288  ```
289
290  Args:
291    ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
292    assignment_map: Dict, where keys are names of the variables in the
293      checkpoint and values are current variables or names of current variables
294      (in default graph).
295
296  Raises:
297    ValueError: If missing variables in current graph, or if missing
298      checkpoints or tensors in checkpoints.
299  """
300  init_from_checkpoint_fn = lambda _: _init_from_checkpoint(
301      ckpt_dir_or_file, assignment_map)
302  if distribution_strategy_context.get_cross_replica_context():
303    init_from_checkpoint_fn(None)
304  else:
305    distribution_strategy_context.get_replica_context().merge_call(
306        init_from_checkpoint_fn)
307
308
309def _init_from_checkpoint(ckpt_dir_or_file, assignment_map):
310  """See `init_from_checkpoint` for documentation."""
311  ckpt_file = _get_checkpoint_filename(ckpt_dir_or_file)
312  reader = load_checkpoint(ckpt_dir_or_file)
313  variable_map = reader.get_variable_to_shape_map()
314  for tensor_name_in_ckpt, current_var_or_name in sorted(
315      six.iteritems(assignment_map)):
316    var = None
317    # Check if this is Variable object or list of Variable objects (in case of
318    # partitioned variables).
319    if _is_variable(current_var_or_name) or (
320        isinstance(current_var_or_name, list)
321        and all(_is_variable(v) for v in current_var_or_name)):
322      var = current_var_or_name
323    else:
324      store_vars = vs._get_default_variable_store()._vars  # pylint:disable=protected-access
325      # Check if this variable is in var_store.
326      var = store_vars.get(current_var_or_name, None)
327      # Also check if variable is partitioned as list.
328      if var is None:
329        var = _collect_partitioned_variable(current_var_or_name, store_vars)
330    if var is not None:
331      # If 1 to 1 mapping was provided, find variable in the checkpoint.
332      if tensor_name_in_ckpt not in variable_map:
333        raise ValueError("Tensor %s is not found in %s checkpoint %s" % (
334            tensor_name_in_ckpt, ckpt_dir_or_file, variable_map
335        ))
336      if _is_variable(var):
337        # Additional at-call-time checks.
338        if not var.get_shape().is_compatible_with(
339            variable_map[tensor_name_in_ckpt]):
340          raise ValueError(
341              "Shape of variable %s (%s) doesn't match with shape of "
342              "tensor %s (%s) from checkpoint reader." % (
343                  var.name, str(var.get_shape()),
344                  tensor_name_in_ckpt, str(variable_map[tensor_name_in_ckpt])
345              ))
346        var_name = var.name
347      else:
348        var_name = ",".join(v.name for v in var)
349      _set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt)
350      logging.debug("Initialize variable %s from checkpoint %s with %s",
351                    var_name, ckpt_dir_or_file, tensor_name_in_ckpt)
352    else:
353      scopes = ""
354      # TODO(vihanjain): Support list of 'current_var_or_name' here.
355      if "/" in current_var_or_name:
356        scopes = current_var_or_name[:current_var_or_name.rindex("/")]
357      if not tensor_name_in_ckpt.endswith("/"):
358        raise ValueError(
359            "Assignment map with scope only name {} should map to scope only "
360            "{}. Should be 'scope/': 'other_scope/'.".format(
361                scopes, tensor_name_in_ckpt))
362      # If scope to scope mapping was provided, find all variables in the scope
363      # and create variable to variable mapping.
364      scope_variables = set()
365      for var_name in store_vars:
366        if not scopes or var_name.startswith(scopes + "/"):
367          # Consume /part_ if partitioned variable.
368          if "/part_" in var_name:
369            var_name = var_name[:var_name.index("/part_")]
370          scope_variables.add(var_name)
371      for var_name in sorted(scope_variables):
372        # Lookup name with specified prefix and suffix from current variable.
373        # If tensor_name given is '/' (root), don't use it for full name.
374        full_tensor_name = var_name[len(scopes):]
375        if current_var_or_name != "/":
376          full_tensor_name = full_tensor_name[1:]
377        if tensor_name_in_ckpt != "/":
378          full_tensor_name = tensor_name_in_ckpt + full_tensor_name
379        # Remove trailing '/', if any, in the full_tensor_name
380        if full_tensor_name.endswith("/"):
381          full_tensor_name = full_tensor_name[:-1]
382        if full_tensor_name not in variable_map:
383          raise ValueError(
384              "Tensor %s (%s in %s) is not found in %s checkpoint" % (
385                  full_tensor_name, var_name[len(scopes) + 1:],
386                  tensor_name_in_ckpt, ckpt_dir_or_file
387              ))
388        var = store_vars.get(var_name, None)
389        if var is None:
390          var = _collect_partitioned_variable(var_name, store_vars)
391        _set_variable_or_list_initializer(var, ckpt_file, full_tensor_name)
392        logging.debug("Initialize variable %s from checkpoint %s with %s",
393                      var_name, ckpt_dir_or_file, full_tensor_name)
394
395
396def _get_checkpoint_filename(ckpt_dir_or_file):
397  """Returns checkpoint filename given directory or specific checkpoint file."""
398  if gfile.IsDirectory(ckpt_dir_or_file):
399    return checkpoint_management.latest_checkpoint(ckpt_dir_or_file)
400  return ckpt_dir_or_file
401
402
403def _set_checkpoint_initializer(variable,
404                                ckpt_file,
405                                tensor_name,
406                                slice_spec,
407                                name="checkpoint_initializer"):
408  """Overrides given variable's initialization op.
409
410  Sets variable initializer to assign op that initializes variable from tensor's
411  value in the checkpoint.
412
413  Args:
414    variable: `tf.Variable` object.
415    ckpt_file: string, full path of the checkpoint.
416    tensor_name: Name of the tensor to load from the checkpoint.
417    slice_spec: Slice specification for loading partitioned tensors.
418    name: Name of the operation.
419  """
420  base_type = variable.dtype.base_dtype
421  # Do not colocate with variable since RestoreV2 op only runs on CPU and
422  # colocation will force variable (and other ops that colocate with variable)
423  # to be on CPU as well. It is okay to place the variable's initializer op on
424  # CPU since it will only be run once at the start.
425  with ops.device(variable.device), ops.device("/cpu:0"):
426    restore_op = io_ops.restore_v2(
427        ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0]
428
429    names_to_saveables = saveable_object_util.op_list_to_dict([variable])
430    saveable_objects = []
431    for name, op in names_to_saveables.items():
432      for s in saveable_object_util.saveable_objects_for_op(op, name):
433        saveable_objects.append(s)
434
435    assert len(saveable_objects) == 1  # Should be only one variable.
436  init_op = saveable_objects[0].restore([restore_op], restored_shapes=None)
437
438  # pylint:disable=protected-access
439  variable._initializer_op = init_op
440  restore_op.set_shape(variable.shape)
441  variable._initial_value = restore_op
442  # pylint:enable=protected-access
443
444
445def _set_variable_or_list_initializer(variable_or_list, ckpt_file,
446                                      tensor_name):
447  """Overrides initialization op of given variable or list of variables.
448
449  Calls `_set_checkpoint_initializer` for each variable in the given list of
450  variables.
451
452  Args:
453    variable_or_list: `tf.Variable` object or a list of `tf.Variable` objects.
454    ckpt_file: string, full path of the checkpoint.
455    tensor_name: Name of the tensor to load from the checkpoint.
456
457  Raises:
458    ValueError: if all objects in `variable_or_list` are not partitions of the
459      same large variable.
460  """
461  if isinstance(variable_or_list, (list, tuple)):
462    # A set of slices.
463    slice_name = None
464    for v in variable_or_list:
465      slice_info = v._save_slice_info  # pylint:disable=protected-access
466      if slice_name is None:
467        slice_name = slice_info.full_name
468      elif slice_name != slice_info.full_name:
469        raise ValueError("Slices must all be from the same tensor: %s != %s" %
470                         (slice_name, slice_info.full_name))
471      _set_checkpoint_initializer(v, ckpt_file, tensor_name, slice_info.spec)
472  else:
473    _set_checkpoint_initializer(variable_or_list, ckpt_file, tensor_name, "")
474
475
476def _is_variable(x):
477  return (isinstance(x, variables.Variable) or
478          resource_variable_ops.is_resource_variable(x))
479
480
481def _collect_partitioned_variable(name, all_vars):
482  """Returns list of `tf.Variable` that comprise the partitioned variable."""
483  if name + "/part_0" in all_vars:
484    var = []
485    i = 0
486    while name + "/part_%d" % i in all_vars:
487      var.append(all_vars[name + "/part_%d" % i])
488      i += 1
489    return var
490  return None
491