1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Saves and restore variables inside traced @tf.functions."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.protobuf import saver_pb2
22from tensorflow.python.eager import def_function
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_spec
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import io_ops
28from tensorflow.python.training.saving import saveable_object
29from tensorflow.python.util import nest
30
31
32class Saver(object):
33  """A minimal utility class for saving and restoring checkpoints.
34
35  Note that this is a low-level utility which stores Tensors in the keys
36  specified by `SaveableObject`s. Higher-level utilities for object-based
37  checkpointing are built on top of it.
38  """
39
40  def __init__(self, saveable_objects):
41    """Specify a list of `SaveableObject`s to save and restore.
42
43    Args:
44      saveable_objects: A list of `SaveableObject`s.
45    """
46    saveable_objects = list(saveable_objects)
47    for saveable in saveable_objects:
48      if not isinstance(saveable, saveable_object.SaveableObject):
49        raise ValueError(
50            "Saver expected a list of SaveableObjects, got %s." % (saveable,))
51    self._saveable_objects = saveable_objects
52
53  def to_proto(self):
54    """Serializes to a SaverDef referencing the current graph."""
55    filename_tensor = array_ops.placeholder(
56        shape=[], dtype=dtypes.string, name="saver_filename")
57    # TODO(allenl): Add save and restore function names to the proto directly.
58    signature = (tensor_spec.TensorSpec(shape=(), dtype=dtypes.string),)
59    # Autograph is off because of reference cycles which must be collected when
60    # a function is created and destroyed (as in tf.saved_model.save). It's also
61    # not necessary, so having it off may be slightly faster.
62    #
63    # TODO(b/121302372): We should be able to decorate save() and restore()
64    # unconditionally.
65    save_tensor = def_function.function(
66        self.save, input_signature=signature, autograph=False)(filename_tensor)
67    restore_op = def_function.function(
68        self.restore, input_signature=signature, autograph=False)(
69            filename_tensor).op
70    return saver_pb2.SaverDef(
71        filename_tensor_name=filename_tensor.name,
72        save_tensor_name=save_tensor.name,
73        restore_op_name=restore_op.name,
74        version=saver_pb2.SaverDef.V2)
75
76  def save(self, file_prefix):
77    """Save the saveable objects to a checkpoint with `file_prefix`.
78
79    Args:
80      file_prefix: A string or scalar string Tensor containing the prefix to
81        save under.
82    Returns:
83      A scalar string Tensor containing `file_prefix` with control dependencies
84      on the save ops.
85    """
86    tensor_names = []
87    tensors = []
88    tensor_slices = []
89    for saveable in self._saveable_objects:
90      for spec in saveable.specs:
91        tensor_names.append(spec.name)
92        tensors.append(spec.tensor)
93        tensor_slices.append(spec.slice_spec)
94    with ops.device("cpu:0"):
95      with ops.control_dependencies([io_ops.save_v2(
96          file_prefix, tensor_names, tensor_slices, tensors)]):
97        return array_ops.identity(file_prefix)
98
99  def restore(self, file_prefix):
100    """Restore the saveable objects from a checkpoint with `file_prefix`.
101
102    Args:
103      file_prefix: A string or scalar string Tensor containing the prefix for
104        files to read from.
105
106    Returns:
107      A scalar string Tensor containing `file_prefix` with control dependencies
108      on the restore ops.
109    """
110    restore_ops = restore_from_saveable_objects(
111        file_prefix, self._saveable_objects)
112    with ops.device("cpu:0"):
113      with ops.control_dependencies(restore_ops):
114        return array_ops.identity(file_prefix)
115
116
117def restore_from_saveable_objects(file_prefix, saveable_objects):
118  """Reads from a checkpoint and returns restore ops for `saveable_objects`s."""
119  restore_specs = []
120  tensor_structure = []
121  for saveable in saveable_objects:
122    saveable_tensor_structure = []
123    tensor_structure.append(saveable_tensor_structure)
124    for spec in saveable.specs:
125      saveable_tensor_structure.append(spec.name)
126      restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
127  tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs)
128  with ops.device("cpu:0"):
129    restored_tensors = io_ops.restore_v2(
130        file_prefix, tensor_names, tensor_slices, tensor_dtypes)
131  structured_restored_tensors = nest.pack_sequence_as(
132      tensor_structure, restored_tensors)
133  restore_ops = []
134  for saveable, restored_tensors in zip(saveable_objects,
135                                        structured_restored_tensors):
136    restore_ops.append(saveable.restore(restored_tensors,
137                                        restored_shapes=None))
138  return restore_ops
139