1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains LossScale classes.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import abc 21 22import six 23 24from tensorflow.python.distribute import distribution_strategy_context 25from tensorflow.python.distribute import reduce_util 26from tensorflow.python.eager import context 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.ops import control_flow_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import variable_scope 32from tensorflow.python.ops import variables 33from tensorflow.python.training.tracking import base as trackable 34from tensorflow.python.util import deprecation 35from tensorflow.python.util import nest 36from tensorflow.python.util.tf_export import tf_export 37 38 39@six.add_metaclass(abc.ABCMeta) 40@deprecation.deprecated_endpoints('mixed_precision.experimental.LossScale', 41 'train.experimental.LossScale') 42@tf_export( 43 'mixed_precision.experimental.LossScale', 44 'train.experimental.LossScale', 45 v1=[ 46 'mixed_precision.LossScale', 47 'mixed_precision.experimental.LossScale', 48 'train.experimental.LossScale' 49 ]) 50class LossScale(trackable.Trackable): 51 """Base class for all TF1 loss scales. 52 53 WARNING: This class is deprecated and will be unexposed from the TF 2 54 namespace starting in TensorFlow 2.5. In TensorFlow 2.5, this class will only 55 be accessible as `tf.compat.v1.mixed_precision.LossScale`. Additionally in 56 2.5, you will no longer be able to pass a `LossScale` to a 57 `tf.keras.mixed_precision.Policy`. All the functionality in this class has 58 been merged into `tf.keras.mixed_precision.LossScaleOptimizer`, so this class 59 is no longer needed. 60 61 This is an abstract base class, so you cannot instantiate it directly. 62 Instead, use one of its concrete subclasses: 63 * `tf.compat.v1.mixed_precision.DynamicLossScale` 64 * `tf.compat.v1.mixed_precision.FixedLossScale` 65 66 Loss scaling is a process that multiplies the loss by a multiplier called the 67 loss scale, and divides each gradient by the same multiplier. The pseudocode 68 for this process is: 69 70 ``` 71 loss = ... 72 loss *= loss_scale 73 grads = gradients(loss, vars) 74 grads /= loss_scale 75 ``` 76 77 Mathematically, loss scaling has no effect, but can help avoid numerical 78 underflow in intermediate gradients when float16 tensors are used for mixed 79 precision training. By multiplying the loss, each intermediate gradient will 80 have the same multiplier applied. 81 82 Instances of this class represent a loss scale. Calling instances of this 83 class returns the loss scale as a scalar float32 tensor, while method 84 `update()` updates the loss scale depending on the values of the gradients. 85 Optimizers use instances of this class to scale loss and gradients. 86 87 In most functions that accept a LossScale, you can also pass an int (such as 88 8) to create a `FixedLossScale` or the string `"dynamic"` to create a dynamic 89 loss scale. 90 """ 91 92 def __init__(self): 93 """Initializes the loss scale class.""" 94 self._weights = {} 95 96 @abc.abstractmethod 97 def __call__(self): 98 """Returns the current loss scale as a scalar `float32` tensor.""" 99 pass 100 101 @abc.abstractmethod 102 def update(self, grads): 103 """Updates the value of the loss scale. 104 105 The loss scale will be potentially updated, based on the value of `grads`. 106 The tensor returned by calling this class is only updated when this function 107 is evaluated. 108 109 In eager mode, this directly updates the loss scale, so that calling 110 `__call__` will return the newly updated loss scale. In graph mode, 111 this returns an op that, when evaluated, updates the loss scale. 112 113 This function also returns a `should_apply_gradients` bool. If False, 114 gradients should not be applied to the variables that step, as nonfinite 115 gradients were found, and the loss scale has been be updated to reduce the 116 chance of finding nonfinite gradients in the next step. Some loss scale 117 classes will always return True, as they cannot adjust themselves in 118 response to nonfinite gradients. 119 120 When a DistributionStrategy is used, this function may only be called in a 121 cross-replica context. 122 123 Args: 124 grads: A nested structure of unscaled gradients, each which is the 125 gradient of the loss with respect to a weight. The gradients should have 126 already been divided by the loss scale being before passed to this 127 function. 'None' gradients are accepted, and are ignored. 128 129 Returns: 130 update_op: In eager mode, None. In graph mode, an op to update the loss 131 scale. 132 should_apply_gradients: Either a bool or a scalar boolean tensor. If 133 False, the caller should skip applying `grads` to the variables this 134 step. 135 """ 136 pass 137 138 def _add_weight(self, name, initial_value, dtype=None): 139 """Adds a weight to this loss scale. 140 141 Args: 142 name: Variable name. 143 initial_value: The variable's initial value. 144 dtype: The type of the variable. 145 146 Returns: 147 A variable. 148 149 Raises: 150 RuntimeError: If a weight with `name` has already been added. 151 """ 152 variable = variable_scope.variable( 153 initial_value=initial_value, 154 name=name, 155 dtype=dtype, 156 trainable=False, 157 use_resource=True, 158 synchronization=variables.VariableSynchronization.AUTO, 159 # Set aggregation to NONE, as loss scaling variables should never be 160 # aggregated. 161 aggregation=variables.VariableAggregation.NONE) 162 if context.executing_eagerly(): 163 graph_key = None 164 else: 165 graph = ops.get_default_graph() 166 graph_key = graph._graph_key # pylint: disable=protected-access 167 168 key = (name, graph_key) 169 if self._weights.get(key, None) is not None: 170 raise RuntimeError('Duplicate variables detected. {}'.format(key)) 171 self._weights[key] = variable 172 self._handle_deferred_dependencies(name=name, trackable=variable) 173 return variable 174 175 @property 176 def _checkpoint_dependencies(self): 177 """From Trackable. Gather graph-specific weights to save.""" 178 if context.executing_eagerly(): 179 graph_key = None 180 else: 181 graph = ops.get_default_graph() 182 graph_key = graph._graph_key # pylint: disable=protected-access 183 weights = [] 184 for (name, g), v in sorted(self._weights.items(), key=lambda i: i[0][0]): 185 if g == graph_key: 186 weights.append(trackable.TrackableReference(name=name, ref=v)) 187 return super(LossScale, self)._checkpoint_dependencies + weights 188 189 def _lookup_dependency(self, name): 190 """From Trackable. Find a weight in the current graph.""" 191 unconditional = super(LossScale, self)._lookup_dependency(name) 192 if unconditional is not None: 193 return unconditional 194 if context.executing_eagerly(): 195 graph_key = None 196 else: 197 graph = ops.get_default_graph() 198 graph_key = graph._graph_key # pylint: disable=protected-access 199 return self._weights.get((name, graph_key), None) 200 201 @abc.abstractmethod 202 def get_config(self): 203 """Returns the config of this loss scale.""" 204 pass 205 206 @classmethod 207 def from_config(cls, config): 208 """Creates the LossScale from its config.""" 209 return cls(**config) 210 211 212@deprecation.deprecated_endpoints('mixed_precision.experimental.FixedLossScale', 213 'train.experimental.FixedLossScale') 214@tf_export( 215 'mixed_precision.experimental.FixedLossScale', 216 'train.experimental.FixedLossScale', 217 v1=[ 218 'mixed_precision.FixedLossScale', 219 'mixed_precision.experimental.FixedLossScale', 220 'train.experimental.FixedLossScale' 221 ]) 222class FixedLossScale(LossScale): 223 """Loss scale with a fixed value. 224 225 WARNING: This class is deprecated and will be unexposed from the TF 2 226 namespace starting in TensorFlow 2.5. In TensorFlow 2.5, this class will only 227 be accessible as `tf.compat.v1.mixed_precision.FixedLossScale`. Additionally 228 in 2.5, you will no longer be able to pass a `FixedLossScale` to a 229 `tf.keras.mixed_precision.Policy`. All the functionality in this class has 230 been merged into `tf.keras.mixed_precision.LossScaleOptimizer`, so this class 231 is no longer needed. 232 233 The loss scale is not updated for the lifetime of instances of this class. 234 A given instance of this class always returns the same number when called. 235 """ 236 237 @deprecation.deprecated( 238 None, 'Use tf.keras.mixed_precision.LossScaleOptimizer instead. ' 239 'LossScaleOptimizer now has all the functionality of ' 240 'FixedLossScale') 241 def __init__(self, loss_scale_value): 242 """Creates the fixed loss scale. 243 244 Args: 245 loss_scale_value: A Python float. Its ideal value varies depending on 246 models to run. Choosing a too small loss_scale might affect model 247 quality; a too big loss_scale might cause inf or nan. There is no single 248 right loss_scale to apply. There is no harm choosing a relatively big 249 number as long as no nan or inf is encountered in training. 250 251 Raises: 252 ValueError: If loss_scale_value is less than 1. 253 """ 254 super(FixedLossScale, self).__init__() 255 if not isinstance(loss_scale_value, six.integer_types + (float,)): 256 raise ValueError('loss_scale_value must be a Python int or float.') 257 if loss_scale_value < 1: 258 raise ValueError('loss_scale_value must be at least 1.') 259 # It's important we do not create tensors in the constructor, as such 260 # tensors might be on a different device or tf.function vs when the tensor 261 # is used. This would hurt performance. Therefore, we do not create a tensor 262 # from loss_scale_value, but instead leave it as a Python float. 263 # TODO(reedwm): Also do not create tensors in the DynamicLossScale 264 # constructor. 265 self._loss_scale_value = float(loss_scale_value) 266 267 def __call__(self): 268 return ops.convert_to_tensor(self._loss_scale_value) 269 270 def update(self, grads): 271 del grads 272 return control_flow_ops.no_op(), True 273 274 def __repr__(self): 275 return 'FixedLossScale(%s)' % self._loss_scale_value 276 277 def get_config(self): 278 return {'loss_scale_value': self._loss_scale_value} 279 280 281def _is_all_finite(grads): 282 """Returns a scalar boolean tensor indicating if all gradients are finite.""" 283 is_finite_per_grad = [ 284 math_ops.reduce_all(math_ops.is_finite(g)) for g in grads if g is not None 285 ] 286 return math_ops.reduce_all(is_finite_per_grad) 287 288 289def _op_in_graph_mode(tensor): 290 """Returns the tensor's op in graph mode, or the tensor in eager mode. 291 292 This is useful because sometimes an op is needed in graph mode instead of a 293 tensor. In eager mode, there are no ops. 294 295 Args: 296 tensor: A tensor. 297 298 Returns: 299 The tensor's op in graph mode. The tensor in eager mode. 300 """ 301 if context.executing_eagerly(): 302 return tensor 303 return tensor.op 304 305 306def _assign_if_finite(var, value): 307 """Assigns a value to a variable if the value is finite.""" 308 return control_flow_ops.cond( 309 math_ops.is_finite(value), lambda: _op_in_graph_mode(var.assign(value)), 310 control_flow_ops.no_op) 311 312 313@deprecation.deprecated_endpoints( 314 'mixed_precision.experimental.DynamicLossScale', 315 'train.experimental.DynamicLossScale') 316@tf_export( 317 'mixed_precision.experimental.DynamicLossScale', 318 'train.experimental.DynamicLossScale', 319 v1=[ 320 'mixed_precision.DynamicLossScale', 321 'mixed_precision.experimental.DynamicLossScale', 322 'train.experimental.DynamicLossScale' 323 ]) 324class DynamicLossScale(LossScale): 325 """Loss scale that dynamically adjusts itself. 326 327 WARNING: This class is deprecated and will be unexposed from the TF 2 328 namespace starting in TensorFlow 2.5. In TensorFlow 2.5, this class will only 329 be accessible as `tf.compat.v1.mixed_precision.DynamicLossScale`. Additionally 330 in 2.5, you will no longer be able to pass a `DynamicLossScale` to a 331 `tf.keras.mixed_precision.Policy`. All the functionality in this class has 332 been merged into `tf.keras.mixed_precision.LossScaleOptimizer`, so this class 333 is no longer needed. 334 335 Dynamic loss scaling works by adjusting the loss scale as training progresses. 336 The goal is to keep the loss scale as high as possible without overflowing the 337 gradients. As long as the gradients do not overflow, raising the loss scale 338 never hurts. 339 340 The algorithm starts by setting the loss scale to an initial value. Every N 341 steps that the gradients are finite, the loss scale is increased by some 342 factor. However, if a NaN or Inf gradient is found, the gradients for that 343 step are not applied, and the loss scale is decreased by the factor. This 344 process tends to keep the loss scale as high as possible without gradients 345 overflowing. 346 """ 347 348 @deprecation.deprecated( 349 None, 'Use tf.keras.mixed_precision.LossScaleOptimizer instead. ' 350 'LossScaleOptimizer now has all the functionality of ' 351 'DynamicLossScale') 352 def __init__(self, 353 initial_loss_scale=2 ** 15, # See docstring for why this is big. 354 increment_period=2000, 355 multiplier=2.): 356 """Creates the dynamic loss scale. 357 358 Args: 359 initial_loss_scale: A Python float. The loss scale to use at the 360 beginning. It's better to start this at a very high number, because a 361 loss scale that is too high gets lowered far more quickly than a loss 362 scale that is too low gets raised. The default is 2 ** 15, which is 363 approximately half the maximum float16 value. 364 increment_period: Increases loss scale every `increment_period` 365 consecutive steps that finite gradients are encountered. If a nonfinite 366 gradient is encountered, the count is reset back to zero. 367 multiplier: The multiplier to use when increasing or decreasing the loss 368 scale. 369 """ 370 super(DynamicLossScale, self).__init__() 371 self._initial_loss_scale = float(initial_loss_scale) 372 self._increment_period = int(increment_period) 373 self._multiplier = float(multiplier) 374 375 self._current_loss_scale = self._add_weight( 376 name='current_loss_scale', 377 dtype=dtypes.float32, 378 initial_value=self._initial_loss_scale) 379 # The number of consecutive steps with finite gradients since the last 380 # nonfinite gradient or change in loss scale. 381 self._num_good_steps = self._add_weight( 382 name='good_steps', dtype=dtypes.int64, initial_value=0) 383 384 @property 385 def initial_loss_scale(self): 386 return self._initial_loss_scale 387 388 @property 389 def increment_period(self): 390 return self._increment_period 391 392 @property 393 def multiplier(self): 394 return self._multiplier 395 396 def __call__(self): 397 return ops.convert_to_tensor(self._current_loss_scale) 398 399 def update(self, grads): 400 """Updates loss scale based on if gradients are finite in current step.""" 401 grads = nest.flatten(grads) 402 if distribution_strategy_context.has_strategy(): 403 distribution = distribution_strategy_context.get_cross_replica_context() 404 405 def get_is_finite(grads): 406 is_finite = _is_all_finite(grads) 407 # We cast to float, because we cannot reduce booleans with 408 # DistributionStrategy. 409 return math_ops.cast(is_finite, dtypes.float32) 410 411 is_finite_float = distribution.extended.call_for_each_replica( 412 get_is_finite, args=(grads,)) 413 reduced_is_finite_float = distribution.reduce(reduce_util.ReduceOp.SUM, 414 is_finite_float, axis=None) 415 is_finite = math_ops.equal(reduced_is_finite_float, 416 distribution.num_replicas_in_sync) 417 else: 418 is_finite = _is_all_finite(grads) 419 420 def update_if_finite_grads(): 421 """Update assuming the gradients are finite.""" 422 423 def incr_loss_scale(): 424 new_loss_scale = self._current_loss_scale * self._multiplier 425 return control_flow_ops.group( 426 _assign_if_finite(self._current_loss_scale, new_loss_scale), 427 self._num_good_steps.assign(0)) 428 429 return control_flow_ops.cond( 430 self._num_good_steps + 1 >= self._increment_period, 431 incr_loss_scale, lambda: _op_in_graph_mode( 432 self._num_good_steps.assign_add(1))) 433 434 def update_if_not_finite_grads(): 435 """Update assuming the gradients are nonfinite.""" 436 437 new_loss_scale = math_ops.maximum( 438 self._current_loss_scale / self._multiplier, 1) 439 return control_flow_ops.group( 440 self._num_good_steps.assign(0), 441 self._current_loss_scale.assign(new_loss_scale)) 442 443 update_op = control_flow_ops.cond(is_finite, update_if_finite_grads, 444 update_if_not_finite_grads) 445 should_apply_gradients = is_finite 446 return update_op, should_apply_gradients 447 448 def __repr__(self): 449 if context.executing_eagerly(): 450 return ('DynamicLossScale(current_loss_scale=%s, num_good_steps=%s, ' 451 'initial_loss_scale=%s, increment_period=%s, multiplier=%s)' % 452 (self._current_loss_scale.numpy(), self._num_good_steps.numpy(), 453 self.initial_loss_scale, self.increment_period, self.multiplier)) 454 else: 455 return ('DynamicLossScale(initial_loss_scale=%s, increment_period=%s, ' 456 'multiplier=%s)' % 457 (self.initial_loss_scale, self.increment_period, self.multiplier)) 458 459 def get_config(self): 460 return { 461 'initial_loss_scale': self.initial_loss_scale, 462 'increment_period': self.increment_period, 463 'multiplier': self.multiplier, 464 } 465 466 467def get(identifier): 468 """Get a loss scale object.""" 469 if isinstance(identifier, six.integer_types + (float,)): 470 return FixedLossScale(identifier) 471 if identifier == 'dynamic': 472 return DynamicLossScale() 473 if isinstance(identifier, LossScale): 474 return identifier 475 elif identifier is None: 476 return None 477 else: 478 raise ValueError('Could not interpret loss scale identifier: %s' % 479 identifier) 480