1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains the loss scaling optimizer class.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.distribute import collective_all_reduce_strategy 21from tensorflow.python.distribute import distribution_strategy_context 22from tensorflow.python.distribute import mirrored_strategy 23from tensorflow.python.distribute import one_device_strategy 24from tensorflow.python.distribute import reduce_util 25from tensorflow.python.distribute import tpu_strategy 26from tensorflow.python.eager import backprop 27from tensorflow.python.eager import context 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import smart_cond 31from tensorflow.python.keras import backend 32from tensorflow.python.keras import optimizers 33from tensorflow.python.keras.mixed_precision import loss_scale as keras_loss_scale_module 34from tensorflow.python.keras.optimizer_v2 import optimizer_v2 35from tensorflow.python.ops import control_flow_ops 36from tensorflow.python.ops import math_ops 37from tensorflow.python.ops import variable_scope 38from tensorflow.python.ops import variables 39from tensorflow.python.platform import tf_logging 40from tensorflow.python.training.experimental import loss_scale as loss_scale_module 41from tensorflow.python.training.experimental import mixed_precision 42from tensorflow.python.training.tracking import base as trackable 43from tensorflow.python.util import nest 44from tensorflow.python.util.tf_export import keras_export 45 46 47class _UnwrapPreventer(object): 48 """Wrapper that DistributionStrategy will not unwrap. 49 50 Typically, DistributionStrategy will unwrap values when going from a cross- 51 replica context to a replica context via `call_for_each_replica`. This class 52 is a wrapper that DistributionStrategy will not unwrap, so it can be used to 53 prevent it from unwrapping a value. 54 55 TODO(reedwm): Find/implement a better way of preventing values from being 56 unwrapped by DistributionStrategy 57 """ 58 59 __slots__ = ['value'] 60 61 def __init__(self, value): 62 self.value = value 63 64 65class _DelegatingTrackableMixin(object): 66 """A mixin that delegates all Trackable methods to another trackable object. 67 68 This class must be used with multiple inheritance. A class that subclasses 69 Trackable can also subclass this class, which causes all Trackable methods to 70 be delegated to the trackable object passed in the constructor. 71 72 A subclass can use this mixin to appear as if it were the trackable passed to 73 the constructor, from a Checkpoint's perspective. LossScaleOptimizer uses this 74 mixin, so that the checkpoint format for a LossScaleOptimizer is identical to 75 the checkpoint format for a normal optimizer. This allows a model to be saved 76 with a normal Optimizer and restored with a LossScaleOptimizer, or vice versa. 77 The only difference in checkpoint format is that the loss scale is also saved 78 with a LossScaleOptimizer. 79 """ 80 81 def __init__(self, trackable_obj): 82 self._trackable = trackable_obj 83 84 # pylint: disable=protected-access 85 @property 86 def _setattr_tracking(self): 87 return self._trackable._setattr_tracking 88 89 @_setattr_tracking.setter 90 def _setattr_tracking(self, value): 91 self._trackable._setattr_tracking = value 92 93 @property 94 def _update_uid(self): 95 return self._trackable._update_uid 96 97 @_update_uid.setter 98 def _update_uid(self, value): 99 self._trackable._update_uid = value 100 101 @property 102 def _unconditional_checkpoint_dependencies(self): 103 return self._trackable._unconditional_checkpoint_dependencies 104 105 @property 106 def _unconditional_dependency_names(self): 107 return self._trackable._unconditional_dependency_names 108 109 @property 110 def _name_based_restores(self): 111 return self._trackable._name_based_restores 112 113 def _maybe_initialize_trackable(self): 114 return self._trackable._maybe_initialize_trackable() 115 116 @property 117 def _object_identifier(self): 118 return self._trackable._object_identifier 119 120 @property 121 def _tracking_metadata(self): 122 return self._trackable._tracking_metadata 123 124 def _no_dependency(self, value): 125 return self._trackable._no_dependency(value) 126 127 def _name_based_attribute_restore(self, checkpoint): 128 return self._trackable._name_based_attribute_restore(checkpoint) 129 130 @property 131 def _checkpoint_dependencies(self): 132 return self._trackable._checkpoint_dependencies 133 134 @property 135 def _deferred_dependencies(self): 136 return self._trackable._deferred_dependencies 137 138 def _lookup_dependency(self, name): 139 self._trackable._lookup_dependency(name) 140 141 def _add_variable_with_custom_getter(self, 142 name, 143 shape=None, 144 dtype=dtypes.float32, 145 initializer=None, 146 getter=None, 147 overwrite=False, 148 **kwargs_for_getter): 149 return self._trackable._add_variable_with_custom_getter( 150 name, shape, dtype, initializer, getter, overwrite, **kwargs_for_getter) 151 152 def _preload_simple_restoration(self, name): 153 return self._trackable._preload_simple_restoration(name) 154 155 def _track_trackable(self, trackable, name, overwrite=False): # pylint: disable=redefined-outer-name 156 return self._trackable._track_trackable(trackable, name, overwrite) 157 158 def _handle_deferred_dependencies(self, name, trackable): # pylint: disable=redefined-outer-name 159 return self._trackable._handle_deferred_dependencies(name, trackable) 160 161 def _restore_from_checkpoint_position(self, checkpoint_position): 162 return self._trackable._restore_from_checkpoint_position( 163 checkpoint_position) 164 165 def _single_restoration_from_checkpoint_position(self, checkpoint_position, 166 visit_queue): 167 return self._trackable._single_restoration_from_checkpoint_position( 168 checkpoint_position, visit_queue) 169 170 def _gather_saveables_for_checkpoint(self): 171 return self._trackable._gather_saveables_for_checkpoint() 172 173 def _list_extra_dependencies_for_serialization(self, serialization_cache): 174 return self._trackable._list_extra_dependencies_for_serialization( 175 serialization_cache) 176 177 def _list_functions_for_serialization(self, serialization_cache): 178 return self._trackable._list_functions_for_serialization( 179 serialization_cache) 180 # pylint: enable=protected-access 181 182 183def _is_all_finite(grads): 184 """Returns a scalar boolean tensor indicating if all gradients are finite.""" 185 is_finite_per_grad = [ 186 math_ops.reduce_all(math_ops.is_finite(g)) for g in grads if g is not None 187 ] 188 return math_ops.reduce_all(is_finite_per_grad) 189 190 191def _op_in_graph_mode(tensor): 192 """Returns the tensor's op in graph mode, or the tensor in eager mode. 193 194 This is useful because sometimes an op is needed in graph mode instead of a 195 tensor. In eager mode, there are no ops. 196 197 Args: 198 tensor: A tensor. 199 200 Returns: 201 The tensor's op in graph mode. The tensor in eager mode. 202 """ 203 if context.executing_eagerly(): 204 return tensor 205 return tensor.op 206 207 208def _assign_if_finite(var, value): 209 """Assigns a value to a variable if the value is finite.""" 210 return control_flow_ops.cond( 211 math_ops.is_finite(value), lambda: _op_in_graph_mode(var.assign(value)), 212 control_flow_ops.no_op) 213 214 215class _DynamicLossScaleState(trackable.Trackable): 216 """The state of a dynamic loss scale.""" 217 218 def __init__(self, 219 initial_loss_scale, 220 growth_steps, 221 multiplier): 222 """Creates the dynamic loss scale.""" 223 super(_DynamicLossScaleState, self).__init__() 224 self._initial_loss_scale = float(initial_loss_scale) 225 self._growth_steps = int(growth_steps) 226 self._multiplier = float(multiplier) 227 228 self._weights = {} 229 self._current_loss_scale = self._add_weight( 230 name='current_loss_scale', 231 dtype=dtypes.float32, 232 initial_value=self._initial_loss_scale) 233 # The number of consecutive steps with finite gradients since the last 234 # nonfinite gradient or change in loss scale. The name is 'good_steps' for 235 # backwards compatibility with older checkpoints. 236 self._counter = self._add_weight( 237 name='good_steps', dtype=dtypes.int64, initial_value=0) 238 239 def _add_weight(self, name, initial_value, dtype=None): 240 """Adds a weight to this loss scale. 241 242 Args: 243 name: Variable name. 244 initial_value: The variable's initial value. 245 dtype: The type of the variable. 246 247 Returns: 248 A variable. 249 250 Raises: 251 RuntimeError: If a weight with `name` has already been added. 252 """ 253 variable = variable_scope.variable( 254 initial_value=initial_value, 255 name=name, 256 dtype=dtype, 257 trainable=False, 258 use_resource=True, 259 synchronization=variables.VariableSynchronization.AUTO, 260 # Set aggregation to NONE, as loss scaling variables should never be 261 # aggregated. 262 aggregation=variables.VariableAggregation.NONE) 263 if context.executing_eagerly(): 264 graph_key = None 265 else: 266 graph = ops.get_default_graph() 267 graph_key = graph._graph_key # pylint: disable=protected-access 268 269 key = (name, graph_key) 270 self._weights[key] = variable 271 self._handle_deferred_dependencies(name=name, trackable=variable) 272 backend.track_variable(variable) 273 return variable 274 275 @property 276 def _checkpoint_dependencies(self): 277 """From Trackable. Gather graph-specific weights to save.""" 278 if context.executing_eagerly(): 279 graph_key = None 280 else: 281 graph = ops.get_default_graph() 282 graph_key = graph._graph_key # pylint: disable=protected-access 283 weights = [] 284 for (name, g), v in sorted(self._weights.items(), key=lambda i: i[0][0]): 285 if g == graph_key: 286 weights.append(trackable.TrackableReference(name=name, ref=v)) 287 return (super(_DynamicLossScaleState, self)._checkpoint_dependencies + 288 weights) 289 290 def _lookup_dependency(self, name): 291 """From Trackable. Find a weight in the current graph.""" 292 unconditional = super(_DynamicLossScaleState, self)._lookup_dependency(name) 293 if unconditional is not None: 294 return unconditional 295 if context.executing_eagerly(): 296 graph_key = None 297 else: 298 graph = ops.get_default_graph() 299 graph_key = graph._graph_key # pylint: disable=protected-access 300 return self._weights.get((name, graph_key), None) 301 302 @property 303 def initial_loss_scale(self): 304 return self._initial_loss_scale 305 306 @property 307 def growth_steps(self): 308 return self._growth_steps 309 310 @property 311 def multiplier(self): 312 return self._multiplier 313 314 @property 315 def current_loss_scale(self): 316 """Returns the current loss scale as a float32 `tf.Variable`.""" 317 return self._current_loss_scale 318 319 @property 320 def counter(self): 321 """Returns the counter as a float32 `tf.Variable`.""" 322 return self._counter 323 324 def __call__(self): 325 """Returns the current loss scale as a scalar `float32` tensor.""" 326 return ops.convert_to_tensor(self._current_loss_scale) 327 328 def update(self, grads): 329 """Updates the value of the loss scale. 330 331 Args: 332 grads: A nested structure of unscaled gradients, each which is the 333 gradient of the loss with respect to a weight. 334 335 Returns: 336 update_op: In eager mode, None. In graph mode, an op to update the loss 337 scale. 338 should_apply_gradients: Either a bool or a scalar boolean tensor. If 339 False, the caller should skip applying `grads` to the variables this 340 step. 341 """ 342 grads = nest.flatten(grads) 343 if distribution_strategy_context.has_strategy(): 344 distribution = distribution_strategy_context.get_strategy() 345 346 def get_is_finite(grads): 347 is_finite = _is_all_finite(grads) 348 # We cast to float, because we cannot reduce booleans with 349 # DistributionStrategy. 350 return math_ops.cast(is_finite, dtypes.float32) 351 352 is_finite_float = distribution.extended.call_for_each_replica( 353 get_is_finite, args=(grads,)) 354 reduced_is_finite_float = distribution.reduce(reduce_util.ReduceOp.SUM, 355 is_finite_float, axis=None) 356 is_finite = math_ops.equal(reduced_is_finite_float, 357 distribution.num_replicas_in_sync) 358 else: 359 is_finite = _is_all_finite(grads) 360 361 def update_if_finite_grads(): 362 """Update assuming the gradients are finite.""" 363 364 def incr_loss_scale(): 365 new_loss_scale = self.current_loss_scale * self.multiplier 366 return control_flow_ops.group( 367 _assign_if_finite(self.current_loss_scale, new_loss_scale), 368 self.counter.assign(0)) 369 370 return control_flow_ops.cond( 371 self.counter + 1 >= self.growth_steps, 372 incr_loss_scale, 373 lambda: _op_in_graph_mode(self.counter.assign_add(1))) 374 375 def update_if_not_finite_grads(): 376 """Update assuming the gradients are nonfinite.""" 377 378 new_loss_scale = math_ops.maximum( 379 self.current_loss_scale / self.multiplier, 1) 380 return control_flow_ops.group( 381 self.counter.assign(0), 382 self.current_loss_scale.assign(new_loss_scale)) 383 384 update_op = control_flow_ops.cond(is_finite, update_if_finite_grads, 385 update_if_not_finite_grads) 386 should_apply_gradients = is_finite 387 return update_op, should_apply_gradients 388 389 390# See LossScaleOptimizer docstring for why this is so big 391_DEFAULT_INITIAL_SCALE = 2 ** 15 392_DEFAULT_GROWTH_STEPS = 2000 393 394 395# pylint: disable=g-classes-have-attributes 396@keras_export('keras.mixed_precision.LossScaleOptimizer') 397class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): 398 """An optimizer that applies loss scaling to prevent numeric underflow. 399 400 Loss scaling is a technique to prevent numeric underflow in intermediate 401 gradients when float16 is used. To prevent underflow, the loss is multiplied 402 (or "scaled") by a certain factor called the "loss scale", which causes 403 intermediate gradients to be scaled by the loss scale as well. The final 404 gradients are divided (or "unscaled") by the loss scale to bring them back to 405 their original value. 406 407 `LossScaleOptimizer` wraps another optimizer and applies loss scaling to it. 408 By default, the loss scale is dynamically updated over time so you do not have 409 to choose the loss scale. The `minimize` method automatically scales the loss, 410 unscales the gradients, and updates the loss scale so all you have to do is 411 wrap your optimizer with a `LossScaleOptimizer` if you use `minimize`. For 412 example: 413 414 >>> opt = tf.keras.optimizers.SGD(0.25) 415 >>> opt = tf.keras.mixed_precision.LossScaleOptimizer(opt) 416 >>> var = tf.Variable(1.) 417 >>> loss_fn = lambda: var ** 2 418 >>> # 'minimize' applies loss scaling and updates the loss sale. 419 >>> opt.minimize(loss_fn, var_list=var) 420 >>> var.numpy() 421 0.5 422 423 If a `tf.GradientTape` is used to compute gradients instead of `minimize`, you 424 must scale the loss and gradients manually. This can be done with the 425 `LossScaleOptimizer.get_scaled_loss` and 426 `LossScaleOptimizer.get_unscaled_gradients` methods. For example: 427 428 >>> with tf.GradientTape() as tape: 429 ... loss = loss_fn() 430 ... scaled_loss = opt.get_scaled_loss(loss) 431 >>> scaled_grad = tape.gradient(scaled_loss, var) 432 >>> (grad,) = opt.get_unscaled_gradients([scaled_grad]) 433 >>> opt.apply_gradients([(grad, var)]) # Loss scale is updated here 434 >>> var.numpy() 435 0.25 436 437 Warning: If you forget to call `get_scaled_loss` or `get_unscaled_gradients` 438 (or both) when using a `tf.GradientTape`, the model will likely converge to a 439 worse quality. Please make sure you call each function exactly once. 440 441 When mixed precision with float16 is used, there is typically no risk of 442 underflow affecting model quality if loss scaling is properly used. See 443 [the mixed precision guide]( 444 https://www.tensorflow.org/guide/keras/mixed_precision) for more information 445 on how to use mixed precision. 446 447 Args: 448 inner_optimizer: The `tf.keras.optimizers.Optimizer` instance to wrap. 449 dynamic: Bool indicating whether dynamic loss scaling is used. Defaults to 450 True. If True, the loss scale will be dynamically updated over time using 451 an algorithm that keeps the loss scale at approximately its optimal value. 452 If False, a single fixed loss scale is used and `initial_scale` must be 453 specified, which is used as the loss scale. Recommended to keep as True, 454 as choosing a fixed loss scale can be tricky. Currently, there is a small 455 performance overhead to dynamic loss scaling compared to fixed loss 456 scaling. 457 initial_scale: The initial loss scale. If `dynamic` is True, this defaults 458 to `2 ** 15`. If `dynamic` is False, this must be specified and acts as 459 the sole loss scale, as the loss scale does not change over time. When 460 dynamic loss scaling is used, is better for this to be a very high number, 461 because a loss scale that is too high gets lowered far more quickly than a 462 loss scale that is too low gets raised. 463 dynamic_growth_steps: With dynamic loss scaling, every 464 `dynamic_growth_steps` steps with finite gradients, the loss scale is 465 doubled. Defaults to 2000. If a nonfinite gradient is encountered, the 466 count is reset back to zero, gradients are skipped that step, and the loss 467 scale is halved. The count can be queried with 468 `LossScaleOptimizer.dynamic_counter`. This argument can only be specified 469 if `dynamic` is True. 470 471 `LossScaleOptimizer` will occasionally skip applying gradients to the 472 variables, in which case the trainable variables will not change that step. 473 This is done because the dynamic loss scale will sometimes be raised too 474 high, causing overflow in the gradients. Typically, the first 2 to 15 steps of 475 the model are skipped as the initial loss scale is very high, but afterwards 476 steps will only be skipped on average 0.05% of the time (the fraction of steps 477 skipped is `1 / dynamic_growth_steps`). 478 479 `LossScaleOptimizer` delegates all public `Optimizer` methods to the inner 480 optimizer. Additionally, in methods `minimize` and `get_gradients, it scales 481 the loss and unscales the gradients. In methods `minimize` and 482 `apply_gradients`, it additionally updates the loss scale and skips applying 483 gradients if any gradient has a nonfinite value. 484 485 ### Hyperparameters 486 487 Hyperparameters can be accessed and set on the LossScaleOptimizer, which will 488 be delegated to the wrapped optimizer. 489 490 >>> opt = tf.keras.optimizers.Adam(beta_1=0.8, epsilon=1e-5) 491 >>> opt = tf.keras.mixed_precision.LossScaleOptimizer(opt) 492 >>> opt.beta_1 # Equivalent to `opt.inner_optimizer.beta_1` 493 0.8 494 >>> opt.beta_1 = 0.7 # Equivalent to `opt.inner_optimizer.beta_1 = 0.7` 495 >>> opt.beta_1 496 0.7 497 >>> opt.inner_optimizer.beta_1 498 0.7 499 500 However, accessing or setting non-hyperparameters is not delegated to the 501 LossScaleOptimizer. In an Adam optimizer, `beta_1` is a hyperparameter but 502 `epsilon` is not, as the Adam optimizer only calls `Optimizer._set_hyper` on 503 `beta_1`. 504 505 >>> opt.inner_optimizer.epsilon 506 1e-5 507 >>> opt.epsilon 508 Traceback (most recent call last): 509 ... 510 AttributeError: 'LossScaleOptimizer' object has no attribute 'epsilon' 511 >>> opt.epsilon = 1e-4 # This does NOT set epsilon on `opt.inner_optimizer` 512 >>> opt.inner_optimizer.epsilon 513 >>> 1e-5 514 515 In the above example, despite epsilon being set on the LossScaleOptimizer, the 516 old epsilon value will still be used when training as epsilon was not set on 517 the inner optimizer. 518 """ 519 520 _HAS_AGGREGATE_GRAD = True 521 522 def __init__(self, inner_optimizer, dynamic=True, initial_scale=None, 523 dynamic_growth_steps=None): 524 if not isinstance(inner_optimizer, optimizer_v2.OptimizerV2): 525 raise TypeError('"inner_optimizer" must be an instance of OptimizerV2, ' 526 'but got: %s' % inner_optimizer) 527 if not isinstance(dynamic, bool): 528 # Catch errors if a user incorrectly passes a string or float to the 529 # second argument argument, as this is commonly done for 530 # LossScaleOptimizerV1. 531 raise TypeError('"dynamic" argument to LossScaleOptimizer.__init__ must ' 532 'be a bool, but got: %r' % (dynamic,)) 533 self._raise_if_strategy_unsupported() 534 self._optimizer = inner_optimizer 535 536 # We don't call super().__init__, since we do not want to call OptimizerV2's 537 # constructor. 538 _DelegatingTrackableMixin.__init__(self, self._optimizer) 539 540 if dynamic: 541 if initial_scale is None: 542 initial_scale = _DEFAULT_INITIAL_SCALE 543 if dynamic_growth_steps is None: 544 dynamic_growth_steps = _DEFAULT_GROWTH_STEPS 545 self._loss_scale = _DynamicLossScaleState( 546 initial_scale, dynamic_growth_steps, multiplier=2) 547 self._track_trackable(self._loss_scale, 'loss_scale') 548 else: 549 if initial_scale is None: 550 raise ValueError('"initial_scale" must be specified if "dynamic" is ' 551 'False') 552 self._loss_scale = float(initial_scale) 553 if dynamic_growth_steps is not None: 554 raise ValueError('"dynamic_growth_steps" must be None if "dynamic" ' 555 'is False, but got: %s' % (dynamic_growth_steps,)) 556 557 # To support restoring TensorFlow 2.2 checkpoints. 558 self._track_trackable(FakeOptimizerForRestoration(self._optimizer), 559 'base_optimizer') 560 561 @property 562 def dynamic(self): 563 """Bool indicating whether dynamic loss scaling is used.""" 564 return isinstance(self._loss_scale, _DynamicLossScaleState) 565 566 @property 567 def loss_scale(self): 568 """The current loss scale as a float32 scalar tensor.""" 569 if isinstance(self._loss_scale, _DynamicLossScaleState): 570 return ops.convert_to_tensor(self._loss_scale.current_loss_scale) 571 else: 572 return ops.convert_to_tensor(self._loss_scale) 573 574 @property 575 def dynamic_counter(self): 576 """The number of steps since the loss scale was last increased or decreased. 577 578 This is None if `LossScaleOptimizer.dynamic` is False. 579 580 The counter is incremented every step. Once it reaches 581 `LossScaleOptimizer.dynamic_growth_steps`, the loss scale will be doubled 582 and the counter will be reset back to zero. If nonfinite gradients are 583 encountered, the loss scale will be halved and the counter will be reset 584 back to zero. 585 """ 586 if isinstance(self._loss_scale, _DynamicLossScaleState): 587 return self._loss_scale.counter 588 else: 589 return None 590 591 @property 592 def initial_scale(self): 593 """The initial loss scale. 594 595 If `LossScaleOptimizer.dynamic` is False, this is the same number as 596 `LossScaleOptimizer.loss_scale`, as the loss scale never changes. 597 """ 598 if isinstance(self._loss_scale, _DynamicLossScaleState): 599 return self._loss_scale.initial_loss_scale 600 else: 601 return self._loss_scale 602 603 @property 604 def dynamic_growth_steps(self): 605 """The number of steps it takes to increase the loss scale. 606 607 This is None if `LossScaleOptimizer.dynamic` is False. 608 609 Every `dynamic_growth_steps` consecutive steps with finite gradients, the 610 loss scale is increased. 611 """ 612 if isinstance(self._loss_scale, _DynamicLossScaleState): 613 return self._loss_scale.growth_steps 614 else: 615 return None 616 617 @property 618 def inner_optimizer(self): 619 """The optimizer that this LossScaleOptimizer is wrapping.""" 620 return self._optimizer 621 622 def get_scaled_loss(self, loss): 623 """Scales the loss by the loss scale. 624 625 This method is only needed if you compute gradients manually, e.g. with 626 `tf.GradientTape`. In that case, call this method to scale the loss before 627 passing the loss to `tf.GradientTape`. If you use 628 `LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, loss 629 scaling is automatically applied and this method is unneeded. 630 631 If this method is called, `get_unscaled_gradients` should also be called. 632 See the `tf.keras.mixed_precision.LossScaleOptimizer` doc for 633 an example. 634 635 Args: 636 loss: The loss, which will be multiplied by the loss scale. Can either be 637 a tensor or a callable returning a tensor. 638 639 Returns: 640 `loss` multiplied by `LossScaleOptimizer.loss_scale`. 641 """ 642 if callable(loss): 643 def new_loss(): 644 loss_val = loss() 645 return loss_val * math_ops.cast(self.loss_scale, loss_val.dtype) 646 return new_loss 647 else: 648 return loss * math_ops.cast(self.loss_scale, loss.dtype) 649 650 def get_unscaled_gradients(self, grads): 651 """Unscales the gradients by the loss scale. 652 653 This method is only needed if you compute gradients manually, e.g. with 654 `tf.GradientTape`. In that case, call this method to unscale the gradients 655 after computing them with `tf.GradientTape`. If you use 656 `LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, loss 657 scaling is automatically applied and this method is unneeded. 658 659 If this method is called, `get_scaled_loss` should also be called. See 660 the `tf.keras.mixed_precision.LossScaleOptimizer` doc for an 661 example. 662 663 Args: 664 grads: A list of tensors, each which will be divided by the loss scale. 665 Can have None values, which are ignored. 666 667 Returns: 668 A new list the same size as `grads`, where every non-None value in `grads` 669 is divided by `LossScaleOptimizer.loss_scale`. 670 """ 671 loss_scale_reciprocal = 1. / self.loss_scale 672 return [ 673 _multiply_gradient(g, loss_scale_reciprocal) if g is not None else None 674 for g in grads 675 ] 676 677 def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None): 678 tape = backprop.GradientTape() if tape is None else tape 679 with tape: 680 loss = self.get_scaled_loss(loss) 681 grads_and_vars = self._optimizer._compute_gradients( # pylint: disable=protected-access 682 loss, 683 var_list, 684 grad_loss, 685 tape=tape) 686 grads = [g for g, _ in grads_and_vars] 687 weights = [v for _, v in grads_and_vars] 688 unscaled_grads = self.get_unscaled_gradients(grads) 689 return list(zip(unscaled_grads, weights)) 690 691 def get_gradients(self, loss, params): 692 loss = self.get_scaled_loss(loss) 693 grads = self._optimizer.get_gradients(loss, params) 694 return self.get_unscaled_gradients(grads) 695 696 def _create_all_weights(self, var_list): 697 self._optimizer._create_all_weights(var_list) # pylint: disable=protected-access 698 699 def apply_gradients(self, 700 grads_and_vars, 701 name=None, 702 experimental_aggregate_gradients=True): 703 if distribution_strategy_context.in_cross_replica_context(): 704 raise ValueError('apply_gradients() must be called in a replica context.') 705 # We check for the strategy here despite already checking in the constructor 706 # as frequently the optimizer is created outside the strategy's scope. 707 self._raise_if_strategy_unsupported() 708 709 grads_and_vars = tuple(grads_and_vars) 710 return distribution_strategy_context.get_replica_context().merge_call( 711 self._apply_gradients_cross_replica, 712 args=(grads_and_vars, name, experimental_aggregate_gradients)) 713 714 def _apply_gradients_cross_replica(self, distribution, grads_and_vars, name, 715 experimental_aggregate_gradients): 716 grads = [g for g, _ in grads_and_vars] 717 if isinstance(self._loss_scale, _DynamicLossScaleState): 718 loss_scale_update_op, should_apply_grads = self._loss_scale.update(grads) 719 else: 720 loss_scale_update_op = control_flow_ops.no_op() 721 should_apply_grads = True 722 723 def apply_fn(): 724 # We do not want DistributionStrategy to unwrap any MirroredVariables in 725 # grads_and_vars, because even in a replica context, the wrapped optimizer 726 # expects mirrored variables. So we wrap the variables with an 727 # _UnwrapPreventer, preventing DistributionStrategy from unwrapping the 728 # MirroredVariables. 729 wrapped_vars = _UnwrapPreventer([v for _, v in grads_and_vars]) 730 return distribution.extended.call_for_each_replica( 731 self._apply_gradients, 732 args=(grads, wrapped_vars, name, experimental_aggregate_gradients)) 733 734 def do_not_apply_fn(): 735 # Normally self._optimizer.iterations is incremented in 736 # self._optimizer.apply_gradients(). Since that is not called in this 737 # branch, we increment it here instead. 738 return self._optimizer.iterations.assign_add(1, read_value=False) 739 740 # Note: We must call this cond() in a cross-replica context. 741 # DistributionStrategy does not support having a cond in a replica context 742 # with a branch that calls `merge_call`, and self._optimizer.apply_gradients 743 # calls `merge_call`. 744 maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn, 745 do_not_apply_fn) 746 return control_flow_ops.group(maybe_apply_op, loss_scale_update_op) 747 748 def _apply_gradients(self, grads, wrapped_vars, name, 749 experimental_aggregate_gradients): 750 # TODO(reedwm): This will raise a fairly cryptic error message if 751 # self._optimizer.apply_gradients does not take 752 # experimental_aggregate_gradients. 753 return self._optimizer.apply_gradients( 754 list(zip(grads, wrapped_vars.value)), name, 755 experimental_aggregate_gradients=experimental_aggregate_gradients) 756 757 def get_config(self): 758 serialized_optimizer = optimizers.serialize(self._optimizer) 759 return { 760 'inner_optimizer': serialized_optimizer, 761 'dynamic': self.dynamic, 762 'initial_scale': self.initial_scale, 763 'dynamic_growth_steps': self.dynamic_growth_steps, 764 } 765 766 @classmethod 767 def from_config(cls, config, custom_objects=None): 768 config = config.copy() # Make a copy, since we mutate config 769 if 'loss_scale' in config: 770 # If loss_scale is in config, we assume we are deserializing a 771 # LossScaleOptimizer from TF 2.3 or below. We convert the config so it 772 # can be deserialized in the current LossScaleOptimizer. 773 loss_scale = keras_loss_scale_module.deserialize( 774 config.pop('loss_scale')) 775 if isinstance(loss_scale, loss_scale_module.FixedLossScale): 776 config['dynamic'] = False 777 config['initial_scale'] = loss_scale._loss_scale_value # pylint: disable=protected-access 778 elif isinstance(loss_scale, loss_scale_module.DynamicLossScale): 779 config['dynamic'] = True 780 config['initial_scale'] = loss_scale.initial_loss_scale 781 config['dynamic_growth_steps'] = loss_scale.increment_period 782 if loss_scale.multiplier != 2: 783 raise ValueError('Cannot deserialize LossScaleOptimizer with a ' 784 'DynamicLossScale whose multiplier is not 2. Got ' 785 'DynamicLossScale: %s' % (loss_scale,)) 786 else: 787 raise ValueError( 788 'Serialized LossScaleOptimizers with a LossScale that is neither a ' 789 'FixedLossScale nor a DynamicLossScale can no longer be ' 790 'deserialized') 791 config['inner_optimizer'] = config.pop('optimizer') 792 config['inner_optimizer'] = optimizers.deserialize( 793 config['inner_optimizer'], custom_objects=custom_objects) 794 return cls(**config) 795 796 def _raise_if_strategy_unsupported(self): 797 if not strategy_supports_loss_scaling(): 798 strategy = distribution_strategy_context.get_strategy() 799 if isinstance(strategy, 800 (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1, 801 tpu_strategy.TPUStrategyV2)): 802 raise ValueError( 803 'Loss scaling is not supported with TPUStrategy. Loss scaling is ' 804 'unnecessary with TPUs, since they support bfloat16 instead of ' 805 'float16 and bfloat16 does not require loss scaling. You should ' 806 'remove the use of the LossScaleOptimizer when TPUs are used.') 807 else: 808 raise ValueError('Loss scaling is not supported with the ' 809 'tf.distribute.Strategy: %s. Try using a different ' 810 'Strategy, e.g. a MirroredStrategy' % 811 strategy.__class__.__name__) 812 813 # Delegations: We delegate most OptimizerV2 methods to the wrapped optimizer 814 # below. 815 816 @property 817 def iterations(self): 818 return self._optimizer.iterations 819 820 @iterations.setter 821 def iterations(self, variable): 822 self._optimizer.iterations = variable 823 824 def get_slot_names(self): 825 return self._optimizer.get_slot_names() 826 827 def variables(self): 828 return self._optimizer.variables() 829 830 @property 831 def weights(self): 832 return self._optimizer.weights 833 834 def get_weights(self): 835 return self._optimizer.get_weights() 836 837 def set_weights(self, weights): 838 return self._optimizer.set_weights(weights) 839 840 @property 841 def clipnorm(self): 842 return self._optimizer.clipnorm 843 844 @clipnorm.setter 845 def clipnorm(self, val): 846 self._optimizer.clipnorm = val 847 848 @property 849 def global_clipnorm(self): 850 return self._optimizer.global_clipnorm 851 852 @global_clipnorm.setter 853 def global_clipnorm(self, val): 854 self._optimizer.global_clipnorm = val 855 856 @property 857 def clipvalue(self): 858 return self._optimizer.clipvalue 859 860 @clipvalue.setter 861 def clipvalue(self, val): 862 self._optimizer.clipvalue = val 863 864 def _aggregate_gradients(self, grads_and_vars): 865 return self._optimizer._aggregate_gradients(grads_and_vars) # pylint: disable=protected-access 866 867 def _restore_slot_variable(self, slot_name, variable, slot_variable): 868 return self._optimizer._restore_slot_variable(slot_name, variable, # pylint: disable=protected-access 869 slot_variable) 870 871 def _create_or_restore_slot_variable(self, slot_variable_position, slot_name, 872 variable): 873 return self._optimizer._create_or_restore_slot_variable( # pylint: disable=protected-access 874 slot_variable_position, slot_name, variable) 875 876 def get_slot(self, var, slot_name): 877 return self._optimizer.get_slot(var, slot_name) 878 879 def add_slot(self, var, slot_name, initializer='zeros'): 880 return self._optimizer.add_slot(var, slot_name, initializer) 881 882 def __getattribute__(self, name): 883 try: 884 return object.__getattribute__(self, name) 885 except AttributeError as e: 886 if name == '_optimizer' or name == '_hyper': 887 # Avoid infinite recursion 888 raise e 889 890 # Delegate hyperparameter accesses to inner optimizer. 891 if name == 'lr': 892 name = 'learning_rate' 893 if name in self._optimizer._hyper: 894 return self._optimizer._get_hyper(name) 895 raise e 896 897 def __dir__(self): 898 result = set(super(LossScaleOptimizer, self).__dir__()) 899 if '_optimizer' in result: 900 result |= self._optimizer._hyper.keys() 901 if 'learning_rate' in self._optimizer._hyper.keys(): 902 result.add('lr') 903 return list(result) 904 905 def __setattr__(self, name, value): 906 if name == 'lr': 907 name = 'learning_rate' 908 # Delegate setting hyperparameter to inner optimizer if the attribute does 909 # not exist on the LossScaleOptimizer 910 try: 911 # We cannot check for the 'iterations' attribute as it cannot be set after 912 # it is accessed. 913 if name != 'iterations': 914 object.__getattribute__(self, name) 915 has_attribute = True 916 except AttributeError: 917 has_attribute = False 918 if (name != '_optimizer' and name in self._optimizer._hyper 919 and not has_attribute): 920 self._optimizer._set_hyper(name, value) 921 else: 922 super(LossScaleOptimizer, self).__setattr__(name, value) 923 924 # We do not override some OptimizerV2 methods. For each, we describe why we do 925 # not delegate them to self._optimizer: 926 # * get_updates: get_updates() calls get_gradients(). Since we override 927 # get_gradients(), we cannot delegate get_updates() to self._optimizer, 928 # otherwise the overridden get_gradients() method would not be called. 929 # Luckily, get_updates() does not access any OptimizerV2 fields, so 930 # inheriting the OptimizerV2 version works fine. 931 # * minimize: We don't delegate for a similar as get_updates(): it calls 932 # both self._compute_gradients() and self.apply_gradients(), and both need 933 # to have the LossScaleOptimizer version called. 934 935 # TODO(reedwm): Maybe throw an error if mixed precision is used without this 936 # optimizer being used. 937 938 939@keras_export('keras.mixed_precision.experimental.LossScaleOptimizer') 940class LossScaleOptimizerV1(LossScaleOptimizer): 941 """An deprecated optimizer that applies loss scaling. 942 943 Warning: This class is deprecated and will be removed in TensorFlow 2.5. 944 Please use the non-experimental class 945 `tf.keras.mixed_precision.LossScaleOptimizer` instead. 946 947 This class is identical to the non-experimental 948 `keras.mixed_precision.LossScaleOptimizer` except its constructor takes 949 different arguments. For this class (the experimental version), the 950 constructor takes a `loss_scale` argument. For the non-experimental class, 951 the constructor encodes the loss scaling information in multiple arguments. 952 Note that unlike this class, the non-experimental class does not accept a 953 `tf.compat.v1.mixed_precision.LossScale`, which is deprecated. 954 955 If you currently use this class, you should switch to the non-experimental 956 `tf.keras.mixed_precision.LossScaleOptimizer` instead. We show several 957 examples of converting the use of the experimental class to the equivalent 958 non-experimental class. 959 960 >>> # In all of the the examples below, `opt1` and `opt2` are identical 961 >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer( 962 ... tf.keras.optimizers.SGD(), loss_scale='dynamic') 963 >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer( 964 ... tf.keras.optimizers.SGD()) 965 >>> assert opt1.get_config() == opt2.get_config() 966 967 >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer( 968 ... tf.keras.optimizers.SGD(), loss_scale=123) 969 >>> # dynamic=False indicates to use fixed loss scaling. initial_scale=123 970 >>> # refers to the initial loss scale, which is the single fixed loss scale 971 >>> # when dynamic=False. 972 >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer( 973 ... tf.keras.optimizers.SGD(), dynamic=False, initial_scale=123) 974 >>> assert opt1.get_config() == opt2.get_config() 975 976 >>> loss_scale = tf.compat.v1.mixed_precision.experimental.DynamicLossScale( 977 ... initial_loss_scale=2048, increment_period=500) 978 >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer( 979 ... tf.keras.optimizers.SGD(), loss_scale=loss_scale) 980 >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer( 981 ... tf.keras.optimizers.SGD(), initial_scale=2048, 982 ... dynamic_growth_steps=500) 983 >>> assert opt1.get_config() == opt2.get_config() 984 985 Make sure to also switch from this class to the non-experimental class in 986 isinstance checks, if you have any. If you do not do this, your model may run 987 into hard-to-debug issues, as the experimental `LossScaleOptimizer` subclasses 988 the non-experimental `LossScaleOptimizer`, but not vice versa. It is safe to 989 switch isinstance checks to the non-experimental `LossScaleOptimizer` even 990 before using the non-experimental `LossScaleOptimizer`. 991 992 >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer( 993 ... tf.keras.optimizers.SGD(), loss_scale='dynamic') 994 >>> # The experimental class subclasses the non-experimental class 995 >>> isinstance(opt1, tf.keras.mixed_precision.LossScaleOptimizer) 996 True 997 >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer( 998 ... tf.keras.optimizers.SGD()) 999 >>> # The non-experimental class does NOT subclass the experimental class. 1000 >>> isinstance(opt2, tf.keras.mixed_precision.experimental.LossScaleOptimizer) 1001 False 1002 1003 Args: 1004 optimizer: The Optimizer instance to wrap. 1005 loss_scale: The loss scale to scale the loss and gradients. This can 1006 either be an int/float to use a fixed loss scale, the string "dynamic" 1007 to use dynamic loss scaling, or an instance of a LossScale. The string 1008 "dynamic" equivalent to passing `DynamicLossScale()`, and passing an 1009 int/float is equivalent to passing a FixedLossScale with the given loss 1010 scale. If a DynamicLossScale is passed, DynamicLossScale.multiplier must 1011 be 2 (the default). 1012 """ 1013 1014 def __init__(self, optimizer, loss_scale): 1015 warn_msg_prefix = ( 1016 'tf.keras.mixed_precision.experimental.LossScaleOptimizer is ' 1017 'deprecated. Please use tf.keras.mixed_precision.LossScaleOptimizer ' 1018 'instead. ') 1019 1020 if isinstance(loss_scale, dict): 1021 loss_scale = keras_loss_scale_module.deserialize(loss_scale) 1022 1023 if isinstance(loss_scale, (int, float)): 1024 tf_logging.warn( 1025 warn_msg_prefix + 'For example\n' 1026 ' opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(' 1027 'opt, dynamic=False, initial_scale={})'.format(loss_scale)) 1028 super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False, 1029 initial_scale=loss_scale) 1030 elif isinstance(loss_scale, loss_scale_module.FixedLossScale): 1031 ls_val = loss_scale._loss_scale_value # pylint: disable=protected-access 1032 tf_logging.warn( 1033 warn_msg_prefix + 'For example\n' 1034 ' opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(' 1035 'opt, dynamic=False, initial_scale={})'.format(ls_val)) 1036 super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False, 1037 initial_scale=ls_val) 1038 elif loss_scale == 'dynamic': 1039 tf_logging.warn( 1040 warn_msg_prefix + 'For example\n' 1041 ' opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(' 1042 'opt)') 1043 super(LossScaleOptimizerV1, self).__init__(optimizer) 1044 elif isinstance(loss_scale, loss_scale_module.DynamicLossScale): 1045 kwargs = {} 1046 extra_arguments = '' 1047 if loss_scale.initial_loss_scale != _DEFAULT_INITIAL_SCALE: 1048 kwargs['initial_scale'] = loss_scale.initial_loss_scale 1049 extra_arguments += (', initial_scale=%s' % 1050 loss_scale.initial_loss_scale) 1051 if loss_scale.increment_period != _DEFAULT_GROWTH_STEPS: 1052 kwargs['dynamic_growth_steps'] = loss_scale.increment_period 1053 extra_arguments += (', dynamic_growth_steps=%s' % 1054 loss_scale.increment_period) 1055 if loss_scale.multiplier != 2: 1056 raise ValueError('When passing a DynamicLossScale to "loss_scale", ' 1057 'DynamicLossScale.multiplier must be 2. Got: %s' 1058 % (loss_scale,)) 1059 tf_logging.warn( 1060 warn_msg_prefix + 1061 'Note that the non-experimental LossScaleOptimizer does not take a ' 1062 'DynamicLossScale but instead takes the dynamic configuration ' 1063 'directly in the constructor. For example:\n' 1064 ' opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(' 1065 'opt{})\n'.format(extra_arguments)) 1066 super(LossScaleOptimizerV1, self).__init__(optimizer, **kwargs) 1067 elif isinstance(loss_scale, loss_scale_module.LossScale): 1068 raise TypeError('Passing a LossScale that is not a FixedLossScale or a ' 1069 'DynamicLossScale is no longer supported. Got: {}' 1070 .format(loss_scale)) 1071 else: 1072 raise ValueError('Invalid value passed to loss_scale. loss_scale ' 1073 'must be the string "dynamic" (recommended), an int, ' 1074 'a float, a FixedLossScale, or a DynamicLossScale. Got ' 1075 'value: {}'.format(loss_scale)) 1076 1077 @classmethod 1078 def from_config(cls, config, custom_objects=None): 1079 config = config.copy() # Make a copy, since we mutate config 1080 1081 # If loss_scale is in config, we assume we are deserializing a 1082 # LossScaleOptimizer from TF 2.3 or below. Otherwise, we assume we are 1083 # deserializing a LossScaleOptimizer from TF 2.4 or above. 1084 if 'loss_scale' in config: 1085 config['loss_scale'] = keras_loss_scale_module.deserialize( 1086 config['loss_scale']) 1087 if (isinstance(config['loss_scale'], loss_scale_module.DynamicLossScale) 1088 and config['loss_scale'].multiplier != 2): 1089 raise ValueError('Cannot deserialize LossScaleOptimizer with a ' 1090 'DynamicLossScale whose multiplier is not 2. Got ' 1091 'DynamicLossScale: %s' % (config['loss_scale'],)) 1092 config['optimizer'] = optimizers.deserialize( 1093 config['optimizer'], custom_objects=custom_objects) 1094 return cls(**config) 1095 1096 # We convert the config, as generated by LossScaleOptimizer.get_config, to a 1097 # version that can be passed to LossScaleOptimizerV1.__init__ 1098 if config['dynamic']: 1099 config['loss_scale'] = loss_scale_module.DynamicLossScale( 1100 config['initial_scale'], config['dynamic_growth_steps'], multiplier=2) 1101 else: 1102 config['loss_scale'] = loss_scale_module.FixedLossScale( 1103 config['initial_scale']) 1104 1105 del config['dynamic'] 1106 del config['initial_scale'] 1107 del config['dynamic_growth_steps'] 1108 config['optimizer'] = optimizers.deserialize( 1109 config.pop('inner_optimizer'), custom_objects=custom_objects) 1110 return cls(**config) 1111 1112 1113class FakeOptimizerForRestoration(trackable.Trackable): 1114 """A fake optimizer used to support restoring TensorFlow 2.2 checkpoints. 1115 1116 The checkpoint format for LossScaleOptimizers changed after TF 2.2. This class 1117 exists to support restoring TF 2.2 checkpoints in newer version of TensorFlow. 1118 1119 In TF 2.2, LossScaleOptimizer would track the wrapped optimizer by calling the 1120 following in LossScaleOptimizer.__init__ 1121 1122 ``` 1123 self._track_trackable(self._optimizer, 'base_optimizer') 1124 ``` 1125 1126 This means a dependency from the LossScaleOptimizer to the wrapped optimizer 1127 would be stored in the checkpoint. However now, the checkpoint format with a 1128 LossScaleOptimizer is the same as the format without a LossScaleOptimizer, 1129 except the loss scale is also stored. This means there is no dependency from 1130 the LossScaleOptimizer to the wrapped optimizer. Instead, the 1131 LossScaleOptimizer acts as if it is the wrapped optimizer, from a checkpoint's 1132 perspective, by overriding all Trackable methods and delegating them to the 1133 wrapped optimizer. 1134 1135 To allow restoring TF 2.2. checkpoints, LossScaleOptimizer adds a dependency 1136 on this class instead of the inner optimizer. When restored, this class will 1137 instead restore the slot variables of the inner optimizer. Since this class 1138 has no variables, it does not affect the checkpoint when saved. 1139 """ 1140 1141 def __init__(self, optimizer): 1142 self._optimizer = optimizer 1143 1144 def get_slot_names(self): 1145 return self._optimizer.get_slot_names() 1146 1147 def _create_or_restore_slot_variable(self, slot_variable_position, slot_name, 1148 variable): 1149 return self._optimizer._create_or_restore_slot_variable( # pylint: disable=protected-access 1150 slot_variable_position, slot_name, variable) 1151 1152 1153# pylint: disable=protected-access 1154mixed_precision._register_wrapper_optimizer_cls(optimizer_v2.OptimizerV2, 1155 LossScaleOptimizerV1) 1156 1157 1158def _multiply_gradient(gradient, scale): 1159 """Multiply a (possibly sparse) gradient by the given scale factor.""" 1160 scale = math_ops.cast(scale, gradient.dtype) 1161 if isinstance(gradient, ops.IndexedSlices): 1162 return ops.IndexedSlices( 1163 gradient.values * scale, 1164 gradient.indices, 1165 dense_shape=gradient.dense_shape) 1166 else: 1167 return gradient * scale 1168 1169 1170def strategy_supports_loss_scaling(): 1171 """Returns True if the current Strategy supports loss scaling.""" 1172 if not distribution_strategy_context.has_strategy(): 1173 return True 1174 strategy = distribution_strategy_context.get_strategy() 1175 # Strategies are supported if either there is only one replica or if variables 1176 # are replicated per device. Otherwise, the current model.fit() implementation 1177 # and most custom training loops incorrectly unscale the gradients. Currently, 1178 # gradients are unscaled once per compute replica, but they should be unscaled 1179 # once per variable replica. When there is one variable replica for each 1180 # compute replica, this works fine, but otherwise issues will occur. 1181 # TODO(reedwm): Support all strategies. 1182 return isinstance(strategy, ( 1183 collective_all_reduce_strategy.CollectiveAllReduceStrategy, 1184 collective_all_reduce_strategy.CollectiveAllReduceStrategyV1, 1185 one_device_strategy.OneDeviceStrategy, 1186 one_device_strategy.OneDeviceStrategyV1, 1187 mirrored_strategy.MirroredStrategy, 1188 mirrored_strategy.MirroredStrategyV1, 1189 )) 1190