1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Core Keras layers.
16"""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22import functools
23import operator
24import sys
25import textwrap
26import types as python_types
27import warnings
28
29import numpy as np
30
31from tensorflow.python.eager import backprop
32from tensorflow.python.eager import context
33from tensorflow.python.eager import monitoring
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import ops
37from tensorflow.python.framework import tensor_shape
38from tensorflow.python.keras import activations
39from tensorflow.python.keras import backend as K
40from tensorflow.python.keras import constraints
41from tensorflow.python.keras import initializers
42from tensorflow.python.keras import regularizers
43from tensorflow.python.keras.engine import keras_tensor
44from tensorflow.python.keras.engine.base_layer import Layer
45from tensorflow.python.keras.engine.input_spec import InputSpec
46from tensorflow.python.keras.layers.ops import core as core_ops
47from tensorflow.python.keras.utils import control_flow_util
48from tensorflow.python.keras.utils import conv_utils
49from tensorflow.python.keras.utils import generic_utils
50from tensorflow.python.keras.utils import tf_inspect
51from tensorflow.python.keras.utils import tf_utils
52from tensorflow.python.ops import array_ops
53from tensorflow.python.ops import math_ops
54from tensorflow.python.ops import nn
55from tensorflow.python.ops import variable_scope
56from tensorflow.python.ops.ragged import ragged_tensor
57from tensorflow.python.platform import tf_logging
58from tensorflow.python.training.tracking import base as trackable
59from tensorflow.python.util import dispatch
60from tensorflow.python.util import nest
61from tensorflow.python.util import tf_decorator
62from tensorflow.python.util.tf_export import get_canonical_name_for_symbol
63from tensorflow.python.util.tf_export import get_symbol_from_name
64from tensorflow.python.util.tf_export import keras_export
65
66# TODO(b/168039935): track dropout rate to decide whether/how to make a
67# dropout rate fastpath.
68keras_temporary_dropout_rate = monitoring.BoolGauge(
69    '/tensorflow/api/keras/dropout/temp_rate_is_zero',
70    'Temporarily record if Keras dropout layer was created w/'
71    'constant rate = 0')
72
73
74# pylint: disable=g-classes-have-attributes
75@keras_export('keras.layers.Masking')
76class Masking(Layer):
77  """Masks a sequence by using a mask value to skip timesteps.
78
79  For each timestep in the input tensor (dimension #1 in the tensor),
80  if all values in the input tensor at that timestep
81  are equal to `mask_value`, then the timestep will be masked (skipped)
82  in all downstream layers (as long as they support masking).
83
84  If any downstream layer does not support masking yet receives such
85  an input mask, an exception will be raised.
86
87  Example:
88
89  Consider a Numpy data array `x` of shape `(samples, timesteps, features)`,
90  to be fed to an LSTM layer. You want to mask timestep #3 and #5 because you
91  lack data for these timesteps. You can:
92
93  - Set `x[:, 3, :] = 0.` and `x[:, 5, :] = 0.`
94  - Insert a `Masking` layer with `mask_value=0.` before the LSTM layer:
95
96  ```python
97  samples, timesteps, features = 32, 10, 8
98  inputs = np.random.random([samples, timesteps, features]).astype(np.float32)
99  inputs[:, 3, :] = 0.
100  inputs[:, 5, :] = 0.
101
102  model = tf.keras.models.Sequential()
103  model.add(tf.keras.layers.Masking(mask_value=0.,
104                                    input_shape=(timesteps, features)))
105  model.add(tf.keras.layers.LSTM(32))
106
107  output = model(inputs)
108  # The time step 3 and 5 will be skipped from LSTM calculation.
109  ```
110
111  See [the masking and padding guide](
112    https://www.tensorflow.org/guide/keras/masking_and_padding)
113  for more details.
114  """
115
116  def __init__(self, mask_value=0., **kwargs):
117    super(Masking, self).__init__(**kwargs)
118    self.supports_masking = True
119    self.mask_value = mask_value
120    self._compute_output_and_mask_jointly = True
121
122  def compute_mask(self, inputs, mask=None):
123    return K.any(math_ops.not_equal(inputs, self.mask_value), axis=-1)
124
125  def call(self, inputs):
126    boolean_mask = K.any(
127        math_ops.not_equal(inputs, self.mask_value), axis=-1, keepdims=True)
128    outputs = inputs * math_ops.cast(boolean_mask, inputs.dtype)
129    # Compute the mask and outputs simultaneously.
130    outputs._keras_mask = array_ops.squeeze(boolean_mask, axis=-1)  # pylint: disable=protected-access
131    return outputs
132
133  def compute_output_shape(self, input_shape):
134    return input_shape
135
136  def get_config(self):
137    config = {'mask_value': self.mask_value}
138    base_config = super(Masking, self).get_config()
139    return dict(list(base_config.items()) + list(config.items()))
140
141
142@keras_export('keras.layers.Dropout')
143class Dropout(Layer):
144  """Applies Dropout to the input.
145
146  The Dropout layer randomly sets input units to 0 with a frequency of `rate`
147  at each step during training time, which helps prevent overfitting.
148  Inputs not set to 0 are scaled up by 1/(1 - rate) such that the sum over
149  all inputs is unchanged.
150
151  Note that the Dropout layer only applies when `training` is set to True
152  such that no values are dropped during inference. When using `model.fit`,
153  `training` will be appropriately set to True automatically, and in other
154  contexts, you can set the kwarg explicitly to True when calling the layer.
155
156  (This is in contrast to setting `trainable=False` for a Dropout layer.
157  `trainable` does not affect the layer's behavior, as Dropout does
158  not have any variables/weights that can be frozen during training.)
159
160  >>> tf.random.set_seed(0)
161  >>> layer = tf.keras.layers.Dropout(.2, input_shape=(2,))
162  >>> data = np.arange(10).reshape(5, 2).astype(np.float32)
163  >>> print(data)
164  [[0. 1.]
165   [2. 3.]
166   [4. 5.]
167   [6. 7.]
168   [8. 9.]]
169  >>> outputs = layer(data, training=True)
170  >>> print(outputs)
171  tf.Tensor(
172  [[ 0.    1.25]
173   [ 2.5   3.75]
174   [ 5.    6.25]
175   [ 7.5   8.75]
176   [10.    0.  ]], shape=(5, 2), dtype=float32)
177
178  Args:
179    rate: Float between 0 and 1. Fraction of the input units to drop.
180    noise_shape: 1D integer tensor representing the shape of the
181      binary dropout mask that will be multiplied with the input.
182      For instance, if your inputs have shape
183      `(batch_size, timesteps, features)` and
184      you want the dropout mask to be the same for all timesteps,
185      you can use `noise_shape=(batch_size, 1, features)`.
186    seed: A Python integer to use as random seed.
187
188  Call arguments:
189    inputs: Input tensor (of any rank).
190    training: Python boolean indicating whether the layer should behave in
191      training mode (adding dropout) or in inference mode (doing nothing).
192  """
193
194  def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
195    super(Dropout, self).__init__(**kwargs)
196    self.rate = rate
197    if isinstance(rate, (int, float)) and not rate:
198      keras_temporary_dropout_rate.get_cell().set(True)
199    else:
200      keras_temporary_dropout_rate.get_cell().set(False)
201    self.noise_shape = noise_shape
202    self.seed = seed
203    self.supports_masking = True
204
205  def _get_noise_shape(self, inputs):
206    # Subclasses of `Dropout` may implement `_get_noise_shape(self, inputs)`,
207    # which will override `self.noise_shape`, and allows for custom noise
208    # shapes with dynamically sized inputs.
209    if self.noise_shape is None:
210      return None
211
212    concrete_inputs_shape = array_ops.shape(inputs)
213    noise_shape = []
214    for i, value in enumerate(self.noise_shape):
215      noise_shape.append(concrete_inputs_shape[i] if value is None else value)
216    return ops.convert_to_tensor_v2_with_dispatch(noise_shape)
217
218  def call(self, inputs, training=None):
219    if training is None:
220      training = K.learning_phase()
221
222    def dropped_inputs():
223      return nn.dropout(
224          inputs,
225          noise_shape=self._get_noise_shape(inputs),
226          seed=self.seed,
227          rate=self.rate)
228
229    output = control_flow_util.smart_cond(training, dropped_inputs,
230                                          lambda: array_ops.identity(inputs))
231    return output
232
233  def compute_output_shape(self, input_shape):
234    return input_shape
235
236  def get_config(self):
237    config = {
238        'rate': self.rate,
239        'noise_shape': self.noise_shape,
240        'seed': self.seed
241    }
242    base_config = super(Dropout, self).get_config()
243    return dict(list(base_config.items()) + list(config.items()))
244
245
246@keras_export('keras.layers.SpatialDropout1D')
247class SpatialDropout1D(Dropout):
248  """Spatial 1D version of Dropout.
249
250  This version performs the same function as Dropout, however, it drops
251  entire 1D feature maps instead of individual elements. If adjacent frames
252  within feature maps are strongly correlated (as is normally the case in
253  early convolution layers) then regular dropout will not regularize the
254  activations and will otherwise just result in an effective learning rate
255  decrease. In this case, SpatialDropout1D will help promote independence
256  between feature maps and should be used instead.
257
258  Args:
259    rate: Float between 0 and 1. Fraction of the input units to drop.
260
261  Call arguments:
262    inputs: A 3D tensor.
263    training: Python boolean indicating whether the layer should behave in
264      training mode (adding dropout) or in inference mode (doing nothing).
265
266  Input shape:
267    3D tensor with shape:
268    `(samples, timesteps, channels)`
269
270  Output shape:
271    Same as input.
272
273  References:
274    - [Efficient Object Localization Using Convolutional
275      Networks](https://arxiv.org/abs/1411.4280)
276  """
277
278  def __init__(self, rate, **kwargs):
279    super(SpatialDropout1D, self).__init__(rate, **kwargs)
280    self.input_spec = InputSpec(ndim=3)
281
282  def _get_noise_shape(self, inputs):
283    input_shape = array_ops.shape(inputs)
284    noise_shape = (input_shape[0], 1, input_shape[2])
285    return noise_shape
286
287
288@keras_export('keras.layers.SpatialDropout2D')
289class SpatialDropout2D(Dropout):
290  """Spatial 2D version of Dropout.
291
292  This version performs the same function as Dropout, however, it drops
293  entire 2D feature maps instead of individual elements. If adjacent pixels
294  within feature maps are strongly correlated (as is normally the case in
295  early convolution layers) then regular dropout will not regularize the
296  activations and will otherwise just result in an effective learning rate
297  decrease. In this case, SpatialDropout2D will help promote independence
298  between feature maps and should be used instead.
299
300  Args:
301    rate: Float between 0 and 1. Fraction of the input units to drop.
302    data_format: 'channels_first' or 'channels_last'.
303      In 'channels_first' mode, the channels dimension
304      (the depth) is at index 1,
305      in 'channels_last' mode is it at index 3.
306      It defaults to the `image_data_format` value found in your
307      Keras config file at `~/.keras/keras.json`.
308      If you never set it, then it will be "channels_last".
309
310  Call arguments:
311    inputs: A 4D tensor.
312    training: Python boolean indicating whether the layer should behave in
313      training mode (adding dropout) or in inference mode (doing nothing).
314
315  Input shape:
316    4D tensor with shape:
317    `(samples, channels, rows, cols)` if data_format='channels_first'
318    or 4D tensor with shape:
319    `(samples, rows, cols, channels)` if data_format='channels_last'.
320
321  Output shape:
322    Same as input.
323
324  References:
325    - [Efficient Object Localization Using Convolutional
326      Networks](https://arxiv.org/abs/1411.4280)
327  """
328
329  def __init__(self, rate, data_format=None, **kwargs):
330    super(SpatialDropout2D, self).__init__(rate, **kwargs)
331    if data_format is None:
332      data_format = K.image_data_format()
333    if data_format not in {'channels_last', 'channels_first'}:
334      raise ValueError('data_format must be in '
335                       '{"channels_last", "channels_first"}')
336    self.data_format = data_format
337    self.input_spec = InputSpec(ndim=4)
338
339  def _get_noise_shape(self, inputs):
340    input_shape = array_ops.shape(inputs)
341    if self.data_format == 'channels_first':
342      return (input_shape[0], input_shape[1], 1, 1)
343    elif self.data_format == 'channels_last':
344      return (input_shape[0], 1, 1, input_shape[3])
345
346
347@keras_export('keras.layers.SpatialDropout3D')
348class SpatialDropout3D(Dropout):
349  """Spatial 3D version of Dropout.
350
351  This version performs the same function as Dropout, however, it drops
352  entire 3D feature maps instead of individual elements. If adjacent voxels
353  within feature maps are strongly correlated (as is normally the case in
354  early convolution layers) then regular dropout will not regularize the
355  activations and will otherwise just result in an effective learning rate
356  decrease. In this case, SpatialDropout3D will help promote independence
357  between feature maps and should be used instead.
358
359  Args:
360    rate: Float between 0 and 1. Fraction of the input units to drop.
361    data_format: 'channels_first' or 'channels_last'.
362        In 'channels_first' mode, the channels dimension (the depth)
363        is at index 1, in 'channels_last' mode is it at index 4.
364        It defaults to the `image_data_format` value found in your
365        Keras config file at `~/.keras/keras.json`.
366        If you never set it, then it will be "channels_last".
367
368  Call arguments:
369    inputs: A 5D tensor.
370    training: Python boolean indicating whether the layer should behave in
371      training mode (adding dropout) or in inference mode (doing nothing).
372
373  Input shape:
374    5D tensor with shape:
375    `(samples, channels, dim1, dim2, dim3)` if data_format='channels_first'
376    or 5D tensor with shape:
377    `(samples, dim1, dim2, dim3, channels)` if data_format='channels_last'.
378
379  Output shape:
380    Same as input.
381
382  References:
383    - [Efficient Object Localization Using Convolutional
384      Networks](https://arxiv.org/abs/1411.4280)
385  """
386
387  def __init__(self, rate, data_format=None, **kwargs):
388    super(SpatialDropout3D, self).__init__(rate, **kwargs)
389    if data_format is None:
390      data_format = K.image_data_format()
391    if data_format not in {'channels_last', 'channels_first'}:
392      raise ValueError('data_format must be in '
393                       '{"channels_last", "channels_first"}')
394    self.data_format = data_format
395    self.input_spec = InputSpec(ndim=5)
396
397  def _get_noise_shape(self, inputs):
398    input_shape = array_ops.shape(inputs)
399    if self.data_format == 'channels_first':
400      return (input_shape[0], input_shape[1], 1, 1, 1)
401    elif self.data_format == 'channels_last':
402      return (input_shape[0], 1, 1, 1, input_shape[4])
403
404
405@keras_export('keras.layers.Activation')
406class Activation(Layer):
407  """Applies an activation function to an output.
408
409  Args:
410    activation: Activation function, such as `tf.nn.relu`, or string name of
411      built-in activation function, such as "relu".
412
413  Usage:
414
415  >>> layer = tf.keras.layers.Activation('relu')
416  >>> output = layer([-3.0, -1.0, 0.0, 2.0])
417  >>> list(output.numpy())
418  [0.0, 0.0, 0.0, 2.0]
419  >>> layer = tf.keras.layers.Activation(tf.nn.relu)
420  >>> output = layer([-3.0, -1.0, 0.0, 2.0])
421  >>> list(output.numpy())
422  [0.0, 0.0, 0.0, 2.0]
423
424  Input shape:
425    Arbitrary. Use the keyword argument `input_shape`
426    (tuple of integers, does not include the batch axis)
427    when using this layer as the first layer in a model.
428
429  Output shape:
430    Same shape as input.
431  """
432
433  def __init__(self, activation, **kwargs):
434    super(Activation, self).__init__(**kwargs)
435    self.supports_masking = True
436    self.activation = activations.get(activation)
437
438  def call(self, inputs):
439    return self.activation(inputs)
440
441  def compute_output_shape(self, input_shape):
442    return input_shape
443
444  def get_config(self):
445    config = {'activation': activations.serialize(self.activation)}
446    base_config = super(Activation, self).get_config()
447    return dict(list(base_config.items()) + list(config.items()))
448
449
450@keras_export('keras.layers.Reshape')
451class Reshape(Layer):
452  """Layer that reshapes inputs into the given shape.
453
454  Input shape:
455    Arbitrary, although all dimensions in the input shape must be known/fixed.
456    Use the keyword argument `input_shape` (tuple of integers, does not include
457    the samples/batch size axis) when using this layer as the first layer
458    in a model.
459
460  Output shape:
461    `(batch_size,) + target_shape`
462
463  Example:
464
465  >>> # as first layer in a Sequential model
466  >>> model = tf.keras.Sequential()
467  >>> model.add(tf.keras.layers.Reshape((3, 4), input_shape=(12,)))
468  >>> # model.output_shape == (None, 3, 4), `None` is the batch size.
469  >>> model.output_shape
470  (None, 3, 4)
471
472  >>> # as intermediate layer in a Sequential model
473  >>> model.add(tf.keras.layers.Reshape((6, 2)))
474  >>> model.output_shape
475  (None, 6, 2)
476
477  >>> # also supports shape inference using `-1` as dimension
478  >>> model.add(tf.keras.layers.Reshape((-1, 2, 2)))
479  >>> model.output_shape
480  (None, 3, 2, 2)
481  """
482
483  def __init__(self, target_shape, **kwargs):
484    """Creates a `tf.keras.layers.Reshape`  layer instance.
485
486    Args:
487      target_shape: Target shape. Tuple of integers, does not include the
488        samples dimension (batch size).
489      **kwargs: Any additional layer keyword arguments.
490    """
491    super(Reshape, self).__init__(**kwargs)
492    self.target_shape = tuple(target_shape)
493
494  def _fix_unknown_dimension(self, input_shape, output_shape):
495    """Find and replace a missing dimension in an output shape.
496
497    This is a near direct port of the internal Numpy function
498    `_fix_unknown_dimension` in `numpy/core/src/multiarray/shape.c`
499
500    Args:
501      input_shape: Shape of array being reshaped
502      output_shape: Desired shape of the array with at most
503        a single -1 which indicates a dimension that should be
504        derived from the input shape.
505
506    Returns:
507      The new output shape with a -1 replaced with its computed value.
508
509    Raises:
510      ValueError: If the total array size of the output_shape is
511      different than the input_shape, or more than one unknown dimension
512      is specified.
513    """
514    output_shape = list(output_shape)
515    msg = ('total size of new array must be unchanged, '
516           'input_shape = {}, output_shape = {}'
517           .format(input_shape, output_shape))
518
519    known, unknown = 1, None
520    for index, dim in enumerate(output_shape):
521      if dim < 0:
522        if unknown is None:
523          unknown = index
524        else:
525          raise ValueError('Can only specify one unknown dimension.')
526      else:
527        known *= dim
528
529    original = np.prod(input_shape, dtype=int)
530    if unknown is not None:
531      if known == 0 or original % known != 0:
532        raise ValueError(msg)
533      output_shape[unknown] = original // known
534    elif original != known:
535      raise ValueError(msg)
536    return output_shape
537
538  def compute_output_shape(self, input_shape):
539    input_shape = tensor_shape.TensorShape(input_shape).as_list()
540    if None in input_shape[1:]:
541      output_shape = [input_shape[0]]
542      # input shape (partially) unknown? replace -1's with None's
543      output_shape += tuple(s if s != -1 else None for s in self.target_shape)
544    else:
545      output_shape = [input_shape[0]]
546      output_shape += self._fix_unknown_dimension(input_shape[1:],
547                                                  self.target_shape)
548    return tensor_shape.TensorShape(output_shape)
549
550  def call(self, inputs):
551    result = array_ops.reshape(
552        inputs, (array_ops.shape(inputs)[0],) + self.target_shape)
553    if not context.executing_eagerly():
554      # Set the static shape for the result since it might lost during array_ops
555      # reshape, eg, some `None` dim in the result could be inferred.
556      result.set_shape(self.compute_output_shape(inputs.shape))
557    return result
558
559  def get_config(self):
560    config = {'target_shape': self.target_shape}
561    base_config = super(Reshape, self).get_config()
562    return dict(list(base_config.items()) + list(config.items()))
563
564
565@keras_export('keras.layers.Permute')
566class Permute(Layer):
567  """Permutes the dimensions of the input according to a given pattern.
568
569  Useful e.g. connecting RNNs and convnets.
570
571  Example:
572
573  ```python
574  model = Sequential()
575  model.add(Permute((2, 1), input_shape=(10, 64)))
576  # now: model.output_shape == (None, 64, 10)
577  # note: `None` is the batch dimension
578  ```
579
580  Args:
581    dims: Tuple of integers. Permutation pattern does not include the
582      samples dimension. Indexing starts at 1.
583      For instance, `(2, 1)` permutes the first and second dimensions
584      of the input.
585
586  Input shape:
587    Arbitrary. Use the keyword argument `input_shape`
588    (tuple of integers, does not include the samples axis)
589    when using this layer as the first layer in a model.
590
591  Output shape:
592    Same as the input shape, but with the dimensions re-ordered according
593    to the specified pattern.
594  """
595
596  def __init__(self, dims, **kwargs):
597    super(Permute, self).__init__(**kwargs)
598    self.dims = tuple(dims)
599    if sorted(dims) != list(range(1, len(dims) + 1)):
600      raise ValueError(
601          'Invalid permutation `dims` for Permute Layer: %s. '
602          'The set of indices in `dims` must be consecutive and start from 1.' %
603          (dims,))
604    self.input_spec = InputSpec(ndim=len(self.dims) + 1)
605
606  def compute_output_shape(self, input_shape):
607    input_shape = tensor_shape.TensorShape(input_shape).as_list()
608    output_shape = copy.copy(input_shape)
609    for i, dim in enumerate(self.dims):
610      target_dim = input_shape[dim]
611      output_shape[i + 1] = target_dim
612    return tensor_shape.TensorShape(output_shape)
613
614  def call(self, inputs):
615    return array_ops.transpose(inputs, perm=(0,) + self.dims)
616
617  def get_config(self):
618    config = {'dims': self.dims}
619    base_config = super(Permute, self).get_config()
620    return dict(list(base_config.items()) + list(config.items()))
621
622
623@keras_export('keras.layers.Flatten')
624class Flatten(Layer):
625  """Flattens the input. Does not affect the batch size.
626
627  Note: If inputs are shaped `(batch,)` without a feature axis, then
628  flattening adds an extra channel dimension and output shape is `(batch, 1)`.
629
630  Args:
631    data_format: A string,
632      one of `channels_last` (default) or `channels_first`.
633      The ordering of the dimensions in the inputs.
634      `channels_last` corresponds to inputs with shape
635      `(batch, ..., channels)` while `channels_first` corresponds to
636      inputs with shape `(batch, channels, ...)`.
637      It defaults to the `image_data_format` value found in your
638      Keras config file at `~/.keras/keras.json`.
639      If you never set it, then it will be "channels_last".
640
641  Example:
642
643  >>> model = tf.keras.Sequential()
644  >>> model.add(tf.keras.layers.Conv2D(64, 3, 3, input_shape=(3, 32, 32)))
645  >>> model.output_shape
646  (None, 1, 10, 64)
647
648  >>> model.add(Flatten())
649  >>> model.output_shape
650  (None, 640)
651
652  """
653
654  def __init__(self, data_format=None, **kwargs):
655    super(Flatten, self).__init__(**kwargs)
656    self.data_format = conv_utils.normalize_data_format(data_format)
657    self.input_spec = InputSpec(min_ndim=1)
658    self._channels_first = self.data_format == 'channels_first'
659
660  def call(self, inputs):
661    if self._channels_first:
662      rank = inputs.shape.rank
663      if rank and rank > 1:
664        # Switch to channels-last format.
665        permutation = [0]
666        permutation.extend(range(2, rank))
667        permutation.append(1)
668        inputs = array_ops.transpose(inputs, perm=permutation)
669
670    if context.executing_eagerly():
671      # Full static shape is guaranteed to be available.
672      # Performance: Using `constant_op` is much faster than passing a list.
673      flattened_shape = constant_op.constant([inputs.shape[0], -1])
674      return array_ops.reshape(inputs, flattened_shape)
675    else:
676      input_shape = inputs.shape
677      rank = input_shape.rank
678      if rank == 1:
679        return array_ops.expand_dims_v2(inputs, axis=1)
680      else:
681        batch_dim = tensor_shape.dimension_value(input_shape[0])
682        non_batch_dims = input_shape[1:]
683        # Reshape in a way that preserves as much shape info as possible.
684        if non_batch_dims.is_fully_defined():
685          last_dim = int(functools.reduce(operator.mul, non_batch_dims))
686          flattened_shape = constant_op.constant([-1, last_dim])
687        elif batch_dim is not None:
688          flattened_shape = constant_op.constant([int(batch_dim), -1])
689        else:
690          flattened_shape = [array_ops.shape_v2(inputs)[0], -1]
691        return array_ops.reshape(inputs, flattened_shape)
692
693  def compute_output_shape(self, input_shape):
694    input_shape = tensor_shape.TensorShape(input_shape).as_list()
695    if not input_shape:
696      output_shape = tensor_shape.TensorShape([1])
697    else:
698      output_shape = [input_shape[0]]
699    if np.all(input_shape[1:]):
700      output_shape += [np.prod(input_shape[1:], dtype=int)]
701    else:
702      output_shape += [None]
703    return tensor_shape.TensorShape(output_shape)
704
705  def get_config(self):
706    config = super(Flatten, self).get_config()
707    config.update({'data_format': self.data_format})
708    return config
709
710
711@keras_export('keras.layers.RepeatVector')
712class RepeatVector(Layer):
713  """Repeats the input n times.
714
715  Example:
716
717  ```python
718  model = Sequential()
719  model.add(Dense(32, input_dim=32))
720  # now: model.output_shape == (None, 32)
721  # note: `None` is the batch dimension
722
723  model.add(RepeatVector(3))
724  # now: model.output_shape == (None, 3, 32)
725  ```
726
727  Args:
728    n: Integer, repetition factor.
729
730  Input shape:
731    2D tensor of shape `(num_samples, features)`.
732
733  Output shape:
734    3D tensor of shape `(num_samples, n, features)`.
735  """
736
737  def __init__(self, n, **kwargs):
738    super(RepeatVector, self).__init__(**kwargs)
739    self.n = n
740    self.input_spec = InputSpec(ndim=2)
741
742  def compute_output_shape(self, input_shape):
743    input_shape = tensor_shape.TensorShape(input_shape).as_list()
744    return tensor_shape.TensorShape([input_shape[0], self.n, input_shape[1]])
745
746  def call(self, inputs):
747    return K.repeat(inputs, self.n)
748
749  def get_config(self):
750    config = {'n': self.n}
751    base_config = super(RepeatVector, self).get_config()
752    return dict(list(base_config.items()) + list(config.items()))
753
754
755@keras_export('keras.layers.Lambda')
756class Lambda(Layer):
757  """Wraps arbitrary expressions as a `Layer` object.
758
759  The `Lambda` layer exists so that arbitrary expressions can be used
760  as a `Layer` when constructing `Sequential`
761  and Functional API models. `Lambda` layers are best suited for simple
762  operations or quick experimentation. For more advanced use cases, follow
763  [this guide](https://www.tensorflow.org/guide/keras/custom_layers_and_models)
764  for subclassing `tf.keras.layers.Layer`.
765
766  WARNING: `tf.keras.layers.Lambda` layers have (de)serialization limitations!
767
768  The main reason to subclass `tf.keras.layers.Layer` instead of using a
769  `Lambda` layer is saving and inspecting a Model. `Lambda` layers
770  are saved by serializing the Python bytecode, which is fundamentally
771  non-portable. They should only be loaded in the same environment where
772  they were saved. Subclassed layers can be saved in a more portable way
773  by overriding their `get_config` method. Models that rely on
774  subclassed Layers are also often easier to visualize and reason about.
775
776  Examples:
777
778  ```python
779  # add a x -> x^2 layer
780  model.add(Lambda(lambda x: x ** 2))
781  ```
782  ```python
783  # add a layer that returns the concatenation
784  # of the positive part of the input and
785  # the opposite of the negative part
786
787  def antirectifier(x):
788      x -= K.mean(x, axis=1, keepdims=True)
789      x = K.l2_normalize(x, axis=1)
790      pos = K.relu(x)
791      neg = K.relu(-x)
792      return K.concatenate([pos, neg], axis=1)
793
794  model.add(Lambda(antirectifier))
795  ```
796
797  Variables:
798    While it is possible to use Variables with Lambda layers, this practice is
799    discouraged as it can easily lead to bugs. For instance, consider the
800    following layer:
801
802    ```python
803      scale = tf.Variable(1.)
804      scale_layer = tf.keras.layers.Lambda(lambda x: x * scale)
805    ```
806
807    Because scale_layer does not directly track the `scale` variable, it will
808    not appear in `scale_layer.trainable_weights` and will therefore not be
809    trained if `scale_layer` is used in a Model.
810
811    A better pattern is to write a subclassed Layer:
812
813    ```python
814      class ScaleLayer(tf.keras.layers.Layer):
815        def __init__(self):
816          super(ScaleLayer, self).__init__()
817          self.scale = tf.Variable(1.)
818
819        def call(self, inputs):
820          return inputs * self.scale
821    ```
822
823    In general, Lambda layers can be convenient for simple stateless
824    computation, but anything more complex should use a subclass Layer instead.
825
826  Args:
827    function: The function to be evaluated. Takes input tensor as first
828      argument.
829    output_shape: Expected output shape from function. This argument can be
830      inferred if not explicitly provided. Can be a tuple or function. If a
831      tuple, it only specifies the first dimension onward;
832      sample dimension is assumed either the same as the input: `output_shape =
833        (input_shape[0], ) + output_shape` or, the input is `None` and
834      the sample dimension is also `None`: `output_shape = (None, ) +
835        output_shape` If a function, it specifies the entire shape as a function
836        of the
837      input shape: `output_shape = f(input_shape)`
838    mask: Either None (indicating no masking) or a callable with the same
839      signature as the `compute_mask` layer method, or a tensor that will be
840      returned as output mask regardless of what the input is.
841    arguments: Optional dictionary of keyword arguments to be passed to the
842      function.
843
844  Input shape:
845    Arbitrary. Use the keyword argument input_shape (tuple of
846    integers, does not include the samples axis) when using this layer as the
847    first layer in a model.
848
849  Output shape:
850    Specified by `output_shape` argument
851  """
852
853  @trackable.no_automatic_dependency_tracking
854  def __init__(self, function, output_shape=None, mask=None, arguments=None,
855               **kwargs):
856    super(Lambda, self).__init__(**kwargs)
857
858    self.arguments = arguments or {}
859    self.function = function
860
861    if mask is not None:
862      self.supports_masking = True
863    self.mask = mask
864    self._output_shape = output_shape
865
866    # Warning on every invocation will be quite irksome in Eager mode.
867    self._already_warned = False
868
869    function_args = tf_inspect.getfullargspec(function).args
870    self._fn_expects_training_arg = 'training' in function_args
871    self._fn_expects_mask_arg = 'mask' in function_args
872
873  @tf_utils.shape_type_conversion
874  def compute_output_shape(self, input_shape):
875    if self._output_shape is None:
876      # Make use of existing autocomputation but provide Lambda-specific
877      # error message. This is always safe to run even when the outer context
878      # is Graph mode because Lambda layers don't have side effects such as
879      # `add_loss`.
880      with context.eager_mode():
881        try:
882          return super(Lambda, self).compute_output_shape(input_shape)
883        except NotImplementedError:
884          raise NotImplementedError(
885              'We could not automatically infer the shape of the Lambda\'s '
886              'output. Please specify `output_shape` for this Lambda.')
887
888    if callable(self._output_shape):
889      output_shapes = self._output_shape(input_shape)
890      return tf_utils.convert_shapes(output_shapes, to_tuples=False)
891
892    # Output shapes are passed directly and don't include batch dimension.
893    input_tensor_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
894    batch_size = nest.flatten(input_tensor_shape)[0][0] if input_shape else None
895
896    def _add_batch(shape):
897      return tensor_shape.TensorShape([batch_size] + shape.as_list())
898
899    output_shapes = tf_utils.convert_shapes(self._output_shape, to_tuples=False)
900    return nest.map_structure(_add_batch, output_shapes)
901
902  def call(self, inputs, mask=None, training=None):
903    # We must copy for thread safety, but it only needs to be a shallow copy.
904    kwargs = {k: v for k, v in self.arguments.items()}
905    if self._fn_expects_mask_arg:
906      kwargs['mask'] = mask
907    if self._fn_expects_training_arg:
908      kwargs['training'] = training
909
910    created_variables = []
911    def _variable_creator(next_creator, **kwargs):
912      var = next_creator(**kwargs)
913      created_variables.append(var)
914      return var
915
916    with backprop.GradientTape(watch_accessed_variables=True) as tape,\
917        variable_scope.variable_creator_scope(_variable_creator):
918      result = self.function(inputs, **kwargs)
919    self._check_variables(created_variables, tape.watched_variables())
920    return result
921
922  def _check_variables(self, created_variables, accessed_variables):
923    if not created_variables and not accessed_variables:
924      # In the common case that a Lambda layer does not touch a Variable, we
925      # don't want to incur the runtime cost of assembling any state used for
926      # checking only to immediately discard it.
927      return
928
929    tracked_weights = set(v.ref() for v in self.weights)
930    untracked_new_vars = [
931        v for v in created_variables if v.ref() not in tracked_weights
932    ]
933    if untracked_new_vars:
934      variable_str = '\n'.join('  {}'.format(i) for i in untracked_new_vars)
935      error_str = textwrap.dedent(
936          '''
937          The following Variables were created within a Lambda layer ({name})
938          but are not tracked by said layer:
939          {variable_str}
940          The layer cannot safely ensure proper Variable reuse across multiple
941          calls, and consquently this behavior is disallowed for safety. Lambda
942          layers are not well suited to stateful computation; instead, writing a
943          subclassed Layer is the recommend way to define layers with
944          Variables.'''
945      ).format(name=self.name, variable_str=variable_str)
946      raise ValueError(error_str)
947
948    untracked_used_vars = [
949        v for v in accessed_variables if v.ref() not in tracked_weights
950    ]
951    if untracked_used_vars and not self._already_warned:
952      variable_str = '\n'.join('  {}'.format(i) for i in untracked_used_vars)
953      self._warn(textwrap.dedent(
954          '''
955          The following Variables were used a Lambda layer's call ({name}), but
956          are not present in its tracked objects:
957          {variable_str}
958          It is possible that this is intended behavior, but it is more likely
959          an omission. This is a strong indication that this layer should be
960          formulated as a subclassed Layer rather than a Lambda layer.'''
961      ).format(name=self.name, variable_str=variable_str))
962      self._already_warned = True
963
964  def _warn(self, msg):
965    # This method will be overridden in a unit test to raise an error, because
966    # self.assertWarns is not universally implemented.
967    return tf_logging.warn(msg)
968
969  def compute_mask(self, inputs, mask=None):
970    if callable(self.mask):
971      return self.mask(inputs, mask)
972    return self.mask
973
974  def get_config(self):
975    function_config = self._serialize_function_to_config(self.function)
976    output_shape_config = self._serialize_function_to_config(self._output_shape,
977                                                             allow_raw=True)
978    config = {
979        'function': function_config[0],
980        'function_type': function_config[1],
981        'module': function_config[2],
982        'output_shape': output_shape_config[0],
983        'output_shape_type': output_shape_config[1],
984        'output_shape_module': output_shape_config[2],
985    }
986    if self.mask is not None:
987      mask_config = self._serialize_function_to_config(self.mask)
988      config.update({
989          'mask': mask_config[0],
990          'mask_type': mask_config[1],
991          'mask_module': mask_config[2]
992      })
993    config['arguments'] = self.arguments
994
995    base_config = super(Lambda, self).get_config()
996    return dict(list(base_config.items()) + list(config.items()))
997
998  def _serialize_function_to_config(self, inputs, allow_raw=False):
999    if isinstance(inputs, python_types.LambdaType):
1000      output = generic_utils.func_dump(inputs)
1001      output_type = 'lambda'
1002      module = inputs.__module__
1003    elif callable(inputs):
1004      output = inputs.__name__
1005      output_type = 'function'
1006      module = inputs.__module__
1007    elif allow_raw:
1008      output = inputs
1009      output_type = 'raw'
1010      module = None
1011    else:
1012      raise ValueError(
1013          'Invalid input for serialization, type: %s ' % type(inputs))
1014
1015    return output, output_type, module
1016
1017  @classmethod
1018  def from_config(cls, config, custom_objects=None):
1019    config = config.copy()
1020    function = cls._parse_function_from_config(
1021        config, custom_objects, 'function', 'module', 'function_type')
1022
1023    output_shape = cls._parse_function_from_config(
1024        config, custom_objects, 'output_shape', 'output_shape_module',
1025        'output_shape_type')
1026    if 'mask' in config:
1027      mask = cls._parse_function_from_config(
1028          config, custom_objects, 'mask', 'mask_module', 'mask_type')
1029    else:
1030      mask = None
1031
1032    config['function'] = function
1033    config['output_shape'] = output_shape
1034    config['mask'] = mask
1035
1036    # If arguments were numpy array, they have been saved as
1037    # list. We need to recover the ndarray
1038    if 'arguments' in config:
1039      for key in config['arguments']:
1040        if isinstance(config['arguments'][key], dict):
1041          arg_dict = config['arguments'][key]
1042          if 'type' in arg_dict and arg_dict['type'] == 'ndarray':
1043            # Overwrite the argument with its numpy translation
1044            config['arguments'][key] = np.array(arg_dict['value'])
1045
1046    return cls(**config)
1047
1048  @classmethod
1049  def _parse_function_from_config(
1050      cls, config, custom_objects, func_attr_name, module_attr_name,
1051      func_type_attr_name):
1052    globs = globals().copy()
1053    module = config.pop(module_attr_name, None)
1054    if module in sys.modules:
1055      globs.update(sys.modules[module].__dict__)
1056    elif module is not None:
1057      # Note: we don't know the name of the function if it's a lambda.
1058      warnings.warn('{} is not loaded, but a Lambda layer uses it. '
1059                    'It may cause errors.'.format(module)
1060                    , UserWarning)
1061    if custom_objects:
1062      globs.update(custom_objects)
1063    function_type = config.pop(func_type_attr_name)
1064    if function_type == 'function':
1065      # Simple lookup in custom objects
1066      function = generic_utils.deserialize_keras_object(
1067          config[func_attr_name],
1068          custom_objects=custom_objects,
1069          printable_module_name='function in Lambda layer')
1070    elif function_type == 'lambda':
1071      # Unsafe deserialization from bytecode
1072      function = generic_utils.func_load(
1073          config[func_attr_name], globs=globs)
1074    elif function_type == 'raw':
1075      function = config[func_attr_name]
1076    else:
1077      raise TypeError('Unknown function type:', function_type)
1078    return function
1079
1080
1081@keras_export('keras.layers.Dense')
1082class Dense(Layer):
1083  """Just your regular densely-connected NN layer.
1084
1085  `Dense` implements the operation:
1086  `output = activation(dot(input, kernel) + bias)`
1087  where `activation` is the element-wise activation function
1088  passed as the `activation` argument, `kernel` is a weights matrix
1089  created by the layer, and `bias` is a bias vector created by the layer
1090  (only applicable if `use_bias` is `True`).
1091
1092  Note: If the input to the layer has a rank greater than 2, then `Dense`
1093  computes the dot product between the `inputs` and the `kernel` along the
1094  last axis of the `inputs` and axis 1 of the `kernel` (using `tf.tensordot`).
1095  For example, if input has dimensions `(batch_size, d0, d1)`,
1096  then we create a `kernel` with shape `(d1, units)`, and the `kernel` operates
1097  along axis 2 of the `input`, on every sub-tensor of shape `(1, 1, d1)`
1098  (there are `batch_size * d0` such sub-tensors).
1099  The output in this case will have shape `(batch_size, d0, units)`.
1100
1101  Besides, layer attributes cannot be modified after the layer has been called
1102  once (except the `trainable` attribute).
1103
1104  Example:
1105
1106  >>> # Create a `Sequential` model and add a Dense layer as the first layer.
1107  >>> model = tf.keras.models.Sequential()
1108  >>> model.add(tf.keras.Input(shape=(16,)))
1109  >>> model.add(tf.keras.layers.Dense(32, activation='relu'))
1110  >>> # Now the model will take as input arrays of shape (None, 16)
1111  >>> # and output arrays of shape (None, 32).
1112  >>> # Note that after the first layer, you don't need to specify
1113  >>> # the size of the input anymore:
1114  >>> model.add(tf.keras.layers.Dense(32))
1115  >>> model.output_shape
1116  (None, 32)
1117
1118  Args:
1119    units: Positive integer, dimensionality of the output space.
1120    activation: Activation function to use.
1121      If you don't specify anything, no activation is applied
1122      (ie. "linear" activation: `a(x) = x`).
1123    use_bias: Boolean, whether the layer uses a bias vector.
1124    kernel_initializer: Initializer for the `kernel` weights matrix.
1125    bias_initializer: Initializer for the bias vector.
1126    kernel_regularizer: Regularizer function applied to
1127      the `kernel` weights matrix.
1128    bias_regularizer: Regularizer function applied to the bias vector.
1129    activity_regularizer: Regularizer function applied to
1130      the output of the layer (its "activation").
1131    kernel_constraint: Constraint function applied to
1132      the `kernel` weights matrix.
1133    bias_constraint: Constraint function applied to the bias vector.
1134
1135  Input shape:
1136    N-D tensor with shape: `(batch_size, ..., input_dim)`.
1137    The most common situation would be
1138    a 2D input with shape `(batch_size, input_dim)`.
1139
1140  Output shape:
1141    N-D tensor with shape: `(batch_size, ..., units)`.
1142    For instance, for a 2D input with shape `(batch_size, input_dim)`,
1143    the output would have shape `(batch_size, units)`.
1144  """
1145
1146  def __init__(self,
1147               units,
1148               activation=None,
1149               use_bias=True,
1150               kernel_initializer='glorot_uniform',
1151               bias_initializer='zeros',
1152               kernel_regularizer=None,
1153               bias_regularizer=None,
1154               activity_regularizer=None,
1155               kernel_constraint=None,
1156               bias_constraint=None,
1157               **kwargs):
1158    super(Dense, self).__init__(
1159        activity_regularizer=activity_regularizer, **kwargs)
1160
1161    self.units = int(units) if not isinstance(units, int) else units
1162    self.activation = activations.get(activation)
1163    self.use_bias = use_bias
1164    self.kernel_initializer = initializers.get(kernel_initializer)
1165    self.bias_initializer = initializers.get(bias_initializer)
1166    self.kernel_regularizer = regularizers.get(kernel_regularizer)
1167    self.bias_regularizer = regularizers.get(bias_regularizer)
1168    self.kernel_constraint = constraints.get(kernel_constraint)
1169    self.bias_constraint = constraints.get(bias_constraint)
1170
1171    self.input_spec = InputSpec(min_ndim=2)
1172    self.supports_masking = True
1173
1174  def build(self, input_shape):
1175    dtype = dtypes.as_dtype(self.dtype or K.floatx())
1176    if not (dtype.is_floating or dtype.is_complex):
1177      raise TypeError('Unable to build `Dense` layer with non-floating point '
1178                      'dtype %s' % (dtype,))
1179
1180    input_shape = tensor_shape.TensorShape(input_shape)
1181    last_dim = tensor_shape.dimension_value(input_shape[-1])
1182    if last_dim is None:
1183      raise ValueError('The last dimension of the inputs to `Dense` '
1184                       'should be defined. Found `None`.')
1185    self.input_spec = InputSpec(min_ndim=2, axes={-1: last_dim})
1186    self.kernel = self.add_weight(
1187        'kernel',
1188        shape=[last_dim, self.units],
1189        initializer=self.kernel_initializer,
1190        regularizer=self.kernel_regularizer,
1191        constraint=self.kernel_constraint,
1192        dtype=self.dtype,
1193        trainable=True)
1194    if self.use_bias:
1195      self.bias = self.add_weight(
1196          'bias',
1197          shape=[self.units,],
1198          initializer=self.bias_initializer,
1199          regularizer=self.bias_regularizer,
1200          constraint=self.bias_constraint,
1201          dtype=self.dtype,
1202          trainable=True)
1203    else:
1204      self.bias = None
1205    self.built = True
1206
1207  def call(self, inputs):
1208    return core_ops.dense(
1209        inputs,
1210        self.kernel,
1211        self.bias,
1212        self.activation,
1213        dtype=self._compute_dtype_object)
1214
1215  def compute_output_shape(self, input_shape):
1216    input_shape = tensor_shape.TensorShape(input_shape)
1217    input_shape = input_shape.with_rank_at_least(2)
1218    if tensor_shape.dimension_value(input_shape[-1]) is None:
1219      raise ValueError(
1220          'The innermost dimension of input_shape must be defined, but saw: %s'
1221          % input_shape)
1222    return input_shape[:-1].concatenate(self.units)
1223
1224  def get_config(self):
1225    config = super(Dense, self).get_config()
1226    config.update({
1227        'units':
1228            self.units,
1229        'activation':
1230            activations.serialize(self.activation),
1231        'use_bias':
1232            self.use_bias,
1233        'kernel_initializer':
1234            initializers.serialize(self.kernel_initializer),
1235        'bias_initializer':
1236            initializers.serialize(self.bias_initializer),
1237        'kernel_regularizer':
1238            regularizers.serialize(self.kernel_regularizer),
1239        'bias_regularizer':
1240            regularizers.serialize(self.bias_regularizer),
1241        'activity_regularizer':
1242            regularizers.serialize(self.activity_regularizer),
1243        'kernel_constraint':
1244            constraints.serialize(self.kernel_constraint),
1245        'bias_constraint':
1246            constraints.serialize(self.bias_constraint)
1247    })
1248    return config
1249
1250
1251@keras_export('keras.layers.ActivityRegularization')
1252class ActivityRegularization(Layer):
1253  """Layer that applies an update to the cost function based input activity.
1254
1255  Args:
1256    l1: L1 regularization factor (positive float).
1257    l2: L2 regularization factor (positive float).
1258
1259  Input shape:
1260    Arbitrary. Use the keyword argument `input_shape`
1261    (tuple of integers, does not include the samples axis)
1262    when using this layer as the first layer in a model.
1263
1264  Output shape:
1265    Same shape as input.
1266  """
1267
1268  def __init__(self, l1=0., l2=0., **kwargs):
1269    super(ActivityRegularization, self).__init__(
1270        activity_regularizer=regularizers.L1L2(l1=l1, l2=l2), **kwargs)
1271    self.supports_masking = True
1272    self.l1 = l1
1273    self.l2 = l2
1274
1275  def compute_output_shape(self, input_shape):
1276    return input_shape
1277
1278  def get_config(self):
1279    config = {'l1': self.l1, 'l2': self.l2}
1280    base_config = super(ActivityRegularization, self).get_config()
1281    return dict(list(base_config.items()) + list(config.items()))
1282
1283
1284class TFOpLambda(Layer):
1285  """Wraps TF API symbols in a `Layer` object.
1286
1287  It is inserted by the Functional API construction whenever users call
1288  a supported TF symbol on KerasTensors.
1289
1290  Like Lambda layers, this layer tries to raise warnings when it detects users
1291  explicitly use variables in the call. (To let them know
1292  that the layer will not capture the variables).
1293
1294  This is useful in the case where users do something like:
1295  x = keras.Input(...)
1296  y = tf.Variable(...)
1297  out = x * tf_variable
1298  """
1299
1300  @trackable.no_automatic_dependency_tracking
1301  def __init__(self, function, **kwargs):
1302    self.function = function
1303    self.symbol = (
1304        get_canonical_name_for_symbol(self.function,
1305                                      add_prefix_to_v1_names=True) or
1306        get_canonical_name_for_symbol(self.function,
1307                                      api_name='keras',
1308                                      add_prefix_to_v1_names=True))
1309    if 'name' not in kwargs:
1310      # Generate a name.
1311      # TFOpLambda layers avoid already-observed names,
1312      # because users cannot easily control the generated names.
1313      # Without this avoidance, users would be more likely to run
1314      # into unavoidable duplicate layer name collisions.
1315      # (For standard layers users could just set `name` when creating the
1316      # layer to work around a collision, but they can't do that for
1317      # auto-generated layers)
1318      if self.symbol:
1319        name = 'tf.' + self.symbol
1320      else:
1321        name = self.function.__name__
1322      kwargs['name'] = K.unique_object_name(
1323          name, zero_based=True, avoid_observed_names=True)
1324    kwargs['autocast'] = False
1325
1326    # Decorate the function to produce this layer's call method
1327    def _call_wrapper(*args, **kwargs):
1328      return self._call_wrapper(*args, **kwargs)
1329    self.call = tf_decorator.make_decorator(function, _call_wrapper)
1330
1331    # Do not individually trace op layers in the SavedModel.
1332    self._must_restore_from_config = True
1333
1334    super(TFOpLambda, self).__init__(**kwargs)
1335
1336    # Preserve all argument data structures when saving/loading a config
1337    # (e.g., don't unnest lists that contain one element)
1338    self._preserve_input_structure_in_config = True
1339
1340    # Warning on every invocation will be quite irksome in Eager mode.
1341    self._already_warned = False
1342
1343    self._expects_training_arg = False
1344    self._expects_mask_arg = False
1345
1346  def _call_wrapper(self, *args, **kwargs):
1347    created_variables = []
1348    def _variable_creator(next_creator, **creator_kwargs):
1349      var = next_creator(**creator_kwargs)
1350      created_variables.append(var)
1351      return var
1352
1353    with backprop.GradientTape(watch_accessed_variables=True) as tape, \
1354        variable_scope.variable_creator_scope(_variable_creator):
1355      # We explicitly drop `name` arguments here,
1356      # to guard against the case where an op explicitly has a
1357      # `name` passed (which is susceptible to producing
1358      # multiple ops w/ the same name when the layer is reused)
1359      kwargs.pop('name', None)
1360      result = self.function(*args, **kwargs)
1361    self._check_variables(created_variables, tape.watched_variables())
1362    return result
1363
1364  def _check_variables(self, created_variables, accessed_variables):
1365    if not created_variables and not accessed_variables:
1366      # In the common case that a Lambda layer does not touch a Variable, we
1367      # don't want to incur the runtime cost of assembling any state used for
1368      # checking only to immediately discard it.
1369      return
1370
1371    tracked_weights = set(v.ref() for v in self.weights)
1372    untracked_new_vars = [
1373        v for v in created_variables if v.ref() not in tracked_weights
1374    ]
1375    if untracked_new_vars:
1376      variable_str = '\n'.join('  {}'.format(i) for i in untracked_new_vars)
1377      error_str = textwrap.dedent(
1378          '''
1379          The following Variables were created within a Lambda layer ({name})
1380          but are not tracked by said layer:
1381          {variable_str}
1382          The layer cannot safely ensure proper Variable reuse across multiple
1383          calls, and consquently this behavior is disallowed for safety. Lambda
1384          layers are not well suited to stateful computation; instead, writing a
1385          subclassed Layer is the recommend way to define layers with
1386          Variables.'''
1387      ).format(name=self.name, variable_str=variable_str)
1388      raise ValueError(error_str)
1389
1390    untracked_used_vars = [
1391        v for v in accessed_variables if v.ref() not in tracked_weights
1392    ]
1393    if untracked_used_vars and not self._already_warned:
1394      variable_str = '\n'.join('  {}'.format(i) for i in untracked_used_vars)
1395      self._warn(textwrap.dedent(
1396          '''
1397          The following Variables were used a Lambda layer's call ({name}), but
1398          are not present in its tracked objects:
1399          {variable_str}
1400          It is possible that this is intended behavior, but it is more likely
1401          an omission. This is a strong indication that this layer should be
1402          formulated as a subclassed Layer rather than a Lambda layer.'''
1403      ).format(name=self.name, variable_str=variable_str))
1404      self._already_warned = True
1405
1406  def _warn(self, msg):
1407    # This method will be overridden in a unit test to raise an error, because
1408    # self.assertWarns is not universally implemented.
1409    return tf_logging.warn(msg)
1410
1411  def get_config(self):
1412    if not self.symbol:
1413      raise ValueError('This Keras op layer was generated from %s, a method '
1414                       'that is not an exposed in the TensorFlow API. This '
1415                       'may have happened if the method was explicitly '
1416                       'decorated to add dispatching support, and it was used '
1417                       'during Functional model construction. '
1418                       'To ensure cross-version compatibility of Keras models '
1419                       'that use op layers, only op layers produced from '
1420                       'exported TF API symbols can be serialized.'
1421                       % self.function)
1422    config = {
1423        'function': self.symbol
1424    }
1425
1426    base_config = super(TFOpLambda, self).get_config()
1427    return dict(list(base_config.items()) + list(config.items()))
1428
1429  @classmethod
1430  def from_config(cls, config, custom_objects=None):
1431    config = config.copy()
1432    symbol_name = config['function']
1433    function = get_symbol_from_name(symbol_name)
1434    if not function:
1435      raise ValueError(
1436          'TF symbol `tf.%s` could not be found.' % symbol_name)
1437
1438    config['function'] = function
1439
1440    return cls(**config)
1441
1442
1443class KerasOpDispatcher(dispatch.GlobalOpDispatcher):
1444  """A global dispatcher that allows building a functional model with TF Ops."""
1445
1446  def handle(self, op, args, kwargs):
1447    """Handle the specified operation with the specified arguments."""
1448    if any(
1449        isinstance(x, keras_tensor.KerasTensor)
1450        for x in nest.flatten([args, kwargs])):
1451      return TFOpLambda(op)(*args, **kwargs)
1452    else:
1453      return self.NOT_SUPPORTED
1454
1455KerasOpDispatcher().register()
1456
1457
1458def _slice_to_dict(x):
1459  if isinstance(x, slice):
1460    return {'start': x.start, 'stop': x.stop, 'step': x.step}
1461  return x
1462
1463
1464def _dict_to_slice(x):
1465  if isinstance(x, dict):
1466    return slice(x['start'], x['stop'], x['step'])
1467  return x
1468
1469
1470class SlicingOpLambda(TFOpLambda):
1471  """Wraps TF API symbols in a `Layer` object.
1472
1473  It is inserted by the Functional API construction whenever users call
1474  a supported TF symbol on KerasTensors.
1475
1476  Like Lambda layers, this layer tries to raise warnings when it detects users
1477  explicitly use variables in the call. (To let them know
1478  that the layer will not capture the variables).
1479
1480  This is useful in the case where users do something like:
1481  x = keras.Input(...)
1482  y = tf.Variable(...)
1483  out = x * tf_variable
1484  """
1485
1486  @trackable.no_automatic_dependency_tracking
1487  def __init__(self, function, **kwargs):
1488    super(SlicingOpLambda, self).__init__(function, **kwargs)
1489
1490    original_call = self.call
1491    # Decorate the function to produce this layer's call method
1492    def _call_wrapper(*args, **kwargs):
1493      # Turn any slice dicts in the args back into `slice` objects.
1494      # This conversion cannot use nest.flatten/map_structure,
1495      # because dicts are flattened by nest while slices aren't.
1496      # So, map_structure would only see the individual elements in the
1497      # dict.
1498      # This can't use map_structure_up_to either because the 'shallowness' of
1499      # the shallow tree would have to vary depending on if only one dim or
1500      # multiple are being sliced.
1501      new_args = []
1502      for arg in args:
1503        arg = _dict_to_slice(arg)
1504        if isinstance(arg, (list, tuple)):
1505          new_arg = []
1506          for sub_arg in arg:
1507            new_arg.append(_dict_to_slice(sub_arg))
1508          arg = new_arg
1509        new_args.append(arg)
1510
1511      # Handle the kwargs too.
1512      new_kwargs = {}
1513      for key, value in kwargs.items():
1514        value = _dict_to_slice(value)
1515        if isinstance(value, (list, tuple)):
1516          new_value = []
1517          for v in value:
1518            new_value.append(_dict_to_slice(v))
1519          value = new_value
1520        new_kwargs[key] = value
1521
1522      return original_call(*new_args, **new_kwargs)
1523    self.call = tf_decorator.make_decorator(original_call, _call_wrapper)
1524
1525
1526class TFSlicingOpDispatcher(dispatch.OpDispatcher):
1527  """A global dispatcher that allows building a functional model with TF Ops."""
1528
1529  def __init__(self, op):
1530    self.op = op
1531
1532  def handle(self, args, kwargs):
1533    """Handle the specified operation with the specified arguments."""
1534    args = nest.map_structure(_slice_to_dict, args)
1535    kwargs = nest.map_structure(_slice_to_dict, kwargs)
1536    if any(
1537        isinstance(x, keras_tensor.KerasTensor)
1538        for x in nest.flatten([args, kwargs])):
1539      return SlicingOpLambda(self.op)(*args, **kwargs)
1540    else:
1541      return self.NOT_SUPPORTED
1542
1543for slicing_op in [array_ops._slice_helper,  # pylint: disable=protected-access
1544                   array_ops.boolean_mask,
1545                   array_ops.boolean_mask_v2]:
1546  TFSlicingOpDispatcher(slicing_op).register(slicing_op)
1547
1548
1549class InstanceProperty(Layer):
1550  """Wraps an instance property access (e.g. `x.foo`) in a Keras Layer.
1551
1552  This layer takes an attribute name `attr_name` in the constructor and,
1553  when called on input tensor `obj` returns `obj.attr_name`.
1554
1555  KerasTensors specialized for specific extension types use it to
1556  represent instance property accesses on the represented object in the
1557  case where the property needs to be dynamically accessed as opposed to
1558  being statically computed from the typespec, e.g.
1559
1560  x = keras.Input(..., ragged=True)
1561  out = x.flat_values
1562  """
1563
1564  @trackable.no_automatic_dependency_tracking
1565  def __init__(self, attr_name, **kwargs):
1566    self.attr_name = attr_name
1567
1568    if 'name' not in kwargs:
1569      kwargs['name'] = K.unique_object_name(
1570          'input.' + self.attr_name, zero_based=True, avoid_observed_names=True)
1571    kwargs['autocast'] = False
1572
1573    # Do not individually trace op layers in the SavedModel.
1574    self._must_restore_from_config = True
1575
1576    super(InstanceProperty, self).__init__(**kwargs)
1577
1578    # Preserve all argument data structures when saving/loading a config
1579    # (e.g., don't unnest lists that contain one element)
1580    self._preserve_input_structure_in_config = True
1581
1582  def call(self, obj):
1583    return getattr(obj, self.attr_name)
1584
1585  def get_config(self):
1586    config = {
1587        'attr_name': self.attr_name
1588    }
1589    base_config = super(InstanceProperty, self).get_config()
1590    return dict(list(base_config.items()) + list(config.items()))
1591
1592  @classmethod
1593  def from_config(cls, config, custom_objects=None):
1594    return cls(**config)
1595
1596
1597class InstanceMethod(InstanceProperty):
1598  """Wraps an instance method access (e.g. `x.foo(arg)` in a Keras Layer.
1599
1600  This layer takes an attribute name `attr_name` in the constructor and,
1601  when called on input tensor `obj` with additional arguments `args` and
1602  `kwargs` returns `obj.attr_name(*args, **kwargs)`.
1603
1604  KerasTensors specialized for specific extension types use it to
1605  represent dynamic instance method calls on the represented object, e.g.
1606
1607  x = keras.Input(..., ragged=True)
1608  new_values = keras.Input(...)
1609  out = x.with_values(new_values)
1610  """
1611
1612  def call(self, obj, args, kwargs):
1613    method = getattr(obj, self.attr_name)
1614    return method(*args, **kwargs)
1615
1616
1617def _delegate_property(keras_tensor_cls, property_name):  # pylint: disable=invalid-name
1618  """Register property on a KerasTensor class.
1619
1620  Calling this multiple times with the same arguments should be a no-op.
1621
1622  This method exposes a property on the KerasTensor class that will use an
1623  `InstanceProperty` layer to access the property on the represented
1624  intermediate values in the model.
1625
1626  Args:
1627    keras_tensor_cls: The KerasTensor subclass that should expose the property.
1628    property_name: The name of the property to expose and delegate to the
1629      represented (Composite)Tensor.
1630  """
1631  # We use a lambda because we can't create a Keras layer at import time
1632  # due to dynamic layer class versioning.
1633  property_access = property(lambda self: InstanceProperty(property_name)(self))  # pylint: disable=unnecessary-lambda
1634  setattr(keras_tensor_cls, property_name, property_access)
1635
1636
1637def _delegate_method(keras_tensor_cls, method_name):  # pylint: disable=invalid-name
1638  """Register method on a KerasTensor class.
1639
1640  Calling this function times with the same arguments should be a no-op.
1641
1642  This method exposes an instance method on the KerasTensor class that will use
1643  an `InstanceMethod` layer to run the desired method on the represented
1644  intermediate values in the model.
1645
1646  Args:
1647    keras_tensor_cls: The KerasTensor subclass that should expose the property.
1648    method_name: The name of the method to expose and delegate to the
1649      represented (Composite)Tensor.
1650  """
1651  def delegate(self, *args, **kwargs):
1652    return InstanceMethod(method_name)(self, args, kwargs)
1653  setattr(keras_tensor_cls, method_name, delegate)
1654
1655# We do not support the `uniform_row_length` property because it
1656# returns either `None` or an int tensor, and code that relies on it tends
1657# to check `is None` directly. Delegating it here would always return a
1658# `KerasTensor`, regardless of what can be statically inferred. This would
1659# never equal `None`, breaking code that expects it to be partially-static
1660# in unpredictable ways.
1661for ragged_property in [
1662    'values',
1663    'flat_values',
1664    'row_splits',
1665    'nested_row_splits'
1666]:
1667  _delegate_property(keras_tensor.RaggedKerasTensor, ragged_property)
1668
1669for ragged_method_name in [
1670    'value_rowids',
1671    'nested_value_rowids',
1672    'nrows',
1673    'row_starts',
1674    'row_limits',
1675    'row_lengths',
1676    'nested_row_lengths',
1677    'bounding_shape',
1678    'with_values',
1679    'with_flat_values',
1680    'with_row_splits_dtype',
1681    'merge_dims',
1682    'to_tensor',
1683    'to_sparse',
1684]:
1685  _delegate_method(keras_tensor.RaggedKerasTensor, ragged_method_name)
1686
1687for sparse_property in [
1688    'indices',
1689    'values',
1690]:
1691  _delegate_property(keras_tensor.SparseKerasTensor, sparse_property)
1692
1693for sparse_method in [
1694    'with_values',
1695]:
1696  _delegate_method(keras_tensor.SparseKerasTensor, sparse_method)
1697
1698
1699class ClassMethod(Layer):
1700  """Wraps a TF API Class's class method  in a `Layer` object.
1701
1702  It is inserted by the Functional API construction whenever users call
1703  a supported TF Class's class method on KerasTensors.
1704
1705  This is useful in the case where users do something like:
1706  x = keras.Input(...)
1707  y = keras.Input(...)
1708  out = tf.RaggedTensor.from_row_splits(x, y)
1709  """
1710
1711  @trackable.no_automatic_dependency_tracking
1712  def __init__(self, cls_ref, method_name, **kwargs):
1713    self.cls_ref = cls_ref
1714    self.method_name = method_name
1715    self.cls_symbol = (
1716        get_canonical_name_for_symbol(self.cls_ref,
1717                                      add_prefix_to_v1_names=True) or
1718        get_canonical_name_for_symbol(self.cls_ref,
1719                                      api_name='keras',
1720                                      add_prefix_to_v1_names=True))
1721    if 'name' not in kwargs:
1722      kwargs['name'] = K.unique_object_name(
1723          'tf.' + self.cls_symbol + '.' + self.method_name, zero_based=True,
1724          avoid_observed_names=True)
1725    kwargs['autocast'] = False
1726
1727    # Do not individually trace op layers in the SavedModel.
1728    self._must_restore_from_config = True
1729
1730    super(ClassMethod, self).__init__(**kwargs)
1731
1732    # Preserve all argument data structures when saving/loading a config
1733    # (e.g., don't unnest lists that contain one element)
1734    self._preserve_input_structure_in_config = True
1735
1736    self._expects_training_arg = False
1737    self._expects_mask_arg = False
1738
1739  def call(self, args, kwargs):
1740    return getattr(self.cls_ref, self.method_name)(*args, **kwargs)
1741
1742  def get_config(self):
1743    if not self.cls_symbol:
1744      raise ValueError('This Keras class method conversion tried to convert '
1745                       'a method belonging to class %s, a class '
1746                       'that is not an exposed in the TensorFlow API. '
1747                       'To ensure cross-version compatibility of Keras models '
1748                       'that use op layers, only op layers produced from '
1749                       'exported TF API symbols can be serialized.'
1750                       % self.cls_symbol)
1751    config = {
1752        'cls_symbol': self.cls_symbol,
1753        'method_name': self.method_name
1754    }
1755
1756    base_config = super(ClassMethod, self).get_config()
1757    return dict(list(base_config.items()) + list(config.items()))
1758
1759  @classmethod
1760  def from_config(cls, config, custom_objects=None):
1761    config = config.copy()
1762    symbol_name = config.pop('cls_symbol')
1763    cls_ref = get_symbol_from_name(symbol_name)
1764    if not cls_ref:
1765      raise ValueError(
1766          'TF symbol `tf.%s` could not be found.' % symbol_name)
1767
1768    config['cls_ref'] = cls_ref
1769
1770    return cls(**config)
1771
1772
1773class TFClassMethodDispatcher(dispatch.OpDispatcher):
1774  """A class method dispatcher that allows building a functional model with TF class methods."""
1775
1776  def __init__(self, cls, method_name):
1777    self.cls = cls
1778    self.method_name = method_name
1779
1780  def handle(self, args, kwargs):
1781    """Handle the specified operation with the specified arguments."""
1782    if any(
1783        isinstance(x, keras_tensor.KerasTensor)
1784        for x in nest.flatten([args, kwargs])):
1785      return ClassMethod(self.cls, self.method_name)(args[1:], kwargs)
1786    else:
1787      return self.NOT_SUPPORTED
1788
1789for ragged_class_method in [
1790    'from_value_rowids',
1791    'from_row_splits',
1792    'from_row_lengths',
1793    'from_row_starts',
1794    'from_row_limits',
1795    'from_uniform_row_length',
1796    'from_nested_value_rowids',
1797    'from_nested_row_splits',
1798    'from_nested_row_lengths',
1799    'from_tensor',
1800    'from_sparse',
1801]:
1802  TFClassMethodDispatcher(
1803      ragged_tensor.RaggedTensor, ragged_class_method).register(
1804          getattr(ragged_tensor.RaggedTensor, ragged_class_method))
1805