1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helper classes that list&validate all attributes to serialize to SavedModel.
16"""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.eager import def_function
22from tensorflow.python.eager import function as defun
23from tensorflow.python.keras.saving.saved_model import constants
24from tensorflow.python.keras.utils.generic_utils import LazyLoader
25from tensorflow.python.training.tracking import base as trackable
26from tensorflow.python.training.tracking.tracking import AutoTrackable
27
28# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
29# once the issue with copybara is fixed.
30# pylint:disable=g-inconsistent-quotes
31base_layer = LazyLoader(
32    "base_layer", globals(),
33    "tensorflow.python.keras.engine.base_layer")
34training_lib = LazyLoader(
35    "training_lib", globals(),
36    "tensorflow.python.keras.engine.training")
37metrics = LazyLoader("metrics", globals(),
38                     "tensorflow.python.keras.metrics")
39recurrent = LazyLoader(
40    "recurrent", globals(),
41    "tensorflow.python.keras.layers.recurrent")
42# pylint:enable=g-inconsistent-quotes
43
44
45class SerializedAttributes(object):
46  """Class that tracks and validates all serialization attributes.
47
48  Keras models contain many Python-defined components. For example, the
49  trainable_variable property lists the model's trainable variables by
50  recursively retrieving the trainable variables from each of the child layers.
51  Another example is model.call, a python function that calls child layers and
52  adds ops to the backend graph.
53
54  Only Tensorflow checkpointable objects and functions can be serialized to
55  SavedModel. Serializing a Keras model as-is results in a checkpointable object
56  that does not resemble a Keras model at all. Thus, extra checkpointable
57  objects and functions must be created during serialization.
58
59  **Defining new serialized attributes**
60  Child classes should be defined using:
61    SerializedAttributes.with_attributes(
62        'name', checkpointable_objects=[...], functions=[...], copy_from=[...])
63  This class is used to cache generated checkpointable objects and functions,
64  ensuring that new objects and functions are generated a single time.
65
66  **Usage during serialization**
67  Each Layer/Model object should have a corresponding instance of
68  SerializedAttributes. Create a new instance by calling
69  `SerializedAttributes.new(obj)`. Objects and functions may be saved using
70  `.set_and_validate_checkpointable_objects`/`.set_and_and_validate_functions`.
71  The properties `.checkpointable_objects` and `.functions` returns the cached
72  values.
73
74  **Adding/changing attributes to save to SavedModel**
75  1. Change the call to `SerializedAttributes.with_attributes` in the correct
76     class:
77     - CommonEndpoints: Base attributes to be added during serialization. If
78       these attributes are present in a Trackable object, it can be
79       deserialized to a Keras Model.
80     - LayerAttributes: Attributes to serialize for Layer objects.
81     - ModelAttributes: Attributes to serialize for Model objects.
82  2. Update class docstring
83  3. Update arguments to any calls to `set_and_validate_*`. For example, if
84     `call_raw_tensors` is added to the ModelAttributes function list, then
85     a `call_raw_tensors` function should be passed to
86     `set_and_validate_functions`.
87
88  **Common endpoints vs other attributes**
89  Only common endpoints are attached directly to the root object. Keras-specific
90  attributes are saved to a separate trackable object with the name "keras_api".
91  The number of objects attached to the root is limited because any naming
92  conflicts will cause user code to break.
93
94  Another reason is that this will only affect users who call
95  `tf.saved_model.load` instead of `tf.keras.models.load_model`. These are
96  advanced users who are likely to have defined their own tf.functions and
97  trackable objects. The added Keras-specific attributes are kept out of the way
98  in the "keras_api" namespace.
99
100  Properties defined in this class may be used to filter out keras-specific
101  attributes:
102  - `functions_to_serialize`: Returns dict of functions to attach to the root
103      object.
104  - `checkpointable_objects_to_serialize`: Returns dict of objects to attach to
105      the root object (including separate trackable object containing
106      keras-specific attributes)
107
108  All changes to the serialized attributes must be backwards-compatible, so
109  attributes should not be removed or modified without sufficient justification.
110  """
111
112  @staticmethod
113  def with_attributes(
114      name, checkpointable_objects=None, functions=None, copy_from=None):
115    """Creates a subclass with all attributes as specified in the arguments.
116
117    Args:
118      name: Name of subclass
119      checkpointable_objects: List of checkpointable objects to be serialized
120        in the SavedModel.
121      functions: List of functions to be serialized in the SavedModel.
122      copy_from: List of other SerializedAttributes subclasses. The returned
123        class will copy checkpoint objects/functions from each subclass.
124
125    Returns:
126      Child class with attributes as defined in the `checkpointable_objects`
127      and `functions` lists.
128    """
129    checkpointable_objects = checkpointable_objects or []
130    functions = functions or []
131
132    if copy_from is not None:
133      for cls in copy_from:
134        checkpointable_objects.extend(cls.all_checkpointable_objects)
135        functions.extend(cls.all_functions)
136
137    classdict = {
138        'all_checkpointable_objects': set(checkpointable_objects),
139        'all_functions': set(functions)}
140    return type(name, (SerializedAttributes,), classdict)
141
142  @staticmethod
143  def new(obj):
144    """Returns a new SerializedAttribute object."""
145    if isinstance(obj, training_lib.Model):
146      return ModelAttributes()
147    elif isinstance(obj, metrics.Metric):
148      return MetricAttributes()
149    elif isinstance(obj, recurrent.RNN):
150      return RNNAttributes()
151    elif isinstance(obj, base_layer.Layer):
152      return LayerAttributes()
153    else:
154      raise TypeError('Internal error during serialization: Expected Keras '
155                      'Layer object, got {} of type {}'.format(obj, type(obj)))
156
157  def __init__(self):
158    self._object_dict = {}
159    self._function_dict = {}
160    self._keras_trackable = AutoTrackable()
161
162  @property
163  def functions(self):
164    """Returns dictionary of all functions."""
165    return {key: value for key, value in self._function_dict.items()
166            if value is not None}
167
168  @property
169  def checkpointable_objects(self):
170    """Returns dictionary of all checkpointable objects."""
171    return {key: value for key, value in self._object_dict.items()
172            if value is not None}
173
174  @property
175  def functions_to_serialize(self):
176    """Returns functions to attach to the root object during serialization."""
177    return {key: value for key, value in self.functions.items()
178            if key in CommonEndpoints.all_functions}
179
180  @property
181  def objects_to_serialize(self):
182    """Returns objects to attach to the root object during serialization."""
183    objects = {key: value for key, value in self.checkpointable_objects.items()
184               if key in CommonEndpoints.all_checkpointable_objects}
185    objects[constants.KERAS_ATTR] = self._keras_trackable
186    return objects
187
188  def set_and_validate_functions(self, function_dict):
189    """Saves function dictionary, and validates dictionary values."""
190    for key in self.all_functions:
191      if key in function_dict:
192        if (function_dict[key] is not None and  # Not all functions are required
193            not isinstance(function_dict[key],
194                           (defun.Function, def_function.Function))):
195          raise ValueError(
196              'Function dictionary contained a non-function object: {} (for key'
197              ' {})'.format(function_dict[key], key))
198        self._function_dict[key] = function_dict[key]
199        setattr(self._keras_trackable, key, function_dict[key])
200      else:
201        raise ValueError('Function {} missing from serialized function dict.'
202                         .format(key))
203    return self.functions
204
205  def set_and_validate_objects(self, object_dict):
206    """Saves objects to a dictionary, and validates the values."""
207    for key in self.all_checkpointable_objects:
208      if key in object_dict:
209        if not isinstance(object_dict[key], trackable.Trackable):
210          raise ValueError(
211              'Object dictionary contained a non-trackable object: {} (for key'
212              ' {})'.format(object_dict[key], key))
213        self._object_dict[key] = object_dict[key]
214        setattr(self._keras_trackable, key, object_dict[key])
215      else:
216        raise ValueError(
217            'Object {} missing from serialized object dict.'.format(key))
218    return self.checkpointable_objects
219
220
221class CommonEndpoints(SerializedAttributes.with_attributes(
222    'CommonEndpoints',
223    checkpointable_objects=['variables', 'trainable_variables',
224                            'regularization_losses'],
225    functions=['__call__', 'call_and_return_all_conditional_losses',
226               '_default_save_signature'])):
227  """Common endpoints shared by all models loadable by Keras.
228
229  List of all attributes:
230    variables: List of all variables in the model and its sublayers.
231    trainable_variables: List of all trainable variables in the model and its
232      sublayers.
233    regularization_losses: List of all unconditional losses (losses not
234      dependent on the inputs) in the model and its sublayers.
235    __call__: Function that takes inputs and returns the outputs of the model
236      call function.
237    call_and_return_all_conditional_losses: Function that returns a tuple of
238      (call function outputs, list of all losses that depend on the inputs).
239    _default_save_signature: Traced model call function. This is only included
240      if the top level exported object is a Keras model.
241  """
242
243
244class LayerAttributes(SerializedAttributes.with_attributes(
245    'LayerAttributes',
246    checkpointable_objects=['non_trainable_variables', 'layers', 'metrics',
247                            'layer_regularization_losses', 'layer_metrics'],
248    functions=['call_and_return_conditional_losses', 'activity_regularizer_fn'],
249    copy_from=[CommonEndpoints]
250    )):
251  """Layer checkpointable objects + functions that are saved to the SavedModel.
252
253  List of all attributes:
254    All attributes from CommonEndpoints
255    non_trainable_variables: List of non-trainable variables in the layer and
256      its sublayers.
257    layers: List of all sublayers.
258    metrics: List of all metrics in the layer and its sublayers.
259    call_and_return_conditional_losses: Function that takes inputs and returns a
260      tuple of (outputs of the call function, list of input-dependent losses).
261      The list of losses excludes the activity regularizer function, which is
262      separate to allow the deserialized Layer object to define a different
263      activity regularizer.
264    activity_regularizer_fn: Callable that returns the activity regularizer loss
265    layer_regularization_losses: List of losses owned only by this layer.
266    layer_metrics: List of metrics owned by this layer.
267  """
268
269
270class ModelAttributes(SerializedAttributes.with_attributes(
271    'ModelAttributes',
272    copy_from=[LayerAttributes])):
273  """Model checkpointable objects + functions that are saved to the SavedModel.
274
275  List of all attributes:
276    All attributes from LayerAttributes (including CommonEndpoints)
277  """
278  # TODO(kathywu): Add attributes `compile_losses` and `compile_metrics`, which
279  #  list all losses and metrics defined by `model.compile`.
280
281
282class MetricAttributes(
283    SerializedAttributes.with_attributes(
284        'MetricAttributes',
285        checkpointable_objects=['variables'],
286        functions=[],
287    )):
288  """Attributes that are added to Metric objects when saved to SavedModel.
289
290  List of all attributes:
291    variables: list of all variables
292  """
293  pass
294
295
296class RNNAttributes(SerializedAttributes.with_attributes(
297    'RNNAttributes',
298    checkpointable_objects=['states'],
299    copy_from=[LayerAttributes])):
300  """RNN checkpointable objects + functions that are saved to the SavedModel.
301
302  List of all attributes:
303    All attributes from LayerAttributes (including CommonEndpoints)
304    states: List of state variables
305  """
306
307