1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Normalization layers."""
16# pylint: disable=g-classes-have-attributes
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.distribute import distribution_strategy_context
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_shape
26from tensorflow.python.keras import backend as K
27from tensorflow.python.keras import constraints
28from tensorflow.python.keras import initializers
29from tensorflow.python.keras import regularizers
30from tensorflow.python.keras.engine.base_layer import Layer
31from tensorflow.python.keras.engine.input_spec import InputSpec
32from tensorflow.python.keras.utils import control_flow_util
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import init_ops
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops import nn
37from tensorflow.python.ops import state_ops
38from tensorflow.python.ops import variables as tf_variables
39from tensorflow.python.ops.control_flow_ops import get_enclosing_xla_context
40from tensorflow.python.platform import tf_logging as logging
41from tensorflow.python.util.tf_export import keras_export
42
43
44class BatchNormalizationBase(Layer):
45  r"""Layer that normalizes its inputs.
46
47  Batch normalization applies a transformation that maintains the mean output
48  close to 0 and the output standard deviation close to 1.
49
50  Importantly, batch normalization works differently during training and
51  during inference.
52
53  **During training** (i.e. when using `fit()` or when calling the layer/model
54  with the argument `training=True`), the layer normalizes its output using
55  the mean and standard deviation of the current batch of inputs. That is to
56  say, for each channel being normalized, the layer returns
57  `gamma * (batch - mean(batch)) / sqrt(var(batch) + epsilon) + beta`, where:
58
59  - `epsilon` is small constant (configurable as part of the constructor
60  arguments)
61  - `gamma` is a learned scaling factor (initialized as 1), which
62  can be disabled by passing `scale=False` to the constructor.
63  - `beta` is a learned offset factor (initialized as 0), which
64  can be disabled by passing `center=False` to the constructor.
65
66  **During inference** (i.e. when using `evaluate()` or `predict()` or when
67  calling the layer/model with the argument `training=False` (which is the
68  default), the layer normalizes its output using a moving average of the
69  mean and standard deviation of the batches it has seen during training. That
70  is to say, it returns
71  `gamma * (batch - self.moving_mean) / sqrt(self.moving_var + epsilon) + beta`.
72
73  `self.moving_mean` and `self.moving_var` are non-trainable variables that
74  are updated each time the layer in called in training mode, as such:
75
76  - `moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)`
77  - `moving_var = moving_var * momentum + var(batch) * (1 - momentum)`
78
79  As such, the layer will only normalize its inputs during inference
80  *after having been trained on data that has similar statistics as the
81  inference data*.
82
83  Args:
84    axis: Integer or a list of integers, the axis that should be normalized
85    (typically the features axis). For instance, after a `Conv2D` layer with
86      `data_format="channels_first"`, set `axis=1` in `BatchNormalization`.
87    momentum: Momentum for the moving average.
88    epsilon: Small float added to variance to avoid dividing by zero.
89    center: If True, add offset of `beta` to normalized tensor. If False, `beta`
90      is ignored.
91    scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the
92      next layer is linear (also e.g. `nn.relu`), this can be disabled since the
93      scaling will be done by the next layer.
94    beta_initializer: Initializer for the beta weight.
95    gamma_initializer: Initializer for the gamma weight.
96    moving_mean_initializer: Initializer for the moving mean.
97    moving_variance_initializer: Initializer for the moving variance.
98    beta_regularizer: Optional regularizer for the beta weight.
99    gamma_regularizer: Optional regularizer for the gamma weight.
100    beta_constraint: Optional constraint for the beta weight.
101    gamma_constraint: Optional constraint for the gamma weight.
102    renorm: Whether to use [Batch Renormalization](
103      https://arxiv.org/abs/1702.03275). This adds extra variables during
104        training. The inference is the same for either value of this parameter.
105    renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to
106      scalar `Tensors` used to clip the renorm correction. The correction `(r,
107      d)` is used as `corrected_value = normalized_value * r + d`, with `r`
108      clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin,
109      dmax are set to inf, 0, inf, respectively.
110    renorm_momentum: Momentum used to update the moving means and standard
111      deviations with renorm. Unlike `momentum`, this affects training and
112      should be neither too small (which would add noise) nor too large (which
113      would give stale estimates). Note that `momentum` is still applied to get
114      the means and variances for inference.
115    fused: if `True`, use a faster, fused implementation, or raise a ValueError
116      if the fused implementation cannot be used. If `None`, use the faster
117      implementation if possible. If False, do not used the fused
118      implementation.
119      Note that in TensorFlow 1.x, the meaning of `fused=True` is different:
120      if `False`, the layer uses the system-recommended implementation.
121    trainable: Boolean, if `True` the variables will be marked as trainable.
122    virtual_batch_size: An `int`. By default, `virtual_batch_size` is `None`,
123      which means batch normalization is performed across the whole batch. When
124      `virtual_batch_size` is not `None`, instead perform "Ghost Batch
125      Normalization", which creates virtual sub-batches which are each
126      normalized separately (with shared gamma, beta, and moving statistics).
127      Must divide the actual batch size during execution.
128    adjustment: A function taking the `Tensor` containing the (dynamic) shape of
129      the input tensor and returning a pair (scale, bias) to apply to the
130      normalized values (before gamma and beta), only during training. For
131      example, if `axis=-1`,
132        `adjustment = lambda shape: (
133          tf.random.uniform(shape[-1:], 0.93, 1.07),
134          tf.random.uniform(shape[-1:], -0.1, 0.1))` will scale the normalized
135            value by up to 7% up or down, then shift the result by up to 0.1
136            (with independent scaling and bias for each feature but shared
137            across all examples), and finally apply gamma and/or beta. If
138            `None`, no adjustment is applied. Cannot be specified if
139            virtual_batch_size is specified.
140
141  Call arguments:
142    inputs: Input tensor (of any rank).
143    training: Python boolean indicating whether the layer should behave in
144      training mode or in inference mode.
145      - `training=True`: The layer will normalize its inputs using the mean and
146        variance of the current batch of inputs.
147      - `training=False`: The layer will normalize its inputs using the mean and
148        variance of its moving statistics, learned during training.
149
150  Input shape:
151    Arbitrary. Use the keyword argument `input_shape` (tuple of
152    integers, does not include the samples axis) when using this layer as the
153    first layer in a model.
154
155  Output shape:
156    Same shape as input.
157
158  Reference:
159    - [Ioffe and Szegedy, 2015](https://arxiv.org/abs/1502.03167).
160  """
161
162  # By default, the base class uses V2 behavior. The BatchNormalization V1
163  # subclass sets this to False to use the V1 behavior.
164  _USE_V2_BEHAVIOR = True
165
166  def __init__(self,
167               axis=-1,
168               momentum=0.99,
169               epsilon=1e-3,
170               center=True,
171               scale=True,
172               beta_initializer='zeros',
173               gamma_initializer='ones',
174               moving_mean_initializer='zeros',
175               moving_variance_initializer='ones',
176               beta_regularizer=None,
177               gamma_regularizer=None,
178               beta_constraint=None,
179               gamma_constraint=None,
180               renorm=False,
181               renorm_clipping=None,
182               renorm_momentum=0.99,
183               fused=None,
184               trainable=True,
185               virtual_batch_size=None,
186               adjustment=None,
187               name=None,
188               **kwargs):
189    super(BatchNormalizationBase, self).__init__(name=name, **kwargs)
190    if isinstance(axis, (list, tuple)):
191      self.axis = axis[:]
192    elif isinstance(axis, int):
193      self.axis = axis
194    else:
195      raise TypeError('Expected an int or a list/tuple of ints for the '
196                      'argument \'axis\', but received: %r' % axis)
197    self.momentum = momentum
198    self.epsilon = epsilon
199    self.center = center
200    self.scale = scale
201    self.beta_initializer = initializers.get(beta_initializer)
202    self.gamma_initializer = initializers.get(gamma_initializer)
203    self.moving_mean_initializer = initializers.get(moving_mean_initializer)
204    self.moving_variance_initializer = initializers.get(
205        moving_variance_initializer)
206    self.beta_regularizer = regularizers.get(beta_regularizer)
207    self.gamma_regularizer = regularizers.get(gamma_regularizer)
208    self.beta_constraint = constraints.get(beta_constraint)
209    self.gamma_constraint = constraints.get(gamma_constraint)
210    self.renorm = renorm
211    self.virtual_batch_size = virtual_batch_size
212    self.adjustment = adjustment
213    if self._USE_V2_BEHAVIOR:
214      if fused:
215        self._raise_if_fused_cannot_be_used()
216      # We leave fused as None if self._fused_can_be_used()==True, since we
217      # still may set it to False in self.build() if the input rank is not 4.
218      elif fused is None and not self._fused_can_be_used():
219        fused = False
220    elif fused is None:
221      fused = True
222    self.supports_masking = True
223
224    self.fused = fused
225    self._bessels_correction_test_only = True
226    self.trainable = trainable
227
228    if renorm:
229      renorm_clipping = renorm_clipping or {}
230      keys = ['rmax', 'rmin', 'dmax']
231      if set(renorm_clipping) - set(keys):
232        raise ValueError('renorm_clipping %s contains keys not in %s' %
233                         (renorm_clipping, keys))
234      self.renorm_clipping = renorm_clipping
235      self.renorm_momentum = renorm_momentum
236
237  def _raise_if_fused_cannot_be_used(self):
238    """Raises a ValueError if fused implementation cannot be used.
239
240    In addition to the checks done in this function, the input tensors rank must
241    be 4. The input rank check can only be done once the input shape is known.
242    """
243    # Note the ValueErrors in this function are caught and not reraised in
244    # _fused_can_be_used(). No other exception besides ValueError should be
245    # raised here.
246
247    # Currently fused batch norm doesn't support renorm. It also only supports a
248    # channel dimension on axis 1 or 3, when no virtual batch size or adjustment
249    # is used.
250    if self.renorm:
251      raise ValueError('Passing both `fused=True` and `renorm=True` is '
252                       'not supported')
253    axis = [self.axis] if isinstance(self.axis, int) else self.axis
254    # Axis -3 is equivalent to 1, and axis -1 is equivalent to 3, because the
255    # input rank is required to be 4 (which is checked later).
256    # TODO(b/173253101): Once the input rank can be 5, update this check.
257    if len(axis) > 1 or axis[0] not in (-3, -1, 1, 3):
258      raise ValueError('Passing `fused=True` is only supported when axis is 1 '
259                       'or 3. Got axis %s' % (axis,))
260    if self.virtual_batch_size is not None:
261      raise ValueError('Passing `fused=True` is not supported when '
262                       '`virtual_batch_size` is specified.')
263    if self.adjustment is not None:
264      raise ValueError('Passing `fused=True` is not supported when '
265                       '`adjustment` is specified.')
266    # TODO(reedwm): Support fp64 in FusedBatchNorm then remove this check.
267    if self._compute_dtype not in ('float16', 'bfloat16', 'float32', None):
268      raise ValueError(
269          'Passing `fused=True` is only supported when the compute '
270          'dtype is float16, bfloat16, or float32. Got dtype: %s' %
271          (self._compute_dtype,))
272
273  def _fused_can_be_used(self):
274    try:
275      self._raise_if_fused_cannot_be_used()
276      return True
277    except ValueError:
278      return False
279
280  @property
281  def trainable(self):
282    return self._trainable
283
284  @trainable.setter
285  def trainable(self, value):
286    self._trainable = value
287
288  @property
289  def _param_dtype(self):
290    # Raise parameters of fp16 batch norm to fp32
291    if self.dtype == dtypes.float16 or self.dtype == dtypes.bfloat16:
292      return dtypes.float32
293    else:
294      return self.dtype or dtypes.float32
295
296  def _support_zero_size_input(self):
297    return distribution_strategy_context.has_strategy() and getattr(
298        distribution_strategy_context.get_strategy().extended,
299        'experimental_enable_get_next_as_optional', False)
300
301  def build(self, input_shape):
302    input_shape = tensor_shape.TensorShape(input_shape)
303    if not input_shape.ndims:
304      raise ValueError('Input has undefined rank.')
305    ndims = len(input_shape)
306
307    # Convert axis to list and resolve negatives
308    if isinstance(self.axis, int):
309      self.axis = [self.axis]
310
311    for idx, x in enumerate(self.axis):
312      if x < 0:
313        self.axis[idx] = ndims + x
314
315    # Validate axes
316    for x in self.axis:
317      if x < 0 or x >= ndims:
318        raise ValueError('Invalid axis: %s' % (self.axis,))
319    if len(self.axis) != len(set(self.axis)):
320      raise ValueError('Duplicate axis: %s' % (self.axis,))
321
322    if self.virtual_batch_size is not None:
323      if self.virtual_batch_size <= 0:
324        raise ValueError('virtual_batch_size must be a positive integer that '
325                         'divides the true batch size of the input tensor')
326      # If using virtual batches, the first dimension must be the batch
327      # dimension and cannot be the batch norm axis
328      if 0 in self.axis:
329        raise ValueError('When using virtual_batch_size, the batch dimension '
330                         'must be 0 and thus axis cannot include 0. '
331                         'Received axis=%s' % (self.axis,))
332      if self.adjustment is not None:
333        raise ValueError('When using virtual_batch_size, adjustment cannot '
334                         'be specified')
335
336    if self.fused in (None, True):
337      # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
338      # output back to its original shape accordingly.
339      if self._USE_V2_BEHAVIOR:
340        # TODO(b/173253101): Using fused in the 5D case is currently disabled
341        # due to a regression on UNet, so it is only currently only supported in
342        # the 4D case.
343        if self.fused is None:
344          self.fused = ndims == 4
345        elif self.fused and ndims != 4:
346          raise ValueError('Batch normalization layers with `fused=True` only '
347                           'support 4D or 5D input tensors. '
348                           'Received tensor with shape: %s' %
349                           (tuple(input_shape),))
350      else:
351        assert self.fused is not None
352        self.fused = (ndims == 4 and self._fused_can_be_used())
353      # TODO(chrisying): fused batch norm is currently not supported for
354      # multi-axis batch norm and by extension virtual batches. In some cases,
355      # it might be possible to use fused batch norm but would require reshaping
356      # the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is
357      # particularly tricky. A compromise might be to just support the most
358      # common use case (turning 5D w/ virtual batch to NCHW)
359
360    if self.fused:
361      if self.axis == [1] and ndims == 4:
362        self._data_format = 'NCHW'
363      elif self.axis == [1] and ndims == 5:
364        self._data_format = 'NCDHW'
365      elif self.axis == [3] and ndims == 4:
366        self._data_format = 'NHWC'
367      elif self.axis == [4] and ndims == 5:
368        self._data_format = 'NDHWC'
369      elif ndims == 5:
370        # 5D tensors that can be passed in but should not use fused batch norm
371        # due to unsupported axis.
372        self.fused = False
373      else:
374        if ndims == 4:
375          raise ValueError(
376              'Unsupported axis. The use of `fused=True` is only possible with '
377              '`axis=1` or `axis=3` for 4D input tensors. Received '
378              'axis=%s' % (self.axis,))
379        else:
380          raise ValueError(
381              'Unsupported axis. The use of `fused=True` is only possible with '
382              '`axis=1` or `axis=4` for 5D input tensors. Received '
383              'axis=%s' % (self.axis,))
384
385    axis_to_dim = {x: input_shape.dims[x].value for x in self.axis}
386    for x in axis_to_dim:
387      if axis_to_dim[x] is None:
388        raise ValueError('Input has undefined `axis` dimension. Received input '
389                         'with shape %s. Axis value: %s' %
390                         (tuple(input_shape), self.axis))
391    self.input_spec = InputSpec(ndim=ndims, axes=axis_to_dim)
392
393    if len(axis_to_dim) == 1 and self.virtual_batch_size is None:
394      # Single axis batch norm (most common/default use-case)
395      param_shape = (list(axis_to_dim.values())[0],)
396    else:
397      # Parameter shape is the original shape but with 1 in all non-axis dims
398      param_shape = [
399          axis_to_dim[i] if i in axis_to_dim else 1 for i in range(ndims)
400      ]
401      if self.virtual_batch_size is not None:
402        # When using virtual batches, add an extra dim at index 1
403        param_shape.insert(1, 1)
404        for idx, x in enumerate(self.axis):
405          self.axis[idx] = x + 1  # Account for added dimension
406
407    if self.scale:
408      self.gamma = self.add_weight(
409          name='gamma',
410          shape=param_shape,
411          dtype=self._param_dtype,
412          initializer=self.gamma_initializer,
413          regularizer=self.gamma_regularizer,
414          constraint=self.gamma_constraint,
415          trainable=True,
416          experimental_autocast=False)
417    else:
418      self.gamma = None
419      if self.fused:
420        self._gamma_const = K.constant(
421            1.0, dtype=self._param_dtype, shape=param_shape)
422
423    if self.center:
424      self.beta = self.add_weight(
425          name='beta',
426          shape=param_shape,
427          dtype=self._param_dtype,
428          initializer=self.beta_initializer,
429          regularizer=self.beta_regularizer,
430          constraint=self.beta_constraint,
431          trainable=True,
432          experimental_autocast=False)
433    else:
434      self.beta = None
435      if self.fused:
436        self._beta_const = K.constant(
437            0.0, dtype=self._param_dtype, shape=param_shape)
438
439    try:
440      # Disable variable partitioning when creating the moving mean and variance
441      if hasattr(self, '_scope') and self._scope:
442        partitioner = self._scope.partitioner
443        self._scope.set_partitioner(None)
444      else:
445        partitioner = None
446      self.moving_mean = self.add_weight(
447          name='moving_mean',
448          shape=param_shape,
449          dtype=self._param_dtype,
450          initializer=self.moving_mean_initializer,
451          synchronization=tf_variables.VariableSynchronization.ON_READ,
452          trainable=False,
453          aggregation=tf_variables.VariableAggregation.MEAN,
454          experimental_autocast=False)
455
456      self.moving_variance = self.add_weight(
457          name='moving_variance',
458          shape=param_shape,
459          dtype=self._param_dtype,
460          initializer=self.moving_variance_initializer,
461          synchronization=tf_variables.VariableSynchronization.ON_READ,
462          trainable=False,
463          aggregation=tf_variables.VariableAggregation.MEAN,
464          experimental_autocast=False)
465
466      if self.renorm:
467        # In batch renormalization we track the inference moving stddev instead
468        # of the moving variance to more closely align with the paper.
469        def moving_stddev_initializer(*args, **kwargs):
470          return math_ops.sqrt(
471              self.moving_variance_initializer(*args, **kwargs))
472
473        with distribution_strategy_context.get_strategy(
474        ).extended.colocate_vars_with(self.moving_variance):
475          self.moving_stddev = self.add_weight(
476              name='moving_stddev',
477              shape=param_shape,
478              dtype=self._param_dtype,
479              initializer=moving_stddev_initializer,
480              synchronization=tf_variables.VariableSynchronization.ON_READ,
481              trainable=False,
482              aggregation=tf_variables.VariableAggregation.MEAN,
483              experimental_autocast=False)
484
485        # Create variables to maintain the moving mean and standard deviation.
486        # These are used in training and thus are different from the moving
487        # averages above. The renorm variables are colocated with moving_mean
488        # and moving_stddev.
489        # NOTE: below, the outer `with device` block causes the current device
490        # stack to be cleared. The nested ones use a `lambda` to set the desired
491        # device and ignore any devices that may be set by the custom getter.
492        def _renorm_variable(name,
493                             shape,
494                             initializer=init_ops.zeros_initializer()):
495          """Create a renorm variable."""
496          var = self.add_weight(
497              name=name,
498              shape=shape,
499              dtype=self._param_dtype,
500              initializer=initializer,
501              synchronization=tf_variables.VariableSynchronization.ON_READ,
502              trainable=False,
503              aggregation=tf_variables.VariableAggregation.MEAN,
504              experimental_autocast=False)
505          return var
506
507        with distribution_strategy_context.get_strategy(
508        ).extended.colocate_vars_with(self.moving_mean):
509          self.renorm_mean = _renorm_variable('renorm_mean', param_shape,
510                                              self.moving_mean_initializer)
511        with distribution_strategy_context.get_strategy(
512        ).extended.colocate_vars_with(self.moving_stddev):
513          self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape,
514                                                moving_stddev_initializer)
515    finally:
516      if partitioner:
517        self._scope.set_partitioner(partitioner)
518    self.built = True
519
520  def _assign_moving_average(self, variable, value, momentum, inputs_size):
521    with K.name_scope('AssignMovingAvg') as scope:
522      with ops.colocate_with(variable):
523        decay = ops.convert_to_tensor_v2_with_dispatch(
524            1.0 - momentum, name='decay')
525        if decay.dtype != variable.dtype.base_dtype:
526          decay = math_ops.cast(decay, variable.dtype.base_dtype)
527        update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay
528        if inputs_size is not None:
529          update_delta = array_ops.where(inputs_size > 0, update_delta,
530                                         K.zeros_like(update_delta))
531        return state_ops.assign_sub(variable, update_delta, name=scope)
532
533  def _assign_new_value(self, variable, value):
534    with K.name_scope('AssignNewValue') as scope:
535      with ops.colocate_with(variable):
536        return state_ops.assign(variable, value, name=scope)
537
538  def _fused_batch_norm(self, inputs, training):
539    """Returns the output of fused batch norm."""
540    beta = self.beta if self.center else self._beta_const
541    gamma = self.gamma if self.scale else self._gamma_const
542
543    # TODO(b/129279393): Support zero batch input in non DistributionStrategy
544    # code as well.
545    if self._support_zero_size_input():
546      # Keras assumes that batch dimension is the first dimension for Batch
547      # Normalization.
548      input_batch_size = array_ops.shape(inputs)[0]
549    else:
550      input_batch_size = None
551
552    # TODO(rmlarsen): Support using fused avg updates for non-eager execution
553    # after fixing graph pattern matching and enabling fused_batch_norm to
554    # take exponential_avg_factor as a tensor input.
555    use_fused_avg_updates = (
556        ops.executing_eagerly_outside_functions() and
557        isinstance(self.momentum, (float, int)) and
558        get_enclosing_xla_context() is None)
559    if use_fused_avg_updates:
560      exponential_avg_factor = 1.0 - self.momentum
561    else:
562      exponential_avg_factor = None
563
564    def _maybe_add_or_remove_bessels_correction(variance, remove=True):
565      r"""Add or remove Bessel's correction."""
566      # Removes Bessel's correction if remove == True, adds it otherwise.
567      # This is to be consistent with non-fused batch norm. Note that the
568      # variance computed by fused batch norm is with Bessel's correction.
569      # This is only used in legacy V1 batch norm tests.
570      if self._bessels_correction_test_only:
571        return variance
572      sample_size = math_ops.cast(
573          array_ops.size(inputs) / array_ops.size(variance), variance.dtype)
574      if remove:
575        factor = (sample_size -
576                  math_ops.cast(1.0, variance.dtype)) / sample_size
577      else:
578        factor = sample_size / (
579            sample_size - math_ops.cast(1.0, variance.dtype))
580      return variance * factor
581
582    def _fused_batch_norm_training():
583      return nn.fused_batch_norm(
584          inputs,
585          gamma,
586          beta,
587          mean=self.moving_mean,
588          variance=_maybe_add_or_remove_bessels_correction(
589              self.moving_variance, remove=False),
590          epsilon=self.epsilon,
591          is_training=True,
592          data_format=self._data_format,
593          exponential_avg_factor=exponential_avg_factor)
594
595    def _fused_batch_norm_training_empty():
596      return inputs, self.moving_mean, self.moving_variance
597
598    def _fused_batch_norm_inference():
599      return nn.fused_batch_norm(
600          inputs,
601          gamma,
602          beta,
603          mean=self.moving_mean,
604          variance=self.moving_variance,
605          epsilon=self.epsilon,
606          is_training=False,
607          data_format=self._data_format)
608
609    train_op = _fused_batch_norm_training
610    if use_fused_avg_updates and input_batch_size is not None:
611      # pylint: disable=g-long-lambda
612      train_op = lambda: control_flow_util.smart_cond(
613          input_batch_size > 0, _fused_batch_norm_training,
614          _fused_batch_norm_training_empty)
615      # pylint: enable=g-long-lambda
616
617    output, mean, variance = control_flow_util.smart_cond(
618        training, train_op, _fused_batch_norm_inference)
619    variance = _maybe_add_or_remove_bessels_correction(variance, remove=True)
620
621    training_value = control_flow_util.constant_value(training)
622    if training_value or training_value is None:
623      if not use_fused_avg_updates:
624        if training_value is None:
625          momentum = control_flow_util.smart_cond(training,
626                                                  lambda: self.momentum,
627                                                  lambda: 1.0)
628        else:
629          momentum = ops.convert_to_tensor_v2_with_dispatch(self.momentum)
630
631      def mean_update():
632        """Update self.moving_mean with the most recent data point."""
633        if use_fused_avg_updates:
634          return self._assign_new_value(self.moving_mean, mean)
635        else:
636          return self._assign_moving_average(self.moving_mean, mean, momentum,
637                                             input_batch_size)
638
639      def variance_update():
640        """Update self.moving_variance with the most recent data point."""
641        if use_fused_avg_updates:
642          return self._assign_new_value(self.moving_variance, variance)
643        else:
644          return self._assign_moving_average(self.moving_variance, variance,
645                                             momentum, input_batch_size)
646
647      self.add_update(mean_update)
648      self.add_update(variance_update)
649
650    return output
651
652  def _renorm_correction_and_moments(self, mean, variance, training,
653                                     inputs_size):
654    """Returns the correction and update values for renorm."""
655    stddev = math_ops.sqrt(variance + self.epsilon)
656    # Compute the average mean and standard deviation, as if they were
657    # initialized with this batch's moments.
658    renorm_mean = self.renorm_mean
659    # Avoid divide by zero early on in training.
660    renorm_stddev = math_ops.maximum(self.renorm_stddev,
661                                     math_ops.sqrt(self.epsilon))
662    # Compute the corrections for batch renorm.
663    r = stddev / renorm_stddev
664    d = (mean - renorm_mean) / renorm_stddev
665    # Ensure the corrections use pre-update moving averages.
666    with ops.control_dependencies([r, d]):
667      mean = array_ops.identity(mean)
668      stddev = array_ops.identity(stddev)
669    rmin, rmax, dmax = [
670        self.renorm_clipping.get(key) for key in ['rmin', 'rmax', 'dmax']
671    ]
672    if rmin is not None:
673      r = math_ops.maximum(r, rmin)
674    if rmax is not None:
675      r = math_ops.minimum(r, rmax)
676    if dmax is not None:
677      d = math_ops.maximum(d, -dmax)
678      d = math_ops.minimum(d, dmax)
679    # When not training, use r=1, d=0.
680    r = control_flow_util.smart_cond(training, lambda: r,
681                                     lambda: array_ops.ones_like(r))
682    d = control_flow_util.smart_cond(training, lambda: d,
683                                     lambda: array_ops.zeros_like(d))
684
685    def _update_renorm_variable(var, value, inputs_size):
686      """Updates a moving average and weight, returns the unbiased value."""
687      value = array_ops.identity(value)
688
689      def _do_update():
690        """Updates the var, returns the updated value."""
691        new_var = self._assign_moving_average(var, value, self.renorm_momentum,
692                                              inputs_size)
693        return new_var
694
695      def _fake_update():
696        return array_ops.identity(var)
697
698      return control_flow_util.smart_cond(training, _do_update, _fake_update)
699
700    # TODO(yuefengz): colocate the operations
701    update_new_mean = _update_renorm_variable(self.renorm_mean, mean,
702                                              inputs_size)
703    update_new_stddev = _update_renorm_variable(self.renorm_stddev, stddev,
704                                                inputs_size)
705
706    # Update the inference mode moving averages with the batch value.
707    with ops.control_dependencies([update_new_mean, update_new_stddev]):
708      out_mean = array_ops.identity(mean)
709      out_variance = array_ops.identity(variance)
710
711    return (r, d, out_mean, out_variance)
712
713  def _calculate_mean_and_var(self, inputs, reduction_axes, keep_dims):
714    return nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
715
716  def _moments(self, inputs, reduction_axes, keep_dims):
717    mean, variance = self._calculate_mean_and_var(inputs, reduction_axes,
718                                                  keep_dims)
719    # TODO(b/129279393): Support zero batch input in non DistributionStrategy
720    # code as well.
721    if self._support_zero_size_input():
722      input_batch_size = array_ops.shape(inputs)[0]
723      mean = array_ops.where(input_batch_size > 0, mean, K.zeros_like(mean))
724      variance = array_ops.where(input_batch_size > 0, variance,
725                                 K.zeros_like(variance))
726    return mean, variance
727
728  def _get_training_value(self, training=None):
729    if training is None:
730      training = K.learning_phase()
731    if self._USE_V2_BEHAVIOR:
732      if isinstance(training, int):
733        training = bool(training)
734      if not self.trainable:
735        # When the layer is not trainable, it overrides the value passed from
736        # model.
737        training = False
738    return training
739
740  def call(self, inputs, training=None):
741    training = self._get_training_value(training)
742
743    if self.virtual_batch_size is not None:
744      # Virtual batches (aka ghost batches) can be simulated by reshaping the
745      # Tensor and reusing the existing batch norm implementation
746      original_shape = array_ops.shape(inputs)
747      original_shape = array_ops.concat(
748          [constant_op.constant([-1]), original_shape[1:]], axis=0)
749      expanded_shape = array_ops.concat([
750          constant_op.constant([self.virtual_batch_size, -1]),
751          original_shape[1:]
752      ],
753                                        axis=0)
754
755      # Will cause errors if virtual_batch_size does not divide the batch size
756      inputs = array_ops.reshape(inputs, expanded_shape)
757
758      def undo_virtual_batching(outputs):
759        outputs = array_ops.reshape(outputs, original_shape)
760        return outputs
761
762    if self.fused:
763      outputs = self._fused_batch_norm(inputs, training=training)
764      if self.virtual_batch_size is not None:
765        # Currently never reaches here since fused_batch_norm does not support
766        # virtual batching
767        outputs = undo_virtual_batching(outputs)
768      return outputs
769
770    inputs_dtype = inputs.dtype.base_dtype
771    if inputs_dtype in (dtypes.float16, dtypes.bfloat16):
772      # Do all math in float32 if given 16-bit inputs for numeric stability.
773      # In particular, it's very easy for variance to overflow in float16 and
774      # for safety we also choose to cast bfloat16 to float32.
775      inputs = math_ops.cast(inputs, dtypes.float32)
776
777    # Compute the axes along which to reduce the mean / variance
778    input_shape = inputs.shape
779    ndims = len(input_shape)
780    reduction_axes = [i for i in range(ndims) if i not in self.axis]
781    if self.virtual_batch_size is not None:
782      del reduction_axes[1]  # Do not reduce along virtual batch dim
783
784    # Broadcasting only necessary for single-axis batch norm where the axis is
785    # not the last dimension
786    broadcast_shape = [1] * ndims
787    broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value
788
789    def _broadcast(v):
790      if (v is not None and len(v.shape) != ndims and
791          reduction_axes != list(range(ndims - 1))):
792        return array_ops.reshape(v, broadcast_shape)
793      return v
794
795    scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
796
797    def _compose_transforms(scale, offset, then_scale, then_offset):
798      if then_scale is not None:
799        scale *= then_scale
800        offset *= then_scale
801      if then_offset is not None:
802        offset += then_offset
803      return (scale, offset)
804
805    # Determine a boolean value for `training`: could be True, False, or None.
806    training_value = control_flow_util.constant_value(training)
807    if training_value == False:  # pylint: disable=singleton-comparison,g-explicit-bool-comparison
808      mean, variance = self.moving_mean, self.moving_variance
809    else:
810      if self.adjustment:
811        adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
812        # Adjust only during training.
813        adj_scale = control_flow_util.smart_cond(
814            training, lambda: adj_scale, lambda: array_ops.ones_like(adj_scale))
815        adj_bias = control_flow_util.smart_cond(
816            training, lambda: adj_bias, lambda: array_ops.zeros_like(adj_bias))
817        scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset)
818
819      # Some of the computations here are not necessary when training==False
820      # but not a constant. However, this makes the code simpler.
821      keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1
822      mean, variance = self._moments(
823          math_ops.cast(inputs, self._param_dtype),
824          reduction_axes,
825          keep_dims=keep_dims)
826
827      moving_mean = self.moving_mean
828      moving_variance = self.moving_variance
829
830      mean = control_flow_util.smart_cond(
831          training, lambda: mean,
832          lambda: ops.convert_to_tensor_v2_with_dispatch(moving_mean))
833      variance = control_flow_util.smart_cond(
834          training, lambda: variance,
835          lambda: ops.convert_to_tensor_v2_with_dispatch(moving_variance))
836
837      if self.virtual_batch_size is not None:
838        # This isn't strictly correct since in ghost batch norm, you are
839        # supposed to sequentially update the moving_mean and moving_variance
840        # with each sub-batch. However, since the moving statistics are only
841        # used during evaluation, it is more efficient to just update in one
842        # step and should not make a significant difference in the result.
843        new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True)
844        new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True)
845      else:
846        new_mean, new_variance = mean, variance
847
848      if self._support_zero_size_input():
849        # Keras assumes that batch dimension is the first dimension for Batch
850        # Normalization.
851        input_batch_size = array_ops.shape(inputs)[0]
852      else:
853        input_batch_size = None
854
855      if self.renorm:
856        r, d, new_mean, new_variance = self._renorm_correction_and_moments(
857            new_mean, new_variance, training, input_batch_size)
858        # When training, the normalized values (say, x) will be transformed as
859        # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
860        # = x * (r * gamma) + (d * gamma + beta) with renorm.
861        r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
862        d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
863        scale, offset = _compose_transforms(r, d, scale, offset)
864
865      def _do_update(var, value):
866        """Compute the updates for mean and variance."""
867        return self._assign_moving_average(var, value, self.momentum,
868                                           input_batch_size)
869
870      def mean_update():
871        true_branch = lambda: _do_update(self.moving_mean, new_mean)
872        false_branch = lambda: self.moving_mean
873        return control_flow_util.smart_cond(training, true_branch, false_branch)
874
875      def variance_update():
876        """Update the moving variance."""
877
878        def true_branch_renorm():
879          # We apply epsilon as part of the moving_stddev to mirror the training
880          # code path.
881          moving_stddev = _do_update(self.moving_stddev,
882                                     math_ops.sqrt(new_variance + self.epsilon))
883          return self._assign_new_value(
884              self.moving_variance,
885              # Apply relu in case floating point rounding causes it to go
886              # negative.
887              K.relu(moving_stddev * moving_stddev - self.epsilon))
888
889        if self.renorm:
890          true_branch = true_branch_renorm
891        else:
892          true_branch = lambda: _do_update(self.moving_variance, new_variance)
893
894        false_branch = lambda: self.moving_variance
895        return control_flow_util.smart_cond(training, true_branch, false_branch)
896
897      self.add_update(mean_update)
898      self.add_update(variance_update)
899
900    mean = math_ops.cast(mean, inputs.dtype)
901    variance = math_ops.cast(variance, inputs.dtype)
902    if offset is not None:
903      offset = math_ops.cast(offset, inputs.dtype)
904    if scale is not None:
905      scale = math_ops.cast(scale, inputs.dtype)
906    outputs = nn.batch_normalization(inputs, _broadcast(mean),
907                                     _broadcast(variance), offset, scale,
908                                     self.epsilon)
909    if inputs_dtype in (dtypes.float16, dtypes.bfloat16):
910      outputs = math_ops.cast(outputs, inputs_dtype)
911
912    # If some components of the shape got lost due to adjustments, fix that.
913    outputs.set_shape(input_shape)
914
915    if self.virtual_batch_size is not None:
916      outputs = undo_virtual_batching(outputs)
917    return outputs
918
919  def compute_output_shape(self, input_shape):
920    return input_shape
921
922  def get_config(self):
923    config = {
924        'axis':
925            self.axis,
926        'momentum':
927            self.momentum,
928        'epsilon':
929            self.epsilon,
930        'center':
931            self.center,
932        'scale':
933            self.scale,
934        'beta_initializer':
935            initializers.serialize(self.beta_initializer),
936        'gamma_initializer':
937            initializers.serialize(self.gamma_initializer),
938        'moving_mean_initializer':
939            initializers.serialize(self.moving_mean_initializer),
940        'moving_variance_initializer':
941            initializers.serialize(self.moving_variance_initializer),
942        'beta_regularizer':
943            regularizers.serialize(self.beta_regularizer),
944        'gamma_regularizer':
945            regularizers.serialize(self.gamma_regularizer),
946        'beta_constraint':
947            constraints.serialize(self.beta_constraint),
948        'gamma_constraint':
949            constraints.serialize(self.gamma_constraint)
950    }
951    # Only add TensorFlow-specific parameters if they are set, so as to preserve
952    # model compatibility with external Keras.
953    if self.renorm:
954      config['renorm'] = True
955      config['renorm_clipping'] = self.renorm_clipping
956      config['renorm_momentum'] = self.renorm_momentum
957    if self.virtual_batch_size is not None:
958      config['virtual_batch_size'] = self.virtual_batch_size
959    # Note: adjustment is not serializable.
960    if self.adjustment is not None:
961      logging.warning('The `adjustment` function of this `BatchNormalization` '
962                      'layer cannot be serialized and has been omitted from '
963                      'the layer config. It will not be included when '
964                      're-creating the layer from the saved config.')
965    base_config = super(BatchNormalizationBase, self).get_config()
966    return dict(list(base_config.items()) + list(config.items()))
967
968
969# pylint: disable=missing-docstring
970@keras_export(v1=['keras.layers.BatchNormalization'])
971class BatchNormalization(BatchNormalizationBase):
972  _USE_V2_BEHAVIOR = False
973
974
975@keras_export('keras.layers.LayerNormalization')
976class LayerNormalization(Layer):
977  """Layer normalization layer (Ba et al., 2016).
978
979  Normalize the activations of the previous layer for each given example in a
980  batch independently, rather than across a batch like Batch Normalization.
981  i.e. applies a transformation that maintains the mean activation within each
982  example close to 0 and the activation standard deviation close to 1.
983
984  Given a tensor `inputs`, moments are calculated and normalization
985  is performed across the axes specified in `axis`.
986
987  Example:
988
989  >>> data = tf.constant(np.arange(10).reshape(5, 2) * 10, dtype=tf.float32)
990  >>> print(data)
991  tf.Tensor(
992  [[ 0. 10.]
993   [20. 30.]
994   [40. 50.]
995   [60. 70.]
996   [80. 90.]], shape=(5, 2), dtype=float32)
997
998  >>> layer = tf.keras.layers.LayerNormalization(axis=1)
999  >>> output = layer(data)
1000  >>> print(output)
1001  tf.Tensor(
1002  [[-1. 1.]
1003   [-1. 1.]
1004   [-1. 1.]
1005   [-1. 1.]
1006   [-1. 1.]], shape=(5, 2), dtype=float32)
1007
1008  Notice that with Layer Normalization the normalization happens across the
1009  axes *within* each example, rather than across different examples in the
1010  batch.
1011
1012  If `scale` or `center` are enabled, the layer will scale the normalized
1013  outputs by broadcasting them with a trainable variable `gamma`, and center
1014  the outputs by broadcasting with a trainable variable `beta`. `gamma` will
1015  default to a ones tensor and `beta` will default to a zeros tensor, so that
1016  centering and scaling are no-ops before training has begun.
1017
1018  So, with scaling and centering enabled the normalization equations
1019  are as follows:
1020
1021  Let the intermediate activations for a mini-batch to be the `inputs`.
1022
1023  For each sample `x_i` in `inputs` with `k` features, we compute the mean and
1024  variance of the sample:
1025
1026  ```python
1027  mean_i = sum(x_i[j] for j in range(k)) / k
1028  var_i = sum((x_i[j] - mean_i) ** 2 for j in range(k)) / k
1029  ```
1030
1031  and then compute a normalized `x_i_normalized`, including a small factor
1032  `epsilon` for numerical stability.
1033
1034  ```python
1035  x_i_normalized = (x_i - mean_i) / sqrt(var_i + epsilon)
1036  ```
1037
1038  And finally `x_i_normalized ` is linearly transformed by `gamma` and `beta`,
1039  which are learned parameters:
1040
1041  ```python
1042  output_i = x_i_normalized * gamma + beta
1043  ```
1044
1045  `gamma` and `beta` will span the axes of `inputs` specified in `axis`, and
1046  this part of the inputs' shape must be fully defined.
1047
1048  For example:
1049
1050  >>> layer = tf.keras.layers.LayerNormalization(axis=[1, 2, 3])
1051  >>> layer.build([5, 20, 30, 40])
1052  >>> print(layer.beta.shape)
1053  (20, 30, 40)
1054  >>> print(layer.gamma.shape)
1055  (20, 30, 40)
1056
1057  Note that other implementations of layer normalization may choose to define
1058  `gamma` and `beta` over a separate set of axes from the axes being
1059  normalized across. For example, Group Normalization
1060  ([Wu et al. 2018](https://arxiv.org/abs/1803.08494)) with group size of 1
1061  corresponds to a Layer Normalization that normalizes across height, width,
1062  and channel and has `gamma` and `beta` span only the channel dimension.
1063  So, this Layer Normalization implementation will not match a Group
1064  Normalization layer with group size set to 1.
1065
1066  Args:
1067    axis: Integer or List/Tuple. The axis or axes to normalize across. Typically
1068      this is the features axis/axes. The left-out axes are typically the batch
1069      axis/axes. This argument defaults to `-1`, the last dimension in the
1070      input.
1071    epsilon: Small float added to variance to avoid dividing by zero. Defaults
1072      to 1e-3
1073    center: If True, add offset of `beta` to normalized tensor. If False, `beta`
1074      is ignored. Defaults to True.
1075    scale: If True, multiply by `gamma`. If False, `gamma` is not used. Defaults
1076      to True. When the next layer is linear (also e.g. `nn.relu`), this can be
1077      disabled since the scaling will be done by the next layer.
1078    beta_initializer: Initializer for the beta weight. Defaults to zeros.
1079    gamma_initializer: Initializer for the gamma weight. Defaults to ones.
1080    beta_regularizer: Optional regularizer for the beta weight. None by default.
1081    gamma_regularizer: Optional regularizer for the gamma weight. None by
1082      default.
1083    beta_constraint: Optional constraint for the beta weight. None by default.
1084    gamma_constraint: Optional constraint for the gamma weight. None by default.
1085
1086  Input shape:
1087    Arbitrary. Use the keyword argument `input_shape` (tuple of
1088    integers, does not include the samples axis) when using this layer as the
1089    first layer in a model.
1090
1091  Output shape:
1092    Same shape as input.
1093
1094  Reference:
1095    - [Lei Ba et al., 2016](https://arxiv.org/abs/1607.06450).
1096  """
1097
1098  def __init__(self,
1099               axis=-1,
1100               epsilon=1e-3,
1101               center=True,
1102               scale=True,
1103               beta_initializer='zeros',
1104               gamma_initializer='ones',
1105               beta_regularizer=None,
1106               gamma_regularizer=None,
1107               beta_constraint=None,
1108               gamma_constraint=None,
1109               **kwargs):
1110    super(LayerNormalization, self).__init__(**kwargs)
1111    if isinstance(axis, (list, tuple)):
1112      self.axis = axis[:]
1113    elif isinstance(axis, int):
1114      self.axis = axis
1115    else:
1116      raise TypeError('Expected an int or a list/tuple of ints for the '
1117                      'argument \'axis\', but received: %r' % axis)
1118
1119    self.epsilon = epsilon
1120    self.center = center
1121    self.scale = scale
1122    self.beta_initializer = initializers.get(beta_initializer)
1123    self.gamma_initializer = initializers.get(gamma_initializer)
1124    self.beta_regularizer = regularizers.get(beta_regularizer)
1125    self.gamma_regularizer = regularizers.get(gamma_regularizer)
1126    self.beta_constraint = constraints.get(beta_constraint)
1127    self.gamma_constraint = constraints.get(gamma_constraint)
1128
1129    self.supports_masking = True
1130
1131    # Indicates whether a faster fused implementation can be used. This will be
1132    # set to True or False in build()"
1133    self._fused = None
1134
1135  def _fused_can_be_used(self, ndims):
1136    """Return false if fused implementation cannot be used.
1137
1138    Check if the axis is contiguous and can be collapsed into the last axis.
1139    The self.axis is assumed to have no duplicates.
1140    """
1141    axis = sorted(self.axis)
1142    can_use_fused = False
1143
1144    if axis[-1] == ndims - 1 and axis[-1] - axis[0] == len(axis) - 1:
1145      can_use_fused = True
1146
1147    # fused_batch_norm will silently raise epsilon to be at least 1.001e-5, so
1148    # we cannot used the fused version if epsilon is below that value. Also, the
1149    # variable dtype must be float32, as fused_batch_norm only supports float32
1150    # variables.
1151    if self.epsilon < 1.001e-5 or self.dtype != 'float32':
1152      can_use_fused = False
1153
1154    return can_use_fused
1155
1156  def build(self, input_shape):
1157    ndims = len(input_shape)
1158    if ndims is None:
1159      raise ValueError('Input shape %s has undefined rank.' % input_shape)
1160
1161    # Convert axis to list and resolve negatives
1162    if isinstance(self.axis, int):
1163      self.axis = [self.axis]
1164    elif isinstance(self.axis, tuple):
1165      self.axis = list(self.axis)
1166    for idx, x in enumerate(self.axis):
1167      if x < 0:
1168        self.axis[idx] = ndims + x
1169
1170    # Validate axes
1171    for x in self.axis:
1172      if x < 0 or x >= ndims:
1173        raise ValueError('Invalid axis: %d' % x)
1174    if len(self.axis) != len(set(self.axis)):
1175      raise ValueError('Duplicate axis: {}'.format(tuple(self.axis)))
1176
1177    param_shape = [input_shape[dim] for dim in self.axis]
1178    if self.scale:
1179      self.gamma = self.add_weight(
1180          name='gamma',
1181          shape=param_shape,
1182          initializer=self.gamma_initializer,
1183          regularizer=self.gamma_regularizer,
1184          constraint=self.gamma_constraint,
1185          trainable=True,
1186          experimental_autocast=False)
1187    else:
1188      self.gamma = None
1189
1190    if self.center:
1191      self.beta = self.add_weight(
1192          name='beta',
1193          shape=param_shape,
1194          initializer=self.beta_initializer,
1195          regularizer=self.beta_regularizer,
1196          constraint=self.beta_constraint,
1197          trainable=True,
1198          experimental_autocast=False)
1199    else:
1200      self.beta = None
1201
1202    self._fused = self._fused_can_be_used(ndims)
1203
1204    self.built = True
1205
1206  def call(self, inputs):
1207    # Compute the axes along which to reduce the mean / variance
1208    input_shape = inputs.shape
1209    ndims = len(input_shape)
1210
1211    # Broadcasting only necessary for norm when the axis is not just
1212    # the last dimension
1213    broadcast_shape = [1] * ndims
1214    for dim in self.axis:
1215      broadcast_shape[dim] = input_shape.dims[dim].value
1216
1217    def _broadcast(v):
1218      if (v is not None and len(v.shape) != ndims and self.axis != [ndims - 1]):
1219        return array_ops.reshape(v, broadcast_shape)
1220      return v
1221
1222    if not self._fused:
1223      input_dtype = inputs.dtype
1224      if input_dtype in ('float16', 'bfloat16') and self.dtype == 'float32':
1225        # If mixed precision is used, cast inputs to float32 so that this is at
1226        # least as numerically stable as the fused version.
1227        inputs = math_ops.cast(inputs, 'float32')
1228
1229      # Calculate the moments on the last axis (layer activations).
1230      mean, variance = nn.moments(inputs, self.axis, keep_dims=True)
1231
1232      scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
1233
1234      # Compute layer normalization using the batch_normalization function.
1235      outputs = nn.batch_normalization(
1236          inputs,
1237          mean,
1238          variance,
1239          offset=offset,
1240          scale=scale,
1241          variance_epsilon=self.epsilon)
1242      outputs = math_ops.cast(outputs, input_dtype)
1243    else:
1244      # Collapse dims before self.axis, and dims in self.axis
1245      pre_dim, in_dim = (1, 1)
1246      axis = sorted(self.axis)
1247      tensor_shape = array_ops.shape(inputs)
1248      for dim in range(0, ndims):
1249        dim_tensor = tensor_shape[dim]
1250        if dim < axis[0]:
1251          pre_dim = pre_dim * dim_tensor
1252        else:
1253          assert dim in axis
1254          in_dim = in_dim * dim_tensor
1255
1256      squeezed_shape = [1, pre_dim, in_dim, 1]
1257      # This fused operation requires reshaped inputs to be NCHW.
1258      data_format = 'NCHW'
1259
1260      inputs = array_ops.reshape(inputs, squeezed_shape)
1261
1262      def _set_const_tensor(val, dtype, shape):
1263        return array_ops.fill(shape, constant_op.constant(val, dtype=dtype))
1264
1265      # self.gamma and self.beta have the wrong shape for fused_batch_norm, so
1266      # we cannot pass them as the scale and offset parameters. Therefore, we
1267      # create two constant tensors in correct shapes for fused_batch_norm and
1268      # later construct a separate calculation on the scale and offset.
1269      scale = _set_const_tensor(1.0, self.dtype, [pre_dim])
1270      offset = _set_const_tensor(0.0, self.dtype, [pre_dim])
1271
1272      # Compute layer normalization using the fused_batch_norm function.
1273      outputs, _, _ = nn.fused_batch_norm(
1274          inputs,
1275          scale=scale,
1276          offset=offset,
1277          epsilon=self.epsilon,
1278          data_format=data_format)
1279
1280      outputs = array_ops.reshape(outputs, tensor_shape)
1281
1282      scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
1283
1284      if scale is not None:
1285        outputs = outputs * math_ops.cast(scale, outputs.dtype)
1286      if offset is not None:
1287        outputs = outputs + math_ops.cast(offset, outputs.dtype)
1288
1289    # If some components of the shape got lost due to adjustments, fix that.
1290    outputs.set_shape(input_shape)
1291
1292    return outputs
1293
1294  def compute_output_shape(self, input_shape):
1295    return input_shape
1296
1297  def get_config(self):
1298    config = {
1299        'axis': self.axis,
1300        'epsilon': self.epsilon,
1301        'center': self.center,
1302        'scale': self.scale,
1303        'beta_initializer': initializers.serialize(self.beta_initializer),
1304        'gamma_initializer': initializers.serialize(self.gamma_initializer),
1305        'beta_regularizer': regularizers.serialize(self.beta_regularizer),
1306        'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
1307        'beta_constraint': constraints.serialize(self.beta_constraint),
1308        'gamma_constraint': constraints.serialize(self.gamma_constraint)
1309    }
1310    base_config = super(LayerNormalization, self).get_config()
1311    return dict(list(base_config.items()) + list(config.items()))
1312