1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""The V2 implementation of Normalization layers. 16""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.distribute import distribution_strategy_context as ds 22from tensorflow.python.distribute import reduce_util 23from tensorflow.python.framework import dtypes 24from tensorflow.python.keras import backend 25from tensorflow.python.keras.layers import normalization 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.util.tf_export import keras_export 29 30 31# pylint: disable=g-classes-have-attributes 32@keras_export('keras.layers.experimental.SyncBatchNormalization', v1=[]) 33class SyncBatchNormalization(normalization.BatchNormalizationBase): 34 r"""Normalize and scale inputs or activations synchronously across replicas. 35 36 Applies batch normalization to activations of the previous layer at each batch 37 by synchronizing the global batch statistics across all devices that are 38 training the model. For specific details about batch normalization please 39 refer to the `tf.keras.layers.BatchNormalization` layer docs. 40 41 If this layer is used when using tf.distribute strategy to train models 42 across devices/workers, there will be an allreduce call to aggregate batch 43 statistics across all replicas at every training step. Without tf.distribute 44 strategy, this layer behaves as a regular `tf.keras.layers.BatchNormalization` 45 layer. 46 47 Example usage: 48 49 ```python 50 strategy = tf.distribute.MirroredStrategy() 51 52 with strategy.scope(): 53 model = tf.keras.Sequential() 54 model.add(tf.keras.layers.Dense(16)) 55 model.add(tf.keras.layers.experimental.SyncBatchNormalization()) 56 ``` 57 58 Args: 59 axis: Integer, the axis that should be normalized 60 (typically the features axis). 61 For instance, after a `Conv2D` layer with 62 `data_format="channels_first"`, 63 set `axis=1` in `BatchNormalization`. 64 momentum: Momentum for the moving average. 65 epsilon: Small float added to variance to avoid dividing by zero. 66 center: If True, add offset of `beta` to normalized tensor. 67 If False, `beta` is ignored. 68 scale: If True, multiply by `gamma`. 69 If False, `gamma` is not used. 70 When the next layer is linear (also e.g. `nn.relu`), 71 this can be disabled since the scaling 72 will be done by the next layer. 73 beta_initializer: Initializer for the beta weight. 74 gamma_initializer: Initializer for the gamma weight. 75 moving_mean_initializer: Initializer for the moving mean. 76 moving_variance_initializer: Initializer for the moving variance. 77 beta_regularizer: Optional regularizer for the beta weight. 78 gamma_regularizer: Optional regularizer for the gamma weight. 79 beta_constraint: Optional constraint for the beta weight. 80 gamma_constraint: Optional constraint for the gamma weight. 81 82 Call arguments: 83 inputs: Input tensor (of any rank). 84 training: Python boolean indicating whether the layer should behave in 85 training mode or in inference mode. 86 - `training=True`: The layer will normalize its inputs using the 87 mean and variance of the current batch of inputs. 88 - `training=False`: The layer will normalize its inputs using the 89 mean and variance of its moving statistics, learned during training. 90 91 Input shape: 92 Arbitrary. Use the keyword argument `input_shape` 93 (tuple of integers, does not include the samples axis) 94 when using this layer as the first layer in a model. 95 96 Output shape: 97 Same shape as input. 98 99 """ 100 101 def __init__(self, 102 axis=-1, 103 momentum=0.99, 104 epsilon=1e-3, 105 center=True, 106 scale=True, 107 beta_initializer='zeros', 108 gamma_initializer='ones', 109 moving_mean_initializer='zeros', 110 moving_variance_initializer='ones', 111 beta_regularizer=None, 112 gamma_regularizer=None, 113 beta_constraint=None, 114 gamma_constraint=None, 115 **kwargs): 116 if kwargs.pop('fused', None): 117 raise ValueError( 118 '`fused` argument cannot be True for SyncBatchNormalization.') 119 120 # Currently we only support aggregating over the global batch size. 121 super(SyncBatchNormalization, self).__init__( 122 axis=axis, 123 momentum=momentum, 124 epsilon=epsilon, 125 center=center, 126 scale=scale, 127 beta_initializer=beta_initializer, 128 gamma_initializer=gamma_initializer, 129 moving_mean_initializer=moving_mean_initializer, 130 moving_variance_initializer=moving_variance_initializer, 131 beta_regularizer=beta_regularizer, 132 gamma_regularizer=gamma_regularizer, 133 beta_constraint=beta_constraint, 134 gamma_constraint=gamma_constraint, 135 fused=False, 136 **kwargs) 137 138 def _calculate_mean_and_var(self, x, axes, keep_dims): 139 140 with backend.name_scope('moments'): 141 # The dynamic range of fp16 is too limited to support the collection of 142 # sufficient statistics. As a workaround we simply perform the operations 143 # on 32-bit floats before converting the mean and variance back to fp16 144 y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x 145 replica_ctx = ds.get_replica_context() 146 if replica_ctx: 147 local_sum = math_ops.reduce_sum(y, axis=axes, keepdims=True) 148 local_squared_sum = math_ops.reduce_sum(math_ops.square(y), axis=axes, 149 keepdims=True) 150 batch_size = math_ops.cast(array_ops.shape_v2(y)[0], dtypes.float32) 151 # TODO(b/163099951): batch the all-reduces once we sort out the ordering 152 # issue for NCCL. We don't have a mechanism to launch NCCL in the same 153 # order in each replica nowadays, so we limit NCCL to batch all-reduces. 154 y_sum = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM, local_sum) 155 y_squared_sum = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM, 156 local_squared_sum) 157 global_batch_size = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM, 158 batch_size) 159 160 axes_vals = [(array_ops.shape_v2(y))[i] for i in range(1, len(axes))] 161 multiplier = math_ops.cast(math_ops.reduce_prod(axes_vals), 162 dtypes.float32) 163 multiplier = multiplier * global_batch_size 164 165 mean = y_sum / multiplier 166 y_squared_mean = y_squared_sum / multiplier 167 # var = E(x^2) - E(x)^2 168 variance = y_squared_mean - math_ops.square(mean) 169 else: 170 # Compute true mean while keeping the dims for proper broadcasting. 171 mean = math_ops.reduce_mean(y, axes, keepdims=True, name='mean') 172 # sample variance, not unbiased variance 173 # Note: stop_gradient does not change the gradient that gets 174 # backpropagated to the mean from the variance calculation, 175 # because that gradient is zero 176 variance = math_ops.reduce_mean( 177 math_ops.squared_difference(y, array_ops.stop_gradient(mean)), 178 axes, 179 keepdims=True, 180 name='variance') 181 if not keep_dims: 182 mean = array_ops.squeeze(mean, axes) 183 variance = array_ops.squeeze(variance, axes) 184 if x.dtype == dtypes.float16: 185 return (math_ops.cast(mean, dtypes.float16), 186 math_ops.cast(variance, dtypes.float16)) 187 else: 188 return (mean, variance) 189 190 191@keras_export('keras.layers.BatchNormalization', v1=[]) 192class BatchNormalization(normalization.BatchNormalizationBase): 193 """Layer that normalizes its inputs. 194 195 Batch normalization applies a transformation that maintains the mean output 196 close to 0 and the output standard deviation close to 1. 197 198 Importantly, batch normalization works differently during training and 199 during inference. 200 201 **During training** (i.e. when using `fit()` or when calling the layer/model 202 with the argument `training=True`), the layer normalizes its output using 203 the mean and standard deviation of the current batch of inputs. That is to 204 say, for each channel being normalized, the layer returns 205 `(batch - mean(batch)) / (var(batch) + epsilon) * gamma + beta`, where: 206 207 - `epsilon` is small constant (configurable as part of the constructor 208 arguments) 209 - `gamma` is a learned scaling factor (initialized as 1), which 210 can be disabled by passing `scale=False` to the constructor. 211 - `beta` is a learned offset factor (initialized as 0), which 212 can be disabled by passing `center=False` to the constructor. 213 214 **During inference** (i.e. when using `evaluate()` or `predict()` or when 215 calling the layer/model with the argument `training=False` (which is the 216 default), the layer normalizes its output using a moving average of the 217 mean and standard deviation of the batches it has seen during training. That 218 is to say, it returns 219 `(batch - self.moving_mean) / (self.moving_var + epsilon) * gamma + beta`. 220 221 `self.moving_mean` and `self.moving_var` are non-trainable variables that 222 are updated each time the layer in called in training mode, as such: 223 224 - `moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)` 225 - `moving_var = moving_var * momentum + var(batch) * (1 - momentum)` 226 227 As such, the layer will only normalize its inputs during inference 228 *after having been trained on data that has similar statistics as the 229 inference data*. 230 231 Args: 232 axis: Integer, the axis that should be normalized (typically the features 233 axis). For instance, after a `Conv2D` layer with 234 `data_format="channels_first"`, set `axis=1` in `BatchNormalization`. 235 momentum: Momentum for the moving average. 236 epsilon: Small float added to variance to avoid dividing by zero. 237 center: If True, add offset of `beta` to normalized tensor. If False, `beta` 238 is ignored. 239 scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the 240 next layer is linear (also e.g. `nn.relu`), this can be disabled since the 241 scaling will be done by the next layer. 242 beta_initializer: Initializer for the beta weight. 243 gamma_initializer: Initializer for the gamma weight. 244 moving_mean_initializer: Initializer for the moving mean. 245 moving_variance_initializer: Initializer for the moving variance. 246 beta_regularizer: Optional regularizer for the beta weight. 247 gamma_regularizer: Optional regularizer for the gamma weight. 248 beta_constraint: Optional constraint for the beta weight. 249 gamma_constraint: Optional constraint for the gamma weight. 250 251 Call arguments: 252 inputs: Input tensor (of any rank). 253 training: Python boolean indicating whether the layer should behave in 254 training mode or in inference mode. 255 - `training=True`: The layer will normalize its inputs using the mean and 256 variance of the current batch of inputs. 257 - `training=False`: The layer will normalize its inputs using the mean and 258 variance of its moving statistics, learned during training. 259 260 Input shape: 261 Arbitrary. Use the keyword argument `input_shape` (tuple of 262 integers, does not include the samples axis) when using this layer as the 263 first layer in a model. 264 265 Output shape: 266 Same shape as input. 267 268 Reference: 269 - [Ioffe and Szegedy, 2015](https://arxiv.org/abs/1502.03167). 270 271 **About setting `layer.trainable = False` on a `BatchNormalization` layer:** 272 273 The meaning of setting `layer.trainable = False` is to freeze the layer, 274 i.e. its internal state will not change during training: 275 its trainable weights will not be updated 276 during `fit()` or `train_on_batch()`, and its state updates will not be run. 277 278 Usually, this does not necessarily mean that the layer is run in inference 279 mode (which is normally controlled by the `training` argument that can 280 be passed when calling a layer). "Frozen state" and "inference mode" 281 are two separate concepts. 282 283 However, in the case of the `BatchNormalization` layer, **setting 284 `trainable = False` on the layer means that the layer will be 285 subsequently run in inference mode** (meaning that it will use 286 the moving mean and the moving variance to normalize the current batch, 287 rather than using the mean and variance of the current batch). 288 289 This behavior has been introduced in TensorFlow 2.0, in order 290 to enable `layer.trainable = False` to produce the most commonly 291 expected behavior in the convnet fine-tuning use case. 292 293 Note that: 294 - Setting `trainable` on an model containing other layers will 295 recursively set the `trainable` value of all inner layers. 296 - If the value of the `trainable` 297 attribute is changed after calling `compile()` on a model, 298 the new value doesn't take effect for this model 299 until `compile()` is called again. 300 """ 301 _USE_V2_BEHAVIOR = True 302 303 def __init__(self, 304 axis=-1, 305 momentum=0.99, 306 epsilon=1e-3, 307 center=True, 308 scale=True, 309 beta_initializer='zeros', 310 gamma_initializer='ones', 311 moving_mean_initializer='zeros', 312 moving_variance_initializer='ones', 313 beta_regularizer=None, 314 gamma_regularizer=None, 315 beta_constraint=None, 316 gamma_constraint=None, 317 **kwargs): 318 super(BatchNormalization, self).__init__( 319 axis=axis, 320 momentum=momentum, 321 epsilon=epsilon, 322 center=center, 323 scale=scale, 324 beta_initializer=beta_initializer, 325 gamma_initializer=gamma_initializer, 326 moving_mean_initializer=moving_mean_initializer, 327 moving_variance_initializer=moving_variance_initializer, 328 beta_regularizer=beta_regularizer, 329 gamma_regularizer=gamma_regularizer, 330 beta_constraint=beta_constraint, 331 gamma_constraint=gamma_constraint, 332 **kwargs) 333