1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Layer serialization/deserialization functions.
16"""
17# pylint: disable=wildcard-import
18# pylint: disable=unused-import
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24from tensorflow.python import tf2
25from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer
26from tensorflow.python.keras.engine.input_layer import Input
27from tensorflow.python.keras.engine.input_layer import InputLayer
28from tensorflow.python.keras.layers.advanced_activations import *
29from tensorflow.python.keras.layers.convolutional import *
30from tensorflow.python.keras.layers.convolutional_recurrent import *
31from tensorflow.python.keras.layers.core import *
32from tensorflow.python.keras.layers.cudnn_recurrent import *
33from tensorflow.python.keras.layers.embeddings import *
34from tensorflow.python.keras.layers.local import *
35from tensorflow.python.keras.layers.merge import *
36from tensorflow.python.keras.layers.noise import *
37from tensorflow.python.keras.layers.normalization import *
38from tensorflow.python.keras.layers.pooling import *
39from tensorflow.python.keras.layers.recurrent import *
40from tensorflow.python.keras.layers.wrappers import *
41from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
42from tensorflow.python.util.tf_export import keras_export
43
44if tf2.enabled():
45  from tensorflow.python.keras.layers.normalization_v2 import *  # pylint: disable=g-import-not-at-top
46  from tensorflow.python.keras.layers.recurrent_v2 import *     # pylint: disable=g-import-not-at-top
47
48
49@keras_export('keras.layers.serialize')
50def serialize(layer):
51  return {'class_name': layer.__class__.__name__, 'config': layer.get_config()}
52
53
54@keras_export('keras.layers.deserialize')
55def deserialize(config, custom_objects=None):
56  """Instantiates a layer from a config dictionary.
57
58  Arguments:
59      config: dict of the form {'class_name': str, 'config': dict}
60      custom_objects: dict mapping class names (or function names)
61          of custom (non-Keras) objects to class/functions
62
63  Returns:
64      Layer instance (may be Model, Sequential, Network, Layer...)
65  """
66  from tensorflow.python.keras import models  # pylint: disable=g-import-not-at-top
67  globs = globals()  # All layers.
68  globs['Network'] = models.Network
69  globs['Model'] = models.Model
70  globs['Sequential'] = models.Sequential
71
72  return deserialize_keras_object(
73      config,
74      module_objects=globs,
75      custom_objects=custom_objects,
76      printable_module_name='layer')
77