1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Base class to make optimizers weight decay ready."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib.opt.python.training import shampoo
22from tensorflow.python.framework import ops
23from tensorflow.python.ops import control_flow_ops
24from tensorflow.python.ops import resource_variable_ops
25from tensorflow.python.ops import state_ops
26from tensorflow.python.training import adam
27from tensorflow.python.training import momentum as momentum_opt
28from tensorflow.python.training import optimizer
29from tensorflow.python.util.tf_export import tf_export
30from tensorflow.python.ops import array_ops
31
32
33class DecoupledWeightDecayExtension(object):
34  """This class allows to extend optimizers with decoupled weight decay.
35
36  It implements the decoupled weight decay described by Loshchilov & Hutter
37  (https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is
38  decoupled from the optimization steps w.r.t. to the loss function.
39  For SGD variants, this simplifies hyperparameter search since it decouples
40  the settings of weight decay and learning rate.
41  For adaptive gradient algorithms, it regularizes variables with large
42  gradients more than L2 regularization would, which was shown to yield better
43  training loss and generalization error in the paper above.
44
45  This class alone is not an optimizer but rather extends existing
46  optimizers with decoupled weight decay. We explicitly define the two examples
47  used in the above paper (SGDW and AdamW), but in general this can extend
48  any OptimizerX by using
49  `extend_with_weight_decay(OptimizerX, weight_decay=weight_decay)`.
50  In order for it to work, it must be the first class the Optimizer with
51  weight decay inherits from, e.g.
52
53  ```python
54  class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer):
55    def __init__(self, weight_decay, *args, **kwargs):
56      super(AdamWOptimizer, self).__init__(weight_decay, *args, **kwargs).
57  ```
58
59  Note that this extension decays weights BEFORE applying the update based
60  on the gradient, i.e. this extension only has the desired behaviour for
61  optimizers which do not depend on the value of'var' in the update step!
62
63  Note: when applying a decay to the learning rate, be sure to manually apply
64  the decay to the `weight_decay` as well. For example:
65
66  ```python
67    schedule = tf.train.piecewise_constant(tf.train.get_global_step(),
68                                           [10000, 15000], [1e-0, 1e-1, 1e-2])
69    lr = 1e-1 * schedule()
70    wd = lambda: 1e-4 * schedule()
71
72    # ...
73
74    optimizer = tf.contrib.opt.MomentumWOptimizer(learning_rate=lr,
75                                                  weight_decay=wd,
76                                                  momentum=0.9,
77                                                  use_nesterov=True)
78  ```
79  """
80
81  def __init__(self, weight_decay, **kwargs):
82    """Construct the extension class that adds weight decay to an optimizer.
83
84    Args:
85      weight_decay: A `Tensor` or a floating point value, the factor by which
86        a variable is decayed in the update step.
87      **kwargs: Optional list or tuple or set of `Variable` objects to
88        decay.
89    """
90    self._decay_var_list = None  # is set in minimize or apply_gradients
91    self._weight_decay = weight_decay
92    # The tensors are initialized in call to _prepare
93    self._weight_decay_tensor = None
94    super(DecoupledWeightDecayExtension, self).__init__(**kwargs)
95
96  def minimize(self, loss, global_step=None, var_list=None,
97               gate_gradients=optimizer.Optimizer.GATE_OP,
98               aggregation_method=None, colocate_gradients_with_ops=False,
99               name=None, grad_loss=None, decay_var_list=None):
100    """Add operations to minimize `loss` by updating `var_list` with decay.
101
102    This function is the same as Optimizer.minimize except that it allows to
103    specify the variables that should be decayed using decay_var_list.
104    If decay_var_list is None, all variables in var_list are decayed.
105
106    For more information see the documentation of Optimizer.minimize.
107
108    Args:
109      loss: A `Tensor` containing the value to minimize.
110      global_step: Optional `Variable` to increment by one after the
111        variables have been updated.
112      var_list: Optional list or tuple of `Variable` objects to update to
113        minimize `loss`.  Defaults to the list of variables collected in
114        the graph under the key `GraphKeys.TRAINABLE_VARIABLES`.
115      gate_gradients: How to gate the computation of gradients.  Can be
116        `GATE_NONE`, `GATE_OP`, or  `GATE_GRAPH`.
117      aggregation_method: Specifies the method used to combine gradient terms.
118        Valid values are defined in the class `AggregationMethod`.
119      colocate_gradients_with_ops: If True, try colocating gradients with
120        the corresponding op.
121      name: Optional name for the returned operation.
122      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
123      decay_var_list: Optional list of decay variables.
124
125    Returns:
126      An Operation that updates the variables in `var_list`.  If `global_step`
127      was not `None`, that operation also increments `global_step`.
128
129    """
130    self._decay_var_list = set(decay_var_list) if decay_var_list else False
131    return super(DecoupledWeightDecayExtension, self).minimize(
132        loss, global_step=global_step, var_list=var_list,
133        gate_gradients=gate_gradients, aggregation_method=aggregation_method,
134        colocate_gradients_with_ops=colocate_gradients_with_ops, name=name,
135        grad_loss=grad_loss)
136
137  def apply_gradients(self, grads_and_vars, global_step=None, name=None,
138                      decay_var_list=None):
139    """Apply gradients to variables and decay the variables.
140
141    This function is the same as Optimizer.apply_gradients except that it
142    allows to specify the variables that should be decayed using
143    decay_var_list. If decay_var_list is None, all variables in var_list
144    are decayed.
145
146    For more information see the documentation of Optimizer.apply_gradients.
147
148    Args:
149      grads_and_vars: List of (gradient, variable) pairs as returned by
150        `compute_gradients()`.
151      global_step: Optional `Variable` to increment by one after the
152        variables have been updated.
153      name: Optional name for the returned operation.  Default to the
154        name passed to the `Optimizer` constructor.
155      decay_var_list: Optional list of decay variables.
156
157    Returns:
158      An `Operation` that applies the specified gradients. If `global_step`
159      was not None, that operation also increments `global_step`.
160    """
161    self._decay_var_list = set(decay_var_list) if decay_var_list else False
162    return super(DecoupledWeightDecayExtension, self).apply_gradients(
163        grads_and_vars, global_step=global_step, name=name)
164
165  def _prepare(self):
166    weight_decay = self._weight_decay
167    if callable(weight_decay):
168      weight_decay = weight_decay()
169    self._weight_decay_tensor = ops.convert_to_tensor(
170        weight_decay, name="weight_decay")
171    # Call the optimizers _prepare function.
172    super(DecoupledWeightDecayExtension, self)._prepare()
173
174  def _decay_weights_op(self, var):
175    if not self._decay_var_list or var in self._decay_var_list:
176      return var.assign_sub(self._weight_decay * var, self._use_locking)
177    return control_flow_ops.no_op()
178
179  def _decay_weights_sparse_op(self, var, indices, scatter_add):
180    if not self._decay_var_list or var in self._decay_var_list:
181      update = -self._weight_decay * array_ops.gather(var, indices)
182      return scatter_add(var, indices, update, self._use_locking)
183    return control_flow_ops.no_op()
184
185  # Here, we overwrite the apply functions that the base optimizer calls.
186  # super().apply_x resolves to the apply_x function of the BaseOptimizer.
187  def _apply_dense(self, grad, var):
188    with ops.control_dependencies([self._decay_weights_op(var)]):
189      return super(DecoupledWeightDecayExtension, self)._apply_dense(grad, var)
190
191  def _resource_apply_dense(self, grad, var):
192    with ops.control_dependencies([self._decay_weights_op(var)]):
193      return super(DecoupledWeightDecayExtension, self)._resource_apply_dense(
194          grad, var)
195
196  def _apply_sparse(self, grad, var):
197    scatter_add = state_ops.scatter_add
198    decay_op = self._decay_weights_sparse_op(var, grad.indices, scatter_add)
199    with ops.control_dependencies([decay_op]):
200      return super(DecoupledWeightDecayExtension, self)._apply_sparse(
201          grad, var)
202
203  def _resource_scatter_add(self, x, i, v, _=None):
204    # last argument allows for one overflow argument, to have the same function
205    # signature as state_ops.scatter_add
206    with ops.control_dependencies(
207        [resource_variable_ops.resource_scatter_add(x.handle, i, v)]):
208      return x.value()
209
210  def _resource_apply_sparse(self, grad, var, indices):
211    scatter_add = self._resource_scatter_add
212    decay_op = self._decay_weights_sparse_op(var, indices, scatter_add)
213    with ops.control_dependencies([decay_op]):
214      return super(DecoupledWeightDecayExtension, self)._resource_apply_sparse(
215          grad, var, indices)
216
217
218def extend_with_decoupled_weight_decay(base_optimizer):
219  """Factory function returning an optimizer class with decoupled weight decay.
220
221  Returns an optimizer class. An instance of the returned class computes the
222  update step of `base_optimizer` and additionally decays the weights.
223  E.g., the class returned by
224  `extend_with_decoupled_weight_decay(tf.train.AdamOptimizer)` is equivalent to
225  `tf.contrib.opt.AdamWOptimizer`.
226
227  The API of the new optimizer class slightly differs from the API of the
228  base optimizer:
229  - The first argument to the constructor is the weight decay rate.
230  - `minimize` and `apply_gradients` accept the optional keyword argument
231    `decay_var_list`, which specifies the variables that should be decayed.
232    If `None`, all variables that are optimized are decayed.
233
234  Usage example:
235  ```python
236  # MyAdamW is a new class
237  MyAdamW = extend_with_decoupled_weight_decay(tf.train.AdamOptimizer)
238  # Create a MyAdamW object
239  optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001)
240  sess.run(optimizer.minimize(loss, decay_variables=[var1, var2]))
241
242  Note that this extension decays weights BEFORE applying the update based
243  on the gradient, i.e. this extension only has the desired behaviour for
244  optimizers which do not depend on the value of'var' in the update step!
245  ```
246
247  Args:
248    base_optimizer: An optimizer class that inherits from tf.train.Optimizer.
249
250  Returns:
251    A new optimizer class that inherits from DecoupledWeightDecayExtension
252    and base_optimizer.
253  """
254
255  class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension,
256                                          base_optimizer):
257    """Base_optimizer with decoupled weight decay.
258
259    This class computes the update step of `base_optimizer` and
260    additionally decays the variable with the weight decay being decoupled from
261    the optimization steps w.r.t. to the loss function, as described by
262    Loshchilov & Hutter (https://arxiv.org/pdf/1711.05101.pdf).
263    For SGD variants, this simplifies hyperparameter search since
264    it decouples the settings of weight decay and learning rate.
265    For adaptive gradient algorithms, it regularizes variables with large
266    gradients more than L2 regularization would, which was shown to yield
267    better training loss and generalization error in the paper above.
268    """
269
270    def __init__(self, weight_decay, *args, **kwargs):
271      # super delegation is necessary here
272      # pylint: disable=useless-super-delegation
273      super(OptimizerWithDecoupledWeightDecay, self).__init__(
274          weight_decay, *args, **kwargs)
275      # pylint: enable=useless-super-delegation
276
277  return OptimizerWithDecoupledWeightDecay
278
279
280@tf_export("contrib.opt.MomentumWOptimizer")
281class MomentumWOptimizer(DecoupledWeightDecayExtension,
282                         momentum_opt.MomentumOptimizer):
283  """Optimizer that implements the Momentum algorithm with weight_decay.
284
285  This is an implementation of the SGDW optimizer described in "Fixing
286  Weight Decay Regularization in Adam" by Loshchilov & Hutter
287  (https://arxiv.org/abs/1711.05101)
288  ([pdf])(https://arxiv.org/pdf/1711.05101.pdf).
289  It computes the update step of `train.MomentumOptimizer` and additionally
290  decays the variable. Note that this is different from adding
291  L2 regularization on the variables to the loss. Decoupling the weight decay
292  from other hyperparameters (in particular the learning rate) simplifies
293  hyperparameter search.
294
295  For further information see the documentation of the Momentum Optimizer.
296
297  Note that this optimizer can also be instantiated as
298  ```python
299  extend_with_weight_decay(tf.train.MomentumOptimizer,
300                           weight_decay=weight_decay)
301  ```
302  """
303
304  def __init__(self, weight_decay, learning_rate, momentum,
305               use_locking=False, name="MomentumW", use_nesterov=False):
306    """Construct a new MomentumW optimizer.
307
308    For further information see the documentation of the Momentum Optimizer.
309
310    Args:
311      weight_decay:  A `Tensor` or a floating point value.  The weight decay.
312      learning_rate: A `Tensor` or a floating point value.  The learning rate.
313      momentum: A `Tensor` or a floating point value.  The momentum.
314      use_locking: If `True` use locks for update operations.
315      name: Optional name prefix for the operations created when applying
316        gradients.  Defaults to "Momentum".
317      use_nesterov: If `True` use Nesterov Momentum.
318        See [Sutskever et al., 2013](
319        http://jmlr.org/proceedings/papers/v28/sutskever13.pdf).
320        This implementation always computes gradients at the value of the
321        variable(s) passed to the optimizer. Using Nesterov Momentum makes the
322        variable(s) track the values called `theta_t + mu*v_t` in the paper.
323
324    @compatibility(eager)
325    When eager execution is enabled, learning_rate, weight_decay and momentum
326    can each be a callable that takes no arguments and returns the actual value
327    to use. This can be useful for changing these values across different
328    invocations of optimizer functions.
329    @end_compatibility
330    """
331    super(MomentumWOptimizer, self).__init__(
332        weight_decay, learning_rate=learning_rate, momentum=momentum,
333        use_locking=use_locking, name=name, use_nesterov=use_nesterov)
334
335
336@tf_export("contrib.opt.AdamWOptimizer")
337class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer):
338  """Optimizer that implements the Adam algorithm with weight decay.
339
340  This is an implementation of the AdamW optimizer described in "Fixing
341  Weight Decay Regularization in Adam" by Loshchilov & Hutter
342  (https://arxiv.org/abs/1711.05101)
343  ([pdf])(https://arxiv.org/pdf/1711.05101.pdf).
344
345  It computes the update step of `train.AdamOptimizer` and additionally decays
346  the variable. Note that this is different from adding L2 regularization on
347  the variables to the loss: it regularizes variables with large
348  gradients more than L2 regularization would, which was shown to yield better
349  training loss and generalization error in the paper above.
350
351  For further information see the documentation of the Adam Optimizer.
352
353  Note that this optimizer can also be instantiated as
354  ```python
355  extend_with_weight_decay(tf.train.AdamOptimizer, weight_decay=weight_decay)
356  ```
357  """
358
359  def __init__(self, weight_decay, learning_rate=0.001, beta1=0.9, beta2=0.999,
360               epsilon=1e-8, use_locking=False, name="AdamW"):
361    """Construct a new AdamW optimizer.
362
363    For further information see the documentation of the Adam Optimizer.
364
365    Args:
366      weight_decay:  A `Tensor` or a floating point value.  The weight decay.
367      learning_rate: A Tensor or a floating point value.  The learning rate.
368      beta1: A float value or a constant float tensor.
369        The exponential decay rate for the 1st moment estimates.
370      beta2: A float value or a constant float tensor.
371        The exponential decay rate for the 2nd moment estimates.
372      epsilon: A small constant for numerical stability. This epsilon is
373        "epsilon hat" in the Kingma and Ba paper (in the formula just before
374        Section 2.1), not the epsilon in Algorithm 1 of the paper.
375      use_locking: If True use locks for update operations.
376      name: Optional name for the operations created when applying gradients.
377        Defaults to "Adam".
378    """
379    super(AdamWOptimizer, self).__init__(
380        weight_decay, learning_rate=learning_rate, beta1=beta1, beta2=beta2,
381        epsilon=epsilon, use_locking=use_locking, name=name)
382
383
384@tf_export("contrib.opt.ShampooWOptimizer")
385class ShampooWOptimizer(DecoupledWeightDecayExtension,
386                        shampoo.ShampooOptimizer):
387  """Optimizer that implements the Shampoo algorithm with weight decay.
388
389  For further information see the documentation of the Shampoo Optimizer.
390  """
391
392  def __init__(self,
393               weight_decay,
394               global_step,
395               max_matrix_size=768,
396               gbar_decay=0.0,
397               gbar_weight=1.0,
398               mat_gbar_decay=1.0,
399               mat_gbar_weight=1.0,
400               learning_rate=1.0,
401               svd_interval=1,
402               precond_update_interval=1,
403               epsilon=1e-4,
404               alpha=0.5,
405               use_iterative_root=False,
406               use_locking=False,
407               name="ShampooW"):
408    """Construct a new ShampooW optimizer.
409
410    For further information see the documentation of the Shampoo Optimizer.
411
412    Args:
413      weight_decay:  A `Tensor` or a floating point value.  The weight decay.
414      global_step: tensorflow variable indicating the step.
415      max_matrix_size: We do not perform SVD for matrices larger than this.
416      gbar_decay:
417      gbar_weight:  Used to update gbar: gbar[t] = gbar_decay[t] * gbar[t-1] +
418        gbar_weight[t] * g[t]
419      mat_gbar_decay:
420      mat_gbar_weight:  Used to update mat_gbar: mat_gbar_j[t] =
421        mat_gbar_decay[t] * mat_gbar_j[t-1] + mat_gbar_weight[t] * gg_j[t]
422      learning_rate: Similar to SGD
423      svd_interval: We should do SVD after this many steps. Default = 1, i.e.
424        every step. Usually 20 leads to no loss of accuracy, and 50 or 100 is
425        also OK. May also want more often early,
426                    and less often later - set in caller as for example:
427                    "svd_interval = lambda(T): tf.cond(
428                        T < 2000, lambda: 20.0, lambda: 1000.0)"
429      precond_update_interval: We should update the preconditioners after this
430        many steps. Default = 1. Usually less than svd_interval.
431      epsilon:  epsilon * I_n is added to each mat_gbar_j for stability
432      alpha:  total power of the preconditioners.
433      use_iterative_root: should the optimizer use SVD (faster) or the iterative
434        root method (for TPU) for finding the roots of PSD matrices.
435      use_locking: If `True` use locks for update operations.
436      name: name of optimizer.
437    """
438    super(ShampooWOptimizer, self).__init__(
439        weight_decay,
440        global_step=global_step,
441        max_matrix_size=max_matrix_size,
442        gbar_decay=gbar_decay,
443        gbar_weight=gbar_weight,
444        mat_gbar_decay=mat_gbar_weight,
445        learning_rate=learning_rate,
446        svd_interval=svd_interval,
447        precond_update_interval=precond_update_interval,
448        epsilon=epsilon,
449        alpha=alpha,
450        use_iterative_root=use_iterative_root,
451        use_locking=use_locking,
452        name=name)
453