1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains LossScale classes.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.distribute import distribution_strategy_context 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import smart_cond 23from tensorflow.python.ops import control_flow_ops 24from tensorflow.python.ops import math_ops 25from tensorflow.python.training import optimizer 26from tensorflow.python.training.experimental import loss_scale as loss_scale_module 27from tensorflow.python.util import deprecation 28from tensorflow.python.util.tf_export import tf_export 29 30 31@deprecation.deprecated_endpoints( 32 'train.experimental.MixedPrecisionLossScaleOptimizer') 33@tf_export(v1=['mixed_precision.MixedPrecisionLossScaleOptimizer', 34 'train.experimental.MixedPrecisionLossScaleOptimizer']) 35class MixedPrecisionLossScaleOptimizer(optimizer.Optimizer): 36 """An optimizer that applies loss scaling. 37 38 Loss scaling is a process that multiplies the loss by a multiplier called the 39 loss scale, and divides each gradient by the same multiplier. The pseudocode 40 for this process is: 41 42 ``` 43 loss = ... 44 loss *= loss_scale 45 grads = gradients(loss, vars) 46 grads /= loss_scale 47 ``` 48 49 Mathematically, loss scaling has no effect, but can help avoid numerical 50 underflow in intermediate gradients when float16 tensors are used for mixed 51 precision training. By multiplying the loss, each intermediate gradient will 52 have the same multiplier applied. 53 54 The loss scale can either be a fixed constant, chosen by the user, or be 55 dynamically determined. Dynamically determining the loss scale is convenient 56 as a loss scale does not have to be explicitly chosen. However it reduces 57 performance. 58 59 This optimizer wraps another optimizer and applies loss scaling to it via a 60 `LossScale`. Loss scaling is applied whenever gradients are 61 computed, such as through `minimize()`. 62 """ 63 64 def __init__(self, opt, loss_scale): 65 if not isinstance(opt, optimizer.Optimizer): 66 raise ValueError('"opt" must be an instance of Optimizer, but got: %s' % 67 type(opt)) 68 self._optimizer = opt 69 70 use_locking = opt._use_locking # pylint: disable=protected-access 71 name = opt.get_name() 72 super(MixedPrecisionLossScaleOptimizer, self).__init__(use_locking, name) 73 74 self._loss_scale = loss_scale_module.get(loss_scale) 75 if self._loss_scale is None: 76 raise ValueError('loss_scale cannot be None') 77 self._track_trackable(self._optimizer, 'base_optimizer') 78 self._track_trackable(self._loss_scale, 'loss_scale') 79 80 def _doing_dynamic_loss_scaling(self): 81 """Check if `_loss_scale` dynamically manages the loss scale.""" 82 return isinstance(self._loss_scale, loss_scale_module.DynamicLossScale) 83 84 def compute_gradients(self, 85 loss, 86 var_list=None, 87 gate_gradients=optimizer.Optimizer.GATE_OP, 88 aggregation_method=None, 89 colocate_gradients_with_ops=False, 90 grad_loss=None): 91 """Compute gradients of `loss` for the variables in `var_list`. 92 93 This adjusts the dynamic range of the gradient evaluation by scaling up 94 the `loss` value. The gradient values are then scaled back down by the 95 reciprocal of the loss scale. This is useful in reduced precision training 96 where small gradient values would otherwise underflow the representable 97 range. 98 99 Args: 100 loss: A Tensor containing the value to minimize or a callable taking no 101 arguments which returns the value to minimize. When eager execution is 102 enabled it must be a callable. 103 var_list: Optional list or tuple of `tf.Variable` to update to minimize 104 `loss`. Defaults to the list of variables collected in the graph under 105 the key `GraphKeys.TRAINABLE_VARIABLES`. 106 gate_gradients: How to gate the computation of gradients. Can be 107 `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. 108 aggregation_method: Specifies the method used to combine gradient terms. 109 Valid values are defined in the class `AggregationMethod`. 110 colocate_gradients_with_ops: If True, try colocating gradients with the 111 corresponding op. 112 grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. 113 114 Returns: 115 A list of (gradient, variable) pairs. Variable is always present, but 116 gradient can be `None`. 117 """ 118 loss = self._scale_loss(loss) 119 grads_and_vars = self._optimizer.compute_gradients( 120 loss=loss, 121 var_list=var_list, 122 gate_gradients=gate_gradients, 123 aggregation_method=aggregation_method, 124 colocate_gradients_with_ops=colocate_gradients_with_ops, 125 grad_loss=grad_loss) 126 127 grads = [g for g, _ in grads_and_vars] 128 variables = [v for _, v in grads_and_vars] 129 unscaled_grads = self._unscale_grads(grads) 130 return list(zip(unscaled_grads, variables)) 131 132 def _scale_loss(self, loss): 133 loss_scale = self._loss_scale() 134 if callable(loss): 135 def new_loss(): 136 loss_val = loss() 137 return loss_val * math_ops.cast(loss_scale, loss_val.dtype) 138 return new_loss 139 else: 140 return loss * math_ops.cast(loss_scale, loss.dtype) 141 142 def _unscale_grads(self, grads): 143 loss_scale = self._loss_scale() 144 loss_scale_reciprocal = 1 / loss_scale 145 return [ 146 None if g is None else self._scale_grad(g, loss_scale_reciprocal) 147 for g in grads 148 ] 149 150 def _scale_grad(self, grad, loss_scale_reciprocal): 151 if isinstance(grad, ops.IndexedSlices): 152 grad_vals = grad.values * loss_scale_reciprocal 153 return ops.IndexedSlices(grad_vals, grad.indices, grad.dense_shape) 154 return grad * loss_scale_reciprocal 155 156 def apply_gradients(self, grads_and_vars, global_step=None, name=None): 157 """Apply gradients to variables. 158 159 This is the second part of `minimize()`. It returns an `Operation` that 160 conditionally applies gradients if all gradient values are finite. 161 Otherwise no update is performed (nor is `global_step` incremented). 162 163 Args: 164 grads_and_vars: List of (gradient, variable) pairs as returned by 165 `compute_gradients()`. 166 global_step: Optional `Variable` to increment by one after the variables 167 have been updated. 168 name: Optional name for the returned operation. Default to the name 169 passed to the `Optimizer` constructor. 170 171 Returns: 172 An `Operation` that conditionally applies the specified gradients. If 173 `global_step` was not None, that operation also increments `global_step`. 174 175 Raises: 176 RuntimeError: If you should use `_distributed_apply()` instead. 177 """ 178 if distribution_strategy_context.in_cross_replica_context(): 179 raise ValueError('apply_gradients() must be called in a replica context.') 180 181 if not self._doing_dynamic_loss_scaling(): 182 return self._optimizer.apply_gradients(grads_and_vars, global_step, name) 183 184 replica_context = distribution_strategy_context.get_replica_context() 185 grads_and_vars = tuple(grads_and_vars) 186 187 # TODO(nluehr) cleanup GraphKeys.TRAIN_OP 188 return replica_context.merge_call( 189 self._distributed_apply, args=(grads_and_vars, global_step, name)) 190 191 def _distributed_apply(self, 192 distribution, 193 grads_and_vars, 194 global_step=None, 195 name=None): 196 """A version of `apply_gradients` for cross replica context. 197 198 When users are in a cross replica strategy, they must call this rather than 199 `apply_gradients()`. 200 201 Args: 202 distribution: a `DistributionStrategy` object. 203 grads_and_vars: List of (gradient, variable) pairs as returned by 204 `compute_gradients()` and then aggregated across replicas. 205 global_step: Optional (mirrored) `Variable` to increment by one after the 206 variables have been updated. 207 name: Optional name for the returned operation. Default to the name passed 208 to the `Optimizer` constructor. 209 210 Returns: 211 An `Operation` that applies the specified gradients across all 212 replicas. If `global_step` was not None, that operation also 213 increments `global_step` 214 """ 215 name = name if name is not None else self.get_name() 216 grads = [g for g, _ in grads_and_vars] 217 loss_scale_update_op, should_apply_grads = (self._loss_scale.update(grads)) 218 219 def apply_fn(): 220 return self._apply_gradients(distribution, grads_and_vars, global_step, 221 name + '-wrapped') 222 223 maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn, 224 control_flow_ops.no_op) 225 return control_flow_ops.group( 226 maybe_apply_op, loss_scale_update_op, name=name) 227 228 def _apply_gradients(self, distribution, grads_and_vars, global_step, name): 229 """Unconditionally apply gradients in cross replica context.""" 230 update_ops = distribution.extended.call_for_each_replica( 231 self._optimizer.apply_gradients, 232 args=(grads_and_vars, global_step, name)) 233 return distribution.group(update_ops) 234 235 def _apply_sparse(self, grad, var): 236 """This function should never be called.""" 237 raise RuntimeError('This function should never be called') 238 239 def _apply_dense(self, grad, var): 240 """This function should never be called.""" 241 raise RuntimeError('This function should never be called') 242 243 def _resource_apply_sparse(self, grad, handle, indices): 244 """This function should never be called.""" 245 raise RuntimeError('This function should never be called') 246 247 def _resource_apply_dense(self, grad, handle): 248 """This function should never be called.""" 249 raise RuntimeError('This function should never be called') 250 251 def variables(self): 252 """Returns the variables of the Optimizer.""" 253 return (self._optimizer.variables() + 254 list(self._loss_scale._weights.values())) # pylint: disable=protected-access 255