1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains LossScale classes."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.distribute import distribution_strategy_context
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import smart_cond
23from tensorflow.python.ops import control_flow_ops
24from tensorflow.python.ops import math_ops
25from tensorflow.python.training import optimizer
26from tensorflow.python.training.experimental import loss_scale as loss_scale_module
27from tensorflow.python.util import deprecation
28from tensorflow.python.util.tf_export import tf_export
29
30
31@deprecation.deprecated_endpoints(
32    'train.experimental.MixedPrecisionLossScaleOptimizer')
33@tf_export(v1=['mixed_precision.MixedPrecisionLossScaleOptimizer',
34               'train.experimental.MixedPrecisionLossScaleOptimizer'])
35class MixedPrecisionLossScaleOptimizer(optimizer.Optimizer):
36  """An optimizer that applies loss scaling.
37
38  Loss scaling is a process that multiplies the loss by a multiplier called the
39  loss scale, and divides each gradient by the same multiplier. The pseudocode
40  for this process is:
41
42  ```
43  loss = ...
44  loss *= loss_scale
45  grads = gradients(loss, vars)
46  grads /= loss_scale
47  ```
48
49  Mathematically, loss scaling has no effect, but can help avoid numerical
50  underflow in intermediate gradients when float16 tensors are used for mixed
51  precision training. By multiplying the loss, each intermediate gradient will
52  have the same multiplier applied.
53
54  The loss scale can either be a fixed constant, chosen by the user, or be
55  dynamically determined. Dynamically determining the loss scale is convenient
56  as a loss scale does not have to be explicitly chosen. However it reduces
57  performance.
58
59  This optimizer wraps another optimizer and applies loss scaling to it via a
60  `LossScale`. Loss scaling is applied whenever gradients are
61  computed, such as through `minimize()`.
62  """
63
64  def __init__(self, opt, loss_scale):
65    if not isinstance(opt, optimizer.Optimizer):
66      raise ValueError('"opt" must be an instance of Optimizer, but got: %s' %
67                       type(opt))
68    self._optimizer = opt
69
70    use_locking = opt._use_locking  # pylint: disable=protected-access
71    name = opt.get_name()
72    super(MixedPrecisionLossScaleOptimizer, self).__init__(use_locking, name)
73
74    self._loss_scale = loss_scale_module.get(loss_scale)
75    if self._loss_scale is None:
76      raise ValueError('loss_scale cannot be None')
77    self._track_trackable(self._optimizer, 'base_optimizer')
78    self._track_trackable(self._loss_scale, 'loss_scale')
79
80  def _doing_dynamic_loss_scaling(self):
81    """Check if `_loss_scale` dynamically manages the loss scale."""
82    return isinstance(self._loss_scale, loss_scale_module.DynamicLossScale)
83
84  def compute_gradients(self,
85                        loss,
86                        var_list=None,
87                        gate_gradients=optimizer.Optimizer.GATE_OP,
88                        aggregation_method=None,
89                        colocate_gradients_with_ops=False,
90                        grad_loss=None):
91    """Compute gradients of `loss` for the variables in `var_list`.
92
93    This adjusts the dynamic range of the gradient evaluation by scaling up
94    the `loss` value. The gradient values are then scaled back down by the
95    reciprocal of the loss scale. This is useful in reduced precision training
96    where small gradient values would otherwise underflow the representable
97    range.
98
99    Args:
100      loss: A Tensor containing the value to minimize or a callable taking no
101        arguments which returns the value to minimize. When eager execution is
102        enabled it must be a callable.
103      var_list: Optional list or tuple of `tf.Variable` to update to minimize
104        `loss`.  Defaults to the list of variables collected in the graph under
105        the key `GraphKeys.TRAINABLE_VARIABLES`.
106      gate_gradients: How to gate the computation of gradients.  Can be
107        `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
108      aggregation_method: Specifies the method used to combine gradient terms.
109        Valid values are defined in the class `AggregationMethod`.
110      colocate_gradients_with_ops: If True, try colocating gradients with the
111        corresponding op.
112      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
113
114    Returns:
115      A list of (gradient, variable) pairs. Variable is always present, but
116      gradient can be `None`.
117    """
118    loss = self._scale_loss(loss)
119    grads_and_vars = self._optimizer.compute_gradients(
120        loss=loss,
121        var_list=var_list,
122        gate_gradients=gate_gradients,
123        aggregation_method=aggregation_method,
124        colocate_gradients_with_ops=colocate_gradients_with_ops,
125        grad_loss=grad_loss)
126
127    grads = [g for g, _ in grads_and_vars]
128    variables = [v for _, v in grads_and_vars]
129    unscaled_grads = self._unscale_grads(grads)
130    return list(zip(unscaled_grads, variables))
131
132  def _scale_loss(self, loss):
133    loss_scale = self._loss_scale()
134    if callable(loss):
135      def new_loss():
136        loss_val = loss()
137        return loss_val * math_ops.cast(loss_scale, loss_val.dtype)
138      return new_loss
139    else:
140      return loss * math_ops.cast(loss_scale, loss.dtype)
141
142  def _unscale_grads(self, grads):
143    loss_scale = self._loss_scale()
144    loss_scale_reciprocal = 1 / loss_scale
145    return [
146        None if g is None else self._scale_grad(g, loss_scale_reciprocal)
147        for g in grads
148    ]
149
150  def _scale_grad(self, grad, loss_scale_reciprocal):
151    if isinstance(grad, ops.IndexedSlices):
152      grad_vals = grad.values * loss_scale_reciprocal
153      return ops.IndexedSlices(grad_vals, grad.indices, grad.dense_shape)
154    return grad * loss_scale_reciprocal
155
156  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
157    """Apply gradients to variables.
158
159    This is the second part of `minimize()`. It returns an `Operation` that
160    conditionally applies gradients if all gradient values are finite.
161    Otherwise no update is performed (nor is `global_step` incremented).
162
163    Args:
164      grads_and_vars: List of (gradient, variable) pairs as returned by
165        `compute_gradients()`.
166      global_step: Optional `Variable` to increment by one after the variables
167        have been updated.
168      name: Optional name for the returned operation.  Default to the name
169        passed to the `Optimizer` constructor.
170
171    Returns:
172      An `Operation` that conditionally applies the specified gradients. If
173      `global_step` was not None, that operation also increments `global_step`.
174
175    Raises:
176      RuntimeError: If you should use `_distributed_apply()` instead.
177    """
178    if distribution_strategy_context.in_cross_replica_context():
179      raise ValueError('apply_gradients() must be called in a replica context.')
180
181    if not self._doing_dynamic_loss_scaling():
182      return self._optimizer.apply_gradients(grads_and_vars, global_step, name)
183
184    replica_context = distribution_strategy_context.get_replica_context()
185    grads_and_vars = tuple(grads_and_vars)
186
187    # TODO(nluehr) cleanup GraphKeys.TRAIN_OP
188    return replica_context.merge_call(
189        self._distributed_apply, args=(grads_and_vars, global_step, name))
190
191  def _distributed_apply(self,
192                         distribution,
193                         grads_and_vars,
194                         global_step=None,
195                         name=None):
196    """A version of `apply_gradients` for cross replica context.
197
198    When users are in a cross replica strategy, they must call this rather than
199    `apply_gradients()`.
200
201    Args:
202      distribution: a `DistributionStrategy` object.
203      grads_and_vars: List of (gradient, variable) pairs as returned by
204        `compute_gradients()` and then aggregated across replicas.
205      global_step: Optional (mirrored) `Variable` to increment by one after the
206        variables have been updated.
207      name: Optional name for the returned operation. Default to the name passed
208        to the `Optimizer` constructor.
209
210    Returns:
211      An `Operation` that applies the specified gradients across all
212      replicas. If `global_step` was not None, that operation also
213      increments `global_step`
214    """
215    name = name if name is not None else self.get_name()
216    grads = [g for g, _ in grads_and_vars]
217    loss_scale_update_op, should_apply_grads = (self._loss_scale.update(grads))
218
219    def apply_fn():
220      return self._apply_gradients(distribution, grads_and_vars, global_step,
221                                   name + '-wrapped')
222
223    maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn,
224                                           control_flow_ops.no_op)
225    return control_flow_ops.group(
226        maybe_apply_op, loss_scale_update_op, name=name)
227
228  def _apply_gradients(self, distribution, grads_and_vars, global_step, name):
229    """Unconditionally apply gradients in cross replica context."""
230    update_ops = distribution.extended.call_for_each_replica(
231        self._optimizer.apply_gradients,
232        args=(grads_and_vars, global_step, name))
233    return distribution.group(update_ops)
234
235  def _apply_sparse(self, grad, var):
236    """This function should never be called."""
237    raise RuntimeError('This function should never be called')
238
239  def _apply_dense(self, grad, var):
240    """This function should never be called."""
241    raise RuntimeError('This function should never be called')
242
243  def _resource_apply_sparse(self, grad, handle, indices):
244    """This function should never be called."""
245    raise RuntimeError('This function should never be called')
246
247  def _resource_apply_dense(self, grad, handle):
248    """This function should never be called."""
249    raise RuntimeError('This function should never be called')
250
251  def variables(self):
252    """Returns the variables of the Optimizer."""
253    return (self._optimizer.variables() +
254            list(self._loss_scale._weights.values()))  # pylint: disable=protected-access
255