1"""Manages a graph of Trackable objects."""
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21import weakref
22
23from tensorflow.core.protobuf import trackable_object_graph_pb2
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.training import optimizer as optimizer_v1
28from tensorflow.python.training.saving import saveable_object as saveable_object_lib
29from tensorflow.python.training.saving import saveable_object_util
30from tensorflow.python.training.tracking import base
31from tensorflow.python.training.tracking import object_identity
32from tensorflow.python.training.tracking import tracking
33
34
35_ESCAPE_CHAR = "."  # For avoiding conflicts with user-specified names.
36
37# Keyword for identifying that the next bit of a checkpoint variable name is a
38# slot name. Checkpoint names for slot variables look like:
39#
40#   <path to variable>/<_OPTIMIZER_SLOTS_NAME>/<path to optimizer>/<slot name>
41#
42# Where <path to variable> is a full path from the checkpoint root to the
43# variable being slotted for.
44_OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"
45# Keyword for separating the path to an object from the name of an
46# attribute in checkpoint names. Used like:
47#   <path to variable>/<_OBJECT_ATTRIBUTES_NAME>/<name of attribute>
48_OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"
49
50
51def _escape_local_name(name):
52  # We need to support slashes in local names for compatibility, since this
53  # naming scheme is being patched in to things like Layer.add_variable where
54  # slashes were previously accepted. We also want to use slashes to indicate
55  # edges traversed to reach the variable, so we escape forward slashes in
56  # names.
57  return (name.replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR)
58          .replace(r"/", _ESCAPE_CHAR + "S"))
59
60
61def _object_prefix_from_path(path_to_root):
62  return "/".join(
63      (_escape_local_name(trackable.name)
64       for trackable in path_to_root))
65
66
67def _slot_variable_naming_for_optimizer(optimizer_path):
68  """Make a function for naming slot variables in an optimizer."""
69  # Name slot variables:
70  #
71  #   <variable name>/<_OPTIMIZER_SLOTS_NAME>/<optimizer path>/<slot name>
72  #
73  # where <variable name> is exactly the checkpoint name used for the original
74  # variable, including the path from the checkpoint root and the local name in
75  # the object which owns it. Note that we only save slot variables if the
76  # variable it's slotting for is also being saved.
77
78  optimizer_identifier = "/%s/%s/" % (_OPTIMIZER_SLOTS_NAME, optimizer_path)
79
80  def _name_slot_variable(variable_path, slot_name):
81    """With an optimizer specified, name a slot variable."""
82    return (variable_path
83            + optimizer_identifier
84            + _escape_local_name(slot_name))
85
86  return _name_slot_variable
87
88
89def _serialize_slot_variables(trackable_objects, node_ids, object_names):
90  """Gather and name slot variables."""
91  non_slot_objects = list(trackable_objects)
92  slot_variables = object_identity.ObjectIdentityDictionary()
93  for trackable in non_slot_objects:
94    if (isinstance(trackable, optimizer_v1.Optimizer)
95        # TODO(b/110718070): Fix Keras imports.
96        or hasattr(trackable, "_create_or_restore_slot_variable")):
97      naming_scheme = _slot_variable_naming_for_optimizer(
98          optimizer_path=object_names[trackable])
99      slot_names = trackable.get_slot_names()
100      for slot_name in slot_names:
101        for original_variable_node_id, original_variable in enumerate(
102            non_slot_objects):
103          try:
104            slot_variable = trackable.get_slot(
105                original_variable, slot_name)
106          except (AttributeError, KeyError):
107            slot_variable = None
108          if slot_variable is None:
109            continue
110          slot_variable._maybe_initialize_trackable()  # pylint: disable=protected-access
111          if slot_variable._checkpoint_dependencies:  # pylint: disable=protected-access
112            # TODO(allenl): Gather dependencies of slot variables.
113            raise NotImplementedError(
114                "Currently only variables with no dependencies can be saved as "
115                "slot variables. File a feature request if this limitation "
116                "bothers you.")
117          if slot_variable in node_ids:
118            raise NotImplementedError(
119                "A slot variable was re-used as a dependency of a "
120                "Trackable object. This is not currently allowed. File a "
121                "feature request if this limitation bothers you.")
122          checkpoint_name = naming_scheme(
123              variable_path=object_names[original_variable],
124              slot_name=slot_name)
125          object_names[slot_variable] = checkpoint_name
126          slot_variable_node_id = len(trackable_objects)
127          node_ids[slot_variable] = slot_variable_node_id
128          trackable_objects.append(slot_variable)
129          slot_variable_proto = (
130              trackable_object_graph_pb2.TrackableObjectGraph
131              .TrackableObject.SlotVariableReference(
132                  slot_name=slot_name,
133                  original_variable_node_id=original_variable_node_id,
134                  slot_variable_node_id=slot_variable_node_id))
135          slot_variables.setdefault(trackable, []).append(
136              slot_variable_proto)
137  return slot_variables
138
139
140class ObjectGraphView(object):
141  """Gathers and serializes an object graph."""
142
143  def __init__(self, root, saveables_cache=None):
144    """Configure the graph view.
145
146    Args:
147      root: A `Trackable` object whose variables (including the variables
148        of dependencies, recursively) should be saved. May be a weak reference.
149      saveables_cache: A dictionary mapping `Trackable` objects ->
150        attribute names -> SaveableObjects, used to avoid re-creating
151        SaveableObjects when graph building.
152    """
153    self._root_ref = root
154    self._saveables_cache = saveables_cache
155
156  def list_dependencies(self, obj):
157    # pylint: disable=protected-access
158    obj._maybe_initialize_trackable()
159    return obj._checkpoint_dependencies
160    # pylint: enable=protected-access
161
162  @property
163  def saveables_cache(self):
164    """Maps Trackable objects -> attribute names -> list(SaveableObjects).
165
166    Used to avoid re-creating SaveableObjects when graph building. None when
167    executing eagerly.
168
169    Returns:
170      The cache (an object-identity dictionary), or None if caching is disabled.
171    """
172    return self._saveables_cache
173
174  @property
175  def root(self):
176    if isinstance(self._root_ref, weakref.ref):
177      derefed = self._root_ref()
178      assert derefed is not None
179      return derefed
180    else:
181      return self._root_ref
182
183  def _breadth_first_traversal(self):
184    """Find shortest paths to all dependencies of self.root."""
185    bfs_sorted = []
186    to_visit = collections.deque([self.root])
187    path_to_root = object_identity.ObjectIdentityDictionary()
188    path_to_root[self.root] = ()
189    while to_visit:
190      current_trackable = to_visit.popleft()
191      if isinstance(current_trackable, tracking.NotTrackable):
192        raise NotImplementedError(
193            ("The object %s does not support object-based saving. File a "
194             "feature request if this limitation bothers you. In the meantime, "
195             "you can remove the dependency on this object and save everything "
196             "else.")
197            % (current_trackable,))
198      bfs_sorted.append(current_trackable)
199      for name, dependency in self.list_dependencies(current_trackable):
200        if dependency not in path_to_root:
201          path_to_root[dependency] = (
202              path_to_root[current_trackable] + (
203                  base.TrackableReference(name, dependency),))
204          to_visit.append(dependency)
205    return bfs_sorted, path_to_root
206
207  def _add_attributes_to_object_graph(
208      self, trackable_objects, object_graph_proto, node_ids, object_names,
209      object_map):
210    """Create SaveableObjects and corresponding SerializedTensor protos."""
211    named_saveable_objects = []
212    if self._saveables_cache is None:
213      # No SaveableObject caching. Either we're executing eagerly, or building a
214      # static save which is specialized to the current Python state.
215      feed_additions = None
216    else:
217      # If we are caching SaveableObjects, we need to build up a feed_dict with
218      # functions computing volatile Python state to be saved with the
219      # checkpoint.
220      feed_additions = {}
221    for checkpoint_id, (trackable, object_proto) in enumerate(
222        zip(trackable_objects, object_graph_proto.nodes)):
223      assert node_ids[trackable] == checkpoint_id
224      object_name = object_names[trackable]
225      if object_map is None:
226        object_to_save = trackable
227      else:
228        object_to_save = object_map.get(trackable, trackable)
229      if self._saveables_cache is not None:
230        cached_attributes = self._saveables_cache.setdefault(object_to_save, {})
231      else:
232        cached_attributes = None
233
234      for name, saveable_factory in (
235          object_to_save._gather_saveables_for_checkpoint().items()):  # pylint: disable=protected-access
236        attribute = object_proto.attributes.add()
237        attribute.name = name
238        attribute.checkpoint_key = "%s/%s/%s" % (
239            object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name))
240        if cached_attributes is None:
241          saveables = None
242        else:
243          saveables = cached_attributes.get(name, None)
244          if saveables is not None:
245            for saveable in saveables:
246              if attribute.checkpoint_key not in saveable.name:
247                # The checkpoint key for this SaveableObject is different. We
248                # need to re-create it.
249                saveables = None
250                del cached_attributes[name]
251                break
252        if saveables is None:
253          if callable(saveable_factory):
254            maybe_saveable = saveable_factory(name=attribute.checkpoint_key)
255          else:
256            maybe_saveable = saveable_factory
257          if isinstance(maybe_saveable, saveable_object_lib.SaveableObject):
258            saveables = (maybe_saveable,)
259          else:
260            # Figure out the name-based Saver's name for this variable. If it's
261            # already a SaveableObject we'd just get the checkpoint key back, so
262            # we leave full_name blank.
263            saver_dict = saveable_object_util.op_list_to_dict(
264                [maybe_saveable], convert_variable_to_tensor=False)
265            full_name, = saver_dict.keys()
266            saveables = tuple(saveable_object_util.saveable_objects_for_op(
267                op=maybe_saveable, name=attribute.checkpoint_key))
268            for saveable in saveables:
269              saveable.full_name = full_name
270          for saveable in saveables:
271            if attribute.checkpoint_key not in saveable.name:
272              raise AssertionError(
273                  ("The object %s produced a SaveableObject with name '%s' for "
274                   "attribute '%s'. Expected a name containing '%s'.")
275                  % (trackable, name, saveable.name,
276                     attribute.checkpoint_key))
277          if cached_attributes is not None:
278            cached_attributes[name] = saveables
279
280        optional_restore = None
281        for saveable in saveables:
282          if optional_restore is None:
283            optional_restore = saveable.optional_restore
284          else:
285            optional_restore = optional_restore and saveable.optional_restore
286
287          if hasattr(saveable, "full_name"):
288            attribute.full_name = saveable.full_name
289          if isinstance(saveable, base.PythonStateSaveable):
290            if feed_additions is None:
291              assert self._saveables_cache is None
292              # If we're not caching saveables, then we're either executing
293              # eagerly or building a static save/restore (e.g. for a
294              # SavedModel). In either case, we should embed the current Python
295              # state in the graph rather than relying on a feed dict.
296              saveable = saveable.freeze()
297            else:
298              saveable_feed_dict = saveable.feed_dict_additions()
299              for new_feed_key in saveable_feed_dict.keys():
300                if new_feed_key in feed_additions:
301                  raise AssertionError(
302                      ("The object %s tried to feed a value for the Tensor %s "
303                       "when saving, but another object is already feeding a "
304                       "value.")
305                      % (trackable, new_feed_key))
306              feed_additions.update(saveable_feed_dict)
307          named_saveable_objects.append(saveable)
308        if optional_restore is None:
309          optional_restore = False
310        attribute.optional_restore = optional_restore
311
312    return named_saveable_objects, feed_additions
313
314  def _fill_object_graph_proto(self, trackable_objects,
315                               node_ids,
316                               slot_variables,
317                               object_graph_proto=None):
318    """Name non-slot `Trackable`s and add them to `object_graph_proto`."""
319    if object_graph_proto is None:
320      object_graph_proto = (
321          trackable_object_graph_pb2.TrackableObjectGraph())
322    for checkpoint_id, trackable in enumerate(trackable_objects):
323      assert node_ids[trackable] == checkpoint_id
324      object_proto = object_graph_proto.nodes.add()
325      object_proto.slot_variables.extend(slot_variables.get(trackable, ()))
326      for child in self.list_dependencies(trackable):
327        child_proto = object_proto.children.add()
328        child_proto.node_id = node_ids[child.ref]
329        child_proto.local_name = child.name
330    return object_graph_proto
331
332  def _serialize_gathered_objects(self, trackable_objects, path_to_root,
333                                  object_map=None):
334    """Create SaveableObjects and protos for gathered objects."""
335    object_names = object_identity.ObjectIdentityDictionary()
336    for obj, path in path_to_root.items():
337      object_names[obj] = _object_prefix_from_path(path)
338    node_ids = object_identity.ObjectIdentityDictionary()
339    for node_id, node in enumerate(trackable_objects):
340      node_ids[node] = node_id
341    slot_variables = _serialize_slot_variables(
342        trackable_objects=trackable_objects,
343        node_ids=node_ids,
344        object_names=object_names)
345    object_graph_proto = self._fill_object_graph_proto(
346        trackable_objects=trackable_objects,
347        node_ids=node_ids,
348        slot_variables=slot_variables)
349    named_saveable_objects, feed_additions = (
350        self._add_attributes_to_object_graph(
351            trackable_objects=trackable_objects,
352            object_graph_proto=object_graph_proto,
353            node_ids=node_ids,
354            object_names=object_names,
355            object_map=object_map))
356    return named_saveable_objects, object_graph_proto, feed_additions
357
358  def serialize_object_graph(self):
359    """Determine checkpoint keys for variables and build a serialized graph.
360
361    Non-slot variables are keyed based on a shortest path from the root saveable
362    to the object which owns the variable (i.e. the one which called
363    `Trackable._add_variable` to create it).
364
365    Slot variables are keyed based on a shortest path to the variable being
366    slotted for, a shortest path to their optimizer, and the slot name.
367
368    Returns:
369      A tuple of (named_variables, object_graph_proto, feed_additions):
370        named_variables: A dictionary mapping names to variable objects.
371        object_graph_proto: A TrackableObjectGraph protocol buffer
372          containing the serialized object graph and variable references.
373        feed_additions: A dictionary mapping from Tensors to values which should
374          be fed when saving.
375
376    Raises:
377      ValueError: If there are invalid characters in an optimizer's slot names.
378    """
379    trackable_objects, path_to_root = self._breadth_first_traversal()
380    return self._serialize_gathered_objects(
381        trackable_objects, path_to_root)
382
383  def frozen_saveable_objects(self, object_map=None, to_graph=None):
384    """Creates SaveableObjects with the current object graph frozen."""
385    trackable_objects, path_to_root = self._breadth_first_traversal()
386    if to_graph:
387      target_context = to_graph.as_default
388    else:
389      target_context = ops.NullContextmanager
390    with target_context():
391      named_saveable_objects, graph_proto, _ = self._serialize_gathered_objects(
392          trackable_objects,
393          path_to_root,
394          object_map)
395      with ops.device("/cpu:0"):
396        object_graph_tensor = constant_op.constant(
397            graph_proto.SerializeToString(), dtype=dtypes.string)
398      named_saveable_objects.append(
399          base.NoRestoreSaveable(
400              tensor=object_graph_tensor,
401              name=base.OBJECT_GRAPH_PROTO_KEY))
402    return named_saveable_objects
403
404  def objects_ids_and_slot_variables(self):
405    """Traverse the object graph and list all accessible objects.
406
407    Looks for `Trackable` objects which are dependencies of
408    `root_trackable`. Includes slot variables only if the variable they are
409    slotting for and the optimizer are dependencies of `root_trackable`
410    (i.e. if they would be saved with a checkpoint).
411
412    Returns:
413      A tuple of (trackable objects, object -> node id, slot variables)
414    """
415    trackable_objects, path_to_root = self._breadth_first_traversal()
416    object_names = object_identity.ObjectIdentityDictionary()
417    for obj, path in path_to_root.items():
418      object_names[obj] = _object_prefix_from_path(path)
419    node_ids = object_identity.ObjectIdentityDictionary()
420    for node_id, node in enumerate(trackable_objects):
421      node_ids[node] = node_id
422    slot_variables = _serialize_slot_variables(
423        trackable_objects=trackable_objects,
424        node_ids=node_ids,
425        object_names=object_names)
426    return trackable_objects, node_ids, slot_variables
427
428  def list_objects(self):
429    """Traverse the object graph and list all accessible objects."""
430    trackable_objects, _, _ = self.objects_ids_and_slot_variables()
431    return trackable_objects
432