1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Layer serialization/deserialization functions.
16"""
17# pylint: disable=wildcard-import
18# pylint: disable=unused-import
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import threading
25
26from tensorflow.python import tf2
27from tensorflow.python.keras.engine import base_layer
28from tensorflow.python.keras.engine import input_layer
29from tensorflow.python.keras.engine import input_spec
30from tensorflow.python.keras.layers import advanced_activations
31from tensorflow.python.keras.layers import convolutional
32from tensorflow.python.keras.layers import convolutional_recurrent
33from tensorflow.python.keras.layers import core
34from tensorflow.python.keras.layers import cudnn_recurrent
35from tensorflow.python.keras.layers import dense_attention
36from tensorflow.python.keras.layers import einsum_dense
37from tensorflow.python.keras.layers import embeddings
38from tensorflow.python.keras.layers import local
39from tensorflow.python.keras.layers import merge
40from tensorflow.python.keras.layers import multi_head_attention
41from tensorflow.python.keras.layers import noise
42from tensorflow.python.keras.layers import normalization
43from tensorflow.python.keras.layers import normalization_v2
44from tensorflow.python.keras.layers import pooling
45from tensorflow.python.keras.layers import recurrent
46from tensorflow.python.keras.layers import recurrent_v2
47from tensorflow.python.keras.layers import rnn_cell_wrapper_v2
48from tensorflow.python.keras.layers import wrappers
49from tensorflow.python.keras.layers.preprocessing import category_crossing
50from tensorflow.python.keras.layers.preprocessing import category_encoding
51from tensorflow.python.keras.layers.preprocessing import discretization
52from tensorflow.python.keras.layers.preprocessing import hashing
53from tensorflow.python.keras.layers.preprocessing import image_preprocessing
54from tensorflow.python.keras.layers.preprocessing import integer_lookup as preprocessing_integer_lookup
55from tensorflow.python.keras.layers.preprocessing import integer_lookup_v1 as preprocessing_integer_lookup_v1
56from tensorflow.python.keras.layers.preprocessing import normalization as preprocessing_normalization
57from tensorflow.python.keras.layers.preprocessing import normalization_v1 as preprocessing_normalization_v1
58from tensorflow.python.keras.layers.preprocessing import string_lookup as preprocessing_string_lookup
59from tensorflow.python.keras.layers.preprocessing import string_lookup_v1 as preprocessing_string_lookup_v1
60from tensorflow.python.keras.layers.preprocessing import text_vectorization as preprocessing_text_vectorization
61from tensorflow.python.keras.layers.preprocessing import text_vectorization_v1 as preprocessing_text_vectorization_v1
62from tensorflow.python.keras.utils import generic_utils
63from tensorflow.python.keras.utils import tf_inspect as inspect
64from tensorflow.python.util.tf_export import keras_export
65
66
67ALL_MODULES = (base_layer, input_layer, advanced_activations, convolutional,
68               convolutional_recurrent, core, cudnn_recurrent, dense_attention,
69               embeddings, einsum_dense, local, merge, noise, normalization,
70               pooling, image_preprocessing, preprocessing_integer_lookup_v1,
71               preprocessing_normalization_v1, preprocessing_string_lookup_v1,
72               preprocessing_text_vectorization_v1, recurrent, wrappers,
73               hashing, category_crossing, category_encoding, discretization,
74               multi_head_attention)
75ALL_V2_MODULES = (rnn_cell_wrapper_v2, normalization_v2, recurrent_v2,
76                  preprocessing_integer_lookup, preprocessing_normalization,
77                  preprocessing_string_lookup, preprocessing_text_vectorization)
78# ALL_OBJECTS is meant to be a global mutable. Hence we need to make it
79# thread-local to avoid concurrent mutations.
80LOCAL = threading.local()
81
82
83def populate_deserializable_objects():
84  """Populates dict ALL_OBJECTS with every built-in layer.
85  """
86  global LOCAL
87  if not hasattr(LOCAL, 'ALL_OBJECTS'):
88    LOCAL.ALL_OBJECTS = {}
89    LOCAL.GENERATED_WITH_V2 = None
90
91  if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled():
92    # Objects dict is already generated for the proper TF version:
93    # do nothing.
94    return
95
96  LOCAL.ALL_OBJECTS = {}
97  LOCAL.GENERATED_WITH_V2 = tf2.enabled()
98
99  base_cls = base_layer.Layer
100  generic_utils.populate_dict_with_module_objects(
101      LOCAL.ALL_OBJECTS,
102      ALL_MODULES,
103      obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls))
104
105  # Overwrite certain V1 objects with V2 versions
106  if tf2.enabled():
107    generic_utils.populate_dict_with_module_objects(
108        LOCAL.ALL_OBJECTS,
109        ALL_V2_MODULES,
110        obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls))
111
112  # These deserialization aliases are added for backward compatibility,
113  # as in TF 1.13, "BatchNormalizationV1" and "BatchNormalizationV2"
114  # were used as class name for v1 and v2 version of BatchNormalization,
115  # respectively. Here we explicitly convert them to their canonical names.
116  LOCAL.ALL_OBJECTS['BatchNormalizationV1'] = normalization.BatchNormalization
117  LOCAL.ALL_OBJECTS[
118      'BatchNormalizationV2'] = normalization_v2.BatchNormalization
119
120  # Prevent circular dependencies.
121  from tensorflow.python.keras import models  # pylint: disable=g-import-not-at-top
122  from tensorflow.python.keras.premade.linear import LinearModel  # pylint: disable=g-import-not-at-top
123  from tensorflow.python.keras.premade.wide_deep import WideDeepModel  # pylint: disable=g-import-not-at-top
124  from tensorflow.python.keras.feature_column.sequence_feature_column import SequenceFeatures  # pylint: disable=g-import-not-at-top
125
126  LOCAL.ALL_OBJECTS['Input'] = input_layer.Input
127  LOCAL.ALL_OBJECTS['InputSpec'] = input_spec.InputSpec
128  LOCAL.ALL_OBJECTS['Functional'] = models.Functional
129  LOCAL.ALL_OBJECTS['Model'] = models.Model
130  LOCAL.ALL_OBJECTS['SequenceFeatures'] = SequenceFeatures
131  LOCAL.ALL_OBJECTS['Sequential'] = models.Sequential
132  LOCAL.ALL_OBJECTS['LinearModel'] = LinearModel
133  LOCAL.ALL_OBJECTS['WideDeepModel'] = WideDeepModel
134
135  if tf2.enabled():
136    from tensorflow.python.keras.feature_column.dense_features_v2 import DenseFeatures  # pylint: disable=g-import-not-at-top
137    LOCAL.ALL_OBJECTS['DenseFeatures'] = DenseFeatures
138  else:
139    from tensorflow.python.keras.feature_column.dense_features import DenseFeatures  # pylint: disable=g-import-not-at-top
140    LOCAL.ALL_OBJECTS['DenseFeatures'] = DenseFeatures
141
142  # Merge layers, function versions.
143  LOCAL.ALL_OBJECTS['add'] = merge.add
144  LOCAL.ALL_OBJECTS['subtract'] = merge.subtract
145  LOCAL.ALL_OBJECTS['multiply'] = merge.multiply
146  LOCAL.ALL_OBJECTS['average'] = merge.average
147  LOCAL.ALL_OBJECTS['maximum'] = merge.maximum
148  LOCAL.ALL_OBJECTS['minimum'] = merge.minimum
149  LOCAL.ALL_OBJECTS['concatenate'] = merge.concatenate
150  LOCAL.ALL_OBJECTS['dot'] = merge.dot
151
152
153@keras_export('keras.layers.serialize')
154def serialize(layer):
155  return generic_utils.serialize_keras_object(layer)
156
157
158@keras_export('keras.layers.deserialize')
159def deserialize(config, custom_objects=None):
160  """Instantiates a layer from a config dictionary.
161
162  Args:
163      config: dict of the form {'class_name': str, 'config': dict}
164      custom_objects: dict mapping class names (or function names)
165          of custom (non-Keras) objects to class/functions
166
167  Returns:
168      Layer instance (may be Model, Sequential, Network, Layer...)
169  """
170  populate_deserializable_objects()
171  return generic_utils.deserialize_keras_object(
172      config,
173      module_objects=LOCAL.ALL_OBJECTS,
174      custom_objects=custom_objects,
175      printable_module_name='layer')
176