1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains LossScale classes."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import abc
21
22import six
23
24from tensorflow.python.distribute import distribution_strategy_context
25from tensorflow.python.distribute import reduce_util
26from tensorflow.python.eager import context
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.ops import control_flow_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import variable_scope
32from tensorflow.python.ops import variables
33from tensorflow.python.training.tracking import base as trackable
34from tensorflow.python.util import deprecation
35from tensorflow.python.util import nest
36from tensorflow.python.util.tf_export import tf_export
37
38
39@six.add_metaclass(abc.ABCMeta)
40@deprecation.deprecated_endpoints('mixed_precision.experimental.LossScale',
41                                  'train.experimental.LossScale')
42@tf_export(
43    'mixed_precision.experimental.LossScale',
44    'train.experimental.LossScale',
45    v1=[
46        'mixed_precision.LossScale',
47        'mixed_precision.experimental.LossScale',
48        'train.experimental.LossScale'
49    ])
50class LossScale(trackable.Trackable):
51  """Base class for all TF1 loss scales.
52
53  WARNING: This class is deprecated and will be unexposed from the TF 2
54  namespace starting in TensorFlow 2.5. In TensorFlow 2.5, this class will only
55  be accessible as `tf.compat.v1.mixed_precision.LossScale`. Additionally in
56  2.5, you will no longer be able to pass a `LossScale` to a
57  `tf.keras.mixed_precision.Policy`. All the functionality in this class has
58  been merged into `tf.keras.mixed_precision.LossScaleOptimizer`, so this class
59  is no longer needed.
60
61  This is an abstract base class, so you cannot instantiate it directly.
62  Instead, use one of its concrete subclasses:
63    * `tf.compat.v1.mixed_precision.DynamicLossScale`
64    * `tf.compat.v1.mixed_precision.FixedLossScale`
65
66  Loss scaling is a process that multiplies the loss by a multiplier called the
67  loss scale, and divides each gradient by the same multiplier. The pseudocode
68  for this process is:
69
70  ```
71  loss = ...
72  loss *= loss_scale
73  grads = gradients(loss, vars)
74  grads /= loss_scale
75  ```
76
77  Mathematically, loss scaling has no effect, but can help avoid numerical
78  underflow in intermediate gradients when float16 tensors are used for mixed
79  precision training. By multiplying the loss, each intermediate gradient will
80  have the same multiplier applied.
81
82  Instances of this class represent a loss scale. Calling instances of this
83  class returns the loss scale as a scalar float32 tensor, while method
84  `update()` updates the loss scale depending on the values of the gradients.
85  Optimizers use instances of this class to scale loss and gradients.
86
87  In most functions that accept a LossScale, you can also pass an int (such as
88  8) to create a `FixedLossScale` or the string `"dynamic"` to create a dynamic
89  loss scale.
90  """
91
92  def __init__(self):
93    """Initializes the loss scale class."""
94    self._weights = {}
95
96  @abc.abstractmethod
97  def __call__(self):
98    """Returns the current loss scale as a scalar `float32` tensor."""
99    pass
100
101  @abc.abstractmethod
102  def update(self, grads):
103    """Updates the value of the loss scale.
104
105    The loss scale will be potentially updated, based on the value of `grads`.
106    The tensor returned by calling this class is only updated when this function
107    is evaluated.
108
109    In eager mode, this directly updates the loss scale, so that calling
110    `__call__` will return the newly updated loss scale. In graph mode,
111    this returns an op that, when evaluated, updates the loss scale.
112
113    This function also returns a `should_apply_gradients` bool. If False,
114    gradients should not be applied to the variables that step, as nonfinite
115    gradients were found, and the loss scale has been be updated to reduce the
116    chance of finding nonfinite gradients in the next step. Some loss scale
117    classes will always return True, as they cannot adjust themselves in
118    response to nonfinite gradients.
119
120    When a DistributionStrategy is used, this function may only be called in a
121    cross-replica context.
122
123    Args:
124      grads: A nested structure of unscaled gradients, each which is the
125        gradient of the loss with respect to a weight. The gradients should have
126        already been divided by the loss scale being before passed to this
127        function. 'None' gradients are accepted, and are ignored.
128
129    Returns:
130      update_op: In eager mode, None. In graph mode, an op to update the loss
131        scale.
132      should_apply_gradients: Either a bool or a scalar boolean tensor. If
133        False, the caller should skip applying `grads` to the variables this
134        step.
135    """
136    pass
137
138  def _add_weight(self, name, initial_value, dtype=None):
139    """Adds a weight to this loss scale.
140
141    Args:
142      name: Variable name.
143      initial_value: The variable's initial value.
144      dtype: The type of the variable.
145
146    Returns:
147      A variable.
148
149    Raises:
150      RuntimeError: If a weight with `name` has already been added.
151    """
152    variable = variable_scope.variable(
153        initial_value=initial_value,
154        name=name,
155        dtype=dtype,
156        trainable=False,
157        use_resource=True,
158        synchronization=variables.VariableSynchronization.AUTO,
159        # Set aggregation to NONE, as loss scaling variables should never be
160        # aggregated.
161        aggregation=variables.VariableAggregation.NONE)
162    if context.executing_eagerly():
163      graph_key = None
164    else:
165      graph = ops.get_default_graph()
166      graph_key = graph._graph_key  # pylint: disable=protected-access
167
168    key = (name, graph_key)
169    if self._weights.get(key, None) is not None:
170      raise RuntimeError('Duplicate variables detected. {}'.format(key))
171    self._weights[key] = variable
172    self._handle_deferred_dependencies(name=name, trackable=variable)
173    return variable
174
175  @property
176  def _checkpoint_dependencies(self):
177    """From Trackable. Gather graph-specific weights to save."""
178    if context.executing_eagerly():
179      graph_key = None
180    else:
181      graph = ops.get_default_graph()
182      graph_key = graph._graph_key  # pylint: disable=protected-access
183    weights = []
184    for (name, g), v in sorted(self._weights.items(), key=lambda i: i[0][0]):
185      if g == graph_key:
186        weights.append(trackable.TrackableReference(name=name, ref=v))
187    return super(LossScale, self)._checkpoint_dependencies + weights
188
189  def _lookup_dependency(self, name):
190    """From Trackable. Find a weight in the current graph."""
191    unconditional = super(LossScale, self)._lookup_dependency(name)
192    if unconditional is not None:
193      return unconditional
194    if context.executing_eagerly():
195      graph_key = None
196    else:
197      graph = ops.get_default_graph()
198      graph_key = graph._graph_key  # pylint: disable=protected-access
199    return self._weights.get((name, graph_key), None)
200
201  @abc.abstractmethod
202  def get_config(self):
203    """Returns the config of this loss scale."""
204    pass
205
206  @classmethod
207  def from_config(cls, config):
208    """Creates the LossScale from its config."""
209    return cls(**config)
210
211
212@deprecation.deprecated_endpoints('mixed_precision.experimental.FixedLossScale',
213                                  'train.experimental.FixedLossScale')
214@tf_export(
215    'mixed_precision.experimental.FixedLossScale',
216    'train.experimental.FixedLossScale',
217    v1=[
218        'mixed_precision.FixedLossScale',
219        'mixed_precision.experimental.FixedLossScale',
220        'train.experimental.FixedLossScale'
221    ])
222class FixedLossScale(LossScale):
223  """Loss scale with a fixed value.
224
225  WARNING: This class is deprecated and will be unexposed from the TF 2
226  namespace starting in TensorFlow 2.5. In TensorFlow 2.5, this class will only
227  be accessible as `tf.compat.v1.mixed_precision.FixedLossScale`. Additionally
228  in 2.5, you will no longer be able to pass a `FixedLossScale` to a
229  `tf.keras.mixed_precision.Policy`. All the functionality in this class has
230  been merged into `tf.keras.mixed_precision.LossScaleOptimizer`, so this class
231  is no longer needed.
232
233  The loss scale is not updated for the lifetime of instances of this class.
234  A given instance of this class always returns the same number when called.
235  """
236
237  @deprecation.deprecated(
238      None, 'Use tf.keras.mixed_precision.LossScaleOptimizer instead. '
239            'LossScaleOptimizer now has all the functionality of '
240            'FixedLossScale')
241  def __init__(self, loss_scale_value):
242    """Creates the fixed loss scale.
243
244    Args:
245      loss_scale_value: A Python float. Its ideal value varies depending on
246        models to run. Choosing a too small loss_scale might affect model
247        quality; a too big loss_scale might cause inf or nan. There is no single
248        right loss_scale to apply. There is no harm choosing a relatively big
249        number as long as no nan or inf is encountered in training.
250
251    Raises:
252      ValueError: If loss_scale_value is less than 1.
253    """
254    super(FixedLossScale, self).__init__()
255    if not isinstance(loss_scale_value, six.integer_types + (float,)):
256      raise ValueError('loss_scale_value must be a Python int or float.')
257    if loss_scale_value < 1:
258      raise ValueError('loss_scale_value must be at least 1.')
259    # It's important we do not create tensors in the constructor, as such
260    # tensors might be on a different device or tf.function vs when the tensor
261    # is used. This would hurt performance. Therefore, we do not create a tensor
262    # from loss_scale_value, but instead leave it as a Python float.
263    # TODO(reedwm): Also do not create tensors in the DynamicLossScale
264    # constructor.
265    self._loss_scale_value = float(loss_scale_value)
266
267  def __call__(self):
268    return ops.convert_to_tensor(self._loss_scale_value)
269
270  def update(self, grads):
271    del grads
272    return control_flow_ops.no_op(), True
273
274  def __repr__(self):
275    return 'FixedLossScale(%s)' % self._loss_scale_value
276
277  def get_config(self):
278    return {'loss_scale_value': self._loss_scale_value}
279
280
281def _is_all_finite(grads):
282  """Returns a scalar boolean tensor indicating if all gradients are finite."""
283  is_finite_per_grad = [
284      math_ops.reduce_all(math_ops.is_finite(g)) for g in grads if g is not None
285  ]
286  return math_ops.reduce_all(is_finite_per_grad)
287
288
289def _op_in_graph_mode(tensor):
290  """Returns the tensor's op in graph mode, or the tensor in eager mode.
291
292  This is useful because sometimes an op is needed in graph mode instead of a
293  tensor. In eager mode, there are no ops.
294
295  Args:
296    tensor: A tensor.
297
298  Returns:
299    The tensor's op in graph mode. The tensor in eager mode.
300  """
301  if context.executing_eagerly():
302    return tensor
303  return tensor.op
304
305
306def _assign_if_finite(var, value):
307  """Assigns a value to a variable if the value is finite."""
308  return control_flow_ops.cond(
309      math_ops.is_finite(value), lambda: _op_in_graph_mode(var.assign(value)),
310      control_flow_ops.no_op)
311
312
313@deprecation.deprecated_endpoints(
314    'mixed_precision.experimental.DynamicLossScale',
315    'train.experimental.DynamicLossScale')
316@tf_export(
317    'mixed_precision.experimental.DynamicLossScale',
318    'train.experimental.DynamicLossScale',
319    v1=[
320        'mixed_precision.DynamicLossScale',
321        'mixed_precision.experimental.DynamicLossScale',
322        'train.experimental.DynamicLossScale'
323    ])
324class DynamicLossScale(LossScale):
325  """Loss scale that dynamically adjusts itself.
326
327  WARNING: This class is deprecated and will be unexposed from the TF 2
328  namespace starting in TensorFlow 2.5. In TensorFlow 2.5, this class will only
329  be accessible as `tf.compat.v1.mixed_precision.DynamicLossScale`. Additionally
330  in 2.5, you will no longer be able to pass a `DynamicLossScale` to a
331  `tf.keras.mixed_precision.Policy`. All the functionality in this class has
332  been merged into `tf.keras.mixed_precision.LossScaleOptimizer`, so this class
333  is no longer needed.
334
335  Dynamic loss scaling works by adjusting the loss scale as training progresses.
336  The goal is to keep the loss scale as high as possible without overflowing the
337  gradients. As long as the gradients do not overflow, raising the loss scale
338  never hurts.
339
340  The algorithm starts by setting the loss scale to an initial value. Every N
341  steps that the gradients are finite, the loss scale is increased by some
342  factor. However, if a NaN or Inf gradient is found, the gradients for that
343  step are not applied, and the loss scale is decreased by the factor. This
344  process tends to keep the loss scale as high as possible without gradients
345  overflowing.
346  """
347
348  @deprecation.deprecated(
349      None, 'Use tf.keras.mixed_precision.LossScaleOptimizer instead. '
350            'LossScaleOptimizer now has all the functionality of '
351            'DynamicLossScale')
352  def __init__(self,
353               initial_loss_scale=2 ** 15,  # See docstring for why this is big.
354               increment_period=2000,
355               multiplier=2.):
356    """Creates the dynamic loss scale.
357
358    Args:
359      initial_loss_scale: A Python float.  The loss scale to use at the
360        beginning. It's better to start this at a very high number, because a
361        loss scale that is too high gets lowered far more quickly than a loss
362        scale that is too low gets raised. The default is 2 ** 15, which is
363        approximately half the maximum float16 value.
364      increment_period: Increases loss scale every `increment_period`
365        consecutive steps that finite gradients are encountered. If a nonfinite
366        gradient is encountered, the count is reset back to zero.
367      multiplier: The multiplier to use when increasing or decreasing the loss
368        scale.
369    """
370    super(DynamicLossScale, self).__init__()
371    self._initial_loss_scale = float(initial_loss_scale)
372    self._increment_period = int(increment_period)
373    self._multiplier = float(multiplier)
374
375    self._current_loss_scale = self._add_weight(
376        name='current_loss_scale',
377        dtype=dtypes.float32,
378        initial_value=self._initial_loss_scale)
379    # The number of consecutive steps with finite gradients since the last
380    # nonfinite gradient or change in loss scale.
381    self._num_good_steps = self._add_weight(
382        name='good_steps', dtype=dtypes.int64, initial_value=0)
383
384  @property
385  def initial_loss_scale(self):
386    return self._initial_loss_scale
387
388  @property
389  def increment_period(self):
390    return self._increment_period
391
392  @property
393  def multiplier(self):
394    return self._multiplier
395
396  def __call__(self):
397    return ops.convert_to_tensor(self._current_loss_scale)
398
399  def update(self, grads):
400    """Updates loss scale based on if gradients are finite in current step."""
401    grads = nest.flatten(grads)
402    if distribution_strategy_context.has_strategy():
403      distribution = distribution_strategy_context.get_cross_replica_context()
404
405      def get_is_finite(grads):
406        is_finite = _is_all_finite(grads)
407        # We cast to float, because we cannot reduce booleans with
408        # DistributionStrategy.
409        return math_ops.cast(is_finite, dtypes.float32)
410
411      is_finite_float = distribution.extended.call_for_each_replica(
412          get_is_finite, args=(grads,))
413      reduced_is_finite_float = distribution.reduce(reduce_util.ReduceOp.SUM,
414                                                    is_finite_float, axis=None)
415      is_finite = math_ops.equal(reduced_is_finite_float,
416                                 distribution.num_replicas_in_sync)
417    else:
418      is_finite = _is_all_finite(grads)
419
420    def update_if_finite_grads():
421      """Update assuming the gradients are finite."""
422
423      def incr_loss_scale():
424        new_loss_scale = self._current_loss_scale * self._multiplier
425        return control_flow_ops.group(
426            _assign_if_finite(self._current_loss_scale, new_loss_scale),
427            self._num_good_steps.assign(0))
428
429      return control_flow_ops.cond(
430          self._num_good_steps + 1 >= self._increment_period,
431          incr_loss_scale, lambda: _op_in_graph_mode(
432              self._num_good_steps.assign_add(1)))
433
434    def update_if_not_finite_grads():
435      """Update assuming the gradients are nonfinite."""
436
437      new_loss_scale = math_ops.maximum(
438          self._current_loss_scale / self._multiplier, 1)
439      return control_flow_ops.group(
440          self._num_good_steps.assign(0),
441          self._current_loss_scale.assign(new_loss_scale))
442
443    update_op = control_flow_ops.cond(is_finite, update_if_finite_grads,
444                                      update_if_not_finite_grads)
445    should_apply_gradients = is_finite
446    return update_op, should_apply_gradients
447
448  def __repr__(self):
449    if context.executing_eagerly():
450      return ('DynamicLossScale(current_loss_scale=%s, num_good_steps=%s, '
451              'initial_loss_scale=%s, increment_period=%s, multiplier=%s)' %
452              (self._current_loss_scale.numpy(), self._num_good_steps.numpy(),
453               self.initial_loss_scale, self.increment_period, self.multiplier))
454    else:
455      return ('DynamicLossScale(initial_loss_scale=%s, increment_period=%s, '
456              'multiplier=%s)' %
457              (self.initial_loss_scale, self.increment_period, self.multiplier))
458
459  def get_config(self):
460    return {
461        'initial_loss_scale': self.initial_loss_scale,
462        'increment_period': self.increment_period,
463        'multiplier': self.multiplier,
464    }
465
466
467def get(identifier):
468  """Get a loss scale object."""
469  if isinstance(identifier, six.integer_types + (float,)):
470    return FixedLossScale(identifier)
471  if identifier == 'dynamic':
472    return DynamicLossScale()
473  if isinstance(identifier, LossScale):
474    return identifier
475  elif identifier is None:
476    return None
477  else:
478    raise ValueError('Could not interpret loss scale identifier: %s' %
479                     identifier)
480