1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Normalization layers. 16""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.distribute import distribution_strategy_context 22from tensorflow.python.eager import context 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import tensor_shape 26from tensorflow.python.keras import backend as K 27from tensorflow.python.keras import constraints 28from tensorflow.python.keras import initializers 29from tensorflow.python.keras import regularizers 30from tensorflow.python.keras.engine.base_layer import Layer 31from tensorflow.python.keras.engine.input_spec import InputSpec 32from tensorflow.python.keras.utils import tf_utils 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import init_ops 35from tensorflow.python.ops import math_ops 36from tensorflow.python.ops import nn 37from tensorflow.python.ops import state_ops 38from tensorflow.python.ops import variables as tf_variables 39from tensorflow.python.platform import tf_logging as logging 40from tensorflow.python.util.tf_export import keras_export 41 42 43class BatchNormalizationBase(Layer): 44 """Base class of Batch normalization layer (Ioffe and Szegedy, 2014). 45 46 Normalize the activations of the previous layer at each batch, 47 i.e. applies a transformation that maintains the mean activation 48 close to 0 and the activation standard deviation close to 1. 49 50 Arguments: 51 axis: Integer, the axis that should be normalized 52 (typically the features axis). 53 For instance, after a `Conv2D` layer with 54 `data_format="channels_first"`, 55 set `axis=1` in `BatchNormalization`. 56 momentum: Momentum for the moving average. 57 epsilon: Small float added to variance to avoid dividing by zero. 58 center: If True, add offset of `beta` to normalized tensor. 59 If False, `beta` is ignored. 60 scale: If True, multiply by `gamma`. 61 If False, `gamma` is not used. 62 When the next layer is linear (also e.g. `nn.relu`), 63 this can be disabled since the scaling 64 will be done by the next layer. 65 beta_initializer: Initializer for the beta weight. 66 gamma_initializer: Initializer for the gamma weight. 67 moving_mean_initializer: Initializer for the moving mean. 68 moving_variance_initializer: Initializer for the moving variance. 69 beta_regularizer: Optional regularizer for the beta weight. 70 gamma_regularizer: Optional regularizer for the gamma weight. 71 beta_constraint: Optional constraint for the beta weight. 72 gamma_constraint: Optional constraint for the gamma weight. 73 renorm: Whether to use Batch Renormalization 74 (https://arxiv.org/abs/1702.03275). This adds extra variables during 75 training. The inference is the same for either value of this parameter. 76 renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to 77 scalar `Tensors` used to clip the renorm correction. The correction 78 `(r, d)` is used as `corrected_value = normalized_value * r + d`, with 79 `r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin, 80 dmax are set to inf, 0, inf, respectively. 81 renorm_momentum: Momentum used to update the moving means and standard 82 deviations with renorm. Unlike `momentum`, this affects training 83 and should be neither too small (which would add noise) nor too large 84 (which would give stale estimates). Note that `momentum` is still applied 85 to get the means and variances for inference. 86 fused: if `True`, use a faster, fused implementation, or raise a ValueError 87 if the fused implementation cannot be used. If `None`, use the faster 88 implementation if possible. If False, do not used the fused 89 implementation. 90 trainable: Boolean, if `True` the variables will be marked as trainable. 91 virtual_batch_size: An `int`. By default, `virtual_batch_size` is `None`, 92 which means batch normalization is performed across the whole batch. When 93 `virtual_batch_size` is not `None`, instead perform "Ghost Batch 94 Normalization", which creates virtual sub-batches which are each 95 normalized separately (with shared gamma, beta, and moving statistics). 96 Must divide the actual batch size during execution. 97 adjustment: A function taking the `Tensor` containing the (dynamic) shape of 98 the input tensor and returning a pair (scale, bias) to apply to the 99 normalized values (before gamma and beta), only during training. For 100 example, if axis==-1, 101 `adjustment = lambda shape: ( 102 tf.random_uniform(shape[-1:], 0.93, 1.07), 103 tf.random_uniform(shape[-1:], -0.1, 0.1))` 104 will scale the normalized value by up to 7% up or down, then shift the 105 result by up to 0.1 (with independent scaling and bias for each feature 106 but shared across all examples), and finally apply gamma and/or beta. If 107 `None`, no adjustment is applied. Cannot be specified if 108 virtual_batch_size is specified. 109 110 Call arguments: 111 inputs: Input tensor (of any rank). 112 training: Python boolean indicating whether the layer should behave in 113 training mode or in inference mode. 114 - `training=True`: The layer will normalize its inputs using the 115 mean and variance of the current batch of inputs. 116 - `training=False`: The layer will normalize its inputs using the 117 mean and variance of its moving statistics, learned during training. 118 119 Input shape: 120 Arbitrary. Use the keyword argument `input_shape` 121 (tuple of integers, does not include the samples axis) 122 when using this layer as the first layer in a model. 123 124 Output shape: 125 Same shape as input. 126 127 References: 128 - [Batch Normalization: Accelerating Deep Network Training by Reducing 129 Internal Covariate Shift](https://arxiv.org/abs/1502.03167) 130 """ 131 132 # By default, the base class uses V2 behavior. The BatchNormalization V1 133 # subclass sets this to False to use the V1 behavior. 134 _USE_V2_BEHAVIOR = True 135 136 def __init__(self, 137 axis=-1, 138 momentum=0.99, 139 epsilon=1e-3, 140 center=True, 141 scale=True, 142 beta_initializer='zeros', 143 gamma_initializer='ones', 144 moving_mean_initializer='zeros', 145 moving_variance_initializer='ones', 146 beta_regularizer=None, 147 gamma_regularizer=None, 148 beta_constraint=None, 149 gamma_constraint=None, 150 renorm=False, 151 renorm_clipping=None, 152 renorm_momentum=0.99, 153 fused=None, 154 trainable=True, 155 virtual_batch_size=None, 156 adjustment=None, 157 name=None, 158 **kwargs): 159 super(BatchNormalizationBase, self).__init__( 160 name=name, trainable=trainable, **kwargs) 161 if isinstance(axis, list): 162 self.axis = axis[:] 163 elif isinstance(axis, int): 164 self.axis = axis 165 else: 166 raise TypeError('axis must be int or list, type given: %s' 167 % type(self.axis)) 168 self.momentum = momentum 169 self.epsilon = epsilon 170 self.center = center 171 self.scale = scale 172 self.beta_initializer = initializers.get(beta_initializer) 173 self.gamma_initializer = initializers.get(gamma_initializer) 174 self.moving_mean_initializer = initializers.get(moving_mean_initializer) 175 self.moving_variance_initializer = initializers.get( 176 moving_variance_initializer) 177 self.beta_regularizer = regularizers.get(beta_regularizer) 178 self.gamma_regularizer = regularizers.get(gamma_regularizer) 179 self.beta_constraint = constraints.get(beta_constraint) 180 self.gamma_constraint = constraints.get(gamma_constraint) 181 self.renorm = renorm 182 self.virtual_batch_size = virtual_batch_size 183 self.adjustment = adjustment 184 if self._USE_V2_BEHAVIOR: 185 if fused: 186 self._raise_if_fused_cannot_be_used() 187 # We leave fused as None if self._fused_can_be_used()==True, since we 188 # still may set it to False in self.build() if the input rank is not 4. 189 elif fused is None and not self._fused_can_be_used(): 190 fused = False 191 elif fused is None: 192 fused = True 193 self.supports_masking = True 194 195 self.fused = fused 196 self._bessels_correction_test_only = True 197 198 if renorm: 199 renorm_clipping = renorm_clipping or {} 200 keys = ['rmax', 'rmin', 'dmax'] 201 if set(renorm_clipping) - set(keys): 202 raise ValueError('renorm_clipping %s contains keys not in %s' % 203 (renorm_clipping, keys)) 204 self.renorm_clipping = renorm_clipping 205 self.renorm_momentum = renorm_momentum 206 207 def _raise_if_fused_cannot_be_used(self): 208 """Raises a ValueError if fused implementation cannot be used. 209 210 In addition to the checks done in this function, the input tensors rank must 211 be 4. The input rank check can only be done once the input shape is known. 212 """ 213 # Currently fused batch norm doesn't support renorm. It also only supports a 214 # channel dimension on axis 1 or 3, when no virtual batch size or adjustment 215 # is used. 216 if self.renorm: 217 raise ValueError('Passing both fused=True and renorm=True is ' 218 'unsupported') 219 axis = [self.axis] if isinstance(self.axis, int) else self.axis 220 # Axis -3 is equivalent to 1, and axis -1 is equivalent to 3, because the 221 # input rank is required to be 4 (which is checked later). 222 if len(axis) > 1 or axis[0] not in (-3, -1, 1, 3): 223 raise ValueError('Passing fused=True is only supported when axis is 1 ' 224 'or 3') 225 if self.virtual_batch_size is not None: 226 raise ValueError('Passing fused=True is unsupported when ' 227 'virtual_batch_size is specified.') 228 if self.adjustment is not None: 229 raise ValueError('Passing fused=True is unsupported when ' 230 'adjustment is specified.') 231 232 def _fused_can_be_used(self): 233 try: 234 self._raise_if_fused_cannot_be_used() 235 return True 236 except ValueError: 237 return False 238 239 @property 240 def _param_dtype(self): 241 # Raise parameters of fp16 batch norm to fp32 242 if self.dtype == dtypes.float16 or self.dtype == dtypes.bfloat16: 243 return dtypes.float32 244 else: 245 return self.dtype or dtypes.float32 246 247 def build(self, input_shape): 248 input_shape = tensor_shape.TensorShape(input_shape) 249 if not input_shape.ndims: 250 raise ValueError('Input has undefined rank:', input_shape) 251 ndims = len(input_shape) 252 253 # Convert axis to list and resolve negatives 254 if isinstance(self.axis, int): 255 self.axis = [self.axis] 256 257 for idx, x in enumerate(self.axis): 258 if x < 0: 259 self.axis[idx] = ndims + x 260 261 # Validate axes 262 for x in self.axis: 263 if x < 0 or x >= ndims: 264 raise ValueError('Invalid axis: %d' % x) 265 if len(self.axis) != len(set(self.axis)): 266 raise ValueError('Duplicate axis: %s' % self.axis) 267 268 if self.virtual_batch_size is not None: 269 if self.virtual_batch_size <= 0: 270 raise ValueError('virtual_batch_size must be a positive integer that ' 271 'divides the true batch size of the input Tensor') 272 # If using virtual batches, the first dimension must be the batch 273 # dimension and cannot be the batch norm axis 274 if 0 in self.axis: 275 raise ValueError('When using virtual_batch_size, the batch dimension ' 276 'must be 0 and thus axis cannot include 0') 277 if self.adjustment is not None: 278 raise ValueError('When using virtual_batch_size, adjustment cannot ' 279 'be specified') 280 281 if self.fused in (None, True): 282 # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the 283 # output back to its original shape accordingly. 284 if self._USE_V2_BEHAVIOR: 285 if self.fused is None: 286 self.fused = (ndims == 4) 287 elif self.fused and ndims != 4: 288 raise ValueError('Batch normalization layers with fused=True only ' 289 'support 4D input tensors.') 290 else: 291 assert self.fused is not None 292 self.fused = (ndims == 4 and self._fused_can_be_used()) 293 # TODO(chrisying): fused batch norm is currently not supported for 294 # multi-axis batch norm and by extension virtual batches. In some cases, 295 # it might be possible to use fused batch norm but would require reshaping 296 # the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is 297 # particularly tricky. A compromise might be to just support the most 298 # common use case (turning 5D w/ virtual batch to NCHW) 299 300 if self.fused: 301 if self.axis == [1]: 302 self._data_format = 'NCHW' 303 elif self.axis == [3]: 304 self._data_format = 'NHWC' 305 else: 306 raise ValueError('Unsupported axis, fused batch norm only supports ' 307 'axis == [1] or axis == [3]') 308 309 axis_to_dim = {x: input_shape.dims[x].value for x in self.axis} 310 for x in axis_to_dim: 311 if axis_to_dim[x] is None: 312 raise ValueError('Input has undefined `axis` dimension. Input shape: ', 313 input_shape) 314 self.input_spec = InputSpec(ndim=ndims, axes=axis_to_dim) 315 316 if len(axis_to_dim) == 1 and self.virtual_batch_size is None: 317 # Single axis batch norm (most common/default use-case) 318 param_shape = (list(axis_to_dim.values())[0],) 319 else: 320 # Parameter shape is the original shape but with 1 in all non-axis dims 321 param_shape = [axis_to_dim[i] if i in axis_to_dim 322 else 1 for i in range(ndims)] 323 if self.virtual_batch_size is not None: 324 # When using virtual batches, add an extra dim at index 1 325 param_shape.insert(1, 1) 326 for idx, x in enumerate(self.axis): 327 self.axis[idx] = x + 1 # Account for added dimension 328 329 if self.scale: 330 self.gamma = self.add_weight( 331 name='gamma', 332 shape=param_shape, 333 dtype=self._param_dtype, 334 initializer=self.gamma_initializer, 335 regularizer=self.gamma_regularizer, 336 constraint=self.gamma_constraint, 337 trainable=True, 338 experimental_autocast=False) 339 else: 340 self.gamma = None 341 if self.fused: 342 self._gamma_const = K.constant( 343 1.0, dtype=self._param_dtype, shape=param_shape) 344 345 if self.center: 346 self.beta = self.add_weight( 347 name='beta', 348 shape=param_shape, 349 dtype=self._param_dtype, 350 initializer=self.beta_initializer, 351 regularizer=self.beta_regularizer, 352 constraint=self.beta_constraint, 353 trainable=True, 354 experimental_autocast=False) 355 else: 356 self.beta = None 357 if self.fused: 358 self._beta_const = K.constant( 359 0.0, dtype=self._param_dtype, shape=param_shape) 360 361 try: 362 # Disable variable partitioning when creating the moving mean and variance 363 if hasattr(self, '_scope') and self._scope: 364 partitioner = self._scope.partitioner 365 self._scope.set_partitioner(None) 366 else: 367 partitioner = None 368 self.moving_mean = self.add_weight( 369 name='moving_mean', 370 shape=param_shape, 371 dtype=self._param_dtype, 372 initializer=self.moving_mean_initializer, 373 synchronization=tf_variables.VariableSynchronization.ON_READ, 374 trainable=False, 375 aggregation=tf_variables.VariableAggregation.MEAN, 376 experimental_autocast=False) 377 378 self.moving_variance = self.add_weight( 379 name='moving_variance', 380 shape=param_shape, 381 dtype=self._param_dtype, 382 initializer=self.moving_variance_initializer, 383 synchronization=tf_variables.VariableSynchronization.ON_READ, 384 trainable=False, 385 aggregation=tf_variables.VariableAggregation.MEAN, 386 experimental_autocast=False) 387 388 if self.renorm: 389 # Create variables to maintain the moving mean and standard deviation. 390 # These are used in training and thus are different from the moving 391 # averages above. The renorm variables are colocated with moving_mean 392 # and moving_variance. 393 # NOTE: below, the outer `with device` block causes the current device 394 # stack to be cleared. The nested ones use a `lambda` to set the desired 395 # device and ignore any devices that may be set by the custom getter. 396 def _renorm_variable(name, shape): 397 """Create a renorm variable.""" 398 var = self.add_weight( 399 name=name, 400 shape=shape, 401 dtype=self._param_dtype, 402 initializer=init_ops.zeros_initializer(), 403 synchronization=tf_variables.VariableSynchronization.ON_READ, 404 trainable=False, 405 aggregation=tf_variables.VariableAggregation.MEAN, 406 experimental_autocast=False) 407 return var 408 409 with distribution_strategy_context.get_strategy( 410 ).extended.colocate_vars_with(self.moving_mean): 411 self.renorm_mean = _renorm_variable('renorm_mean', param_shape) 412 self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ()) 413 # We initialize renorm_stddev to 0, and maintain the (0-initialized) 414 # renorm_stddev_weight. This allows us to (1) mix the average 415 # stddev with the minibatch stddev early in training, and (2) compute 416 # the unbiased average stddev by dividing renorm_stddev by the weight. 417 with distribution_strategy_context.get_strategy( 418 ).extended.colocate_vars_with(self.moving_variance): 419 self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape) 420 self.renorm_stddev_weight = _renorm_variable('renorm_stddev_weight', 421 ()) 422 finally: 423 if partitioner: 424 self._scope.set_partitioner(partitioner) 425 self.built = True 426 427 def _assign_moving_average(self, variable, value, momentum): 428 with ops.name_scope(None, 'AssignMovingAvg', 429 [variable, value, momentum]) as scope: 430 with ops.colocate_with(variable): 431 decay = ops.convert_to_tensor(1.0 - momentum, name='decay') 432 if decay.dtype != variable.dtype.base_dtype: 433 decay = math_ops.cast(decay, variable.dtype.base_dtype) 434 update_delta = ( 435 variable - math_ops.cast(value, variable.dtype)) * decay 436 return state_ops.assign_sub(variable, update_delta, name=scope) 437 438 def _fused_batch_norm(self, inputs, training): 439 """Returns the output of fused batch norm.""" 440 beta = self.beta if self.center else self._beta_const 441 gamma = self.gamma if self.scale else self._gamma_const 442 443 def _fused_batch_norm_training(): 444 return nn.fused_batch_norm( 445 inputs, 446 gamma, 447 beta, 448 epsilon=self.epsilon, 449 data_format=self._data_format) 450 451 def _fused_batch_norm_inference(): 452 return nn.fused_batch_norm( 453 inputs, 454 gamma, 455 beta, 456 mean=self.moving_mean, 457 variance=self.moving_variance, 458 epsilon=self.epsilon, 459 is_training=False, 460 data_format=self._data_format) 461 462 output, mean, variance = tf_utils.smart_cond( 463 training, _fused_batch_norm_training, _fused_batch_norm_inference) 464 if not self._bessels_correction_test_only: 465 # Remove Bessel's correction to be consistent with non-fused batch norm. 466 # Note that the variance computed by fused batch norm is 467 # with Bessel's correction. 468 sample_size = math_ops.cast( 469 array_ops.size(inputs) / array_ops.size(variance), variance.dtype) 470 factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size 471 variance *= factor 472 473 training_value = tf_utils.constant_value(training) 474 if training_value is None: 475 momentum = tf_utils.smart_cond(training, 476 lambda: self.momentum, 477 lambda: 1.0) 478 else: 479 momentum = ops.convert_to_tensor(self.momentum) 480 if training_value or training_value is None: 481 if distribution_strategy_context.in_cross_replica_context(): 482 strategy = distribution_strategy_context.get_strategy() 483 mean_update = strategy.extended.update( 484 self.moving_mean, self._assign_moving_average, 485 (mean, self.momentum)) 486 variance_update = strategy.extended.update( 487 self.moving_variance, self._assign_moving_average, 488 (variance, self.momentum)) 489 else: 490 mean_update = self._assign_moving_average(self.moving_mean, mean, 491 momentum) 492 variance_update = self._assign_moving_average(self.moving_variance, 493 variance, momentum) 494 self.add_update(mean_update, inputs=True) 495 self.add_update(variance_update, inputs=True) 496 497 return output 498 499 def _renorm_correction_and_moments(self, mean, variance, training): 500 """Returns the correction and update values for renorm.""" 501 stddev = math_ops.sqrt(variance + self.epsilon) 502 # Compute the average mean and standard deviation, as if they were 503 # initialized with this batch's moments. 504 mixed_renorm_mean = (self.renorm_mean + 505 (1. - self.renorm_mean_weight) * mean) 506 mixed_renorm_stddev = (self.renorm_stddev + 507 (1. - self.renorm_stddev_weight) * stddev) 508 # Compute the corrections for batch renorm. 509 r = stddev / mixed_renorm_stddev 510 d = (mean - mixed_renorm_mean) / mixed_renorm_stddev 511 # Ensure the corrections use pre-update moving averages. 512 with ops.control_dependencies([r, d]): 513 mean = array_ops.identity(mean) 514 stddev = array_ops.identity(stddev) 515 rmin, rmax, dmax = [self.renorm_clipping.get(key) 516 for key in ['rmin', 'rmax', 'dmax']] 517 if rmin is not None: 518 r = math_ops.maximum(r, rmin) 519 if rmax is not None: 520 r = math_ops.minimum(r, rmax) 521 if dmax is not None: 522 d = math_ops.maximum(d, -dmax) 523 d = math_ops.minimum(d, dmax) 524 # When not training, use r=1, d=0. 525 r = tf_utils.smart_cond(training, lambda: r, lambda: array_ops.ones_like(r)) 526 d = tf_utils.smart_cond(training, 527 lambda: d, 528 lambda: array_ops.zeros_like(d)) 529 530 def _update_renorm_variable(var, weight, value): 531 """Updates a moving average and weight, returns the unbiased value.""" 532 value = array_ops.identity(value) 533 def _do_update(): 534 """Updates the var and weight, returns their updated ratio.""" 535 # Update the variables without zero debiasing. The debiasing will be 536 # accomplished by dividing the exponential moving average by the weight. 537 # For example, after a single update, the moving average would be 538 # (1-decay) * value. and the weight will be 1-decay, with their ratio 539 # giving the value. 540 # Make sure the weight is not updated until before r and d computation. 541 with ops.control_dependencies([value]): 542 weight_value = array_ops.constant(1., dtype=weight.dtype) 543 new_var = self._assign_moving_average(var, value, self.renorm_momentum) 544 new_weight = self._assign_moving_average(weight, weight_value, 545 self.renorm_momentum) 546 # TODO(yuefengz): the updates to var and weighted can not be batched 547 # together if we fetch their updated values here. Consider calculating 548 # new values and delaying the updates. 549 return new_var / new_weight 550 551 def _fake_update(): 552 return array_ops.identity(var) 553 return tf_utils.smart_cond(training, _do_update, _fake_update) 554 555 # TODO(yuefengz): colocate the operations 556 new_mean = _update_renorm_variable(self.renorm_mean, 557 self.renorm_mean_weight, mean) 558 new_stddev = _update_renorm_variable(self.renorm_stddev, 559 self.renorm_stddev_weight, stddev) 560 # Make sqrt(moving_variance + epsilon) = new_stddev. 561 new_variance = math_ops.square(new_stddev) - self.epsilon 562 563 return (r, d, new_mean, new_variance) 564 565 def _moments(self, inputs, reduction_axes, keep_dims): 566 return nn.moments(inputs, reduction_axes, keep_dims=keep_dims) 567 568 def call(self, inputs, training=None): 569 if training is None: 570 training = K.learning_phase() 571 572 in_eager_mode = context.executing_eagerly() 573 if self.virtual_batch_size is not None: 574 # Virtual batches (aka ghost batches) can be simulated by reshaping the 575 # Tensor and reusing the existing batch norm implementation 576 original_shape = [-1] + inputs.shape.as_list()[1:] 577 expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:] 578 579 # Will cause errors if virtual_batch_size does not divide the batch size 580 inputs = array_ops.reshape(inputs, expanded_shape) 581 582 def undo_virtual_batching(outputs): 583 outputs = array_ops.reshape(outputs, original_shape) 584 return outputs 585 586 if self.fused: 587 outputs = self._fused_batch_norm(inputs, training=training) 588 if self.virtual_batch_size is not None: 589 # Currently never reaches here since fused_batch_norm does not support 590 # virtual batching 591 outputs = undo_virtual_batching(outputs) 592 return outputs 593 594 # Compute the axes along which to reduce the mean / variance 595 input_shape = inputs.get_shape() 596 ndims = len(input_shape) 597 reduction_axes = [i for i in range(ndims) if i not in self.axis] 598 if self.virtual_batch_size is not None: 599 del reduction_axes[1] # Do not reduce along virtual batch dim 600 601 # Broadcasting only necessary for single-axis batch norm where the axis is 602 # not the last dimension 603 broadcast_shape = [1] * ndims 604 broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value 605 def _broadcast(v): 606 if (v is not None and 607 len(v.get_shape()) != ndims and 608 reduction_axes != list(range(ndims - 1))): 609 return array_ops.reshape(v, broadcast_shape) 610 return v 611 612 scale, offset = _broadcast(self.gamma), _broadcast(self.beta) 613 614 def _compose_transforms(scale, offset, then_scale, then_offset): 615 if then_scale is not None: 616 scale *= then_scale 617 offset *= then_scale 618 if then_offset is not None: 619 offset += then_offset 620 return (scale, offset) 621 622 # Determine a boolean value for `training`: could be True, False, or None. 623 training_value = tf_utils.constant_value(training) 624 if training_value is not False: 625 if self.adjustment: 626 adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs)) 627 # Adjust only during training. 628 adj_scale = tf_utils.smart_cond(training, 629 lambda: adj_scale, 630 lambda: array_ops.ones_like(adj_scale)) 631 adj_bias = tf_utils.smart_cond(training, 632 lambda: adj_bias, 633 lambda: array_ops.zeros_like(adj_bias)) 634 scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset) 635 636 # Some of the computations here are not necessary when training==False 637 # but not a constant. However, this makes the code simpler. 638 keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1 639 mean, variance = self._moments( 640 math_ops.cast(inputs, self._param_dtype), 641 reduction_axes, 642 keep_dims=keep_dims) 643 644 moving_mean = self.moving_mean 645 moving_variance = self.moving_variance 646 647 mean = tf_utils.smart_cond(training, 648 lambda: mean, 649 lambda: moving_mean) 650 variance = tf_utils.smart_cond(training, 651 lambda: variance, 652 lambda: moving_variance) 653 654 if self.virtual_batch_size is not None: 655 # This isn't strictly correct since in ghost batch norm, you are 656 # supposed to sequentially update the moving_mean and moving_variance 657 # with each sub-batch. However, since the moving statistics are only 658 # used during evaluation, it is more efficient to just update in one 659 # step and should not make a significant difference in the result. 660 new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True) 661 new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True) 662 else: 663 new_mean, new_variance = mean, variance 664 665 if self.renorm: 666 r, d, new_mean, new_variance = self._renorm_correction_and_moments( 667 new_mean, new_variance, training) 668 # When training, the normalized values (say, x) will be transformed as 669 # x * gamma + beta without renorm, and (x * r + d) * gamma + beta 670 # = x * (r * gamma) + (d * gamma + beta) with renorm. 671 r = _broadcast(array_ops.stop_gradient(r, name='renorm_r')) 672 d = _broadcast(array_ops.stop_gradient(d, name='renorm_d')) 673 scale, offset = _compose_transforms(r, d, scale, offset) 674 675 if distribution_strategy_context.in_cross_replica_context(): 676 strategy = distribution_strategy_context.get_strategy() 677 678 def _do_update(var, value): 679 """Compute the updates for mean and variance.""" 680 if in_eager_mode and not self.trainable: 681 return 682 return strategy.extended.update( 683 var, self._assign_moving_average, (value, self.momentum), 684 group=False) 685 # We need to unwrap the moving_mean or moving_variance in the case of 686 # training being false to match the output of true_fn and false_fn 687 # in the smart cond. 688 mean_update = tf_utils.smart_cond( 689 training, 690 lambda: _do_update(self.moving_mean, new_mean), 691 lambda: strategy.unwrap(self.moving_mean)) 692 variance_update = tf_utils.smart_cond( 693 training, 694 lambda: _do_update(self.moving_variance, new_variance), 695 lambda: strategy.unwrap(self.moving_variance)) 696 else: 697 def _do_update(var, value): 698 """Compute the updates for mean and variance.""" 699 if in_eager_mode and not self.trainable: 700 return 701 return self._assign_moving_average(var, value, self.momentum) 702 mean_update = tf_utils.smart_cond( 703 training, 704 lambda: _do_update(self.moving_mean, new_mean), 705 lambda: self.moving_mean) 706 variance_update = tf_utils.smart_cond( 707 training, 708 lambda: _do_update(self.moving_variance, new_variance), 709 lambda: self.moving_variance) 710 if not context.executing_eagerly(): 711 self.add_update(mean_update, inputs=True) 712 self.add_update(variance_update, inputs=True) 713 714 else: 715 mean, variance = self.moving_mean, self.moving_variance 716 717 mean = math_ops.cast(mean, inputs.dtype) 718 variance = math_ops.cast(variance, inputs.dtype) 719 if offset is not None: 720 offset = math_ops.cast(offset, inputs.dtype) 721 if scale is not None: 722 scale = math_ops.cast(scale, inputs.dtype) 723 # TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing 724 # math in float16 hurts validation accuracy of popular models like resnet. 725 outputs = nn.batch_normalization(inputs, 726 _broadcast(mean), 727 _broadcast(variance), 728 offset, 729 scale, 730 self.epsilon) 731 # If some components of the shape got lost due to adjustments, fix that. 732 outputs.set_shape(input_shape) 733 734 if self.virtual_batch_size is not None: 735 outputs = undo_virtual_batching(outputs) 736 return outputs 737 738 def compute_output_shape(self, input_shape): 739 return input_shape 740 741 def get_config(self): 742 config = { 743 'axis': self.axis, 744 'momentum': self.momentum, 745 'epsilon': self.epsilon, 746 'center': self.center, 747 'scale': self.scale, 748 'beta_initializer': initializers.serialize(self.beta_initializer), 749 'gamma_initializer': initializers.serialize(self.gamma_initializer), 750 'moving_mean_initializer': 751 initializers.serialize(self.moving_mean_initializer), 752 'moving_variance_initializer': 753 initializers.serialize(self.moving_variance_initializer), 754 'beta_regularizer': regularizers.serialize(self.beta_regularizer), 755 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), 756 'beta_constraint': constraints.serialize(self.beta_constraint), 757 'gamma_constraint': constraints.serialize(self.gamma_constraint) 758 } 759 # Only add TensorFlow-specific parameters if they are set, so as to preserve 760 # model compatibility with external Keras. 761 if self.renorm: 762 config['renorm'] = True 763 config['renorm_clipping'] = self.renorm_clipping 764 config['renorm_momentum'] = self.renorm_momentum 765 if self.virtual_batch_size is not None: 766 config['virtual_batch_size'] = self.virtual_batch_size 767 # Note: adjustment is not serializable. 768 if self.adjustment is not None: 769 logging.warning('The `adjustment` function of this `BatchNormalization` ' 770 'layer cannot be serialized and has been omitted from ' 771 'the layer config. It will not be included when ' 772 're-creating the layer from the saved config.') 773 base_config = super(BatchNormalizationBase, self).get_config() 774 return dict(list(base_config.items()) + list(config.items())) 775 776 777def _replace_in_base_docstring(old, new): 778 string = BatchNormalizationBase.__doc__ 779 if old not in string: 780 raise ValueError('Could not find following string in BatchNormalizationBase' 781 ' docstring: "{}"'.format(old)) 782 return string.replace(old, new) 783 784 785@keras_export(v1=['keras.layers.BatchNormalization']) # pylint: disable=missing-docstring 786class BatchNormalization(BatchNormalizationBase): 787 788 __doc__ = _replace_in_base_docstring( 789 ''' 790 fused: if `True`, use a faster, fused implementation, or raise a ValueError 791 if the fused implementation cannot be used. If `None`, use the faster 792 implementation if possible. If False, do not used the fused 793 implementation.''', 794 795 ''' 796 fused: if `None` or `True`, use a faster, fused implementation if possible. 797 If `False`, use the system recommended implementation.''') 798 799 _USE_V2_BEHAVIOR = False 800 801 802@keras_export('keras.layers.experimental.LayerNormalization') 803class LayerNormalization(Layer): 804 """Layer normalization layer (Ba et al., 2016). 805 806 Normalize the activations of the previous layer for each given example in a 807 batch independently, rather than across a batch like Batch Normalization. 808 i.e. applies a transformation that maintains the mean activation within each 809 example close to 0 and the activation standard deviation close to 1. 810 811 Given a tensor `inputs` of rank `R`, moments are calculated and normalization 812 is performed over all axes in norm_axis. Scaling and centering, 813 if requested, is performed over all axes in params_axis. 814 815 By default, normalization is performed over all but the first axis 816 (the `HWC` if `inputs` is `NHWC`), while the `beta` and `gamma` trainable 817 parameters are calculated for the rightmost axis (the `C` if `inputs` is 818 `NHWC`). Scaling and recentering is performed via broadcast of the 819 `beta` and `gamma` parameters with the normalized tensor. 820 821 The shapes of `beta` and `gamma` are 822 `[inputs.shape[i] for i in (param axes)]`, 823 and this part of the inputs' shape must be fully defined. 824 825 Arguments: 826 norm_axis: Integer or List. normalization will be 827 performed along these dimensions. If unspecified (None), it will default 828 to the dimensions `begin_norm_axis : rank(inputs)` 829 params_axis: Integer or List. The (beta, gamma) dimensions: scale 830 and centering parameters will have take their shapes from these axes and 831 will be broadcast with the normalized inputs accordingly. If unspecified 832 (None), it will default to the last dimension 833 epsilon: Small float added to variance to avoid dividing by zero. 834 center: If True, add offset of `beta` to normalized tensor. 835 If False, `beta` is ignored. 836 scale: If True, multiply by `gamma`. 837 If False, `gamma` is not used. 838 When the next layer is linear (also e.g. `nn.relu`), 839 this can be disabled since the scaling 840 will be done by the next layer. 841 beta_initializer: Initializer for the beta weight. 842 gamma_initializer: Initializer for the gamma weight. 843 beta_regularizer: Optional regularizer for the beta weight. 844 gamma_regularizer: Optional regularizer for the gamma weight. 845 beta_constraint: Optional constraint for the beta weight. 846 gamma_constraint: Optional constraint for the gamma weight. 847 trainable: Boolean, if `True` the variables will be marked as trainable. 848 849 Input shape: 850 Arbitrary. Use the keyword argument `input_shape` 851 (tuple of integers, does not include the samples axis) 852 when using this layer as the first layer in a model. 853 854 Output shape: 855 Same shape as input. 856 857 References: 858 - [Layer Normalization](https://arxiv.org/abs/1607.06450) 859 """ 860 861 def __init__(self, 862 norm_axis=None, 863 params_axis=-1, 864 epsilon=1e-12, 865 center=True, 866 scale=True, 867 beta_initializer='zeros', 868 gamma_initializer='ones', 869 beta_regularizer=None, 870 gamma_regularizer=None, 871 beta_constraint=None, 872 gamma_constraint=None, 873 trainable=True, 874 name=None, 875 **kwargs): 876 super(LayerNormalization, self).__init__( 877 name=name, trainable=trainable, **kwargs) 878 if isinstance(norm_axis, list): 879 self.norm_axis = norm_axis[:] 880 elif isinstance(norm_axis, int): 881 self.norm_axis = norm_axis 882 elif norm_axis is None: 883 self.norm_axis = None 884 else: 885 raise TypeError('norm_axis must be int or list or None, type given: %s' 886 % type(norm_axis)) 887 888 if isinstance(params_axis, list): 889 self.params_axis = params_axis[:] 890 elif isinstance(params_axis, int): 891 self.params_axis = params_axis 892 else: 893 raise TypeError('params_axis must be int or list, type given: %s' 894 % type(params_axis)) 895 896 self.epsilon = epsilon 897 self.center = center 898 self.scale = scale 899 self.beta_initializer = initializers.get(beta_initializer) 900 self.gamma_initializer = initializers.get(gamma_initializer) 901 self.beta_regularizer = regularizers.get(beta_regularizer) 902 self.gamma_regularizer = regularizers.get(gamma_regularizer) 903 self.beta_constraint = constraints.get(beta_constraint) 904 self.gamma_constraint = constraints.get(gamma_constraint) 905 906 self.supports_masking = True 907 908 def build(self, input_shape): 909 ndims = len(input_shape) 910 if ndims is None: 911 raise ValueError('Input shape %s has undefined rank.' % input_shape) 912 913 # Handle an unspecified norm_axis 914 if self.norm_axis is None: 915 self.norm_axis = list(range(1, ndims)) 916 917 # Convert axes to lists and resolve negatives 918 if isinstance(self.norm_axis, int): 919 self.norm_axis = [self.norm_axis] 920 for idx, x in enumerate(self.norm_axis): 921 if x < 0: 922 self.norm_axis[idx] = ndims + x 923 924 if isinstance(self.params_axis, int): 925 self.params_axis = [self.params_axis] 926 for idx, x in enumerate(self.params_axis): 927 if x < 0: 928 self.params_axis[idx] = ndims + x 929 930 # Validate axes 931 for x in self.norm_axis: 932 if x < 0 or x >= ndims: 933 raise ValueError('Invalid axis: %d' % x) 934 if len(self.norm_axis) != len(set(self.norm_axis)): 935 raise ValueError('Duplicate axis: %s' % self.norm_axis) 936 937 for x in self.params_axis: 938 if x < 0 or x >= ndims: 939 raise ValueError('Invalid axis: %d' % x) 940 if len(self.params_axis) != len(set(self.params_axis)): 941 raise ValueError('Duplicate axis: %s' % self.params_axis) 942 943 param_shape = [input_shape[dim] for dim in self.params_axis] 944 945 if self.scale: 946 self.gamma = self.add_weight( 947 name='gamma', 948 shape=param_shape, 949 initializer=self.gamma_initializer, 950 regularizer=self.gamma_regularizer, 951 constraint=self.gamma_constraint, 952 trainable=True, 953 experimental_autocast=False) 954 else: 955 self.gamma = None 956 957 if self.center: 958 self.beta = self.add_weight( 959 name='beta', 960 shape=param_shape, 961 initializer=self.beta_initializer, 962 regularizer=self.beta_regularizer, 963 constraint=self.beta_constraint, 964 trainable=True, 965 experimental_autocast=False) 966 else: 967 self.beta = None 968 969 def call(self, inputs): 970 # Compute the axes along which to reduce the mean / variance 971 input_shape = inputs.get_shape() 972 ndims = len(input_shape) 973 974 # Calculate the moments on the last axis (layer activations). 975 mean, variance = nn.moments(inputs, self.norm_axis, keep_dims=True) 976 977 # Broadcasting only necessary for norm where the params axes aren't just 978 # the last dimension 979 broadcast_shape = [1] * ndims 980 for dim in self.params_axis: 981 broadcast_shape[dim] = input_shape.dims[dim].value 982 def _broadcast(v): 983 if (v is not None and 984 len(v.get_shape()) != ndims and 985 self.params_axis != [ndims - 1]): 986 return array_ops.reshape(v, broadcast_shape) 987 return v 988 scale, offset = _broadcast(self.gamma), _broadcast(self.beta) 989 990 # Compute layer normalization using the batch_normalization function. 991 outputs = nn.batch_normalization( 992 inputs, 993 mean, 994 variance, 995 offset=offset, 996 scale=scale, 997 variance_epsilon=self.epsilon) 998 999 # If some components of the shape got lost due to adjustments, fix that. 1000 outputs.set_shape(input_shape) 1001 1002 return outputs 1003 1004 def compute_output_shape(self, input_shape): 1005 return input_shape 1006 1007 def get_config(self): 1008 config = { 1009 'norm_axis': self.norm_axis, 1010 'params_axis': self.params_axis, 1011 'epsilon': self.epsilon, 1012 'center': self.center, 1013 'scale': self.scale, 1014 'beta_initializer': initializers.serialize(self.beta_initializer), 1015 'gamma_initializer': initializers.serialize(self.gamma_initializer), 1016 'beta_regularizer': regularizers.serialize(self.beta_regularizer), 1017 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), 1018 'beta_constraint': constraints.serialize(self.beta_constraint), 1019 'gamma_constraint': constraints.serialize(self.gamma_constraint) 1020 } 1021 base_config = super(LayerNormalization, self).get_config() 1022 return dict(list(base_config.items()) + list(config.items())) 1023