1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for working with and creating SaveableObjects."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import functools
21import six
22
23from tensorflow.python.eager import context
24from tensorflow.python.eager import def_function
25
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import device as pydev
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_spec
31from tensorflow.python.framework import tensor_util
32from tensorflow.python.framework import type_spec
33
34
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import resource_variable_ops
37from tensorflow.python.ops import state_ops
38from tensorflow.python.ops import variables
39from tensorflow.python.platform import tf_logging as logging
40from tensorflow.python.training.saving import saveable_object
41from tensorflow.python.training.tracking import base as trackable
42from tensorflow.python.util import nest
43from tensorflow.python.util import object_identity
44
45
46# Op names which identify variable reads which should be saved.
47_VARIABLE_OPS = set(["Variable",
48                     "VariableV2",
49                     "AutoReloadVariable",
50                     "VarHandleOp",
51                     "ReadVariableOp"])
52
53
54def set_cpu0(device_string):
55  """Creates a new device string based on `device_string` but using /CPU:0.
56
57  If the device is already on /CPU:0, this is a no-op.
58
59  Args:
60    device_string: A device string.
61
62  Returns:
63    A device string.
64  """
65  parsed_device = pydev.DeviceSpec.from_string(device_string)
66  parsed_device = parsed_device.replace(device_type="CPU", device_index=0)
67  return parsed_device.to_string()
68
69
70class ReferenceVariableSaveable(saveable_object.SaveableObject):
71  """SaveableObject implementation that handles reference variables."""
72
73  def __init__(self, var, slice_spec, name):
74    spec = saveable_object.SaveSpec(var, slice_spec, name, dtype=var.dtype)
75    super(ReferenceVariableSaveable, self).__init__(var, [spec], name)
76
77  def restore(self, restored_tensors, restored_shapes):
78    restored_tensor = restored_tensors[0]
79    if restored_shapes is not None:
80      restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
81    return state_ops.assign(
82        self.op,
83        restored_tensor,
84        validate_shape=restored_shapes is None and
85        self.op.get_shape().is_fully_defined())
86
87
88class ResourceVariableSaveable(saveable_object.SaveableObject):
89  """SaveableObject implementation that handles ResourceVariables."""
90
91  def __init__(self, var, slice_spec, name):
92    self._var_device = var.device
93    self._var_shape = var.shape
94    if isinstance(var, ops.Tensor):
95      self.handle_op = var.op.inputs[0]
96      tensor = var
97    elif resource_variable_ops.is_resource_variable(var):
98
99      def _read_variable_closure(v):
100        def f():
101          with ops.device(v.device):
102            if context.executing_eagerly() and not v.is_initialized():
103              # A SaveSpec tensor value of `None` indicates that the variable is
104              # uninitialized.
105              return None
106            x = v.read_value()
107            # To allow variables placed on non-CPU devices to be checkpointed,
108            # we copy them to CPU on the same machine first.
109            with ops.device("/device:CPU:0"):
110              return array_ops.identity(x)
111
112        return f
113
114      self.handle_op = var.handle
115      tensor = _read_variable_closure(var)
116    else:
117      raise ValueError(
118          "Saveable is neither a resource variable nor a read operation."
119          " Got: %s" % repr(var))
120    spec = saveable_object.SaveSpec(tensor, slice_spec, name,
121                                    dtype=var.dtype, device=var.device)
122    super(ResourceVariableSaveable, self).__init__(var, [spec], name)
123
124  def restore(self, restored_tensors, restored_shapes):
125    restored_tensor = restored_tensors[0]
126    if restored_shapes is not None:
127      restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
128    # Copy the restored tensor to the variable's device.
129    with ops.device(self._var_device):
130      restored_tensor = array_ops.identity(restored_tensor)
131      return resource_variable_ops.shape_safe_assign_variable_handle(
132          self.handle_op, self._var_shape, restored_tensor)
133
134
135def _tensor_comes_from_variable(v):
136  return isinstance(v, ops.Tensor) and v.op.type in _VARIABLE_OPS
137
138
139def saveable_objects_for_op(op, name):
140  """Create `SaveableObject`s from an operation.
141
142  Args:
143    op: A variable, operation, or SaveableObject to coerce into a
144      SaveableObject.
145    name: A string name for the SaveableObject.
146
147  Yields:
148    `SaveableObject`s which together save/restore `op`.
149
150  Raises:
151    TypeError: If `name` is not a string.
152    ValueError: For operations with no known conversion to SaveableObject.
153  """
154  if not isinstance(name, six.string_types):
155    raise TypeError(
156        "names_to_saveables must be a dict mapping string names to "
157        "trackable operations. Name is not a string: %s" % name)
158  if isinstance(op, saveable_object.SaveableObject):
159    yield op
160  elif isinstance(op, (list, tuple, variables.PartitionedVariable)):
161    if isinstance(op, variables.PartitionedVariable):
162      op = list(op)
163    # A set of slices.
164    slice_name = None
165    # pylint: disable=protected-access
166    for variable in op:
167      if isinstance(variable, saveable_object.SaveableObject):
168        yield variable
169        continue
170      if not isinstance(variable, variables.Variable):
171        raise ValueError("Slices must all be Variables: %s" % variable)
172      if not variable._save_slice_info:
173        raise ValueError("Slices must all be slices: %s" % variable)
174      if slice_name is None:
175        slice_name = variable._save_slice_info.full_name
176      elif slice_name != variable._save_slice_info.full_name:
177        raise ValueError(
178            "Slices must all be from the same tensor: %s != %s" %
179            (slice_name, variable._save_slice_info.full_name))
180      if variable.op.type in ["Variable", "VariableV2",
181                              "AutoReloadVariable"]:
182        yield ReferenceVariableSaveable(
183            variable, variable._save_slice_info.spec, name)
184      else:
185        yield ResourceVariableSaveable(variable, variable._save_slice_info.spec,
186                                       name)
187    # pylint: enable=protected-access
188  elif isinstance(op, trackable.Trackable) and not isinstance(
189      op, variables.Variable):
190    # pylint: disable=protected-access
191    for attr, factory in op._gather_saveables_for_checkpoint().items():
192      if attr == trackable.VARIABLE_VALUE_KEY:
193        # Keep original name for classes masquerading as variables.
194        full_name = name
195      else:
196        full_name = name + "_" + attr
197      op = (factory(full_name) if callable(factory) else factory)
198      for op in saveable_objects_for_op(op, op.name):
199        yield op
200    # pylint: enable=protected-access
201  else:
202    # A variable or tensor.
203    if isinstance(op, resource_variable_ops.BaseResourceVariable):
204      if op._in_graph_mode:  # pylint: disable=protected-access
205        variable = op._graph_element  # pylint: disable=protected-access
206      else:
207        variable = op
208      yield ResourceVariableSaveable(variable, "", name)
209    else:
210      if context.executing_eagerly():
211        raise ValueError("Can only save/restore ResourceVariables when "
212                         "executing eagerly, got type: %s." % type(op))
213
214      variable = ops.convert_to_tensor(op, as_ref=True)
215      if not _tensor_comes_from_variable(variable):
216        raise TypeError("names_to_saveables must be a dict mapping string "
217                        "names to Tensors/Variables. Not a variable: %s" %
218                        variable)
219      if variable.op.type in ["Variable", "VariableV2",
220                              "AutoReloadVariable"]:
221        yield ReferenceVariableSaveable(variable, "", name)
222      else:
223        yield ResourceVariableSaveable(variable, "", name)
224
225
226def op_list_to_dict(op_list, convert_variable_to_tensor=True):
227  """Create a dictionary of names to operation lists.
228
229  Args:
230    op_list: A (nested) list, tuple, or set of Variables or SaveableObjects.
231    convert_variable_to_tensor: Whether or not to convert single Variables
232      with no slice info into Tensors.
233
234  Returns:
235    A dictionary of names to the operations that must be saved under
236    that name.  Variables with save_slice_info are grouped together under the
237    same key in no particular order.
238
239  Raises:
240    TypeError: If the type of op_list or its elements is not supported.
241    ValueError: If at least two saveables share the same name.
242  """
243  if not isinstance(op_list, (list, tuple, set)):
244    raise TypeError("Variables to save should be passed in a dict or a "
245                    "list: %s" % op_list)
246  # List casting is necessary to support sets.
247  op_list = nest.flatten(list(op_list))
248  # When ResourceVariables are converted to Tensors, read ops are added to the
249  # graph. Sorting the op_list ensures that the resulting graph is always
250  # constructed in a deterministic way:
251  op_list = sorted(op_list, key=lambda x: x.name)
252  names_to_saveables = {}
253  # pylint: disable=protected-access
254  for var in op_list:
255    resource_or_ref_variable = (
256        isinstance(var, resource_variable_ops.BaseResourceVariable) or
257        isinstance(var, variables.RefVariable))
258
259    if isinstance(var, saveable_object.SaveableObject):
260      names_to_saveables[var.name] = var
261    elif isinstance(var, variables.PartitionedVariable):
262      if var.name in names_to_saveables:
263        raise ValueError("At least two variables have the same name: %s" %
264                         var.name)
265      names_to_saveables[var.name] = var
266    elif isinstance(var, variables.Variable) and var._save_slice_info:
267      name = var._save_slice_info.full_name
268      if name in names_to_saveables:
269        if not isinstance(names_to_saveables[name], list):
270          raise ValueError("Mixing slices and non-slices with the same name: "
271                           "%s" % name)
272        names_to_saveables[name].append(var)
273      else:
274        names_to_saveables[name] = [var]
275    elif isinstance(var, trackable.Trackable) and not resource_or_ref_variable:
276      trackable_saveables = [
277          (factory() if callable(factory) else factory)
278          for factory in var._gather_saveables_for_checkpoint().values()]
279      names_to_saveables.update(
280          op_list_to_dict(trackable_saveables))
281    else:
282      # Variables (reference and resource) have an _in_graph_mode property
283      # indicating whether they were created in a graph building context. We
284      # also get Tensors when graph building, which do not have this property.
285      if not getattr(var, "_in_graph_mode", True):
286        if not isinstance(var, resource_variable_ops.BaseResourceVariable):
287          raise ValueError(
288              "Can only save/restore ResourceVariables when eager execution "
289              "is enabled, type: %s." % type(var))
290        set_var = names_to_saveables.setdefault(var._shared_name, var)
291        if set_var is not var:
292          raise ValueError(
293              ("Two different ResourceVariable objects with the same "
294               "shared_name '%s' were passed to the Saver. This likely means "
295               "that they were created in different Graphs or isoWlation "
296               "contexts, and may not be checkpointed together.") %
297              (var._shared_name,))
298      else:
299        if convert_variable_to_tensor:
300          if isinstance(var, resource_variable_ops.BaseResourceVariable):
301            var = var._graph_element  # pylint: disable=protected-access
302          else:
303            var = ops.convert_to_tensor(var, as_ref=True)
304          if not _tensor_comes_from_variable(var):
305            raise TypeError("Variable to save is not a Variable: %s" % var)
306        if var.op.type == "ReadVariableOp":
307          name = var.op.inputs[0].op.name
308        else:
309          name = var.op.name
310        if name in names_to_saveables:
311          raise ValueError("At least two variables have the same name: %s" %
312                           name)
313        names_to_saveables[name] = var
314
315    # pylint: enable=protected-access
316  return names_to_saveables
317
318
319def _add_saveable(saveables, seen_ops, saveable):
320  """Adds the saveable to the saveables list.
321
322  Args:
323    saveables: List to append the SaveableObject to.
324    seen_ops: Set of the ops of the saveables already processed.  Used to
325      check that each saveable is only saved once.
326    saveable: The saveable.
327
328  Raises:
329    ValueError: If the saveable has already been processed.
330  """
331  if saveable.op is not None and saveable.op in seen_ops:
332    raise ValueError("The same saveable will be restored with two names: %s" %
333                     saveable.name)
334  saveables.append(saveable)
335  seen_ops.add(saveable.op)
336
337
338def validate_and_slice_inputs(names_to_saveables):
339  """Returns the variables and names that will be used for a Saver.
340
341  Args:
342    names_to_saveables: A dict (k, v) where k is the name of an operation and
343       v is an operation to save or a BaseSaverBuilder.Saver.
344
345  Returns:
346    A list of SaveableObjects.
347
348  Raises:
349    TypeError: If any of the keys are not strings or any of the
350      values are not one of Tensor or Variable or a trackable operation.
351    ValueError: If the same operation is given in more than one value
352      (this also applies to slices of SlicedVariables).
353  """
354  if not isinstance(names_to_saveables, dict):
355    names_to_saveables = op_list_to_dict(names_to_saveables)
356
357  saveables = []
358  seen_ops = object_identity.ObjectIdentitySet()
359  for name, op in sorted(names_to_saveables.items(),
360                         # Avoid comparing ops, sort only by name.
361                         key=lambda x: x[0]):
362    for converted_saveable_object in saveable_objects_for_op(op, name):
363      _add_saveable(saveables, seen_ops, converted_saveable_object)
364  return saveables
365
366
367def trace_save_restore_functions(object_to_save):
368  """Gathers all SaveableObjects and traces the save and restore ops."""
369  saveable_map = {}  # Maps name -> (save function, restore function)
370  for name, saveable_factory in (
371      object_to_save._gather_saveables_for_checkpoint().items()):  # pylint: disable=protected-access
372    if not callable(saveable_factory):
373      if isinstance(saveable_factory, saveable_object.SaveableObject):
374        logging.debug(
375            "Trackable {} should return callable factories, not SaveableObjects"
376            " in `_gather_saveables_for_checkpoint`. This could lead to "
377            "problems loading the SavedModel back into Python."
378            .format(object_to_save))
379      continue
380
381    if is_factory_for_restored_saveable_object(saveable_factory):
382      saveable_map[name] = (saveable_factory.keywords["save_function"],
383                            saveable_factory.keywords["restore_function"])
384    else:
385      concrete_save_fn, concrete_restore_fn = _trace_save_and_restore_function(
386          saveable_factory, object_to_save)
387      if concrete_save_fn is not None:
388        saveable_map[name] = (concrete_save_fn, concrete_restore_fn)
389  return saveable_map
390
391
392def _trace_save_and_restore_function(saveable_factory, object_to_save):
393  """Traces the save and restore concrete functions."""
394  saveables = []
395
396  @def_function.function(
397      input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
398  def save_fn(checkpoint_key):
399    maybe_saveable = saveable_factory(name=checkpoint_key)
400    if isinstance(maybe_saveable, saveable_object.SaveableObject):
401      maybe_saveable = [maybe_saveable]
402    saveables[:] = maybe_saveable
403
404    # Return list of all SaveSpecs created by the factory.
405    ret = []
406    for saveable in saveables:
407      for spec in saveable.specs:
408        ret.append({"name": spec.name, "tensor": spec.tensor,
409                    "slice_spec": spec.slice_spec})
410    return ret
411
412  concrete_save_fn = save_fn.get_concrete_function()
413  if any(isinstance(saveable, trackable.PythonStateSaveable)
414         for saveable in saveables):
415    logging.warn(
416        "Note that object {} stores python values into the checkpoint. "
417        "These values will not be restored when loading the SavedModel "
418        "into python.".format(object_to_save))
419    return None, None
420  if any(isinstance(saveable, trackable.NoRestoreSaveable)
421         for saveable in saveables):
422    return None, None
423
424  restored_type_specs = []
425  tensor_structure = []
426  for saveable in saveables:
427    saveable_tensor_structure = []
428    tensor_structure.append(saveable_tensor_structure)
429    for spec in saveable.specs:
430      restored_type_specs.append(type_spec.type_spec_from_value(spec.tensor))
431      saveable_tensor_structure.append(spec.name)
432
433  @def_function.function(input_signature=restored_type_specs)
434  def restore_fn(*restored_tensors):
435    structured_restored_tensors = nest.pack_sequence_as(
436        tensor_structure, restored_tensors)
437    for saveable, restored_tensors in zip(saveables,
438                                          structured_restored_tensors):
439      saveable.restore(restored_tensors, restored_shapes=None)
440    return 1
441
442  concrete_restore_fn = restore_fn.get_concrete_function()
443  return concrete_save_fn, concrete_restore_fn
444
445
446class RestoredSaveableObject(saveable_object.SaveableObject):
447  """SaveableObject restored from SavedModel using the traced save/restore."""
448
449  def __init__(self, save_function, restore_function, name):
450    self.save_function = save_function
451    self.restore_function = restore_function
452
453    if tensor_util.is_tf_type(name):
454      name_tensor = name
455    else:
456      with ops.init_scope():
457        name_tensor = constant_op.constant(name)
458    tensors = save_function(name_tensor)
459    specs = [saveable_object.SaveSpec(x["tensor"], x["slice_spec"], x["name"])
460             for x in tensors]
461    super(RestoredSaveableObject, self).__init__(None, specs, name)
462
463  def restore(self, restored_tensors, restored_shapes):
464    del restored_shapes  # unused
465    return self.restore_function(
466        *[restored_tensors[i] for i in range(len(self.specs))])
467
468
469def restored_saved_object_factory(save_function, restore_function):
470  return functools.partial(RestoredSaveableObject,
471                           save_function=save_function,
472                           restore_function=restore_function)
473
474
475def create_saveable_object(factory, name, call_with_mapped_captures):
476  """Creates a SaveableObject while potentially in a different graph.
477
478  When creating the frozen saver for SavedModel, the save and restore ops are
479  placed in a separate graph. Since RestoredSaveableObject uses tf.functions to
480  save and restore, the function captures must be mapped to the new graph.
481
482  Args:
483    factory: Factory method for creating the SaveableObject.
484    name: Checkpoint key of this SaveableObject.
485    call_with_mapped_captures: Helper that calls a tf.function while remapping
486      the captures.
487
488  Returns:
489    a SaveableObject.
490  """
491  if (call_with_mapped_captures is None or
492      not is_factory_for_restored_saveable_object(factory)):
493    return factory(name=name)
494
495  concrete_save_fn = factory.keywords["save_function"]
496  def save_fn(name):
497    return call_with_mapped_captures(concrete_save_fn, [name])
498
499  concrete_restore_fn = factory.keywords["restore_function"]
500  def restore_fn(*restored_tensors):
501    return call_with_mapped_captures(concrete_restore_fn, restored_tensors)
502
503  return factory(save_function=save_fn, restore_function=restore_fn, name=name)
504
505
506def is_factory_for_restored_saveable_object(factory):
507  return (isinstance(factory, functools.partial) and
508          factory.func is RestoredSaveableObject)
509