1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Base class to make optimizers weight decay ready.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.contrib.opt.python.training import shampoo 22from tensorflow.python.framework import ops 23from tensorflow.python.ops import control_flow_ops 24from tensorflow.python.ops import resource_variable_ops 25from tensorflow.python.ops import state_ops 26from tensorflow.python.training import adam 27from tensorflow.python.training import momentum as momentum_opt 28from tensorflow.python.training import optimizer 29from tensorflow.python.util.tf_export import tf_export 30from tensorflow.python.ops import array_ops 31 32 33class DecoupledWeightDecayExtension(object): 34 """This class allows to extend optimizers with decoupled weight decay. 35 36 It implements the decoupled weight decay described by Loshchilov & Hutter 37 (https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is 38 decoupled from the optimization steps w.r.t. to the loss function. 39 For SGD variants, this simplifies hyperparameter search since it decouples 40 the settings of weight decay and learning rate. 41 For adaptive gradient algorithms, it regularizes variables with large 42 gradients more than L2 regularization would, which was shown to yield better 43 training loss and generalization error in the paper above. 44 45 This class alone is not an optimizer but rather extends existing 46 optimizers with decoupled weight decay. We explicitly define the two examples 47 used in the above paper (SGDW and AdamW), but in general this can extend 48 any OptimizerX by using 49 `extend_with_weight_decay(OptimizerX, weight_decay=weight_decay)`. 50 In order for it to work, it must be the first class the Optimizer with 51 weight decay inherits from, e.g. 52 53 ```python 54 class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer): 55 def __init__(self, weight_decay, *args, **kwargs): 56 super(AdamWOptimizer, self).__init__(weight_decay, *args, **kwargs). 57 ``` 58 59 Note that this extension decays weights BEFORE applying the update based 60 on the gradient, i.e. this extension only has the desired behaviour for 61 optimizers which do not depend on the value of'var' in the update step! 62 63 Note: when applying a decay to the learning rate, be sure to manually apply 64 the decay to the `weight_decay` as well. For example: 65 66 ```python 67 schedule = tf.train.piecewise_constant(tf.train.get_global_step(), 68 [10000, 15000], [1e-0, 1e-1, 1e-2]) 69 lr = 1e-1 * schedule() 70 wd = lambda: 1e-4 * schedule() 71 72 # ... 73 74 optimizer = tf.contrib.opt.MomentumWOptimizer(learning_rate=lr, 75 weight_decay=wd, 76 momentum=0.9, 77 use_nesterov=True) 78 ``` 79 """ 80 81 def __init__(self, weight_decay, **kwargs): 82 """Construct the extension class that adds weight decay to an optimizer. 83 84 Args: 85 weight_decay: A `Tensor` or a floating point value, the factor by which 86 a variable is decayed in the update step. 87 **kwargs: Optional list or tuple or set of `Variable` objects to 88 decay. 89 """ 90 self._decay_var_list = None # is set in minimize or apply_gradients 91 self._weight_decay = weight_decay 92 # The tensors are initialized in call to _prepare 93 self._weight_decay_tensor = None 94 super(DecoupledWeightDecayExtension, self).__init__(**kwargs) 95 96 def minimize(self, loss, global_step=None, var_list=None, 97 gate_gradients=optimizer.Optimizer.GATE_OP, 98 aggregation_method=None, colocate_gradients_with_ops=False, 99 name=None, grad_loss=None, decay_var_list=None): 100 """Add operations to minimize `loss` by updating `var_list` with decay. 101 102 This function is the same as Optimizer.minimize except that it allows to 103 specify the variables that should be decayed using decay_var_list. 104 If decay_var_list is None, all variables in var_list are decayed. 105 106 For more information see the documentation of Optimizer.minimize. 107 108 Args: 109 loss: A `Tensor` containing the value to minimize. 110 global_step: Optional `Variable` to increment by one after the 111 variables have been updated. 112 var_list: Optional list or tuple of `Variable` objects to update to 113 minimize `loss`. Defaults to the list of variables collected in 114 the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. 115 gate_gradients: How to gate the computation of gradients. Can be 116 `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. 117 aggregation_method: Specifies the method used to combine gradient terms. 118 Valid values are defined in the class `AggregationMethod`. 119 colocate_gradients_with_ops: If True, try colocating gradients with 120 the corresponding op. 121 name: Optional name for the returned operation. 122 grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. 123 decay_var_list: Optional list of decay variables. 124 125 Returns: 126 An Operation that updates the variables in `var_list`. If `global_step` 127 was not `None`, that operation also increments `global_step`. 128 129 """ 130 self._decay_var_list = set(decay_var_list) if decay_var_list else False 131 return super(DecoupledWeightDecayExtension, self).minimize( 132 loss, global_step=global_step, var_list=var_list, 133 gate_gradients=gate_gradients, aggregation_method=aggregation_method, 134 colocate_gradients_with_ops=colocate_gradients_with_ops, name=name, 135 grad_loss=grad_loss) 136 137 def apply_gradients(self, grads_and_vars, global_step=None, name=None, 138 decay_var_list=None): 139 """Apply gradients to variables and decay the variables. 140 141 This function is the same as Optimizer.apply_gradients except that it 142 allows to specify the variables that should be decayed using 143 decay_var_list. If decay_var_list is None, all variables in var_list 144 are decayed. 145 146 For more information see the documentation of Optimizer.apply_gradients. 147 148 Args: 149 grads_and_vars: List of (gradient, variable) pairs as returned by 150 `compute_gradients()`. 151 global_step: Optional `Variable` to increment by one after the 152 variables have been updated. 153 name: Optional name for the returned operation. Default to the 154 name passed to the `Optimizer` constructor. 155 decay_var_list: Optional list of decay variables. 156 157 Returns: 158 An `Operation` that applies the specified gradients. If `global_step` 159 was not None, that operation also increments `global_step`. 160 """ 161 self._decay_var_list = set(decay_var_list) if decay_var_list else False 162 return super(DecoupledWeightDecayExtension, self).apply_gradients( 163 grads_and_vars, global_step=global_step, name=name) 164 165 def _prepare(self): 166 weight_decay = self._weight_decay 167 if callable(weight_decay): 168 weight_decay = weight_decay() 169 self._weight_decay_tensor = ops.convert_to_tensor( 170 weight_decay, name="weight_decay") 171 # Call the optimizers _prepare function. 172 super(DecoupledWeightDecayExtension, self)._prepare() 173 174 def _decay_weights_op(self, var): 175 if not self._decay_var_list or var in self._decay_var_list: 176 return var.assign_sub(self._weight_decay * var, self._use_locking) 177 return control_flow_ops.no_op() 178 179 def _decay_weights_sparse_op(self, var, indices, scatter_add): 180 if not self._decay_var_list or var in self._decay_var_list: 181 update = -self._weight_decay * array_ops.gather(var, indices) 182 return scatter_add(var, indices, update, self._use_locking) 183 return control_flow_ops.no_op() 184 185 # Here, we overwrite the apply functions that the base optimizer calls. 186 # super().apply_x resolves to the apply_x function of the BaseOptimizer. 187 def _apply_dense(self, grad, var): 188 with ops.control_dependencies([self._decay_weights_op(var)]): 189 return super(DecoupledWeightDecayExtension, self)._apply_dense(grad, var) 190 191 def _resource_apply_dense(self, grad, var): 192 with ops.control_dependencies([self._decay_weights_op(var)]): 193 return super(DecoupledWeightDecayExtension, self)._resource_apply_dense( 194 grad, var) 195 196 def _apply_sparse(self, grad, var): 197 scatter_add = state_ops.scatter_add 198 decay_op = self._decay_weights_sparse_op(var, grad.indices, scatter_add) 199 with ops.control_dependencies([decay_op]): 200 return super(DecoupledWeightDecayExtension, self)._apply_sparse( 201 grad, var) 202 203 def _resource_scatter_add(self, x, i, v, _=None): 204 # last argument allows for one overflow argument, to have the same function 205 # signature as state_ops.scatter_add 206 with ops.control_dependencies( 207 [resource_variable_ops.resource_scatter_add(x.handle, i, v)]): 208 return x.value() 209 210 def _resource_apply_sparse(self, grad, var, indices): 211 scatter_add = self._resource_scatter_add 212 decay_op = self._decay_weights_sparse_op(var, indices, scatter_add) 213 with ops.control_dependencies([decay_op]): 214 return super(DecoupledWeightDecayExtension, self)._resource_apply_sparse( 215 grad, var, indices) 216 217 218def extend_with_decoupled_weight_decay(base_optimizer): 219 """Factory function returning an optimizer class with decoupled weight decay. 220 221 Returns an optimizer class. An instance of the returned class computes the 222 update step of `base_optimizer` and additionally decays the weights. 223 E.g., the class returned by 224 `extend_with_decoupled_weight_decay(tf.train.AdamOptimizer)` is equivalent to 225 `tf.contrib.opt.AdamWOptimizer`. 226 227 The API of the new optimizer class slightly differs from the API of the 228 base optimizer: 229 - The first argument to the constructor is the weight decay rate. 230 - `minimize` and `apply_gradients` accept the optional keyword argument 231 `decay_var_list`, which specifies the variables that should be decayed. 232 If `None`, all variables that are optimized are decayed. 233 234 Usage example: 235 ```python 236 # MyAdamW is a new class 237 MyAdamW = extend_with_decoupled_weight_decay(tf.train.AdamOptimizer) 238 # Create a MyAdamW object 239 optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001) 240 sess.run(optimizer.minimize(loss, decay_variables=[var1, var2])) 241 242 Note that this extension decays weights BEFORE applying the update based 243 on the gradient, i.e. this extension only has the desired behaviour for 244 optimizers which do not depend on the value of'var' in the update step! 245 ``` 246 247 Args: 248 base_optimizer: An optimizer class that inherits from tf.train.Optimizer. 249 250 Returns: 251 A new optimizer class that inherits from DecoupledWeightDecayExtension 252 and base_optimizer. 253 """ 254 255 class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension, 256 base_optimizer): 257 """Base_optimizer with decoupled weight decay. 258 259 This class computes the update step of `base_optimizer` and 260 additionally decays the variable with the weight decay being decoupled from 261 the optimization steps w.r.t. to the loss function, as described by 262 Loshchilov & Hutter (https://arxiv.org/pdf/1711.05101.pdf). 263 For SGD variants, this simplifies hyperparameter search since 264 it decouples the settings of weight decay and learning rate. 265 For adaptive gradient algorithms, it regularizes variables with large 266 gradients more than L2 regularization would, which was shown to yield 267 better training loss and generalization error in the paper above. 268 """ 269 270 def __init__(self, weight_decay, *args, **kwargs): 271 # super delegation is necessary here 272 # pylint: disable=useless-super-delegation 273 super(OptimizerWithDecoupledWeightDecay, self).__init__( 274 weight_decay, *args, **kwargs) 275 # pylint: enable=useless-super-delegation 276 277 return OptimizerWithDecoupledWeightDecay 278 279 280@tf_export("contrib.opt.MomentumWOptimizer") 281class MomentumWOptimizer(DecoupledWeightDecayExtension, 282 momentum_opt.MomentumOptimizer): 283 """Optimizer that implements the Momentum algorithm with weight_decay. 284 285 This is an implementation of the SGDW optimizer described in "Fixing 286 Weight Decay Regularization in Adam" by Loshchilov & Hutter 287 (https://arxiv.org/abs/1711.05101) 288 ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). 289 It computes the update step of `train.MomentumOptimizer` and additionally 290 decays the variable. Note that this is different from adding 291 L2 regularization on the variables to the loss. Decoupling the weight decay 292 from other hyperparameters (in particular the learning rate) simplifies 293 hyperparameter search. 294 295 For further information see the documentation of the Momentum Optimizer. 296 297 Note that this optimizer can also be instantiated as 298 ```python 299 extend_with_weight_decay(tf.train.MomentumOptimizer, 300 weight_decay=weight_decay) 301 ``` 302 """ 303 304 def __init__(self, weight_decay, learning_rate, momentum, 305 use_locking=False, name="MomentumW", use_nesterov=False): 306 """Construct a new MomentumW optimizer. 307 308 For further information see the documentation of the Momentum Optimizer. 309 310 Args: 311 weight_decay: A `Tensor` or a floating point value. The weight decay. 312 learning_rate: A `Tensor` or a floating point value. The learning rate. 313 momentum: A `Tensor` or a floating point value. The momentum. 314 use_locking: If `True` use locks for update operations. 315 name: Optional name prefix for the operations created when applying 316 gradients. Defaults to "Momentum". 317 use_nesterov: If `True` use Nesterov Momentum. 318 See [Sutskever et al., 2013]( 319 http://jmlr.org/proceedings/papers/v28/sutskever13.pdf). 320 This implementation always computes gradients at the value of the 321 variable(s) passed to the optimizer. Using Nesterov Momentum makes the 322 variable(s) track the values called `theta_t + mu*v_t` in the paper. 323 324 @compatibility(eager) 325 When eager execution is enabled, learning_rate, weight_decay and momentum 326 can each be a callable that takes no arguments and returns the actual value 327 to use. This can be useful for changing these values across different 328 invocations of optimizer functions. 329 @end_compatibility 330 """ 331 super(MomentumWOptimizer, self).__init__( 332 weight_decay, learning_rate=learning_rate, momentum=momentum, 333 use_locking=use_locking, name=name, use_nesterov=use_nesterov) 334 335 336@tf_export("contrib.opt.AdamWOptimizer") 337class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer): 338 """Optimizer that implements the Adam algorithm with weight decay. 339 340 This is an implementation of the AdamW optimizer described in "Fixing 341 Weight Decay Regularization in Adam" by Loshchilov & Hutter 342 (https://arxiv.org/abs/1711.05101) 343 ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). 344 345 It computes the update step of `train.AdamOptimizer` and additionally decays 346 the variable. Note that this is different from adding L2 regularization on 347 the variables to the loss: it regularizes variables with large 348 gradients more than L2 regularization would, which was shown to yield better 349 training loss and generalization error in the paper above. 350 351 For further information see the documentation of the Adam Optimizer. 352 353 Note that this optimizer can also be instantiated as 354 ```python 355 extend_with_weight_decay(tf.train.AdamOptimizer, weight_decay=weight_decay) 356 ``` 357 """ 358 359 def __init__(self, weight_decay, learning_rate=0.001, beta1=0.9, beta2=0.999, 360 epsilon=1e-8, use_locking=False, name="AdamW"): 361 """Construct a new AdamW optimizer. 362 363 For further information see the documentation of the Adam Optimizer. 364 365 Args: 366 weight_decay: A `Tensor` or a floating point value. The weight decay. 367 learning_rate: A Tensor or a floating point value. The learning rate. 368 beta1: A float value or a constant float tensor. 369 The exponential decay rate for the 1st moment estimates. 370 beta2: A float value or a constant float tensor. 371 The exponential decay rate for the 2nd moment estimates. 372 epsilon: A small constant for numerical stability. This epsilon is 373 "epsilon hat" in the Kingma and Ba paper (in the formula just before 374 Section 2.1), not the epsilon in Algorithm 1 of the paper. 375 use_locking: If True use locks for update operations. 376 name: Optional name for the operations created when applying gradients. 377 Defaults to "Adam". 378 """ 379 super(AdamWOptimizer, self).__init__( 380 weight_decay, learning_rate=learning_rate, beta1=beta1, beta2=beta2, 381 epsilon=epsilon, use_locking=use_locking, name=name) 382 383 384@tf_export("contrib.opt.ShampooWOptimizer") 385class ShampooWOptimizer(DecoupledWeightDecayExtension, 386 shampoo.ShampooOptimizer): 387 """Optimizer that implements the Shampoo algorithm with weight decay. 388 389 For further information see the documentation of the Shampoo Optimizer. 390 """ 391 392 def __init__(self, 393 weight_decay, 394 global_step, 395 max_matrix_size=768, 396 gbar_decay=0.0, 397 gbar_weight=1.0, 398 mat_gbar_decay=1.0, 399 mat_gbar_weight=1.0, 400 learning_rate=1.0, 401 svd_interval=1, 402 precond_update_interval=1, 403 epsilon=1e-4, 404 alpha=0.5, 405 use_iterative_root=False, 406 use_locking=False, 407 name="ShampooW"): 408 """Construct a new ShampooW optimizer. 409 410 For further information see the documentation of the Shampoo Optimizer. 411 412 Args: 413 weight_decay: A `Tensor` or a floating point value. The weight decay. 414 global_step: tensorflow variable indicating the step. 415 max_matrix_size: We do not perform SVD for matrices larger than this. 416 gbar_decay: 417 gbar_weight: Used to update gbar: gbar[t] = gbar_decay[t] * gbar[t-1] + 418 gbar_weight[t] * g[t] 419 mat_gbar_decay: 420 mat_gbar_weight: Used to update mat_gbar: mat_gbar_j[t] = 421 mat_gbar_decay[t] * mat_gbar_j[t-1] + mat_gbar_weight[t] * gg_j[t] 422 learning_rate: Similar to SGD 423 svd_interval: We should do SVD after this many steps. Default = 1, i.e. 424 every step. Usually 20 leads to no loss of accuracy, and 50 or 100 is 425 also OK. May also want more often early, 426 and less often later - set in caller as for example: 427 "svd_interval = lambda(T): tf.cond( 428 T < 2000, lambda: 20.0, lambda: 1000.0)" 429 precond_update_interval: We should update the preconditioners after this 430 many steps. Default = 1. Usually less than svd_interval. 431 epsilon: epsilon * I_n is added to each mat_gbar_j for stability 432 alpha: total power of the preconditioners. 433 use_iterative_root: should the optimizer use SVD (faster) or the iterative 434 root method (for TPU) for finding the roots of PSD matrices. 435 use_locking: If `True` use locks for update operations. 436 name: name of optimizer. 437 """ 438 super(ShampooWOptimizer, self).__init__( 439 weight_decay, 440 global_step=global_step, 441 max_matrix_size=max_matrix_size, 442 gbar_decay=gbar_decay, 443 gbar_weight=gbar_weight, 444 mat_gbar_decay=mat_gbar_weight, 445 learning_rate=learning_rate, 446 svd_interval=svd_interval, 447 precond_update_interval=precond_update_interval, 448 epsilon=epsilon, 449 alpha=alpha, 450 use_iterative_root=use_iterative_root, 451 use_locking=use_locking, 452 name=name) 453