1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The V2 implementation of Normalization layers.
16"""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.distribute import distribution_strategy_context as ds
22from tensorflow.python.distribute import reduce_util
23from tensorflow.python.framework import dtypes
24from tensorflow.python.keras import backend
25from tensorflow.python.keras.layers import normalization
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.util.tf_export import keras_export
29
30
31# pylint: disable=g-classes-have-attributes
32@keras_export('keras.layers.experimental.SyncBatchNormalization', v1=[])
33class SyncBatchNormalization(normalization.BatchNormalizationBase):
34  r"""Normalize and scale inputs or activations synchronously across replicas.
35
36  Applies batch normalization to activations of the previous layer at each batch
37  by synchronizing the global batch statistics across all devices that are
38  training the model. For specific details about batch normalization please
39  refer to the `tf.keras.layers.BatchNormalization` layer docs.
40
41  If this layer is used when using tf.distribute strategy to train models
42  across devices/workers, there will be an allreduce call to aggregate batch
43  statistics across all replicas at every training step. Without tf.distribute
44  strategy, this layer behaves as a regular `tf.keras.layers.BatchNormalization`
45  layer.
46
47  Example usage:
48
49  ```python
50  strategy = tf.distribute.MirroredStrategy()
51
52  with strategy.scope():
53    model = tf.keras.Sequential()
54    model.add(tf.keras.layers.Dense(16))
55    model.add(tf.keras.layers.experimental.SyncBatchNormalization())
56  ```
57
58  Args:
59    axis: Integer, the axis that should be normalized
60      (typically the features axis).
61      For instance, after a `Conv2D` layer with
62      `data_format="channels_first"`,
63      set `axis=1` in `BatchNormalization`.
64    momentum: Momentum for the moving average.
65    epsilon: Small float added to variance to avoid dividing by zero.
66    center: If True, add offset of `beta` to normalized tensor.
67      If False, `beta` is ignored.
68    scale: If True, multiply by `gamma`.
69      If False, `gamma` is not used.
70      When the next layer is linear (also e.g. `nn.relu`),
71      this can be disabled since the scaling
72      will be done by the next layer.
73    beta_initializer: Initializer for the beta weight.
74    gamma_initializer: Initializer for the gamma weight.
75    moving_mean_initializer: Initializer for the moving mean.
76    moving_variance_initializer: Initializer for the moving variance.
77    beta_regularizer: Optional regularizer for the beta weight.
78    gamma_regularizer: Optional regularizer for the gamma weight.
79    beta_constraint: Optional constraint for the beta weight.
80    gamma_constraint: Optional constraint for the gamma weight.
81
82  Call arguments:
83    inputs: Input tensor (of any rank).
84    training: Python boolean indicating whether the layer should behave in
85      training mode or in inference mode.
86      - `training=True`: The layer will normalize its inputs using the
87        mean and variance of the current batch of inputs.
88      - `training=False`: The layer will normalize its inputs using the
89        mean and variance of its moving statistics, learned during training.
90
91  Input shape:
92    Arbitrary. Use the keyword argument `input_shape`
93    (tuple of integers, does not include the samples axis)
94    when using this layer as the first layer in a model.
95
96  Output shape:
97    Same shape as input.
98
99  """
100
101  def __init__(self,
102               axis=-1,
103               momentum=0.99,
104               epsilon=1e-3,
105               center=True,
106               scale=True,
107               beta_initializer='zeros',
108               gamma_initializer='ones',
109               moving_mean_initializer='zeros',
110               moving_variance_initializer='ones',
111               beta_regularizer=None,
112               gamma_regularizer=None,
113               beta_constraint=None,
114               gamma_constraint=None,
115               **kwargs):
116    if kwargs.pop('fused', None):
117      raise ValueError(
118          '`fused` argument cannot be True for SyncBatchNormalization.')
119
120    # Currently we only support aggregating over the global batch size.
121    super(SyncBatchNormalization, self).__init__(
122        axis=axis,
123        momentum=momentum,
124        epsilon=epsilon,
125        center=center,
126        scale=scale,
127        beta_initializer=beta_initializer,
128        gamma_initializer=gamma_initializer,
129        moving_mean_initializer=moving_mean_initializer,
130        moving_variance_initializer=moving_variance_initializer,
131        beta_regularizer=beta_regularizer,
132        gamma_regularizer=gamma_regularizer,
133        beta_constraint=beta_constraint,
134        gamma_constraint=gamma_constraint,
135        fused=False,
136        **kwargs)
137
138  def _calculate_mean_and_var(self, x, axes, keep_dims):
139
140    with backend.name_scope('moments'):
141      # The dynamic range of fp16 is too limited to support the collection of
142      # sufficient statistics. As a workaround we simply perform the operations
143      # on 32-bit floats before converting the mean and variance back to fp16
144      y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
145      replica_ctx = ds.get_replica_context()
146      if replica_ctx:
147        local_sum = math_ops.reduce_sum(y, axis=axes, keepdims=True)
148        local_squared_sum = math_ops.reduce_sum(math_ops.square(y), axis=axes,
149                                                keepdims=True)
150        batch_size = math_ops.cast(array_ops.shape_v2(y)[0], dtypes.float32)
151        # TODO(b/163099951): batch the all-reduces once we sort out the ordering
152        # issue for NCCL. We don't have a mechanism to launch NCCL in the same
153        # order in each replica nowadays, so we limit NCCL to batch all-reduces.
154        y_sum = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM, local_sum)
155        y_squared_sum = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM,
156                                               local_squared_sum)
157        global_batch_size = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM,
158                                                   batch_size)
159
160        axes_vals = [(array_ops.shape_v2(y))[i] for i in range(1, len(axes))]
161        multiplier = math_ops.cast(math_ops.reduce_prod(axes_vals),
162                                   dtypes.float32)
163        multiplier = multiplier * global_batch_size
164
165        mean = y_sum / multiplier
166        y_squared_mean = y_squared_sum / multiplier
167        # var = E(x^2) - E(x)^2
168        variance = y_squared_mean - math_ops.square(mean)
169      else:
170        # Compute true mean while keeping the dims for proper broadcasting.
171        mean = math_ops.reduce_mean(y, axes, keepdims=True, name='mean')
172        # sample variance, not unbiased variance
173        # Note: stop_gradient does not change the gradient that gets
174        #       backpropagated to the mean from the variance calculation,
175        #       because that gradient is zero
176        variance = math_ops.reduce_mean(
177            math_ops.squared_difference(y, array_ops.stop_gradient(mean)),
178            axes,
179            keepdims=True,
180            name='variance')
181      if not keep_dims:
182        mean = array_ops.squeeze(mean, axes)
183        variance = array_ops.squeeze(variance, axes)
184      if x.dtype == dtypes.float16:
185        return (math_ops.cast(mean, dtypes.float16),
186                math_ops.cast(variance, dtypes.float16))
187      else:
188        return (mean, variance)
189
190
191@keras_export('keras.layers.BatchNormalization', v1=[])
192class BatchNormalization(normalization.BatchNormalizationBase):
193  """Layer that normalizes its inputs.
194
195  Batch normalization applies a transformation that maintains the mean output
196  close to 0 and the output standard deviation close to 1.
197
198  Importantly, batch normalization works differently during training and
199  during inference.
200
201  **During training** (i.e. when using `fit()` or when calling the layer/model
202  with the argument `training=True`), the layer normalizes its output using
203  the mean and standard deviation of the current batch of inputs. That is to
204  say, for each channel being normalized, the layer returns
205  `(batch - mean(batch)) / (var(batch) + epsilon) * gamma + beta`, where:
206
207  - `epsilon` is small constant (configurable as part of the constructor
208  arguments)
209  - `gamma` is a learned scaling factor (initialized as 1), which
210  can be disabled by passing `scale=False` to the constructor.
211  - `beta` is a learned offset factor (initialized as 0), which
212  can be disabled by passing `center=False` to the constructor.
213
214  **During inference** (i.e. when using `evaluate()` or `predict()` or when
215  calling the layer/model with the argument `training=False` (which is the
216  default), the layer normalizes its output using a moving average of the
217  mean and standard deviation of the batches it has seen during training. That
218  is to say, it returns
219  `(batch - self.moving_mean) / (self.moving_var + epsilon) * gamma + beta`.
220
221  `self.moving_mean` and `self.moving_var` are non-trainable variables that
222  are updated each time the layer in called in training mode, as such:
223
224  - `moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)`
225  - `moving_var = moving_var * momentum + var(batch) * (1 - momentum)`
226
227  As such, the layer will only normalize its inputs during inference
228  *after having been trained on data that has similar statistics as the
229  inference data*.
230
231  Args:
232    axis: Integer, the axis that should be normalized (typically the features
233      axis). For instance, after a `Conv2D` layer with
234      `data_format="channels_first"`, set `axis=1` in `BatchNormalization`.
235    momentum: Momentum for the moving average.
236    epsilon: Small float added to variance to avoid dividing by zero.
237    center: If True, add offset of `beta` to normalized tensor. If False, `beta`
238      is ignored.
239    scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the
240      next layer is linear (also e.g. `nn.relu`), this can be disabled since the
241      scaling will be done by the next layer.
242    beta_initializer: Initializer for the beta weight.
243    gamma_initializer: Initializer for the gamma weight.
244    moving_mean_initializer: Initializer for the moving mean.
245    moving_variance_initializer: Initializer for the moving variance.
246    beta_regularizer: Optional regularizer for the beta weight.
247    gamma_regularizer: Optional regularizer for the gamma weight.
248    beta_constraint: Optional constraint for the beta weight.
249    gamma_constraint: Optional constraint for the gamma weight.
250
251  Call arguments:
252    inputs: Input tensor (of any rank).
253    training: Python boolean indicating whether the layer should behave in
254      training mode or in inference mode.
255      - `training=True`: The layer will normalize its inputs using the mean and
256        variance of the current batch of inputs.
257      - `training=False`: The layer will normalize its inputs using the mean and
258        variance of its moving statistics, learned during training.
259
260  Input shape:
261    Arbitrary. Use the keyword argument `input_shape` (tuple of
262    integers, does not include the samples axis) when using this layer as the
263    first layer in a model.
264
265  Output shape:
266    Same shape as input.
267
268  Reference:
269    - [Ioffe and Szegedy, 2015](https://arxiv.org/abs/1502.03167).
270
271  **About setting `layer.trainable = False` on a `BatchNormalization` layer:**
272
273  The meaning of setting `layer.trainable = False` is to freeze the layer,
274  i.e. its internal state will not change during training:
275  its trainable weights will not be updated
276  during `fit()` or `train_on_batch()`, and its state updates will not be run.
277
278  Usually, this does not necessarily mean that the layer is run in inference
279  mode (which is normally controlled by the `training` argument that can
280  be passed when calling a layer). "Frozen state" and "inference mode"
281  are two separate concepts.
282
283  However, in the case of the `BatchNormalization` layer, **setting
284  `trainable = False` on the layer means that the layer will be
285  subsequently run in inference mode** (meaning that it will use
286  the moving mean and the moving variance to normalize the current batch,
287  rather than using the mean and variance of the current batch).
288
289  This behavior has been introduced in TensorFlow 2.0, in order
290  to enable `layer.trainable = False` to produce the most commonly
291  expected behavior in the convnet fine-tuning use case.
292
293  Note that:
294    - Setting `trainable` on an model containing other layers will
295      recursively set the `trainable` value of all inner layers.
296    - If the value of the `trainable`
297      attribute is changed after calling `compile()` on a model,
298      the new value doesn't take effect for this model
299      until `compile()` is called again.
300  """
301  _USE_V2_BEHAVIOR = True
302
303  def __init__(self,
304               axis=-1,
305               momentum=0.99,
306               epsilon=1e-3,
307               center=True,
308               scale=True,
309               beta_initializer='zeros',
310               gamma_initializer='ones',
311               moving_mean_initializer='zeros',
312               moving_variance_initializer='ones',
313               beta_regularizer=None,
314               gamma_regularizer=None,
315               beta_constraint=None,
316               gamma_constraint=None,
317               **kwargs):
318    super(BatchNormalization, self).__init__(
319        axis=axis,
320        momentum=momentum,
321        epsilon=epsilon,
322        center=center,
323        scale=scale,
324        beta_initializer=beta_initializer,
325        gamma_initializer=gamma_initializer,
326        moving_mean_initializer=moving_mean_initializer,
327        moving_variance_initializer=moving_variance_initializer,
328        beta_regularizer=beta_regularizer,
329        gamma_regularizer=gamma_regularizer,
330        beta_constraint=beta_constraint,
331        gamma_constraint=gamma_constraint,
332        **kwargs)
333