1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helper classes that list&validate all attributes to serialize to SavedModel."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22import six
23
24from tensorflow.python.keras.saving.saved_model import json_utils
25from tensorflow.python.keras.saving.saved_model import utils
26from tensorflow.python.training.tracking import tracking
27
28
29@six.add_metaclass(abc.ABCMeta)
30class SavedModelSaver(object):
31  """Saver defining the methods and properties used to serialize Keras objects.
32  """
33
34  def __init__(self, obj):
35    self.obj = obj
36
37  @abc.abstractproperty
38  def object_identifier(self):
39    """String stored in object identifier field in the SavedModel proto.
40
41    Returns:
42      A string with the object identifier, which is used at load time.
43    """
44    raise NotImplementedError
45
46  @property
47  def tracking_metadata(self):
48    """String stored in metadata field in the SavedModel proto.
49
50    Returns:
51      A serialized JSON storing information necessary for recreating this layer.
52    """
53    # TODO(kathywu): check that serialized JSON can be loaded (e.g., if an
54    # object is in the python property)
55    return json_utils.Encoder().encode(self.python_properties)
56
57  def list_extra_dependencies_for_serialization(self, serialization_cache):
58    """Lists extra dependencies to serialize to SavedModel.
59
60    By overriding this method, extra dependencies can be attached to the
61    serialized Layer. For example, this is used to save the list of `variables`
62    and `trainable_variables`, which are python properties in a Layer object,
63    but are represented as a static list in the SavedModel.
64
65    Args:
66      serialization_cache: A dictionary shared between all objects in the same
67        object graph. This object is passed to both
68        `_list_extra_dependencies_for_serialization` and
69        `_list_functions_for_serialization`.
70
71    Returns:
72      A dictionary mapping attribute names to trackable objects. The entire list
73      of attributes are listed in the `saved_model._LayerAttributes` class.
74    """
75    if not utils.should_save_traces():
76      return {}
77
78    return self.objects_to_serialize(serialization_cache)
79
80  def list_functions_for_serialization(self, serialization_cache):
81    """Lists extra functions to serialize to the SavedModel.
82
83    Args:
84      serialization_cache: Dictionary passed to all objects in the same object
85        graph during serialization.
86
87    Returns:
88        A dictionary mapping attribute names to `Function` or
89        `ConcreteFunction`.
90    """
91    if not utils.should_save_traces():
92      return {}
93
94    fns = self.functions_to_serialize(serialization_cache)
95
96    # The parent AutoTrackable class saves all user-defined tf.functions, and
97    # returns them in _list_functions_for_serialization(). Add these functions
98    # to the dict.
99    fns.update(
100        tracking.AutoTrackable._list_functions_for_serialization(  # pylint:disable=protected-access
101            self.obj, serialization_cache))
102    return fns
103
104  @abc.abstractproperty
105  def python_properties(self):
106    """Returns dictionary of python properties to save in the metadata.
107
108    This dictionary must be serializable and deserializable to/from JSON.
109
110    When loading, the items in this dict are used to initialize the object and
111    define attributes in the revived object.
112    """
113    raise NotImplementedError
114
115  @abc.abstractmethod
116  def objects_to_serialize(self, serialization_cache):
117    """Returns dictionary of extra checkpointable objects to serialize.
118
119    See `functions_to_serialize` for an explanation of this function's
120    effects.
121
122    Args:
123      serialization_cache: Dictionary passed to all objects in the same object
124        graph during serialization.
125
126    Returns:
127        A dictionary mapping attribute names to checkpointable objects.
128    """
129    raise NotImplementedError
130
131  @abc.abstractmethod
132  def functions_to_serialize(self, serialization_cache):
133    """Returns extra functions to include when serializing a Keras object.
134
135    Normally, when calling exporting an object to SavedModel, only the
136    functions and objects defined by the user are saved. For example:
137
138    ```
139    obj = tf.Module()
140    obj.v = tf.Variable(1.)
141
142    @tf.function
143    def foo(...): ...
144
145    obj.foo = foo
146
147    w = tf.Variable(1.)
148
149    tf.saved_model.save(obj, 'path/to/saved/model')
150    loaded = tf.saved_model.load('path/to/saved/model')
151
152    loaded.v  # Variable with the same value as obj.v
153    loaded.foo  # Equivalent to obj.foo
154    loaded.w  # AttributeError
155    ```
156
157    Assigning trackable objects to attributes creates a graph, which is used for
158    both checkpointing and SavedModel serialization.
159
160    When the graph generated from attribute tracking is insufficient, extra
161    objects and functions may be added at serialization time. For example,
162    most models do not have their call function wrapped with a @tf.function
163    decorator. This results in `model.call` not being saved. Since Keras objects
164    should be revivable from the SavedModel format, the call function is added
165    as an extra function to serialize.
166
167    This function and `objects_to_serialize` is called multiple times when
168    exporting to SavedModel. Please use the cache to avoid generating new
169    functions and objects. A fresh cache is created for each SavedModel export.
170
171    Args:
172      serialization_cache: Dictionary passed to all objects in the same object
173        graph during serialization.
174
175    Returns:
176        A dictionary mapping attribute names to `Function` or
177        `ConcreteFunction`.
178    """
179    raise NotImplementedError
180