1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Contains the base Layer class, from which all layers inherit."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import copy
23import functools
24import itertools
25import threading
26import warnings
27import weakref
28
29import numpy as np
30import six
31from six.moves import zip  # pylint: disable=redefined-builtin
32
33from google.protobuf import json_format
34from tensorflow.core.framework import node_def_pb2
35from tensorflow.python import tf2
36from tensorflow.python.autograph.core import ag_ctx
37from tensorflow.python.autograph.impl import api as autograph
38from tensorflow.python.distribute import distribution_strategy_context as ds_context
39from tensorflow.python.eager import context
40from tensorflow.python.eager import def_function
41from tensorflow.python.eager import execute
42from tensorflow.python.eager import monitoring
43from tensorflow.python.framework import constant_op
44from tensorflow.python.framework import dtypes
45from tensorflow.python.framework import errors
46from tensorflow.python.framework import func_graph
47from tensorflow.python.framework import ops
48from tensorflow.python.framework import sparse_tensor
49from tensorflow.python.framework import tensor_spec
50from tensorflow.python.framework import tensor_util
51from tensorflow.python.keras import backend
52from tensorflow.python.keras import constraints
53from tensorflow.python.keras import initializers
54from tensorflow.python.keras import regularizers
55from tensorflow.python.keras.engine import base_layer_utils
56from tensorflow.python.keras.engine import input_spec
57from tensorflow.python.keras.engine import keras_tensor
58from tensorflow.python.keras.engine import node as node_module
59from tensorflow.python.keras.mixed_precision import autocast_variable
60from tensorflow.python.keras.mixed_precision import loss_scale_optimizer
61from tensorflow.python.keras.mixed_precision import policy
62from tensorflow.python.keras.saving.saved_model import layer_serialization
63from tensorflow.python.keras.utils import generic_utils
64from tensorflow.python.keras.utils import layer_utils
65from tensorflow.python.keras.utils import object_identity
66from tensorflow.python.keras.utils import tf_inspect
67from tensorflow.python.keras.utils import tf_utils
68from tensorflow.python.keras.utils import version_utils
69# A module that only depends on `keras.layers` import these from here.
70from tensorflow.python.keras.utils.generic_utils import to_snake_case  # pylint: disable=unused-import
71from tensorflow.python.keras.utils.tf_utils import is_tensor_or_tensor_list  # pylint: disable=unused-import
72
73from tensorflow.python.module import module
74from tensorflow.python.ops import array_ops
75from tensorflow.python.ops import math_ops
76from tensorflow.python.ops import variables as tf_variables
77from tensorflow.python.ops.numpy_ops import np_arrays
78from tensorflow.python.ops.ragged import ragged_tensor
79from tensorflow.python.platform import tf_logging
80from tensorflow.python.training.tracking import base as trackable
81from tensorflow.python.training.tracking import data_structures
82from tensorflow.python.training.tracking import tracking
83from tensorflow.python.util import compat
84from tensorflow.python.util import nest
85from tensorflow.python.util.tf_export import get_canonical_name_for_symbol
86from tensorflow.python.util.tf_export import keras_export
87from tensorflow.tools.docs import doc_controls
88
89# pylint: disable=g-inconsistent-quotes
90metrics_mod = generic_utils.LazyLoader(
91    "metrics_mod", globals(),
92    "tensorflow.python.keras.metrics")
93# pylint: enable=g-inconsistent-quotes
94
95# Prefix that is added to the TF op layer names.
96_TF_OP_LAYER_NAME_PREFIX = 'tf_op_layer_'
97
98# TODO(mdan): Should we have a single generic type for types that can be passed
99# to tf.cast?
100_AUTOCAST_TYPES = (ops.Tensor, sparse_tensor.SparseTensor,
101                   ragged_tensor.RaggedTensor)
102
103keras_layers_gauge = monitoring.BoolGauge('/tensorflow/api/keras/layers',
104                                          'keras layers usage', 'method')
105keras_models_gauge = monitoring.BoolGauge(
106    '/tensorflow/api/keras/models', 'keras model usage', 'method')
107keras_api_gauge = monitoring.BoolGauge('/tensorflow/api/keras',
108                                       'keras api usage', 'method')
109keras_premade_model_gauge = monitoring.BoolGauge(
110    '/tensorflow/api/keras/premade_models', 'premade keras model usage', 'type')
111
112
113@keras_export('keras.layers.Layer')
114class Layer(module.Module, version_utils.LayerVersionSelector):
115  """This is the class from which all layers inherit.
116
117  A layer is a callable object that takes as input one or more tensors and
118  that outputs one or more tensors. It involves *computation*, defined
119  in the `call()` method, and a *state* (weight variables), defined
120  either in the constructor `__init__()` or in the `build()` method.
121
122  Users will just instantiate a layer and then treat it as a callable.
123
124  Args:
125    trainable: Boolean, whether the layer's variables should be trainable.
126    name: String name of the layer.
127    dtype: The dtype of the layer's computations and weights. Can also be a
128      `tf.keras.mixed_precision.Policy`, which allows the computation and weight
129      dtype to differ. Default of `None` means to use
130      `tf.keras.mixed_precision.global_policy()`, which is a float32 policy
131      unless set to different value.
132    dynamic: Set this to `True` if your layer should only be run eagerly, and
133      should not be used to generate a static computation graph.
134      This would be the case for a Tree-RNN or a recursive network,
135      for example, or generally for any layer that manipulates tensors
136      using Python control flow. If `False`, we assume that the layer can
137      safely be used to generate a static computation graph.
138
139  Attributes:
140    name: The name of the layer (string).
141    dtype: The dtype of the layer's weights.
142    variable_dtype: Alias of `dtype`.
143    compute_dtype: The dtype of the layer's computations. Layers automatically
144      cast inputs to this dtype which causes the computations and output to also
145      be in this dtype. When mixed precision is used with a
146      `tf.keras.mixed_precision.Policy`, this will be different than
147      `variable_dtype`.
148    dtype_policy: The layer's dtype policy. See the
149      `tf.keras.mixed_precision.Policy` documentation for details.
150    trainable_weights: List of variables to be included in backprop.
151    non_trainable_weights: List of variables that should not be
152      included in backprop.
153    weights: The concatenation of the lists trainable_weights and
154      non_trainable_weights (in this order).
155    trainable: Whether the layer should be trained (boolean), i.e. whether
156      its potentially-trainable weights should be returned as part of
157      `layer.trainable_weights`.
158    input_spec: Optional (list of) `InputSpec` object(s) specifying the
159      constraints on inputs that can be accepted by the layer.
160
161  We recommend that descendants of `Layer` implement the following methods:
162
163  * `__init__()`: Defines custom layer attributes, and creates layer state
164    variables that do not depend on input shapes, using `add_weight()`.
165  * `build(self, input_shape)`: This method can be used to create weights that
166    depend on the shape(s) of the input(s), using `add_weight()`. `__call__()`
167    will automatically build the layer (if it has not been built yet) by
168    calling `build()`.
169  * `call(self, inputs, *args, **kwargs)`: Called in `__call__` after making
170    sure `build()` has been called. `call()` performs the logic of applying the
171    layer to the input tensors (which should be passed in as argument).
172    Two reserved keyword arguments you can optionally use in `call()` are:
173      - `training` (boolean, whether the call is in inference mode or training
174        mode). See more details in [the layer/model subclassing guide](
175        https://www.tensorflow.org/guide/keras/custom_layers_and_models#privileged_training_argument_in_the_call_method)
176      - `mask` (boolean tensor encoding masked timesteps in the input, used
177        in RNN layers). See more details in [the layer/model subclassing guide](
178        https://www.tensorflow.org/guide/keras/custom_layers_and_models#privileged_mask_argument_in_the_call_method)
179    A typical signature for this method is `call(self, inputs)`, and user could
180    optionally add `training` and `mask` if the layer need them. `*args` and
181    `**kwargs` is only useful for future extension when more input parameters
182    are planned to be added.
183  * `get_config(self)`: Returns a dictionary containing the configuration used
184    to initialize this layer. If the keys differ from the arguments
185    in `__init__`, then override `from_config(self)` as well.
186    This method is used when saving
187    the layer or a model that contains this layer.
188
189  Examples:
190
191  Here's a basic example: a layer with two variables, `w` and `b`,
192  that returns `y = w . x + b`.
193  It shows how to implement `build()` and `call()`.
194  Variables set as attributes of a layer are tracked as weights
195  of the layers (in `layer.weights`).
196
197  ```python
198  class SimpleDense(Layer):
199
200    def __init__(self, units=32):
201        super(SimpleDense, self).__init__()
202        self.units = units
203
204    def build(self, input_shape):  # Create the state of the layer (weights)
205      w_init = tf.random_normal_initializer()
206      self.w = tf.Variable(
207          initial_value=w_init(shape=(input_shape[-1], self.units),
208                               dtype='float32'),
209          trainable=True)
210      b_init = tf.zeros_initializer()
211      self.b = tf.Variable(
212          initial_value=b_init(shape=(self.units,), dtype='float32'),
213          trainable=True)
214
215    def call(self, inputs):  # Defines the computation from inputs to outputs
216        return tf.matmul(inputs, self.w) + self.b
217
218  # Instantiates the layer.
219  linear_layer = SimpleDense(4)
220
221  # This will also call `build(input_shape)` and create the weights.
222  y = linear_layer(tf.ones((2, 2)))
223  assert len(linear_layer.weights) == 2
224
225  # These weights are trainable, so they're listed in `trainable_weights`:
226  assert len(linear_layer.trainable_weights) == 2
227  ```
228
229  Note that the method `add_weight()` offers a shortcut to create weights:
230
231  ```python
232  class SimpleDense(Layer):
233
234    def __init__(self, units=32):
235        super(SimpleDense, self).__init__()
236        self.units = units
237
238    def build(self, input_shape):
239        self.w = self.add_weight(shape=(input_shape[-1], self.units),
240                                 initializer='random_normal',
241                                 trainable=True)
242        self.b = self.add_weight(shape=(self.units,),
243                                 initializer='random_normal',
244                                 trainable=True)
245
246    def call(self, inputs):
247        return tf.matmul(inputs, self.w) + self.b
248  ```
249
250  Besides trainable weights, updated via backpropagation during training,
251  layers can also have non-trainable weights. These weights are meant to
252  be updated manually during `call()`. Here's a example layer that computes
253  the running sum of its inputs:
254
255  ```python
256  class ComputeSum(Layer):
257
258    def __init__(self, input_dim):
259        super(ComputeSum, self).__init__()
260        # Create a non-trainable weight.
261        self.total = tf.Variable(initial_value=tf.zeros((input_dim,)),
262                                 trainable=False)
263
264    def call(self, inputs):
265        self.total.assign_add(tf.reduce_sum(inputs, axis=0))
266        return self.total
267
268  my_sum = ComputeSum(2)
269  x = tf.ones((2, 2))
270
271  y = my_sum(x)
272  print(y.numpy())  # [2. 2.]
273
274  y = my_sum(x)
275  print(y.numpy())  # [4. 4.]
276
277  assert my_sum.weights == [my_sum.total]
278  assert my_sum.non_trainable_weights == [my_sum.total]
279  assert my_sum.trainable_weights == []
280  ```
281
282  For more information about creating layers, see the guide
283  [Making new Layers and Models via subclassing](
284    https://www.tensorflow.org/guide/keras/custom_layers_and_models)
285  """
286
287  # See tf.Module for the usage of this property.
288  # The key for _obj_reference_counts_dict is a Trackable, which could be a
289  # variable or layer etc. tf.Module._flatten will fail to flatten the key
290  # since it is trying to convert Trackable to a string. This attribute can be
291  # ignored even after the fix of nest lib, since the trackable object should
292  # already been available as individual attributes. _obj_reference_counts_dict
293  # just contains a copy of them.
294  _TF_MODULE_IGNORED_PROPERTIES = frozenset(itertools.chain(
295      ('_obj_reference_counts_dict',),
296      module.Module._TF_MODULE_IGNORED_PROPERTIES
297  ))
298
299  # When loading from a SavedModel, Layers typically can be revived into a
300  # generic Layer wrapper. Sometimes, however, layers may implement methods
301  # that go beyond this wrapper, as in the case of PreprocessingLayers'
302  # `adapt` method. When this is the case, layer implementers can override
303  # must_restore_from_config to return True; layers with this property must
304  # be restored into their actual objects (and will fail if the object is
305  # not available to the restoration code).
306  _must_restore_from_config = False
307
308  def _get_cell_name(self):
309    canonical_name = get_canonical_name_for_symbol(
310        self.__class__, api_name='keras', add_prefix_to_v1_names=True)
311    if canonical_name is not None:
312      return 'tf.{}'.format(canonical_name)
313    return self.__class__.__module__ + '.' + self.__class__.__name__
314
315  def _instrument_layer_creation(self):
316    self._instrumented_keras_api = False
317    self._instrumented_keras_layer_class = False
318    self._instrumented_keras_model_class = False
319    if not getattr(self, '_disable_keras_instrumentation', False):
320      keras_api_gauge.get_cell('layer').set(True)
321      self._instrumented_keras_api = True
322      if getattr(self, '_is_model_for_instrumentation', False):
323        keras_models_gauge.get_cell(self._get_cell_name()).set(True)
324        self._instrumented_keras_model_class = True
325      else:
326        keras_layers_gauge.get_cell(self._get_cell_name()).set(True)
327        self._instrumented_keras_layer_class = True
328
329  @trackable.no_automatic_dependency_tracking
330  def __init__(self,
331               trainable=True,
332               name=None,
333               dtype=None,
334               dynamic=False,
335               **kwargs):
336    self._instrument_layer_creation()
337
338    # These properties should be set by the user via keyword arguments.
339    # note that 'dtype', 'input_shape' and 'batch_input_shape'
340    # are only applicable to input layers: do not pass these keywords
341    # to non-input layers.
342    allowed_kwargs = {
343        'input_dim',
344        'input_shape',
345        'batch_input_shape',
346        'batch_size',
347        'weights',
348        'activity_regularizer',
349        'autocast',
350        'implementation',
351    }
352    # Validate optional keyword arguments.
353    generic_utils.validate_kwargs(kwargs, allowed_kwargs)
354
355    # Mutable properties
356    # Indicates whether the layer's weights are updated during training
357    # and whether the layer's updates are run during training.
358    self._trainable = trainable
359    # A stateful layer is a layer whose updates are run during inference too,
360    # for instance stateful RNNs.
361    self._stateful = False
362    # Indicates whether `build` needs to be called upon layer call, to create
363    # the layer's weights.
364    self.built = False
365    # Provides information about which inputs are compatible with the layer.
366    self._input_spec = None
367
368    # SavedModel-related attributes.
369    # Record the build input shape for loading purposes.
370    # TODO(kathywu): Move this to Layer._set_save_spec once cl/290121460 is
371    # submitted.
372    self._build_input_shape = None
373    self._saved_model_inputs_spec = None
374
375    # `Layer.compute_mask` will be called at the end of `Layer.__call__` if
376    # `Layer.compute_mask` is overridden, or if the `Layer` subclass sets
377    # `self.supports_masking=True`.
378    self._supports_masking = not generic_utils.is_default(self.compute_mask)
379
380    self._init_set_name(name)
381    self._activity_regularizer = regularizers.get(
382        kwargs.pop('activity_regularizer', None))
383    self._maybe_create_attribute('_trainable_weights', [])
384    self._maybe_create_attribute('_non_trainable_weights', [])
385    self._updates = []
386    # Object to store all thread local layer properties.
387    self._thread_local = threading.local()
388    # A list of zero-argument lambdas which return Tensors, used for variable
389    # regularizers.
390    self._callable_losses = []
391    # A list of symbolic Tensors containing activity regularizers and losses
392    # manually added through `add_loss` in graph-building mode.
393    self._losses = []
394    # A list of metric instances corresponding to the symbolic metric tensors
395    # added using the `add_metric` API.
396    self._metrics = []
397    # Ensures the same metric is not added multiple times in `MirroredStrategy`.
398    self._metrics_lock = threading.Lock()
399
400    # Both graph and subclassed networks have a dtype policy. For graph
401    # networks, the policy's compute and variable dtypes are ignored. Such
402    # networks only use the policy if it is a PolicyV1, in which case it uses
403    # the PolicyV1's loss_scale (Policy does not have a loss_scale). For
404    # subclassed networks, the compute and variable dtypes are used as like any
405    # ordinary layer.
406    self._set_dtype_policy(dtype)
407    # Boolean indicating whether the layer automatically casts its inputs to the
408    # layer's compute_dtype.
409    self._autocast = kwargs.get('autocast',
410                                base_layer_utils.v2_dtype_behavior_enabled())
411
412    # Tracks `TrackableDataStructure`s, `Module`s, and `Layer`s.
413    # Ordered by when the object was assigned as an attr.
414    # Entries are unique.
415    self._maybe_create_attribute('_self_tracked_trackables', [])
416
417    # These lists will be filled via successive calls
418    # to self._add_inbound_node().
419    # Used in symbolic mode only, only in conjunction with graph-networks
420    self._inbound_nodes_value = []
421    self._outbound_nodes_value = []
422
423    self._init_call_fn_args()
424
425    # Whether the `call` method can be used to build a TF graph without issues.
426    # This attribute has no effect if the model is created using the Functional
427    # API. Instead, `model.dynamic` is determined based on the internal layers.
428    self._dynamic = dynamic
429
430    # Manage input shape information if passed.
431    if 'input_dim' in kwargs and 'input_shape' not in kwargs:
432      # Backwards compatibility: alias 'input_dim' to 'input_shape'.
433      kwargs['input_shape'] = (kwargs['input_dim'],)
434    if 'input_shape' in kwargs or 'batch_input_shape' in kwargs:
435      # In this case we will later create an input layer
436      # to insert before the current layer
437      if 'batch_input_shape' in kwargs:
438        batch_input_shape = tuple(kwargs['batch_input_shape'])
439      elif 'input_shape' in kwargs:
440        if 'batch_size' in kwargs:
441          batch_size = kwargs['batch_size']
442        else:
443          batch_size = None
444        batch_input_shape = (batch_size,) + tuple(kwargs['input_shape'])
445      self._batch_input_shape = batch_input_shape
446
447    # Manage initial weight values if passed.
448    self._initial_weights = kwargs.get('weights', None)
449
450    # Whether the layer will track any layers that is set as attribute on itself
451    # as sub-layers, the weights from the sub-layers will be included in the
452    # parent layer's variables() as well.
453    # Default to True, which means auto tracking is turned on. Certain subclass
454    # might want to turn it off, like Sequential model.
455    self._auto_track_sub_layers = True
456
457    # For backwards compat reasons, most built-in layers do not guarantee
458    # That they will 100% preserve the structure of input args when saving
459    # / loading configs. E.g. they may un-nest an arg that is
460    # a list with one element.
461    self._preserve_input_structure_in_config = False
462
463  @trackable.no_automatic_dependency_tracking
464  @generic_utils.default
465  def build(self, input_shape):
466    """Creates the variables of the layer (optional, for subclass implementers).
467
468    This is a method that implementers of subclasses of `Layer` or `Model`
469    can override if they need a state-creation step in-between
470    layer instantiation and layer call.
471
472    This is typically used to create the weights of `Layer` subclasses.
473
474    Args:
475      input_shape: Instance of `TensorShape`, or list of instances of
476        `TensorShape` if the layer expects a list of inputs
477        (one instance per input).
478    """
479    # Only record the build input shapes of overridden build methods.
480    if not hasattr(self.build, '_is_default'):
481      self._build_input_shape = input_shape
482    self.built = True
483
484  @doc_controls.for_subclass_implementers
485  def call(self, inputs, *args, **kwargs):  # pylint: disable=unused-argument
486    """This is where the layer's logic lives.
487
488    Note here that `call()` method in `tf.keras` is little bit different
489    from `keras` API. In `keras` API, you can pass support masking for
490    layers as additional arguments. Whereas `tf.keras` has `compute_mask()`
491    method to support masking.
492
493    Args:
494        inputs: Input tensor, or list/tuple of input tensors.
495        *args: Additional positional arguments. Currently unused.
496        **kwargs: Additional keyword arguments. Currently unused.
497
498    Returns:
499        A tensor or list/tuple of tensors.
500    """
501    return inputs
502
503  @doc_controls.for_subclass_implementers
504  def _add_trackable(self, trackable_object, trainable):
505    """Adds a Trackable object to this layer's state.
506
507    Args:
508      trackable_object: The tf.tracking.Trackable object to add.
509      trainable: Boolean, whether the variable should be part of the layer's
510        "trainable_variables" (e.g. variables, biases) or
511        "non_trainable_variables" (e.g. BatchNorm mean and variance).
512
513    Returns:
514      The TrackableWeightHandler used to track this object.
515    """
516    handler = base_layer_utils.TrackableWeightHandler(trackable_object)
517    if trainable:
518      self._trainable_weights.append(handler)
519    else:
520      self._non_trainable_weights.append(handler)
521    return handler
522
523  @doc_controls.for_subclass_implementers
524  def add_weight(self,
525                 name=None,
526                 shape=None,
527                 dtype=None,
528                 initializer=None,
529                 regularizer=None,
530                 trainable=None,
531                 constraint=None,
532                 use_resource=None,
533                 synchronization=tf_variables.VariableSynchronization.AUTO,
534                 aggregation=tf_variables.VariableAggregation.NONE,
535                 **kwargs):
536    """Adds a new variable to the layer.
537
538    Args:
539      name: Variable name.
540      shape: Variable shape. Defaults to scalar if unspecified.
541      dtype: The type of the variable. Defaults to `self.dtype`.
542      initializer: Initializer instance (callable).
543      regularizer: Regularizer instance (callable).
544      trainable: Boolean, whether the variable should be part of the layer's
545        "trainable_variables" (e.g. variables, biases)
546        or "non_trainable_variables" (e.g. BatchNorm mean and variance).
547        Note that `trainable` cannot be `True` if `synchronization`
548        is set to `ON_READ`.
549      constraint: Constraint instance (callable).
550      use_resource: Whether to use `ResourceVariable`.
551      synchronization: Indicates when a distributed a variable will be
552        aggregated. Accepted values are constants defined in the class
553        `tf.VariableSynchronization`. By default the synchronization is set to
554        `AUTO` and the current `DistributionStrategy` chooses
555        when to synchronize. If `synchronization` is set to `ON_READ`,
556        `trainable` must not be set to `True`.
557      aggregation: Indicates how a distributed variable will be aggregated.
558        Accepted values are constants defined in the class
559        `tf.VariableAggregation`.
560      **kwargs: Additional keyword arguments. Accepted values are `getter`,
561        `collections`, `experimental_autocast` and `caching_device`.
562
563    Returns:
564      The variable created.
565
566    Raises:
567      ValueError: When giving unsupported dtype and no initializer or when
568        trainable has been set to True with synchronization set as `ON_READ`.
569    """
570    if shape is None:
571      shape = ()
572    kwargs.pop('partitioner', None)  # Ignored.
573    # Validate optional keyword arguments.
574    for kwarg in kwargs:
575      if kwarg not in ['collections', 'experimental_autocast',
576                       'caching_device', 'getter']:
577        raise TypeError('Unknown keyword argument:', kwarg)
578    collections_arg = kwargs.pop('collections', None)
579    # 'experimental_autocast' can be set to False by the caller to indicate an
580    # AutoCastVariable should never be created.
581    autocast = kwargs.pop('experimental_autocast', True)
582    # See the docstring for tf.Variable about the details for caching_device.
583    caching_device = kwargs.pop('caching_device', None)
584
585    if dtype is None:
586      dtype = self.dtype or backend.floatx()
587    dtype = dtypes.as_dtype(dtype)
588    if self._dtype_policy.variable_dtype is None:
589      # The policy is "_infer", so we infer the policy from the variable dtype.
590      self._set_dtype_policy(policy.Policy(dtype.base_dtype.name))
591    initializer = initializers.get(initializer)
592    regularizer = regularizers.get(regularizer)
593    constraint = constraints.get(constraint)
594
595    if synchronization == tf_variables.VariableSynchronization.ON_READ:
596      if trainable:
597        raise ValueError(
598            'Synchronization value can be set to '
599            'VariableSynchronization.ON_READ only for non-trainable variables. '
600            'You have specified trainable=True and '
601            'synchronization=VariableSynchronization.ON_READ.')
602      else:
603        # Set trainable to be false when variable is to be synced on read.
604        trainable = False
605    elif trainable is None:
606      trainable = True
607
608    # Initialize variable when no initializer provided
609    if initializer is None:
610      # If dtype is DT_FLOAT, provide a uniform unit scaling initializer
611      if dtype.is_floating:
612        initializer = initializers.get('glorot_uniform')
613      # If dtype is DT_INT/DT_UINT, provide a default value `zero`
614      # If dtype is DT_BOOL, provide a default value `FALSE`
615      elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool:
616        initializer = initializers.get('zeros')
617      # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here?
618      else:
619        raise ValueError('An initializer for variable %s of type %s is required'
620                         ' for layer %s' % (name, dtype.base_dtype, self.name))
621
622    getter = kwargs.pop('getter', base_layer_utils.make_variable)
623    if (autocast and
624        self._dtype_policy.compute_dtype != self._dtype_policy.variable_dtype
625        and dtype.is_floating):
626      old_getter = getter
627      # Wrap variable constructor to return an AutoCastVariable.
628      def getter(*args, **kwargs):  # pylint: disable=function-redefined
629        variable = old_getter(*args, **kwargs)
630        return autocast_variable.create_autocast_variable(variable)
631      # Also the caching_device does not work with the mixed precision API,
632      # disable it if it is specified.
633      # TODO(b/142020079): Reenable it once the bug is fixed.
634      if caching_device is not None:
635        tf_logging.warn('`caching_device` does not work with mixed precision '
636                        'API. Ignoring user specified `caching_device`.')
637        caching_device = None
638
639    variable = self._add_variable_with_custom_getter(
640        name=name,
641        shape=shape,
642        # TODO(allenl): a `make_variable` equivalent should be added as a
643        # `Trackable` method.
644        getter=getter,
645        # Manage errors in Layer rather than Trackable.
646        overwrite=True,
647        initializer=initializer,
648        dtype=dtype,
649        constraint=constraint,
650        trainable=trainable,
651        use_resource=use_resource,
652        collections=collections_arg,
653        synchronization=synchronization,
654        aggregation=aggregation,
655        caching_device=caching_device)
656    if regularizer is not None:
657      # TODO(fchollet): in the future, this should be handled at the
658      # level of variable creation, and weight regularization losses
659      # should be variable attributes.
660      name_in_scope = variable.name[:variable.name.find(':')]
661      self._handle_weight_regularization(name_in_scope,
662                                         variable,
663                                         regularizer)
664    if base_layer_utils.is_split_variable(variable):
665      for v in variable:
666        backend.track_variable(v)
667        if trainable:
668          self._trainable_weights.append(v)
669        else:
670          self._non_trainable_weights.append(v)
671    else:
672      backend.track_variable(variable)
673      if trainable:
674        self._trainable_weights.append(variable)
675      else:
676        self._non_trainable_weights.append(variable)
677    return variable
678
679  @generic_utils.default
680  def get_config(self):
681    """Returns the config of the layer.
682
683    A layer config is a Python dictionary (serializable)
684    containing the configuration of a layer.
685    The same layer can be reinstantiated later
686    (without its trained weights) from this configuration.
687
688    The config of a layer does not include connectivity
689    information, nor the layer class name. These are handled
690    by `Network` (one layer of abstraction above).
691
692    Note that `get_config()` does not guarantee to return a fresh copy of dict
693    every time it is called. The callers should make a copy of the returned dict
694    if they want to modify it.
695
696    Returns:
697        Python dictionary.
698    """
699    all_args = tf_inspect.getfullargspec(self.__init__).args
700    config = {
701        'name': self.name,
702        'trainable': self.trainable,
703    }
704    if hasattr(self, '_batch_input_shape'):
705      config['batch_input_shape'] = self._batch_input_shape
706    config['dtype'] = policy.serialize(self._dtype_policy)
707    if hasattr(self, 'dynamic'):
708      # Only include `dynamic` in the `config` if it is `True`
709      if self.dynamic:
710        config['dynamic'] = self.dynamic
711      elif 'dynamic' in all_args:
712        all_args.remove('dynamic')
713    expected_args = config.keys()
714    # Finds all arguments in the `__init__` that are not in the config:
715    extra_args = [arg for arg in all_args if arg not in expected_args]
716    # Check that either the only argument in the `__init__` is  `self`,
717    # or that `get_config` has been overridden:
718    if len(extra_args) > 1 and hasattr(self.get_config, '_is_default'):
719      raise NotImplementedError('Layer %s has arguments in `__init__` and '
720                                'therefore must override `get_config`.' %
721                                self.__class__.__name__)
722    return config
723
724  @classmethod
725  def from_config(cls, config):
726    """Creates a layer from its config.
727
728    This method is the reverse of `get_config`,
729    capable of instantiating the same layer from the config
730    dictionary. It does not handle layer connectivity
731    (handled by Network), nor weights (handled by `set_weights`).
732
733    Args:
734        config: A Python dictionary, typically the
735            output of get_config.
736
737    Returns:
738        A layer instance.
739    """
740    return cls(**config)
741
742  def compute_output_shape(self, input_shape):
743    """Computes the output shape of the layer.
744
745    If the layer has not been built, this method will call `build` on the
746    layer. This assumes that the layer will later be used with inputs that
747    match the input shape provided here.
748
749    Args:
750        input_shape: Shape tuple (tuple of integers)
751            or list of shape tuples (one per output tensor of the layer).
752            Shape tuples can include None for free dimensions,
753            instead of an integer.
754
755    Returns:
756        An input shape tuple.
757    """
758    if context.executing_eagerly():
759      # In this case we build the model first in order to do shape inference.
760      # This is acceptable because the framework only calls
761      # `compute_output_shape` on shape values that the layer would later be
762      # built for. It would however cause issues in case a user attempts to
763      # use `compute_output_shape` manually with shapes that are incompatible
764      # with the shape the Layer will be called on (these users will have to
765      # implement `compute_output_shape` themselves).
766      self._maybe_build(input_shape)
767      with func_graph.FuncGraph(str(self.name) + '_scratch_graph').as_default():
768        input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
769        def _make_placeholder_like(shape):
770          ph = backend.placeholder(shape=shape, dtype=self.dtype)
771          ph._keras_mask = None
772          return ph
773        inputs = nest.map_structure(_make_placeholder_like, input_shape)
774        try:
775          outputs = self(inputs, training=False)
776        except TypeError as e:
777          six.raise_from(
778              NotImplementedError(
779                  'We could not automatically infer the static shape of the '
780                  'layer\'s output. Please implement the '
781                  '`compute_output_shape` method on your layer (%s).' %
782                  self.__class__.__name__), e)
783      return nest.map_structure(lambda t: t.shape, outputs)
784    raise NotImplementedError(
785        'Please run in eager mode or implement the `compute_output_shape` '
786        'method on your layer (%s).' % self.__class__.__name__)
787
788  @doc_controls.for_subclass_implementers
789  def compute_output_signature(self, input_signature):
790    """Compute the output tensor signature of the layer based on the inputs.
791
792    Unlike a TensorShape object, a TensorSpec object contains both shape
793    and dtype information for a tensor. This method allows layers to provide
794    output dtype information if it is different from the input dtype.
795    For any layer that doesn't implement this function,
796    the framework will fall back to use `compute_output_shape`, and will
797    assume that the output dtype matches the input dtype.
798
799    Args:
800      input_signature: Single TensorSpec or nested structure of TensorSpec
801        objects, describing a candidate input for the layer.
802
803    Returns:
804      Single TensorSpec or nested structure of TensorSpec objects, describing
805        how the layer would transform the provided input.
806
807    Raises:
808      TypeError: If input_signature contains a non-TensorSpec object.
809    """
810    def check_type_return_shape(s):
811      if not isinstance(s, tensor_spec.TensorSpec):
812        raise TypeError('Only TensorSpec signature types are supported, '
813                        'but saw signature entry: {}.'.format(s))
814      return s.shape
815    input_shape = nest.map_structure(check_type_return_shape, input_signature)
816    output_shape = self.compute_output_shape(input_shape)
817    dtype = self._compute_dtype
818    if dtype is None:
819      input_dtypes = [s.dtype for s in nest.flatten(input_signature)]
820      # Default behavior when self.dtype is None, is to use the first input's
821      # dtype.
822      dtype = input_dtypes[0]
823    return nest.map_structure(
824        lambda s: tensor_spec.TensorSpec(dtype=dtype, shape=s),
825        output_shape)
826
827  def _keras_tensor_symbolic_call(self, inputs, input_masks, args, kwargs):
828    if self.dynamic:
829      # We will use static shape inference to return symbolic tensors
830      # matching the specifications of the layer outputs.
831      # Since `self.dynamic` is True, we will never attempt to
832      # run the underlying TF graph (which is disconnected).
833      # TODO(fchollet): consider py_func as an alternative, which
834      # would enable us to run the underlying graph if needed.
835      input_signature = nest.map_structure(
836          lambda x: tensor_spec.TensorSpec(shape=x.shape, dtype=x.dtype),
837          inputs)
838      output_signature = self.compute_output_signature(input_signature)
839      return nest.map_structure(keras_tensor.KerasTensor, output_signature)
840    else:
841      return self._infer_output_signature(inputs, args, kwargs, input_masks)
842
843  def _infer_output_signature(self, inputs, args, kwargs, input_masks):
844    """TODO(kaftan): Docstring."""
845
846    call_fn = self.call
847    # Wrapping `call` function in autograph to allow for dynamic control
848    # flow and control dependencies in call. We are limiting this to
849    # subclassed layers as autograph is strictly needed only for
850    # subclassed layers and models.
851    # tf_convert will respect the value of autograph setting in the
852    # enclosing tf.function, if any.
853    if (base_layer_utils.is_subclassed(self) and
854        not base_layer_utils.from_saved_model(self)):
855      call_fn = autograph.tf_convert(self.call, ag_ctx.control_status_ctx())
856
857    # We enter a scratch graph and build placeholder inputs inside of it that
858    # match the input args.
859    # We then call the layer inside of the scratch graph to identify the
860    # output signatures, then we build KerasTensors corresponding to those
861    # outputs.
862    scratch_graph = func_graph.FuncGraph(str(self.name) + '_scratch_graph')
863    with scratch_graph.as_default():
864      inputs = nest.map_structure(
865          keras_tensor.keras_tensor_to_placeholder, inputs)
866      args = nest.map_structure(
867          keras_tensor.keras_tensor_to_placeholder, args)
868      kwargs = nest.map_structure(
869          keras_tensor.keras_tensor_to_placeholder, kwargs)
870      input_masks = nest.map_structure(
871          keras_tensor.keras_tensor_to_placeholder, input_masks)
872
873      inputs = self._maybe_cast_inputs(inputs)
874
875      with backend.name_scope(self._name_scope()):
876        with autocast_variable.enable_auto_cast_variables(
877            self._compute_dtype_object):
878          # Build layer if applicable (if the `build` method has been
879          # overridden).
880          # TODO(kaftan): do we maybe_build here, or have we already done it?
881          self._maybe_build(inputs)
882          outputs = call_fn(inputs, *args, **kwargs)
883
884        self._handle_activity_regularization(inputs, outputs)
885      self._set_mask_metadata(inputs, outputs, input_masks,
886                              build_graph=False)
887      outputs = nest.map_structure(
888          keras_tensor.keras_tensor_from_tensor, outputs)
889
890    if hasattr(self, '_set_inputs') and not self.inputs:
891      # TODO(kaftan): figure out if we need to do this at all
892      # Subclassed network: explicitly set metadata normally set by
893      # a call to self._set_inputs().
894      self._set_inputs(inputs, outputs)
895    del scratch_graph
896    return outputs
897
898  @generic_utils.default
899  def compute_mask(self, inputs, mask=None):  # pylint: disable=unused-argument
900    """Computes an output mask tensor.
901
902    Args:
903        inputs: Tensor or list of tensors.
904        mask: Tensor or list of tensors.
905
906    Returns:
907        None or a tensor (or list of tensors,
908            one per output tensor of the layer).
909    """
910    if not self._supports_masking:
911      if any(m is not None for m in nest.flatten(mask)):
912        raise TypeError('Layer ' + self.name + ' does not support masking, '
913                        'but was passed an input_mask: ' + str(mask))
914      # masking not explicitly supported: return None as mask.
915      return None
916    # if masking is explicitly supported, by default
917    # carry over the input mask
918    return mask
919
920  def __call__(self, *args, **kwargs):
921    """Wraps `call`, applying pre- and post-processing steps.
922
923    Args:
924      *args: Positional arguments to be passed to `self.call`.
925      **kwargs: Keyword arguments to be passed to `self.call`.
926
927    Returns:
928      Output tensor(s).
929
930    Note:
931      - The following optional keyword arguments are reserved for specific uses:
932        * `training`: Boolean scalar tensor of Python boolean indicating
933          whether the `call` is meant for training or inference.
934        * `mask`: Boolean input mask.
935      - If the layer's `call` method takes a `mask` argument (as some Keras
936        layers do), its default value will be set to the mask generated
937        for `inputs` by the previous layer (if `input` did come from
938        a layer that generated a corresponding mask, i.e. if it came from
939        a Keras layer with masking support.
940      - If the layer is not built, the method will call `build`.
941
942    Raises:
943      ValueError: if the layer's `call` method returns None (an invalid value).
944      RuntimeError: if `super().__init__()` was not called in the constructor.
945    """
946    if not hasattr(self, '_thread_local'):
947      raise RuntimeError(
948          'You must call `super().__init__()` in the layer constructor.')
949
950    # `inputs` (the first arg in the method spec) is special cased in
951    # layer call due to historical reasons.
952    # This special casing currently takes the form of:
953    # - 'inputs' must be explicitly passed. A layer cannot have zero arguments,
954    #   and inputs cannot have been provided via the default value of a kwarg.
955    # - numpy/scalar values in `inputs` get converted to tensors
956    # - implicit masks / mask metadata are only collected from 'inputs`
957    # - Layers are built using shape info from 'inputs' only
958    # - input_spec compatibility is only checked against `inputs`
959    # - mixed precision casting (autocast) is only applied to `inputs`,
960    #   not to any other argument.
961    # - setting the SavedModel saving spec.
962    inputs, args, kwargs = self._split_out_first_arg(args, kwargs)
963    input_list = nest.flatten(inputs)
964
965    # Functional Model construction mode is invoked when `Layer`s are called on
966    # symbolic `KerasTensor`s, i.e.:
967    # >> inputs = tf.keras.Input(10)
968    # >> outputs = MyLayer()(inputs)  # Functional construction mode.
969    # >> model = tf.keras.Model(inputs, outputs)
970    if _in_functional_construction_mode(self, inputs, args, kwargs, input_list):
971      return self._functional_construction_call(inputs, args, kwargs,
972                                                input_list)
973
974    # Maintains info about the `Layer.call` stack.
975    call_context = base_layer_utils.call_context()
976
977    # Accept NumPy and scalar inputs by converting to Tensors.
978    if any(isinstance(x, (
979        np_arrays.ndarray, np.ndarray, float, int)) for x in input_list):
980      inputs = nest.map_structure(_convert_numpy_or_python_types, inputs)
981      input_list = nest.flatten(inputs)
982
983    # Handle `mask` propagation from previous layer to current layer. Masks can
984    # be propagated explicitly via the `mask` argument, or implicitly via
985    # setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed
986    # explicitly take priority.
987    input_masks, mask_is_implicit = self._get_input_masks(
988        inputs, input_list, args, kwargs)
989    if self._expects_mask_arg and mask_is_implicit:
990      kwargs['mask'] = input_masks
991
992    # Training mode for `Layer.call` is set via (in order of priority):
993    # (1) The `training` argument passed to this `Layer.call`, if it is not None
994    # (2) The training mode of an outer `Layer.call`.
995    # (3) The default mode set by `tf.keras.backend.set_learning_phase` (if set)
996    # (4) Any non-None default value for `training` specified in the call
997    #  signature
998    # (5) False (treating the layer as if it's in inference)
999    args, kwargs, training_mode = self._set_training_mode(
1000        args, kwargs, call_context)
1001
1002    # Losses are cleared for all sublayers on the outermost `Layer.call`.
1003    # Losses are not cleared on inner `Layer.call`s, because sublayers can be
1004    # called multiple times.
1005    if not call_context.in_call:
1006      self._clear_losses()
1007
1008    eager = context.executing_eagerly()
1009    with call_context.enter(
1010        layer=self,
1011        inputs=inputs,
1012        build_graph=not eager,
1013        training=training_mode):
1014
1015      if self._autocast:
1016        inputs = self._maybe_cast_inputs(inputs, input_list)
1017
1018      input_spec.assert_input_compatibility(self.input_spec, inputs, self.name)
1019      if eager:
1020        call_fn = self.call
1021        name_scope = self._name
1022      else:
1023        name_scope = self._name_scope()  # Avoid autoincrementing.
1024        call_fn = self._autographed_call()
1025
1026      with ops.name_scope_v2(name_scope):
1027        if not self.built:
1028          self._maybe_build(inputs)
1029
1030        with autocast_variable.enable_auto_cast_variables(
1031            self._compute_dtype_object):
1032          outputs = call_fn(inputs, *args, **kwargs)
1033
1034        if self._activity_regularizer:
1035          self._handle_activity_regularization(inputs, outputs)
1036        if self._supports_masking:
1037          self._set_mask_metadata(inputs, outputs, input_masks, not eager)
1038        if self._saved_model_inputs_spec is None:
1039          self._set_save_spec(inputs)
1040
1041        return outputs
1042
1043  def _functional_construction_call(self, inputs, args, kwargs, input_list):
1044    call_context = base_layer_utils.call_context()
1045
1046    # Accept NumPy and scalar inputs by converting to Tensors.
1047    if any(isinstance(x, (
1048        np_arrays.ndarray, np.ndarray, float, int)) for x in input_list):
1049
1050      def _convert_non_tensor(x):
1051        # Don't call `ops.convert_to_tensor` on all `inputs` because
1052        # `SparseTensors` can't be converted to `Tensor`.
1053        if isinstance(x, (np_arrays.ndarray, np.ndarray, float, int)):
1054          return ops.convert_to_tensor_v2_with_dispatch(x)
1055        return x
1056
1057      inputs = nest.map_structure(_convert_non_tensor, inputs)
1058      input_list = nest.flatten(inputs)
1059
1060    # Handle `mask` propagation from previous layer to current layer. Masks can
1061    # be propagated explicitly via the `mask` argument, or implicitly via
1062    # setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed
1063    # explicitly take priority.
1064    mask_arg_passed_by_framework = False
1065    input_masks, mask_is_implicit = self._get_input_masks(
1066        inputs, input_list, args, kwargs)
1067    if self._expects_mask_arg and mask_is_implicit:
1068      kwargs['mask'] = input_masks
1069      mask_arg_passed_by_framework = True
1070
1071    # If `training` argument is None or not explicitly passed,
1072    # propagate `training` value from this layer's calling layer.
1073    training_value = None
1074    training_arg_passed_by_framework = False
1075    # Priority 1: `training` was explicitly passed a non-None value.
1076    if self._call_arg_was_passed('training', args, kwargs):
1077      training_value = self._get_call_arg_value('training', args, kwargs)
1078      if not self._expects_training_arg:
1079        kwargs.pop('training')
1080
1081    if training_value is None:
1082      # Priority 2: `training` was passed to a parent layer.
1083      if call_context.training is not None:
1084        training_value = call_context.training
1085      # Priority 3: `learning_phase()` has been set.
1086      elif backend.global_learning_phase_is_set():
1087        training_value = backend.learning_phase()
1088        # Force the training_value to be bool type which matches to the contract
1089        # for layer/model call args.
1090        if tensor_util.is_tf_type(training_value):
1091          training_value = math_ops.cast(training_value, dtypes.bool)
1092        else:
1093          training_value = bool(training_value)
1094      # Priority 4: trace layer with the default training argument specified
1095      # in the `call` signature (or in inference mode if the `call` signature
1096      # specifies no non-None default).
1097      else:
1098        training_value = self._default_training_arg
1099      # In cases (2), (3), (4) the training argument is passed automatically
1100      # by the framework, and will not be hard-coded into the model.
1101      if self._expects_training_arg:
1102        args, kwargs = self._set_call_arg_value('training', training_value,
1103                                                args, kwargs)
1104        training_arg_passed_by_framework = True
1105
1106    if keras_tensor.keras_tensors_enabled():
1107      with call_context.enter(
1108          layer=self, inputs=inputs, build_graph=True, training=training_value):
1109        # Check input assumptions set after layer building, e.g. input shape.
1110        outputs = self._keras_tensor_symbolic_call(
1111            inputs, input_masks, args, kwargs)
1112
1113        if outputs is None:
1114          raise ValueError('A layer\'s `call` method should return a '
1115                           'Tensor or a list of Tensors, not None '
1116                           '(layer: ' + self.name + ').')
1117        if training_arg_passed_by_framework:
1118          args, kwargs = self._set_call_arg_value(
1119              'training', None, args, kwargs, pop_kwarg_if_none=True)
1120        if mask_arg_passed_by_framework:
1121          kwargs.pop('mask')
1122        # Node connectivity does not special-case the first argument.
1123        outputs = self._set_connectivity_metadata((inputs,) + args, kwargs,
1124                                                  outputs)
1125        return outputs
1126
1127    # Only create Keras history if at least one tensor originates from a
1128    # `keras.Input`. Otherwise this Layer may be being used outside the Keras
1129    # framework.
1130    # TODO(kaftan): make this not special case inputs
1131    if base_layer_utils.needs_keras_history(inputs):
1132      base_layer_utils.create_keras_history(inputs)
1133
1134    with call_context.enter(
1135        layer=self, inputs=inputs, build_graph=True, training=training_value):
1136      # Symbolic execution on symbolic tensors. We will attempt to build
1137      # the corresponding TF subgraph inside `backend.get_graph()`
1138      # TODO(reedwm): We should assert input compatibility after the inputs
1139      # are casted, not before.
1140      input_spec.assert_input_compatibility(self.input_spec, inputs, self.name)
1141      graph = backend.get_graph()
1142      # Use `self._name_scope()` to avoid auto-incrementing the name.
1143      with graph.as_default(), backend.name_scope(self._name_scope()):
1144        # Build layer if applicable (if the `build` method has been
1145        # overridden).
1146        self._maybe_build(inputs)
1147        cast_inputs = self._maybe_cast_inputs(inputs, input_list)
1148
1149        if not self.dynamic:
1150          # Wrapping `call` function in autograph to allow for dynamic control
1151          # flow and control dependencies in call. We are limiting this to
1152          # subclassed layers as autograph is strictly needed only for
1153          # subclassed layers and models.
1154          # tf_convert will respect the value of autograph setting in the
1155          # enclosing tf.function, if any.
1156          if (base_layer_utils.is_subclassed(self) and
1157              not base_layer_utils.from_saved_model(self)):
1158            call_fn = autograph.tf_convert(self.call,
1159                                           ag_ctx.control_status_ctx())
1160          else:
1161            call_fn = self.call
1162
1163          try:
1164            with autocast_variable.enable_auto_cast_variables(
1165                self._compute_dtype_object):
1166              outputs = call_fn(cast_inputs, *args, **kwargs)
1167
1168          except errors.OperatorNotAllowedInGraphError as e:
1169            raise TypeError('You are attempting to use Python control '
1170                            'flow in a layer that was not declared to be '
1171                            'dynamic. Pass `dynamic=True` to the class '
1172                            'constructor.\nEncountered error:\n"""\n' + str(e) +
1173                            '\n"""')
1174        else:
1175          # We will use static shape inference to return symbolic tensors
1176          # matching the specifications of the layer outputs.
1177          # Since `self.dynamic` is True, we will never attempt to
1178          # run the underlying TF graph (which is disconnected).
1179          # TODO(fchollet): consider py_func as an alternative, which
1180          # would enable us to run the underlying graph if needed.
1181          outputs = self._symbolic_call(inputs)
1182
1183        if outputs is None:
1184          raise ValueError('A layer\'s `call` method should return a '
1185                           'Tensor or a list of Tensors, not None '
1186                           '(layer: ' + self.name + ').')
1187        # TODO(kaftan): This should be 'any' and check all args
1188        if base_layer_utils.have_all_keras_metadata(inputs):
1189          if training_arg_passed_by_framework:
1190            args, kwargs = self._set_call_arg_value(
1191                'training', None, args, kwargs, pop_kwarg_if_none=True)
1192          if mask_arg_passed_by_framework:
1193            kwargs.pop('mask')
1194          # Node connectivity does not special-case the first argument.
1195          outputs = self._set_connectivity_metadata((inputs,) + args, kwargs,
1196                                                    outputs)
1197        self._handle_activity_regularization(inputs, outputs)
1198        self._set_mask_metadata(inputs, outputs, input_masks, True)
1199        if hasattr(self, '_set_inputs') and not self.inputs:
1200          # Subclassed network: explicitly set metadata normally set by
1201          # a call to self._set_inputs().
1202          self._set_inputs(cast_inputs, outputs)
1203
1204    return outputs
1205
1206  def _set_training_mode(self, args, kwargs, call_context):
1207    training_mode = None
1208    if self._expects_training_arg:
1209      # (1) `training` was passed to this `Layer.call`.
1210      if self._call_arg_was_passed('training', args, kwargs):
1211        training_mode = self._get_call_arg_value('training', args, kwargs)
1212      # If no `training` arg was passed, or `None` was explicitly passed,
1213      # the framework will make a decision about the training mode is.
1214      if training_mode is None:
1215        call_ctx_training = call_context.training
1216        # (2) `training` mode is inferred from an outer `Layer.call`.
1217        if call_ctx_training is not None:
1218          training_mode = call_ctx_training
1219        # (3) User set `tf.keras.backend.set_learning_phase`.
1220        elif backend.global_learning_phase_is_set():
1221          training_mode = backend.learning_phase()
1222          # Ensure value is a `bool` or `tf.bool`.
1223          if isinstance(training_mode, bool):
1224            pass
1225          elif tensor_util.is_tf_type(training_mode):
1226            training_mode = math_ops.cast(training_mode, dtypes.bool)
1227          else:
1228            training_mode = bool(training_mode)
1229        # (4) We default to using `call`'s default value for `training`,
1230        # or treating the layer as if it is in inference if no non-None default
1231        # is specified in the `call` signature.
1232        else:
1233          training_mode = self._default_training_arg
1234
1235        # For case (2), (3), (4) `training` arg is passed by framework.
1236        args, kwargs = self._set_call_arg_value('training', training_mode, args,
1237                                                kwargs)
1238    else:
1239      if 'training' in kwargs:
1240        # `training` was passed to this `Layer` but is not needed for
1241        # `Layer.call`. It will set the default mode for inner `Layer.call`s.
1242        training_mode = kwargs.pop('training')
1243      else:
1244        # Grab the current `training` mode from any outer `Layer.call`.
1245        training_mode = call_context.training
1246
1247    return args, kwargs, training_mode
1248
1249  def _autographed_call(self):
1250    # Wrapping `call` function in autograph to allow for dynamic control
1251    # flow and control dependencies in call. We are limiting this to
1252    # subclassed layers as autograph is strictly needed only for
1253    # subclassed layers and models.
1254    # tf_convert will respect the value of autograph setting in the
1255    # enclosing tf.function, if any.
1256    if (base_layer_utils.is_subclassed(self) and
1257        not base_layer_utils.from_saved_model(self)):
1258      return autograph.tf_convert(self.call, ag_ctx.control_status_ctx())
1259    else:
1260      return self.call
1261
1262  @property
1263  def dtype(self):
1264    """The dtype of the layer weights.
1265
1266    This is equivalent to `Layer.dtype_policy.variable_dtype`. Unless
1267    mixed precision is used, this is the same as `Layer.compute_dtype`, the
1268    dtype of the layer's computations.
1269    """
1270    return self._dtype_policy.variable_dtype
1271
1272  @property
1273  def name(self):
1274    """Name of the layer (string), set in the constructor."""
1275    return self._name
1276
1277  @property
1278  def supports_masking(self):
1279    """Whether this layer supports computing a mask using `compute_mask`."""
1280    return self._supports_masking
1281
1282  @supports_masking.setter
1283  def supports_masking(self, value):
1284    self._supports_masking = value
1285
1286  @property
1287  def dynamic(self):
1288    """Whether the layer is dynamic (eager-only); set in the constructor."""
1289    return any(layer._dynamic for layer in self._flatten_layers())
1290
1291  @property
1292  @doc_controls.do_not_doc_inheritable
1293  def stateful(self):
1294    return any(layer._stateful for layer in self._flatten_layers())
1295
1296  @stateful.setter
1297  def stateful(self, value):
1298    self._stateful = value
1299
1300  @property
1301  def trainable(self):
1302    return self._trainable
1303
1304  @trainable.setter
1305  def trainable(self, value):
1306    for layer in self._flatten_layers():
1307      layer._trainable = value
1308
1309  @property
1310  def activity_regularizer(self):
1311    """Optional regularizer function for the output of this layer."""
1312    return self._activity_regularizer
1313
1314  @activity_regularizer.setter
1315  def activity_regularizer(self, regularizer):
1316    """Optional regularizer function for the output of this layer."""
1317    self._activity_regularizer = regularizer
1318
1319  @property
1320  def input_spec(self):
1321    """`InputSpec` instance(s) describing the input format for this layer.
1322
1323    When you create a layer subclass, you can set `self.input_spec` to enable
1324    the layer to run input compatibility checks when it is called.
1325    Consider a `Conv2D` layer: it can only be called on a single input tensor
1326    of rank 4. As such, you can set, in `__init__()`:
1327
1328    ```python
1329    self.input_spec = tf.keras.layers.InputSpec(ndim=4)
1330    ```
1331
1332    Now, if you try to call the layer on an input that isn't rank 4
1333    (for instance, an input of shape `(2,)`, it will raise a nicely-formatted
1334    error:
1335
1336    ```
1337    ValueError: Input 0 of layer conv2d is incompatible with the layer:
1338    expected ndim=4, found ndim=1. Full shape received: [2]
1339    ```
1340
1341    Input checks that can be specified via `input_spec` include:
1342    - Structure (e.g. a single input, a list of 2 inputs, etc)
1343    - Shape
1344    - Rank (ndim)
1345    - Dtype
1346
1347    For more information, see `tf.keras.layers.InputSpec`.
1348
1349    Returns:
1350      A `tf.keras.layers.InputSpec` instance, or nested structure thereof.
1351    """
1352    return self._input_spec
1353
1354  @input_spec.setter
1355  # Must be decorated to prevent tracking, since the input_spec can be nested
1356  # InputSpec objects.
1357  @trackable.no_automatic_dependency_tracking
1358  def input_spec(self, value):
1359    for v in nest.flatten(value):
1360      if v is not None and not isinstance(v, InputSpec):
1361        raise TypeError('Layer input_spec must be an instance of InputSpec. '
1362                        'Got: {}'.format(v))
1363    self._input_spec = value
1364
1365  @property
1366  def trainable_weights(self):
1367    """List of all trainable weights tracked by this layer.
1368
1369    Trainable weights are updated via gradient descent during training.
1370
1371    Returns:
1372      A list of trainable variables.
1373    """
1374    if self.trainable:
1375      children_weights = self._gather_children_attribute('trainable_variables')
1376      return self._dedup_weights(self._trainable_weights + children_weights)
1377    else:
1378      return []
1379
1380  @property
1381  def non_trainable_weights(self):
1382    """List of all non-trainable weights tracked by this layer.
1383
1384    Non-trainable weights are *not* updated during training. They are expected
1385    to be updated manually in `call()`.
1386
1387    Returns:
1388      A list of non-trainable variables.
1389    """
1390    if self.trainable:
1391      children_weights = self._gather_children_attribute(
1392          'non_trainable_variables')
1393      non_trainable_weights = self._non_trainable_weights + children_weights
1394    else:
1395      children_weights = self._gather_children_attribute('variables')
1396      non_trainable_weights = (
1397          self._trainable_weights + self._non_trainable_weights +
1398          children_weights)
1399    return self._dedup_weights(non_trainable_weights)
1400
1401  @property
1402  def weights(self):
1403    """Returns the list of all layer variables/weights.
1404
1405    Returns:
1406      A list of variables.
1407    """
1408    return self.trainable_weights + self.non_trainable_weights
1409
1410  @property
1411  @doc_controls.do_not_generate_docs
1412  def updates(self):
1413    warnings.warn('`layer.updates` will be removed in a future version. '
1414                  'This property should not be used in TensorFlow 2.0, '
1415                  'as `updates` are applied automatically.')
1416    if keras_tensor.keras_tensors_enabled():
1417      return []
1418
1419    collected_updates = []
1420    all_layers = self._flatten_layers()
1421    with backend.get_graph().as_default():
1422      for layer in all_layers:
1423        if not layer.trainable and not layer.stateful:
1424          continue
1425        for u in layer._updates:
1426          if callable(u):
1427            u = u()
1428          collected_updates.append(u)
1429    return collected_updates
1430
1431  @property
1432  def losses(self):
1433    """List of losses added using the `add_loss()` API.
1434
1435    Variable regularization tensors are created when this property is accessed,
1436    so it is eager safe: accessing `losses` under a `tf.GradientTape` will
1437    propagate gradients back to the corresponding variables.
1438
1439    Examples:
1440
1441    >>> class MyLayer(tf.keras.layers.Layer):
1442    ...   def call(self, inputs):
1443    ...     self.add_loss(tf.abs(tf.reduce_mean(inputs)))
1444    ...     return inputs
1445    >>> l = MyLayer()
1446    >>> l(np.ones((10, 1)))
1447    >>> l.losses
1448    [1.0]
1449
1450    >>> inputs = tf.keras.Input(shape=(10,))
1451    >>> x = tf.keras.layers.Dense(10)(inputs)
1452    >>> outputs = tf.keras.layers.Dense(1)(x)
1453    >>> model = tf.keras.Model(inputs, outputs)
1454    >>> # Activity regularization.
1455    >>> len(model.losses)
1456    0
1457    >>> model.add_loss(tf.abs(tf.reduce_mean(x)))
1458    >>> len(model.losses)
1459    1
1460
1461    >>> inputs = tf.keras.Input(shape=(10,))
1462    >>> d = tf.keras.layers.Dense(10, kernel_initializer='ones')
1463    >>> x = d(inputs)
1464    >>> outputs = tf.keras.layers.Dense(1)(x)
1465    >>> model = tf.keras.Model(inputs, outputs)
1466    >>> # Weight regularization.
1467    >>> model.add_loss(lambda: tf.reduce_mean(d.kernel))
1468    >>> model.losses
1469    [<tf.Tensor: shape=(), dtype=float32, numpy=1.0>]
1470
1471    Returns:
1472      A list of tensors.
1473    """
1474    collected_losses = []
1475    for layer in self._flatten_layers():
1476      # If any eager losses are present, we assume the model to be part of an
1477      # eager training loop (either a custom one or the one used when
1478      # `run_eagerly=True`) and so we always return just the eager losses.
1479      if layer._eager_losses:
1480        # Filter placeholder losses that may have been added by revived layers.
1481        # (see base_layer_utils for details).
1482        if (layer._eager_losses[0] is
1483            not base_layer_utils.REVIVED_LOSS_PLACEHOLDER):
1484          collected_losses.extend(layer._eager_losses)
1485      else:
1486        collected_losses.extend(layer._losses)
1487      for regularizer in layer._callable_losses:
1488        loss_tensor = regularizer()
1489        if loss_tensor is not None:
1490          collected_losses.append(loss_tensor)
1491    return collected_losses
1492
1493  def add_loss(self, losses, **kwargs):
1494    """Add loss tensor(s), potentially dependent on layer inputs.
1495
1496    Some losses (for instance, activity regularization losses) may be dependent
1497    on the inputs passed when calling a layer. Hence, when reusing the same
1498    layer on different inputs `a` and `b`, some entries in `layer.losses` may
1499    be dependent on `a` and some on `b`. This method automatically keeps track
1500    of dependencies.
1501
1502    This method can be used inside a subclassed layer or model's `call`
1503    function, in which case `losses` should be a Tensor or list of Tensors.
1504
1505    Example:
1506
1507    ```python
1508    class MyLayer(tf.keras.layers.Layer):
1509      def call(self, inputs):
1510        self.add_loss(tf.abs(tf.reduce_mean(inputs)))
1511        return inputs
1512    ```
1513
1514    This method can also be called directly on a Functional Model during
1515    construction. In this case, any loss Tensors passed to this Model must
1516    be symbolic and be able to be traced back to the model's `Input`s. These
1517    losses become part of the model's topology and are tracked in `get_config`.
1518
1519    Example:
1520
1521    ```python
1522    inputs = tf.keras.Input(shape=(10,))
1523    x = tf.keras.layers.Dense(10)(inputs)
1524    outputs = tf.keras.layers.Dense(1)(x)
1525    model = tf.keras.Model(inputs, outputs)
1526    # Activity regularization.
1527    model.add_loss(tf.abs(tf.reduce_mean(x)))
1528    ```
1529
1530    If this is not the case for your loss (if, for example, your loss references
1531    a `Variable` of one of the model's layers), you can wrap your loss in a
1532    zero-argument lambda. These losses are not tracked as part of the model's
1533    topology since they can't be serialized.
1534
1535    Example:
1536
1537    ```python
1538    inputs = tf.keras.Input(shape=(10,))
1539    d = tf.keras.layers.Dense(10)
1540    x = d(inputs)
1541    outputs = tf.keras.layers.Dense(1)(x)
1542    model = tf.keras.Model(inputs, outputs)
1543    # Weight regularization.
1544    model.add_loss(lambda: tf.reduce_mean(d.kernel))
1545    ```
1546
1547    Args:
1548      losses: Loss tensor, or list/tuple of tensors. Rather than tensors, losses
1549        may also be zero-argument callables which create a loss tensor.
1550      **kwargs: Additional keyword arguments for backward compatibility.
1551        Accepted values:
1552          inputs - Deprecated, will be automatically inferred.
1553    """
1554    kwargs.pop('inputs', None)
1555    if kwargs:
1556      raise TypeError('Unknown keyword arguments: %s' % (kwargs.keys(),))
1557
1558    def _tag_callable(loss):
1559      """Tags callable loss tensor as `_unconditional_loss`."""
1560      if callable(loss):
1561        # We run the loss without autocasting, as regularizers are often
1562        # numerically unstable in float16.
1563        with autocast_variable.enable_auto_cast_variables(None):
1564          loss = loss()
1565      if loss is None:
1566        return None  # Will be filtered out when computing the .losses property
1567      if not tensor_util.is_tf_type(loss):
1568        loss = ops.convert_to_tensor_v2_with_dispatch(
1569            loss, dtype=backend.floatx())
1570      loss._unconditional_loss = True  # pylint: disable=protected-access
1571      return loss
1572
1573    losses = nest.flatten(losses)
1574
1575    callable_losses = []
1576    eager_losses = []
1577    symbolic_losses = []
1578    for loss in losses:
1579      if callable(loss):
1580        callable_losses.append(functools.partial(_tag_callable, loss))
1581        continue
1582      if loss is None:
1583        continue
1584      if not tensor_util.is_tf_type(loss) and not isinstance(
1585          loss, keras_tensor.KerasTensor):
1586        loss = ops.convert_to_tensor_v2_with_dispatch(
1587            loss, dtype=backend.floatx())
1588      # TF Functions should take the eager path.
1589      if ((tf_utils.is_symbolic_tensor(loss) or
1590           isinstance(loss, keras_tensor.KerasTensor)) and
1591          not base_layer_utils.is_in_tf_function()):
1592        symbolic_losses.append(loss)
1593      elif tensor_util.is_tf_type(loss):
1594        eager_losses.append(loss)
1595
1596    self._callable_losses.extend(callable_losses)
1597
1598    in_call_context = base_layer_utils.call_context().in_call
1599    if eager_losses and not in_call_context:
1600      raise ValueError(
1601          'Expected a symbolic Tensors or a callable for the loss value. '
1602          'Please wrap your loss computation in a zero argument `lambda`.')
1603
1604    self._eager_losses.extend(eager_losses)
1605
1606    if in_call_context and not keras_tensor.keras_tensors_enabled():
1607      for symbolic_loss in symbolic_losses:
1608        self._losses.append(symbolic_loss)
1609    else:
1610      for symbolic_loss in symbolic_losses:
1611        if getattr(self, '_is_graph_network', False):
1612          self._graph_network_add_loss(symbolic_loss)
1613        else:
1614          # Possible a loss was added in a Layer's `build`.
1615          self._losses.append(symbolic_loss)
1616
1617  def _clear_losses(self):
1618    """Used every step in eager to reset losses."""
1619    # Set to thread local directly to avoid Layer.__setattr__ overhead.
1620    if not getattr(self, '_self_tracked_trackables',
1621                   None):  # Fast path for single Layer.
1622      self._thread_local._eager_losses = []
1623    else:
1624      for layer in self._flatten_layers():
1625        layer._thread_local._eager_losses = []
1626
1627  @property
1628  def metrics(self):
1629    """List of metrics added using the `add_metric()` API.
1630
1631    Example:
1632
1633    >>> input = tf.keras.layers.Input(shape=(3,))
1634    >>> d = tf.keras.layers.Dense(2)
1635    >>> output = d(input)
1636    >>> d.add_metric(tf.reduce_max(output), name='max')
1637    >>> d.add_metric(tf.reduce_min(output), name='min')
1638    >>> [m.name for m in d.metrics]
1639    ['max', 'min']
1640
1641    Returns:
1642      A list of `Metric` objects.
1643    """
1644    collected_metrics = []
1645    for layer in self._flatten_layers():
1646      with layer._metrics_lock:
1647        collected_metrics.extend(layer._metrics)
1648    return collected_metrics
1649
1650  def add_metric(self, value, name=None, **kwargs):
1651    """Adds metric tensor to the layer.
1652
1653    This method can be used inside the `call()` method of a subclassed layer
1654    or model.
1655
1656    ```python
1657    class MyMetricLayer(tf.keras.layers.Layer):
1658      def __init__(self):
1659        super(MyMetricLayer, self).__init__(name='my_metric_layer')
1660        self.mean = tf.keras.metrics.Mean(name='metric_1')
1661
1662      def call(self, inputs):
1663        self.add_metric(self.mean(x))
1664        self.add_metric(tf.reduce_sum(x), name='metric_2')
1665        return inputs
1666    ```
1667
1668    This method can also be called directly on a Functional Model during
1669    construction. In this case, any tensor passed to this Model must
1670    be symbolic and be able to be traced back to the model's `Input`s. These
1671    metrics become part of the model's topology and are tracked when you
1672    save the model via `save()`.
1673
1674    ```python
1675    inputs = tf.keras.Input(shape=(10,))
1676    x = tf.keras.layers.Dense(10)(inputs)
1677    outputs = tf.keras.layers.Dense(1)(x)
1678    model = tf.keras.Model(inputs, outputs)
1679    model.add_metric(math_ops.reduce_sum(x), name='metric_1')
1680    ```
1681
1682    Note: Calling `add_metric()` with the result of a metric object on a
1683    Functional Model, as shown in the example below, is not supported. This is
1684    because we cannot trace the metric result tensor back to the model's inputs.
1685
1686    ```python
1687    inputs = tf.keras.Input(shape=(10,))
1688    x = tf.keras.layers.Dense(10)(inputs)
1689    outputs = tf.keras.layers.Dense(1)(x)
1690    model = tf.keras.Model(inputs, outputs)
1691    model.add_metric(tf.keras.metrics.Mean()(x), name='metric_1')
1692    ```
1693
1694    Args:
1695      value: Metric tensor.
1696      name: String metric name.
1697      **kwargs: Additional keyword arguments for backward compatibility.
1698        Accepted values:
1699        `aggregation` - When the `value` tensor provided is not the result of
1700        calling a `keras.Metric` instance, it will be aggregated by default
1701        using a `keras.Metric.Mean`.
1702    """
1703    kwargs_keys = list(kwargs.keys())
1704    if (len(kwargs_keys) > 1 or
1705        (len(kwargs_keys) == 1 and kwargs_keys[0] != 'aggregation')):
1706      raise TypeError('Unknown keyword arguments: ', str(kwargs.keys()))
1707
1708    from_metric_obj = hasattr(value, '_metric_obj')
1709    if keras_tensor.keras_tensors_enabled():
1710      is_symbolic = isinstance(value, keras_tensor.KerasTensor)
1711    else:
1712      is_symbolic = tf_utils.is_symbolic_tensor(value)
1713    in_call_context = base_layer_utils.call_context().in_call
1714
1715    if name is None and not from_metric_obj:
1716      # Eg. `self.add_metric(math_ops.reduce_sum(x))`
1717      # In eager mode, we use metric name to lookup a metric. Without a name,
1718      # a new Mean metric wrapper will be created on every model/layer call.
1719      # So, we raise an error when no name is provided.
1720      # We will do the same for symbolic mode for consistency although a name
1721      # will be generated if no name is provided.
1722
1723      # We will not raise this error in the foll use case for the sake of
1724      # consistency as name in provided in the metric constructor.
1725      # mean = metrics.Mean(name='my_metric')
1726      # model.add_metric(mean(outputs))
1727      raise ValueError('Please provide a name for your metric like '
1728                       '`self.add_metric(tf.reduce_sum(inputs), '
1729                       'name=\'mean_activation\')`')
1730    elif from_metric_obj:
1731      name = value._metric_obj.name
1732
1733    if not in_call_context and not is_symbolic:
1734      raise ValueError('Expected a symbolic Tensor for the metric value, '
1735                       'received: ' + str(value))
1736
1737    # If a metric was added in a Layer's `call` or `build`.
1738    if in_call_context or not getattr(self, '_is_graph_network', False):
1739      # TF Function path should take the eager path.
1740
1741      # If the given metric is available in `metrics` list we just update state
1742      # on it, otherwise we create a new metric instance and
1743      # add it to the `metrics` list.
1744      metric_obj = getattr(value, '_metric_obj', None)
1745      # Tensors that come from a Metric object already updated the Metric state.
1746      should_update_state = not metric_obj
1747      name = metric_obj.name if metric_obj else name
1748
1749      with self._metrics_lock:
1750        match = self._get_existing_metric(name)
1751        if match:
1752          metric_obj = match
1753        elif metric_obj:
1754          self._metrics.append(metric_obj)
1755        else:
1756          # Build the metric object with the value's dtype if it defines one
1757          metric_obj = metrics_mod.Mean(
1758              name=name, dtype=getattr(value, 'dtype', None))
1759          self._metrics.append(metric_obj)
1760
1761      if should_update_state:
1762        metric_obj(value)
1763    else:
1764      if from_metric_obj:
1765        raise ValueError('Using the result of calling a `Metric` object '
1766                         'when calling `add_metric` on a Functional '
1767                         'Model is not supported. Please pass the '
1768                         'Tensor to monitor directly.')
1769
1770      # Insert layers into the Keras Graph Network.
1771      aggregation = None if from_metric_obj else 'mean'
1772      self._graph_network_add_metric(value, aggregation, name)
1773
1774  @doc_controls.do_not_doc_inheritable
1775  def add_update(self, updates, inputs=None):
1776    """Add update op(s), potentially dependent on layer inputs.
1777
1778    Weight updates (for instance, the updates of the moving mean and variance
1779    in a BatchNormalization layer) may be dependent on the inputs passed
1780    when calling a layer. Hence, when reusing the same layer on
1781    different inputs `a` and `b`, some entries in `layer.updates` may be
1782    dependent on `a` and some on `b`. This method automatically keeps track
1783    of dependencies.
1784
1785    This call is ignored when eager execution is enabled (in that case, variable
1786    updates are run on the fly and thus do not need to be tracked for later
1787    execution).
1788
1789    Args:
1790      updates: Update op, or list/tuple of update ops, or zero-arg callable
1791        that returns an update op. A zero-arg callable should be passed in
1792        order to disable running the updates by setting `trainable=False`
1793        on this Layer, when executing in Eager mode.
1794      inputs: Deprecated, will be automatically inferred.
1795    """
1796    if inputs is not None:
1797      tf_logging.warning(
1798          '`add_update` `inputs` kwarg has been deprecated. You no longer need '
1799          'to pass a value to `inputs` as it is being automatically inferred.')
1800    call_context = base_layer_utils.call_context()
1801    # No need to run updates during Functional API construction.
1802    if call_context.in_keras_graph:
1803      return
1804
1805    # Callable updates are disabled by setting `trainable=False`.
1806    if not call_context.frozen:
1807      for update in nest.flatten(updates):
1808        if callable(update):
1809          update()  # pylint: disable=not-callable
1810
1811  def set_weights(self, weights):
1812    """Sets the weights of the layer, from Numpy arrays.
1813
1814    The weights of a layer represent the state of the layer. This function
1815    sets the weight values from numpy arrays. The weight values should be
1816    passed in the order they are created by the layer. Note that the layer's
1817    weights must be instantiated before calling this function by calling
1818    the layer.
1819
1820    For example, a Dense layer returns a list of two values-- per-output
1821    weights and the bias value. These can be used to set the weights of another
1822    Dense layer:
1823
1824    >>> a = tf.keras.layers.Dense(1,
1825    ...   kernel_initializer=tf.constant_initializer(1.))
1826    >>> a_out = a(tf.convert_to_tensor([[1., 2., 3.]]))
1827    >>> a.get_weights()
1828    [array([[1.],
1829           [1.],
1830           [1.]], dtype=float32), array([0.], dtype=float32)]
1831    >>> b = tf.keras.layers.Dense(1,
1832    ...   kernel_initializer=tf.constant_initializer(2.))
1833    >>> b_out = b(tf.convert_to_tensor([[10., 20., 30.]]))
1834    >>> b.get_weights()
1835    [array([[2.],
1836           [2.],
1837           [2.]], dtype=float32), array([0.], dtype=float32)]
1838    >>> b.set_weights(a.get_weights())
1839    >>> b.get_weights()
1840    [array([[1.],
1841           [1.],
1842           [1.]], dtype=float32), array([0.], dtype=float32)]
1843
1844    Args:
1845        weights: a list of Numpy arrays. The number
1846            of arrays and their shape must match
1847            number of the dimensions of the weights
1848            of the layer (i.e. it should match the
1849            output of `get_weights`).
1850
1851    Raises:
1852        ValueError: If the provided weights list does not match the
1853            layer's specifications.
1854    """
1855    params = self.weights
1856
1857    expected_num_weights = 0
1858    for param in params:
1859      if isinstance(param, base_layer_utils.TrackableWeightHandler):
1860        expected_num_weights += param.num_tensors
1861      else:
1862        expected_num_weights += 1
1863
1864    if expected_num_weights != len(weights):
1865      raise ValueError(
1866          'You called `set_weights(weights)` on layer "%s" '
1867          'with a weight list of length %s, but the layer was '
1868          'expecting %s weights. Provided weights: %s...' %
1869          (self.name, len(weights), expected_num_weights, str(weights)[:50]))
1870
1871    weight_index = 0
1872    weight_value_tuples = []
1873    for param in params:
1874      if isinstance(param, base_layer_utils.TrackableWeightHandler):
1875        num_tensors = param.num_tensors
1876        tensors = weights[weight_index:weight_index + num_tensors]
1877        param.set_weights(tensors)
1878        weight_index += num_tensors
1879      else:
1880        weight = weights[weight_index]
1881        ref_shape = param.shape
1882        if not ref_shape.is_compatible_with(weight.shape):
1883          raise ValueError(
1884              'Layer weight shape %s not compatible with provided weight '
1885              'shape %s' % (ref_shape, weight.shape))
1886        weight_value_tuples.append((param, weight))
1887        weight_index += 1
1888
1889    backend.batch_set_value(weight_value_tuples)
1890
1891  def get_weights(self):
1892    """Returns the current weights of the layer.
1893
1894    The weights of a layer represent the state of the layer. This function
1895    returns both trainable and non-trainable weight values associated with this
1896    layer as a list of Numpy arrays, which can in turn be used to load state
1897    into similarly parameterized layers.
1898
1899    For example, a Dense layer returns a list of two values-- per-output
1900    weights and the bias value. These can be used to set the weights of another
1901    Dense layer:
1902
1903    >>> a = tf.keras.layers.Dense(1,
1904    ...   kernel_initializer=tf.constant_initializer(1.))
1905    >>> a_out = a(tf.convert_to_tensor([[1., 2., 3.]]))
1906    >>> a.get_weights()
1907    [array([[1.],
1908           [1.],
1909           [1.]], dtype=float32), array([0.], dtype=float32)]
1910    >>> b = tf.keras.layers.Dense(1,
1911    ...   kernel_initializer=tf.constant_initializer(2.))
1912    >>> b_out = b(tf.convert_to_tensor([[10., 20., 30.]]))
1913    >>> b.get_weights()
1914    [array([[2.],
1915           [2.],
1916           [2.]], dtype=float32), array([0.], dtype=float32)]
1917    >>> b.set_weights(a.get_weights())
1918    >>> b.get_weights()
1919    [array([[1.],
1920           [1.],
1921           [1.]], dtype=float32), array([0.], dtype=float32)]
1922
1923    Returns:
1924        Weights values as a list of numpy arrays.
1925    """
1926    weights = self.weights
1927    output_weights = []
1928    for weight in weights:
1929      if isinstance(weight, base_layer_utils.TrackableWeightHandler):
1930        output_weights.extend(weight.get_tensors())
1931      else:
1932        output_weights.append(weight)
1933    return backend.batch_get_value(output_weights)
1934
1935  @doc_controls.do_not_generate_docs
1936  def get_updates_for(self, inputs):
1937    """Deprecated, do NOT use!
1938
1939    Retrieves updates relevant to a specific set of inputs.
1940
1941    Args:
1942      inputs: Input tensor or list/tuple of input tensors.
1943
1944    Returns:
1945      List of update ops of the layer that depend on `inputs`.
1946    """
1947    warnings.warn('`layer.get_updates_for` is deprecated and '
1948                  'will be removed in a future version. '
1949                  'Please use `layer.updates` method instead.')
1950    return self.updates
1951
1952  @doc_controls.do_not_generate_docs
1953  def get_losses_for(self, inputs):
1954    """Deprecated, do NOT use!
1955
1956    Retrieves losses relevant to a specific set of inputs.
1957
1958    Args:
1959      inputs: Input tensor or list/tuple of input tensors.
1960
1961    Returns:
1962      List of loss tensors of the layer that depend on `inputs`.
1963    """
1964    warnings.warn('`layer.get_losses_for` is deprecated and '
1965                  'will be removed in a future version. '
1966                  'Please use `layer.losses` instead.')
1967    return self.losses
1968
1969  @doc_controls.do_not_doc_inheritable
1970  def get_input_mask_at(self, node_index):
1971    """Retrieves the input mask tensor(s) of a layer at a given node.
1972
1973    Args:
1974        node_index: Integer, index of the node
1975            from which to retrieve the attribute.
1976            E.g. `node_index=0` will correspond to the
1977            first time the layer was called.
1978
1979    Returns:
1980        A mask tensor
1981        (or list of tensors if the layer has multiple inputs).
1982    """
1983    inputs = self.get_input_at(node_index)
1984    if isinstance(inputs, list):
1985      return [getattr(x, '_keras_mask', None) for x in inputs]
1986    else:
1987      return getattr(inputs, '_keras_mask', None)
1988
1989  @doc_controls.do_not_doc_inheritable
1990  def get_output_mask_at(self, node_index):
1991    """Retrieves the output mask tensor(s) of a layer at a given node.
1992
1993    Args:
1994        node_index: Integer, index of the node
1995            from which to retrieve the attribute.
1996            E.g. `node_index=0` will correspond to the
1997            first time the layer was called.
1998
1999    Returns:
2000        A mask tensor
2001        (or list of tensors if the layer has multiple outputs).
2002    """
2003    output = self.get_output_at(node_index)
2004    if isinstance(output, list):
2005      return [getattr(x, '_keras_mask', None) for x in output]
2006    else:
2007      return getattr(output, '_keras_mask', None)
2008
2009  @property
2010  @doc_controls.do_not_doc_inheritable
2011  def input_mask(self):
2012    """Retrieves the input mask tensor(s) of a layer.
2013
2014    Only applicable if the layer has exactly one inbound node,
2015    i.e. if it is connected to one incoming layer.
2016
2017    Returns:
2018        Input mask tensor (potentially None) or list of input
2019        mask tensors.
2020
2021    Raises:
2022        AttributeError: if the layer is connected to
2023        more than one incoming layers.
2024    """
2025    inputs = self.input
2026    if isinstance(inputs, list):
2027      return [getattr(x, '_keras_mask', None) for x in inputs]
2028    else:
2029      return getattr(inputs, '_keras_mask', None)
2030
2031  @property
2032  @doc_controls.do_not_doc_inheritable
2033  def output_mask(self):
2034    """Retrieves the output mask tensor(s) of a layer.
2035
2036    Only applicable if the layer has exactly one inbound node,
2037    i.e. if it is connected to one incoming layer.
2038
2039    Returns:
2040        Output mask tensor (potentially None) or list of output
2041        mask tensors.
2042
2043    Raises:
2044        AttributeError: if the layer is connected to
2045        more than one incoming layers.
2046    """
2047    output = self.output
2048    if isinstance(output, list):
2049      return [getattr(x, '_keras_mask', None) for x in output]
2050    else:
2051      return getattr(output, '_keras_mask', None)
2052
2053  @doc_controls.do_not_doc_inheritable
2054  def get_input_shape_at(self, node_index):
2055    """Retrieves the input shape(s) of a layer at a given node.
2056
2057    Args:
2058        node_index: Integer, index of the node
2059            from which to retrieve the attribute.
2060            E.g. `node_index=0` will correspond to the
2061            first time the layer was called.
2062
2063    Returns:
2064        A shape tuple
2065        (or list of shape tuples if the layer has multiple inputs).
2066
2067    Raises:
2068      RuntimeError: If called in Eager mode.
2069    """
2070    return self._get_node_attribute_at_index(node_index, 'input_shapes',
2071                                             'input shape')
2072
2073  @doc_controls.do_not_doc_inheritable
2074  def get_output_shape_at(self, node_index):
2075    """Retrieves the output shape(s) of a layer at a given node.
2076
2077    Args:
2078        node_index: Integer, index of the node
2079            from which to retrieve the attribute.
2080            E.g. `node_index=0` will correspond to the
2081            first time the layer was called.
2082
2083    Returns:
2084        A shape tuple
2085        (or list of shape tuples if the layer has multiple outputs).
2086
2087    Raises:
2088      RuntimeError: If called in Eager mode.
2089    """
2090    return self._get_node_attribute_at_index(node_index, 'output_shapes',
2091                                             'output shape')
2092
2093  @doc_controls.do_not_doc_inheritable
2094  def get_input_at(self, node_index):
2095    """Retrieves the input tensor(s) of a layer at a given node.
2096
2097    Args:
2098        node_index: Integer, index of the node
2099            from which to retrieve the attribute.
2100            E.g. `node_index=0` will correspond to the
2101            first input node of the layer.
2102
2103    Returns:
2104        A tensor (or list of tensors if the layer has multiple inputs).
2105
2106    Raises:
2107      RuntimeError: If called in Eager mode.
2108    """
2109    return self._get_node_attribute_at_index(node_index, 'input_tensors',
2110                                             'input')
2111
2112  @doc_controls.do_not_doc_inheritable
2113  def get_output_at(self, node_index):
2114    """Retrieves the output tensor(s) of a layer at a given node.
2115
2116    Args:
2117        node_index: Integer, index of the node
2118            from which to retrieve the attribute.
2119            E.g. `node_index=0` will correspond to the
2120            first output node of the layer.
2121
2122    Returns:
2123        A tensor (or list of tensors if the layer has multiple outputs).
2124
2125    Raises:
2126      RuntimeError: If called in Eager mode.
2127    """
2128    return self._get_node_attribute_at_index(node_index, 'output_tensors',
2129                                             'output')
2130
2131  @property
2132  def input(self):
2133    """Retrieves the input tensor(s) of a layer.
2134
2135    Only applicable if the layer has exactly one input,
2136    i.e. if it is connected to one incoming layer.
2137
2138    Returns:
2139        Input tensor or list of input tensors.
2140
2141    Raises:
2142      RuntimeError: If called in Eager mode.
2143      AttributeError: If no inbound nodes are found.
2144    """
2145    if not self._inbound_nodes:
2146      raise AttributeError('Layer ' + self.name +
2147                           ' is not connected, no input to return.')
2148    return self._get_node_attribute_at_index(0, 'input_tensors', 'input')
2149
2150  @property
2151  def output(self):
2152    """Retrieves the output tensor(s) of a layer.
2153
2154    Only applicable if the layer has exactly one output,
2155    i.e. if it is connected to one incoming layer.
2156
2157    Returns:
2158      Output tensor or list of output tensors.
2159
2160    Raises:
2161      AttributeError: if the layer is connected to more than one incoming
2162        layers.
2163      RuntimeError: if called in Eager mode.
2164    """
2165    if not self._inbound_nodes:
2166      raise AttributeError('Layer ' + self.name + ' has no inbound nodes.')
2167    return self._get_node_attribute_at_index(0, 'output_tensors', 'output')
2168
2169  @property
2170  @doc_controls.do_not_doc_inheritable
2171  def input_shape(self):
2172    """Retrieves the input shape(s) of a layer.
2173
2174    Only applicable if the layer has exactly one input,
2175    i.e. if it is connected to one incoming layer, or if all inputs
2176    have the same shape.
2177
2178    Returns:
2179        Input shape, as an integer shape tuple
2180        (or list of shape tuples, one tuple per input tensor).
2181
2182    Raises:
2183        AttributeError: if the layer has no defined input_shape.
2184        RuntimeError: if called in Eager mode.
2185    """
2186    if not self._inbound_nodes:
2187      raise AttributeError('The layer has never been called '
2188                           'and thus has no defined input shape.')
2189    all_input_shapes = set(
2190        [str(node.input_shapes) for node in self._inbound_nodes])
2191    if len(all_input_shapes) == 1:
2192      return self._inbound_nodes[0].input_shapes
2193    else:
2194      raise AttributeError('The layer "' + str(self.name) +
2195                           ' has multiple inbound nodes, '
2196                           'with different input shapes. Hence '
2197                           'the notion of "input shape" is '
2198                           'ill-defined for the layer. '
2199                           'Use `get_input_shape_at(node_index)` '
2200                           'instead.')
2201
2202  def count_params(self):
2203    """Count the total number of scalars composing the weights.
2204
2205    Returns:
2206        An integer count.
2207
2208    Raises:
2209        ValueError: if the layer isn't yet built
2210          (in which case its weights aren't yet defined).
2211    """
2212    if not self.built:
2213      if getattr(self, '_is_graph_network', False):
2214        with tf_utils.maybe_init_scope(self):
2215          self._maybe_build(self.inputs)
2216      else:
2217        raise ValueError('You tried to call `count_params` on ' + self.name +
2218                         ', but the layer isn\'t built. '
2219                         'You can build it manually via: `' + self.name +
2220                         '.build(batch_input_shape)`.')
2221    return layer_utils.count_params(self.weights)
2222
2223  @property
2224  @doc_controls.do_not_doc_inheritable
2225  def output_shape(self):
2226    """Retrieves the output shape(s) of a layer.
2227
2228    Only applicable if the layer has one output,
2229    or if all outputs have the same shape.
2230
2231    Returns:
2232        Output shape, as an integer shape tuple
2233        (or list of shape tuples, one tuple per output tensor).
2234
2235    Raises:
2236        AttributeError: if the layer has no defined output shape.
2237        RuntimeError: if called in Eager mode.
2238    """
2239    if not self._inbound_nodes:
2240      raise AttributeError('The layer has never been called '
2241                           'and thus has no defined output shape.')
2242    all_output_shapes = set(
2243        [str(node.output_shapes) for node in self._inbound_nodes])
2244    if len(all_output_shapes) == 1:
2245      return self._inbound_nodes[0].output_shapes
2246    else:
2247      raise AttributeError('The layer "%s"'
2248                           ' has multiple inbound nodes, '
2249                           'with different output shapes. Hence '
2250                           'the notion of "output shape" is '
2251                           'ill-defined for the layer. '
2252                           'Use `get_output_shape_at(node_index)` '
2253                           'instead.' % self.name)
2254
2255  @property
2256  @doc_controls.do_not_doc_inheritable
2257  def inbound_nodes(self):
2258    """Deprecated, do NOT use! Only for compatibility with external Keras."""
2259    return self._inbound_nodes
2260
2261  @property
2262  @doc_controls.do_not_doc_inheritable
2263  def outbound_nodes(self):
2264    """Deprecated, do NOT use! Only for compatibility with external Keras."""
2265    return self._outbound_nodes
2266
2267  ##############################################################################
2268  # Methods & attributes below are public aliases of other methods.            #
2269  ##############################################################################
2270
2271  @doc_controls.do_not_doc_inheritable
2272  def apply(self, inputs, *args, **kwargs):
2273    """Deprecated, do NOT use!
2274
2275    This is an alias of `self.__call__`.
2276
2277    Args:
2278      inputs: Input tensor(s).
2279      *args: additional positional arguments to be passed to `self.call`.
2280      **kwargs: additional keyword arguments to be passed to `self.call`.
2281
2282    Returns:
2283      Output tensor(s).
2284    """
2285    warnings.warn('`layer.apply` is deprecated and '
2286                  'will be removed in a future version. '
2287                  'Please use `layer.__call__` method instead.')
2288    return self.__call__(inputs, *args, **kwargs)
2289
2290  @doc_controls.do_not_doc_inheritable
2291  def add_variable(self, *args, **kwargs):
2292    """Deprecated, do NOT use! Alias for `add_weight`."""
2293    warnings.warn('`layer.add_variable` is deprecated and '
2294                  'will be removed in a future version. '
2295                  'Please use `layer.add_weight` method instead.')
2296    return self.add_weight(*args, **kwargs)
2297
2298  @property
2299  @doc_controls.do_not_generate_docs
2300  def variables(self):
2301    """Returns the list of all layer variables/weights.
2302
2303    Alias of `self.weights`.
2304
2305    Note: This will not track the weights of nested `tf.Modules` that are not
2306    themselves Keras layers.
2307
2308    Returns:
2309      A list of variables.
2310    """
2311    return self.weights
2312
2313  @property
2314  @doc_controls.do_not_generate_docs
2315  def trainable_variables(self):
2316    return self.trainable_weights
2317
2318  @property
2319  @doc_controls.do_not_generate_docs
2320  def non_trainable_variables(self):
2321    return self.non_trainable_weights
2322
2323  ##############################################################################
2324  # Methods & attributes below are all private and only used by the framework. #
2325  ##############################################################################
2326
2327  @property
2328  def _inbound_nodes(self):
2329    return self._inbound_nodes_value
2330
2331  @_inbound_nodes.setter
2332  @trackable.no_automatic_dependency_tracking
2333  def _inbound_nodes(self, value):
2334    self._inbound_nodes_value = value
2335
2336  @property
2337  def _outbound_nodes(self):
2338    return self._outbound_nodes_value
2339
2340  @_outbound_nodes.setter
2341  @trackable.no_automatic_dependency_tracking
2342  def _outbound_nodes(self, value):
2343    self._outbound_nodes_value = value
2344
2345  def _set_dtype_policy(self, dtype):
2346    """Sets self._dtype_policy."""
2347    if isinstance(dtype, policy.Policy):
2348      self._dtype_policy = dtype
2349    elif isinstance(dtype, dict):
2350      self._dtype_policy = policy.deserialize(dtype)
2351    elif dtype:
2352      self._dtype_policy = policy.Policy(dtypes.as_dtype(dtype).name)
2353    else:
2354      self._dtype_policy = policy.global_policy()
2355    if (self._dtype_policy.name == 'mixed_float16' and
2356        not loss_scale_optimizer.strategy_supports_loss_scaling()):
2357      # Although only loss scaling doesn't support certain strategies, to avoid
2358      # confusion, we disallow the 'mixed_float16' policy with unsupported
2359      # strategies. This is because 'mixed_float16' requires loss scaling for
2360      # numeric stability.
2361      strategy = ds_context.get_strategy()
2362      raise ValueError('Mixed precision is not supported with the '
2363                       'tf.distribute.Strategy: %s. Either stop using mixed '
2364                       'precision by removing the use of the "%s" policy or '
2365                       'use a different Strategy, e.g. a MirroredStrategy.' %
2366                       (strategy.__class__.__name__, self._dtype_policy.name))
2367
2368    # Performance optimization: cache the compute dtype as a Dtype object or
2369    # None, so that str to Dtype conversion doesn't happen in Layer.__call__.
2370    # TODO(b/157486353): Investigate returning DTypes in Policy.
2371    if self._dtype_policy.compute_dtype:
2372      self._compute_dtype_object = dtypes.as_dtype(
2373          self._dtype_policy.compute_dtype)
2374    else:
2375      self._compute_dtype_object = None
2376
2377  @property
2378  def dtype_policy(self):
2379    """The dtype policy associated with this layer.
2380
2381    This is an instance of a `tf.keras.mixed_precision.Policy`.
2382    """
2383    return self._dtype_policy
2384
2385  @property
2386  def compute_dtype(self):
2387    """The dtype of the layer's computations.
2388
2389    This is equivalent to `Layer.dtype_policy.compute_dtype`. Unless
2390    mixed precision is used, this is the same as `Layer.dtype`, the dtype of
2391    the weights.
2392
2393    Layers automatically cast their inputs to the compute dtype, which causes
2394    computations and the output to be in the compute dtype as well. This is done
2395    by the base Layer class in `Layer.__call__`, so you do not have to insert
2396    these casts if implementing your own layer.
2397
2398    Layers often perform certain internal computations in higher precision when
2399    `compute_dtype` is float16 or bfloat16 for numeric stability. The output
2400    will still typically be float16 or bfloat16 in such cases.
2401
2402    Returns:
2403      The layer's compute dtype.
2404    """
2405    return self._dtype_policy.compute_dtype
2406
2407  @property
2408  def _compute_dtype(self):
2409    """Deprecated alias of `compute_dtype`."""
2410    return self._dtype_policy.compute_dtype
2411
2412  @property
2413  def variable_dtype(self):
2414    """Alias of `Layer.dtype`, the dtype of the weights."""
2415    return self.dtype
2416
2417  def _maybe_cast_inputs(self, inputs, input_list=None):
2418    """Maybe casts the inputs to the compute dtype.
2419
2420    If self._compute_dtype is floating-point, and self_autocast is True,
2421    floating-point inputs are casted to self._compute_dtype.
2422
2423    Args:
2424      inputs: Input tensor, or structure of input tensors.
2425      input_list: Flat list of input tensors.
2426
2427    Returns:
2428      `inputs`, but tensors may have been casted to self._compute_dtype
2429    """
2430    if not input_list:
2431      input_list = nest.flatten(inputs)
2432
2433    compute_dtype_object = self._compute_dtype_object
2434    should_autocast = (
2435        self._autocast and compute_dtype_object and
2436        compute_dtype_object.is_floating)
2437
2438    if (should_autocast and
2439        any(map(self._should_cast_single_input, input_list))):
2440      # Only perform expensive `nest` operation when needed.
2441      return nest.map_structure(self._cast_single_input, inputs)
2442    else:
2443      return inputs
2444
2445  def _should_cast_single_input(self, x):
2446    if isinstance(x, _AUTOCAST_TYPES):
2447      return (self._compute_dtype_object and
2448              x.dtype != self._compute_dtype_object and x.dtype.is_floating)
2449    return False
2450
2451  def _cast_single_input(self, x):
2452    """Cast a single Tensor or TensorSpec to the compute dtype."""
2453    if self._should_cast_single_input(x):
2454      return math_ops.cast(x, self._compute_dtype_object)
2455    else:
2456      return x
2457
2458  # _dtype used to be an attribute set in the constructor. We still expose it
2459  # because some clients still use it.
2460  # TODO(reedwm): Deprecate, then remove the _dtype property.
2461  @property
2462  def _dtype(self):
2463    # This is equivalent to returning self.dtype . We do not return self.dtype
2464    # as it would cause infinite recursion in a few subclasses, which override
2465    # "dtype" to return self._dtype.
2466    return self._dtype_policy.variable_dtype
2467
2468  @_dtype.setter
2469  def _dtype(self, value):
2470    value = dtypes.as_dtype(value).name
2471    self._set_dtype_policy(policy.Policy(value))
2472
2473  def _name_scope(self):
2474    if not tf2.enabled():
2475      return self.name
2476    name_scope = self.name
2477    current_name_scope = ops.get_name_scope()
2478    if current_name_scope:
2479      name_scope = current_name_scope + '/' + name_scope
2480    if name_scope:
2481      # Note that the trailing `/` prevents autogenerated
2482      # numerical suffixes to get appended. It will also fully reset
2483      # nested name scope (i.e. the outer name scope has no effect).
2484      name_scope += '/'
2485    return name_scope
2486
2487  def _init_set_name(self, name, zero_based=True):
2488    if not name:
2489      self._name = backend.unique_object_name(
2490          generic_utils.to_snake_case(self.__class__.__name__),
2491          zero_based=zero_based)
2492    else:
2493      backend.observe_object_name(name)
2494      self._name = name
2495
2496  def _get_existing_metric(self, name=None):
2497    match = [m for m in self._metrics if m.name == name]
2498    if not match:
2499      return
2500    if len(match) > 1:
2501      raise ValueError(
2502          'Please provide different names for the metrics you have added. '
2503          'We found {} metrics with the name: "{}"'.format(len(match), name))
2504    return match[0]
2505
2506  def _handle_weight_regularization(self, name, variable, regularizer):
2507    """Create lambdas which compute regularization losses."""
2508
2509    def _loss_for_variable(v):
2510      """Creates a regularization loss `Tensor` for variable `v`."""
2511      with backend.name_scope(name + '/Regularizer'):
2512        regularization = regularizer(v)
2513      return regularization
2514
2515    if base_layer_utils.is_split_variable(variable):
2516      for v in variable:
2517        self.add_loss(functools.partial(_loss_for_variable, v))
2518    else:
2519      self.add_loss(functools.partial(_loss_for_variable, variable))
2520
2521  def _handle_activity_regularization(self, inputs, outputs):
2522    # Apply activity regularization.
2523    # Note that it should be applied every time the layer creates a new
2524    # output, since it is output-specific.
2525    if self._activity_regularizer:
2526      output_list = nest.flatten(outputs)
2527      with backend.name_scope('ActivityRegularizer'):
2528        for output in output_list:
2529          activity_loss = self._activity_regularizer(output)
2530          batch_size = math_ops.cast(
2531              array_ops.shape(output)[0], activity_loss.dtype)
2532          # Make activity regularization strength batch-agnostic.
2533          mean_activity_loss = activity_loss / batch_size
2534          self.add_loss(mean_activity_loss)
2535
2536  def _set_mask_metadata(self, inputs, outputs, previous_mask, build_graph):
2537    # Many `Layer`s don't need to call `compute_mask`.
2538    # This method is optimized to do as little work as needed for the common
2539    # case.
2540    if not self._supports_masking:
2541      return
2542
2543    flat_outputs = nest.flatten(outputs)
2544
2545    mask_already_computed = (
2546        getattr(self, '_compute_output_and_mask_jointly', False) or
2547        all(getattr(x, '_keras_mask', None) is not None for x in flat_outputs))
2548    if mask_already_computed:
2549      if build_graph:
2550        self._set_mask_keras_history_checked(flat_outputs)
2551      return
2552
2553    output_masks = self.compute_mask(inputs, previous_mask)
2554    if output_masks is None:
2555      return
2556
2557    flat_masks = nest.flatten(output_masks)
2558    for tensor, mask in zip(flat_outputs, flat_masks):
2559      try:
2560        tensor._keras_mask = mask
2561      except AttributeError:
2562        # C Type such as np.ndarray.
2563        pass
2564
2565    if build_graph:
2566      self._set_mask_keras_history_checked(flat_outputs)
2567
2568  def _set_mask_keras_history_checked(self, flat_outputs):
2569    for output in flat_outputs:
2570      if getattr(output, '_keras_mask', None) is not None:
2571        # Do not track masks for `TensorFlowOpLayer` construction.
2572        output._keras_mask._keras_history_checked = True
2573
2574  def _get_input_masks(self, inputs, input_list, args, kwargs):
2575    if not self._supports_masking and not self._expects_mask_arg:
2576      # Input masks only need to be retrieved if they are needed for `call`
2577      # or `compute_mask`.
2578      input_masks = None
2579      implicit_mask = False
2580    elif self._call_arg_was_passed('mask', args, kwargs):
2581      input_masks = self._get_call_arg_value('mask', args, kwargs)
2582      implicit_mask = False
2583    else:
2584      input_masks = [getattr(t, '_keras_mask', None) for t in input_list]
2585      if all(mask is None for mask in input_masks):
2586        input_masks = None
2587        implicit_mask = False
2588      else:
2589        # Only do expensive `nest` op when masking is actually being used.
2590        input_masks = nest.pack_sequence_as(inputs, input_masks)
2591        implicit_mask = True
2592    return input_masks, implicit_mask
2593
2594  def _call_arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False):
2595    # Performance optimization: do no work in most common case.
2596    if not args and not kwargs:
2597      return False
2598
2599    if arg_name in kwargs:
2600      return True
2601    call_fn_args = self._call_fn_args
2602    if not inputs_in_args:
2603      # Ignore `inputs` arg.
2604      call_fn_args = call_fn_args[1:]
2605    return arg_name in dict(zip(call_fn_args, args))
2606
2607  def _get_call_arg_value(self, arg_name, args, kwargs, inputs_in_args=False):
2608    if arg_name in kwargs:
2609      return kwargs[arg_name]
2610    call_fn_args = self._call_fn_args
2611    if not inputs_in_args:
2612      # Ignore `inputs` arg.
2613      call_fn_args = call_fn_args[1:]
2614    args_dict = dict(zip(call_fn_args, args))
2615    return args_dict[arg_name]
2616
2617  def _set_call_arg_value(
2618      self, arg_name, new_value, args,
2619      kwargs, inputs_in_args=False, pop_kwarg_if_none=False):
2620    arg_pos = self._call_fn_arg_positions.get(arg_name, None)
2621    if arg_pos is not None:
2622      if not inputs_in_args:
2623        # Ignore `inputs` arg.
2624        arg_pos = arg_pos - 1
2625      if len(args) > arg_pos:
2626        args = list(args)
2627        args[arg_pos] = new_value
2628        return tuple(args), kwargs
2629    if new_value is None and pop_kwarg_if_none:
2630      kwargs.pop(arg_name, None)
2631    else:
2632      kwargs[arg_name] = new_value
2633    return args, kwargs
2634
2635  def _set_connectivity_metadata(self, args, kwargs, outputs):
2636    # If the layer returns tensors from its inputs unmodified,
2637    # we copy them to avoid loss of KerasHistory metadata.
2638    flat_outputs = nest.flatten(outputs)
2639    flat_inputs = nest.flatten((args, kwargs))
2640    input_ids_set = {id(i) for i in flat_inputs}
2641    outputs_copy = []
2642    for x in flat_outputs:
2643      if id(x) in input_ids_set:
2644        with backend.name_scope(self.name):
2645          x = array_ops.identity(x)
2646      outputs_copy.append(x)
2647    outputs = nest.pack_sequence_as(outputs, outputs_copy)
2648
2649    # Create node, Node wires itself to inbound and outbound layers.
2650    # The Node constructor actually updates this layer's self._inbound_nodes,
2651    # sets _keras_history on the outputs, and adds itself to the
2652    # `_outbound_nodes` of the layers that produced the inputs to this
2653    # layer call.
2654    node_module.Node(self, call_args=args, call_kwargs=kwargs, outputs=outputs)
2655    return outputs
2656
2657  def _get_node_attribute_at_index(self, node_index, attr, attr_name):
2658    """Private utility to retrieves an attribute (e.g. inputs) from a node.
2659
2660    This is used to implement the methods:
2661        - get_input_shape_at
2662        - get_output_shape_at
2663        - get_input_at
2664        etc...
2665
2666    Args:
2667        node_index: Integer index of the node from which
2668            to retrieve the attribute.
2669        attr: Exact node attribute name.
2670        attr_name: Human-readable attribute name, for error messages.
2671
2672    Returns:
2673        The layer's attribute `attr` at the node of index `node_index`.
2674
2675    Raises:
2676        RuntimeError: If the layer has no inbound nodes, or if called in Eager
2677        mode.
2678        ValueError: If the index provided does not match any node.
2679    """
2680    if not self._inbound_nodes:
2681      raise RuntimeError('The layer has never been called '
2682                         'and thus has no defined ' + attr_name + '.')
2683    if not len(self._inbound_nodes) > node_index:
2684      raise ValueError('Asked to get ' + attr_name + ' at node ' +
2685                       str(node_index) + ', but the layer has only ' +
2686                       str(len(self._inbound_nodes)) + ' inbound nodes.')
2687    values = getattr(self._inbound_nodes[node_index], attr)
2688    if isinstance(values, list) and len(values) == 1:
2689      return values[0]
2690    else:
2691      return values
2692
2693  def _maybe_build(self, inputs):
2694    # Check input assumptions set before layer building, e.g. input rank.
2695    if not self.built:
2696      input_spec.assert_input_compatibility(
2697          self.input_spec, inputs, self.name)
2698      input_list = nest.flatten(inputs)
2699      if input_list and self._dtype_policy.compute_dtype is None:
2700        try:
2701          dtype = input_list[0].dtype.base_dtype.name
2702        except AttributeError:
2703          pass
2704        else:
2705          self._set_dtype_policy(policy.Policy(dtype))
2706      input_shapes = None
2707      # Converts Tensors / CompositeTensors to TensorShapes.
2708      if all(hasattr(x, 'shape') for x in input_list):
2709        input_shapes = tf_utils.get_shapes(inputs)
2710      else:
2711        # Converts input shape to TensorShapes.
2712        try:
2713          input_shapes = tf_utils.convert_shapes(inputs, to_tuples=False)
2714        except ValueError:
2715          pass
2716      # Only call `build` if the user has manually overridden the build method.
2717      if not hasattr(self.build, '_is_default'):
2718        # Any setup work performed only once should happen in an `init_scope`
2719        # to avoid creating symbolic Tensors that will later pollute any eager
2720        # operations.
2721        with tf_utils.maybe_init_scope(self):
2722          self.build(input_shapes)  # pylint:disable=not-callable
2723      # We must set also ensure that the layer is marked as built, and the build
2724      # shape is stored since user defined build functions may not be calling
2725      # `super.build()`
2726      Layer.build(self, input_shapes)
2727
2728    # Optionally load weight values specified at layer instantiation.
2729    if self._initial_weights is not None:
2730      if ops.executing_eagerly_outside_functions():
2731        with ops.init_scope():
2732          # Using `init_scope` since we want variable assignment in
2733          # `set_weights` to be treated like variable initialization.
2734          self.set_weights(self._initial_weights)
2735      else:
2736        self.set_weights(self._initial_weights)
2737      self._initial_weights = None
2738
2739  def _symbolic_call(self, inputs):
2740    input_shapes = nest.map_structure(lambda x: x.shape, inputs)
2741    output_shapes = self.compute_output_shape(input_shapes)
2742    # Convert to TensorShape so that nest.map_structure will not map into
2743    # individual dim of the shape.
2744    output_shapes = tf_utils.convert_shapes(output_shapes, to_tuples=False)
2745
2746    def _make_placeholder_like(shape):
2747      ph = backend.placeholder(shape=shape, dtype=self.dtype)
2748      ph._keras_mask = None
2749      return ph
2750    return nest.map_structure(_make_placeholder_like, output_shapes)
2751
2752  def _get_trainable_state(self):
2753    """Get the `trainable` state of each sublayer.
2754
2755    Returns:
2756      A dict mapping all sublayers to their `trainable` value.
2757    """
2758    trainable_state = weakref.WeakKeyDictionary()
2759    for layer in self._flatten_layers():
2760      trainable_state[layer] = layer.trainable
2761    return trainable_state
2762
2763  def _set_trainable_state(self, trainable_state):
2764    """Set `trainable` state for each sublayer."""
2765    for layer in self._flatten_layers():
2766      if layer in trainable_state:
2767        layer.trainable = trainable_state[layer]
2768
2769  @property
2770  def _obj_reference_counts(self):
2771    """A dictionary counting the number of attributes referencing an object."""
2772    self._maybe_create_attribute('_obj_reference_counts_dict',
2773                                 object_identity.ObjectIdentityDictionary())
2774    return self._obj_reference_counts_dict
2775
2776  @trackable.no_automatic_dependency_tracking
2777  def _maybe_create_attribute(self, name, default_value):
2778    """Create the attribute with the default value if it hasn't been created.
2779
2780    This is useful for fields that is used for tracking purpose,
2781    _trainable_weights, or _layers. Note that user could create a layer subclass
2782    and assign an internal field before invoking the Layer.__init__(), the
2783    __setattr__() need to create the tracking fields and __init__() need to not
2784    override them.
2785
2786    Args:
2787      name: String, the name of the attribute.
2788      default_value: Object, the default value of the attribute.
2789    """
2790    if not hasattr(self, name):
2791      self.__setattr__(name, default_value)
2792
2793  def __delattr__(self, name):
2794    # For any super.__delattr__() call, we will directly use the implementation
2795    # in Trackable and skip the behavior in AutoTrackable. The Layer was
2796    # originally use Trackable as base class, the change of using Module as base
2797    # class forced us to have AutoTrackable in the class hierarchy. Skipping
2798    # the __delattr__ and __setattr__ in AutoTrackable will keep the status quo.
2799    existing_value = getattr(self, name, None)
2800
2801    # If this value is replacing an existing object assigned to an attribute, we
2802    # should clean it out to avoid leaking memory. First we check if there are
2803    # other attributes referencing it.
2804    reference_counts = self._obj_reference_counts
2805    if existing_value not in reference_counts:
2806      super(tracking.AutoTrackable, self).__delattr__(name)
2807      return
2808
2809    reference_count = reference_counts[existing_value]
2810    if reference_count > 1:
2811      # There are other remaining references. We can't remove this object from
2812      # _layers etc.
2813      reference_counts[existing_value] = reference_count - 1
2814      super(tracking.AutoTrackable, self).__delattr__(name)
2815      return
2816    else:
2817      # This is the last remaining reference.
2818      del reference_counts[existing_value]
2819
2820    super(tracking.AutoTrackable, self).__delattr__(name)
2821
2822    if (isinstance(existing_value, Layer)
2823        or base_layer_utils.has_weights(existing_value)):
2824      super(tracking.AutoTrackable, self).__setattr__(
2825          '_self_tracked_trackables',
2826          [l for l in self._self_tracked_trackables if l is not existing_value])
2827    if isinstance(existing_value, tf_variables.Variable):
2828      super(tracking.AutoTrackable, self).__setattr__(
2829          '_trainable_weights',
2830          [w for w in self._trainable_weights if w is not existing_value])
2831      super(tracking.AutoTrackable, self).__setattr__(
2832          '_non_trainable_weights',
2833          [w for w in self._non_trainable_weights if w is not existing_value])
2834
2835  def __setattr__(self, name, value):
2836    if (name == '_self_setattr_tracking' or
2837        not getattr(self, '_self_setattr_tracking', True) or
2838        # Exclude @property.setters from tracking
2839        hasattr(self.__class__, name)):
2840      try:
2841        super(tracking.AutoTrackable, self).__setattr__(name, value)
2842      except AttributeError:
2843        raise AttributeError(
2844            ('Can\'t set the attribute "{}", likely because it conflicts with '
2845             'an existing read-only @property of the object. Please choose a '
2846             'different name.').format(name))
2847      return
2848
2849    # Wraps data structures in `Trackable`, unwraps `NoDependency` objects.
2850    value = data_structures.sticky_attribute_assignment(
2851        trackable=self, value=value, name=name)
2852
2853    reference_counts = self._obj_reference_counts
2854    reference_counts[value] = reference_counts.get(value, 0) + 1
2855
2856    # Clean out the old attribute, which clears _layers and _trainable_weights
2857    # if necessary.
2858    try:
2859      self.__delattr__(name)
2860    except AttributeError:
2861      pass
2862
2863    # Keep track of metric instance created in subclassed layer.
2864    for val in nest.flatten(value):
2865      if isinstance(val, metrics_mod.Metric) and hasattr(self, '_metrics'):
2866        self._metrics.append(val)
2867
2868    # Append value to self._self_tracked_trackables if relevant
2869    if (getattr(self, '_auto_track_sub_layers', True) and
2870        (isinstance(value, module.Module) or
2871         base_layer_utils.has_weights(value))):
2872      self._maybe_create_attribute('_self_tracked_trackables', [])
2873      # We need to check object identity to avoid de-duplicating empty
2874      # container types which compare equal.
2875      if not any((layer is value for layer in self._self_tracked_trackables)):
2876        self._self_tracked_trackables.append(value)
2877        if hasattr(value, '_use_resource_variables'):
2878          # Legacy layers (V1 tf.layers) must always use
2879          # resource variables.
2880          value._use_resource_variables = True
2881
2882    # Append value to list of trainable / non-trainable weights if relevant
2883    # TODO(b/125122625): This won't pick up on any variables added to a
2884    # list/dict after creation.
2885    for val in nest.flatten(value, expand_composites=True):
2886      if not isinstance(val, tf_variables.Variable):
2887        continue
2888
2889      # Users may add extra weights/variables
2890      # simply by assigning them to attributes (invalid for graph networks)
2891      self._maybe_create_attribute('_trainable_weights', [])
2892      self._maybe_create_attribute('_non_trainable_weights', [])
2893      if val.trainable:
2894        if any(val is w for w in self._trainable_weights):
2895          continue
2896        self._trainable_weights.append(val)
2897      else:
2898        if any(val is w for w in self._non_trainable_weights):
2899          continue
2900        self._non_trainable_weights.append(val)
2901
2902      backend.track_variable(val)
2903
2904    # Skip the auto trackable from tf.Module to keep status quo. See the comment
2905    # at __delattr__.
2906    super(tracking.AutoTrackable, self).__setattr__(name, value)
2907
2908  def _gather_children_attribute(self, attribute):
2909    assert attribute in {
2910        'variables', 'trainable_variables', 'non_trainable_variables'
2911    }
2912    if hasattr(self, '_self_tracked_trackables'):
2913      nested_layers = self._flatten_modules(include_self=False, recursive=False)
2914      return list(
2915          itertools.chain.from_iterable(
2916              getattr(layer, attribute) for layer in nested_layers))
2917    return []
2918
2919  def _flatten_layers(self, recursive=True, include_self=True):
2920    for m in self._flatten_modules(
2921        recursive=recursive, include_self=include_self):
2922      if isinstance(m, Layer):
2923        yield m
2924
2925  def _flatten_modules(self, recursive=True, include_self=True):
2926    """Flattens `tf.Module` instances (excluding `Metrics`).
2927
2928    Args:
2929      recursive: Whether to recursively flatten through submodules.
2930      include_self: Whether to include this `Layer` instance.
2931
2932    Yields:
2933      `tf.Module` instance tracked by this `Layer`.
2934    """
2935    if include_self:
2936      yield self
2937
2938    # Only instantiate set and deque if needed.
2939    trackables = getattr(self, '_self_tracked_trackables', None)
2940    if trackables:
2941      seen_object_ids = set()
2942      deque = collections.deque(trackables)
2943      while deque:
2944        trackable_obj = deque.popleft()
2945        trackable_id = id(trackable_obj)
2946        if trackable_id in seen_object_ids:
2947          continue
2948        seen_object_ids.add(trackable_id)
2949
2950        # Metrics are not considered part of the Layer's topology.
2951        if (isinstance(trackable_obj, module.Module) and
2952            not isinstance(trackable_obj, metrics_mod.Metric)):
2953          yield trackable_obj
2954          # Introspect recursively through sublayers.
2955          if recursive:
2956            subtrackables = getattr(trackable_obj, '_self_tracked_trackables',
2957                                    None)
2958            if subtrackables:
2959              deque.extendleft(reversed(subtrackables))
2960        elif isinstance(trackable_obj, data_structures.TrackableDataStructure):
2961          # Data structures are introspected even with `recursive=False`.
2962          tracked_values = trackable_obj._values
2963          if tracked_values:
2964            deque.extendleft(reversed(tracked_values))
2965
2966  # This is a hack so that the is_layer (within
2967  # training/trackable/layer_utils.py) check doesn't get the weights attr.
2968  # TODO(b/110718070): Remove when fixed.
2969  def _is_layer(self):
2970    return True
2971
2972  def _init_call_fn_args(self):
2973    # Clear cached call function arguments.
2974    self.__class__._call_full_argspec.fget.cache.pop(self, None)
2975    self.__class__._call_fn_args.fget.cache.pop(self, None)
2976    self.__class__._call_accepts_kwargs.fget.cache.pop(self, None)
2977
2978    call_fn_args = self._call_fn_args
2979    self._expects_training_arg = ('training' in call_fn_args or
2980                                  self._call_accepts_kwargs)
2981    # The default training arg will be any (non-None) default specified in the
2982    # method signature, or None if no value is specified.
2983    self._default_training_arg = self._call_fn_arg_defaults.get(
2984        'training')
2985    self._expects_mask_arg = ('mask' in call_fn_args or
2986                              self._call_accepts_kwargs)
2987
2988  @property
2989  @layer_utils.cached_per_instance
2990  def _call_full_argspec(self):
2991    # Argspec inspection is expensive and the call spec is used often, so it
2992    # makes sense to cache the result.
2993    return tf_inspect.getfullargspec(self.call)
2994
2995  @property
2996  @layer_utils.cached_per_instance
2997  def _call_fn_args(self):
2998    all_args = self._call_full_argspec.args
2999    # Scrub `self` that appears if a decorator was applied.
3000    if all_args and all_args[0] == 'self':
3001      return all_args[1:]
3002    return all_args
3003
3004  @property
3005  @layer_utils.cached_per_instance
3006  def _call_fn_arg_defaults(self):
3007    call_fn_args = self._call_fn_args
3008    call_fn_defaults = self._call_full_argspec.defaults or []
3009    defaults = dict()
3010
3011    # The call arg defaults are an n-tuple of the last n elements of the args
3012    # list. (n = # of elements that have a default argument)
3013    for i in range(-1 * len(call_fn_defaults), 0):
3014      defaults[call_fn_args[i]] = call_fn_defaults[i]
3015    return defaults
3016
3017  @property
3018  @layer_utils.cached_per_instance
3019  def _call_fn_arg_positions(self):
3020    call_fn_arg_positions = dict()
3021    for pos, arg in enumerate(self._call_fn_args):
3022      call_fn_arg_positions[arg] = pos
3023    return call_fn_arg_positions
3024
3025  @property
3026  @layer_utils.cached_per_instance
3027  def _call_accepts_kwargs(self):
3028    return self._call_full_argspec.varkw is not None
3029
3030  @property
3031  def _eager_losses(self):
3032    # A list of loss values containing activity regularizers and losses
3033    # manually added through `add_loss` during eager execution. It is cleared
3034    # after every batch.
3035    # Because we plan on eventually allowing a same model instance to be trained
3036    # in eager mode or graph mode alternatively, we need to keep track of
3037    # eager losses and symbolic losses via separate attributes.
3038    if not hasattr(self._thread_local, '_eager_losses'):
3039      self._thread_local._eager_losses = []
3040    return self._thread_local._eager_losses
3041
3042  @_eager_losses.setter
3043  def _eager_losses(self, losses):
3044    self._thread_local._eager_losses = losses
3045
3046  def _dedup_weights(self, weights):
3047    """Dedupe weights while maintaining order as much as possible."""
3048    output, seen_ids = [], set()
3049    for w in weights:
3050      if id(w) not in seen_ids:
3051        output.append(w)
3052        # Track the Variable's identity to avoid __eq__ issues.
3053        seen_ids.add(id(w))
3054
3055    return output
3056
3057  def _split_out_first_arg(self, args, kwargs):
3058    # Grab the argument corresponding to the first argument in the
3059    # layer's `call` method spec. This will either be the first positional
3060    # argument, or it will be provided as a keyword argument.
3061    if args:
3062      inputs = args[0]
3063      args = args[1:]
3064    elif self._call_fn_args[0] in kwargs:
3065      kwargs = copy.copy(kwargs)
3066      inputs = kwargs.pop(self._call_fn_args[0])
3067    else:
3068      raise ValueError(
3069          'The first argument to `Layer.call` must always be passed.')
3070    return inputs, args, kwargs
3071
3072  # SavedModel properties. Please see keras/saving/saved_model for details.
3073
3074  @trackable.no_automatic_dependency_tracking
3075  def _set_save_spec(self, inputs):
3076    if self._saved_model_inputs_spec is not None:
3077      return  # Already set.
3078
3079    self._saved_model_inputs_spec = nest.map_structure(tf_utils.get_tensor_spec,
3080                                                       inputs)
3081
3082  def _get_save_spec(self, dynamic_batch=True):
3083    if self._saved_model_inputs_spec is None:
3084      return None
3085
3086    return nest.map_structure(
3087        lambda t: tf_utils.get_tensor_spec(t, dynamic_batch=dynamic_batch),
3088        self._saved_model_inputs_spec)
3089
3090  @property
3091  def _trackable_saved_model_saver(self):
3092    return layer_serialization.LayerSavedModelSaver(self)
3093
3094  @property
3095  def _object_identifier(self):
3096    return self._trackable_saved_model_saver.object_identifier
3097
3098  @property
3099  def _tracking_metadata(self):
3100    return self._trackable_saved_model_saver.tracking_metadata
3101
3102  def _list_extra_dependencies_for_serialization(self, serialization_cache):
3103    return (self._trackable_saved_model_saver
3104            .list_extra_dependencies_for_serialization(serialization_cache))
3105
3106  def _list_functions_for_serialization(self, serialization_cache):
3107    return (self._trackable_saved_model_saver
3108            .list_functions_for_serialization(serialization_cache))
3109
3110  @property
3111  def _use_input_spec_as_call_signature(self):
3112    # Whether input spec can be used as the call signature when tracing the
3113    # Layer for SavedModel. By default, this is set to `True` for layers
3114    # exported from the Keras library, because the layers more rigidly define
3115    # the `input_specs` property (many custom layers only set the `ndims`)
3116    return get_canonical_name_for_symbol(type(self)) is not None
3117
3118  def __getstate__(self):
3119    # Override to support `copy.deepcopy` and pickling.
3120    # Thread-local objects cannot be copied in Python 3, so pop these.
3121    # Thread-local objects are used to cache losses in MirroredStrategy, and
3122    # so shouldn't be copied.
3123    state = self.__dict__.copy()
3124    state.pop('_thread_local', None)
3125    state.pop('_metrics_lock', None)
3126    return state
3127
3128  def __setstate__(self, state):
3129    state['_thread_local'] = threading.local()
3130    state['_metrics_lock'] = threading.Lock()
3131    # Bypass Trackable logic as `__dict__` already contains this info.
3132    object.__setattr__(self, '__dict__', state)
3133
3134
3135class TensorFlowOpLayer(Layer):
3136  """Wraps a TensorFlow Operation in a Layer.
3137
3138  This class is used internally by the Functional API. When a user
3139  uses a raw TensorFlow Operation on symbolic tensors originating
3140  from an `Input` Layer, the resultant operation will be wrapped
3141  with this Layer object in order to make the operation compatible
3142  with the Keras API.
3143
3144  This Layer will create a new, identical operation (except for inputs
3145  and outputs) every time it is called. If `run_eagerly` is `True`,
3146  the op creation and calculation will happen inside an Eager function.
3147
3148  Instances of this Layer are created when `autolambda` is called, which
3149  is whenever a Layer's `__call__` encounters symbolic inputs that do
3150  not have Keras metadata, or when a Network's `__init__` encounters
3151  outputs that do not have Keras metadata.
3152
3153  Attributes:
3154    node_def: String, the serialized NodeDef of the Op this layer will wrap.
3155    name: String, the name of the Layer.
3156    constants: Dict of NumPy arrays, the values of any Tensors needed for this
3157      Operation that do not originate from a Keras `Input` Layer. Since all
3158      placeholders must come from Keras `Input` Layers, these Tensors must be
3159      treated as constant in the Functional API.
3160    trainable: Bool, whether this Layer is trainable. Currently Variables are
3161      not supported, and so this parameter has no effect.
3162    dtype: The default dtype of this Layer. Inherited from `Layer` and has no
3163      effect on this class, however is used in `get_config`.
3164  """
3165
3166  @trackable.no_automatic_dependency_tracking
3167  def __init__(self,
3168               node_def,
3169               name,
3170               constants=None,
3171               trainable=True,
3172               dtype=None):
3173    # Pass autocast=False, as if inputs are cast, input types might not match
3174    # Operation type.
3175    super(TensorFlowOpLayer, self).__init__(
3176        name=_TF_OP_LAYER_NAME_PREFIX + name, trainable=trainable, dtype=dtype,
3177        autocast=False)
3178    if isinstance(node_def, dict):
3179      self.node_def = json_format.ParseDict(node_def, node_def_pb2.NodeDef())
3180    else:
3181      if not isinstance(node_def, bytes):
3182        node_def = node_def.encode('utf-8')
3183      self.node_def = node_def_pb2.NodeDef.FromString(node_def)
3184    # JSON serialization stringifies keys which are integer input indices.
3185    self.constants = ({
3186        int(index): constant for index, constant in constants.items()
3187    } if constants is not None else {})
3188    # Layer uses original op unless it is called on new inputs.
3189    # This means `built` is not set in `__call__`.
3190    self.built = True
3191
3192    # Do not individually trace TensorflowOpLayers in the SavedModel.
3193    self._must_restore_from_config = True
3194
3195  def call(self, inputs):
3196    if context.executing_eagerly():
3197      return self._defun_call(inputs)
3198    return self._make_op(inputs)
3199
3200  def _make_node_def(self, graph):
3201    node_def = node_def_pb2.NodeDef()
3202    node_def.CopyFrom(self.node_def)
3203    # Used in TPUReplicateContext to indicate whether this node has been cloned
3204    # and to not add TPU attributes.
3205    node_def.attr['_cloned'].b = True
3206    node_def.name = graph.unique_name(node_def.name)
3207    return node_def
3208
3209  def _make_op(self, inputs):
3210    inputs = nest.flatten(inputs)
3211    graph = inputs[0].graph
3212    node_def = self._make_node_def(graph)
3213    with graph.as_default():
3214      for index, constant in self.constants.items():
3215        # Recreate constant in graph to add distribution context.
3216        value = tensor_util.constant_value(constant)
3217        if value is not None:
3218          constant = constant_op.constant(value, name=node_def.input[index])
3219        inputs.insert(index, constant)
3220      c_op = ops._create_c_op(graph, node_def, inputs, control_inputs=[])
3221      op = graph._create_op_from_tf_operation(c_op)
3222      op._control_flow_post_processing()
3223
3224      # Record the gradient because custom-made ops don't go through the
3225      # code-gen'd eager call path
3226      op_type = compat.as_str(op.op_def.name)
3227      attr_names = [compat.as_str(attr.name) for attr in op.op_def.attr]
3228      attrs = []
3229      for attr_name in attr_names:
3230        attrs.append(attr_name)
3231        attrs.append(op.get_attr(attr_name))
3232      attrs = tuple(attrs)
3233      execute.record_gradient(op_type, op.inputs, attrs, op.outputs)
3234
3235      if len(op.outputs) == 1:
3236        return op.outputs[0]
3237      return op.outputs
3238
3239  @def_function.function
3240  def _defun_call(self, inputs):
3241    """Wraps the op creation method in an Eager function for `run_eagerly`."""
3242    return self._make_op(inputs)
3243
3244  def get_config(self):
3245    config = super(TensorFlowOpLayer, self).get_config()
3246    config.update({
3247        # `__init__` prefixes the name. Revert to the constructor argument.
3248        'name': config['name'][len(_TF_OP_LAYER_NAME_PREFIX):],
3249        'node_def': json_format.MessageToDict(self.node_def),
3250        'constants': {
3251            i: backend.get_value(c) for i, c in self.constants.items()
3252        }
3253    })
3254    return config
3255
3256
3257class AddLoss(Layer):
3258  """Adds its inputs as a loss.
3259
3260  Attributes:
3261    unconditional: Whether or not the loss should be conditioned on the inputs.
3262  """
3263
3264  def __init__(self, unconditional, **kwargs):
3265    # Pass autocast=False, as there is no reason to cast loss to a different
3266    # dtype.
3267    kwargs['autocast'] = False
3268    super(AddLoss, self).__init__(**kwargs)
3269    self.unconditional = unconditional
3270
3271  def call(self, inputs):
3272    self.add_loss(inputs, inputs=(not self.unconditional))
3273    return inputs
3274
3275  def get_config(self):
3276    config = super(AddLoss, self).get_config()
3277    config.update({'unconditional': self.unconditional})
3278    return config
3279
3280
3281class AddMetric(Layer):
3282  """Adds its inputs as a metric.
3283
3284  Attributes:
3285    aggregation: 'mean' or None. How the inputs should be aggregated.
3286    metric_name: The name to use for this metric.
3287  """
3288
3289  def __init__(self, aggregation=None, metric_name=None, **kwargs):
3290    super(AddMetric, self).__init__(**kwargs)
3291    self.aggregation = aggregation
3292    self.metric_name = metric_name
3293
3294  def call(self, inputs):
3295    self.add_metric(inputs, aggregation=self.aggregation, name=self.metric_name)
3296    return inputs
3297
3298  def get_config(self):
3299    config = super(AddMetric, self).get_config()
3300    config.update({
3301        'aggregation': self.aggregation,
3302        'metric_name': self.metric_name
3303    })
3304    return config
3305
3306
3307def _in_functional_construction_mode(layer, inputs, args, kwargs, input_list):  # pylint: disable=unused-argument
3308  """Check the arguments to see if we are constructing a functional model."""
3309  if keras_tensor.keras_tensors_enabled():
3310    # We are constructing a functional model if any of the inputs
3311    # are KerasTensors
3312    return any(
3313        isinstance(tensor, keras_tensor.KerasTensor)
3314        for tensor in nest.flatten([inputs, args, kwargs]))
3315  else:
3316    if context.executing_eagerly():
3317      all_inputs_symbolic = all(
3318          tf_utils.is_symbolic_tensor(t) for t in input_list)
3319      if (base_layer_utils.is_subclassed(layer) and
3320          any(tf_utils.is_symbolic_tensor(t) for t in nest.flatten(
3321              [inputs, args, kwargs])) and not all_inputs_symbolic):
3322        raise ValueError('It appears you are trying to construct a '
3323                         'functional model, but not all of the inputs in '
3324                         'the first positional argument of your layer call '
3325                         'are symbolic tensors. '
3326                         '(Input objects, or the output of another layer) '
3327                         'Functional models cannot correctly track custom '
3328                         'layers unless all values in the first call argument '
3329                         'are symbolic.')
3330      return all_inputs_symbolic
3331    else:
3332      return (base_layer_utils.is_in_keras_graph() or
3333              all(hasattr(t, '_keras_history') for t in input_list))
3334
3335
3336def _convert_numpy_or_python_types(x):
3337  if isinstance(x, (np_arrays.ndarray, np.ndarray, float, int)):
3338    return ops.convert_to_tensor_v2_with_dispatch(x)
3339  return x
3340
3341
3342# Avoid breaking users who directly import this symbol from this file.
3343# TODO(fchollet): remove this.
3344InputSpec = input_spec.InputSpec  # pylint:disable=invalid-name
3345