1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Handles types registrations for tf.saved_model.load."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.framework import versions_pb2
22from tensorflow.core.protobuf import saved_object_graph_pb2
23
24
25class VersionedTypeRegistration(object):
26  """Holds information about one version of a revived type."""
27
28  def __init__(self, object_factory, version, min_producer_version,
29               min_consumer_version, bad_consumers=None, setter=setattr):
30    """Identify a revived type version.
31
32    Args:
33      object_factory: A callable which takes a SavedUserObject proto and returns
34        a trackable object. Dependencies are added later via `setter`.
35      version: An integer, the producer version of this wrapper type. When
36        making incompatible changes to a wrapper, add a new
37        `VersionedTypeRegistration` with an incremented `version`. The most
38        recent version will be saved, and all registrations with a matching
39        identifier will be searched for the highest compatible version to use
40        when loading.
41      min_producer_version: The minimum producer version number required to use
42        this `VersionedTypeRegistration` when loading a proto.
43      min_consumer_version: `VersionedTypeRegistration`s with a version number
44        less than `min_consumer_version` will not be used to load a proto saved
45        with this object. `min_consumer_version` should be set to the lowest
46        version number which can successfully load protos saved by this
47        object. If no matching registration is available on load, the object
48        will be revived with a generic trackable type.
49
50        `min_consumer_version` and `bad_consumers` are a blunt tool, and using
51        them will generally break forward compatibility: previous versions of
52        TensorFlow will revive newly saved objects as opaque trackable
53        objects rather than wrapped objects. When updating wrappers, prefer
54        saving new information but preserving compatibility with previous
55        wrapper versions. They are, however, useful for ensuring that
56        previously-released buggy wrapper versions degrade gracefully rather
57        than throwing exceptions when presented with newly-saved SavedModels.
58      bad_consumers: A list of consumer versions which are incompatible (in
59        addition to any version less than `min_consumer_version`).
60      setter: A callable with the same signature as `setattr` to use when adding
61        dependencies to generated objects.
62    """
63    self.setter = setter
64    self.identifier = None  # Set after registration
65    self._object_factory = object_factory
66    self.version = version
67    self._min_consumer_version = min_consumer_version
68    self._min_producer_version = min_producer_version
69    if bad_consumers is None:
70      bad_consumers = []
71    self._bad_consumers = bad_consumers
72
73  def to_proto(self):
74    """Create a SavedUserObject proto."""
75    # For now wrappers just use dependencies to save their state, so the
76    # SavedUserObject doesn't depend on the object being saved.
77    # TODO(allenl): Add a wrapper which uses its own proto.
78    return saved_object_graph_pb2.SavedUserObject(
79        identifier=self.identifier,
80        version=versions_pb2.VersionDef(
81            producer=self.version,
82            min_consumer=self._min_consumer_version,
83            bad_consumers=self._bad_consumers))
84
85  def from_proto(self, proto):
86    """Recreate a trackable object from a SavedUserObject proto."""
87    return self._object_factory(proto)
88
89  def should_load(self, proto):
90    """Checks if this object should load the SavedUserObject `proto`."""
91    if proto.identifier != self.identifier:
92      return False
93    if self.version < proto.version.min_consumer:
94      return False
95    if proto.version.producer < self._min_producer_version:
96      return False
97    for bad_version in proto.version.bad_consumers:
98      if self.version == bad_version:
99        return False
100    return True
101
102
103# string identifier -> (predicate, [VersionedTypeRegistration])
104_REVIVED_TYPE_REGISTRY = {}
105_TYPE_IDENTIFIERS = []
106
107
108def register_revived_type(identifier, predicate, versions):
109  """Register a type for revived objects.
110
111  Args:
112    identifier: A unique string identifying this class of objects.
113    predicate: A Boolean predicate for this registration. Takes a
114      trackable object as an argument. If True, `type_registration` may be
115      used to save and restore the object.
116    versions: A list of `VersionedTypeRegistration` objects.
117  """
118  # Keep registrations in order of version. We always use the highest matching
119  # version (respecting the min consumer version and bad consumers).
120  versions.sort(key=lambda reg: reg.version, reverse=True)
121  if not versions:
122    raise AssertionError("Need at least one version of a registered type.")
123  version_numbers = set()
124  for registration in versions:
125    # Copy over the identifier for use in generating protos
126    registration.identifier = identifier
127    if registration.version in version_numbers:
128      raise AssertionError(
129          "Got multiple registrations with version {} for type {}".format(
130              registration.version, identifier))
131    version_numbers.add(registration.version)
132  if identifier in _REVIVED_TYPE_REGISTRY:
133    raise AssertionError(
134        "Duplicate registrations for type {}".format(identifier))
135
136  _REVIVED_TYPE_REGISTRY[identifier] = (predicate, versions)
137  _TYPE_IDENTIFIERS.append(identifier)
138
139
140def serialize(obj):
141  """Create a SavedUserObject from a trackable object."""
142  for identifier in _TYPE_IDENTIFIERS:
143    predicate, versions = _REVIVED_TYPE_REGISTRY[identifier]
144    if predicate(obj):
145      # Always uses the most recent version to serialize.
146      return versions[0].to_proto()
147  return None
148
149
150def deserialize(proto):
151  """Create a trackable object from a SavedUserObject proto.
152
153  Args:
154    proto: A SavedUserObject to deserialize.
155
156  Returns:
157    A tuple of (trackable, assignment_fn) where assignment_fn has the same
158    signature as setattr and should be used to add dependencies to
159    `trackable` when they are available.
160  """
161  _, type_registrations = _REVIVED_TYPE_REGISTRY.get(
162      proto.identifier, (None, None))
163  if type_registrations is not None:
164    for type_registration in type_registrations:
165      if type_registration.should_load(proto):
166        return (type_registration.from_proto(proto), type_registration.setter)
167  return None
168
169
170def registered_identifiers():
171  return _REVIVED_TYPE_REGISTRY.keys()
172
173
174def get_setter(proto):
175  _, type_registrations = _REVIVED_TYPE_REGISTRY.get(
176      proto.identifier, (None, None))
177  if type_registrations is not None:
178    for type_registration in type_registrations:
179      if type_registration.should_load(proto):
180        return type_registration.setter
181  return None
182