1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Adamax optimizer implementation."""
16# pylint: disable=g-classes-have-attributes
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.keras import backend_config
24from tensorflow.python.keras.optimizer_v2 import optimizer_v2
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import control_flow_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.training import gen_training_ops
29from tensorflow.python.util.tf_export import keras_export
30
31
32@keras_export('keras.optimizers.Adamax')
33class Adamax(optimizer_v2.OptimizerV2):
34  """Optimizer that implements the Adamax algorithm.
35
36  It is a variant of Adam based on the infinity norm.
37  Default parameters follow those provided in the paper.
38  Adamax is sometimes superior to adam, specially in models with embeddings.
39
40  Initialization:
41
42  ```python
43  m = 0  # Initialize initial 1st moment vector
44  v = 0  # Initialize the exponentially weighted infinity norm
45  t = 0  # Initialize timestep
46  ```
47
48  The update rule for parameter `w` with gradient `g` is
49  described at the end of section 7.1 of the paper:
50
51  ```python
52  t += 1
53  m = beta1 * m + (1 - beta) * g
54  v = max(beta2 * v, abs(g))
55  current_lr = learning_rate / (1 - beta1 ** t)
56  w = w - current_lr * m / (v + epsilon)
57  ```
58
59  Similarly to `Adam`, the epsilon is added for numerical stability
60  (especially to get rid of division by zero when `v_t == 0`).
61
62  In contrast to `Adam`, the sparse implementation of this algorithm
63  (used when the gradient is an IndexedSlices object, typically because of
64  `tf.gather` or an embedding lookup in the forward pass) only updates
65  variable slices and corresponding `m_t`, `v_t` terms when that part of
66  the variable was used in the forward pass. This means that the sparse
67  behavior is contrast to the dense behavior (similar to some momentum
68  implementations which ignore momentum unless a variable slice was actually
69  used).
70
71  Args:
72    learning_rate: A `Tensor`, floating point value, or a schedule that is a
73      `tf.keras.optimizers.schedules.LearningRateSchedule`. The learning rate.
74    beta_1: A float value or a constant float tensor. The exponential decay
75      rate for the 1st moment estimates.
76    beta_2: A float value or a constant float tensor. The exponential decay
77      rate for the exponentially weighted infinity norm.
78    epsilon: A small constant for numerical stability.
79    name: Optional name for the operations created when applying gradients.
80      Defaults to `"Adamax"`.
81    **kwargs: Keyword arguments. Allowed to be one of
82      `"clipnorm"` or `"clipvalue"`.
83      `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips
84      gradients by value.
85
86  Reference:
87    - [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
88  """
89
90  _HAS_AGGREGATE_GRAD = True
91
92  def __init__(self,
93               learning_rate=0.001,
94               beta_1=0.9,
95               beta_2=0.999,
96               epsilon=1e-7,
97               name='Adamax',
98               **kwargs):
99    super(Adamax, self).__init__(name, **kwargs)
100    self._set_hyper('learning_rate', kwargs.get('lr', learning_rate))
101    self._set_hyper('decay', self._initial_decay)
102    self._set_hyper('beta_1', beta_1)
103    self._set_hyper('beta_2', beta_2)
104    self.epsilon = epsilon or backend_config.epsilon()
105
106  def _create_slots(self, var_list):
107    # Separate for-loops to respect the ordering of slot variables from v1.
108    for var in var_list:
109      self.add_slot(var, 'm')  # Create slots for the first moments.
110    for var in var_list:
111      self.add_slot(var, 'v')  # Create slots for the second moments.
112
113  def _prepare_local(self, var_device, var_dtype, apply_state):
114    super(Adamax, self)._prepare_local(var_device, var_dtype, apply_state)
115
116    local_step = math_ops.cast(self.iterations + 1, var_dtype)
117    beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype))
118    beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype))
119    beta_1_power = math_ops.pow(beta_1_t, local_step)
120    lr_t = apply_state[(var_device, var_dtype)]['lr_t']
121
122    apply_state[(var_device, var_dtype)].update(
123        dict(
124            neg_scaled_lr=-lr_t / (1 - beta_1_power),
125            epsilon=ops.convert_to_tensor_v2_with_dispatch(
126                self.epsilon, var_dtype),
127            beta_1_t=beta_1_t,
128            beta_1_power=beta_1_power,
129            one_minus_beta_1_t=1 - beta_1_t,
130            beta_2_t=beta_2_t,
131            zero=array_ops.zeros((), dtype=dtypes.int64)))
132
133  def _resource_apply_dense(self, grad, var, apply_state=None):
134    var_device, var_dtype = var.device, var.dtype.base_dtype
135    coefficients = ((apply_state or {}).get((var_device, var_dtype))
136                    or self._fallback_apply_state(var_device, var_dtype))
137
138    m = self.get_slot(var, 'm')
139    v = self.get_slot(var, 'v')
140    return gen_training_ops.ResourceApplyAdaMax(
141        var=var.handle,
142        m=m.handle,
143        v=v.handle,
144        beta1_power=coefficients['beta_1_power'],
145        lr=coefficients['lr_t'],
146        beta1=coefficients['beta_1_t'],
147        beta2=coefficients['beta_2_t'],
148        epsilon=coefficients['epsilon'],
149        grad=grad,
150        use_locking=self._use_locking)
151
152  def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
153    var_device, var_dtype = var.device, var.dtype.base_dtype
154    coefficients = ((apply_state or {}).get((var_device, var_dtype))
155                    or self._fallback_apply_state(var_device, var_dtype))
156
157    # m_t = beta1 * m + (1 - beta1) * g_t
158    m = self.get_slot(var, 'm')
159    m_slice = array_ops.gather(m, indices, axis=coefficients['zero'])
160    m_t_slice = (m_slice * coefficients['beta_1_t'] +
161                 grad * coefficients['one_minus_beta_1_t'])
162    with ops.control_dependencies([m_t_slice]):
163      m_t = self._resource_scatter_update(m, indices, m_t_slice)
164
165    # u_t = max(beta2 * u, abs(g_t))
166    v = self.get_slot(var, 'v')
167    v_slice = array_ops.gather(v, indices, axis=coefficients['zero'])
168    v_t_slice = math_ops.maximum(v_slice * coefficients['beta_2_t'],
169                                 math_ops.abs(grad))
170    with ops.control_dependencies([v_t_slice]):
171      v_t = self._resource_scatter_update(v, indices, v_t_slice)
172    # theta_t = theta - lr / (1 - beta1^t) * m_t / u_t
173    var_slice = coefficients['neg_scaled_lr'] * (
174        m_t_slice / (v_t_slice + coefficients['epsilon']))
175    with ops.control_dependencies([var_slice]):
176      var_update = self._resource_scatter_add(var, indices, var_slice)
177    return control_flow_ops.group(*[var_update, m_t, v_t])
178
179  def get_config(self):
180    config = super(Adamax, self).get_config()
181    config.update({
182        'learning_rate': self._serialize_hyperparameter('learning_rate'),
183        'decay': self._initial_decay,
184        'beta_1': self._serialize_hyperparameter('beta_1'),
185        'beta_2': self._serialize_hyperparameter('beta_2'),
186        'epsilon': self.epsilon,
187    })
188    return config
189