1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Normalization layers.
16"""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.distribute import distribution_strategy_context
22from tensorflow.python.eager import context
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_shape
26from tensorflow.python.keras import backend as K
27from tensorflow.python.keras import constraints
28from tensorflow.python.keras import initializers
29from tensorflow.python.keras import regularizers
30from tensorflow.python.keras.engine.base_layer import Layer
31from tensorflow.python.keras.engine.input_spec import InputSpec
32from tensorflow.python.keras.utils import tf_utils
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import init_ops
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops import nn
37from tensorflow.python.ops import state_ops
38from tensorflow.python.ops import variables as tf_variables
39from tensorflow.python.platform import tf_logging as logging
40from tensorflow.python.util.tf_export import keras_export
41
42
43class BatchNormalizationBase(Layer):
44  """Base class of Batch normalization layer (Ioffe and Szegedy, 2014).
45
46  Normalize the activations of the previous layer at each batch,
47  i.e. applies a transformation that maintains the mean activation
48  close to 0 and the activation standard deviation close to 1.
49
50  Arguments:
51    axis: Integer, the axis that should be normalized
52      (typically the features axis).
53      For instance, after a `Conv2D` layer with
54      `data_format="channels_first"`,
55      set `axis=1` in `BatchNormalization`.
56    momentum: Momentum for the moving average.
57    epsilon: Small float added to variance to avoid dividing by zero.
58    center: If True, add offset of `beta` to normalized tensor.
59      If False, `beta` is ignored.
60    scale: If True, multiply by `gamma`.
61      If False, `gamma` is not used.
62      When the next layer is linear (also e.g. `nn.relu`),
63      this can be disabled since the scaling
64      will be done by the next layer.
65    beta_initializer: Initializer for the beta weight.
66    gamma_initializer: Initializer for the gamma weight.
67    moving_mean_initializer: Initializer for the moving mean.
68    moving_variance_initializer: Initializer for the moving variance.
69    beta_regularizer: Optional regularizer for the beta weight.
70    gamma_regularizer: Optional regularizer for the gamma weight.
71    beta_constraint: Optional constraint for the beta weight.
72    gamma_constraint: Optional constraint for the gamma weight.
73    renorm: Whether to use Batch Renormalization
74      (https://arxiv.org/abs/1702.03275). This adds extra variables during
75      training. The inference is the same for either value of this parameter.
76    renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to
77      scalar `Tensors` used to clip the renorm correction. The correction
78      `(r, d)` is used as `corrected_value = normalized_value * r + d`, with
79      `r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin,
80      dmax are set to inf, 0, inf, respectively.
81    renorm_momentum: Momentum used to update the moving means and standard
82      deviations with renorm. Unlike `momentum`, this affects training
83      and should be neither too small (which would add noise) nor too large
84      (which would give stale estimates). Note that `momentum` is still applied
85      to get the means and variances for inference.
86    fused: if `True`, use a faster, fused implementation, or raise a ValueError
87      if the fused implementation cannot be used. If `None`, use the faster
88      implementation if possible. If False, do not used the fused
89      implementation.
90    trainable: Boolean, if `True` the variables will be marked as trainable.
91    virtual_batch_size: An `int`. By default, `virtual_batch_size` is `None`,
92      which means batch normalization is performed across the whole batch. When
93      `virtual_batch_size` is not `None`, instead perform "Ghost Batch
94      Normalization", which creates virtual sub-batches which are each
95      normalized separately (with shared gamma, beta, and moving statistics).
96      Must divide the actual batch size during execution.
97    adjustment: A function taking the `Tensor` containing the (dynamic) shape of
98      the input tensor and returning a pair (scale, bias) to apply to the
99      normalized values (before gamma and beta), only during training. For
100      example, if axis==-1,
101        `adjustment = lambda shape: (
102          tf.random_uniform(shape[-1:], 0.93, 1.07),
103          tf.random_uniform(shape[-1:], -0.1, 0.1))`
104      will scale the normalized value by up to 7% up or down, then shift the
105      result by up to 0.1 (with independent scaling and bias for each feature
106      but shared across all examples), and finally apply gamma and/or beta. If
107      `None`, no adjustment is applied. Cannot be specified if
108      virtual_batch_size is specified.
109
110  Call arguments:
111    inputs: Input tensor (of any rank).
112    training: Python boolean indicating whether the layer should behave in
113      training mode or in inference mode.
114      - `training=True`: The layer will normalize its inputs using the
115        mean and variance of the current batch of inputs.
116      - `training=False`: The layer will normalize its inputs using the
117        mean and variance of its moving statistics, learned during training.
118
119  Input shape:
120    Arbitrary. Use the keyword argument `input_shape`
121    (tuple of integers, does not include the samples axis)
122    when using this layer as the first layer in a model.
123
124  Output shape:
125    Same shape as input.
126
127  References:
128    - [Batch Normalization: Accelerating Deep Network Training by Reducing
129      Internal Covariate Shift](https://arxiv.org/abs/1502.03167)
130  """
131
132  # By default, the base class uses V2 behavior. The BatchNormalization V1
133  # subclass sets this to False to use the V1 behavior.
134  _USE_V2_BEHAVIOR = True
135
136  def __init__(self,
137               axis=-1,
138               momentum=0.99,
139               epsilon=1e-3,
140               center=True,
141               scale=True,
142               beta_initializer='zeros',
143               gamma_initializer='ones',
144               moving_mean_initializer='zeros',
145               moving_variance_initializer='ones',
146               beta_regularizer=None,
147               gamma_regularizer=None,
148               beta_constraint=None,
149               gamma_constraint=None,
150               renorm=False,
151               renorm_clipping=None,
152               renorm_momentum=0.99,
153               fused=None,
154               trainable=True,
155               virtual_batch_size=None,
156               adjustment=None,
157               name=None,
158               **kwargs):
159    super(BatchNormalizationBase, self).__init__(
160        name=name, trainable=trainable, **kwargs)
161    if isinstance(axis, list):
162      self.axis = axis[:]
163    elif isinstance(axis, int):
164      self.axis = axis
165    else:
166      raise TypeError('axis must be int or list, type given: %s'
167                      % type(self.axis))
168    self.momentum = momentum
169    self.epsilon = epsilon
170    self.center = center
171    self.scale = scale
172    self.beta_initializer = initializers.get(beta_initializer)
173    self.gamma_initializer = initializers.get(gamma_initializer)
174    self.moving_mean_initializer = initializers.get(moving_mean_initializer)
175    self.moving_variance_initializer = initializers.get(
176        moving_variance_initializer)
177    self.beta_regularizer = regularizers.get(beta_regularizer)
178    self.gamma_regularizer = regularizers.get(gamma_regularizer)
179    self.beta_constraint = constraints.get(beta_constraint)
180    self.gamma_constraint = constraints.get(gamma_constraint)
181    self.renorm = renorm
182    self.virtual_batch_size = virtual_batch_size
183    self.adjustment = adjustment
184    if self._USE_V2_BEHAVIOR:
185      if fused:
186        self._raise_if_fused_cannot_be_used()
187      # We leave fused as None if self._fused_can_be_used()==True, since we
188      # still may set it to False in self.build() if the input rank is not 4.
189      elif fused is None and not self._fused_can_be_used():
190        fused = False
191    elif fused is None:
192      fused = True
193    self.supports_masking = True
194
195    self.fused = fused
196    self._bessels_correction_test_only = True
197
198    if renorm:
199      renorm_clipping = renorm_clipping or {}
200      keys = ['rmax', 'rmin', 'dmax']
201      if set(renorm_clipping) - set(keys):
202        raise ValueError('renorm_clipping %s contains keys not in %s' %
203                         (renorm_clipping, keys))
204      self.renorm_clipping = renorm_clipping
205      self.renorm_momentum = renorm_momentum
206
207  def _raise_if_fused_cannot_be_used(self):
208    """Raises a ValueError if fused implementation cannot be used.
209
210    In addition to the checks done in this function, the input tensors rank must
211    be 4. The input rank check can only be done once the input shape is known.
212    """
213    # Currently fused batch norm doesn't support renorm. It also only supports a
214    # channel dimension on axis 1 or 3, when no virtual batch size or adjustment
215    # is used.
216    if self.renorm:
217      raise ValueError('Passing both fused=True and renorm=True is '
218                       'unsupported')
219    axis = [self.axis] if isinstance(self.axis, int) else self.axis
220    # Axis -3 is equivalent to 1, and axis -1 is equivalent to 3, because the
221    # input rank is required to be 4 (which is checked later).
222    if len(axis) > 1 or axis[0] not in (-3, -1, 1, 3):
223      raise ValueError('Passing fused=True is only supported when axis is 1 '
224                       'or 3')
225    if self.virtual_batch_size is not None:
226      raise ValueError('Passing fused=True is unsupported when '
227                       'virtual_batch_size is specified.')
228    if self.adjustment is not None:
229      raise ValueError('Passing fused=True is unsupported when '
230                       'adjustment is specified.')
231
232  def _fused_can_be_used(self):
233    try:
234      self._raise_if_fused_cannot_be_used()
235      return True
236    except ValueError:
237      return False
238
239  @property
240  def _param_dtype(self):
241    # Raise parameters of fp16 batch norm to fp32
242    if self.dtype == dtypes.float16 or self.dtype == dtypes.bfloat16:
243      return dtypes.float32
244    else:
245      return self.dtype or dtypes.float32
246
247  def build(self, input_shape):
248    input_shape = tensor_shape.TensorShape(input_shape)
249    if not input_shape.ndims:
250      raise ValueError('Input has undefined rank:', input_shape)
251    ndims = len(input_shape)
252
253    # Convert axis to list and resolve negatives
254    if isinstance(self.axis, int):
255      self.axis = [self.axis]
256
257    for idx, x in enumerate(self.axis):
258      if x < 0:
259        self.axis[idx] = ndims + x
260
261    # Validate axes
262    for x in self.axis:
263      if x < 0 or x >= ndims:
264        raise ValueError('Invalid axis: %d' % x)
265    if len(self.axis) != len(set(self.axis)):
266      raise ValueError('Duplicate axis: %s' % self.axis)
267
268    if self.virtual_batch_size is not None:
269      if self.virtual_batch_size <= 0:
270        raise ValueError('virtual_batch_size must be a positive integer that '
271                         'divides the true batch size of the input Tensor')
272      # If using virtual batches, the first dimension must be the batch
273      # dimension and cannot be the batch norm axis
274      if 0 in self.axis:
275        raise ValueError('When using virtual_batch_size, the batch dimension '
276                         'must be 0 and thus axis cannot include 0')
277      if self.adjustment is not None:
278        raise ValueError('When using virtual_batch_size, adjustment cannot '
279                         'be specified')
280
281    if self.fused in (None, True):
282      # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
283      # output back to its original shape accordingly.
284      if self._USE_V2_BEHAVIOR:
285        if self.fused is None:
286          self.fused = (ndims == 4)
287        elif self.fused and ndims != 4:
288          raise ValueError('Batch normalization layers with fused=True only '
289                           'support 4D input tensors.')
290      else:
291        assert self.fused is not None
292        self.fused = (ndims == 4 and self._fused_can_be_used())
293      # TODO(chrisying): fused batch norm is currently not supported for
294      # multi-axis batch norm and by extension virtual batches. In some cases,
295      # it might be possible to use fused batch norm but would require reshaping
296      # the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is
297      # particularly tricky. A compromise might be to just support the most
298      # common use case (turning 5D w/ virtual batch to NCHW)
299
300    if self.fused:
301      if self.axis == [1]:
302        self._data_format = 'NCHW'
303      elif self.axis == [3]:
304        self._data_format = 'NHWC'
305      else:
306        raise ValueError('Unsupported axis, fused batch norm only supports '
307                         'axis == [1] or axis == [3]')
308
309    axis_to_dim = {x: input_shape.dims[x].value for x in self.axis}
310    for x in axis_to_dim:
311      if axis_to_dim[x] is None:
312        raise ValueError('Input has undefined `axis` dimension. Input shape: ',
313                         input_shape)
314    self.input_spec = InputSpec(ndim=ndims, axes=axis_to_dim)
315
316    if len(axis_to_dim) == 1 and self.virtual_batch_size is None:
317      # Single axis batch norm (most common/default use-case)
318      param_shape = (list(axis_to_dim.values())[0],)
319    else:
320      # Parameter shape is the original shape but with 1 in all non-axis dims
321      param_shape = [axis_to_dim[i] if i in axis_to_dim
322                     else 1 for i in range(ndims)]
323      if self.virtual_batch_size is not None:
324        # When using virtual batches, add an extra dim at index 1
325        param_shape.insert(1, 1)
326        for idx, x in enumerate(self.axis):
327          self.axis[idx] = x + 1      # Account for added dimension
328
329    if self.scale:
330      self.gamma = self.add_weight(
331          name='gamma',
332          shape=param_shape,
333          dtype=self._param_dtype,
334          initializer=self.gamma_initializer,
335          regularizer=self.gamma_regularizer,
336          constraint=self.gamma_constraint,
337          trainable=True,
338          experimental_autocast=False)
339    else:
340      self.gamma = None
341      if self.fused:
342        self._gamma_const = K.constant(
343            1.0, dtype=self._param_dtype, shape=param_shape)
344
345    if self.center:
346      self.beta = self.add_weight(
347          name='beta',
348          shape=param_shape,
349          dtype=self._param_dtype,
350          initializer=self.beta_initializer,
351          regularizer=self.beta_regularizer,
352          constraint=self.beta_constraint,
353          trainable=True,
354          experimental_autocast=False)
355    else:
356      self.beta = None
357      if self.fused:
358        self._beta_const = K.constant(
359            0.0, dtype=self._param_dtype, shape=param_shape)
360
361    try:
362      # Disable variable partitioning when creating the moving mean and variance
363      if hasattr(self, '_scope') and self._scope:
364        partitioner = self._scope.partitioner
365        self._scope.set_partitioner(None)
366      else:
367        partitioner = None
368      self.moving_mean = self.add_weight(
369          name='moving_mean',
370          shape=param_shape,
371          dtype=self._param_dtype,
372          initializer=self.moving_mean_initializer,
373          synchronization=tf_variables.VariableSynchronization.ON_READ,
374          trainable=False,
375          aggregation=tf_variables.VariableAggregation.MEAN,
376          experimental_autocast=False)
377
378      self.moving_variance = self.add_weight(
379          name='moving_variance',
380          shape=param_shape,
381          dtype=self._param_dtype,
382          initializer=self.moving_variance_initializer,
383          synchronization=tf_variables.VariableSynchronization.ON_READ,
384          trainable=False,
385          aggregation=tf_variables.VariableAggregation.MEAN,
386          experimental_autocast=False)
387
388      if self.renorm:
389        # Create variables to maintain the moving mean and standard deviation.
390        # These are used in training and thus are different from the moving
391        # averages above. The renorm variables are colocated with moving_mean
392        # and moving_variance.
393        # NOTE: below, the outer `with device` block causes the current device
394        # stack to be cleared. The nested ones use a `lambda` to set the desired
395        # device and ignore any devices that may be set by the custom getter.
396        def _renorm_variable(name, shape):
397          """Create a renorm variable."""
398          var = self.add_weight(
399              name=name,
400              shape=shape,
401              dtype=self._param_dtype,
402              initializer=init_ops.zeros_initializer(),
403              synchronization=tf_variables.VariableSynchronization.ON_READ,
404              trainable=False,
405              aggregation=tf_variables.VariableAggregation.MEAN,
406              experimental_autocast=False)
407          return var
408
409        with distribution_strategy_context.get_strategy(
410        ).extended.colocate_vars_with(self.moving_mean):
411          self.renorm_mean = _renorm_variable('renorm_mean', param_shape)
412          self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ())
413        # We initialize renorm_stddev to 0, and maintain the (0-initialized)
414        # renorm_stddev_weight. This allows us to (1) mix the average
415        # stddev with the minibatch stddev early in training, and (2) compute
416        # the unbiased average stddev by dividing renorm_stddev by the weight.
417        with distribution_strategy_context.get_strategy(
418        ).extended.colocate_vars_with(self.moving_variance):
419          self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape)
420          self.renorm_stddev_weight = _renorm_variable('renorm_stddev_weight',
421                                                       ())
422    finally:
423      if partitioner:
424        self._scope.set_partitioner(partitioner)
425    self.built = True
426
427  def _assign_moving_average(self, variable, value, momentum):
428    with ops.name_scope(None, 'AssignMovingAvg',
429                        [variable, value, momentum]) as scope:
430      with ops.colocate_with(variable):
431        decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
432        if decay.dtype != variable.dtype.base_dtype:
433          decay = math_ops.cast(decay, variable.dtype.base_dtype)
434        update_delta = (
435            variable - math_ops.cast(value, variable.dtype)) * decay
436        return state_ops.assign_sub(variable, update_delta, name=scope)
437
438  def _fused_batch_norm(self, inputs, training):
439    """Returns the output of fused batch norm."""
440    beta = self.beta if self.center else self._beta_const
441    gamma = self.gamma if self.scale else self._gamma_const
442
443    def _fused_batch_norm_training():
444      return nn.fused_batch_norm(
445          inputs,
446          gamma,
447          beta,
448          epsilon=self.epsilon,
449          data_format=self._data_format)
450
451    def _fused_batch_norm_inference():
452      return nn.fused_batch_norm(
453          inputs,
454          gamma,
455          beta,
456          mean=self.moving_mean,
457          variance=self.moving_variance,
458          epsilon=self.epsilon,
459          is_training=False,
460          data_format=self._data_format)
461
462    output, mean, variance = tf_utils.smart_cond(
463        training, _fused_batch_norm_training, _fused_batch_norm_inference)
464    if not self._bessels_correction_test_only:
465      # Remove Bessel's correction to be consistent with non-fused batch norm.
466      # Note that the variance computed by fused batch norm is
467      # with Bessel's correction.
468      sample_size = math_ops.cast(
469          array_ops.size(inputs) / array_ops.size(variance), variance.dtype)
470      factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size
471      variance *= factor
472
473    training_value = tf_utils.constant_value(training)
474    if training_value is None:
475      momentum = tf_utils.smart_cond(training,
476                                     lambda: self.momentum,
477                                     lambda: 1.0)
478    else:
479      momentum = ops.convert_to_tensor(self.momentum)
480    if training_value or training_value is None:
481      if distribution_strategy_context.in_cross_replica_context():
482        strategy = distribution_strategy_context.get_strategy()
483        mean_update = strategy.extended.update(
484            self.moving_mean, self._assign_moving_average,
485            (mean, self.momentum))
486        variance_update = strategy.extended.update(
487            self.moving_variance, self._assign_moving_average,
488            (variance, self.momentum))
489      else:
490        mean_update = self._assign_moving_average(self.moving_mean, mean,
491                                                  momentum)
492        variance_update = self._assign_moving_average(self.moving_variance,
493                                                      variance, momentum)
494      self.add_update(mean_update, inputs=True)
495      self.add_update(variance_update, inputs=True)
496
497    return output
498
499  def _renorm_correction_and_moments(self, mean, variance, training):
500    """Returns the correction and update values for renorm."""
501    stddev = math_ops.sqrt(variance + self.epsilon)
502    # Compute the average mean and standard deviation, as if they were
503    # initialized with this batch's moments.
504    mixed_renorm_mean = (self.renorm_mean +
505                         (1. - self.renorm_mean_weight) * mean)
506    mixed_renorm_stddev = (self.renorm_stddev +
507                           (1. - self.renorm_stddev_weight) * stddev)
508    # Compute the corrections for batch renorm.
509    r = stddev / mixed_renorm_stddev
510    d = (mean - mixed_renorm_mean) / mixed_renorm_stddev
511    # Ensure the corrections use pre-update moving averages.
512    with ops.control_dependencies([r, d]):
513      mean = array_ops.identity(mean)
514      stddev = array_ops.identity(stddev)
515    rmin, rmax, dmax = [self.renorm_clipping.get(key)
516                        for key in ['rmin', 'rmax', 'dmax']]
517    if rmin is not None:
518      r = math_ops.maximum(r, rmin)
519    if rmax is not None:
520      r = math_ops.minimum(r, rmax)
521    if dmax is not None:
522      d = math_ops.maximum(d, -dmax)
523      d = math_ops.minimum(d, dmax)
524    # When not training, use r=1, d=0.
525    r = tf_utils.smart_cond(training, lambda: r, lambda: array_ops.ones_like(r))
526    d = tf_utils.smart_cond(training,
527                            lambda: d,
528                            lambda: array_ops.zeros_like(d))
529
530    def _update_renorm_variable(var, weight, value):
531      """Updates a moving average and weight, returns the unbiased value."""
532      value = array_ops.identity(value)
533      def _do_update():
534        """Updates the var and weight, returns their updated ratio."""
535        # Update the variables without zero debiasing. The debiasing will be
536        # accomplished by dividing the exponential moving average by the weight.
537        # For example, after a single update, the moving average would be
538        # (1-decay) * value. and the weight will be 1-decay, with their ratio
539        # giving the value.
540        # Make sure the weight is not updated until before r and d computation.
541        with ops.control_dependencies([value]):
542          weight_value = array_ops.constant(1., dtype=weight.dtype)
543        new_var = self._assign_moving_average(var, value, self.renorm_momentum)
544        new_weight = self._assign_moving_average(weight, weight_value,
545                                                 self.renorm_momentum)
546        # TODO(yuefengz): the updates to var and weighted can not be batched
547        # together if we fetch their updated values here. Consider calculating
548        # new values and delaying the updates.
549        return new_var / new_weight
550
551      def _fake_update():
552        return array_ops.identity(var)
553      return tf_utils.smart_cond(training, _do_update, _fake_update)
554
555    # TODO(yuefengz): colocate the operations
556    new_mean = _update_renorm_variable(self.renorm_mean,
557                                       self.renorm_mean_weight, mean)
558    new_stddev = _update_renorm_variable(self.renorm_stddev,
559                                         self.renorm_stddev_weight, stddev)
560    # Make sqrt(moving_variance + epsilon) = new_stddev.
561    new_variance = math_ops.square(new_stddev) - self.epsilon
562
563    return (r, d, new_mean, new_variance)
564
565  def _moments(self, inputs, reduction_axes, keep_dims):
566    return nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
567
568  def call(self, inputs, training=None):
569    if training is None:
570      training = K.learning_phase()
571
572    in_eager_mode = context.executing_eagerly()
573    if self.virtual_batch_size is not None:
574      # Virtual batches (aka ghost batches) can be simulated by reshaping the
575      # Tensor and reusing the existing batch norm implementation
576      original_shape = [-1] + inputs.shape.as_list()[1:]
577      expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:]
578
579      # Will cause errors if virtual_batch_size does not divide the batch size
580      inputs = array_ops.reshape(inputs, expanded_shape)
581
582      def undo_virtual_batching(outputs):
583        outputs = array_ops.reshape(outputs, original_shape)
584        return outputs
585
586    if self.fused:
587      outputs = self._fused_batch_norm(inputs, training=training)
588      if self.virtual_batch_size is not None:
589        # Currently never reaches here since fused_batch_norm does not support
590        # virtual batching
591        outputs = undo_virtual_batching(outputs)
592      return outputs
593
594    # Compute the axes along which to reduce the mean / variance
595    input_shape = inputs.get_shape()
596    ndims = len(input_shape)
597    reduction_axes = [i for i in range(ndims) if i not in self.axis]
598    if self.virtual_batch_size is not None:
599      del reduction_axes[1]     # Do not reduce along virtual batch dim
600
601    # Broadcasting only necessary for single-axis batch norm where the axis is
602    # not the last dimension
603    broadcast_shape = [1] * ndims
604    broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value
605    def _broadcast(v):
606      if (v is not None and
607          len(v.get_shape()) != ndims and
608          reduction_axes != list(range(ndims - 1))):
609        return array_ops.reshape(v, broadcast_shape)
610      return v
611
612    scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
613
614    def _compose_transforms(scale, offset, then_scale, then_offset):
615      if then_scale is not None:
616        scale *= then_scale
617        offset *= then_scale
618      if then_offset is not None:
619        offset += then_offset
620      return (scale, offset)
621
622    # Determine a boolean value for `training`: could be True, False, or None.
623    training_value = tf_utils.constant_value(training)
624    if training_value is not False:
625      if self.adjustment:
626        adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
627        # Adjust only during training.
628        adj_scale = tf_utils.smart_cond(training,
629                                        lambda: adj_scale,
630                                        lambda: array_ops.ones_like(adj_scale))
631        adj_bias = tf_utils.smart_cond(training,
632                                       lambda: adj_bias,
633                                       lambda: array_ops.zeros_like(adj_bias))
634        scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset)
635
636      # Some of the computations here are not necessary when training==False
637      # but not a constant. However, this makes the code simpler.
638      keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1
639      mean, variance = self._moments(
640          math_ops.cast(inputs, self._param_dtype),
641          reduction_axes,
642          keep_dims=keep_dims)
643
644      moving_mean = self.moving_mean
645      moving_variance = self.moving_variance
646
647      mean = tf_utils.smart_cond(training,
648                                 lambda: mean,
649                                 lambda: moving_mean)
650      variance = tf_utils.smart_cond(training,
651                                     lambda: variance,
652                                     lambda: moving_variance)
653
654      if self.virtual_batch_size is not None:
655        # This isn't strictly correct since in ghost batch norm, you are
656        # supposed to sequentially update the moving_mean and moving_variance
657        # with each sub-batch. However, since the moving statistics are only
658        # used during evaluation, it is more efficient to just update in one
659        # step and should not make a significant difference in the result.
660        new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True)
661        new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True)
662      else:
663        new_mean, new_variance = mean, variance
664
665      if self.renorm:
666        r, d, new_mean, new_variance = self._renorm_correction_and_moments(
667            new_mean, new_variance, training)
668        # When training, the normalized values (say, x) will be transformed as
669        # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
670        # = x * (r * gamma) + (d * gamma + beta) with renorm.
671        r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
672        d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
673        scale, offset = _compose_transforms(r, d, scale, offset)
674
675      if distribution_strategy_context.in_cross_replica_context():
676        strategy = distribution_strategy_context.get_strategy()
677
678        def _do_update(var, value):
679          """Compute the updates for mean and variance."""
680          if in_eager_mode and not self.trainable:
681            return
682          return strategy.extended.update(
683              var, self._assign_moving_average, (value, self.momentum),
684              group=False)
685        # We need to unwrap the moving_mean or moving_variance in the case of
686        # training being false to match the output of true_fn and false_fn
687        # in the smart cond.
688        mean_update = tf_utils.smart_cond(
689            training,
690            lambda: _do_update(self.moving_mean, new_mean),
691            lambda: strategy.unwrap(self.moving_mean))
692        variance_update = tf_utils.smart_cond(
693            training,
694            lambda: _do_update(self.moving_variance, new_variance),
695            lambda: strategy.unwrap(self.moving_variance))
696      else:
697        def _do_update(var, value):
698          """Compute the updates for mean and variance."""
699          if in_eager_mode and not self.trainable:
700            return
701          return self._assign_moving_average(var, value, self.momentum)
702        mean_update = tf_utils.smart_cond(
703            training,
704            lambda: _do_update(self.moving_mean, new_mean),
705            lambda: self.moving_mean)
706        variance_update = tf_utils.smart_cond(
707            training,
708            lambda: _do_update(self.moving_variance, new_variance),
709            lambda: self.moving_variance)
710      if not context.executing_eagerly():
711        self.add_update(mean_update, inputs=True)
712        self.add_update(variance_update, inputs=True)
713
714    else:
715      mean, variance = self.moving_mean, self.moving_variance
716
717    mean = math_ops.cast(mean, inputs.dtype)
718    variance = math_ops.cast(variance, inputs.dtype)
719    if offset is not None:
720      offset = math_ops.cast(offset, inputs.dtype)
721    if scale is not None:
722      scale = math_ops.cast(scale, inputs.dtype)
723    # TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing
724    # math in float16 hurts validation accuracy of popular models like resnet.
725    outputs = nn.batch_normalization(inputs,
726                                     _broadcast(mean),
727                                     _broadcast(variance),
728                                     offset,
729                                     scale,
730                                     self.epsilon)
731    # If some components of the shape got lost due to adjustments, fix that.
732    outputs.set_shape(input_shape)
733
734    if self.virtual_batch_size is not None:
735      outputs = undo_virtual_batching(outputs)
736    return outputs
737
738  def compute_output_shape(self, input_shape):
739    return input_shape
740
741  def get_config(self):
742    config = {
743        'axis': self.axis,
744        'momentum': self.momentum,
745        'epsilon': self.epsilon,
746        'center': self.center,
747        'scale': self.scale,
748        'beta_initializer': initializers.serialize(self.beta_initializer),
749        'gamma_initializer': initializers.serialize(self.gamma_initializer),
750        'moving_mean_initializer':
751            initializers.serialize(self.moving_mean_initializer),
752        'moving_variance_initializer':
753            initializers.serialize(self.moving_variance_initializer),
754        'beta_regularizer': regularizers.serialize(self.beta_regularizer),
755        'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
756        'beta_constraint': constraints.serialize(self.beta_constraint),
757        'gamma_constraint': constraints.serialize(self.gamma_constraint)
758    }
759    # Only add TensorFlow-specific parameters if they are set, so as to preserve
760    # model compatibility with external Keras.
761    if self.renorm:
762      config['renorm'] = True
763      config['renorm_clipping'] = self.renorm_clipping
764      config['renorm_momentum'] = self.renorm_momentum
765    if self.virtual_batch_size is not None:
766      config['virtual_batch_size'] = self.virtual_batch_size
767    # Note: adjustment is not serializable.
768    if self.adjustment is not None:
769      logging.warning('The `adjustment` function of this `BatchNormalization` '
770                      'layer cannot be serialized and has been omitted from '
771                      'the layer config. It will not be included when '
772                      're-creating the layer from the saved config.')
773    base_config = super(BatchNormalizationBase, self).get_config()
774    return dict(list(base_config.items()) + list(config.items()))
775
776
777def _replace_in_base_docstring(old, new):
778  string = BatchNormalizationBase.__doc__
779  if old not in string:
780    raise ValueError('Could not find following string in BatchNormalizationBase'
781                     ' docstring: "{}"'.format(old))
782  return string.replace(old, new)
783
784
785@keras_export(v1=['keras.layers.BatchNormalization'])  # pylint: disable=missing-docstring
786class BatchNormalization(BatchNormalizationBase):
787
788  __doc__ = _replace_in_base_docstring(
789      '''
790    fused: if `True`, use a faster, fused implementation, or raise a ValueError
791      if the fused implementation cannot be used. If `None`, use the faster
792      implementation if possible. If False, do not used the fused
793      implementation.''',
794
795      '''
796    fused: if `None` or `True`, use a faster, fused implementation if possible.
797      If `False`, use the system recommended implementation.''')
798
799  _USE_V2_BEHAVIOR = False
800
801
802@keras_export('keras.layers.experimental.LayerNormalization')
803class LayerNormalization(Layer):
804  """Layer normalization layer (Ba et al., 2016).
805
806  Normalize the activations of the previous layer for each given example in a
807  batch independently, rather than across a batch like Batch Normalization.
808  i.e. applies a transformation that maintains the mean activation within each
809  example close to 0 and the activation standard deviation close to 1.
810
811  Given a tensor `inputs` of rank `R`, moments are calculated and normalization
812  is performed over all axes in norm_axis.  Scaling and centering,
813  if requested, is performed over all axes in params_axis.
814
815  By default, normalization is performed over all but the first axis
816  (the `HWC` if `inputs` is `NHWC`), while the `beta` and `gamma` trainable
817  parameters are calculated for the rightmost axis (the `C` if `inputs` is
818  `NHWC`).  Scaling and recentering is performed via broadcast of the
819  `beta` and `gamma` parameters with the normalized tensor.
820
821  The shapes of `beta` and `gamma` are
822  `[inputs.shape[i] for i in (param axes)]`,
823  and this part of the inputs' shape must be fully defined.
824
825  Arguments:
826    norm_axis: Integer or List. normalization will be
827      performed along these dimensions. If unspecified (None), it will default
828      to the dimensions `begin_norm_axis : rank(inputs)`
829    params_axis: Integer or List. The (beta, gamma) dimensions: scale
830      and centering parameters will have take their shapes from these axes and
831      will be broadcast with the normalized inputs accordingly. If unspecified
832      (None), it will default to the last dimension
833    epsilon: Small float added to variance to avoid dividing by zero.
834    center: If True, add offset of `beta` to normalized tensor.
835        If False, `beta` is ignored.
836    scale: If True, multiply by `gamma`.
837      If False, `gamma` is not used.
838      When the next layer is linear (also e.g. `nn.relu`),
839      this can be disabled since the scaling
840      will be done by the next layer.
841    beta_initializer: Initializer for the beta weight.
842    gamma_initializer: Initializer for the gamma weight.
843    beta_regularizer: Optional regularizer for the beta weight.
844    gamma_regularizer: Optional regularizer for the gamma weight.
845    beta_constraint: Optional constraint for the beta weight.
846    gamma_constraint: Optional constraint for the gamma weight.
847    trainable: Boolean, if `True` the variables will be marked as trainable.
848
849  Input shape:
850    Arbitrary. Use the keyword argument `input_shape`
851    (tuple of integers, does not include the samples axis)
852    when using this layer as the first layer in a model.
853
854  Output shape:
855    Same shape as input.
856
857  References:
858    - [Layer Normalization](https://arxiv.org/abs/1607.06450)
859  """
860
861  def __init__(self,
862               norm_axis=None,
863               params_axis=-1,
864               epsilon=1e-12,
865               center=True,
866               scale=True,
867               beta_initializer='zeros',
868               gamma_initializer='ones',
869               beta_regularizer=None,
870               gamma_regularizer=None,
871               beta_constraint=None,
872               gamma_constraint=None,
873               trainable=True,
874               name=None,
875               **kwargs):
876    super(LayerNormalization, self).__init__(
877        name=name, trainable=trainable, **kwargs)
878    if isinstance(norm_axis, list):
879      self.norm_axis = norm_axis[:]
880    elif isinstance(norm_axis, int):
881      self.norm_axis = norm_axis
882    elif norm_axis is None:
883      self.norm_axis = None
884    else:
885      raise TypeError('norm_axis must be int or list or None, type given: %s'
886                      % type(norm_axis))
887
888    if isinstance(params_axis, list):
889      self.params_axis = params_axis[:]
890    elif isinstance(params_axis, int):
891      self.params_axis = params_axis
892    else:
893      raise TypeError('params_axis must be int or list, type given: %s'
894                      % type(params_axis))
895
896    self.epsilon = epsilon
897    self.center = center
898    self.scale = scale
899    self.beta_initializer = initializers.get(beta_initializer)
900    self.gamma_initializer = initializers.get(gamma_initializer)
901    self.beta_regularizer = regularizers.get(beta_regularizer)
902    self.gamma_regularizer = regularizers.get(gamma_regularizer)
903    self.beta_constraint = constraints.get(beta_constraint)
904    self.gamma_constraint = constraints.get(gamma_constraint)
905
906    self.supports_masking = True
907
908  def build(self, input_shape):
909    ndims = len(input_shape)
910    if ndims is None:
911      raise ValueError('Input shape %s has undefined rank.' % input_shape)
912
913    # Handle an unspecified norm_axis
914    if self.norm_axis is None:
915      self.norm_axis = list(range(1, ndims))
916
917    # Convert axes to lists and resolve negatives
918    if isinstance(self.norm_axis, int):
919      self.norm_axis = [self.norm_axis]
920    for idx, x in enumerate(self.norm_axis):
921      if x < 0:
922        self.norm_axis[idx] = ndims + x
923
924    if isinstance(self.params_axis, int):
925      self.params_axis = [self.params_axis]
926    for idx, x in enumerate(self.params_axis):
927      if x < 0:
928        self.params_axis[idx] = ndims + x
929
930    # Validate axes
931    for x in self.norm_axis:
932      if x < 0 or x >= ndims:
933        raise ValueError('Invalid axis: %d' % x)
934    if len(self.norm_axis) != len(set(self.norm_axis)):
935      raise ValueError('Duplicate axis: %s' % self.norm_axis)
936
937    for x in self.params_axis:
938      if x < 0 or x >= ndims:
939        raise ValueError('Invalid axis: %d' % x)
940    if len(self.params_axis) != len(set(self.params_axis)):
941      raise ValueError('Duplicate axis: %s' % self.params_axis)
942
943    param_shape = [input_shape[dim] for dim in self.params_axis]
944
945    if self.scale:
946      self.gamma = self.add_weight(
947          name='gamma',
948          shape=param_shape,
949          initializer=self.gamma_initializer,
950          regularizer=self.gamma_regularizer,
951          constraint=self.gamma_constraint,
952          trainable=True,
953          experimental_autocast=False)
954    else:
955      self.gamma = None
956
957    if self.center:
958      self.beta = self.add_weight(
959          name='beta',
960          shape=param_shape,
961          initializer=self.beta_initializer,
962          regularizer=self.beta_regularizer,
963          constraint=self.beta_constraint,
964          trainable=True,
965          experimental_autocast=False)
966    else:
967      self.beta = None
968
969  def call(self, inputs):
970    # Compute the axes along which to reduce the mean / variance
971    input_shape = inputs.get_shape()
972    ndims = len(input_shape)
973
974    # Calculate the moments on the last axis (layer activations).
975    mean, variance = nn.moments(inputs, self.norm_axis, keep_dims=True)
976
977    # Broadcasting only necessary for norm where the params axes aren't just
978    # the last dimension
979    broadcast_shape = [1] * ndims
980    for dim in self.params_axis:
981      broadcast_shape[dim] = input_shape.dims[dim].value
982    def _broadcast(v):
983      if (v is not None and
984          len(v.get_shape()) != ndims and
985          self.params_axis != [ndims - 1]):
986        return array_ops.reshape(v, broadcast_shape)
987      return v
988    scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
989
990    # Compute layer normalization using the batch_normalization function.
991    outputs = nn.batch_normalization(
992        inputs,
993        mean,
994        variance,
995        offset=offset,
996        scale=scale,
997        variance_epsilon=self.epsilon)
998
999    # If some components of the shape got lost due to adjustments, fix that.
1000    outputs.set_shape(input_shape)
1001
1002    return outputs
1003
1004  def compute_output_shape(self, input_shape):
1005    return input_shape
1006
1007  def get_config(self):
1008    config = {
1009        'norm_axis': self.norm_axis,
1010        'params_axis': self.params_axis,
1011        'epsilon': self.epsilon,
1012        'center': self.center,
1013        'scale': self.scale,
1014        'beta_initializer': initializers.serialize(self.beta_initializer),
1015        'gamma_initializer': initializers.serialize(self.gamma_initializer),
1016        'beta_regularizer': regularizers.serialize(self.beta_regularizer),
1017        'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
1018        'beta_constraint': constraints.serialize(self.beta_constraint),
1019        'gamma_constraint': constraints.serialize(self.gamma_constraint)
1020    }
1021    base_config = super(LayerNormalization, self).get_config()
1022    return dict(list(base_config.items()) + list(config.items()))
1023