1"""Manages a graph of Trackable objects."""
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21import weakref
22
23from tensorflow.core.protobuf import trackable_object_graph_pb2
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.training import optimizer as optimizer_v1
28from tensorflow.python.training.saving import saveable_object as saveable_object_lib
29from tensorflow.python.training.saving import saveable_object_util
30from tensorflow.python.training.tracking import base
31from tensorflow.python.training.tracking import tracking
32from tensorflow.python.util import object_identity
33
34
35_ESCAPE_CHAR = "."  # For avoiding conflicts with user-specified names.
36
37# Keyword for identifying that the next bit of a checkpoint variable name is a
38# slot name. Checkpoint names for slot variables look like:
39#
40#   <path to variable>/<_OPTIMIZER_SLOTS_NAME>/<path to optimizer>/<slot name>
41#
42# Where <path to variable> is a full path from the checkpoint root to the
43# variable being slotted for.
44_OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"
45# Keyword for separating the path to an object from the name of an
46# attribute in checkpoint names. Used like:
47#   <path to variable>/<_OBJECT_ATTRIBUTES_NAME>/<name of attribute>
48_OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"
49
50
51def _escape_local_name(name):
52  # We need to support slashes in local names for compatibility, since this
53  # naming scheme is being patched in to things like Layer.add_variable where
54  # slashes were previously accepted. We also want to use slashes to indicate
55  # edges traversed to reach the variable, so we escape forward slashes in
56  # names.
57  return (name.replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR)
58          .replace(r"/", _ESCAPE_CHAR + "S"))
59
60
61def _object_prefix_from_path(path_to_root):
62  return "/".join(
63      (_escape_local_name(trackable.name)
64       for trackable in path_to_root))
65
66
67def _slot_variable_naming_for_optimizer(optimizer_path):
68  """Make a function for naming slot variables in an optimizer."""
69  # Name slot variables:
70  #
71  #   <variable name>/<_OPTIMIZER_SLOTS_NAME>/<optimizer path>/<slot name>
72  #
73  # where <variable name> is exactly the checkpoint name used for the original
74  # variable, including the path from the checkpoint root and the local name in
75  # the object which owns it. Note that we only save slot variables if the
76  # variable it's slotting for is also being saved.
77
78  optimizer_identifier = "/%s/%s/" % (_OPTIMIZER_SLOTS_NAME, optimizer_path)
79
80  def _name_slot_variable(variable_path, slot_name):
81    """With an optimizer specified, name a slot variable."""
82    return (variable_path
83            + optimizer_identifier
84            + _escape_local_name(slot_name))
85
86  return _name_slot_variable
87
88
89def _serialize_slot_variables(trackable_objects, node_ids, object_names):
90  """Gather and name slot variables."""
91  non_slot_objects = list(trackable_objects)
92  slot_variables = object_identity.ObjectIdentityDictionary()
93  for trackable in non_slot_objects:
94    if (isinstance(trackable, optimizer_v1.Optimizer)
95        # TODO(b/110718070): Fix Keras imports.
96        # Note: dir() is used rather than hasattr() here to avoid triggering
97        # custom __getattr__ code, see b/152031870 for context.
98        or "_create_or_restore_slot_variable" in dir(trackable)):
99      naming_scheme = _slot_variable_naming_for_optimizer(
100          optimizer_path=object_names[trackable])
101      slot_names = trackable.get_slot_names()
102      for slot_name in slot_names:
103        for original_variable_node_id, original_variable in enumerate(
104            non_slot_objects):
105          try:
106            slot_variable = trackable.get_slot(
107                original_variable, slot_name)
108          except (AttributeError, KeyError):
109            slot_variable = None
110          if slot_variable is None:
111            continue
112          slot_variable._maybe_initialize_trackable()  # pylint: disable=protected-access
113          if slot_variable._checkpoint_dependencies:  # pylint: disable=protected-access
114            # TODO(allenl): Gather dependencies of slot variables.
115            raise NotImplementedError(
116                "Currently only variables with no dependencies can be saved as "
117                "slot variables. File a feature request if this limitation "
118                "bothers you.")
119          if slot_variable in node_ids:
120            raise NotImplementedError(
121                ("A slot variable was re-used as a dependency of a "
122                 "Trackable object: %s. This is not currently "
123                 "allowed. File a feature request if this limitation bothers "
124                 "you.") % slot_variable)
125          checkpoint_name = naming_scheme(
126              variable_path=object_names[original_variable],
127              slot_name=slot_name)
128          object_names[slot_variable] = checkpoint_name
129          slot_variable_node_id = len(trackable_objects)
130          node_ids[slot_variable] = slot_variable_node_id
131          trackable_objects.append(slot_variable)
132          slot_variable_proto = (
133              trackable_object_graph_pb2.TrackableObjectGraph
134              .TrackableObject.SlotVariableReference(
135                  slot_name=slot_name,
136                  original_variable_node_id=original_variable_node_id,
137                  slot_variable_node_id=slot_variable_node_id))
138          slot_variables.setdefault(trackable, []).append(
139              slot_variable_proto)
140  return slot_variables
141
142
143class ObjectGraphView(object):
144  """Gathers and serializes an object graph."""
145
146  def __init__(self, root, saveables_cache=None, attached_dependencies=None):
147    """Configure the graph view.
148
149    Args:
150      root: A `Trackable` object whose variables (including the variables
151        of dependencies, recursively) should be saved. May be a weak reference.
152      saveables_cache: A dictionary mapping `Trackable` objects ->
153        attribute names -> SaveableObjects, used to avoid re-creating
154        SaveableObjects when graph building.
155      attached_dependencies: Dependencies to attach to the root object. Used
156        when saving a Checkpoint with a defined root object.
157    """
158    self._root_ref = root
159    self._saveables_cache = saveables_cache
160    self._attached_dependencies = attached_dependencies
161
162  def list_dependencies(self, obj):
163    # pylint: disable=protected-access
164    obj._maybe_initialize_trackable()
165    dependencies = obj._checkpoint_dependencies
166    # pylint: enable=protected-access
167
168    if obj is self.root and self._attached_dependencies:
169      dependencies = dependencies.copy()
170      dependencies.extend(self._attached_dependencies)
171    return dependencies
172
173  @property
174  def saveables_cache(self):
175    """Maps Trackable objects -> attribute names -> list(SaveableObjects).
176
177    Used to avoid re-creating SaveableObjects when graph building. None when
178    executing eagerly.
179
180    Returns:
181      The cache (an object-identity dictionary), or None if caching is disabled.
182    """
183    return self._saveables_cache
184
185  @property
186  def attached_dependencies(self):
187    """Returns list of dependencies that should be saved in the checkpoint.
188
189    These dependencies are not tracked by root, but are in the checkpoint.
190    This is defined when the user creates a Checkpoint with both root and kwargs
191    set.
192
193    Returns:
194      A list of TrackableReferences.
195    """
196    return self._attached_dependencies
197
198  @property
199  def root(self):
200    if isinstance(self._root_ref, weakref.ref):
201      derefed = self._root_ref()
202      assert derefed is not None
203      return derefed
204    else:
205      return self._root_ref
206
207  def _breadth_first_traversal(self):
208    """Find shortest paths to all dependencies of self.root."""
209    bfs_sorted = []
210    to_visit = collections.deque([self.root])
211    path_to_root = object_identity.ObjectIdentityDictionary()
212    path_to_root[self.root] = ()
213    while to_visit:
214      current_trackable = to_visit.popleft()
215      if isinstance(current_trackable, tracking.NotTrackable):
216        raise NotImplementedError(
217            ("The object %s does not support object-based saving. File a "
218             "feature request if this limitation bothers you. In the meantime, "
219             "you can remove the dependency on this object and save everything "
220             "else.")
221            % (current_trackable,))
222      bfs_sorted.append(current_trackable)
223      for name, dependency in self.list_dependencies(current_trackable):
224        if dependency not in path_to_root:
225          path_to_root[dependency] = (
226              path_to_root[current_trackable] + (
227                  base.TrackableReference(name, dependency),))
228          to_visit.append(dependency)
229    return bfs_sorted, path_to_root
230
231  def _add_attributes_to_object_graph(
232      self, trackable_objects, object_graph_proto, node_ids, object_names,
233      object_map, call_with_mapped_captures):
234    """Create SaveableObjects and corresponding SerializedTensor protos."""
235    named_saveable_objects = []
236    if self._saveables_cache is None:
237      # No SaveableObject caching. Either we're executing eagerly, or building a
238      # static save which is specialized to the current Python state.
239      feed_additions = None
240    else:
241      # If we are caching SaveableObjects, we need to build up a feed_dict with
242      # functions computing volatile Python state to be saved with the
243      # checkpoint.
244      feed_additions = {}
245    for checkpoint_id, (trackable, object_proto) in enumerate(
246        zip(trackable_objects, object_graph_proto.nodes)):
247      assert node_ids[trackable] == checkpoint_id
248      object_name = object_names[trackable]
249      if object_map is None:
250        object_to_save = trackable
251      else:
252        object_to_save = object_map.get(trackable, trackable)
253      if self._saveables_cache is not None:
254        cached_attributes = self._saveables_cache.setdefault(object_to_save, {})
255      else:
256        cached_attributes = None
257
258      for name, saveable_factory in (
259          object_to_save._gather_saveables_for_checkpoint().items()):  # pylint: disable=protected-access
260        attribute = object_proto.attributes.add()
261        attribute.name = name
262        attribute.checkpoint_key = "%s/%s/%s" % (
263            object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name))
264        if cached_attributes is None:
265          saveables = None
266        else:
267          saveables = cached_attributes.get(name, None)
268          if saveables is not None:
269            for saveable in saveables:
270              if attribute.checkpoint_key not in saveable.name:
271                # The checkpoint key for this SaveableObject is different. We
272                # need to re-create it.
273                saveables = None
274                del cached_attributes[name]
275                break
276        if saveables is None:
277          if callable(saveable_factory):
278            maybe_saveable = saveable_object_util.create_saveable_object(
279                saveable_factory, attribute.checkpoint_key,
280                call_with_mapped_captures)
281          else:
282            maybe_saveable = saveable_factory
283          if isinstance(maybe_saveable, saveable_object_lib.SaveableObject):
284            saveables = (maybe_saveable,)
285          else:
286            # Figure out the name-based Saver's name for this variable. If it's
287            # already a SaveableObject we'd just get the checkpoint key back, so
288            # we leave full_name blank.
289            saver_dict = saveable_object_util.op_list_to_dict(
290                [maybe_saveable], convert_variable_to_tensor=False)
291            full_name, = saver_dict.keys()
292            saveables = tuple(saveable_object_util.saveable_objects_for_op(
293                op=maybe_saveable, name=attribute.checkpoint_key))
294            for saveable in saveables:
295              saveable.full_name = full_name
296          for saveable in saveables:
297            if attribute.checkpoint_key not in saveable.name:
298              raise AssertionError(
299                  ("The object %s produced a SaveableObject with name '%s' for "
300                   "attribute '%s'. Expected a name containing '%s'.")
301                  % (trackable, name, saveable.name,
302                     attribute.checkpoint_key))
303          if cached_attributes is not None:
304            cached_attributes[name] = saveables
305
306        optional_restore = None
307        for saveable in saveables:
308          if optional_restore is None:
309            optional_restore = saveable.optional_restore
310          else:
311            optional_restore = optional_restore and saveable.optional_restore
312
313          if hasattr(saveable, "full_name"):
314            attribute.full_name = saveable.full_name
315          if isinstance(saveable, base.PythonStateSaveable):
316            if feed_additions is None:
317              assert self._saveables_cache is None
318              # If we're not caching saveables, then we're either executing
319              # eagerly or building a static save/restore (e.g. for a
320              # SavedModel). In either case, we should embed the current Python
321              # state in the graph rather than relying on a feed dict.
322              saveable = saveable.freeze()
323            else:
324              saveable_feed_dict = saveable.feed_dict_additions()
325              for new_feed_key in saveable_feed_dict.keys():
326                if new_feed_key in feed_additions:
327                  raise AssertionError(
328                      ("The object %s tried to feed a value for the Tensor %s "
329                       "when saving, but another object is already feeding a "
330                       "value.")
331                      % (trackable, new_feed_key))
332              feed_additions.update(saveable_feed_dict)
333          named_saveable_objects.append(saveable)
334        if optional_restore is None:
335          optional_restore = False
336        attribute.optional_restore = optional_restore
337
338    return named_saveable_objects, feed_additions
339
340  def _fill_object_graph_proto(self, trackable_objects,
341                               node_ids,
342                               slot_variables,
343                               object_graph_proto=None):
344    """Name non-slot `Trackable`s and add them to `object_graph_proto`."""
345    if object_graph_proto is None:
346      object_graph_proto = (
347          trackable_object_graph_pb2.TrackableObjectGraph())
348    for checkpoint_id, trackable in enumerate(trackable_objects):
349      assert node_ids[trackable] == checkpoint_id
350      object_proto = object_graph_proto.nodes.add()
351      object_proto.slot_variables.extend(slot_variables.get(trackable, ()))
352      for child in self.list_dependencies(trackable):
353        child_proto = object_proto.children.add()
354        child_proto.node_id = node_ids[child.ref]
355        child_proto.local_name = child.name
356    return object_graph_proto
357
358  def _serialize_gathered_objects(self, trackable_objects, path_to_root,
359                                  object_map=None,
360                                  call_with_mapped_captures=None):
361    """Create SaveableObjects and protos for gathered objects."""
362    object_names = object_identity.ObjectIdentityDictionary()
363    for obj, path in path_to_root.items():
364      object_names[obj] = _object_prefix_from_path(path)
365    node_ids = object_identity.ObjectIdentityDictionary()
366    for node_id, node in enumerate(trackable_objects):
367      node_ids[node] = node_id
368    slot_variables = _serialize_slot_variables(
369        trackable_objects=trackable_objects,
370        node_ids=node_ids,
371        object_names=object_names)
372    object_graph_proto = self._fill_object_graph_proto(
373        trackable_objects=trackable_objects,
374        node_ids=node_ids,
375        slot_variables=slot_variables)
376    named_saveable_objects, feed_additions = (
377        self._add_attributes_to_object_graph(
378            trackable_objects=trackable_objects,
379            object_graph_proto=object_graph_proto,
380            node_ids=node_ids,
381            object_names=object_names,
382            object_map=object_map,
383            call_with_mapped_captures=call_with_mapped_captures))
384    return named_saveable_objects, object_graph_proto, feed_additions
385
386  def serialize_object_graph(self):
387    """Determine checkpoint keys for variables and build a serialized graph.
388
389    Non-slot variables are keyed based on a shortest path from the root saveable
390    to the object which owns the variable (i.e. the one which called
391    `Trackable._add_variable` to create it).
392
393    Slot variables are keyed based on a shortest path to the variable being
394    slotted for, a shortest path to their optimizer, and the slot name.
395
396    Returns:
397      A tuple of (named_variables, object_graph_proto, feed_additions):
398        named_variables: A dictionary mapping names to variable objects.
399        object_graph_proto: A TrackableObjectGraph protocol buffer
400          containing the serialized object graph and variable references.
401        feed_additions: A dictionary mapping from Tensors to values which should
402          be fed when saving.
403
404    Raises:
405      ValueError: If there are invalid characters in an optimizer's slot names.
406    """
407    trackable_objects, path_to_root = self._breadth_first_traversal()
408    return self._serialize_gathered_objects(
409        trackable_objects, path_to_root)
410
411  def frozen_saveable_objects(self, object_map=None, to_graph=None,
412                              call_with_mapped_captures=None):
413    """Creates SaveableObjects with the current object graph frozen."""
414    trackable_objects, path_to_root = self._breadth_first_traversal()
415    if to_graph:
416      target_context = to_graph.as_default
417    else:
418      target_context = ops.NullContextmanager
419    with target_context():
420      named_saveable_objects, graph_proto, _ = self._serialize_gathered_objects(
421          trackable_objects,
422          path_to_root,
423          object_map,
424          call_with_mapped_captures)
425      with ops.device("/cpu:0"):
426        object_graph_tensor = constant_op.constant(
427            graph_proto.SerializeToString(), dtype=dtypes.string)
428      named_saveable_objects.append(
429          base.NoRestoreSaveable(
430              tensor=object_graph_tensor,
431              name=base.OBJECT_GRAPH_PROTO_KEY))
432    return named_saveable_objects
433
434  def objects_ids_and_slot_variables_and_paths(self):
435    """Traverse the object graph and list all accessible objects.
436
437    Looks for `Trackable` objects which are dependencies of
438    `root_trackable`. Includes slot variables only if the variable they are
439    slotting for and the optimizer are dependencies of `root_trackable`
440    (i.e. if they would be saved with a checkpoint).
441
442    Returns:
443      A tuple of (trackable objects, paths from root for each object,
444                  object -> node id, slot variables)
445    """
446    trackable_objects, path_to_root = self._breadth_first_traversal()
447    object_names = object_identity.ObjectIdentityDictionary()
448    for obj, path in path_to_root.items():
449      object_names[obj] = _object_prefix_from_path(path)
450    node_ids = object_identity.ObjectIdentityDictionary()
451    for node_id, node in enumerate(trackable_objects):
452      node_ids[node] = node_id
453    slot_variables = _serialize_slot_variables(
454        trackable_objects=trackable_objects,
455        node_ids=node_ids,
456        object_names=object_names)
457    return trackable_objects, path_to_root, node_ids, slot_variables
458
459  def objects_ids_and_slot_variables(self):
460    trackable_objects, _, node_ids, slot_variables = (
461        self.objects_ids_and_slot_variables_and_paths())
462    return trackable_objects, node_ids, slot_variables
463
464  def list_objects(self):
465    """Traverse the object graph and list all accessible objects."""
466    trackable_objects, _, _ = self.objects_ids_and_slot_variables()
467    return trackable_objects
468