1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Adam optimizer implementation."""
16# pylint: disable=g-classes-have-attributes
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.eager import def_function
22from tensorflow.python.framework import ops
23from tensorflow.python.keras import backend_config
24from tensorflow.python.keras.optimizer_v2 import optimizer_v2
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import control_flow_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops import state_ops
29from tensorflow.python.training import gen_training_ops
30from tensorflow.python.util.tf_export import keras_export
31
32
33@keras_export('keras.optimizers.Adam')
34class Adam(optimizer_v2.OptimizerV2):
35  r"""Optimizer that implements the Adam algorithm.
36
37  Adam optimization is a stochastic gradient descent method that is based on
38  adaptive estimation of first-order and second-order moments.
39
40  According to
41  [Kingma et al., 2014](http://arxiv.org/abs/1412.6980),
42  the method is "*computationally
43  efficient, has little memory requirement, invariant to diagonal rescaling of
44  gradients, and is well suited for problems that are large in terms of
45  data/parameters*".
46
47  Args:
48    learning_rate: A `Tensor`, floating point value, or a schedule that is a
49      `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable
50      that takes no arguments and returns the actual value to use, The
51      learning rate. Defaults to 0.001.
52    beta_1: A float value or a constant float tensor, or a callable
53      that takes no arguments and returns the actual value to use. The
54      exponential decay rate for the 1st moment estimates. Defaults to 0.9.
55    beta_2: A float value or a constant float tensor, or a callable
56      that takes no arguments and returns the actual value to use, The
57      exponential decay rate for the 2nd moment estimates. Defaults to 0.999.
58    epsilon: A small constant for numerical stability. This epsilon is
59      "epsilon hat" in the Kingma and Ba paper (in the formula just before
60      Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults to
61      1e-7.
62    amsgrad: Boolean. Whether to apply AMSGrad variant of this algorithm from
63      the paper "On the Convergence of Adam and beyond". Defaults to `False`.
64    name: Optional name for the operations created when applying gradients.
65      Defaults to `"Adam"`.
66    **kwargs: Keyword arguments. Allowed to be one of
67      `"clipnorm"` or `"clipvalue"`.
68      `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips
69      gradients by value.
70
71  Usage:
72
73  >>> opt = tf.keras.optimizers.Adam(learning_rate=0.1)
74  >>> var1 = tf.Variable(10.0)
75  >>> loss = lambda: (var1 ** 2)/2.0       # d(loss)/d(var1) == var1
76  >>> step_count = opt.minimize(loss, [var1]).numpy()
77  >>> # The first step is `-learning_rate*sign(grad)`
78  >>> var1.numpy()
79  9.9
80
81  Reference:
82    - [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
83    - [Reddi et al., 2018](
84        https://openreview.net/pdf?id=ryQu7f-RZ) for `amsgrad`.
85
86  Notes:
87
88  The default value of 1e-7 for epsilon might not be a good default in
89  general. For example, when training an Inception network on ImageNet a
90  current good choice is 1.0 or 0.1. Note that since Adam uses the
91  formulation just before Section 2.1 of the Kingma and Ba paper rather than
92  the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
93  hat" in the paper.
94
95  The sparse implementation of this algorithm (used when the gradient is an
96  IndexedSlices object, typically because of `tf.gather` or an embedding
97  lookup in the forward pass) does apply momentum to variable slices even if
98  they were not used in the forward pass (meaning they have a gradient equal
99  to zero). Momentum decay (beta1) is also applied to the entire momentum
100  accumulator. This means that the sparse behavior is equivalent to the dense
101  behavior (in contrast to some momentum implementations which ignore momentum
102  unless a variable slice was actually used).
103  """
104
105  _HAS_AGGREGATE_GRAD = True
106
107  def __init__(self,
108               learning_rate=0.001,
109               beta_1=0.9,
110               beta_2=0.999,
111               epsilon=1e-7,
112               amsgrad=False,
113               name='Adam',
114               **kwargs):
115    super(Adam, self).__init__(name, **kwargs)
116    self._set_hyper('learning_rate', kwargs.get('lr', learning_rate))
117    self._set_hyper('decay', self._initial_decay)
118    self._set_hyper('beta_1', beta_1)
119    self._set_hyper('beta_2', beta_2)
120    self.epsilon = epsilon or backend_config.epsilon()
121    self.amsgrad = amsgrad
122
123  def _create_slots(self, var_list):
124    # Create slots for the first and second moments.
125    # Separate for-loops to respect the ordering of slot variables from v1.
126    for var in var_list:
127      self.add_slot(var, 'm')
128    for var in var_list:
129      self.add_slot(var, 'v')
130    if self.amsgrad:
131      for var in var_list:
132        self.add_slot(var, 'vhat')
133
134  def _prepare_local(self, var_device, var_dtype, apply_state):
135    super(Adam, self)._prepare_local(var_device, var_dtype, apply_state)
136
137    local_step = math_ops.cast(self.iterations + 1, var_dtype)
138    beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype))
139    beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype))
140    beta_1_power = math_ops.pow(beta_1_t, local_step)
141    beta_2_power = math_ops.pow(beta_2_t, local_step)
142    lr = (apply_state[(var_device, var_dtype)]['lr_t'] *
143          (math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)))
144    apply_state[(var_device, var_dtype)].update(
145        dict(
146            lr=lr,
147            epsilon=ops.convert_to_tensor_v2_with_dispatch(
148                self.epsilon, var_dtype),
149            beta_1_t=beta_1_t,
150            beta_1_power=beta_1_power,
151            one_minus_beta_1_t=1 - beta_1_t,
152            beta_2_t=beta_2_t,
153            beta_2_power=beta_2_power,
154            one_minus_beta_2_t=1 - beta_2_t))
155
156  def set_weights(self, weights):
157    params = self.weights
158    # If the weights are generated by Keras V1 optimizer, it includes vhats
159    # even without amsgrad, i.e, V1 optimizer has 3x + 1 variables, while V2
160    # optimizer has 2x + 1 variables. Filter vhats out for compatibility.
161    num_vars = int((len(params) - 1) / 2)
162    if len(weights) == 3 * num_vars + 1:
163      weights = weights[:len(params)]
164    super(Adam, self).set_weights(weights)
165
166  def _resource_apply_dense(self, grad, var, apply_state=None):
167    var_device, var_dtype = var.device, var.dtype.base_dtype
168    coefficients = ((apply_state or {}).get((var_device, var_dtype))
169                    or self._fallback_apply_state(var_device, var_dtype))
170
171    m = self.get_slot(var, 'm')
172    v = self.get_slot(var, 'v')
173
174    if not self.amsgrad:
175      return gen_training_ops.ResourceApplyAdam(
176          var=var.handle,
177          m=m.handle,
178          v=v.handle,
179          beta1_power=coefficients['beta_1_power'],
180          beta2_power=coefficients['beta_2_power'],
181          lr=coefficients['lr_t'],
182          beta1=coefficients['beta_1_t'],
183          beta2=coefficients['beta_2_t'],
184          epsilon=coefficients['epsilon'],
185          grad=grad,
186          use_locking=self._use_locking)
187    else:
188      vhat = self.get_slot(var, 'vhat')
189      return gen_training_ops.ResourceApplyAdamWithAmsgrad(
190          var=var.handle,
191          m=m.handle,
192          v=v.handle,
193          vhat=vhat.handle,
194          beta1_power=coefficients['beta_1_power'],
195          beta2_power=coefficients['beta_2_power'],
196          lr=coefficients['lr_t'],
197          beta1=coefficients['beta_1_t'],
198          beta2=coefficients['beta_2_t'],
199          epsilon=coefficients['epsilon'],
200          grad=grad,
201          use_locking=self._use_locking)
202
203  def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
204    var_device, var_dtype = var.device, var.dtype.base_dtype
205    coefficients = ((apply_state or {}).get((var_device, var_dtype))
206                    or self._fallback_apply_state(var_device, var_dtype))
207
208    # m_t = beta1 * m + (1 - beta1) * g_t
209    m = self.get_slot(var, 'm')
210    m_scaled_g_values = grad * coefficients['one_minus_beta_1_t']
211    m_t = state_ops.assign(m, m * coefficients['beta_1_t'],
212                           use_locking=self._use_locking)
213    with ops.control_dependencies([m_t]):
214      m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)
215
216    # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
217    v = self.get_slot(var, 'v')
218    v_scaled_g_values = (grad * grad) * coefficients['one_minus_beta_2_t']
219    v_t = state_ops.assign(v, v * coefficients['beta_2_t'],
220                           use_locking=self._use_locking)
221    with ops.control_dependencies([v_t]):
222      v_t = self._resource_scatter_add(v, indices, v_scaled_g_values)
223
224    if not self.amsgrad:
225      v_sqrt = math_ops.sqrt(v_t)
226      var_update = state_ops.assign_sub(
227          var, coefficients['lr'] * m_t / (v_sqrt + coefficients['epsilon']),
228          use_locking=self._use_locking)
229      return control_flow_ops.group(*[var_update, m_t, v_t])
230    else:
231      v_hat = self.get_slot(var, 'vhat')
232      v_hat_t = math_ops.maximum(v_hat, v_t)
233      with ops.control_dependencies([v_hat_t]):
234        v_hat_t = state_ops.assign(
235            v_hat, v_hat_t, use_locking=self._use_locking)
236      v_hat_sqrt = math_ops.sqrt(v_hat_t)
237      var_update = state_ops.assign_sub(
238          var,
239          coefficients['lr'] * m_t / (v_hat_sqrt + coefficients['epsilon']),
240          use_locking=self._use_locking)
241      return control_flow_ops.group(*[var_update, m_t, v_t, v_hat_t])
242
243  def get_config(self):
244    config = super(Adam, self).get_config()
245    config.update({
246        'learning_rate': self._serialize_hyperparameter('learning_rate'),
247        'decay': self._initial_decay,
248        'beta_1': self._serialize_hyperparameter('beta_1'),
249        'beta_2': self._serialize_hyperparameter('beta_2'),
250        'epsilon': self.epsilon,
251        'amsgrad': self.amsgrad,
252    })
253    return config
254
255
256class NonFusedAdam(optimizer_v2.OptimizerV2):
257  r"""Optimizer that implements the Adam algorithm without fused kernels.
258
259  Adam optimization is a stochastic gradient descent method that is based on
260  adaptive estimation of first-order and second-order moments.
261  According to the paper
262  [Adam: A Method for Stochastic Optimization. Kingma et al.,
263  2014](http://arxiv.org/abs/1412.6980), the method is "*computationally
264  efficient, has little memory requirement, invariant to diagonal rescaling of
265  gradients, and is well suited for problems that are large in terms of
266  data/parameters*".
267
268  For AMSGrad see [On The Convergence Of Adam And Beyond.
269  Reddi et al., 5-8](https://openreview.net/pdf?id=ryQu7f-RZ).
270
271  **If amsgrad = False**:
272
273  initialize $m_0$ as 1st moment vector
274  initialize $v_0$ as 2nd moment vector
275
276  The update rule for $\theta$ with gradient $g$ uses an optimization
277  described at the end of section 2 of the paper:
278
279  $$lr_t = \mathrm{learning\_rate} *
280    \sqrt{1 - \beta_2^t} / (1 - \beta_1^t)$$
281  $$m_t = \beta_1 * m_{t-1} + (1 - \beta_1) * g$$
282  $$v_t = \beta_2 * v_{t-1} + (1 - \beta_2) * g^2$$
283  $$\theta_t = \theta_{t-1} - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
284
285  **If amsgrad = True**:
286
287  initialize $m_0$ as 1st moment vector
288  initialize $v_0$ as 2nd moment vector
289  initialize $\hat{v}_0$ as 2nd moment vector
290
291  The update rule for $\theta$ with gradient $g$ uses an optimization
292  described at the end of section 2 of the paper:
293
294  $$lr_t = \mathrm{learning\_rate} *
295    \sqrt{1 - \beta_2^t} / (1 - \beta_1^t)$$
296
297  $$m_t = \beta_1 * m_{t-1} + (1 - \beta_1) * g$$
298  $$v_t = \beta_2 * v_{t-1} + (1 - \beta_2) * g^2$$
299  $$\hat{v}_t = \max(\hat{v}_{t-1}, v_t)$$
300  $$\theta_t = \theta_{t-1} - lr_t * m_t / (\sqrt{\hat{v}_t} + \epsilon)$$
301
302  The default value of 1e-7 for epsilon might not be a good default in
303  general. For example, when training an Inception network on ImageNet a
304  current good choice is 1.0 or 0.1. Note that since Adam uses the
305  formulation just before Section 2.1 of the Kingma and Ba paper rather than
306  the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
307  hat" in the paper.
308
309  The sparse implementation of this algorithm (used when the gradient is an
310  IndexedSlices object, typically because of `tf.gather` or an embedding
311  lookup in the forward pass) does apply momentum to variable slices even if
312  they were not used in the forward pass (meaning they have a gradient equal
313  to zero). Momentum decay (beta1) is also applied to the entire momentum
314  accumulator. This means that the sparse behavior is equivalent to the dense
315  behavior (in contrast to some momentum implementations which ignore momentum
316  unless a variable slice was actually used).
317
318  Usage:
319
320  >>> opt = tf.keras.optimizers.Adam(learning_rate=0.1)
321  >>> var1 = tf.Variable(10.0)
322  >>> loss = lambda: (var1 ** 2)/2.0       # d(loss)/d(var1) == var1
323  >>> step_count = opt.minimize(loss, [var1]).numpy()
324  >>> # The first step is `-learning_rate*sign(grad)`
325  >>> var1.numpy()
326  9.9
327  """
328
329  _HAS_AGGREGATE_GRAD = True
330
331  def __init__(self,
332               learning_rate=0.001,
333               beta_1=0.9,
334               beta_2=0.999,
335               epsilon=1e-7,
336               amsgrad=False,
337               name='Adam',
338               **kwargs):
339    """Construct a new Adam optimizer.
340
341    Args:
342      learning_rate: A `Tensor`, floating point value, or a schedule that is a
343        `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable that
344        takes no arguments and returns the actual value to use, The learning
345        rate. Defaults to 0.001.
346      beta_1: A float value or a constant float tensor, or a callable that takes
347        no arguments and returns the actual value to use. The exponential decay
348        rate for the 1st moment estimates. Defaults to 0.9.
349      beta_2: A float value or a constant float tensor, or a callable that takes
350        no arguments and returns the actual value to use, The exponential decay
351        rate for the 2nd moment estimates. Defaults to 0.999.
352      epsilon: A small constant for numerical stability. This epsilon is
353        "epsilon hat" in the Kingma and Ba paper (in the formula just before
354        Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults to
355        1e-7.
356      amsgrad: Boolean. Whether to apply AMSGrad variant of this algorithm from
357        the paper "On the Convergence of Adam and beyond". Defaults to `False`.
358      name: Optional name for the operations created when applying gradients.
359        Defaults to "Adam".
360      **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`,
361        `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip
362        gradients by value, `decay` is included for backward compatibility to
363        allow time inverse decay of learning rate. `lr` is included for backward
364        compatibility, recommended to use `learning_rate` instead.
365    """
366
367    super(NonFusedAdam, self).__init__(name, **kwargs)
368    self._set_hyper('learning_rate', kwargs.get('lr', learning_rate))
369    self._set_hyper('decay', self._initial_decay)
370    self._set_hyper('beta_1', beta_1)
371    self._set_hyper('beta_2', beta_2)
372    self.epsilon = epsilon or backend_config.epsilon()
373    self.amsgrad = amsgrad
374
375  def _create_slots(self, var_list):
376    # Create slots for the first and second moments.
377    # Separate for-loops to respect the ordering of slot variables from v1.
378    for var in var_list:
379      self.add_slot(var, 'm')
380    for var in var_list:
381      self.add_slot(var, 'v')
382    if self.amsgrad:
383      for var in var_list:
384        self.add_slot(var, 'vhat')
385
386  def _prepare_local(self, var_device, var_dtype, apply_state):
387    super(NonFusedAdam, self)._prepare_local(var_device, var_dtype, apply_state)
388
389    local_step = math_ops.cast(self.iterations + 1, var_dtype)
390    beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype))
391    beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype))
392    beta_1_power = math_ops.pow(beta_1_t, local_step)
393    beta_2_power = math_ops.pow(beta_2_t, local_step)
394    lr = (
395        apply_state[(var_device, var_dtype)]['lr_t'] *
396        (math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)))
397    apply_state[(var_device, var_dtype)].update(
398        dict(
399            lr=lr,
400            epsilon=ops.convert_to_tensor_v2_with_dispatch(
401                self.epsilon, var_dtype),
402            beta_1_t=beta_1_t,
403            beta_1_power=beta_1_power,
404            one_minus_beta_1_t=1 - beta_1_t,
405            beta_2_t=beta_2_t,
406            beta_2_power=beta_2_power,
407            one_minus_beta_2_t=1 - beta_2_t))
408
409  def set_weights(self, weights):
410    params = self.weights
411    # If the weights are generated by Keras V1 optimizer, it includes vhats
412    # even without amsgrad, i.e, V1 optimizer has 3x + 1 variables, while V2
413    # optimizer has 2x + 1 variables. Filter vhats out for compatibility.
414    num_vars = int((len(params) - 1) / 2)
415    if len(weights) == 3 * num_vars + 1:
416      weights = weights[:len(params)]
417    super(NonFusedAdam, self).set_weights(weights)
418
419  @def_function.function(jit_compile=True)
420  def _resource_apply_dense(self, grad, var, apply_state=None):
421    var_device, var_dtype = var.device, var.dtype.base_dtype
422    coefficients = ((apply_state or {}).get((var_device, var_dtype)) or
423                    self._fallback_apply_state(var_device, var_dtype))
424
425    m = self.get_slot(var, 'm')
426    v = self.get_slot(var, 'v')
427
428    alpha = (
429        coefficients['lr_t'] * math_ops.sqrt(1 - coefficients['beta_2_power']) /
430        (1 - coefficients['beta_1_power']))
431    m.assign_add((grad - m) * (1 - coefficients['beta_1_t']))
432    v.assign_add((math_ops.square(grad) - v) * (1 - coefficients['beta_2_t']))
433    if self.amsgrad:
434      vhat = self.get_slot(var, 'vhat')
435      vhat.assign(math_ops.maximum(vhat, v))
436      v = vhat
437    var.assign_sub(
438        (m * alpha) / (math_ops.sqrt(v) - coefficients['epsilon']))
439
440  @def_function.function(jit_compile=True)
441  def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
442    var_device, var_dtype = var.device, var.dtype.base_dtype
443    coefficients = ((apply_state or {}).get((var_device, var_dtype)) or
444                    self._fallback_apply_state(var_device, var_dtype))
445
446    # m_t = beta1 * m + (1 - beta1) * g_t
447    m = self.get_slot(var, 'm')
448    m_scaled_g_values = grad * coefficients['one_minus_beta_1_t']
449    m.assign(m * coefficients['beta_1_t'])
450    m.scatter_add(ops.IndexedSlices(m_scaled_g_values, indices))
451
452    # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
453    v = self.get_slot(var, 'v')
454    v_scaled_g_values = (grad * grad) * coefficients['one_minus_beta_2_t']
455    v.assign(v * coefficients['beta_2_t'])
456    v.scatter_add(ops.IndexedSlices(v_scaled_g_values, indices))
457
458    if not self.amsgrad:
459      var.assign_sub(coefficients['lr'] * m /
460                     (math_ops.sqrt(v) + coefficients['epsilon']))
461    else:
462      v_hat = self.get_slot(var, 'vhat')
463      v_hat.assign(math_ops.maximum(v_hat, v))
464      var.assign_sub(coefficients['lr'] * m /
465                     (math_ops.sqrt(v_hat) + coefficients['epsilon']))
466
467  def get_config(self):
468    config = super(NonFusedAdam, self).get_config()
469    config.update({
470        'learning_rate': self._serialize_hyperparameter('learning_rate'),
471        'decay': self._initial_decay,
472        'beta_1': self._serialize_hyperparameter('beta_1'),
473        'beta_2': self._serialize_hyperparameter('beta_2'),
474        'epsilon': self.epsilon,
475        'amsgrad': self.amsgrad,
476    })
477    return config
478