1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras layers API."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python import tf2
22
23# Generic layers.
24# pylint: disable=g-bad-import-order
25# pylint: disable=g-import-not-at-top
26from tensorflow.python.keras.engine.input_layer import Input
27from tensorflow.python.keras.engine.input_layer import InputLayer
28from tensorflow.python.keras.engine.input_spec import InputSpec
29from tensorflow.python.keras.engine.base_layer import Layer
30from tensorflow.python.keras.engine.base_preprocessing_layer import PreprocessingLayer
31
32# Image preprocessing layers.
33from tensorflow.python.keras.layers.preprocessing.image_preprocessing import CenterCrop
34from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomCrop
35from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomFlip
36from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomContrast
37from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomHeight
38from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomRotation
39from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomTranslation
40from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomWidth
41from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomZoom
42from tensorflow.python.keras.layers.preprocessing.image_preprocessing import Resizing
43from tensorflow.python.keras.layers.preprocessing.image_preprocessing import Rescaling
44
45# Preprocessing layers.
46if tf2.enabled():
47  from tensorflow.python.keras.layers.preprocessing.integer_lookup import IntegerLookup
48  from tensorflow.python.keras.layers.preprocessing.integer_lookup_v1 import IntegerLookup as IntegerLookupV1
49  IntegerLookupV2 = IntegerLookup
50  from tensorflow.python.keras.layers.preprocessing.normalization import Normalization
51  from tensorflow.python.keras.layers.preprocessing.normalization_v1 import Normalization as NormalizationV1
52  NormalizationV2 = Normalization
53  from tensorflow.python.keras.layers.preprocessing.string_lookup import StringLookup
54  from tensorflow.python.keras.layers.preprocessing.string_lookup_v1 import StringLookup as StringLookupV1
55  StringLookupV2 = StringLookup
56  from tensorflow.python.keras.layers.preprocessing.text_vectorization import TextVectorization
57  from tensorflow.python.keras.layers.preprocessing.text_vectorization_v1 import TextVectorization as TextVectorizationV1
58  TextVectorizationV2 = TextVectorization
59else:
60  from tensorflow.python.keras.layers.preprocessing.integer_lookup_v1 import IntegerLookup
61  from tensorflow.python.keras.layers.preprocessing.integer_lookup import IntegerLookup as IntegerLookupV2
62  IntegerLookupV1 = IntegerLookup
63  from tensorflow.python.keras.layers.preprocessing.normalization_v1 import Normalization
64  from tensorflow.python.keras.layers.preprocessing.normalization import Normalization as NormalizationV2
65  NormalizationV1 = Normalization
66  from tensorflow.python.keras.layers.preprocessing.string_lookup_v1 import StringLookup
67  from tensorflow.python.keras.layers.preprocessing.string_lookup import StringLookup as StringLookupV2
68  StringLookupV1 = StringLookup
69  from tensorflow.python.keras.layers.preprocessing.text_vectorization_v1 import TextVectorization
70  from tensorflow.python.keras.layers.preprocessing.text_vectorization import TextVectorization as TextVectorizationV2
71  TextVectorizationV1 = TextVectorization
72from tensorflow.python.keras.layers.preprocessing.category_crossing import CategoryCrossing
73from tensorflow.python.keras.layers.preprocessing.category_encoding import CategoryEncoding
74from tensorflow.python.keras.layers.preprocessing.discretization import Discretization
75from tensorflow.python.keras.layers.preprocessing.hashing import Hashing
76
77# Advanced activations.
78from tensorflow.python.keras.layers.advanced_activations import LeakyReLU
79from tensorflow.python.keras.layers.advanced_activations import PReLU
80from tensorflow.python.keras.layers.advanced_activations import ELU
81from tensorflow.python.keras.layers.advanced_activations import ReLU
82from tensorflow.python.keras.layers.advanced_activations import ThresholdedReLU
83from tensorflow.python.keras.layers.advanced_activations import Softmax
84
85# Convolution layers.
86from tensorflow.python.keras.layers.convolutional import Conv1D
87from tensorflow.python.keras.layers.convolutional import Conv2D
88from tensorflow.python.keras.layers.convolutional import Conv3D
89from tensorflow.python.keras.layers.convolutional import Conv1DTranspose
90from tensorflow.python.keras.layers.convolutional import Conv2DTranspose
91from tensorflow.python.keras.layers.convolutional import Conv3DTranspose
92from tensorflow.python.keras.layers.convolutional import SeparableConv1D
93from tensorflow.python.keras.layers.convolutional import SeparableConv2D
94
95# Convolution layer aliases.
96from tensorflow.python.keras.layers.convolutional import Convolution1D
97from tensorflow.python.keras.layers.convolutional import Convolution2D
98from tensorflow.python.keras.layers.convolutional import Convolution3D
99from tensorflow.python.keras.layers.convolutional import Convolution2DTranspose
100from tensorflow.python.keras.layers.convolutional import Convolution3DTranspose
101from tensorflow.python.keras.layers.convolutional import SeparableConvolution1D
102from tensorflow.python.keras.layers.convolutional import SeparableConvolution2D
103from tensorflow.python.keras.layers.convolutional import DepthwiseConv2D
104
105# Image processing layers.
106from tensorflow.python.keras.layers.convolutional import UpSampling1D
107from tensorflow.python.keras.layers.convolutional import UpSampling2D
108from tensorflow.python.keras.layers.convolutional import UpSampling3D
109from tensorflow.python.keras.layers.convolutional import ZeroPadding1D
110from tensorflow.python.keras.layers.convolutional import ZeroPadding2D
111from tensorflow.python.keras.layers.convolutional import ZeroPadding3D
112from tensorflow.python.keras.layers.convolutional import Cropping1D
113from tensorflow.python.keras.layers.convolutional import Cropping2D
114from tensorflow.python.keras.layers.convolutional import Cropping3D
115
116# Core layers.
117from tensorflow.python.keras.layers.core import Masking
118from tensorflow.python.keras.layers.core import Dropout
119from tensorflow.python.keras.layers.core import SpatialDropout1D
120from tensorflow.python.keras.layers.core import SpatialDropout2D
121from tensorflow.python.keras.layers.core import SpatialDropout3D
122from tensorflow.python.keras.layers.core import Activation
123from tensorflow.python.keras.layers.core import Reshape
124from tensorflow.python.keras.layers.core import Permute
125from tensorflow.python.keras.layers.core import Flatten
126from tensorflow.python.keras.layers.core import RepeatVector
127from tensorflow.python.keras.layers.core import Lambda
128from tensorflow.python.keras.layers.core import Dense
129from tensorflow.python.keras.layers.core import ActivityRegularization
130
131# Dense Attention layers.
132from tensorflow.python.keras.layers.dense_attention import AdditiveAttention
133from tensorflow.python.keras.layers.dense_attention import Attention
134
135# Embedding layers.
136from tensorflow.python.keras.layers.embeddings import Embedding
137
138# Einsum-based dense layer/
139from tensorflow.python.keras.layers.einsum_dense import EinsumDense
140
141# Multi-head Attention layer.
142from tensorflow.python.keras.layers.multi_head_attention import MultiHeadAttention
143
144# Locally-connected layers.
145from tensorflow.python.keras.layers.local import LocallyConnected1D
146from tensorflow.python.keras.layers.local import LocallyConnected2D
147
148# Merge layers.
149from tensorflow.python.keras.layers.merge import Add
150from tensorflow.python.keras.layers.merge import Subtract
151from tensorflow.python.keras.layers.merge import Multiply
152from tensorflow.python.keras.layers.merge import Average
153from tensorflow.python.keras.layers.merge import Maximum
154from tensorflow.python.keras.layers.merge import Minimum
155from tensorflow.python.keras.layers.merge import Concatenate
156from tensorflow.python.keras.layers.merge import Dot
157from tensorflow.python.keras.layers.merge import add
158from tensorflow.python.keras.layers.merge import subtract
159from tensorflow.python.keras.layers.merge import multiply
160from tensorflow.python.keras.layers.merge import average
161from tensorflow.python.keras.layers.merge import maximum
162from tensorflow.python.keras.layers.merge import minimum
163from tensorflow.python.keras.layers.merge import concatenate
164from tensorflow.python.keras.layers.merge import dot
165
166# Noise layers.
167from tensorflow.python.keras.layers.noise import AlphaDropout
168from tensorflow.python.keras.layers.noise import GaussianNoise
169from tensorflow.python.keras.layers.noise import GaussianDropout
170
171# Normalization layers.
172from tensorflow.python.keras.layers.normalization import LayerNormalization
173from tensorflow.python.keras.layers.normalization_v2 import SyncBatchNormalization
174
175if tf2.enabled():
176  from tensorflow.python.keras.layers.normalization_v2 import BatchNormalization
177  from tensorflow.python.keras.layers.normalization import BatchNormalization as BatchNormalizationV1
178  BatchNormalizationV2 = BatchNormalization
179else:
180  from tensorflow.python.keras.layers.normalization import BatchNormalization
181  from tensorflow.python.keras.layers.normalization_v2 import BatchNormalization as BatchNormalizationV2
182  BatchNormalizationV1 = BatchNormalization
183
184# Kernelized layers.
185from tensorflow.python.keras.layers.kernelized import RandomFourierFeatures
186
187# Pooling layers.
188from tensorflow.python.keras.layers.pooling import MaxPooling1D
189from tensorflow.python.keras.layers.pooling import MaxPooling2D
190from tensorflow.python.keras.layers.pooling import MaxPooling3D
191from tensorflow.python.keras.layers.pooling import AveragePooling1D
192from tensorflow.python.keras.layers.pooling import AveragePooling2D
193from tensorflow.python.keras.layers.pooling import AveragePooling3D
194from tensorflow.python.keras.layers.pooling import GlobalAveragePooling1D
195from tensorflow.python.keras.layers.pooling import GlobalAveragePooling2D
196from tensorflow.python.keras.layers.pooling import GlobalAveragePooling3D
197from tensorflow.python.keras.layers.pooling import GlobalMaxPooling1D
198from tensorflow.python.keras.layers.pooling import GlobalMaxPooling2D
199from tensorflow.python.keras.layers.pooling import GlobalMaxPooling3D
200
201# Pooling layer aliases.
202from tensorflow.python.keras.layers.pooling import MaxPool1D
203from tensorflow.python.keras.layers.pooling import MaxPool2D
204from tensorflow.python.keras.layers.pooling import MaxPool3D
205from tensorflow.python.keras.layers.pooling import AvgPool1D
206from tensorflow.python.keras.layers.pooling import AvgPool2D
207from tensorflow.python.keras.layers.pooling import AvgPool3D
208from tensorflow.python.keras.layers.pooling import GlobalAvgPool1D
209from tensorflow.python.keras.layers.pooling import GlobalAvgPool2D
210from tensorflow.python.keras.layers.pooling import GlobalAvgPool3D
211from tensorflow.python.keras.layers.pooling import GlobalMaxPool1D
212from tensorflow.python.keras.layers.pooling import GlobalMaxPool2D
213from tensorflow.python.keras.layers.pooling import GlobalMaxPool3D
214
215# Recurrent layers.
216from tensorflow.python.keras.layers.recurrent import RNN
217from tensorflow.python.keras.layers.recurrent import AbstractRNNCell
218from tensorflow.python.keras.layers.recurrent import StackedRNNCells
219from tensorflow.python.keras.layers.recurrent import SimpleRNNCell
220from tensorflow.python.keras.layers.recurrent import PeepholeLSTMCell
221from tensorflow.python.keras.layers.recurrent import SimpleRNN
222
223if tf2.enabled():
224  from tensorflow.python.keras.layers.recurrent_v2 import GRU
225  from tensorflow.python.keras.layers.recurrent_v2 import GRUCell
226  from tensorflow.python.keras.layers.recurrent_v2 import LSTM
227  from tensorflow.python.keras.layers.recurrent_v2 import LSTMCell
228  from tensorflow.python.keras.layers.recurrent import GRU as GRUV1
229  from tensorflow.python.keras.layers.recurrent import GRUCell as GRUCellV1
230  from tensorflow.python.keras.layers.recurrent import LSTM as LSTMV1
231  from tensorflow.python.keras.layers.recurrent import LSTMCell as LSTMCellV1
232  GRUV2 = GRU
233  GRUCellV2 = GRUCell
234  LSTMV2 = LSTM
235  LSTMCellV2 = LSTMCell
236else:
237  from tensorflow.python.keras.layers.recurrent import GRU
238  from tensorflow.python.keras.layers.recurrent import GRUCell
239  from tensorflow.python.keras.layers.recurrent import LSTM
240  from tensorflow.python.keras.layers.recurrent import LSTMCell
241  from tensorflow.python.keras.layers.recurrent_v2 import GRU as GRUV2
242  from tensorflow.python.keras.layers.recurrent_v2 import GRUCell as GRUCellV2
243  from tensorflow.python.keras.layers.recurrent_v2 import LSTM as LSTMV2
244  from tensorflow.python.keras.layers.recurrent_v2 import LSTMCell as LSTMCellV2
245  GRUV1 = GRU
246  GRUCellV1 = GRUCell
247  LSTMV1 = LSTM
248  LSTMCellV1 = LSTMCell
249
250# Convolutional-recurrent layers.
251from tensorflow.python.keras.layers.convolutional_recurrent import ConvLSTM2D
252
253# CuDNN recurrent layers.
254from tensorflow.python.keras.layers.cudnn_recurrent import CuDNNLSTM
255from tensorflow.python.keras.layers.cudnn_recurrent import CuDNNGRU
256
257# Wrapper functions
258from tensorflow.python.keras.layers.wrappers import Wrapper
259from tensorflow.python.keras.layers.wrappers import Bidirectional
260from tensorflow.python.keras.layers.wrappers import TimeDistributed
261
262# # RNN Cell wrappers.
263from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import DeviceWrapper
264from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import DropoutWrapper
265from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import ResidualWrapper
266
267# Serialization functions
268from tensorflow.python.keras.layers import serialization
269from tensorflow.python.keras.layers.serialization import deserialize
270from tensorflow.python.keras.layers.serialization import serialize
271
272
273class VersionAwareLayers(object):
274  """Utility to be used internally to access layers in a V1/V2-aware fashion.
275
276  When using layers within the Keras codebase, under the constraint that
277  e.g. `layers.BatchNormalization` should be the `BatchNormalization` version
278  corresponding to the current runtime (TF1 or TF2), do not simply access
279  `layers.BatchNormalization` since it would ignore e.g. an early
280  `compat.v2.disable_v2_behavior()` call. Instead, use an instance
281  of `VersionAwareLayers` (which you can use just like the `layers` module).
282  """
283
284  def __getattr__(self, name):
285    serialization.populate_deserializable_objects()
286    if name in serialization.LOCAL.ALL_OBJECTS:
287      return serialization.LOCAL.ALL_OBJECTS[name]
288    return super(VersionAwareLayers, self).__getattr__(name)
289
290del absolute_import
291del division
292del print_function
293