1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Layer serialization/deserialization functions. 16""" 17# pylint: disable=wildcard-import 18# pylint: disable=unused-import 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import threading 25 26from tensorflow.python import tf2 27from tensorflow.python.keras.engine import base_layer 28from tensorflow.python.keras.engine import input_layer 29from tensorflow.python.keras.engine import input_spec 30from tensorflow.python.keras.layers import advanced_activations 31from tensorflow.python.keras.layers import convolutional 32from tensorflow.python.keras.layers import convolutional_recurrent 33from tensorflow.python.keras.layers import core 34from tensorflow.python.keras.layers import cudnn_recurrent 35from tensorflow.python.keras.layers import dense_attention 36from tensorflow.python.keras.layers import einsum_dense 37from tensorflow.python.keras.layers import embeddings 38from tensorflow.python.keras.layers import local 39from tensorflow.python.keras.layers import merge 40from tensorflow.python.keras.layers import multi_head_attention 41from tensorflow.python.keras.layers import noise 42from tensorflow.python.keras.layers import normalization 43from tensorflow.python.keras.layers import normalization_v2 44from tensorflow.python.keras.layers import pooling 45from tensorflow.python.keras.layers import recurrent 46from tensorflow.python.keras.layers import recurrent_v2 47from tensorflow.python.keras.layers import rnn_cell_wrapper_v2 48from tensorflow.python.keras.layers import wrappers 49from tensorflow.python.keras.layers.preprocessing import category_crossing 50from tensorflow.python.keras.layers.preprocessing import category_encoding 51from tensorflow.python.keras.layers.preprocessing import discretization 52from tensorflow.python.keras.layers.preprocessing import hashing 53from tensorflow.python.keras.layers.preprocessing import image_preprocessing 54from tensorflow.python.keras.layers.preprocessing import integer_lookup as preprocessing_integer_lookup 55from tensorflow.python.keras.layers.preprocessing import integer_lookup_v1 as preprocessing_integer_lookup_v1 56from tensorflow.python.keras.layers.preprocessing import normalization as preprocessing_normalization 57from tensorflow.python.keras.layers.preprocessing import normalization_v1 as preprocessing_normalization_v1 58from tensorflow.python.keras.layers.preprocessing import string_lookup as preprocessing_string_lookup 59from tensorflow.python.keras.layers.preprocessing import string_lookup_v1 as preprocessing_string_lookup_v1 60from tensorflow.python.keras.layers.preprocessing import text_vectorization as preprocessing_text_vectorization 61from tensorflow.python.keras.layers.preprocessing import text_vectorization_v1 as preprocessing_text_vectorization_v1 62from tensorflow.python.keras.utils import generic_utils 63from tensorflow.python.keras.utils import tf_inspect as inspect 64from tensorflow.python.util.tf_export import keras_export 65 66 67ALL_MODULES = (base_layer, input_layer, advanced_activations, convolutional, 68 convolutional_recurrent, core, cudnn_recurrent, dense_attention, 69 embeddings, einsum_dense, local, merge, noise, normalization, 70 pooling, image_preprocessing, preprocessing_integer_lookup_v1, 71 preprocessing_normalization_v1, preprocessing_string_lookup_v1, 72 preprocessing_text_vectorization_v1, recurrent, wrappers, 73 hashing, category_crossing, category_encoding, discretization, 74 multi_head_attention) 75ALL_V2_MODULES = (rnn_cell_wrapper_v2, normalization_v2, recurrent_v2, 76 preprocessing_integer_lookup, preprocessing_normalization, 77 preprocessing_string_lookup, preprocessing_text_vectorization) 78# ALL_OBJECTS is meant to be a global mutable. Hence we need to make it 79# thread-local to avoid concurrent mutations. 80LOCAL = threading.local() 81 82 83def populate_deserializable_objects(): 84 """Populates dict ALL_OBJECTS with every built-in layer. 85 """ 86 global LOCAL 87 if not hasattr(LOCAL, 'ALL_OBJECTS'): 88 LOCAL.ALL_OBJECTS = {} 89 LOCAL.GENERATED_WITH_V2 = None 90 91 if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled(): 92 # Objects dict is already generated for the proper TF version: 93 # do nothing. 94 return 95 96 LOCAL.ALL_OBJECTS = {} 97 LOCAL.GENERATED_WITH_V2 = tf2.enabled() 98 99 base_cls = base_layer.Layer 100 generic_utils.populate_dict_with_module_objects( 101 LOCAL.ALL_OBJECTS, 102 ALL_MODULES, 103 obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls)) 104 105 # Overwrite certain V1 objects with V2 versions 106 if tf2.enabled(): 107 generic_utils.populate_dict_with_module_objects( 108 LOCAL.ALL_OBJECTS, 109 ALL_V2_MODULES, 110 obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls)) 111 112 # These deserialization aliases are added for backward compatibility, 113 # as in TF 1.13, "BatchNormalizationV1" and "BatchNormalizationV2" 114 # were used as class name for v1 and v2 version of BatchNormalization, 115 # respectively. Here we explicitly convert them to their canonical names. 116 LOCAL.ALL_OBJECTS['BatchNormalizationV1'] = normalization.BatchNormalization 117 LOCAL.ALL_OBJECTS[ 118 'BatchNormalizationV2'] = normalization_v2.BatchNormalization 119 120 # Prevent circular dependencies. 121 from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top 122 from tensorflow.python.keras.premade.linear import LinearModel # pylint: disable=g-import-not-at-top 123 from tensorflow.python.keras.premade.wide_deep import WideDeepModel # pylint: disable=g-import-not-at-top 124 from tensorflow.python.keras.feature_column.sequence_feature_column import SequenceFeatures # pylint: disable=g-import-not-at-top 125 126 LOCAL.ALL_OBJECTS['Input'] = input_layer.Input 127 LOCAL.ALL_OBJECTS['InputSpec'] = input_spec.InputSpec 128 LOCAL.ALL_OBJECTS['Functional'] = models.Functional 129 LOCAL.ALL_OBJECTS['Model'] = models.Model 130 LOCAL.ALL_OBJECTS['SequenceFeatures'] = SequenceFeatures 131 LOCAL.ALL_OBJECTS['Sequential'] = models.Sequential 132 LOCAL.ALL_OBJECTS['LinearModel'] = LinearModel 133 LOCAL.ALL_OBJECTS['WideDeepModel'] = WideDeepModel 134 135 if tf2.enabled(): 136 from tensorflow.python.keras.feature_column.dense_features_v2 import DenseFeatures # pylint: disable=g-import-not-at-top 137 LOCAL.ALL_OBJECTS['DenseFeatures'] = DenseFeatures 138 else: 139 from tensorflow.python.keras.feature_column.dense_features import DenseFeatures # pylint: disable=g-import-not-at-top 140 LOCAL.ALL_OBJECTS['DenseFeatures'] = DenseFeatures 141 142 # Merge layers, function versions. 143 LOCAL.ALL_OBJECTS['add'] = merge.add 144 LOCAL.ALL_OBJECTS['subtract'] = merge.subtract 145 LOCAL.ALL_OBJECTS['multiply'] = merge.multiply 146 LOCAL.ALL_OBJECTS['average'] = merge.average 147 LOCAL.ALL_OBJECTS['maximum'] = merge.maximum 148 LOCAL.ALL_OBJECTS['minimum'] = merge.minimum 149 LOCAL.ALL_OBJECTS['concatenate'] = merge.concatenate 150 LOCAL.ALL_OBJECTS['dot'] = merge.dot 151 152 153@keras_export('keras.layers.serialize') 154def serialize(layer): 155 return generic_utils.serialize_keras_object(layer) 156 157 158@keras_export('keras.layers.deserialize') 159def deserialize(config, custom_objects=None): 160 """Instantiates a layer from a config dictionary. 161 162 Args: 163 config: dict of the form {'class_name': str, 'config': dict} 164 custom_objects: dict mapping class names (or function names) 165 of custom (non-Keras) objects to class/functions 166 167 Returns: 168 Layer instance (may be Model, Sequential, Network, Layer...) 169 """ 170 populate_deserializable_objects() 171 return generic_utils.deserialize_keras_object( 172 config, 173 module_objects=LOCAL.ALL_OBJECTS, 174 custom_objects=custom_objects, 175 printable_module_name='layer') 176