1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Layer serialization/deserialization functions. 16""" 17# pylint: disable=wildcard-import 18# pylint: disable=unused-import 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24from tensorflow.python import tf2 25from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer 26from tensorflow.python.keras.engine.input_layer import Input 27from tensorflow.python.keras.engine.input_layer import InputLayer 28from tensorflow.python.keras.layers.advanced_activations import * 29from tensorflow.python.keras.layers.convolutional import * 30from tensorflow.python.keras.layers.convolutional_recurrent import * 31from tensorflow.python.keras.layers.core import * 32from tensorflow.python.keras.layers.cudnn_recurrent import * 33from tensorflow.python.keras.layers.embeddings import * 34from tensorflow.python.keras.layers.local import * 35from tensorflow.python.keras.layers.merge import * 36from tensorflow.python.keras.layers.noise import * 37from tensorflow.python.keras.layers.normalization import * 38from tensorflow.python.keras.layers.pooling import * 39from tensorflow.python.keras.layers.recurrent import * 40from tensorflow.python.keras.layers.wrappers import * 41from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object 42from tensorflow.python.util.tf_export import keras_export 43 44if tf2.enabled(): 45 from tensorflow.python.keras.layers.normalization_v2 import * # pylint: disable=g-import-not-at-top 46 from tensorflow.python.keras.layers.recurrent_v2 import * # pylint: disable=g-import-not-at-top 47 48 49@keras_export('keras.layers.serialize') 50def serialize(layer): 51 return {'class_name': layer.__class__.__name__, 'config': layer.get_config()} 52 53 54@keras_export('keras.layers.deserialize') 55def deserialize(config, custom_objects=None): 56 """Instantiates a layer from a config dictionary. 57 58 Arguments: 59 config: dict of the form {'class_name': str, 'config': dict} 60 custom_objects: dict mapping class names (or function names) 61 of custom (non-Keras) objects to class/functions 62 63 Returns: 64 Layer instance (may be Model, Sequential, Network, Layer...) 65 """ 66 from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top 67 globs = globals() # All layers. 68 globs['Network'] = models.Network 69 globs['Model'] = models.Model 70 globs['Sequential'] = models.Sequential 71 72 return deserialize_keras_object( 73 config, 74 module_objects=globs, 75 custom_objects=custom_objects, 76 printable_module_name='layer') 77