1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Adamax optimizer implementation.""" 16# pylint: disable=g-classes-have-attributes 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.keras import backend_config 24from tensorflow.python.keras.optimizer_v2 import optimizer_v2 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import control_flow_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.training import gen_training_ops 29from tensorflow.python.util.tf_export import keras_export 30 31 32@keras_export('keras.optimizers.Adamax') 33class Adamax(optimizer_v2.OptimizerV2): 34 """Optimizer that implements the Adamax algorithm. 35 36 It is a variant of Adam based on the infinity norm. 37 Default parameters follow those provided in the paper. 38 Adamax is sometimes superior to adam, specially in models with embeddings. 39 40 Initialization: 41 42 ```python 43 m = 0 # Initialize initial 1st moment vector 44 v = 0 # Initialize the exponentially weighted infinity norm 45 t = 0 # Initialize timestep 46 ``` 47 48 The update rule for parameter `w` with gradient `g` is 49 described at the end of section 7.1 of the paper: 50 51 ```python 52 t += 1 53 m = beta1 * m + (1 - beta) * g 54 v = max(beta2 * v, abs(g)) 55 current_lr = learning_rate / (1 - beta1 ** t) 56 w = w - current_lr * m / (v + epsilon) 57 ``` 58 59 Similarly to `Adam`, the epsilon is added for numerical stability 60 (especially to get rid of division by zero when `v_t == 0`). 61 62 In contrast to `Adam`, the sparse implementation of this algorithm 63 (used when the gradient is an IndexedSlices object, typically because of 64 `tf.gather` or an embedding lookup in the forward pass) only updates 65 variable slices and corresponding `m_t`, `v_t` terms when that part of 66 the variable was used in the forward pass. This means that the sparse 67 behavior is contrast to the dense behavior (similar to some momentum 68 implementations which ignore momentum unless a variable slice was actually 69 used). 70 71 Args: 72 learning_rate: A `Tensor`, floating point value, or a schedule that is a 73 `tf.keras.optimizers.schedules.LearningRateSchedule`. The learning rate. 74 beta_1: A float value or a constant float tensor. The exponential decay 75 rate for the 1st moment estimates. 76 beta_2: A float value or a constant float tensor. The exponential decay 77 rate for the exponentially weighted infinity norm. 78 epsilon: A small constant for numerical stability. 79 name: Optional name for the operations created when applying gradients. 80 Defaults to `"Adamax"`. 81 **kwargs: Keyword arguments. Allowed to be one of 82 `"clipnorm"` or `"clipvalue"`. 83 `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips 84 gradients by value. 85 86 Reference: 87 - [Kingma et al., 2014](http://arxiv.org/abs/1412.6980) 88 """ 89 90 _HAS_AGGREGATE_GRAD = True 91 92 def __init__(self, 93 learning_rate=0.001, 94 beta_1=0.9, 95 beta_2=0.999, 96 epsilon=1e-7, 97 name='Adamax', 98 **kwargs): 99 super(Adamax, self).__init__(name, **kwargs) 100 self._set_hyper('learning_rate', kwargs.get('lr', learning_rate)) 101 self._set_hyper('decay', self._initial_decay) 102 self._set_hyper('beta_1', beta_1) 103 self._set_hyper('beta_2', beta_2) 104 self.epsilon = epsilon or backend_config.epsilon() 105 106 def _create_slots(self, var_list): 107 # Separate for-loops to respect the ordering of slot variables from v1. 108 for var in var_list: 109 self.add_slot(var, 'm') # Create slots for the first moments. 110 for var in var_list: 111 self.add_slot(var, 'v') # Create slots for the second moments. 112 113 def _prepare_local(self, var_device, var_dtype, apply_state): 114 super(Adamax, self)._prepare_local(var_device, var_dtype, apply_state) 115 116 local_step = math_ops.cast(self.iterations + 1, var_dtype) 117 beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype)) 118 beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype)) 119 beta_1_power = math_ops.pow(beta_1_t, local_step) 120 lr_t = apply_state[(var_device, var_dtype)]['lr_t'] 121 122 apply_state[(var_device, var_dtype)].update( 123 dict( 124 neg_scaled_lr=-lr_t / (1 - beta_1_power), 125 epsilon=ops.convert_to_tensor_v2_with_dispatch( 126 self.epsilon, var_dtype), 127 beta_1_t=beta_1_t, 128 beta_1_power=beta_1_power, 129 one_minus_beta_1_t=1 - beta_1_t, 130 beta_2_t=beta_2_t, 131 zero=array_ops.zeros((), dtype=dtypes.int64))) 132 133 def _resource_apply_dense(self, grad, var, apply_state=None): 134 var_device, var_dtype = var.device, var.dtype.base_dtype 135 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 136 or self._fallback_apply_state(var_device, var_dtype)) 137 138 m = self.get_slot(var, 'm') 139 v = self.get_slot(var, 'v') 140 return gen_training_ops.ResourceApplyAdaMax( 141 var=var.handle, 142 m=m.handle, 143 v=v.handle, 144 beta1_power=coefficients['beta_1_power'], 145 lr=coefficients['lr_t'], 146 beta1=coefficients['beta_1_t'], 147 beta2=coefficients['beta_2_t'], 148 epsilon=coefficients['epsilon'], 149 grad=grad, 150 use_locking=self._use_locking) 151 152 def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 153 var_device, var_dtype = var.device, var.dtype.base_dtype 154 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 155 or self._fallback_apply_state(var_device, var_dtype)) 156 157 # m_t = beta1 * m + (1 - beta1) * g_t 158 m = self.get_slot(var, 'm') 159 m_slice = array_ops.gather(m, indices, axis=coefficients['zero']) 160 m_t_slice = (m_slice * coefficients['beta_1_t'] + 161 grad * coefficients['one_minus_beta_1_t']) 162 with ops.control_dependencies([m_t_slice]): 163 m_t = self._resource_scatter_update(m, indices, m_t_slice) 164 165 # u_t = max(beta2 * u, abs(g_t)) 166 v = self.get_slot(var, 'v') 167 v_slice = array_ops.gather(v, indices, axis=coefficients['zero']) 168 v_t_slice = math_ops.maximum(v_slice * coefficients['beta_2_t'], 169 math_ops.abs(grad)) 170 with ops.control_dependencies([v_t_slice]): 171 v_t = self._resource_scatter_update(v, indices, v_t_slice) 172 # theta_t = theta - lr / (1 - beta1^t) * m_t / u_t 173 var_slice = coefficients['neg_scaled_lr'] * ( 174 m_t_slice / (v_t_slice + coefficients['epsilon'])) 175 with ops.control_dependencies([var_slice]): 176 var_update = self._resource_scatter_add(var, indices, var_slice) 177 return control_flow_ops.group(*[var_update, m_t, v_t]) 178 179 def get_config(self): 180 config = super(Adamax, self).get_config() 181 config.update({ 182 'learning_rate': self._serialize_hyperparameter('learning_rate'), 183 'decay': self._initial_decay, 184 'beta_1': self._serialize_hyperparameter('beta_1'), 185 'beta_2': self._serialize_hyperparameter('beta_2'), 186 'epsilon': self.epsilon, 187 }) 188 return config 189