1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Adam optimizer implementation.""" 16# pylint: disable=g-classes-have-attributes 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.eager import def_function 22from tensorflow.python.framework import ops 23from tensorflow.python.keras import backend_config 24from tensorflow.python.keras.optimizer_v2 import optimizer_v2 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import control_flow_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.ops import state_ops 29from tensorflow.python.training import gen_training_ops 30from tensorflow.python.util.tf_export import keras_export 31 32 33@keras_export('keras.optimizers.Adam') 34class Adam(optimizer_v2.OptimizerV2): 35 r"""Optimizer that implements the Adam algorithm. 36 37 Adam optimization is a stochastic gradient descent method that is based on 38 adaptive estimation of first-order and second-order moments. 39 40 According to 41 [Kingma et al., 2014](http://arxiv.org/abs/1412.6980), 42 the method is "*computationally 43 efficient, has little memory requirement, invariant to diagonal rescaling of 44 gradients, and is well suited for problems that are large in terms of 45 data/parameters*". 46 47 Args: 48 learning_rate: A `Tensor`, floating point value, or a schedule that is a 49 `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable 50 that takes no arguments and returns the actual value to use, The 51 learning rate. Defaults to 0.001. 52 beta_1: A float value or a constant float tensor, or a callable 53 that takes no arguments and returns the actual value to use. The 54 exponential decay rate for the 1st moment estimates. Defaults to 0.9. 55 beta_2: A float value or a constant float tensor, or a callable 56 that takes no arguments and returns the actual value to use, The 57 exponential decay rate for the 2nd moment estimates. Defaults to 0.999. 58 epsilon: A small constant for numerical stability. This epsilon is 59 "epsilon hat" in the Kingma and Ba paper (in the formula just before 60 Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults to 61 1e-7. 62 amsgrad: Boolean. Whether to apply AMSGrad variant of this algorithm from 63 the paper "On the Convergence of Adam and beyond". Defaults to `False`. 64 name: Optional name for the operations created when applying gradients. 65 Defaults to `"Adam"`. 66 **kwargs: Keyword arguments. Allowed to be one of 67 `"clipnorm"` or `"clipvalue"`. 68 `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips 69 gradients by value. 70 71 Usage: 72 73 >>> opt = tf.keras.optimizers.Adam(learning_rate=0.1) 74 >>> var1 = tf.Variable(10.0) 75 >>> loss = lambda: (var1 ** 2)/2.0 # d(loss)/d(var1) == var1 76 >>> step_count = opt.minimize(loss, [var1]).numpy() 77 >>> # The first step is `-learning_rate*sign(grad)` 78 >>> var1.numpy() 79 9.9 80 81 Reference: 82 - [Kingma et al., 2014](http://arxiv.org/abs/1412.6980) 83 - [Reddi et al., 2018]( 84 https://openreview.net/pdf?id=ryQu7f-RZ) for `amsgrad`. 85 86 Notes: 87 88 The default value of 1e-7 for epsilon might not be a good default in 89 general. For example, when training an Inception network on ImageNet a 90 current good choice is 1.0 or 0.1. Note that since Adam uses the 91 formulation just before Section 2.1 of the Kingma and Ba paper rather than 92 the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon 93 hat" in the paper. 94 95 The sparse implementation of this algorithm (used when the gradient is an 96 IndexedSlices object, typically because of `tf.gather` or an embedding 97 lookup in the forward pass) does apply momentum to variable slices even if 98 they were not used in the forward pass (meaning they have a gradient equal 99 to zero). Momentum decay (beta1) is also applied to the entire momentum 100 accumulator. This means that the sparse behavior is equivalent to the dense 101 behavior (in contrast to some momentum implementations which ignore momentum 102 unless a variable slice was actually used). 103 """ 104 105 _HAS_AGGREGATE_GRAD = True 106 107 def __init__(self, 108 learning_rate=0.001, 109 beta_1=0.9, 110 beta_2=0.999, 111 epsilon=1e-7, 112 amsgrad=False, 113 name='Adam', 114 **kwargs): 115 super(Adam, self).__init__(name, **kwargs) 116 self._set_hyper('learning_rate', kwargs.get('lr', learning_rate)) 117 self._set_hyper('decay', self._initial_decay) 118 self._set_hyper('beta_1', beta_1) 119 self._set_hyper('beta_2', beta_2) 120 self.epsilon = epsilon or backend_config.epsilon() 121 self.amsgrad = amsgrad 122 123 def _create_slots(self, var_list): 124 # Create slots for the first and second moments. 125 # Separate for-loops to respect the ordering of slot variables from v1. 126 for var in var_list: 127 self.add_slot(var, 'm') 128 for var in var_list: 129 self.add_slot(var, 'v') 130 if self.amsgrad: 131 for var in var_list: 132 self.add_slot(var, 'vhat') 133 134 def _prepare_local(self, var_device, var_dtype, apply_state): 135 super(Adam, self)._prepare_local(var_device, var_dtype, apply_state) 136 137 local_step = math_ops.cast(self.iterations + 1, var_dtype) 138 beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype)) 139 beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype)) 140 beta_1_power = math_ops.pow(beta_1_t, local_step) 141 beta_2_power = math_ops.pow(beta_2_t, local_step) 142 lr = (apply_state[(var_device, var_dtype)]['lr_t'] * 143 (math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power))) 144 apply_state[(var_device, var_dtype)].update( 145 dict( 146 lr=lr, 147 epsilon=ops.convert_to_tensor_v2_with_dispatch( 148 self.epsilon, var_dtype), 149 beta_1_t=beta_1_t, 150 beta_1_power=beta_1_power, 151 one_minus_beta_1_t=1 - beta_1_t, 152 beta_2_t=beta_2_t, 153 beta_2_power=beta_2_power, 154 one_minus_beta_2_t=1 - beta_2_t)) 155 156 def set_weights(self, weights): 157 params = self.weights 158 # If the weights are generated by Keras V1 optimizer, it includes vhats 159 # even without amsgrad, i.e, V1 optimizer has 3x + 1 variables, while V2 160 # optimizer has 2x + 1 variables. Filter vhats out for compatibility. 161 num_vars = int((len(params) - 1) / 2) 162 if len(weights) == 3 * num_vars + 1: 163 weights = weights[:len(params)] 164 super(Adam, self).set_weights(weights) 165 166 def _resource_apply_dense(self, grad, var, apply_state=None): 167 var_device, var_dtype = var.device, var.dtype.base_dtype 168 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 169 or self._fallback_apply_state(var_device, var_dtype)) 170 171 m = self.get_slot(var, 'm') 172 v = self.get_slot(var, 'v') 173 174 if not self.amsgrad: 175 return gen_training_ops.ResourceApplyAdam( 176 var=var.handle, 177 m=m.handle, 178 v=v.handle, 179 beta1_power=coefficients['beta_1_power'], 180 beta2_power=coefficients['beta_2_power'], 181 lr=coefficients['lr_t'], 182 beta1=coefficients['beta_1_t'], 183 beta2=coefficients['beta_2_t'], 184 epsilon=coefficients['epsilon'], 185 grad=grad, 186 use_locking=self._use_locking) 187 else: 188 vhat = self.get_slot(var, 'vhat') 189 return gen_training_ops.ResourceApplyAdamWithAmsgrad( 190 var=var.handle, 191 m=m.handle, 192 v=v.handle, 193 vhat=vhat.handle, 194 beta1_power=coefficients['beta_1_power'], 195 beta2_power=coefficients['beta_2_power'], 196 lr=coefficients['lr_t'], 197 beta1=coefficients['beta_1_t'], 198 beta2=coefficients['beta_2_t'], 199 epsilon=coefficients['epsilon'], 200 grad=grad, 201 use_locking=self._use_locking) 202 203 def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 204 var_device, var_dtype = var.device, var.dtype.base_dtype 205 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 206 or self._fallback_apply_state(var_device, var_dtype)) 207 208 # m_t = beta1 * m + (1 - beta1) * g_t 209 m = self.get_slot(var, 'm') 210 m_scaled_g_values = grad * coefficients['one_minus_beta_1_t'] 211 m_t = state_ops.assign(m, m * coefficients['beta_1_t'], 212 use_locking=self._use_locking) 213 with ops.control_dependencies([m_t]): 214 m_t = self._resource_scatter_add(m, indices, m_scaled_g_values) 215 216 # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) 217 v = self.get_slot(var, 'v') 218 v_scaled_g_values = (grad * grad) * coefficients['one_minus_beta_2_t'] 219 v_t = state_ops.assign(v, v * coefficients['beta_2_t'], 220 use_locking=self._use_locking) 221 with ops.control_dependencies([v_t]): 222 v_t = self._resource_scatter_add(v, indices, v_scaled_g_values) 223 224 if not self.amsgrad: 225 v_sqrt = math_ops.sqrt(v_t) 226 var_update = state_ops.assign_sub( 227 var, coefficients['lr'] * m_t / (v_sqrt + coefficients['epsilon']), 228 use_locking=self._use_locking) 229 return control_flow_ops.group(*[var_update, m_t, v_t]) 230 else: 231 v_hat = self.get_slot(var, 'vhat') 232 v_hat_t = math_ops.maximum(v_hat, v_t) 233 with ops.control_dependencies([v_hat_t]): 234 v_hat_t = state_ops.assign( 235 v_hat, v_hat_t, use_locking=self._use_locking) 236 v_hat_sqrt = math_ops.sqrt(v_hat_t) 237 var_update = state_ops.assign_sub( 238 var, 239 coefficients['lr'] * m_t / (v_hat_sqrt + coefficients['epsilon']), 240 use_locking=self._use_locking) 241 return control_flow_ops.group(*[var_update, m_t, v_t, v_hat_t]) 242 243 def get_config(self): 244 config = super(Adam, self).get_config() 245 config.update({ 246 'learning_rate': self._serialize_hyperparameter('learning_rate'), 247 'decay': self._initial_decay, 248 'beta_1': self._serialize_hyperparameter('beta_1'), 249 'beta_2': self._serialize_hyperparameter('beta_2'), 250 'epsilon': self.epsilon, 251 'amsgrad': self.amsgrad, 252 }) 253 return config 254 255 256class NonFusedAdam(optimizer_v2.OptimizerV2): 257 r"""Optimizer that implements the Adam algorithm without fused kernels. 258 259 Adam optimization is a stochastic gradient descent method that is based on 260 adaptive estimation of first-order and second-order moments. 261 According to the paper 262 [Adam: A Method for Stochastic Optimization. Kingma et al., 263 2014](http://arxiv.org/abs/1412.6980), the method is "*computationally 264 efficient, has little memory requirement, invariant to diagonal rescaling of 265 gradients, and is well suited for problems that are large in terms of 266 data/parameters*". 267 268 For AMSGrad see [On The Convergence Of Adam And Beyond. 269 Reddi et al., 5-8](https://openreview.net/pdf?id=ryQu7f-RZ). 270 271 **If amsgrad = False**: 272 273 initialize $m_0$ as 1st moment vector 274 initialize $v_0$ as 2nd moment vector 275 276 The update rule for $\theta$ with gradient $g$ uses an optimization 277 described at the end of section 2 of the paper: 278 279 $$lr_t = \mathrm{learning\_rate} * 280 \sqrt{1 - \beta_2^t} / (1 - \beta_1^t)$$ 281 $$m_t = \beta_1 * m_{t-1} + (1 - \beta_1) * g$$ 282 $$v_t = \beta_2 * v_{t-1} + (1 - \beta_2) * g^2$$ 283 $$\theta_t = \theta_{t-1} - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$ 284 285 **If amsgrad = True**: 286 287 initialize $m_0$ as 1st moment vector 288 initialize $v_0$ as 2nd moment vector 289 initialize $\hat{v}_0$ as 2nd moment vector 290 291 The update rule for $\theta$ with gradient $g$ uses an optimization 292 described at the end of section 2 of the paper: 293 294 $$lr_t = \mathrm{learning\_rate} * 295 \sqrt{1 - \beta_2^t} / (1 - \beta_1^t)$$ 296 297 $$m_t = \beta_1 * m_{t-1} + (1 - \beta_1) * g$$ 298 $$v_t = \beta_2 * v_{t-1} + (1 - \beta_2) * g^2$$ 299 $$\hat{v}_t = \max(\hat{v}_{t-1}, v_t)$$ 300 $$\theta_t = \theta_{t-1} - lr_t * m_t / (\sqrt{\hat{v}_t} + \epsilon)$$ 301 302 The default value of 1e-7 for epsilon might not be a good default in 303 general. For example, when training an Inception network on ImageNet a 304 current good choice is 1.0 or 0.1. Note that since Adam uses the 305 formulation just before Section 2.1 of the Kingma and Ba paper rather than 306 the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon 307 hat" in the paper. 308 309 The sparse implementation of this algorithm (used when the gradient is an 310 IndexedSlices object, typically because of `tf.gather` or an embedding 311 lookup in the forward pass) does apply momentum to variable slices even if 312 they were not used in the forward pass (meaning they have a gradient equal 313 to zero). Momentum decay (beta1) is also applied to the entire momentum 314 accumulator. This means that the sparse behavior is equivalent to the dense 315 behavior (in contrast to some momentum implementations which ignore momentum 316 unless a variable slice was actually used). 317 318 Usage: 319 320 >>> opt = tf.keras.optimizers.Adam(learning_rate=0.1) 321 >>> var1 = tf.Variable(10.0) 322 >>> loss = lambda: (var1 ** 2)/2.0 # d(loss)/d(var1) == var1 323 >>> step_count = opt.minimize(loss, [var1]).numpy() 324 >>> # The first step is `-learning_rate*sign(grad)` 325 >>> var1.numpy() 326 9.9 327 """ 328 329 _HAS_AGGREGATE_GRAD = True 330 331 def __init__(self, 332 learning_rate=0.001, 333 beta_1=0.9, 334 beta_2=0.999, 335 epsilon=1e-7, 336 amsgrad=False, 337 name='Adam', 338 **kwargs): 339 """Construct a new Adam optimizer. 340 341 Args: 342 learning_rate: A `Tensor`, floating point value, or a schedule that is a 343 `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable that 344 takes no arguments and returns the actual value to use, The learning 345 rate. Defaults to 0.001. 346 beta_1: A float value or a constant float tensor, or a callable that takes 347 no arguments and returns the actual value to use. The exponential decay 348 rate for the 1st moment estimates. Defaults to 0.9. 349 beta_2: A float value or a constant float tensor, or a callable that takes 350 no arguments and returns the actual value to use, The exponential decay 351 rate for the 2nd moment estimates. Defaults to 0.999. 352 epsilon: A small constant for numerical stability. This epsilon is 353 "epsilon hat" in the Kingma and Ba paper (in the formula just before 354 Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults to 355 1e-7. 356 amsgrad: Boolean. Whether to apply AMSGrad variant of this algorithm from 357 the paper "On the Convergence of Adam and beyond". Defaults to `False`. 358 name: Optional name for the operations created when applying gradients. 359 Defaults to "Adam". 360 **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, 361 `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip 362 gradients by value, `decay` is included for backward compatibility to 363 allow time inverse decay of learning rate. `lr` is included for backward 364 compatibility, recommended to use `learning_rate` instead. 365 """ 366 367 super(NonFusedAdam, self).__init__(name, **kwargs) 368 self._set_hyper('learning_rate', kwargs.get('lr', learning_rate)) 369 self._set_hyper('decay', self._initial_decay) 370 self._set_hyper('beta_1', beta_1) 371 self._set_hyper('beta_2', beta_2) 372 self.epsilon = epsilon or backend_config.epsilon() 373 self.amsgrad = amsgrad 374 375 def _create_slots(self, var_list): 376 # Create slots for the first and second moments. 377 # Separate for-loops to respect the ordering of slot variables from v1. 378 for var in var_list: 379 self.add_slot(var, 'm') 380 for var in var_list: 381 self.add_slot(var, 'v') 382 if self.amsgrad: 383 for var in var_list: 384 self.add_slot(var, 'vhat') 385 386 def _prepare_local(self, var_device, var_dtype, apply_state): 387 super(NonFusedAdam, self)._prepare_local(var_device, var_dtype, apply_state) 388 389 local_step = math_ops.cast(self.iterations + 1, var_dtype) 390 beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype)) 391 beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype)) 392 beta_1_power = math_ops.pow(beta_1_t, local_step) 393 beta_2_power = math_ops.pow(beta_2_t, local_step) 394 lr = ( 395 apply_state[(var_device, var_dtype)]['lr_t'] * 396 (math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power))) 397 apply_state[(var_device, var_dtype)].update( 398 dict( 399 lr=lr, 400 epsilon=ops.convert_to_tensor_v2_with_dispatch( 401 self.epsilon, var_dtype), 402 beta_1_t=beta_1_t, 403 beta_1_power=beta_1_power, 404 one_minus_beta_1_t=1 - beta_1_t, 405 beta_2_t=beta_2_t, 406 beta_2_power=beta_2_power, 407 one_minus_beta_2_t=1 - beta_2_t)) 408 409 def set_weights(self, weights): 410 params = self.weights 411 # If the weights are generated by Keras V1 optimizer, it includes vhats 412 # even without amsgrad, i.e, V1 optimizer has 3x + 1 variables, while V2 413 # optimizer has 2x + 1 variables. Filter vhats out for compatibility. 414 num_vars = int((len(params) - 1) / 2) 415 if len(weights) == 3 * num_vars + 1: 416 weights = weights[:len(params)] 417 super(NonFusedAdam, self).set_weights(weights) 418 419 @def_function.function(jit_compile=True) 420 def _resource_apply_dense(self, grad, var, apply_state=None): 421 var_device, var_dtype = var.device, var.dtype.base_dtype 422 coefficients = ((apply_state or {}).get((var_device, var_dtype)) or 423 self._fallback_apply_state(var_device, var_dtype)) 424 425 m = self.get_slot(var, 'm') 426 v = self.get_slot(var, 'v') 427 428 alpha = ( 429 coefficients['lr_t'] * math_ops.sqrt(1 - coefficients['beta_2_power']) / 430 (1 - coefficients['beta_1_power'])) 431 m.assign_add((grad - m) * (1 - coefficients['beta_1_t'])) 432 v.assign_add((math_ops.square(grad) - v) * (1 - coefficients['beta_2_t'])) 433 if self.amsgrad: 434 vhat = self.get_slot(var, 'vhat') 435 vhat.assign(math_ops.maximum(vhat, v)) 436 v = vhat 437 var.assign_sub( 438 (m * alpha) / (math_ops.sqrt(v) - coefficients['epsilon'])) 439 440 @def_function.function(jit_compile=True) 441 def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 442 var_device, var_dtype = var.device, var.dtype.base_dtype 443 coefficients = ((apply_state or {}).get((var_device, var_dtype)) or 444 self._fallback_apply_state(var_device, var_dtype)) 445 446 # m_t = beta1 * m + (1 - beta1) * g_t 447 m = self.get_slot(var, 'm') 448 m_scaled_g_values = grad * coefficients['one_minus_beta_1_t'] 449 m.assign(m * coefficients['beta_1_t']) 450 m.scatter_add(ops.IndexedSlices(m_scaled_g_values, indices)) 451 452 # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) 453 v = self.get_slot(var, 'v') 454 v_scaled_g_values = (grad * grad) * coefficients['one_minus_beta_2_t'] 455 v.assign(v * coefficients['beta_2_t']) 456 v.scatter_add(ops.IndexedSlices(v_scaled_g_values, indices)) 457 458 if not self.amsgrad: 459 var.assign_sub(coefficients['lr'] * m / 460 (math_ops.sqrt(v) + coefficients['epsilon'])) 461 else: 462 v_hat = self.get_slot(var, 'vhat') 463 v_hat.assign(math_ops.maximum(v_hat, v)) 464 var.assign_sub(coefficients['lr'] * m / 465 (math_ops.sqrt(v_hat) + coefficients['epsilon'])) 466 467 def get_config(self): 468 config = super(NonFusedAdam, self).get_config() 469 config.update({ 470 'learning_rate': self._serialize_hyperparameter('learning_rate'), 471 'decay': self._initial_decay, 472 'beta_1': self._serialize_hyperparameter('beta_1'), 473 'beta_2': self._serialize_hyperparameter('beta_2'), 474 'epsilon': self.epsilon, 475 'amsgrad': self.amsgrad, 476 }) 477 return config 478