1"""An object-local variable management scheme."""
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import abc
21import collections
22
23import six
24
25from tensorflow.python.eager import context
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import gen_io_ops as io_ops
32from tensorflow.python.platform import tf_logging as logging
33from tensorflow.python.training.saving import saveable_object
34from tensorflow.python.util import tf_contextlib
35from tensorflow.python.util import tf_decorator
36from tensorflow.python.util.tf_export import tf_export
37
38# Key where the object graph proto is saved in a TensorBundle
39OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH"
40
41# A key indicating a variable's value in an object's checkpointed Tensors
42# (Trackable._gather_saveables_for_checkpoint). If this is the only key and
43# the object has no dependencies, then its value may be restored on object
44# creation (avoiding double assignment when executing eagerly).
45VARIABLE_VALUE_KEY = "VARIABLE_VALUE"
46OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON"
47
48TrackableReference = collections.namedtuple(
49    "TrackableReference",
50    [
51        # The local name for this dependency.
52        "name",
53        # The Trackable object being referenced.
54        "ref"
55    ])
56
57
58# TODO(bfontain):  Update once sharded initialization interface is finalized.
59ShardInfo = collections.namedtuple(
60    "CheckpointInitialValueShardInfo", ["shape", "offset"])
61
62
63class CheckpointInitialValueCallable(object):
64  """A callable object that returns a CheckpointInitialValue.
65
66  See CheckpointInitialValue for more information.
67  """
68
69  def __init__(self, checkpoint_position):
70    self._checkpoint_position = checkpoint_position
71
72  @property
73  def checkpoint_position(self):
74    return self._checkpoint_position
75
76  def __call__(self, shape=None, dtype=None, shard_info=None):
77    # Note that the signature here is for compatibility with normal callable
78    # initializers which take shape and dtype. Although dtype isn't used, it
79    # will get passed in by a functool.partial_wrapper in places like
80    # base_layer_utils.py's make_variable.
81    return CheckpointInitialValue(
82        self._checkpoint_position, shape, shard_info=shard_info)
83
84  @property
85  def restore_uid(self):
86    return self._checkpoint_position.restore_uid
87
88
89class CheckpointInitialValue(ops.Tensor):
90  """Tensor wrapper for managing update UIDs in `Variables`.
91
92  When supplied as an initial value, objects of this type let a `Variable`
93  (`Variable`, `ResourceVariable`, etc.) know the UID of the restore the initial
94  value came from. This allows deferred restorations to be sequenced in the
95  order the user specified them, and lets us fall back on assignment if an
96  initial value is not set (e.g. due to a custom getter interfering).
97
98  See comments in _add_variable_with_custom_getter for more information about
99  how `CheckpointInitialValue` is used.
100  """
101
102  def __init__(self, checkpoint_position, shape=None, shard_info=None):
103    if shard_info:
104      full_shape_str = " ".join("%d" % d for d in shape) + " "
105      slice_spec = ":".join(
106          "%d,%d" % (o, s) for o, s in zip(shard_info.offset, shard_info.shape))
107      shape_and_slice = full_shape_str + slice_spec
108      # Override shape here so we set the correct shape below.
109      shape = shard_info.shape
110    else:
111      shape_and_slice = ""
112    self.wrapped_value = checkpoint_position.value_tensors(
113        {VARIABLE_VALUE_KEY: shape_and_slice})[VARIABLE_VALUE_KEY]
114    if shape:
115      # We need to set the static shape information on the initializer if
116      # possible so we don't get a variable with an unknown shape.
117      self.wrapped_value.set_shape(shape)
118    self._checkpoint_position = checkpoint_position
119
120  def __getattr__(self, attr):
121    try:
122      return getattr(self.wrapped_value, attr)
123    except AttributeError:
124      return self.__getattribute__(attr)
125
126  @property
127  def checkpoint_position(self):
128    return self._checkpoint_position
129
130
131class NoRestoreSaveable(saveable_object.SaveableObject):
132  """Embeds a tensor in a checkpoint with no restore ops."""
133
134  def __init__(self, tensor, name, dtype=None, device=None):
135    spec = saveable_object.SaveSpec(
136        tensor, "", name, dtype=dtype, device=device)
137    super(NoRestoreSaveable, self).__init__(tensor, [spec], name)
138
139  def restore(self, restored_tensors, restored_shapes):
140    return control_flow_ops.no_op()
141
142
143@six.add_metaclass(abc.ABCMeta)
144class PythonStateSaveable(saveable_object.SaveableObject):
145  """An interface for saving/restoring volatile Python state."""
146
147  @abc.abstractmethod
148  def feed_dict_additions(self):
149    """When running a graph, indicates fresh state to feed.
150
151    Returns:
152      A dictionary mapping `Tensor`s to current Python state.
153    """
154    pass
155
156  @abc.abstractmethod
157  def freeze(self):
158    """Create a new `SaveableObject` which freezes current state as a constant.
159
160    Used when executing eagerly to embed the current state as a constant, or
161    when creating a static tf.compat.v1.train.Saver with the frozen current
162    Python state.
163
164    Returns:
165      A `SaveableObject` which is not a `PythonStateSaveable` instance (i.e. has
166      no Python state associated with it).
167    """
168    pass
169
170
171class PythonStringStateSaveable(PythonStateSaveable):
172  """Saves Python state in a checkpoint."""
173
174  def __init__(self, name, state_callback, restore_callback=None):
175    """Configure saving.
176
177    Args:
178      name: The checkpoint key to write to.
179      state_callback: A function taking no arguments which returns a string.
180        This function is run every time a checkpoint is written.
181      restore_callback: A function taking a Python string, used to restore
182        state. Optional; defaults to doing nothing, in which case it is ignored
183        by status assertions such as assert_consumed().
184    """
185    self._has_trivial_state_callback = (restore_callback is None)
186
187    def _state_callback_wrapper():
188      with ops.init_scope():
189        return state_callback()
190
191    self._state_callback = _state_callback_wrapper
192    self._restore_callback = restore_callback
193    with ops.device("/cpu:0"):
194      self._save_string = constant_op.constant("", dtype=dtypes.string)
195    spec = saveable_object.SaveSpec(
196        self._save_string, "", name, dtype=dtypes.string)
197    super(PythonStringStateSaveable, self).__init__(self._save_string, [spec],
198                                                    name)
199
200  @property
201  def optional_restore(self):
202    """For values with no restore, relaxes assert_consumed()."""
203    return self._has_trivial_state_callback
204
205  def feed_dict_additions(self):
206    """When running a graph, indicates fresh state to feed."""
207    return {self._save_string: self._state_callback()}
208
209  def freeze(self):
210    """Create a frozen `SaveableObject` which saves the current state."""
211
212    def _constant_state():
213      return constant_op.constant(self._state_callback(), dtype=dtypes.string)
214
215    return NoRestoreSaveable(
216        tensor=_constant_state,
217        dtype=dtypes.string,
218        name=self.name,
219        device="cpu:0")
220
221  def python_restore(self, restored_strings):
222    """Called to restore Python state."""
223    if self._restore_callback:
224      restored, = restored_strings
225      self._restore_callback(restored)
226
227  def restore(self, restored_tensors, restored_shapes):
228    """Called to restore TensorFlow state (nothing to do)."""
229    return control_flow_ops.no_op()
230
231
232class CheckpointPosition(object):
233  """Indicates a position within a `_CheckpointRestoreCoordinator`."""
234
235  __slots__ = ["_checkpoint", "_proto_id"]
236
237  def __init__(self, checkpoint, proto_id):
238    """Specify an object within a checkpoint.
239
240    Args:
241      checkpoint: A _CheckpointRestoreCoordinator object.
242      proto_id: The index of this object in TrackableObjectGraph.nodes.
243    """
244    self._checkpoint = checkpoint
245    self._proto_id = proto_id
246
247  def restore(self, trackable):
248    """Restore this value into `trackable`."""
249    with ops.init_scope():
250      if self.bind_object(trackable):
251        # This object's correspondence with a checkpointed object is new, so
252        # process deferred restorations for it and its dependencies.
253        restore_ops = trackable._restore_from_checkpoint_position(self)  # pylint: disable=protected-access
254        if restore_ops:
255          self._checkpoint.new_restore_ops(restore_ops)
256
257  def bind_object(self, trackable):
258    """Set a checkpoint<->object correspondence and process slot variables.
259
260    Args:
261      trackable: The object to record a correspondence for.
262
263    Returns:
264      True if this is a new assignment, False if this object has already been
265      mapped to a checkpointed `Object` proto.
266    Raises:
267      AssertionError: If another object is already bound to the `Object` proto.
268    """
269    checkpoint = self.checkpoint
270    checkpoint.all_python_objects.add(trackable)
271    current_assignment = checkpoint.object_by_proto_id.get(self._proto_id, None)
272    checkpoint.matched_proto_ids.add(self._proto_id)
273    if current_assignment is None:
274      checkpoint.object_by_proto_id[self._proto_id] = trackable
275      for deferred_slot_restoration in (
276          checkpoint.deferred_slot_restorations.pop(self._proto_id, ())):
277        trackable._create_or_restore_slot_variable(  # pylint: disable=protected-access
278            slot_variable_position=CheckpointPosition(
279                checkpoint=checkpoint,
280                proto_id=deferred_slot_restoration.slot_variable_id),
281            variable=deferred_slot_restoration.original_variable,
282            slot_name=deferred_slot_restoration.slot_name)
283      for slot_restoration in checkpoint.slot_restorations.pop(
284          self._proto_id, ()):
285        optimizer_object = checkpoint.object_by_proto_id.get(
286            slot_restoration.optimizer_id, None)
287        if optimizer_object is None:
288          # The optimizer has not yet been created or tracked. Record in the
289          # checkpoint that the slot variables need to be restored when it is.
290          checkpoint.deferred_slot_restorations.setdefault(
291              slot_restoration.optimizer_id, []).append(
292                  _DeferredSlotVariableRestoration(
293                      original_variable=trackable,
294                      slot_variable_id=slot_restoration.slot_variable_id,
295                      slot_name=slot_restoration.slot_name))
296
297        # `optimizer_object` can be a `Checkpoint` when user only needs the
298        # attributes the optimizer holds, such as `iterations`. In those cases,
299        # it would not have the optimizer's `_create_or_restore_slot_variable`
300        # method.
301        elif hasattr(optimizer_object, "_create_or_restore_slot_variable"):
302          optimizer_object._create_or_restore_slot_variable(  # pylint: disable=protected-access
303              slot_variable_position=CheckpointPosition(
304                  checkpoint=checkpoint,
305                  proto_id=slot_restoration.slot_variable_id),
306              variable=trackable,
307              slot_name=slot_restoration.slot_name)
308      return True  # New assignment
309    else:
310      # The object was already mapped for this checkpoint load, which means
311      # we don't need to do anything besides check that the mapping is
312      # consistent (if the dependency DAG is not a tree then there are
313      # multiple paths to the same object).
314      if current_assignment is not trackable:
315        logging.warning((
316            "Inconsistent references when loading the checkpoint into this "
317            "object graph. Either the Trackable object references in the "
318            "Python program have changed in an incompatible way, or the "
319            "checkpoint was generated in an incompatible program.\n\nTwo "
320            "checkpoint references resolved to different objects (%s and %s)."),
321                        current_assignment, trackable)
322      return False  # Not a new assignment
323
324  def is_simple_variable(self):
325    """Determine whether this value is restorable with a Tensor initializer."""
326    attributes = self.object_proto.attributes
327    return (len(attributes) == 1 and
328            attributes[0].name == VARIABLE_VALUE_KEY and
329            not self.object_proto.children)
330
331  def value_tensors(self, shape_and_slices=None):
332    """Create value `Tensor`s for this object's attributes.
333
334    Does not require that the Python object has been created. Used for
335    restore-on-create when executing eagerly.
336
337    Args:
338      shape_and_slices: A dict mapping from object attribute names to a shape
339        and slice string that will be passed to a RestoreV2 op. If the dict is
340        None or if an object attribute is not in the dict, the full tensor will
341        be restored.
342
343    Returns:
344      A dictionary mapping from object attribute names to `Tensor`s.
345    """
346    value_tensors = {}
347    for serialized_tensor in self.object_proto.attributes:
348      checkpoint_key = serialized_tensor.checkpoint_key
349      dtype = self._checkpoint.dtype_map[checkpoint_key]
350      base_type = dtype.base_dtype
351      io_device = self._checkpoint.options.experimental_io_device or "cpu:0"
352      with ops.init_scope():
353        with ops.device(io_device):
354          # Run the restore itself on the io_device(CPU or specified).
355          if (shape_and_slices is not None and
356              serialized_tensor.name in shape_and_slices):
357            shape_and_slice = shape_and_slices[serialized_tensor.name]
358          else:
359            shape_and_slice = ""
360          value, = io_ops.restore_v2(
361              prefix=self._checkpoint.save_path_tensor,
362              tensor_names=[checkpoint_key],
363              shape_and_slices=[shape_and_slice],
364              dtypes=[base_type],
365              name="%s_checkpoint_read" % (serialized_tensor.name,))
366        # Copy the value to the current device if necessary.
367        value_tensors[serialized_tensor.name] = array_ops.identity(value)
368    return value_tensors
369
370  def gather_ops_or_named_saveables(self):
371    """Looks up or creates SaveableObjects which don't have cached ops."""
372    saveables = self.trackable._gather_saveables_for_checkpoint()  # pylint: disable=protected-access
373    # Name saveables based on the name this object had when it was checkpointed.
374    named_saveables = {}
375    python_saveables = []
376    existing_restore_ops = []
377    for serialized_tensor in self.object_proto.attributes:
378      if context.executing_eagerly():
379        existing_op = None
380      else:
381        existing_op = self._checkpoint.restore_ops_by_name.get(
382            serialized_tensor.checkpoint_key, None)
383      if existing_op is not None:
384        existing_restore_ops.append(existing_op)
385        continue
386
387      # Only if we don't have cached ops for this SaveableObject, we'll see if
388      # the SaveableObject itself has been cached. If not, we'll make it, and
389      # either way we'll extract new ops from it (or if it has Python state to
390      # restore, we'll run that).
391      saveables_cache = self._checkpoint.graph_view.saveables_cache
392      if saveables_cache is None:
393        # No SaveableObject caching when executing eagerly.
394        saveable = None
395      else:
396        # If we've already created and cached a SaveableObject for this
397        # attribute, we can re-use it to avoid re-creating some ops when graph
398        # building.
399        saveable_list = saveables_cache.get(self.trackable,
400                                            {}).get(serialized_tensor.name,
401                                                    (None,))
402        if len(saveable_list) == 1:
403          # Almost every attribute will have exactly one SaveableObject.
404          saveable, = saveable_list
405        else:
406          # Don't use cached SaveableObjects for partitioned variables, which is
407          # the only case where we'd have a list of SaveableObjects. Op caching
408          # will catch them.
409          saveable = None
410      if saveable is not None:
411        # The name of this attribute has changed, so we need to re-generate
412        # the SaveableObject.
413        if serialized_tensor.checkpoint_key not in saveable.name:
414          saveable = None
415          del saveables_cache[self.trackable]
416      if saveable is None:
417        # If there was no cached SaveableObject, we should check if the Python
418        # object has the attribute.
419        saveable_factory = saveables.get(serialized_tensor.name, None)
420        if saveable_factory is None:
421          # Purposefully does not throw an exception if attributes have been
422          # added or deleted. Stores unused attributes so an exception can be
423          # raised if the user decides to check that everything in the
424          # checkpoint was loaded.
425          if not serialized_tensor.optional_restore:
426            self._checkpoint.unused_attributes.setdefault(
427                self._proto_id, []).append(serialized_tensor.name)
428          continue
429        if callable(saveable_factory):
430          saveable = saveable_factory(name=serialized_tensor.checkpoint_key)
431        else:
432          saveable = saveable_factory
433        if saveables_cache is not None:
434          saveables_cache.setdefault(self.trackable,
435                                     {})[serialized_tensor.name] = [saveable]
436      if isinstance(saveable, PythonStateSaveable):
437        python_saveables.append(saveable)
438      else:
439        named_saveables[serialized_tensor.checkpoint_key] = saveable
440    return existing_restore_ops, named_saveables, python_saveables
441
442  def restore_ops(self):
443    """Create or fetch restore ops for this object's attributes.
444
445    Requires that the `Trackable` Python object has been bound to an object
446    ID in the checkpoint.
447
448    Returns:
449      A list of operations when graph building, or an empty list when executing
450      eagerly.
451    """
452    (restore_ops, tensor_saveables,
453     python_saveables) = self.gather_ops_or_named_saveables()
454    restore_ops.extend(
455        self._checkpoint.restore_saveables(tensor_saveables, python_saveables))
456    return restore_ops
457
458  @property
459  def checkpoint(self):
460    return self._checkpoint
461
462  @property
463  def trackable(self):
464    return self._checkpoint.object_by_proto_id[self._proto_id]
465
466  @property
467  def object_proto(self):
468    return self._checkpoint.object_graph_proto.nodes[self._proto_id]
469
470  @property
471  def restore_uid(self):
472    return self._checkpoint.restore_uid
473
474  def __repr__(self):
475    return repr(self.object_proto)
476
477
478_DeferredSlotVariableRestoration = collections.namedtuple(
479    "_DeferredSlotVariableRestoration", [
480        "original_variable",
481        "slot_variable_id",
482        "slot_name",
483    ])
484
485_SlotVariableRestoration = collections.namedtuple(
486    "_SlotVariableRestoration",
487    [
488        # The checkpoint proto id of the optimizer object.
489        "optimizer_id",
490        # The checkpoint proto id of the slot variable.
491        "slot_variable_id",
492        "slot_name",
493    ])
494
495
496def no_automatic_dependency_tracking(method):
497  """Disables automatic dependency tracking on attribute assignment.
498
499  Use to decorate any method of a Trackable object. Attribute assignment in
500  that method will not add dependencies (also respected in Model). Harmless if
501  used in a class which does not do automatic dependency tracking (which means
502  it's safe to use in base classes which may have subclasses which also inherit
503  from Trackable).
504
505  Args:
506    method: The method to decorate.
507
508  Returns:
509    A decorated method which sets and un-sets automatic dependency tracking for
510    the object the method is called on (not thread safe).
511  """
512
513  def _method_wrapper(self, *args, **kwargs):
514    previous_value = getattr(self, "_self_setattr_tracking", True)
515    self._self_setattr_tracking = False  # pylint: disable=protected-access
516    try:
517      result = method(self, *args, **kwargs)
518    finally:
519      self._self_setattr_tracking = previous_value  # pylint: disable=protected-access
520    return result
521
522  return tf_decorator.make_decorator(
523      target=method, decorator_func=_method_wrapper)
524
525
526@tf_contextlib.contextmanager
527def no_manual_dependency_tracking_scope(obj):
528  """A context that disables manual dependency tracking for the given `obj`.
529
530  Sometimes library methods might track objects on their own and we might want
531  to disable that and do the tracking on our own. One can then use this context
532  manager to disable the tracking the library method does and do your own
533  tracking.
534
535  For example:
536
537  class TestLayer(tf.keras.Layer):
538    def build():
539      with no_manual_dependency_tracking_scope(self):
540        var = self.add_variable("name1")  # Creates a var and doesn't track it
541      self._track_trackable("name2", var)  # We track variable with name `name2`
542
543  Args:
544    obj: A trackable object.
545
546  Yields:
547    a scope in which the object doesn't track dependencies manually.
548  """
549  # pylint: disable=protected-access
550  previous_value = getattr(obj, "_manual_tracking", True)
551  obj._manual_tracking = False
552  try:
553    yield
554  finally:
555    obj._manual_tracking = previous_value
556
557
558@tf_contextlib.contextmanager
559def no_automatic_dependency_tracking_scope(obj):
560  """A context that disables automatic dependency tracking when assigning attrs.
561
562  Objects that inherit from Autotrackable automatically creates dependencies
563  to trackable objects through attribute assignments, and wraps data structures
564  (lists or dicts) with trackable classes. This scope may be used to temporarily
565  disable this behavior. This works similar to the decorator
566  `no_automatic_dependency_tracking`.
567
568  Example usage:
569  ```
570  model = tf.keras.Model()
571  model.arr1 = []  # Creates a ListWrapper object
572  with no_automatic_dependency_tracking_scope(model):
573    model.arr2 = []  # Creates a regular, untracked python list
574  ```
575
576  Args:
577    obj: A trackable object.
578
579  Yields:
580    a scope in which the object doesn't track dependencies.
581  """
582  previous_value = getattr(obj, "_setattr_tracking", True)
583  obj._setattr_tracking = False  # pylint: disable=protected-access
584  try:
585    yield
586  finally:
587    obj._setattr_tracking = previous_value  # pylint: disable=protected-access
588
589
590@tf_export("__internal__.tracking.Trackable", v1=[])
591class Trackable(object):
592  """Base class for `Trackable` objects without automatic dependencies.
593
594  This class has no __setattr__ override for performance reasons. Dependencies
595  must be added explicitly. Unless attribute assignment is performance-critical,
596  use `AutoTrackable` instead. Use `Trackable` for `isinstance`
597  checks.
598  """
599
600  # For compatibility with wrapt.ObjectProxy, attributes are all prefixed with
601  # _self_. We have some properties to forward semi-public attributes to their
602  # _self_ equivalents.
603
604  @property
605  def _setattr_tracking(self):
606    if not hasattr(self, "_self_setattr_tracking"):
607      self._self_setattr_tracking = True
608    return self._self_setattr_tracking
609
610  @_setattr_tracking.setter
611  def _setattr_tracking(self, value):
612    self._self_setattr_tracking = value
613
614  @property
615  def _update_uid(self):
616    return self._self_update_uid
617
618  @_update_uid.setter
619  def _update_uid(self, value):
620    self._self_update_uid = value
621
622  @property
623  def _unconditional_checkpoint_dependencies(self):
624    return self._self_unconditional_checkpoint_dependencies
625
626  @property
627  def _unconditional_dependency_names(self):
628    return self._self_unconditional_dependency_names
629
630  @property
631  def _name_based_restores(self):
632    return self._self_name_based_restores
633
634  # Trackable does not do automatic dependency tracking, but uses the
635  # no_automatic_dependency_tracking decorator so it can avoid adding
636  # dependencies if a subclass is Trackable / inherits from Model (both of
637  # which have __setattr__ overrides).
638  @no_automatic_dependency_tracking
639  def _maybe_initialize_trackable(self):
640    """Initialize dependency management.
641
642    Not __init__, since most objects will forget to call it.
643    """
644    if hasattr(self, "_self_unconditional_checkpoint_dependencies"):
645      # __init__ already called. This check means that we don't need
646      # Trackable.__init__() in the constructor of every TensorFlow object.
647      return
648    # A list of TrackableReference objects. Some classes implementing
649    # `Trackable`, notably `Optimizer`s, may override the
650    # _checkpoint_dependencies property with conditional dependencies
651    # (e.g. based on the current graph when saving).
652    self._self_unconditional_checkpoint_dependencies = []
653    # Maps names -> Trackable objects
654    self._self_unconditional_dependency_names = {}
655    # Restorations for other Trackable objects on which this object may
656    # eventually depend. Maps local name -> CheckpointPosition list. Optimizers
657    # tack on conditional dependencies, and so need separate management of
658    # deferred dependencies too.
659    self._self_unconditional_deferred_dependencies = {}
660    # The UID of the highest assignment to this object. Used to ensure that the
661    # last requested assignment determines the final value of an object.
662    if hasattr(self, "_self_update_uid"):
663      raise AssertionError(
664          "Internal error: the object had an update UID set before its "
665          "initialization code was run.")
666    self._self_update_uid = -1
667    # When executing eagerly, holds a collection of _NameBasedRestoreCoordinator
668    # instances, which should be checked when creating variables or other
669    # saveables. These are passed on recursively to all dependencies, since
670    # unlike object-based checkpoint restores we don't know which subgraph is
671    # being restored in advance. This mechanism is only necessary for
672    # restore-on-create when executing eagerly, and so is unused when graph
673    # building.
674    self._self_name_based_restores = set()
675
676    # Dictionary of SaveableObjects factories. This dictionary is defined when
677    # the object is loaded from the SavedModel. When writing a custom class,
678    # prefer overriding "_gather_saveables_from_checkpoint" to using this
679    # attribute.
680    self._self_saveable_object_factories = {}
681
682  @property
683  def _object_identifier(self):
684    """String used to identify this object in a SavedModel.
685
686    Generally, the object identifier is constant across objects of the same
687    class, while the metadata field is used for instance-specific data.
688
689    Returns:
690      String object identifier.
691    """
692    return "_generic_user_object"
693
694  @property
695  def _tracking_metadata(self):
696    """String containing object metadata, which is saved to the SavedModel."""
697    return ""
698
699  def _no_dependency(self, value):
700    """If automatic dependency tracking is enabled, ignores `value`."""
701    return value
702
703  def _name_based_attribute_restore(self, checkpoint):
704    """Restore the object's attributes from a name-based checkpoint."""
705    self._self_name_based_restores.add(checkpoint)
706    if self._self_update_uid < checkpoint.restore_uid:
707      checkpoint.eager_restore(self)
708      self._self_update_uid = checkpoint.restore_uid
709
710  @property
711  def _checkpoint_dependencies(self):
712    """All dependencies of this object.
713
714    May be overridden to include conditional dependencies.
715
716    Returns:
717      A list of `TrackableReference` objects indicating named
718      `Trackable` dependencies which should be saved along with this
719      object.
720    """
721    return self._self_unconditional_checkpoint_dependencies
722
723  @property
724  def _deferred_dependencies(self):
725    """A dictionary with deferred dependencies.
726
727    Stores restorations for other Trackable objects on which this object
728    may eventually depend. May be overridden by sub-classes (e.g. Optimizers use
729    conditional dependencies based the current graph, and so need separate
730    management of deferred dependencies too).
731
732    Returns:
733      A dictionary mapping from local name to a list of CheckpointPosition
734      objects.
735    """
736    return self._self_unconditional_deferred_dependencies
737
738  def _lookup_dependency(self, name):
739    """Look up a dependency by name.
740
741    May be overridden to include conditional dependencies.
742
743    Args:
744      name: The local name of the dependency.
745
746    Returns:
747      A `Trackable` object, or `None` if no dependency by this name was
748      found.
749    """
750    return self._self_unconditional_dependency_names.get(name, None)
751
752  def _add_variable_with_custom_getter(self,
753                                       name,
754                                       shape=None,
755                                       dtype=dtypes.float32,
756                                       initializer=None,
757                                       getter=None,
758                                       overwrite=False,
759                                       **kwargs_for_getter):
760    """Restore-on-create for a variable be saved with this `Trackable`.
761
762    If the user has requested that this object or another `Trackable` which
763    depends on this object be restored from a checkpoint (deferred loading
764    before variable object creation), `initializer` may be ignored and the value
765    from the checkpoint used instead.
766
767    Args:
768      name: A name for the variable. Must be unique within this object.
769      shape: The shape of the variable.
770      dtype: The data type of the variable.
771      initializer: The initializer to use. Ignored if there is a deferred
772        restoration left over from a call to
773        `_restore_from_checkpoint_position`.
774      getter: The getter to wrap which actually fetches the variable.
775      overwrite: If True, disables unique name and type checks.
776      **kwargs_for_getter: Passed to the getter.
777
778    Returns:
779      The new variable object.
780
781    Raises:
782      ValueError: If the variable name is not unique.
783    """
784    self._maybe_initialize_trackable()
785    with ops.init_scope():
786      if context.executing_eagerly():
787        # If this is a variable with a single Tensor stored in the checkpoint,
788        # we can set that value as an initializer rather than initializing and
789        # then assigning (when executing eagerly). This call returns None if
790        # there is nothing to restore.
791        checkpoint_initializer = self._preload_simple_restoration(
792            name=name)
793      else:
794        checkpoint_initializer = None
795      if (checkpoint_initializer is not None and
796          not (isinstance(initializer, CheckpointInitialValueCallable) and
797               (initializer.restore_uid > checkpoint_initializer.restore_uid))):
798        # If multiple Trackable objects are "creating" the same variable
799        # via the magic of custom getters, the one with the highest restore UID
800        # (the one called last) has to make the final initializer. If another
801        # custom getter interrupts this process by overwriting the initializer,
802        # then we'll catch that when we call _track_trackable. So this is
803        # "best effort" to set the initializer with the highest restore UID.
804        initializer = checkpoint_initializer
805    new_variable = getter(
806        name=name,
807        shape=shape,
808        dtype=dtype,
809        initializer=initializer,
810        **kwargs_for_getter)
811
812    # If we set an initializer and the variable processed it, tracking will not
813    # assign again. It will add this variable to our dependencies, and if there
814    # is a non-trivial restoration queued, it will handle that. This also
815    # handles slot variables.
816    if not overwrite or isinstance(new_variable, Trackable):
817      return self._track_trackable(new_variable, name=name, overwrite=overwrite)
818    else:
819      # TODO(allenl): Some variable types are not yet supported. Remove this
820      # fallback once all get_variable() return types are Trackable.
821      return new_variable
822
823  def _preload_simple_restoration(self, name):
824    """Return a dependency's value for restore-on-create.
825
826    Note the restoration is not deleted; if for some reason preload is called
827    and then not assigned to the variable (for example because a custom getter
828    overrides the initializer), the assignment will still happen once the
829    variable is tracked (determined based on checkpoint.restore_uid).
830
831    Args:
832      name: The object-local name of the dependency holding the variable's
833        value.
834
835    Returns:
836      An callable for use as a variable's initializer/initial_value, or None if
837      one should not be set (either because there was no variable with this name
838      in the checkpoint or because it needs more complex deserialization). Any
839      non-trivial deserialization will happen when the variable object is
840      tracked.
841    """
842    deferred_dependencies_list = self._deferred_dependencies.get(name, ())
843    if not deferred_dependencies_list:
844      # Nothing to do; we don't have a restore for this dependency queued up.
845      return
846    for checkpoint_position in deferred_dependencies_list:
847      if not checkpoint_position.is_simple_variable():
848        # If _any_ pending restoration is too complicated to fit in an
849        # initializer (because it has dependencies, or because there are
850        # multiple Tensors to restore), bail and let the general tracking code
851        # handle it.
852        return None
853    checkpoint_position = max(
854        deferred_dependencies_list,
855        key=lambda restore: restore.checkpoint.restore_uid)
856    return CheckpointInitialValueCallable(
857        checkpoint_position=checkpoint_position)
858
859  def _track_trackable(self, trackable, name, overwrite=False):
860    """Declare a dependency on another `Trackable` object.
861
862    Indicates that checkpoints for this object should include variables from
863    `trackable`.
864
865    Variables in a checkpoint are mapped to `Trackable`s based on the names
866    provided when the checkpoint was written. To avoid breaking existing
867    checkpoints when modifying a class, neither variable names nor dependency
868    names (the names passed to `_track_trackable`) may change.
869
870    Args:
871      trackable: A `Trackable` which this object depends on.
872      name: A local name for `trackable`, used for loading checkpoints into the
873        correct objects.
874      overwrite: Boolean, whether silently replacing dependencies is OK. Used
875        for __setattr__, where throwing an error on attribute reassignment would
876        be inappropriate.
877
878    Returns:
879      `trackable`, for convenience when declaring a dependency and
880      assigning to a member variable in one statement.
881
882    Raises:
883      TypeError: If `trackable` does not inherit from `Trackable`.
884      ValueError: If another object is already tracked by this name.
885    """
886    self._maybe_initialize_trackable()
887    if not isinstance(trackable, Trackable):
888      raise TypeError(("Trackable._track_trackable() passed type %s, not a "
889                       "Trackable.") % (type(trackable),))
890    if not getattr(self, "_manual_tracking", True):
891      return trackable
892    new_reference = TrackableReference(name=name, ref=trackable)
893    current_object = self._lookup_dependency(name)
894    if (current_object is not None and current_object is not trackable):
895      if not overwrite:
896        raise ValueError(
897            ("Called Trackable._track_trackable() with name='%s', "
898             "but a Trackable with this name is already declared as a "
899             "dependency. Names must be unique (or overwrite=True).") % (name,))
900      # This is a weird thing to do, but we're not going to stop people from
901      # using __setattr__.
902      for index, (old_name, _) in enumerate(
903          self._self_unconditional_checkpoint_dependencies):
904        if name == old_name:
905          self._self_unconditional_checkpoint_dependencies[
906              index] = new_reference
907    elif current_object is None:
908      self._self_unconditional_checkpoint_dependencies.append(new_reference)
909      self._handle_deferred_dependencies(name=name, trackable=trackable)
910    self._self_unconditional_dependency_names[name] = trackable
911    return trackable
912
913  def _handle_deferred_dependencies(self, name, trackable):
914    """Pop and load any deferred checkpoint restores into `trackable`.
915
916    This method does not add a new dependency on `trackable`, but it does
917    check if any outstanding/deferred dependencies have been queued waiting for
918    this dependency to be added (matched based on `name`). If so,
919    `trackable` and its dependencies are restored. The restorations are
920    considered fulfilled and so are deleted.
921
922    `_track_trackable` is more appropriate for adding a
923    normal/unconditional dependency, and includes handling for deferred
924    restorations. This method allows objects such as `Optimizer` to use the same
925    restoration logic while managing conditional dependencies themselves, by
926    overriding `_checkpoint_dependencies` and `_lookup_dependency` to change the
927    object's dependencies based on the context it is saved/restored in (a single
928    optimizer instance can have state associated with multiple graphs).
929
930    Args:
931      name: The name of the dependency within this object (`self`), used to
932        match `trackable` with values saved in a checkpoint.
933      trackable: The Trackable object to restore (inheriting from `Trackable`).
934    """
935    self._maybe_initialize_trackable()
936    trackable._maybe_initialize_trackable()  # pylint: disable=protected-access
937    deferred_dependencies_list = self._deferred_dependencies.pop(name, ())
938    for checkpoint_position in sorted(
939        deferred_dependencies_list,
940        key=lambda restore: restore.checkpoint.restore_uid,
941        reverse=True):
942      checkpoint_position.restore(trackable)
943
944    # Pass on any name-based restores queued in this object.
945    for name_based_restore in sorted(
946        self._self_name_based_restores,
947        key=lambda checkpoint: checkpoint.restore_uid,
948        reverse=True):
949      trackable._name_based_attribute_restore(name_based_restore)  # pylint: disable=protected-access
950
951  def _restore_from_checkpoint_position(self, checkpoint_position):
952    """Restore this object and its dependencies (may be deferred)."""
953    # Attempt a breadth-first traversal, since presumably the user has more
954    # control over shorter paths. If we don't have all of the dependencies at
955    # this point, the end result is not breadth-first (since other deferred
956    # traversals will happen later).
957    visit_queue = collections.deque([checkpoint_position])
958    restore_ops = []
959    tensor_saveables = {}
960    python_saveables = []
961    while visit_queue:
962      current_position = visit_queue.popleft()
963      new_restore_ops, new_tensor_saveables, new_python_saveables = (
964          current_position.trackable  # pylint: disable=protected-access
965          ._single_restoration_from_checkpoint_position(
966              checkpoint_position=current_position,
967              visit_queue=visit_queue))
968      restore_ops.extend(new_restore_ops)
969      tensor_saveables.update(new_tensor_saveables)
970      python_saveables.extend(new_python_saveables)
971    restore_ops.extend(
972        current_position.checkpoint.restore_saveables(
973            tensor_saveables, python_saveables))
974    return restore_ops
975
976  def _single_restoration_from_checkpoint_position(self, checkpoint_position,
977                                                   visit_queue):
978    """Restore this object, and either queue its dependencies or defer them."""
979    self._maybe_initialize_trackable()
980    checkpoint = checkpoint_position.checkpoint
981    # If the UID of this restore is lower than our current update UID, we don't
982    # need to actually restore the object. However, we should pass the
983    # restoration on to our dependencies.
984    if checkpoint.restore_uid > self._self_update_uid:
985      restore_ops, tensor_saveables, python_saveables = (
986          checkpoint_position.gather_ops_or_named_saveables())
987      self._self_update_uid = checkpoint.restore_uid
988    else:
989      restore_ops = ()
990      tensor_saveables = {}
991      python_saveables = ()
992    for child in checkpoint_position.object_proto.children:
993      child_position = CheckpointPosition(
994          checkpoint=checkpoint, proto_id=child.node_id)
995      local_object = self._lookup_dependency(child.local_name)
996      if local_object is None:
997        # We don't yet have a dependency registered with this name. Save it
998        # in case we do.
999        self._deferred_dependencies.setdefault(child.local_name,
1000                                               []).append(child_position)
1001      else:
1002        if child_position.bind_object(trackable=local_object):
1003          # This object's correspondence is new, so dependencies need to be
1004          # visited. Delay doing it so that we get a breadth-first dependency
1005          # resolution order (shallowest paths first). The caller is responsible
1006          # for emptying visit_queue.
1007          visit_queue.append(child_position)
1008    return restore_ops, tensor_saveables, python_saveables
1009
1010  def _gather_saveables_for_checkpoint(self):
1011    """Returns a dictionary of values to checkpoint with this object.
1012
1013    Keys in the returned dictionary are local to this object and in a separate
1014    namespace from dependencies. Values may either be `SaveableObject` factories
1015    or variables easily converted to `SaveableObject`s (as in
1016    `tf.compat.v1.train.Saver`'s
1017    `var_list` constructor argument).
1018
1019    `SaveableObjects` have a name set, which Trackable needs to generate
1020    itself. So rather than returning `SaveableObjects` directly, this method
1021    should return a dictionary of callables which take `name` arguments and
1022    return `SaveableObjects` with that name.
1023
1024    If this object may also be passed to the global-name-based
1025    `tf.compat.v1.train.Saver`,
1026    the returned callables should have a default value for their name argument
1027    (i.e. be callable with no arguments).
1028
1029    Returned values must be saved only by this object; if any value may be
1030    shared, it should instead be a dependency. For example, variable objects
1031    save their own values with the key `VARIABLE_VALUE_KEY`, but objects which
1032    reference variables simply add a dependency.
1033
1034    Returns:
1035      The dictionary mapping attribute names to `SaveableObject` factories
1036      described above. For example:
1037      {VARIABLE_VALUE_KEY:
1038       lambda name="global_name_for_this_object":
1039       SaveableObject(name=name, ...)}
1040    """
1041    return self._self_saveable_object_factories
1042
1043  def _list_extra_dependencies_for_serialization(self, serialization_cache):
1044    """Lists extra dependencies to serialize.
1045
1046    Internal sub-classes can override this method to return extra dependencies
1047    that should be saved with the object during SavedModel serialization. For
1048    example, this is used to save `trainable_variables` in Keras models. The
1049    python property `trainable_variables` contains logic to iterate through the
1050    weights from the model and its sublayers. The serialized Keras model saves
1051    `trainable_weights` as a trackable list of variables.
1052
1053    PLEASE NOTE when overriding this method:
1054      1. This function may only generate new trackable objects the first time it
1055         is called.
1056      2. The returned dictionary must not have naming conflicts with
1057         dependencies tracked by the root. In other words, if the root is
1058         tracking `object_1` with name 'x', and this functions returns
1059         `{'x': object_2}`, an error is raised when saving.
1060
1061    Args:
1062      serialization_cache: A dictionary shared between all objects in the same
1063        object graph. This object is passed to both
1064        `_list_extra_dependencies_for_serialization` and
1065        `_list_functions_for_serialization`.
1066
1067    Returns:
1068      A dictionary mapping attribute names to trackable objects.
1069    """
1070    del serialization_cache
1071    return dict()
1072
1073  def _list_functions_for_serialization(self, serialization_cache):
1074    """Lists the functions of this trackable to serialize.
1075
1076    Internal sub-classes can override this with specific logic. E.g.
1077    `AutoTrackable` provides an implementation that returns the `attr`
1078    that return functions.
1079
1080    Args:
1081      serialization_cache: Dictionary passed to all objects in the same object
1082        graph during serialization.
1083
1084    Returns:
1085        A dictionary mapping attribute names to `Function` or
1086        `ConcreteFunction`.
1087    """
1088    del serialization_cache
1089    return dict()
1090
1091  def _map_resources(self, save_options):  # pylint: disable=unused-argument
1092    """Makes new resource handle ops corresponding to existing resource tensors.
1093
1094    Internal sub-classes can override this to inform model saving how to add new
1095    resource handle ops to the main GraphDef of a SavedModel (TF 1.x style
1096    graph), which allows session based APIs (e.g, C++ loader API) to interact
1097    with resources owned by this object.
1098
1099    Args:
1100      save_options: A tf.saved_model.SaveOptions instance.
1101
1102    Returns:
1103      A tuple of (object_map, resource_map):
1104        object_map: A dictionary mapping from objects that hold existing
1105          resource tensors to replacement objects created to hold the new
1106          resource tensors.
1107        resource_map: A dictionary mapping from existing resource tensors to
1108          newly created resource tensors.
1109    """
1110    return {}, {}
1111