1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tools to work with checkpoints."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import six
22
23from tensorflow.python import pywrap_tensorflow
24from tensorflow.python.distribute import distribution_strategy_context
25from tensorflow.python.framework import ops
26from tensorflow.python.ops import io_ops
27from tensorflow.python.ops import resource_variable_ops
28from tensorflow.python.ops import variable_scope as vs
29from tensorflow.python.ops import variables
30from tensorflow.python.platform import gfile
31from tensorflow.python.platform import tf_logging as logging
32from tensorflow.python.training import checkpoint_management
33from tensorflow.python.training.saving import saveable_object_util
34from tensorflow.python.util.tf_export import tf_export
35
36
37__all__ = [
38    "load_checkpoint", "load_variable", "list_variables", "init_from_checkpoint"
39]
40
41
42@tf_export("train.load_checkpoint")
43def load_checkpoint(ckpt_dir_or_file):
44  """Returns `CheckpointReader` for checkpoint found in `ckpt_dir_or_file`.
45
46  If `ckpt_dir_or_file` resolves to a directory with multiple checkpoints,
47  reader for the latest checkpoint is returned.
48
49  Args:
50    ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint
51      file.
52
53  Returns:
54    `CheckpointReader` object.
55
56  Raises:
57    ValueError: If `ckpt_dir_or_file` resolves to a directory with no
58      checkpoints.
59  """
60  filename = _get_checkpoint_filename(ckpt_dir_or_file)
61  if filename is None:
62    raise ValueError("Couldn't find 'checkpoint' file or checkpoints in "
63                     "given directory %s" % ckpt_dir_or_file)
64  return pywrap_tensorflow.NewCheckpointReader(filename)
65
66
67@tf_export("train.load_variable")
68def load_variable(ckpt_dir_or_file, name):
69  """Returns the tensor value of the given variable in the checkpoint.
70
71  Args:
72    ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
73    name: Name of the variable to return.
74
75  Returns:
76    A numpy `ndarray` with a copy of the value of this variable.
77  """
78  # TODO(b/29227106): Fix this in the right place and remove this.
79  if name.endswith(":0"):
80    name = name[:-2]
81  reader = load_checkpoint(ckpt_dir_or_file)
82  return reader.get_tensor(name)
83
84
85@tf_export("train.list_variables")
86def list_variables(ckpt_dir_or_file):
87  """Returns list of all variables in the checkpoint.
88
89  Args:
90    ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
91
92  Returns:
93    List of tuples `(name, shape)`.
94  """
95  reader = load_checkpoint(ckpt_dir_or_file)
96  variable_map = reader.get_variable_to_shape_map()
97  names = sorted(variable_map.keys())
98  result = []
99  for name in names:
100    result.append((name, variable_map[name]))
101  return result
102
103
104@tf_export(v1=["train.init_from_checkpoint"])
105def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
106  """Replaces `tf.Variable` initializers so they load from a checkpoint file.
107
108  Values are not loaded immediately, but when the initializer is run
109  (typically by running a `tf.global_variables_initializer` op).
110
111  Note: This overrides default initialization ops of specified variables and
112  redefines dtype.
113
114  Assignment map supports following syntax:
115
116  * `'checkpoint_scope_name/': 'scope_name/'` - will load all variables in
117    current `scope_name` from `checkpoint_scope_name` with matching tensor
118    names.
119  * `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` -
120    will initialize `scope_name/variable_name` variable
121    from `checkpoint_scope_name/some_other_variable`.
122  * `'scope_variable_name': variable` - will initialize given `tf.Variable`
123    object with tensor 'scope_variable_name' from the checkpoint.
124  * `'scope_variable_name': list(variable)` - will initialize list of
125    partitioned variables with tensor 'scope_variable_name' from the checkpoint.
126  * `'/': 'scope_name/'` - will load all variables in current `scope_name` from
127    checkpoint's root (e.g. no scope).
128
129  Supports loading into partitioned variables, which are represented as
130  `'<variable>/part_<part #>'`.
131
132  Example:
133
134  ```python
135
136  # Say, '/tmp/model.ckpt' has the following tensors:
137  #  -- name='old_scope_1/var1', shape=[20, 2]
138  #  -- name='old_scope_1/var2', shape=[50, 4]
139  #  -- name='old_scope_2/var3', shape=[100, 100]
140
141  # Create new model's variables
142  with tf.variable_scope('new_scope_1'):
143    var1 = tf.get_variable('var1', shape=[20, 2],
144                           initializer=tf.zeros_initializer())
145  with tf.variable_scope('new_scope_2'):
146    var2 = tf.get_variable('var2', shape=[50, 4],
147                           initializer=tf.zeros_initializer())
148    # Partition into 5 variables along the first axis.
149    var3 = tf.get_variable(name='var3', shape=[100, 100],
150                           initializer=tf.zeros_initializer(),
151                           partitioner=lambda shape, dtype: [5, 1])
152
153  # Initialize all variables in `new_scope_1` from `old_scope_1`.
154  init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/': 'new_scope_1'})
155
156  # Use names to specify which variables to initialize from checkpoint.
157  init_from_checkpoint('/tmp/model.ckpt',
158                       {'old_scope_1/var1': 'new_scope_1/var1',
159                        'old_scope_1/var2': 'new_scope_2/var2'})
160
161  # Or use tf.Variable objects to identify what to initialize.
162  init_from_checkpoint('/tmp/model.ckpt',
163                       {'old_scope_1/var1': var1,
164                        'old_scope_1/var2': var2})
165
166  # Initialize partitioned variables using variable's name
167  init_from_checkpoint('/tmp/model.ckpt',
168                       {'old_scope_2/var3': 'new_scope_2/var3'})
169
170  # Or specify the list of tf.Variable objects.
171  init_from_checkpoint('/tmp/model.ckpt',
172                       {'old_scope_2/var3': var3._get_variable_list()})
173
174  ```
175
176  Args:
177    ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
178    assignment_map: Dict, where keys are names of the variables in the
179      checkpoint and values are current variables or names of current variables
180      (in default graph).
181
182  Raises:
183    ValueError: If missing variables in current graph, or if missing
184      checkpoints or tensors in checkpoints.
185  """
186  init_from_checkpoint_fn = lambda _: _init_from_checkpoint(
187      ckpt_dir_or_file, assignment_map)
188  if distribution_strategy_context.get_cross_replica_context():
189    init_from_checkpoint_fn(None)
190  else:
191    distribution_strategy_context.get_replica_context().merge_call(
192        init_from_checkpoint_fn)
193
194
195def _init_from_checkpoint(ckpt_dir_or_file, assignment_map):
196  """See `init_from_checkpoint` for documentation."""
197  ckpt_file = _get_checkpoint_filename(ckpt_dir_or_file)
198  reader = load_checkpoint(ckpt_dir_or_file)
199  variable_map = reader.get_variable_to_shape_map()
200  for tensor_name_in_ckpt, current_var_or_name in sorted(
201      six.iteritems(assignment_map)):
202    var = None
203    # Check if this is Variable object or list of Variable objects (in case of
204    # partitioned variables).
205    if _is_variable(current_var_or_name) or (
206        isinstance(current_var_or_name, list)
207        and all(_is_variable(v) for v in current_var_or_name)):
208      var = current_var_or_name
209    else:
210      store_vars = vs._get_default_variable_store()._vars  # pylint:disable=protected-access
211      # Check if this variable is in var_store.
212      var = store_vars.get(current_var_or_name, None)
213      # Also check if variable is partitioned as list.
214      if var is None:
215        var = _collect_partitioned_variable(current_var_or_name, store_vars)
216    if var is not None:
217      # If 1 to 1 mapping was provided, find variable in the checkpoint.
218      if tensor_name_in_ckpt not in variable_map:
219        raise ValueError("Tensor %s is not found in %s checkpoint %s" % (
220            tensor_name_in_ckpt, ckpt_dir_or_file, variable_map
221        ))
222      if _is_variable(var):
223        # Additional at-call-time checks.
224        if not var.get_shape().is_compatible_with(
225            variable_map[tensor_name_in_ckpt]):
226          raise ValueError(
227              "Shape of variable %s (%s) doesn't match with shape of "
228              "tensor %s (%s) from checkpoint reader." % (
229                  var.name, str(var.get_shape()),
230                  tensor_name_in_ckpt, str(variable_map[tensor_name_in_ckpt])
231              ))
232        var_name = var.name
233      else:
234        var_name = ",".join([v.name for v in var])
235      _set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt)
236      logging.debug("Initialize variable %s from checkpoint %s with %s",
237                    var_name, ckpt_dir_or_file, tensor_name_in_ckpt)
238    else:
239      scopes = ""
240      # TODO(vihanjain): Support list of 'current_var_or_name' here.
241      if "/" in current_var_or_name:
242        scopes = current_var_or_name[:current_var_or_name.rindex("/")]
243      if not tensor_name_in_ckpt.endswith("/"):
244        raise ValueError(
245            "Assignment map with scope only name {} should map to scope only "
246            "{}. Should be 'scope/': 'other_scope/'.".format(
247                scopes, tensor_name_in_ckpt))
248      # If scope to scope mapping was provided, find all variables in the scope
249      # and create variable to variable mapping.
250      scope_variables = set()
251      for var_name in store_vars:
252        if not scopes or var_name.startswith(scopes + "/"):
253          # Consume /part_ if partitioned variable.
254          if "/part_" in var_name:
255            var_name = var_name[:var_name.index("/part_")]
256          scope_variables.add(var_name)
257      for var_name in sorted(scope_variables):
258        # Lookup name with specified prefix and suffix from current variable.
259        # If tensor_name given is '/' (root), don't use it for full name.
260        full_tensor_name = var_name[len(scopes):]
261        if current_var_or_name != "/":
262          full_tensor_name = full_tensor_name[1:]
263        if tensor_name_in_ckpt != "/":
264          full_tensor_name = tensor_name_in_ckpt + full_tensor_name
265        # Remove trailing '/', if any, in the full_tensor_name
266        if full_tensor_name.endswith("/"):
267          full_tensor_name = full_tensor_name[:-1]
268        if full_tensor_name not in variable_map:
269          raise ValueError(
270              "Tensor %s (%s in %s) is not found in %s checkpoint" % (
271                  full_tensor_name, var_name[len(scopes) + 1:],
272                  tensor_name_in_ckpt, ckpt_dir_or_file
273              ))
274        var = store_vars.get(var_name, None)
275        if var is None:
276          var = _collect_partitioned_variable(var_name, store_vars)
277        _set_variable_or_list_initializer(var, ckpt_file, full_tensor_name)
278        logging.debug("Initialize variable %s from checkpoint %s with %s",
279                      var_name, ckpt_dir_or_file, full_tensor_name)
280
281
282def _get_checkpoint_filename(ckpt_dir_or_file):
283  """Returns checkpoint filename given directory or specific checkpoint file."""
284  if gfile.IsDirectory(ckpt_dir_or_file):
285    return checkpoint_management.latest_checkpoint(ckpt_dir_or_file)
286  return ckpt_dir_or_file
287
288
289def _set_checkpoint_initializer(variable,
290                                ckpt_file,
291                                tensor_name,
292                                slice_spec,
293                                name="checkpoint_initializer"):
294  """Overrides given variable's initialization op.
295
296  Sets variable initializer to assign op that initializes variable from tensor's
297  value in the checkpoint.
298
299  Args:
300    variable: `tf.Variable` object.
301    ckpt_file: string, full path of the checkpoint.
302    tensor_name: Name of the tensor to load from the checkpoint.
303    slice_spec: Slice specification for loading partitioned tensors.
304    name: Name of the operation.
305  """
306  base_type = variable.dtype.base_dtype
307  # Do not colocate with variable since RestoreV2 op only runs on CPU and
308  # colocation will force variable (and other ops that colocate with variable)
309  # to be on CPU as well. It is okay to place the variable's initializer op on
310  # CPU since it will only be run once at the start.
311  with ops.device(variable.device), ops.device("/cpu:0"):
312    restore_op = io_ops.restore_v2(
313        ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0]
314
315    names_to_saveables = saveable_object_util.op_list_to_dict([variable])
316    saveable_objects = []
317    for name, op in names_to_saveables.items():
318      for s in saveable_object_util.saveable_objects_for_op(op, name):
319        saveable_objects.append(s)
320
321    assert len(saveable_objects) == 1  # Should be only one variable.
322  init_op = saveable_objects[0].restore([restore_op], restored_shapes=None)
323
324  # pylint:disable=protected-access
325  variable._initializer_op = init_op
326  restore_op.set_shape(variable.shape)
327  variable._initial_value = restore_op
328  # pylint:enable=protected-access
329
330
331def _set_variable_or_list_initializer(variable_or_list, ckpt_file,
332                                      tensor_name):
333  """Overrides initialization op of given variable or list of variables.
334
335  Calls `_set_checkpoint_initializer` for each variable in the given list of
336  variables.
337
338  Args:
339    variable_or_list: `tf.Variable` object or a list of `tf.Variable` objects.
340    ckpt_file: string, full path of the checkpoint.
341    tensor_name: Name of the tensor to load from the checkpoint.
342
343  Raises:
344    ValueError: if all objects in `variable_or_list` are not partitions of the
345      same large variable.
346  """
347  if isinstance(variable_or_list, (list, tuple)):
348    # A set of slices.
349    slice_name = None
350    for v in variable_or_list:
351      slice_info = v._save_slice_info  # pylint:disable=protected-access
352      if slice_name is None:
353        slice_name = slice_info.full_name
354      elif slice_name != slice_info.full_name:
355        raise ValueError("Slices must all be from the same tensor: %s != %s" %
356                         (slice_name, slice_info.full_name))
357      _set_checkpoint_initializer(v, ckpt_file, tensor_name, slice_info.spec)
358  else:
359    _set_checkpoint_initializer(variable_or_list, ckpt_file, tensor_name, "")
360
361
362def _is_variable(x):
363  return (isinstance(x, variables.Variable) or
364          resource_variable_ops.is_resource_variable(x))
365
366def _collect_partitioned_variable(name, all_vars):
367  """Returns list of `tf.Variable` that comprise the partitioned variable."""
368  if name + "/part_0" in all_vars:
369    var = []
370    i = 0
371    while name + "/part_%d" % i in all_vars:
372      var.append(all_vars[name + "/part_%d" % i])
373      i += 1
374    return var
375  return None
376