1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains the loss scaling optimizer class."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.distribute import collective_all_reduce_strategy
21from tensorflow.python.distribute import distribution_strategy_context
22from tensorflow.python.distribute import mirrored_strategy
23from tensorflow.python.distribute import one_device_strategy
24from tensorflow.python.distribute import reduce_util
25from tensorflow.python.distribute import tpu_strategy
26from tensorflow.python.eager import backprop
27from tensorflow.python.eager import context
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import smart_cond
31from tensorflow.python.keras import backend
32from tensorflow.python.keras import optimizers
33from tensorflow.python.keras.mixed_precision import loss_scale as keras_loss_scale_module
34from tensorflow.python.keras.optimizer_v2 import optimizer_v2
35from tensorflow.python.ops import control_flow_ops
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops import variable_scope
38from tensorflow.python.ops import variables
39from tensorflow.python.platform import tf_logging
40from tensorflow.python.training.experimental import loss_scale as loss_scale_module
41from tensorflow.python.training.experimental import mixed_precision
42from tensorflow.python.training.tracking import base as trackable
43from tensorflow.python.util import nest
44from tensorflow.python.util.tf_export import keras_export
45
46
47class _UnwrapPreventer(object):
48  """Wrapper that DistributionStrategy will not unwrap.
49
50  Typically, DistributionStrategy will unwrap values when going from a cross-
51  replica context to a replica context via `call_for_each_replica`. This class
52  is a wrapper that DistributionStrategy will not unwrap, so it can be used to
53  prevent it from unwrapping a value.
54
55  TODO(reedwm): Find/implement a better way of preventing values from being
56  unwrapped by DistributionStrategy
57  """
58
59  __slots__ = ['value']
60
61  def __init__(self, value):
62    self.value = value
63
64
65class _DelegatingTrackableMixin(object):
66  """A mixin that delegates all Trackable methods to another trackable object.
67
68  This class must be used with multiple inheritance. A class that subclasses
69  Trackable can also subclass this class, which causes all Trackable methods to
70  be delegated to the trackable object passed in the constructor.
71
72  A subclass can use this mixin to appear as if it were the trackable passed to
73  the constructor, from a Checkpoint's perspective. LossScaleOptimizer uses this
74  mixin, so that the checkpoint format for a LossScaleOptimizer is identical to
75  the checkpoint format for a normal optimizer. This allows a model to be saved
76  with a normal Optimizer and restored with a LossScaleOptimizer, or vice versa.
77  The only difference in checkpoint format is that the loss scale is also saved
78  with a LossScaleOptimizer.
79  """
80
81  def __init__(self, trackable_obj):
82    self._trackable = trackable_obj
83
84  # pylint: disable=protected-access
85  @property
86  def _setattr_tracking(self):
87    return self._trackable._setattr_tracking
88
89  @_setattr_tracking.setter
90  def _setattr_tracking(self, value):
91    self._trackable._setattr_tracking = value
92
93  @property
94  def _update_uid(self):
95    return self._trackable._update_uid
96
97  @_update_uid.setter
98  def _update_uid(self, value):
99    self._trackable._update_uid = value
100
101  @property
102  def _unconditional_checkpoint_dependencies(self):
103    return self._trackable._unconditional_checkpoint_dependencies
104
105  @property
106  def _unconditional_dependency_names(self):
107    return self._trackable._unconditional_dependency_names
108
109  @property
110  def _name_based_restores(self):
111    return self._trackable._name_based_restores
112
113  def _maybe_initialize_trackable(self):
114    return self._trackable._maybe_initialize_trackable()
115
116  @property
117  def _object_identifier(self):
118    return self._trackable._object_identifier
119
120  @property
121  def _tracking_metadata(self):
122    return self._trackable._tracking_metadata
123
124  def _no_dependency(self, value):
125    return self._trackable._no_dependency(value)
126
127  def _name_based_attribute_restore(self, checkpoint):
128    return self._trackable._name_based_attribute_restore(checkpoint)
129
130  @property
131  def _checkpoint_dependencies(self):
132    return self._trackable._checkpoint_dependencies
133
134  @property
135  def _deferred_dependencies(self):
136    return self._trackable._deferred_dependencies
137
138  def _lookup_dependency(self, name):
139    self._trackable._lookup_dependency(name)
140
141  def _add_variable_with_custom_getter(self,
142                                       name,
143                                       shape=None,
144                                       dtype=dtypes.float32,
145                                       initializer=None,
146                                       getter=None,
147                                       overwrite=False,
148                                       **kwargs_for_getter):
149    return self._trackable._add_variable_with_custom_getter(
150        name, shape, dtype, initializer, getter, overwrite, **kwargs_for_getter)
151
152  def _preload_simple_restoration(self, name):
153    return self._trackable._preload_simple_restoration(name)
154
155  def _track_trackable(self, trackable, name, overwrite=False):  # pylint: disable=redefined-outer-name
156    return self._trackable._track_trackable(trackable, name, overwrite)
157
158  def _handle_deferred_dependencies(self, name, trackable):  # pylint: disable=redefined-outer-name
159    return self._trackable._handle_deferred_dependencies(name, trackable)
160
161  def _restore_from_checkpoint_position(self, checkpoint_position):
162    return self._trackable._restore_from_checkpoint_position(
163        checkpoint_position)
164
165  def _single_restoration_from_checkpoint_position(self, checkpoint_position,
166                                                   visit_queue):
167    return self._trackable._single_restoration_from_checkpoint_position(
168        checkpoint_position, visit_queue)
169
170  def _gather_saveables_for_checkpoint(self):
171    return self._trackable._gather_saveables_for_checkpoint()
172
173  def _list_extra_dependencies_for_serialization(self, serialization_cache):
174    return self._trackable._list_extra_dependencies_for_serialization(
175        serialization_cache)
176
177  def _list_functions_for_serialization(self, serialization_cache):
178    return self._trackable._list_functions_for_serialization(
179        serialization_cache)
180  # pylint: enable=protected-access
181
182
183def _is_all_finite(grads):
184  """Returns a scalar boolean tensor indicating if all gradients are finite."""
185  is_finite_per_grad = [
186      math_ops.reduce_all(math_ops.is_finite(g)) for g in grads if g is not None
187  ]
188  return math_ops.reduce_all(is_finite_per_grad)
189
190
191def _op_in_graph_mode(tensor):
192  """Returns the tensor's op in graph mode, or the tensor in eager mode.
193
194  This is useful because sometimes an op is needed in graph mode instead of a
195  tensor. In eager mode, there are no ops.
196
197  Args:
198    tensor: A tensor.
199
200  Returns:
201    The tensor's op in graph mode. The tensor in eager mode.
202  """
203  if context.executing_eagerly():
204    return tensor
205  return tensor.op
206
207
208def _assign_if_finite(var, value):
209  """Assigns a value to a variable if the value is finite."""
210  return control_flow_ops.cond(
211      math_ops.is_finite(value), lambda: _op_in_graph_mode(var.assign(value)),
212      control_flow_ops.no_op)
213
214
215class _DynamicLossScaleState(trackable.Trackable):
216  """The state of a dynamic loss scale."""
217
218  def __init__(self,
219               initial_loss_scale,
220               growth_steps,
221               multiplier):
222    """Creates the dynamic loss scale."""
223    super(_DynamicLossScaleState, self).__init__()
224    self._initial_loss_scale = float(initial_loss_scale)
225    self._growth_steps = int(growth_steps)
226    self._multiplier = float(multiplier)
227
228    self._weights = {}
229    self._current_loss_scale = self._add_weight(
230        name='current_loss_scale',
231        dtype=dtypes.float32,
232        initial_value=self._initial_loss_scale)
233    # The number of consecutive steps with finite gradients since the last
234    # nonfinite gradient or change in loss scale. The name is 'good_steps' for
235    # backwards compatibility with older checkpoints.
236    self._counter = self._add_weight(
237        name='good_steps', dtype=dtypes.int64, initial_value=0)
238
239  def _add_weight(self, name, initial_value, dtype=None):
240    """Adds a weight to this loss scale.
241
242    Args:
243      name: Variable name.
244      initial_value: The variable's initial value.
245      dtype: The type of the variable.
246
247    Returns:
248      A variable.
249
250    Raises:
251      RuntimeError: If a weight with `name` has already been added.
252    """
253    variable = variable_scope.variable(
254        initial_value=initial_value,
255        name=name,
256        dtype=dtype,
257        trainable=False,
258        use_resource=True,
259        synchronization=variables.VariableSynchronization.AUTO,
260        # Set aggregation to NONE, as loss scaling variables should never be
261        # aggregated.
262        aggregation=variables.VariableAggregation.NONE)
263    if context.executing_eagerly():
264      graph_key = None
265    else:
266      graph = ops.get_default_graph()
267      graph_key = graph._graph_key  # pylint: disable=protected-access
268
269    key = (name, graph_key)
270    self._weights[key] = variable
271    self._handle_deferred_dependencies(name=name, trackable=variable)
272    backend.track_variable(variable)
273    return variable
274
275  @property
276  def _checkpoint_dependencies(self):
277    """From Trackable. Gather graph-specific weights to save."""
278    if context.executing_eagerly():
279      graph_key = None
280    else:
281      graph = ops.get_default_graph()
282      graph_key = graph._graph_key  # pylint: disable=protected-access
283    weights = []
284    for (name, g), v in sorted(self._weights.items(), key=lambda i: i[0][0]):
285      if g == graph_key:
286        weights.append(trackable.TrackableReference(name=name, ref=v))
287    return (super(_DynamicLossScaleState, self)._checkpoint_dependencies +
288            weights)
289
290  def _lookup_dependency(self, name):
291    """From Trackable. Find a weight in the current graph."""
292    unconditional = super(_DynamicLossScaleState, self)._lookup_dependency(name)
293    if unconditional is not None:
294      return unconditional
295    if context.executing_eagerly():
296      graph_key = None
297    else:
298      graph = ops.get_default_graph()
299      graph_key = graph._graph_key  # pylint: disable=protected-access
300    return self._weights.get((name, graph_key), None)
301
302  @property
303  def initial_loss_scale(self):
304    return self._initial_loss_scale
305
306  @property
307  def growth_steps(self):
308    return self._growth_steps
309
310  @property
311  def multiplier(self):
312    return self._multiplier
313
314  @property
315  def current_loss_scale(self):
316    """Returns the current loss scale as a float32 `tf.Variable`."""
317    return self._current_loss_scale
318
319  @property
320  def counter(self):
321    """Returns the counter as a float32 `tf.Variable`."""
322    return self._counter
323
324  def __call__(self):
325    """Returns the current loss scale as a scalar `float32` tensor."""
326    return ops.convert_to_tensor(self._current_loss_scale)
327
328  def update(self, grads):
329    """Updates the value of the loss scale.
330
331    Args:
332      grads: A nested structure of unscaled gradients, each which is the
333        gradient of the loss with respect to a weight.
334
335    Returns:
336      update_op: In eager mode, None. In graph mode, an op to update the loss
337        scale.
338      should_apply_gradients: Either a bool or a scalar boolean tensor. If
339        False, the caller should skip applying `grads` to the variables this
340        step.
341    """
342    grads = nest.flatten(grads)
343    if distribution_strategy_context.has_strategy():
344      distribution = distribution_strategy_context.get_strategy()
345
346      def get_is_finite(grads):
347        is_finite = _is_all_finite(grads)
348        # We cast to float, because we cannot reduce booleans with
349        # DistributionStrategy.
350        return math_ops.cast(is_finite, dtypes.float32)
351
352      is_finite_float = distribution.extended.call_for_each_replica(
353          get_is_finite, args=(grads,))
354      reduced_is_finite_float = distribution.reduce(reduce_util.ReduceOp.SUM,
355                                                    is_finite_float, axis=None)
356      is_finite = math_ops.equal(reduced_is_finite_float,
357                                 distribution.num_replicas_in_sync)
358    else:
359      is_finite = _is_all_finite(grads)
360
361    def update_if_finite_grads():
362      """Update assuming the gradients are finite."""
363
364      def incr_loss_scale():
365        new_loss_scale = self.current_loss_scale * self.multiplier
366        return control_flow_ops.group(
367            _assign_if_finite(self.current_loss_scale, new_loss_scale),
368            self.counter.assign(0))
369
370      return control_flow_ops.cond(
371          self.counter + 1 >= self.growth_steps,
372          incr_loss_scale,
373          lambda: _op_in_graph_mode(self.counter.assign_add(1)))
374
375    def update_if_not_finite_grads():
376      """Update assuming the gradients are nonfinite."""
377
378      new_loss_scale = math_ops.maximum(
379          self.current_loss_scale / self.multiplier, 1)
380      return control_flow_ops.group(
381          self.counter.assign(0),
382          self.current_loss_scale.assign(new_loss_scale))
383
384    update_op = control_flow_ops.cond(is_finite, update_if_finite_grads,
385                                      update_if_not_finite_grads)
386    should_apply_gradients = is_finite
387    return update_op, should_apply_gradients
388
389
390# See LossScaleOptimizer docstring for why this is so big
391_DEFAULT_INITIAL_SCALE = 2 ** 15
392_DEFAULT_GROWTH_STEPS = 2000
393
394
395# pylint: disable=g-classes-have-attributes
396@keras_export('keras.mixed_precision.LossScaleOptimizer')
397class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
398  """An optimizer that applies loss scaling to prevent numeric underflow.
399
400  Loss scaling is a technique to prevent numeric underflow in intermediate
401  gradients when float16 is used. To prevent underflow, the loss is multiplied
402  (or "scaled") by a certain factor called the "loss scale", which causes
403  intermediate gradients to be scaled by the loss scale as well. The final
404  gradients are divided (or "unscaled") by the loss scale to bring them back to
405  their original value.
406
407  `LossScaleOptimizer` wraps another optimizer and applies loss scaling to it.
408  By default, the loss scale is dynamically updated over time so you do not have
409  to choose the loss scale. The `minimize` method automatically scales the loss,
410  unscales the gradients, and updates the loss scale so all you have to do is
411  wrap your optimizer with a `LossScaleOptimizer` if you use `minimize`. For
412  example:
413
414  >>> opt = tf.keras.optimizers.SGD(0.25)
415  >>> opt = tf.keras.mixed_precision.LossScaleOptimizer(opt)
416  >>> var = tf.Variable(1.)
417  >>> loss_fn = lambda: var ** 2
418  >>> # 'minimize' applies loss scaling and updates the loss sale.
419  >>> opt.minimize(loss_fn, var_list=var)
420  >>> var.numpy()
421  0.5
422
423  If a `tf.GradientTape` is used to compute gradients instead of `minimize`, you
424  must scale the loss and gradients manually. This can be done with the
425  `LossScaleOptimizer.get_scaled_loss` and
426  `LossScaleOptimizer.get_unscaled_gradients` methods. For example:
427
428  >>> with tf.GradientTape() as tape:
429  ...   loss = loss_fn()
430  ...   scaled_loss = opt.get_scaled_loss(loss)
431  >>> scaled_grad = tape.gradient(scaled_loss, var)
432  >>> (grad,) = opt.get_unscaled_gradients([scaled_grad])
433  >>> opt.apply_gradients([(grad, var)])  # Loss scale is updated here
434  >>> var.numpy()
435  0.25
436
437  Warning: If you forget to call `get_scaled_loss` or `get_unscaled_gradients`
438  (or both) when using a `tf.GradientTape`, the model will likely converge to a
439  worse quality. Please make sure you call each function exactly once.
440
441  When mixed precision with float16 is used, there is typically no risk of
442  underflow affecting model quality if loss scaling is properly used. See
443  [the mixed precision guide](
444  https://www.tensorflow.org/guide/keras/mixed_precision) for more information
445  on how to use mixed precision.
446
447  Args:
448    inner_optimizer: The `tf.keras.optimizers.Optimizer` instance to wrap.
449    dynamic: Bool indicating whether dynamic loss scaling is used. Defaults to
450      True. If True, the loss scale will be dynamically updated over time using
451      an algorithm that keeps the loss scale at approximately its optimal value.
452      If False, a single fixed loss scale is used and `initial_scale` must be
453      specified, which is used as the loss scale. Recommended to keep as True,
454      as choosing a fixed loss scale can be tricky. Currently, there is a small
455      performance overhead to dynamic loss scaling compared to fixed loss
456      scaling.
457    initial_scale: The initial loss scale. If `dynamic` is True, this defaults
458      to `2 ** 15`. If `dynamic` is False, this must be specified and acts as
459      the sole loss scale, as the loss scale does not change over time. When
460      dynamic loss scaling is used, is better for this to be a very high number,
461      because a loss scale that is too high gets lowered far more quickly than a
462      loss scale that is too low gets raised.
463    dynamic_growth_steps: With dynamic loss scaling, every
464      `dynamic_growth_steps` steps with finite gradients, the loss scale is
465      doubled. Defaults to 2000. If a nonfinite gradient is encountered, the
466      count is reset back to zero, gradients are skipped that step, and the loss
467      scale is halved. The count can be queried with
468      `LossScaleOptimizer.dynamic_counter`. This argument can only be specified
469      if `dynamic` is True.
470
471  `LossScaleOptimizer` will occasionally skip applying gradients to the
472  variables, in which case the trainable variables will not change that step.
473  This is done because the dynamic loss scale will sometimes be raised too
474  high, causing overflow in the gradients. Typically, the first 2 to 15 steps of
475  the model are skipped as the initial loss scale is very high, but afterwards
476  steps will only be skipped on average 0.05% of the time (the fraction of steps
477  skipped is `1 / dynamic_growth_steps`).
478
479  `LossScaleOptimizer` delegates all public `Optimizer` methods to the inner
480  optimizer. Additionally, in methods `minimize` and `get_gradients, it scales
481  the loss and unscales the gradients. In methods `minimize` and
482  `apply_gradients`, it additionally updates the loss scale and skips applying
483  gradients if any gradient has a nonfinite value.
484
485  ### Hyperparameters
486
487  Hyperparameters can be accessed and set on the LossScaleOptimizer, which will
488  be delegated to the wrapped optimizer.
489
490  >>> opt = tf.keras.optimizers.Adam(beta_1=0.8, epsilon=1e-5)
491  >>> opt = tf.keras.mixed_precision.LossScaleOptimizer(opt)
492  >>> opt.beta_1  # Equivalent to `opt.inner_optimizer.beta_1`
493  0.8
494  >>> opt.beta_1 = 0.7  # Equivalent to `opt.inner_optimizer.beta_1 = 0.7`
495  >>> opt.beta_1
496  0.7
497  >>> opt.inner_optimizer.beta_1
498  0.7
499
500  However, accessing or setting non-hyperparameters is not delegated to the
501  LossScaleOptimizer. In an Adam optimizer, `beta_1` is a hyperparameter but
502  `epsilon` is not, as the Adam optimizer only calls `Optimizer._set_hyper` on
503  `beta_1`.
504
505  >>> opt.inner_optimizer.epsilon
506  1e-5
507  >>> opt.epsilon
508  Traceback (most recent call last):
509  ...
510  AttributeError: 'LossScaleOptimizer' object has no attribute 'epsilon'
511  >>> opt.epsilon = 1e-4  # This does NOT set epsilon on `opt.inner_optimizer`
512  >>> opt.inner_optimizer.epsilon
513  >>> 1e-5
514
515  In the above example, despite epsilon being set on the LossScaleOptimizer, the
516  old epsilon value will still be used when training as epsilon was not set on
517  the inner optimizer.
518  """
519
520  _HAS_AGGREGATE_GRAD = True
521
522  def __init__(self, inner_optimizer, dynamic=True, initial_scale=None,
523               dynamic_growth_steps=None):
524    if not isinstance(inner_optimizer, optimizer_v2.OptimizerV2):
525      raise TypeError('"inner_optimizer" must be an instance of OptimizerV2, '
526                      'but got: %s' % inner_optimizer)
527    if not isinstance(dynamic, bool):
528      # Catch errors if a user incorrectly passes a string or float to the
529      # second argument argument, as this is commonly done for
530      # LossScaleOptimizerV1.
531      raise TypeError('"dynamic" argument to LossScaleOptimizer.__init__ must '
532                      'be a bool, but got: %r' % (dynamic,))
533    self._raise_if_strategy_unsupported()
534    self._optimizer = inner_optimizer
535
536    # We don't call super().__init__, since we do not want to call OptimizerV2's
537    # constructor.
538    _DelegatingTrackableMixin.__init__(self, self._optimizer)
539
540    if dynamic:
541      if initial_scale is None:
542        initial_scale = _DEFAULT_INITIAL_SCALE
543      if dynamic_growth_steps is None:
544        dynamic_growth_steps = _DEFAULT_GROWTH_STEPS
545      self._loss_scale = _DynamicLossScaleState(
546          initial_scale, dynamic_growth_steps, multiplier=2)
547      self._track_trackable(self._loss_scale, 'loss_scale')
548    else:
549      if initial_scale is None:
550        raise ValueError('"initial_scale" must be specified if "dynamic" is '
551                         'False')
552      self._loss_scale = float(initial_scale)
553      if dynamic_growth_steps is not None:
554        raise ValueError('"dynamic_growth_steps" must be None if "dynamic" '
555                         'is False, but got: %s' % (dynamic_growth_steps,))
556
557    # To support restoring TensorFlow 2.2 checkpoints.
558    self._track_trackable(FakeOptimizerForRestoration(self._optimizer),
559                          'base_optimizer')
560
561  @property
562  def dynamic(self):
563    """Bool indicating whether dynamic loss scaling is used."""
564    return isinstance(self._loss_scale, _DynamicLossScaleState)
565
566  @property
567  def loss_scale(self):
568    """The current loss scale as a float32 scalar tensor."""
569    if isinstance(self._loss_scale, _DynamicLossScaleState):
570      return ops.convert_to_tensor(self._loss_scale.current_loss_scale)
571    else:
572      return ops.convert_to_tensor(self._loss_scale)
573
574  @property
575  def dynamic_counter(self):
576    """The number of steps since the loss scale was last increased or decreased.
577
578    This is None if `LossScaleOptimizer.dynamic` is False.
579
580    The counter is incremented every step. Once it reaches
581    `LossScaleOptimizer.dynamic_growth_steps`, the loss scale will be doubled
582    and the counter will be reset back to zero. If nonfinite gradients are
583    encountered, the loss scale will be halved and the counter will be reset
584    back to zero.
585    """
586    if isinstance(self._loss_scale, _DynamicLossScaleState):
587      return self._loss_scale.counter
588    else:
589      return None
590
591  @property
592  def initial_scale(self):
593    """The initial loss scale.
594
595    If `LossScaleOptimizer.dynamic` is False, this is the same number as
596    `LossScaleOptimizer.loss_scale`, as the loss scale never changes.
597    """
598    if isinstance(self._loss_scale, _DynamicLossScaleState):
599      return self._loss_scale.initial_loss_scale
600    else:
601      return self._loss_scale
602
603  @property
604  def dynamic_growth_steps(self):
605    """The number of steps it takes to increase the loss scale.
606
607    This is None if `LossScaleOptimizer.dynamic` is False.
608
609    Every `dynamic_growth_steps` consecutive steps with finite gradients, the
610    loss scale is increased.
611    """
612    if isinstance(self._loss_scale, _DynamicLossScaleState):
613      return self._loss_scale.growth_steps
614    else:
615      return None
616
617  @property
618  def inner_optimizer(self):
619    """The optimizer that this LossScaleOptimizer is wrapping."""
620    return self._optimizer
621
622  def get_scaled_loss(self, loss):
623    """Scales the loss by the loss scale.
624
625    This method is only needed if you compute gradients manually, e.g. with
626    `tf.GradientTape`. In that case, call this method to scale the loss before
627    passing the loss to `tf.GradientTape`. If you use
628    `LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, loss
629    scaling is automatically applied and this method is unneeded.
630
631    If this method is called, `get_unscaled_gradients` should also be called.
632    See the `tf.keras.mixed_precision.LossScaleOptimizer` doc for
633    an example.
634
635    Args:
636      loss: The loss, which will be multiplied by the loss scale. Can either be
637        a tensor or a callable returning a tensor.
638
639    Returns:
640      `loss` multiplied by `LossScaleOptimizer.loss_scale`.
641    """
642    if callable(loss):
643      def new_loss():
644        loss_val = loss()
645        return loss_val * math_ops.cast(self.loss_scale, loss_val.dtype)
646      return new_loss
647    else:
648      return loss * math_ops.cast(self.loss_scale, loss.dtype)
649
650  def get_unscaled_gradients(self, grads):
651    """Unscales the gradients by the loss scale.
652
653    This method is only needed if you compute gradients manually, e.g. with
654    `tf.GradientTape`. In that case, call this method to unscale the gradients
655    after computing them with `tf.GradientTape`. If you use
656    `LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, loss
657    scaling is automatically applied and this method is unneeded.
658
659    If this method is called, `get_scaled_loss` should also be called. See
660    the `tf.keras.mixed_precision.LossScaleOptimizer` doc for an
661    example.
662
663    Args:
664      grads: A list of tensors, each which will be divided by the loss scale.
665        Can have None values, which are ignored.
666
667    Returns:
668      A new list the same size as `grads`, where every non-None value in `grads`
669      is divided by `LossScaleOptimizer.loss_scale`.
670    """
671    loss_scale_reciprocal = 1. / self.loss_scale
672    return [
673        _multiply_gradient(g, loss_scale_reciprocal) if g is not None else None
674        for g in grads
675    ]
676
677  def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
678    tape = backprop.GradientTape() if tape is None else tape
679    with tape:
680      loss = self.get_scaled_loss(loss)
681    grads_and_vars = self._optimizer._compute_gradients(  # pylint: disable=protected-access
682        loss,
683        var_list,
684        grad_loss,
685        tape=tape)
686    grads = [g for g, _ in grads_and_vars]
687    weights = [v for _, v in grads_and_vars]
688    unscaled_grads = self.get_unscaled_gradients(grads)
689    return list(zip(unscaled_grads, weights))
690
691  def get_gradients(self, loss, params):
692    loss = self.get_scaled_loss(loss)
693    grads = self._optimizer.get_gradients(loss, params)
694    return self.get_unscaled_gradients(grads)
695
696  def _create_all_weights(self, var_list):
697    self._optimizer._create_all_weights(var_list)    # pylint: disable=protected-access
698
699  def apply_gradients(self,
700                      grads_and_vars,
701                      name=None,
702                      experimental_aggregate_gradients=True):
703    if distribution_strategy_context.in_cross_replica_context():
704      raise ValueError('apply_gradients() must be called in a replica context.')
705    # We check for the strategy here despite already checking in the constructor
706    # as frequently the optimizer is created outside the strategy's scope.
707    self._raise_if_strategy_unsupported()
708
709    grads_and_vars = tuple(grads_and_vars)
710    return distribution_strategy_context.get_replica_context().merge_call(
711        self._apply_gradients_cross_replica,
712        args=(grads_and_vars, name, experimental_aggregate_gradients))
713
714  def _apply_gradients_cross_replica(self, distribution, grads_and_vars, name,
715                                     experimental_aggregate_gradients):
716    grads = [g for g, _ in grads_and_vars]
717    if isinstance(self._loss_scale, _DynamicLossScaleState):
718      loss_scale_update_op, should_apply_grads = self._loss_scale.update(grads)
719    else:
720      loss_scale_update_op = control_flow_ops.no_op()
721      should_apply_grads = True
722
723    def apply_fn():
724      # We do not want DistributionStrategy to unwrap any MirroredVariables in
725      # grads_and_vars, because even in a replica context, the wrapped optimizer
726      # expects mirrored variables. So we wrap the variables with an
727      # _UnwrapPreventer, preventing DistributionStrategy from unwrapping the
728      # MirroredVariables.
729      wrapped_vars = _UnwrapPreventer([v for _, v in grads_and_vars])
730      return distribution.extended.call_for_each_replica(
731          self._apply_gradients,
732          args=(grads, wrapped_vars, name, experimental_aggregate_gradients))
733
734    def do_not_apply_fn():
735      # Normally self._optimizer.iterations is incremented in
736      # self._optimizer.apply_gradients(). Since that is not called in this
737      # branch, we increment it here instead.
738      return self._optimizer.iterations.assign_add(1, read_value=False)
739
740    # Note: We must call this cond() in a cross-replica context.
741    # DistributionStrategy does not support having a cond in a replica context
742    # with a branch that calls `merge_call`, and self._optimizer.apply_gradients
743    # calls `merge_call`.
744    maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn,
745                                           do_not_apply_fn)
746    return control_flow_ops.group(maybe_apply_op, loss_scale_update_op)
747
748  def _apply_gradients(self, grads, wrapped_vars, name,
749                       experimental_aggregate_gradients):
750    # TODO(reedwm): This will raise a fairly cryptic error message if
751    # self._optimizer.apply_gradients does not take
752    # experimental_aggregate_gradients.
753    return self._optimizer.apply_gradients(
754        list(zip(grads, wrapped_vars.value)), name,
755        experimental_aggregate_gradients=experimental_aggregate_gradients)
756
757  def get_config(self):
758    serialized_optimizer = optimizers.serialize(self._optimizer)
759    return {
760        'inner_optimizer': serialized_optimizer,
761        'dynamic': self.dynamic,
762        'initial_scale': self.initial_scale,
763        'dynamic_growth_steps': self.dynamic_growth_steps,
764    }
765
766  @classmethod
767  def from_config(cls, config, custom_objects=None):
768    config = config.copy()  # Make a copy, since we mutate config
769    if 'loss_scale' in config:
770      # If loss_scale is in config, we assume we are deserializing a
771      # LossScaleOptimizer from TF 2.3 or below. We convert the config so it
772      # can be deserialized in the current LossScaleOptimizer.
773      loss_scale = keras_loss_scale_module.deserialize(
774          config.pop('loss_scale'))
775      if isinstance(loss_scale, loss_scale_module.FixedLossScale):
776        config['dynamic'] = False
777        config['initial_scale'] = loss_scale._loss_scale_value  # pylint: disable=protected-access
778      elif isinstance(loss_scale, loss_scale_module.DynamicLossScale):
779        config['dynamic'] = True
780        config['initial_scale'] = loss_scale.initial_loss_scale
781        config['dynamic_growth_steps'] = loss_scale.increment_period
782        if loss_scale.multiplier != 2:
783          raise ValueError('Cannot deserialize LossScaleOptimizer with a '
784                           'DynamicLossScale whose multiplier is not 2. Got '
785                           'DynamicLossScale: %s' % (loss_scale,))
786      else:
787        raise ValueError(
788            'Serialized LossScaleOptimizers with a LossScale that is neither a '
789            'FixedLossScale nor a DynamicLossScale can no longer be '
790            'deserialized')
791      config['inner_optimizer'] = config.pop('optimizer')
792    config['inner_optimizer'] = optimizers.deserialize(
793        config['inner_optimizer'], custom_objects=custom_objects)
794    return cls(**config)
795
796  def _raise_if_strategy_unsupported(self):
797    if not strategy_supports_loss_scaling():
798      strategy = distribution_strategy_context.get_strategy()
799      if isinstance(strategy,
800                    (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1,
801                     tpu_strategy.TPUStrategyV2)):
802        raise ValueError(
803            'Loss scaling is not supported with TPUStrategy. Loss scaling is '
804            'unnecessary with TPUs, since they support bfloat16 instead of '
805            'float16 and bfloat16 does not require loss scaling. You should '
806            'remove the use of the LossScaleOptimizer when TPUs are used.')
807      else:
808        raise ValueError('Loss scaling is not supported with the '
809                         'tf.distribute.Strategy: %s. Try using a different '
810                         'Strategy, e.g. a MirroredStrategy' %
811                         strategy.__class__.__name__)
812
813  # Delegations: We delegate most OptimizerV2 methods to the wrapped optimizer
814  # below.
815
816  @property
817  def iterations(self):
818    return self._optimizer.iterations
819
820  @iterations.setter
821  def iterations(self, variable):
822    self._optimizer.iterations = variable
823
824  def get_slot_names(self):
825    return self._optimizer.get_slot_names()
826
827  def variables(self):
828    return self._optimizer.variables()
829
830  @property
831  def weights(self):
832    return self._optimizer.weights
833
834  def get_weights(self):
835    return self._optimizer.get_weights()
836
837  def set_weights(self, weights):
838    return self._optimizer.set_weights(weights)
839
840  @property
841  def clipnorm(self):
842    return self._optimizer.clipnorm
843
844  @clipnorm.setter
845  def clipnorm(self, val):
846    self._optimizer.clipnorm = val
847
848  @property
849  def global_clipnorm(self):
850    return self._optimizer.global_clipnorm
851
852  @global_clipnorm.setter
853  def global_clipnorm(self, val):
854    self._optimizer.global_clipnorm = val
855
856  @property
857  def clipvalue(self):
858    return self._optimizer.clipvalue
859
860  @clipvalue.setter
861  def clipvalue(self, val):
862    self._optimizer.clipvalue = val
863
864  def _aggregate_gradients(self, grads_and_vars):
865    return self._optimizer._aggregate_gradients(grads_and_vars)  # pylint: disable=protected-access
866
867  def _restore_slot_variable(self, slot_name, variable, slot_variable):
868    return self._optimizer._restore_slot_variable(slot_name, variable,  # pylint: disable=protected-access
869                                                  slot_variable)
870
871  def _create_or_restore_slot_variable(self, slot_variable_position, slot_name,
872                                       variable):
873    return self._optimizer._create_or_restore_slot_variable(  # pylint: disable=protected-access
874        slot_variable_position, slot_name, variable)
875
876  def get_slot(self, var, slot_name):
877    return self._optimizer.get_slot(var, slot_name)
878
879  def add_slot(self, var, slot_name, initializer='zeros'):
880    return self._optimizer.add_slot(var, slot_name, initializer)
881
882  def __getattribute__(self, name):
883    try:
884      return object.__getattribute__(self, name)
885    except AttributeError as e:
886      if name == '_optimizer' or name == '_hyper':
887        # Avoid infinite recursion
888        raise e
889
890      # Delegate hyperparameter accesses to inner optimizer.
891      if name == 'lr':
892        name = 'learning_rate'
893      if name in self._optimizer._hyper:
894        return self._optimizer._get_hyper(name)
895      raise e
896
897  def __dir__(self):
898    result = set(super(LossScaleOptimizer, self).__dir__())
899    if '_optimizer' in result:
900      result |= self._optimizer._hyper.keys()
901      if 'learning_rate' in self._optimizer._hyper.keys():
902        result.add('lr')
903    return list(result)
904
905  def __setattr__(self, name, value):
906    if name == 'lr':
907      name = 'learning_rate'
908    # Delegate setting hyperparameter to inner optimizer if the attribute does
909    # not exist on the LossScaleOptimizer
910    try:
911      # We cannot check for the 'iterations' attribute as it cannot be set after
912      # it is accessed.
913      if name != 'iterations':
914        object.__getattribute__(self, name)
915      has_attribute = True
916    except AttributeError:
917      has_attribute = False
918    if (name != '_optimizer' and name in self._optimizer._hyper
919        and not has_attribute):
920      self._optimizer._set_hyper(name, value)
921    else:
922      super(LossScaleOptimizer, self).__setattr__(name, value)
923
924  # We do not override some OptimizerV2 methods. For each, we describe why we do
925  # not delegate them to self._optimizer:
926  # * get_updates: get_updates() calls get_gradients(). Since we override
927  #   get_gradients(), we cannot delegate get_updates() to self._optimizer,
928  #   otherwise the overridden get_gradients() method would not be called.
929  #   Luckily, get_updates() does not access any OptimizerV2 fields, so
930  #   inheriting the OptimizerV2 version works fine.
931  # * minimize: We don't delegate for a similar as get_updates(): it calls
932  #   both self._compute_gradients() and self.apply_gradients(), and both need
933  #   to have the LossScaleOptimizer version called.
934
935  # TODO(reedwm): Maybe throw an error if mixed precision is used without this
936  # optimizer being used.
937
938
939@keras_export('keras.mixed_precision.experimental.LossScaleOptimizer')
940class LossScaleOptimizerV1(LossScaleOptimizer):
941  """An deprecated optimizer that applies loss scaling.
942
943  Warning: This class is deprecated and will be removed in TensorFlow 2.5.
944  Please use the non-experimental class
945  `tf.keras.mixed_precision.LossScaleOptimizer` instead.
946
947  This class is identical to the non-experimental
948  `keras.mixed_precision.LossScaleOptimizer` except its constructor takes
949  different arguments. For this class (the experimental version), the
950  constructor takes a `loss_scale` argument.  For the non-experimental class,
951  the constructor encodes the loss scaling information in multiple arguments.
952  Note that unlike this class, the non-experimental class does not accept a
953  `tf.compat.v1.mixed_precision.LossScale`, which is deprecated.
954
955  If you currently use this class, you should switch to the non-experimental
956  `tf.keras.mixed_precision.LossScaleOptimizer` instead. We show several
957  examples of converting the use of the experimental class to the equivalent
958  non-experimental class.
959
960  >>> # In all of the the examples below, `opt1` and `opt2` are identical
961  >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
962  ...     tf.keras.optimizers.SGD(), loss_scale='dynamic')
963  >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer(
964  ...     tf.keras.optimizers.SGD())
965  >>> assert opt1.get_config() == opt2.get_config()
966
967  >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
968  ...     tf.keras.optimizers.SGD(), loss_scale=123)
969  >>> # dynamic=False indicates to use fixed loss scaling. initial_scale=123
970  >>> # refers to the initial loss scale, which is the single fixed loss scale
971  >>> # when dynamic=False.
972  >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer(
973  ...     tf.keras.optimizers.SGD(), dynamic=False, initial_scale=123)
974  >>> assert opt1.get_config() == opt2.get_config()
975
976  >>> loss_scale = tf.compat.v1.mixed_precision.experimental.DynamicLossScale(
977  ...     initial_loss_scale=2048, increment_period=500)
978  >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
979  ...     tf.keras.optimizers.SGD(), loss_scale=loss_scale)
980  >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer(
981  ...     tf.keras.optimizers.SGD(), initial_scale=2048,
982  ...     dynamic_growth_steps=500)
983  >>> assert opt1.get_config() == opt2.get_config()
984
985  Make sure to also switch from this class to the non-experimental class in
986  isinstance checks, if you have any. If you do not do this, your model may run
987  into hard-to-debug issues, as the experimental `LossScaleOptimizer` subclasses
988  the non-experimental `LossScaleOptimizer`, but not vice versa. It is safe to
989  switch isinstance checks to the non-experimental `LossScaleOptimizer` even
990  before using the non-experimental `LossScaleOptimizer`.
991
992  >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
993  ...     tf.keras.optimizers.SGD(), loss_scale='dynamic')
994  >>> # The experimental class subclasses the non-experimental class
995  >>> isinstance(opt1, tf.keras.mixed_precision.LossScaleOptimizer)
996  True
997  >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer(
998  ...     tf.keras.optimizers.SGD())
999  >>> # The non-experimental class does NOT subclass the experimental class.
1000  >>> isinstance(opt2, tf.keras.mixed_precision.experimental.LossScaleOptimizer)
1001  False
1002
1003  Args:
1004    optimizer: The Optimizer instance to wrap.
1005    loss_scale: The loss scale to scale the loss and gradients. This can
1006      either be an int/float to use a fixed loss scale, the string "dynamic"
1007      to use dynamic loss scaling, or an instance of a LossScale. The string
1008      "dynamic" equivalent to passing `DynamicLossScale()`, and passing an
1009      int/float is equivalent to passing a FixedLossScale with the given loss
1010      scale. If a DynamicLossScale is passed, DynamicLossScale.multiplier must
1011      be 2 (the default).
1012  """
1013
1014  def __init__(self, optimizer, loss_scale):
1015    warn_msg_prefix = (
1016        'tf.keras.mixed_precision.experimental.LossScaleOptimizer is '
1017        'deprecated. Please use tf.keras.mixed_precision.LossScaleOptimizer '
1018        'instead. ')
1019
1020    if isinstance(loss_scale, dict):
1021      loss_scale = keras_loss_scale_module.deserialize(loss_scale)
1022
1023    if isinstance(loss_scale, (int, float)):
1024      tf_logging.warn(
1025          warn_msg_prefix + 'For example\n'
1026          '  opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer('
1027          'opt, dynamic=False, initial_scale={})'.format(loss_scale))
1028      super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False,
1029                                                 initial_scale=loss_scale)
1030    elif isinstance(loss_scale, loss_scale_module.FixedLossScale):
1031      ls_val = loss_scale._loss_scale_value  # pylint: disable=protected-access
1032      tf_logging.warn(
1033          warn_msg_prefix + 'For example\n'
1034          '  opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer('
1035          'opt, dynamic=False, initial_scale={})'.format(ls_val))
1036      super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False,
1037                                                 initial_scale=ls_val)
1038    elif loss_scale == 'dynamic':
1039      tf_logging.warn(
1040          warn_msg_prefix + 'For example\n'
1041          '  opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer('
1042          'opt)')
1043      super(LossScaleOptimizerV1, self).__init__(optimizer)
1044    elif isinstance(loss_scale, loss_scale_module.DynamicLossScale):
1045      kwargs = {}
1046      extra_arguments = ''
1047      if loss_scale.initial_loss_scale != _DEFAULT_INITIAL_SCALE:
1048        kwargs['initial_scale'] = loss_scale.initial_loss_scale
1049        extra_arguments += (', initial_scale=%s' %
1050                            loss_scale.initial_loss_scale)
1051      if loss_scale.increment_period != _DEFAULT_GROWTH_STEPS:
1052        kwargs['dynamic_growth_steps'] = loss_scale.increment_period
1053        extra_arguments += (', dynamic_growth_steps=%s' %
1054                            loss_scale.increment_period)
1055      if loss_scale.multiplier != 2:
1056        raise ValueError('When passing a DynamicLossScale to "loss_scale", '
1057                         'DynamicLossScale.multiplier must be 2. Got: %s'
1058                         % (loss_scale,))
1059      tf_logging.warn(
1060          warn_msg_prefix +
1061          'Note that the non-experimental LossScaleOptimizer does not take a '
1062          'DynamicLossScale but instead takes the dynamic configuration '
1063          'directly in the constructor. For example:\n'
1064          '  opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer('
1065          'opt{})\n'.format(extra_arguments))
1066      super(LossScaleOptimizerV1, self).__init__(optimizer, **kwargs)
1067    elif isinstance(loss_scale, loss_scale_module.LossScale):
1068      raise TypeError('Passing a LossScale that is not a FixedLossScale or a '
1069                      'DynamicLossScale is no longer supported. Got: {}'
1070                      .format(loss_scale))
1071    else:
1072      raise ValueError('Invalid value passed to loss_scale. loss_scale '
1073                       'must be the string "dynamic" (recommended), an int, '
1074                       'a float, a FixedLossScale, or a DynamicLossScale. Got '
1075                       'value: {}'.format(loss_scale))
1076
1077  @classmethod
1078  def from_config(cls, config, custom_objects=None):
1079    config = config.copy()  # Make a copy, since we mutate config
1080
1081    # If loss_scale is in config, we assume we are deserializing a
1082    # LossScaleOptimizer from TF 2.3 or below. Otherwise, we assume we are
1083    # deserializing a LossScaleOptimizer from TF 2.4 or above.
1084    if 'loss_scale' in config:
1085      config['loss_scale'] = keras_loss_scale_module.deserialize(
1086          config['loss_scale'])
1087      if (isinstance(config['loss_scale'], loss_scale_module.DynamicLossScale)
1088          and config['loss_scale'].multiplier != 2):
1089        raise ValueError('Cannot deserialize LossScaleOptimizer with a '
1090                         'DynamicLossScale whose multiplier is not 2. Got '
1091                         'DynamicLossScale: %s' % (config['loss_scale'],))
1092      config['optimizer'] = optimizers.deserialize(
1093          config['optimizer'], custom_objects=custom_objects)
1094      return cls(**config)
1095
1096    # We convert the config, as generated by LossScaleOptimizer.get_config, to a
1097    # version that can be passed to LossScaleOptimizerV1.__init__
1098    if config['dynamic']:
1099      config['loss_scale'] = loss_scale_module.DynamicLossScale(
1100          config['initial_scale'], config['dynamic_growth_steps'], multiplier=2)
1101    else:
1102      config['loss_scale'] = loss_scale_module.FixedLossScale(
1103          config['initial_scale'])
1104
1105    del config['dynamic']
1106    del config['initial_scale']
1107    del config['dynamic_growth_steps']
1108    config['optimizer'] = optimizers.deserialize(
1109        config.pop('inner_optimizer'), custom_objects=custom_objects)
1110    return cls(**config)
1111
1112
1113class FakeOptimizerForRestoration(trackable.Trackable):
1114  """A fake optimizer used to support restoring TensorFlow 2.2 checkpoints.
1115
1116  The checkpoint format for LossScaleOptimizers changed after TF 2.2. This class
1117  exists to support restoring TF 2.2 checkpoints in newer version of TensorFlow.
1118
1119  In TF 2.2, LossScaleOptimizer would track the wrapped optimizer by calling the
1120  following in LossScaleOptimizer.__init__
1121
1122  ```
1123  self._track_trackable(self._optimizer, 'base_optimizer')
1124  ```
1125
1126  This means a dependency from the LossScaleOptimizer to the wrapped optimizer
1127  would be stored in the checkpoint. However now, the checkpoint format with a
1128  LossScaleOptimizer is the same as the format without a LossScaleOptimizer,
1129  except the loss scale is also stored. This means there is no dependency from
1130  the LossScaleOptimizer to the wrapped optimizer. Instead, the
1131  LossScaleOptimizer acts as if it is the wrapped optimizer, from a checkpoint's
1132  perspective, by overriding all Trackable methods and delegating them to the
1133  wrapped optimizer.
1134
1135  To allow restoring TF 2.2. checkpoints, LossScaleOptimizer adds a dependency
1136  on this class instead of the inner optimizer. When restored, this class will
1137  instead restore the slot variables of the inner optimizer. Since this class
1138  has no variables, it does not affect the checkpoint when saved.
1139  """
1140
1141  def __init__(self, optimizer):
1142    self._optimizer = optimizer
1143
1144  def get_slot_names(self):
1145    return self._optimizer.get_slot_names()
1146
1147  def _create_or_restore_slot_variable(self, slot_variable_position, slot_name,
1148                                       variable):
1149    return self._optimizer._create_or_restore_slot_variable(  # pylint: disable=protected-access
1150        slot_variable_position, slot_name, variable)
1151
1152
1153# pylint: disable=protected-access
1154mixed_precision._register_wrapper_optimizer_cls(optimizer_v2.OptimizerV2,
1155                                                LossScaleOptimizerV1)
1156
1157
1158def _multiply_gradient(gradient, scale):
1159  """Multiply a (possibly sparse) gradient by the given scale factor."""
1160  scale = math_ops.cast(scale, gradient.dtype)
1161  if isinstance(gradient, ops.IndexedSlices):
1162    return ops.IndexedSlices(
1163        gradient.values * scale,
1164        gradient.indices,
1165        dense_shape=gradient.dense_shape)
1166  else:
1167    return gradient * scale
1168
1169
1170def strategy_supports_loss_scaling():
1171  """Returns True if the current Strategy supports loss scaling."""
1172  if not distribution_strategy_context.has_strategy():
1173    return True
1174  strategy = distribution_strategy_context.get_strategy()
1175  # Strategies are supported if either there is only one replica or if variables
1176  # are replicated per device. Otherwise, the current model.fit() implementation
1177  # and most custom training loops incorrectly unscale the gradients. Currently,
1178  # gradients are unscaled once per compute replica, but they should be unscaled
1179  # once per variable replica. When there is one variable replica for each
1180  # compute replica, this works fine, but otherwise issues will occur.
1181  # TODO(reedwm): Support all strategies.
1182  return isinstance(strategy, (
1183      collective_all_reduce_strategy.CollectiveAllReduceStrategy,
1184      collective_all_reduce_strategy.CollectiveAllReduceStrategyV1,
1185      one_device_strategy.OneDeviceStrategy,
1186      one_device_strategy.OneDeviceStrategyV1,
1187      mirrored_strategy.MirroredStrategy,
1188      mirrored_strategy.MirroredStrategyV1,
1189  ))
1190