1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for working with and creating SaveableObjects."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import six
21
22from tensorflow.python.eager import context
23from tensorflow.python.framework import device as pydev
24from tensorflow.python.framework import ops
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import resource_variable_ops
27from tensorflow.python.ops import state_ops
28from tensorflow.python.ops import variables
29from tensorflow.python.training.saving import saveable_object
30from tensorflow.python.training.tracking import base as trackable
31
32
33# Op names which identify variable reads which should be saved.
34_VARIABLE_OPS = set(["Variable",
35                     "VariableV2",
36                     "AutoReloadVariable",
37                     "VarHandleOp",
38                     "ReadVariableOp"])
39
40
41def set_cpu0(device_string):
42  """Creates a new device string based on `device_string` but using /CPU:0.
43
44  If the device is already on /CPU:0, this is a no-op.
45
46  Args:
47    device_string: A device string.
48
49  Returns:
50    A device string.
51  """
52  parsed_device = pydev.DeviceSpec.from_string(device_string)
53  parsed_device.device_type = "CPU"
54  parsed_device.device_index = 0
55  return parsed_device.to_string()
56
57
58class ReferenceVariableSaveable(saveable_object.SaveableObject):
59  """SaveableObject implementation that handles reference variables."""
60
61  def __init__(self, var, slice_spec, name):
62    spec = saveable_object.SaveSpec(var, slice_spec, name, dtype=var.dtype)
63    super(ReferenceVariableSaveable, self).__init__(var, [spec], name)
64
65  def restore(self, restored_tensors, restored_shapes):
66    restored_tensor = restored_tensors[0]
67    if restored_shapes is not None:
68      restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
69    return state_ops.assign(
70        self.op,
71        restored_tensor,
72        validate_shape=restored_shapes is None and
73        self.op.get_shape().is_fully_defined())
74
75
76class ResourceVariableSaveable(saveable_object.SaveableObject):
77  """SaveableObject implementation that handles ResourceVariables."""
78
79  def __init__(self, var, slice_spec, name):
80    self._var_device = var.device
81    self._var_shape = var.shape
82    if isinstance(var, ops.Tensor):
83      self.handle_op = var.op.inputs[0]
84      tensor = var
85    elif isinstance(var, resource_variable_ops.ResourceVariable):
86
87      def _read_variable_closure(v):
88        def f():
89          with ops.device(v.device):
90            x = v.read_value()
91            # To allow variables placed on non-CPU devices to be checkpointed,
92            # we copy them to CPU on the same machine first.
93            with ops.device("/device:CPU:0"):
94              return array_ops.identity(x)
95        return f
96
97      self.handle_op = var.handle
98      tensor = _read_variable_closure(var)
99    else:
100      raise ValueError(
101          "Saveable is neither a resource variable nor a read operation."
102          " Got: %s" % repr(var))
103    spec = saveable_object.SaveSpec(tensor, slice_spec, name,
104                                    dtype=var.dtype)
105    super(ResourceVariableSaveable, self).__init__(var, [spec], name)
106
107  def restore(self, restored_tensors, restored_shapes):
108    restored_tensor = restored_tensors[0]
109    if restored_shapes is not None:
110      restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
111    # Copy the restored tensor to the variable's device.
112    with ops.device(self._var_device):
113      restored_tensor = array_ops.identity(restored_tensor)
114      return resource_variable_ops.shape_safe_assign_variable_handle(
115          self.handle_op, self._var_shape, restored_tensor)
116
117
118def _tensor_comes_from_variable(v):
119  return isinstance(v, ops.Tensor) and v.op.type in _VARIABLE_OPS
120
121
122def saveable_objects_for_op(op, name):
123  """Create `SaveableObject`s from an operation.
124
125  Args:
126    op: A variable, operation, or SaveableObject to coerce into a
127      SaveableObject.
128    name: A string name for the SaveableObject.
129
130  Yields:
131    `SaveableObject`s which together save/restore `op`.
132
133  Raises:
134    TypeError: If `name` is not a string.
135    ValueError: For operations with no known conversion to SaveableObject.
136  """
137  if not isinstance(name, six.string_types):
138    raise TypeError(
139        "names_to_saveables must be a dict mapping string names to "
140        "trackable operations. Name is not a string: %s" % name)
141  if isinstance(op, saveable_object.SaveableObject):
142    yield op
143  elif isinstance(op, (list, tuple, variables.PartitionedVariable)):
144    if isinstance(op, variables.PartitionedVariable):
145      op = list(op)
146    # A set of slices.
147    slice_name = None
148    # pylint: disable=protected-access
149    for variable in op:
150      if not isinstance(variable, variables.Variable):
151        raise ValueError("Slices must all be Variables: %s" % variable)
152      if not variable._save_slice_info:
153        raise ValueError("Slices must all be slices: %s" % variable)
154      if slice_name is None:
155        slice_name = variable._save_slice_info.full_name
156      elif slice_name != variable._save_slice_info.full_name:
157        raise ValueError(
158            "Slices must all be from the same tensor: %s != %s" %
159            (slice_name, variable._save_slice_info.full_name))
160      if variable.op.type in ["Variable", "VariableV2",
161                              "AutoReloadVariable"]:
162        yield ReferenceVariableSaveable(
163            variable, variable._save_slice_info.spec, name)
164      else:
165        yield ResourceVariableSaveable(
166            variable, variable._save_slice_info.spec, name)
167    # pylint: enable=protected-access
168  elif isinstance(op, trackable.Trackable) and not isinstance(
169      op, variables.Variable):
170    # pylint: disable=protected-access
171    for attr, factory in op._gather_saveables_for_checkpoint().items():
172      if attr == trackable.VARIABLE_VALUE_KEY:
173        # Keep original name for classes masquerading as variables.
174        full_name = name
175      else:
176        full_name = name + "_" + attr
177      op = (factory(full_name) if callable(factory) else factory)
178      for op in saveable_objects_for_op(op, op.name):
179        yield op
180    # pylint: enable=protected-access
181  else:
182    # A variable or tensor.
183    if isinstance(op, resource_variable_ops.ResourceVariable):
184      # pylint: disable=protected-access
185      if op._in_graph_mode:
186        variable = op._graph_element
187      else:
188        variable = op
189      # pylint: enable=protected-access
190      yield ResourceVariableSaveable(variable, "", name)
191    else:
192      with ops.init_scope():
193        if context.executing_eagerly():
194          raise ValueError("Can only save/restore ResourceVariables when "
195                           "executing eagerly, got type: %s." % type(op))
196
197      variable = ops.internal_convert_to_tensor(op, as_ref=True)
198      if not _tensor_comes_from_variable(variable):
199        raise TypeError("names_to_saveables must be a dict mapping string "
200                        "names to Tensors/Variables. Not a variable: %s" %
201                        variable)
202      if variable.op.type in ["Variable", "VariableV2",
203                              "AutoReloadVariable"]:
204        yield ReferenceVariableSaveable(variable, "", name)
205      else:
206        yield ResourceVariableSaveable(
207            variable, "", name)
208
209
210def op_list_to_dict(op_list, convert_variable_to_tensor=True):
211  """Create a dictionary of names to operation lists.
212
213  Args:
214    op_list: A list, tuple, or set of Variables or SaveableObjects.
215    convert_variable_to_tensor: Whether or not to convert single Variables
216      with no slice info into Tensors.
217
218  Returns:
219    A dictionary of names to the operations that must be saved under
220    that name.  Variables with save_slice_info are grouped together under the
221    same key in no particular order.
222
223  Raises:
224    TypeError: If the type of op_list or its elements is not supported.
225    ValueError: If at least two saveables share the same name.
226  """
227  if not isinstance(op_list, (list, tuple, set)):
228    raise TypeError("Variables to save should be passed in a dict or a "
229                    "list: %s" % op_list)
230  # When ResourceVariables are converted to Tensors, read ops are added to the
231  # graph. Sorting the op_list ensures that the resulting graph is always
232  # constructed in a deterministic way:
233  op_list = sorted(op_list, key=lambda x: x.name)
234  names_to_saveables = {}
235  # pylint: disable=protected-access
236  for var in op_list:
237    if isinstance(var, saveable_object.SaveableObject):
238      names_to_saveables[var.name] = var
239    elif isinstance(var, variables.PartitionedVariable):
240      if var.name in names_to_saveables:
241        raise ValueError("At least two variables have the same name: %s" %
242                         var.name)
243      names_to_saveables[var.name] = var
244    elif isinstance(var, variables.Variable) and var._save_slice_info:
245      name = var._save_slice_info.full_name
246      if name in names_to_saveables:
247        if not isinstance(names_to_saveables[name], list):
248          raise ValueError("Mixing slices and non-slices with the same name: "
249                           "%s" % name)
250        names_to_saveables[name].append(var)
251      else:
252        names_to_saveables[name] = [var]
253    elif (isinstance(var, trackable.Trackable)
254          and not isinstance(var, variables.Variable)):
255      trackable_saveables = [
256          (factory() if callable(factory) else factory)
257          for factory in var._gather_saveables_for_checkpoint().values()]
258      names_to_saveables.update(
259          op_list_to_dict(trackable_saveables))
260    else:
261      # Variables (reference and resource) have an _in_graph_mode property
262      # indicating whether they were created in a graph building context. We
263      # also get Tensors when graph building, which do not have this property.
264      if not getattr(var, "_in_graph_mode", True):
265        if not isinstance(var, resource_variable_ops.ResourceVariable):
266          raise ValueError(
267              "Can only save/restore ResourceVariables when eager execution "
268              "is enabled, type: %s." % type(var))
269        set_var = names_to_saveables.setdefault(var._shared_name, var)
270        if set_var is not var:
271          raise ValueError(
272              ("Two different ResourceVariable objects with the same "
273               "shared_name '%s' were passed to the Saver. This likely means "
274               "that they were created in different Graphs or isolation "
275               "contexts, and may not be checkpointed together.") %
276              (var._shared_name,))
277      else:
278        if convert_variable_to_tensor:
279          if isinstance(var, resource_variable_ops.ResourceVariable):
280            var = var._graph_element  # pylint: disable=protected-access
281          else:
282            var = ops.internal_convert_to_tensor(var, as_ref=True)
283          if not _tensor_comes_from_variable(var):
284            raise TypeError("Variable to save is not a Variable: %s" % var)
285        if var.op.type == "ReadVariableOp":
286          name = var.op.inputs[0].op.name
287        else:
288          name = var.op.name
289        if name in names_to_saveables:
290          raise ValueError("At least two variables have the same name: %s" %
291                           name)
292        names_to_saveables[name] = var
293
294    # pylint: enable=protected-access
295  return names_to_saveables
296
297
298def _add_saveable(saveables, seen_ops, saveable):
299  """Adds the saveable to the saveables list.
300
301  Args:
302    saveables: List to append the SaveableObject to.
303    seen_ops: Set of the ops of the saveables already processed.  Used to
304      check that each saveable is only saved once.
305    saveable: The saveable.
306
307  Raises:
308    ValueError: If the saveable has already been processed.
309  """
310  if saveable.op in seen_ops:
311    raise ValueError("The same saveable will be restored with two names: %s" %
312                     saveable.name)
313  saveables.append(saveable)
314  seen_ops.add(saveable.op)
315
316
317def validate_and_slice_inputs(names_to_saveables):
318  """Returns the variables and names that will be used for a Saver.
319
320  Args:
321    names_to_saveables: A dict (k, v) where k is the name of an operation and
322       v is an operation to save or a BaseSaverBuilder.Saver.
323
324  Returns:
325    A list of SaveableObjects.
326
327  Raises:
328    TypeError: If any of the keys are not strings or any of the
329      values are not one of Tensor or Variable or a trackable operation.
330    ValueError: If the same operation is given in more than one value
331      (this also applies to slices of SlicedVariables).
332  """
333  if not isinstance(names_to_saveables, dict):
334    names_to_saveables = op_list_to_dict(names_to_saveables)
335
336  saveables = []
337  seen_ops = set()
338  for name, op in sorted(names_to_saveables.items(),
339                         # Avoid comparing ops, sort only by name.
340                         key=lambda x: x[0]):
341    for converted_saveable_object in saveable_objects_for_op(op, name):
342      _add_saveable(saveables, seen_ops, converted_saveable_object)
343  return saveables
344