1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Normalization layers.""" 16# pylint: disable=g-classes-have-attributes 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.distribute import distribution_strategy_context 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import tensor_shape 26from tensorflow.python.keras import backend as K 27from tensorflow.python.keras import constraints 28from tensorflow.python.keras import initializers 29from tensorflow.python.keras import regularizers 30from tensorflow.python.keras.engine.base_layer import Layer 31from tensorflow.python.keras.engine.input_spec import InputSpec 32from tensorflow.python.keras.utils import control_flow_util 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import init_ops 35from tensorflow.python.ops import math_ops 36from tensorflow.python.ops import nn 37from tensorflow.python.ops import state_ops 38from tensorflow.python.ops import variables as tf_variables 39from tensorflow.python.ops.control_flow_ops import get_enclosing_xla_context 40from tensorflow.python.platform import tf_logging as logging 41from tensorflow.python.util.tf_export import keras_export 42 43 44class BatchNormalizationBase(Layer): 45 r"""Layer that normalizes its inputs. 46 47 Batch normalization applies a transformation that maintains the mean output 48 close to 0 and the output standard deviation close to 1. 49 50 Importantly, batch normalization works differently during training and 51 during inference. 52 53 **During training** (i.e. when using `fit()` or when calling the layer/model 54 with the argument `training=True`), the layer normalizes its output using 55 the mean and standard deviation of the current batch of inputs. That is to 56 say, for each channel being normalized, the layer returns 57 `gamma * (batch - mean(batch)) / sqrt(var(batch) + epsilon) + beta`, where: 58 59 - `epsilon` is small constant (configurable as part of the constructor 60 arguments) 61 - `gamma` is a learned scaling factor (initialized as 1), which 62 can be disabled by passing `scale=False` to the constructor. 63 - `beta` is a learned offset factor (initialized as 0), which 64 can be disabled by passing `center=False` to the constructor. 65 66 **During inference** (i.e. when using `evaluate()` or `predict()` or when 67 calling the layer/model with the argument `training=False` (which is the 68 default), the layer normalizes its output using a moving average of the 69 mean and standard deviation of the batches it has seen during training. That 70 is to say, it returns 71 `gamma * (batch - self.moving_mean) / sqrt(self.moving_var + epsilon) + beta`. 72 73 `self.moving_mean` and `self.moving_var` are non-trainable variables that 74 are updated each time the layer in called in training mode, as such: 75 76 - `moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)` 77 - `moving_var = moving_var * momentum + var(batch) * (1 - momentum)` 78 79 As such, the layer will only normalize its inputs during inference 80 *after having been trained on data that has similar statistics as the 81 inference data*. 82 83 Args: 84 axis: Integer or a list of integers, the axis that should be normalized 85 (typically the features axis). For instance, after a `Conv2D` layer with 86 `data_format="channels_first"`, set `axis=1` in `BatchNormalization`. 87 momentum: Momentum for the moving average. 88 epsilon: Small float added to variance to avoid dividing by zero. 89 center: If True, add offset of `beta` to normalized tensor. If False, `beta` 90 is ignored. 91 scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the 92 next layer is linear (also e.g. `nn.relu`), this can be disabled since the 93 scaling will be done by the next layer. 94 beta_initializer: Initializer for the beta weight. 95 gamma_initializer: Initializer for the gamma weight. 96 moving_mean_initializer: Initializer for the moving mean. 97 moving_variance_initializer: Initializer for the moving variance. 98 beta_regularizer: Optional regularizer for the beta weight. 99 gamma_regularizer: Optional regularizer for the gamma weight. 100 beta_constraint: Optional constraint for the beta weight. 101 gamma_constraint: Optional constraint for the gamma weight. 102 renorm: Whether to use [Batch Renormalization]( 103 https://arxiv.org/abs/1702.03275). This adds extra variables during 104 training. The inference is the same for either value of this parameter. 105 renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to 106 scalar `Tensors` used to clip the renorm correction. The correction `(r, 107 d)` is used as `corrected_value = normalized_value * r + d`, with `r` 108 clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin, 109 dmax are set to inf, 0, inf, respectively. 110 renorm_momentum: Momentum used to update the moving means and standard 111 deviations with renorm. Unlike `momentum`, this affects training and 112 should be neither too small (which would add noise) nor too large (which 113 would give stale estimates). Note that `momentum` is still applied to get 114 the means and variances for inference. 115 fused: if `True`, use a faster, fused implementation, or raise a ValueError 116 if the fused implementation cannot be used. If `None`, use the faster 117 implementation if possible. If False, do not used the fused 118 implementation. 119 Note that in TensorFlow 1.x, the meaning of `fused=True` is different: 120 if `False`, the layer uses the system-recommended implementation. 121 trainable: Boolean, if `True` the variables will be marked as trainable. 122 virtual_batch_size: An `int`. By default, `virtual_batch_size` is `None`, 123 which means batch normalization is performed across the whole batch. When 124 `virtual_batch_size` is not `None`, instead perform "Ghost Batch 125 Normalization", which creates virtual sub-batches which are each 126 normalized separately (with shared gamma, beta, and moving statistics). 127 Must divide the actual batch size during execution. 128 adjustment: A function taking the `Tensor` containing the (dynamic) shape of 129 the input tensor and returning a pair (scale, bias) to apply to the 130 normalized values (before gamma and beta), only during training. For 131 example, if `axis=-1`, 132 `adjustment = lambda shape: ( 133 tf.random.uniform(shape[-1:], 0.93, 1.07), 134 tf.random.uniform(shape[-1:], -0.1, 0.1))` will scale the normalized 135 value by up to 7% up or down, then shift the result by up to 0.1 136 (with independent scaling and bias for each feature but shared 137 across all examples), and finally apply gamma and/or beta. If 138 `None`, no adjustment is applied. Cannot be specified if 139 virtual_batch_size is specified. 140 141 Call arguments: 142 inputs: Input tensor (of any rank). 143 training: Python boolean indicating whether the layer should behave in 144 training mode or in inference mode. 145 - `training=True`: The layer will normalize its inputs using the mean and 146 variance of the current batch of inputs. 147 - `training=False`: The layer will normalize its inputs using the mean and 148 variance of its moving statistics, learned during training. 149 150 Input shape: 151 Arbitrary. Use the keyword argument `input_shape` (tuple of 152 integers, does not include the samples axis) when using this layer as the 153 first layer in a model. 154 155 Output shape: 156 Same shape as input. 157 158 Reference: 159 - [Ioffe and Szegedy, 2015](https://arxiv.org/abs/1502.03167). 160 """ 161 162 # By default, the base class uses V2 behavior. The BatchNormalization V1 163 # subclass sets this to False to use the V1 behavior. 164 _USE_V2_BEHAVIOR = True 165 166 def __init__(self, 167 axis=-1, 168 momentum=0.99, 169 epsilon=1e-3, 170 center=True, 171 scale=True, 172 beta_initializer='zeros', 173 gamma_initializer='ones', 174 moving_mean_initializer='zeros', 175 moving_variance_initializer='ones', 176 beta_regularizer=None, 177 gamma_regularizer=None, 178 beta_constraint=None, 179 gamma_constraint=None, 180 renorm=False, 181 renorm_clipping=None, 182 renorm_momentum=0.99, 183 fused=None, 184 trainable=True, 185 virtual_batch_size=None, 186 adjustment=None, 187 name=None, 188 **kwargs): 189 super(BatchNormalizationBase, self).__init__(name=name, **kwargs) 190 if isinstance(axis, (list, tuple)): 191 self.axis = axis[:] 192 elif isinstance(axis, int): 193 self.axis = axis 194 else: 195 raise TypeError('Expected an int or a list/tuple of ints for the ' 196 'argument \'axis\', but received: %r' % axis) 197 self.momentum = momentum 198 self.epsilon = epsilon 199 self.center = center 200 self.scale = scale 201 self.beta_initializer = initializers.get(beta_initializer) 202 self.gamma_initializer = initializers.get(gamma_initializer) 203 self.moving_mean_initializer = initializers.get(moving_mean_initializer) 204 self.moving_variance_initializer = initializers.get( 205 moving_variance_initializer) 206 self.beta_regularizer = regularizers.get(beta_regularizer) 207 self.gamma_regularizer = regularizers.get(gamma_regularizer) 208 self.beta_constraint = constraints.get(beta_constraint) 209 self.gamma_constraint = constraints.get(gamma_constraint) 210 self.renorm = renorm 211 self.virtual_batch_size = virtual_batch_size 212 self.adjustment = adjustment 213 if self._USE_V2_BEHAVIOR: 214 if fused: 215 self._raise_if_fused_cannot_be_used() 216 # We leave fused as None if self._fused_can_be_used()==True, since we 217 # still may set it to False in self.build() if the input rank is not 4. 218 elif fused is None and not self._fused_can_be_used(): 219 fused = False 220 elif fused is None: 221 fused = True 222 self.supports_masking = True 223 224 self.fused = fused 225 self._bessels_correction_test_only = True 226 self.trainable = trainable 227 228 if renorm: 229 renorm_clipping = renorm_clipping or {} 230 keys = ['rmax', 'rmin', 'dmax'] 231 if set(renorm_clipping) - set(keys): 232 raise ValueError('renorm_clipping %s contains keys not in %s' % 233 (renorm_clipping, keys)) 234 self.renorm_clipping = renorm_clipping 235 self.renorm_momentum = renorm_momentum 236 237 def _raise_if_fused_cannot_be_used(self): 238 """Raises a ValueError if fused implementation cannot be used. 239 240 In addition to the checks done in this function, the input tensors rank must 241 be 4. The input rank check can only be done once the input shape is known. 242 """ 243 # Note the ValueErrors in this function are caught and not reraised in 244 # _fused_can_be_used(). No other exception besides ValueError should be 245 # raised here. 246 247 # Currently fused batch norm doesn't support renorm. It also only supports a 248 # channel dimension on axis 1 or 3, when no virtual batch size or adjustment 249 # is used. 250 if self.renorm: 251 raise ValueError('Passing both `fused=True` and `renorm=True` is ' 252 'not supported') 253 axis = [self.axis] if isinstance(self.axis, int) else self.axis 254 # Axis -3 is equivalent to 1, and axis -1 is equivalent to 3, because the 255 # input rank is required to be 4 (which is checked later). 256 # TODO(b/173253101): Once the input rank can be 5, update this check. 257 if len(axis) > 1 or axis[0] not in (-3, -1, 1, 3): 258 raise ValueError('Passing `fused=True` is only supported when axis is 1 ' 259 'or 3. Got axis %s' % (axis,)) 260 if self.virtual_batch_size is not None: 261 raise ValueError('Passing `fused=True` is not supported when ' 262 '`virtual_batch_size` is specified.') 263 if self.adjustment is not None: 264 raise ValueError('Passing `fused=True` is not supported when ' 265 '`adjustment` is specified.') 266 # TODO(reedwm): Support fp64 in FusedBatchNorm then remove this check. 267 if self._compute_dtype not in ('float16', 'bfloat16', 'float32', None): 268 raise ValueError( 269 'Passing `fused=True` is only supported when the compute ' 270 'dtype is float16, bfloat16, or float32. Got dtype: %s' % 271 (self._compute_dtype,)) 272 273 def _fused_can_be_used(self): 274 try: 275 self._raise_if_fused_cannot_be_used() 276 return True 277 except ValueError: 278 return False 279 280 @property 281 def trainable(self): 282 return self._trainable 283 284 @trainable.setter 285 def trainable(self, value): 286 self._trainable = value 287 288 @property 289 def _param_dtype(self): 290 # Raise parameters of fp16 batch norm to fp32 291 if self.dtype == dtypes.float16 or self.dtype == dtypes.bfloat16: 292 return dtypes.float32 293 else: 294 return self.dtype or dtypes.float32 295 296 def _support_zero_size_input(self): 297 return distribution_strategy_context.has_strategy() and getattr( 298 distribution_strategy_context.get_strategy().extended, 299 'experimental_enable_get_next_as_optional', False) 300 301 def build(self, input_shape): 302 input_shape = tensor_shape.TensorShape(input_shape) 303 if not input_shape.ndims: 304 raise ValueError('Input has undefined rank.') 305 ndims = len(input_shape) 306 307 # Convert axis to list and resolve negatives 308 if isinstance(self.axis, int): 309 self.axis = [self.axis] 310 311 for idx, x in enumerate(self.axis): 312 if x < 0: 313 self.axis[idx] = ndims + x 314 315 # Validate axes 316 for x in self.axis: 317 if x < 0 or x >= ndims: 318 raise ValueError('Invalid axis: %s' % (self.axis,)) 319 if len(self.axis) != len(set(self.axis)): 320 raise ValueError('Duplicate axis: %s' % (self.axis,)) 321 322 if self.virtual_batch_size is not None: 323 if self.virtual_batch_size <= 0: 324 raise ValueError('virtual_batch_size must be a positive integer that ' 325 'divides the true batch size of the input tensor') 326 # If using virtual batches, the first dimension must be the batch 327 # dimension and cannot be the batch norm axis 328 if 0 in self.axis: 329 raise ValueError('When using virtual_batch_size, the batch dimension ' 330 'must be 0 and thus axis cannot include 0. ' 331 'Received axis=%s' % (self.axis,)) 332 if self.adjustment is not None: 333 raise ValueError('When using virtual_batch_size, adjustment cannot ' 334 'be specified') 335 336 if self.fused in (None, True): 337 # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the 338 # output back to its original shape accordingly. 339 if self._USE_V2_BEHAVIOR: 340 # TODO(b/173253101): Using fused in the 5D case is currently disabled 341 # due to a regression on UNet, so it is only currently only supported in 342 # the 4D case. 343 if self.fused is None: 344 self.fused = ndims == 4 345 elif self.fused and ndims != 4: 346 raise ValueError('Batch normalization layers with `fused=True` only ' 347 'support 4D or 5D input tensors. ' 348 'Received tensor with shape: %s' % 349 (tuple(input_shape),)) 350 else: 351 assert self.fused is not None 352 self.fused = (ndims == 4 and self._fused_can_be_used()) 353 # TODO(chrisying): fused batch norm is currently not supported for 354 # multi-axis batch norm and by extension virtual batches. In some cases, 355 # it might be possible to use fused batch norm but would require reshaping 356 # the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is 357 # particularly tricky. A compromise might be to just support the most 358 # common use case (turning 5D w/ virtual batch to NCHW) 359 360 if self.fused: 361 if self.axis == [1] and ndims == 4: 362 self._data_format = 'NCHW' 363 elif self.axis == [1] and ndims == 5: 364 self._data_format = 'NCDHW' 365 elif self.axis == [3] and ndims == 4: 366 self._data_format = 'NHWC' 367 elif self.axis == [4] and ndims == 5: 368 self._data_format = 'NDHWC' 369 elif ndims == 5: 370 # 5D tensors that can be passed in but should not use fused batch norm 371 # due to unsupported axis. 372 self.fused = False 373 else: 374 if ndims == 4: 375 raise ValueError( 376 'Unsupported axis. The use of `fused=True` is only possible with ' 377 '`axis=1` or `axis=3` for 4D input tensors. Received ' 378 'axis=%s' % (self.axis,)) 379 else: 380 raise ValueError( 381 'Unsupported axis. The use of `fused=True` is only possible with ' 382 '`axis=1` or `axis=4` for 5D input tensors. Received ' 383 'axis=%s' % (self.axis,)) 384 385 axis_to_dim = {x: input_shape.dims[x].value for x in self.axis} 386 for x in axis_to_dim: 387 if axis_to_dim[x] is None: 388 raise ValueError('Input has undefined `axis` dimension. Received input ' 389 'with shape %s. Axis value: %s' % 390 (tuple(input_shape), self.axis)) 391 self.input_spec = InputSpec(ndim=ndims, axes=axis_to_dim) 392 393 if len(axis_to_dim) == 1 and self.virtual_batch_size is None: 394 # Single axis batch norm (most common/default use-case) 395 param_shape = (list(axis_to_dim.values())[0],) 396 else: 397 # Parameter shape is the original shape but with 1 in all non-axis dims 398 param_shape = [ 399 axis_to_dim[i] if i in axis_to_dim else 1 for i in range(ndims) 400 ] 401 if self.virtual_batch_size is not None: 402 # When using virtual batches, add an extra dim at index 1 403 param_shape.insert(1, 1) 404 for idx, x in enumerate(self.axis): 405 self.axis[idx] = x + 1 # Account for added dimension 406 407 if self.scale: 408 self.gamma = self.add_weight( 409 name='gamma', 410 shape=param_shape, 411 dtype=self._param_dtype, 412 initializer=self.gamma_initializer, 413 regularizer=self.gamma_regularizer, 414 constraint=self.gamma_constraint, 415 trainable=True, 416 experimental_autocast=False) 417 else: 418 self.gamma = None 419 if self.fused: 420 self._gamma_const = K.constant( 421 1.0, dtype=self._param_dtype, shape=param_shape) 422 423 if self.center: 424 self.beta = self.add_weight( 425 name='beta', 426 shape=param_shape, 427 dtype=self._param_dtype, 428 initializer=self.beta_initializer, 429 regularizer=self.beta_regularizer, 430 constraint=self.beta_constraint, 431 trainable=True, 432 experimental_autocast=False) 433 else: 434 self.beta = None 435 if self.fused: 436 self._beta_const = K.constant( 437 0.0, dtype=self._param_dtype, shape=param_shape) 438 439 try: 440 # Disable variable partitioning when creating the moving mean and variance 441 if hasattr(self, '_scope') and self._scope: 442 partitioner = self._scope.partitioner 443 self._scope.set_partitioner(None) 444 else: 445 partitioner = None 446 self.moving_mean = self.add_weight( 447 name='moving_mean', 448 shape=param_shape, 449 dtype=self._param_dtype, 450 initializer=self.moving_mean_initializer, 451 synchronization=tf_variables.VariableSynchronization.ON_READ, 452 trainable=False, 453 aggregation=tf_variables.VariableAggregation.MEAN, 454 experimental_autocast=False) 455 456 self.moving_variance = self.add_weight( 457 name='moving_variance', 458 shape=param_shape, 459 dtype=self._param_dtype, 460 initializer=self.moving_variance_initializer, 461 synchronization=tf_variables.VariableSynchronization.ON_READ, 462 trainable=False, 463 aggregation=tf_variables.VariableAggregation.MEAN, 464 experimental_autocast=False) 465 466 if self.renorm: 467 # In batch renormalization we track the inference moving stddev instead 468 # of the moving variance to more closely align with the paper. 469 def moving_stddev_initializer(*args, **kwargs): 470 return math_ops.sqrt( 471 self.moving_variance_initializer(*args, **kwargs)) 472 473 with distribution_strategy_context.get_strategy( 474 ).extended.colocate_vars_with(self.moving_variance): 475 self.moving_stddev = self.add_weight( 476 name='moving_stddev', 477 shape=param_shape, 478 dtype=self._param_dtype, 479 initializer=moving_stddev_initializer, 480 synchronization=tf_variables.VariableSynchronization.ON_READ, 481 trainable=False, 482 aggregation=tf_variables.VariableAggregation.MEAN, 483 experimental_autocast=False) 484 485 # Create variables to maintain the moving mean and standard deviation. 486 # These are used in training and thus are different from the moving 487 # averages above. The renorm variables are colocated with moving_mean 488 # and moving_stddev. 489 # NOTE: below, the outer `with device` block causes the current device 490 # stack to be cleared. The nested ones use a `lambda` to set the desired 491 # device and ignore any devices that may be set by the custom getter. 492 def _renorm_variable(name, 493 shape, 494 initializer=init_ops.zeros_initializer()): 495 """Create a renorm variable.""" 496 var = self.add_weight( 497 name=name, 498 shape=shape, 499 dtype=self._param_dtype, 500 initializer=initializer, 501 synchronization=tf_variables.VariableSynchronization.ON_READ, 502 trainable=False, 503 aggregation=tf_variables.VariableAggregation.MEAN, 504 experimental_autocast=False) 505 return var 506 507 with distribution_strategy_context.get_strategy( 508 ).extended.colocate_vars_with(self.moving_mean): 509 self.renorm_mean = _renorm_variable('renorm_mean', param_shape, 510 self.moving_mean_initializer) 511 with distribution_strategy_context.get_strategy( 512 ).extended.colocate_vars_with(self.moving_stddev): 513 self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape, 514 moving_stddev_initializer) 515 finally: 516 if partitioner: 517 self._scope.set_partitioner(partitioner) 518 self.built = True 519 520 def _assign_moving_average(self, variable, value, momentum, inputs_size): 521 with K.name_scope('AssignMovingAvg') as scope: 522 with ops.colocate_with(variable): 523 decay = ops.convert_to_tensor_v2_with_dispatch( 524 1.0 - momentum, name='decay') 525 if decay.dtype != variable.dtype.base_dtype: 526 decay = math_ops.cast(decay, variable.dtype.base_dtype) 527 update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay 528 if inputs_size is not None: 529 update_delta = array_ops.where(inputs_size > 0, update_delta, 530 K.zeros_like(update_delta)) 531 return state_ops.assign_sub(variable, update_delta, name=scope) 532 533 def _assign_new_value(self, variable, value): 534 with K.name_scope('AssignNewValue') as scope: 535 with ops.colocate_with(variable): 536 return state_ops.assign(variable, value, name=scope) 537 538 def _fused_batch_norm(self, inputs, training): 539 """Returns the output of fused batch norm.""" 540 beta = self.beta if self.center else self._beta_const 541 gamma = self.gamma if self.scale else self._gamma_const 542 543 # TODO(b/129279393): Support zero batch input in non DistributionStrategy 544 # code as well. 545 if self._support_zero_size_input(): 546 # Keras assumes that batch dimension is the first dimension for Batch 547 # Normalization. 548 input_batch_size = array_ops.shape(inputs)[0] 549 else: 550 input_batch_size = None 551 552 # TODO(rmlarsen): Support using fused avg updates for non-eager execution 553 # after fixing graph pattern matching and enabling fused_batch_norm to 554 # take exponential_avg_factor as a tensor input. 555 use_fused_avg_updates = ( 556 ops.executing_eagerly_outside_functions() and 557 isinstance(self.momentum, (float, int)) and 558 get_enclosing_xla_context() is None) 559 if use_fused_avg_updates: 560 exponential_avg_factor = 1.0 - self.momentum 561 else: 562 exponential_avg_factor = None 563 564 def _maybe_add_or_remove_bessels_correction(variance, remove=True): 565 r"""Add or remove Bessel's correction.""" 566 # Removes Bessel's correction if remove == True, adds it otherwise. 567 # This is to be consistent with non-fused batch norm. Note that the 568 # variance computed by fused batch norm is with Bessel's correction. 569 # This is only used in legacy V1 batch norm tests. 570 if self._bessels_correction_test_only: 571 return variance 572 sample_size = math_ops.cast( 573 array_ops.size(inputs) / array_ops.size(variance), variance.dtype) 574 if remove: 575 factor = (sample_size - 576 math_ops.cast(1.0, variance.dtype)) / sample_size 577 else: 578 factor = sample_size / ( 579 sample_size - math_ops.cast(1.0, variance.dtype)) 580 return variance * factor 581 582 def _fused_batch_norm_training(): 583 return nn.fused_batch_norm( 584 inputs, 585 gamma, 586 beta, 587 mean=self.moving_mean, 588 variance=_maybe_add_or_remove_bessels_correction( 589 self.moving_variance, remove=False), 590 epsilon=self.epsilon, 591 is_training=True, 592 data_format=self._data_format, 593 exponential_avg_factor=exponential_avg_factor) 594 595 def _fused_batch_norm_training_empty(): 596 return inputs, self.moving_mean, self.moving_variance 597 598 def _fused_batch_norm_inference(): 599 return nn.fused_batch_norm( 600 inputs, 601 gamma, 602 beta, 603 mean=self.moving_mean, 604 variance=self.moving_variance, 605 epsilon=self.epsilon, 606 is_training=False, 607 data_format=self._data_format) 608 609 train_op = _fused_batch_norm_training 610 if use_fused_avg_updates and input_batch_size is not None: 611 # pylint: disable=g-long-lambda 612 train_op = lambda: control_flow_util.smart_cond( 613 input_batch_size > 0, _fused_batch_norm_training, 614 _fused_batch_norm_training_empty) 615 # pylint: enable=g-long-lambda 616 617 output, mean, variance = control_flow_util.smart_cond( 618 training, train_op, _fused_batch_norm_inference) 619 variance = _maybe_add_or_remove_bessels_correction(variance, remove=True) 620 621 training_value = control_flow_util.constant_value(training) 622 if training_value or training_value is None: 623 if not use_fused_avg_updates: 624 if training_value is None: 625 momentum = control_flow_util.smart_cond(training, 626 lambda: self.momentum, 627 lambda: 1.0) 628 else: 629 momentum = ops.convert_to_tensor_v2_with_dispatch(self.momentum) 630 631 def mean_update(): 632 """Update self.moving_mean with the most recent data point.""" 633 if use_fused_avg_updates: 634 return self._assign_new_value(self.moving_mean, mean) 635 else: 636 return self._assign_moving_average(self.moving_mean, mean, momentum, 637 input_batch_size) 638 639 def variance_update(): 640 """Update self.moving_variance with the most recent data point.""" 641 if use_fused_avg_updates: 642 return self._assign_new_value(self.moving_variance, variance) 643 else: 644 return self._assign_moving_average(self.moving_variance, variance, 645 momentum, input_batch_size) 646 647 self.add_update(mean_update) 648 self.add_update(variance_update) 649 650 return output 651 652 def _renorm_correction_and_moments(self, mean, variance, training, 653 inputs_size): 654 """Returns the correction and update values for renorm.""" 655 stddev = math_ops.sqrt(variance + self.epsilon) 656 # Compute the average mean and standard deviation, as if they were 657 # initialized with this batch's moments. 658 renorm_mean = self.renorm_mean 659 # Avoid divide by zero early on in training. 660 renorm_stddev = math_ops.maximum(self.renorm_stddev, 661 math_ops.sqrt(self.epsilon)) 662 # Compute the corrections for batch renorm. 663 r = stddev / renorm_stddev 664 d = (mean - renorm_mean) / renorm_stddev 665 # Ensure the corrections use pre-update moving averages. 666 with ops.control_dependencies([r, d]): 667 mean = array_ops.identity(mean) 668 stddev = array_ops.identity(stddev) 669 rmin, rmax, dmax = [ 670 self.renorm_clipping.get(key) for key in ['rmin', 'rmax', 'dmax'] 671 ] 672 if rmin is not None: 673 r = math_ops.maximum(r, rmin) 674 if rmax is not None: 675 r = math_ops.minimum(r, rmax) 676 if dmax is not None: 677 d = math_ops.maximum(d, -dmax) 678 d = math_ops.minimum(d, dmax) 679 # When not training, use r=1, d=0. 680 r = control_flow_util.smart_cond(training, lambda: r, 681 lambda: array_ops.ones_like(r)) 682 d = control_flow_util.smart_cond(training, lambda: d, 683 lambda: array_ops.zeros_like(d)) 684 685 def _update_renorm_variable(var, value, inputs_size): 686 """Updates a moving average and weight, returns the unbiased value.""" 687 value = array_ops.identity(value) 688 689 def _do_update(): 690 """Updates the var, returns the updated value.""" 691 new_var = self._assign_moving_average(var, value, self.renorm_momentum, 692 inputs_size) 693 return new_var 694 695 def _fake_update(): 696 return array_ops.identity(var) 697 698 return control_flow_util.smart_cond(training, _do_update, _fake_update) 699 700 # TODO(yuefengz): colocate the operations 701 update_new_mean = _update_renorm_variable(self.renorm_mean, mean, 702 inputs_size) 703 update_new_stddev = _update_renorm_variable(self.renorm_stddev, stddev, 704 inputs_size) 705 706 # Update the inference mode moving averages with the batch value. 707 with ops.control_dependencies([update_new_mean, update_new_stddev]): 708 out_mean = array_ops.identity(mean) 709 out_variance = array_ops.identity(variance) 710 711 return (r, d, out_mean, out_variance) 712 713 def _calculate_mean_and_var(self, inputs, reduction_axes, keep_dims): 714 return nn.moments(inputs, reduction_axes, keep_dims=keep_dims) 715 716 def _moments(self, inputs, reduction_axes, keep_dims): 717 mean, variance = self._calculate_mean_and_var(inputs, reduction_axes, 718 keep_dims) 719 # TODO(b/129279393): Support zero batch input in non DistributionStrategy 720 # code as well. 721 if self._support_zero_size_input(): 722 input_batch_size = array_ops.shape(inputs)[0] 723 mean = array_ops.where(input_batch_size > 0, mean, K.zeros_like(mean)) 724 variance = array_ops.where(input_batch_size > 0, variance, 725 K.zeros_like(variance)) 726 return mean, variance 727 728 def _get_training_value(self, training=None): 729 if training is None: 730 training = K.learning_phase() 731 if self._USE_V2_BEHAVIOR: 732 if isinstance(training, int): 733 training = bool(training) 734 if not self.trainable: 735 # When the layer is not trainable, it overrides the value passed from 736 # model. 737 training = False 738 return training 739 740 def call(self, inputs, training=None): 741 training = self._get_training_value(training) 742 743 if self.virtual_batch_size is not None: 744 # Virtual batches (aka ghost batches) can be simulated by reshaping the 745 # Tensor and reusing the existing batch norm implementation 746 original_shape = array_ops.shape(inputs) 747 original_shape = array_ops.concat( 748 [constant_op.constant([-1]), original_shape[1:]], axis=0) 749 expanded_shape = array_ops.concat([ 750 constant_op.constant([self.virtual_batch_size, -1]), 751 original_shape[1:] 752 ], 753 axis=0) 754 755 # Will cause errors if virtual_batch_size does not divide the batch size 756 inputs = array_ops.reshape(inputs, expanded_shape) 757 758 def undo_virtual_batching(outputs): 759 outputs = array_ops.reshape(outputs, original_shape) 760 return outputs 761 762 if self.fused: 763 outputs = self._fused_batch_norm(inputs, training=training) 764 if self.virtual_batch_size is not None: 765 # Currently never reaches here since fused_batch_norm does not support 766 # virtual batching 767 outputs = undo_virtual_batching(outputs) 768 return outputs 769 770 inputs_dtype = inputs.dtype.base_dtype 771 if inputs_dtype in (dtypes.float16, dtypes.bfloat16): 772 # Do all math in float32 if given 16-bit inputs for numeric stability. 773 # In particular, it's very easy for variance to overflow in float16 and 774 # for safety we also choose to cast bfloat16 to float32. 775 inputs = math_ops.cast(inputs, dtypes.float32) 776 777 # Compute the axes along which to reduce the mean / variance 778 input_shape = inputs.shape 779 ndims = len(input_shape) 780 reduction_axes = [i for i in range(ndims) if i not in self.axis] 781 if self.virtual_batch_size is not None: 782 del reduction_axes[1] # Do not reduce along virtual batch dim 783 784 # Broadcasting only necessary for single-axis batch norm where the axis is 785 # not the last dimension 786 broadcast_shape = [1] * ndims 787 broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value 788 789 def _broadcast(v): 790 if (v is not None and len(v.shape) != ndims and 791 reduction_axes != list(range(ndims - 1))): 792 return array_ops.reshape(v, broadcast_shape) 793 return v 794 795 scale, offset = _broadcast(self.gamma), _broadcast(self.beta) 796 797 def _compose_transforms(scale, offset, then_scale, then_offset): 798 if then_scale is not None: 799 scale *= then_scale 800 offset *= then_scale 801 if then_offset is not None: 802 offset += then_offset 803 return (scale, offset) 804 805 # Determine a boolean value for `training`: could be True, False, or None. 806 training_value = control_flow_util.constant_value(training) 807 if training_value == False: # pylint: disable=singleton-comparison,g-explicit-bool-comparison 808 mean, variance = self.moving_mean, self.moving_variance 809 else: 810 if self.adjustment: 811 adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs)) 812 # Adjust only during training. 813 adj_scale = control_flow_util.smart_cond( 814 training, lambda: adj_scale, lambda: array_ops.ones_like(adj_scale)) 815 adj_bias = control_flow_util.smart_cond( 816 training, lambda: adj_bias, lambda: array_ops.zeros_like(adj_bias)) 817 scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset) 818 819 # Some of the computations here are not necessary when training==False 820 # but not a constant. However, this makes the code simpler. 821 keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1 822 mean, variance = self._moments( 823 math_ops.cast(inputs, self._param_dtype), 824 reduction_axes, 825 keep_dims=keep_dims) 826 827 moving_mean = self.moving_mean 828 moving_variance = self.moving_variance 829 830 mean = control_flow_util.smart_cond( 831 training, lambda: mean, 832 lambda: ops.convert_to_tensor_v2_with_dispatch(moving_mean)) 833 variance = control_flow_util.smart_cond( 834 training, lambda: variance, 835 lambda: ops.convert_to_tensor_v2_with_dispatch(moving_variance)) 836 837 if self.virtual_batch_size is not None: 838 # This isn't strictly correct since in ghost batch norm, you are 839 # supposed to sequentially update the moving_mean and moving_variance 840 # with each sub-batch. However, since the moving statistics are only 841 # used during evaluation, it is more efficient to just update in one 842 # step and should not make a significant difference in the result. 843 new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True) 844 new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True) 845 else: 846 new_mean, new_variance = mean, variance 847 848 if self._support_zero_size_input(): 849 # Keras assumes that batch dimension is the first dimension for Batch 850 # Normalization. 851 input_batch_size = array_ops.shape(inputs)[0] 852 else: 853 input_batch_size = None 854 855 if self.renorm: 856 r, d, new_mean, new_variance = self._renorm_correction_and_moments( 857 new_mean, new_variance, training, input_batch_size) 858 # When training, the normalized values (say, x) will be transformed as 859 # x * gamma + beta without renorm, and (x * r + d) * gamma + beta 860 # = x * (r * gamma) + (d * gamma + beta) with renorm. 861 r = _broadcast(array_ops.stop_gradient(r, name='renorm_r')) 862 d = _broadcast(array_ops.stop_gradient(d, name='renorm_d')) 863 scale, offset = _compose_transforms(r, d, scale, offset) 864 865 def _do_update(var, value): 866 """Compute the updates for mean and variance.""" 867 return self._assign_moving_average(var, value, self.momentum, 868 input_batch_size) 869 870 def mean_update(): 871 true_branch = lambda: _do_update(self.moving_mean, new_mean) 872 false_branch = lambda: self.moving_mean 873 return control_flow_util.smart_cond(training, true_branch, false_branch) 874 875 def variance_update(): 876 """Update the moving variance.""" 877 878 def true_branch_renorm(): 879 # We apply epsilon as part of the moving_stddev to mirror the training 880 # code path. 881 moving_stddev = _do_update(self.moving_stddev, 882 math_ops.sqrt(new_variance + self.epsilon)) 883 return self._assign_new_value( 884 self.moving_variance, 885 # Apply relu in case floating point rounding causes it to go 886 # negative. 887 K.relu(moving_stddev * moving_stddev - self.epsilon)) 888 889 if self.renorm: 890 true_branch = true_branch_renorm 891 else: 892 true_branch = lambda: _do_update(self.moving_variance, new_variance) 893 894 false_branch = lambda: self.moving_variance 895 return control_flow_util.smart_cond(training, true_branch, false_branch) 896 897 self.add_update(mean_update) 898 self.add_update(variance_update) 899 900 mean = math_ops.cast(mean, inputs.dtype) 901 variance = math_ops.cast(variance, inputs.dtype) 902 if offset is not None: 903 offset = math_ops.cast(offset, inputs.dtype) 904 if scale is not None: 905 scale = math_ops.cast(scale, inputs.dtype) 906 outputs = nn.batch_normalization(inputs, _broadcast(mean), 907 _broadcast(variance), offset, scale, 908 self.epsilon) 909 if inputs_dtype in (dtypes.float16, dtypes.bfloat16): 910 outputs = math_ops.cast(outputs, inputs_dtype) 911 912 # If some components of the shape got lost due to adjustments, fix that. 913 outputs.set_shape(input_shape) 914 915 if self.virtual_batch_size is not None: 916 outputs = undo_virtual_batching(outputs) 917 return outputs 918 919 def compute_output_shape(self, input_shape): 920 return input_shape 921 922 def get_config(self): 923 config = { 924 'axis': 925 self.axis, 926 'momentum': 927 self.momentum, 928 'epsilon': 929 self.epsilon, 930 'center': 931 self.center, 932 'scale': 933 self.scale, 934 'beta_initializer': 935 initializers.serialize(self.beta_initializer), 936 'gamma_initializer': 937 initializers.serialize(self.gamma_initializer), 938 'moving_mean_initializer': 939 initializers.serialize(self.moving_mean_initializer), 940 'moving_variance_initializer': 941 initializers.serialize(self.moving_variance_initializer), 942 'beta_regularizer': 943 regularizers.serialize(self.beta_regularizer), 944 'gamma_regularizer': 945 regularizers.serialize(self.gamma_regularizer), 946 'beta_constraint': 947 constraints.serialize(self.beta_constraint), 948 'gamma_constraint': 949 constraints.serialize(self.gamma_constraint) 950 } 951 # Only add TensorFlow-specific parameters if they are set, so as to preserve 952 # model compatibility with external Keras. 953 if self.renorm: 954 config['renorm'] = True 955 config['renorm_clipping'] = self.renorm_clipping 956 config['renorm_momentum'] = self.renorm_momentum 957 if self.virtual_batch_size is not None: 958 config['virtual_batch_size'] = self.virtual_batch_size 959 # Note: adjustment is not serializable. 960 if self.adjustment is not None: 961 logging.warning('The `adjustment` function of this `BatchNormalization` ' 962 'layer cannot be serialized and has been omitted from ' 963 'the layer config. It will not be included when ' 964 're-creating the layer from the saved config.') 965 base_config = super(BatchNormalizationBase, self).get_config() 966 return dict(list(base_config.items()) + list(config.items())) 967 968 969# pylint: disable=missing-docstring 970@keras_export(v1=['keras.layers.BatchNormalization']) 971class BatchNormalization(BatchNormalizationBase): 972 _USE_V2_BEHAVIOR = False 973 974 975@keras_export('keras.layers.LayerNormalization') 976class LayerNormalization(Layer): 977 """Layer normalization layer (Ba et al., 2016). 978 979 Normalize the activations of the previous layer for each given example in a 980 batch independently, rather than across a batch like Batch Normalization. 981 i.e. applies a transformation that maintains the mean activation within each 982 example close to 0 and the activation standard deviation close to 1. 983 984 Given a tensor `inputs`, moments are calculated and normalization 985 is performed across the axes specified in `axis`. 986 987 Example: 988 989 >>> data = tf.constant(np.arange(10).reshape(5, 2) * 10, dtype=tf.float32) 990 >>> print(data) 991 tf.Tensor( 992 [[ 0. 10.] 993 [20. 30.] 994 [40. 50.] 995 [60. 70.] 996 [80. 90.]], shape=(5, 2), dtype=float32) 997 998 >>> layer = tf.keras.layers.LayerNormalization(axis=1) 999 >>> output = layer(data) 1000 >>> print(output) 1001 tf.Tensor( 1002 [[-1. 1.] 1003 [-1. 1.] 1004 [-1. 1.] 1005 [-1. 1.] 1006 [-1. 1.]], shape=(5, 2), dtype=float32) 1007 1008 Notice that with Layer Normalization the normalization happens across the 1009 axes *within* each example, rather than across different examples in the 1010 batch. 1011 1012 If `scale` or `center` are enabled, the layer will scale the normalized 1013 outputs by broadcasting them with a trainable variable `gamma`, and center 1014 the outputs by broadcasting with a trainable variable `beta`. `gamma` will 1015 default to a ones tensor and `beta` will default to a zeros tensor, so that 1016 centering and scaling are no-ops before training has begun. 1017 1018 So, with scaling and centering enabled the normalization equations 1019 are as follows: 1020 1021 Let the intermediate activations for a mini-batch to be the `inputs`. 1022 1023 For each sample `x_i` in `inputs` with `k` features, we compute the mean and 1024 variance of the sample: 1025 1026 ```python 1027 mean_i = sum(x_i[j] for j in range(k)) / k 1028 var_i = sum((x_i[j] - mean_i) ** 2 for j in range(k)) / k 1029 ``` 1030 1031 and then compute a normalized `x_i_normalized`, including a small factor 1032 `epsilon` for numerical stability. 1033 1034 ```python 1035 x_i_normalized = (x_i - mean_i) / sqrt(var_i + epsilon) 1036 ``` 1037 1038 And finally `x_i_normalized ` is linearly transformed by `gamma` and `beta`, 1039 which are learned parameters: 1040 1041 ```python 1042 output_i = x_i_normalized * gamma + beta 1043 ``` 1044 1045 `gamma` and `beta` will span the axes of `inputs` specified in `axis`, and 1046 this part of the inputs' shape must be fully defined. 1047 1048 For example: 1049 1050 >>> layer = tf.keras.layers.LayerNormalization(axis=[1, 2, 3]) 1051 >>> layer.build([5, 20, 30, 40]) 1052 >>> print(layer.beta.shape) 1053 (20, 30, 40) 1054 >>> print(layer.gamma.shape) 1055 (20, 30, 40) 1056 1057 Note that other implementations of layer normalization may choose to define 1058 `gamma` and `beta` over a separate set of axes from the axes being 1059 normalized across. For example, Group Normalization 1060 ([Wu et al. 2018](https://arxiv.org/abs/1803.08494)) with group size of 1 1061 corresponds to a Layer Normalization that normalizes across height, width, 1062 and channel and has `gamma` and `beta` span only the channel dimension. 1063 So, this Layer Normalization implementation will not match a Group 1064 Normalization layer with group size set to 1. 1065 1066 Args: 1067 axis: Integer or List/Tuple. The axis or axes to normalize across. Typically 1068 this is the features axis/axes. The left-out axes are typically the batch 1069 axis/axes. This argument defaults to `-1`, the last dimension in the 1070 input. 1071 epsilon: Small float added to variance to avoid dividing by zero. Defaults 1072 to 1e-3 1073 center: If True, add offset of `beta` to normalized tensor. If False, `beta` 1074 is ignored. Defaults to True. 1075 scale: If True, multiply by `gamma`. If False, `gamma` is not used. Defaults 1076 to True. When the next layer is linear (also e.g. `nn.relu`), this can be 1077 disabled since the scaling will be done by the next layer. 1078 beta_initializer: Initializer for the beta weight. Defaults to zeros. 1079 gamma_initializer: Initializer for the gamma weight. Defaults to ones. 1080 beta_regularizer: Optional regularizer for the beta weight. None by default. 1081 gamma_regularizer: Optional regularizer for the gamma weight. None by 1082 default. 1083 beta_constraint: Optional constraint for the beta weight. None by default. 1084 gamma_constraint: Optional constraint for the gamma weight. None by default. 1085 1086 Input shape: 1087 Arbitrary. Use the keyword argument `input_shape` (tuple of 1088 integers, does not include the samples axis) when using this layer as the 1089 first layer in a model. 1090 1091 Output shape: 1092 Same shape as input. 1093 1094 Reference: 1095 - [Lei Ba et al., 2016](https://arxiv.org/abs/1607.06450). 1096 """ 1097 1098 def __init__(self, 1099 axis=-1, 1100 epsilon=1e-3, 1101 center=True, 1102 scale=True, 1103 beta_initializer='zeros', 1104 gamma_initializer='ones', 1105 beta_regularizer=None, 1106 gamma_regularizer=None, 1107 beta_constraint=None, 1108 gamma_constraint=None, 1109 **kwargs): 1110 super(LayerNormalization, self).__init__(**kwargs) 1111 if isinstance(axis, (list, tuple)): 1112 self.axis = axis[:] 1113 elif isinstance(axis, int): 1114 self.axis = axis 1115 else: 1116 raise TypeError('Expected an int or a list/tuple of ints for the ' 1117 'argument \'axis\', but received: %r' % axis) 1118 1119 self.epsilon = epsilon 1120 self.center = center 1121 self.scale = scale 1122 self.beta_initializer = initializers.get(beta_initializer) 1123 self.gamma_initializer = initializers.get(gamma_initializer) 1124 self.beta_regularizer = regularizers.get(beta_regularizer) 1125 self.gamma_regularizer = regularizers.get(gamma_regularizer) 1126 self.beta_constraint = constraints.get(beta_constraint) 1127 self.gamma_constraint = constraints.get(gamma_constraint) 1128 1129 self.supports_masking = True 1130 1131 # Indicates whether a faster fused implementation can be used. This will be 1132 # set to True or False in build()" 1133 self._fused = None 1134 1135 def _fused_can_be_used(self, ndims): 1136 """Return false if fused implementation cannot be used. 1137 1138 Check if the axis is contiguous and can be collapsed into the last axis. 1139 The self.axis is assumed to have no duplicates. 1140 """ 1141 axis = sorted(self.axis) 1142 can_use_fused = False 1143 1144 if axis[-1] == ndims - 1 and axis[-1] - axis[0] == len(axis) - 1: 1145 can_use_fused = True 1146 1147 # fused_batch_norm will silently raise epsilon to be at least 1.001e-5, so 1148 # we cannot used the fused version if epsilon is below that value. Also, the 1149 # variable dtype must be float32, as fused_batch_norm only supports float32 1150 # variables. 1151 if self.epsilon < 1.001e-5 or self.dtype != 'float32': 1152 can_use_fused = False 1153 1154 return can_use_fused 1155 1156 def build(self, input_shape): 1157 ndims = len(input_shape) 1158 if ndims is None: 1159 raise ValueError('Input shape %s has undefined rank.' % input_shape) 1160 1161 # Convert axis to list and resolve negatives 1162 if isinstance(self.axis, int): 1163 self.axis = [self.axis] 1164 elif isinstance(self.axis, tuple): 1165 self.axis = list(self.axis) 1166 for idx, x in enumerate(self.axis): 1167 if x < 0: 1168 self.axis[idx] = ndims + x 1169 1170 # Validate axes 1171 for x in self.axis: 1172 if x < 0 or x >= ndims: 1173 raise ValueError('Invalid axis: %d' % x) 1174 if len(self.axis) != len(set(self.axis)): 1175 raise ValueError('Duplicate axis: {}'.format(tuple(self.axis))) 1176 1177 param_shape = [input_shape[dim] for dim in self.axis] 1178 if self.scale: 1179 self.gamma = self.add_weight( 1180 name='gamma', 1181 shape=param_shape, 1182 initializer=self.gamma_initializer, 1183 regularizer=self.gamma_regularizer, 1184 constraint=self.gamma_constraint, 1185 trainable=True, 1186 experimental_autocast=False) 1187 else: 1188 self.gamma = None 1189 1190 if self.center: 1191 self.beta = self.add_weight( 1192 name='beta', 1193 shape=param_shape, 1194 initializer=self.beta_initializer, 1195 regularizer=self.beta_regularizer, 1196 constraint=self.beta_constraint, 1197 trainable=True, 1198 experimental_autocast=False) 1199 else: 1200 self.beta = None 1201 1202 self._fused = self._fused_can_be_used(ndims) 1203 1204 self.built = True 1205 1206 def call(self, inputs): 1207 # Compute the axes along which to reduce the mean / variance 1208 input_shape = inputs.shape 1209 ndims = len(input_shape) 1210 1211 # Broadcasting only necessary for norm when the axis is not just 1212 # the last dimension 1213 broadcast_shape = [1] * ndims 1214 for dim in self.axis: 1215 broadcast_shape[dim] = input_shape.dims[dim].value 1216 1217 def _broadcast(v): 1218 if (v is not None and len(v.shape) != ndims and self.axis != [ndims - 1]): 1219 return array_ops.reshape(v, broadcast_shape) 1220 return v 1221 1222 if not self._fused: 1223 input_dtype = inputs.dtype 1224 if input_dtype in ('float16', 'bfloat16') and self.dtype == 'float32': 1225 # If mixed precision is used, cast inputs to float32 so that this is at 1226 # least as numerically stable as the fused version. 1227 inputs = math_ops.cast(inputs, 'float32') 1228 1229 # Calculate the moments on the last axis (layer activations). 1230 mean, variance = nn.moments(inputs, self.axis, keep_dims=True) 1231 1232 scale, offset = _broadcast(self.gamma), _broadcast(self.beta) 1233 1234 # Compute layer normalization using the batch_normalization function. 1235 outputs = nn.batch_normalization( 1236 inputs, 1237 mean, 1238 variance, 1239 offset=offset, 1240 scale=scale, 1241 variance_epsilon=self.epsilon) 1242 outputs = math_ops.cast(outputs, input_dtype) 1243 else: 1244 # Collapse dims before self.axis, and dims in self.axis 1245 pre_dim, in_dim = (1, 1) 1246 axis = sorted(self.axis) 1247 tensor_shape = array_ops.shape(inputs) 1248 for dim in range(0, ndims): 1249 dim_tensor = tensor_shape[dim] 1250 if dim < axis[0]: 1251 pre_dim = pre_dim * dim_tensor 1252 else: 1253 assert dim in axis 1254 in_dim = in_dim * dim_tensor 1255 1256 squeezed_shape = [1, pre_dim, in_dim, 1] 1257 # This fused operation requires reshaped inputs to be NCHW. 1258 data_format = 'NCHW' 1259 1260 inputs = array_ops.reshape(inputs, squeezed_shape) 1261 1262 def _set_const_tensor(val, dtype, shape): 1263 return array_ops.fill(shape, constant_op.constant(val, dtype=dtype)) 1264 1265 # self.gamma and self.beta have the wrong shape for fused_batch_norm, so 1266 # we cannot pass them as the scale and offset parameters. Therefore, we 1267 # create two constant tensors in correct shapes for fused_batch_norm and 1268 # later construct a separate calculation on the scale and offset. 1269 scale = _set_const_tensor(1.0, self.dtype, [pre_dim]) 1270 offset = _set_const_tensor(0.0, self.dtype, [pre_dim]) 1271 1272 # Compute layer normalization using the fused_batch_norm function. 1273 outputs, _, _ = nn.fused_batch_norm( 1274 inputs, 1275 scale=scale, 1276 offset=offset, 1277 epsilon=self.epsilon, 1278 data_format=data_format) 1279 1280 outputs = array_ops.reshape(outputs, tensor_shape) 1281 1282 scale, offset = _broadcast(self.gamma), _broadcast(self.beta) 1283 1284 if scale is not None: 1285 outputs = outputs * math_ops.cast(scale, outputs.dtype) 1286 if offset is not None: 1287 outputs = outputs + math_ops.cast(offset, outputs.dtype) 1288 1289 # If some components of the shape got lost due to adjustments, fix that. 1290 outputs.set_shape(input_shape) 1291 1292 return outputs 1293 1294 def compute_output_shape(self, input_shape): 1295 return input_shape 1296 1297 def get_config(self): 1298 config = { 1299 'axis': self.axis, 1300 'epsilon': self.epsilon, 1301 'center': self.center, 1302 'scale': self.scale, 1303 'beta_initializer': initializers.serialize(self.beta_initializer), 1304 'gamma_initializer': initializers.serialize(self.gamma_initializer), 1305 'beta_regularizer': regularizers.serialize(self.beta_regularizer), 1306 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), 1307 'beta_constraint': constraints.serialize(self.beta_constraint), 1308 'gamma_constraint': constraints.serialize(self.gamma_constraint) 1309 } 1310 base_config = super(LayerNormalization, self).get_config() 1311 return dict(list(base_config.items()) + list(config.items())) 1312