1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes and functions implementing Layer SavedModel serialization."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.keras.mixed_precision import policy
22from tensorflow.python.keras.saving.saved_model import base_serialization
23from tensorflow.python.keras.saving.saved_model import constants
24from tensorflow.python.keras.saving.saved_model import save_impl
25from tensorflow.python.keras.saving.saved_model import serialized_attributes
26from tensorflow.python.keras.utils import generic_utils
27from tensorflow.python.training.tracking import data_structures
28from tensorflow.python.util import nest
29
30
31class LayerSavedModelSaver(base_serialization.SavedModelSaver):
32  """Implements Layer SavedModel serialization."""
33
34  @property
35  def object_identifier(self):
36    return constants.LAYER_IDENTIFIER
37
38  @property
39  def python_properties(self):
40    # TODO(kathywu): Add python property validator
41    return self._python_properties_internal()
42
43  def _python_properties_internal(self):
44    """Returns dictionary of all python properties."""
45    # TODO(kathywu): Add support for metrics serialization.
46    # TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once
47    # the python config serialization has caught up.
48    metadata = dict(
49        name=self.obj.name,
50        trainable=self.obj.trainable,
51        expects_training_arg=self.obj._expects_training_arg,  # pylint: disable=protected-access
52        dtype=policy.serialize(self.obj._dtype_policy),  # pylint: disable=protected-access
53        batch_input_shape=getattr(self.obj, '_batch_input_shape', None),
54        stateful=self.obj.stateful,
55        must_restore_from_config=self.obj._must_restore_from_config,  # pylint: disable=protected-access
56    )
57
58    metadata.update(get_serialized(self.obj))
59    if self.obj.input_spec is not None:
60      # Layer's input_spec has already been type-checked in the property setter.
61      metadata['input_spec'] = nest.map_structure(
62          lambda x: generic_utils.serialize_keras_object(x) if x else None,
63          self.obj.input_spec)
64    if (self.obj.activity_regularizer is not None and
65        hasattr(self.obj.activity_regularizer, 'get_config')):
66      metadata['activity_regularizer'] = generic_utils.serialize_keras_object(
67          self.obj.activity_regularizer)
68    if self.obj._build_input_shape is not None:  # pylint: disable=protected-access
69      metadata['build_input_shape'] = self.obj._build_input_shape  # pylint: disable=protected-access
70    return metadata
71
72  def objects_to_serialize(self, serialization_cache):
73    return (self._get_serialized_attributes(
74        serialization_cache).objects_to_serialize)
75
76  def functions_to_serialize(self, serialization_cache):
77    return (self._get_serialized_attributes(
78        serialization_cache).functions_to_serialize)
79
80  def _get_serialized_attributes(self, serialization_cache):
81    """Generates or retrieves serialized attributes from cache."""
82    keras_cache = serialization_cache.setdefault(constants.KERAS_CACHE_KEY, {})
83    if self.obj in keras_cache:
84      return keras_cache[self.obj]
85
86    serialized_attr = keras_cache[self.obj] = (
87        serialized_attributes.SerializedAttributes.new(self.obj))
88
89    if (save_impl.should_skip_serialization(self.obj) or
90        self.obj._must_restore_from_config):  # pylint: disable=protected-access
91      return serialized_attr
92
93    object_dict, function_dict = self._get_serialized_attributes_internal(
94        serialization_cache)
95
96    serialized_attr.set_and_validate_objects(object_dict)
97    serialized_attr.set_and_validate_functions(function_dict)
98    return serialized_attr
99
100  def _get_serialized_attributes_internal(self, serialization_cache):
101    """Returns dictionary of serialized attributes."""
102    objects = save_impl.wrap_layer_objects(self.obj, serialization_cache)
103    functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
104    # Attribute validator requires that the default save signature is added to
105    # function dict, even if the value is None.
106    functions['_default_save_signature'] = None
107    return objects, functions
108
109
110# TODO(kathywu): Move serialization utils (and related utils from
111# generic_utils.py) to a separate file.
112def get_serialized(obj):
113  with generic_utils.skip_failed_serialization():
114    # Store the config dictionary, which may be used when reviving the object.
115    # When loading, the program will attempt to revive the object from config,
116    # and if that fails, the object will be revived from the SavedModel.
117    return generic_utils.serialize_keras_object(obj)
118
119
120class InputLayerSavedModelSaver(base_serialization.SavedModelSaver):
121  """InputLayer serialization."""
122
123  @property
124  def object_identifier(self):
125    return constants.INPUT_LAYER_IDENTIFIER
126
127  @property
128  def python_properties(self):
129
130    return dict(
131        class_name=type(self.obj).__name__,
132        name=self.obj.name,
133        dtype=self.obj.dtype,
134        sparse=self.obj.sparse,
135        ragged=self.obj.ragged,
136        batch_input_shape=self.obj._batch_input_shape,  # pylint: disable=protected-access
137        config=self.obj.get_config())
138
139  def objects_to_serialize(self, serialization_cache):
140    return {}
141
142  def functions_to_serialize(self, serialization_cache):
143    return {}
144
145
146class RNNSavedModelSaver(LayerSavedModelSaver):
147  """RNN layer serialization."""
148
149  @property
150  def object_identifier(self):
151    return constants.RNN_LAYER_IDENTIFIER
152
153  def _get_serialized_attributes_internal(self, serialization_cache):
154    objects, functions = (
155        super(RNNSavedModelSaver, self)._get_serialized_attributes_internal(
156            serialization_cache))
157    states = data_structures.wrap_or_unwrap(self.obj.states)
158    # Force the tuple into TupleWrapper which is a trackable object. The
159    # save/load code requires all the objects to be trackable.
160    # Tuple is not converted to TupleWrapper by data_structures.wrap_or_unwrap()
161    # if it doesn't contains any trackable objects.
162    if isinstance(states, tuple):
163      states = data_structures._TupleWrapper(states)  # pylint: disable=protected-access
164    objects['states'] = states
165    return objects, functions
166