1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Nadam optimizer implementation.""" 16# pylint: disable=g-classes-have-attributes 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import ops 22from tensorflow.python.keras import backend_config 23from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule 24from tensorflow.python.keras.optimizer_v2 import optimizer_v2 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import control_flow_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.ops import state_ops 29from tensorflow.python.ops import variables as tf_variables 30from tensorflow.python.util.tf_export import keras_export 31 32 33@keras_export('keras.optimizers.Nadam') 34class Nadam(optimizer_v2.OptimizerV2): 35 r"""Optimizer that implements the NAdam algorithm. 36 Much like Adam is essentially RMSprop with momentum, Nadam is Adam with 37 Nesterov momentum. 38 39 Args: 40 learning_rate: A Tensor or a floating point value. The learning rate. 41 beta_1: A float value or a constant float tensor. The exponential decay 42 rate for the 1st moment estimates. 43 beta_2: A float value or a constant float tensor. The exponential decay 44 rate for the exponentially weighted infinity norm. 45 epsilon: A small constant for numerical stability. 46 name: Optional name for the operations created when applying gradients. 47 Defaults to `"Nadam"`. 48 **kwargs: Keyword arguments. Allowed to be one of 49 `"clipnorm"` or `"clipvalue"`. 50 `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips 51 gradients by value. 52 53 Usage Example: 54 >>> opt = tf.keras.optimizers.Nadam(learning_rate=0.2) 55 >>> var1 = tf.Variable(10.0) 56 >>> loss = lambda: (var1 ** 2) / 2.0 57 >>> step_count = opt.minimize(loss, [var1]).numpy() 58 >>> "{:.1f}".format(var1.numpy()) 59 9.8 60 61 Reference: 62 - [Dozat, 2015](http://cs229.stanford.edu/proj2015/054_report.pdf). 63 """ 64 65 _HAS_AGGREGATE_GRAD = True 66 67 def __init__(self, 68 learning_rate=0.001, 69 beta_1=0.9, 70 beta_2=0.999, 71 epsilon=1e-7, 72 name='Nadam', 73 **kwargs): 74 # Backwards compatibility with keras NAdam optimizer. 75 kwargs['decay'] = kwargs.pop('schedule_decay', 0.004) 76 learning_rate = kwargs.get('lr', learning_rate) 77 if isinstance(learning_rate, learning_rate_schedule.LearningRateSchedule): 78 raise ValueError('The Nadam optimizer does not support ' 79 'tf.keras.optimizers.LearningRateSchedules as the ' 80 'learning rate.') 81 82 super(Nadam, self).__init__(name, **kwargs) 83 self._set_hyper('learning_rate', kwargs.get('lr', learning_rate)) 84 self._set_hyper('decay', self._initial_decay) 85 self._set_hyper('beta_1', beta_1) 86 self._set_hyper('beta_2', beta_2) 87 self.epsilon = epsilon or backend_config.epsilon() 88 self._m_cache = None 89 90 def _create_slots(self, var_list): 91 var_dtype = var_list[0].dtype.base_dtype 92 if self._m_cache is None: 93 self._m_cache = self.add_weight( 94 'momentum_cache', 95 shape=[], 96 dtype=var_dtype, 97 initializer='ones', 98 trainable=False, 99 aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA) 100 self._weights.append(self._m_cache) 101 # Separate for-loops to respect the ordering of slot variables from v1. 102 for var in var_list: 103 # Create slots for the first moments. 104 self.add_slot(var, 'm') 105 for var in var_list: 106 # Create slots for the second moments. 107 self.add_slot(var, 'v') 108 109 def _prepare_local(self, var_device, var_dtype, apply_state): 110 lr_t = array_ops.identity(self._get_hyper('learning_rate', var_dtype)) 111 beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype)) 112 beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype)) 113 local_step = math_ops.cast(self.iterations + 1, var_dtype) 114 next_step = math_ops.cast(self.iterations + 2, var_dtype) 115 116 decay_base = math_ops.cast(0.96, var_dtype) 117 118 m_t = beta_1_t * (1. - 0.5 * ( 119 math_ops.pow(decay_base, self._initial_decay * local_step))) 120 m_t_1 = beta_1_t * (1. - 0.5 * ( 121 math_ops.pow(decay_base, self._initial_decay * next_step))) 122 123 m_schedule_new = math_ops.cast(self._m_cache_read, var_dtype) * m_t 124 if var_dtype is self._m_cache.dtype: 125 m_schedule_new = array_ops.identity(state_ops.assign( 126 self._m_cache, m_schedule_new, use_locking=self._use_locking)) 127 m_schedule_next = m_schedule_new * m_t_1 128 129 apply_state[(var_device, var_dtype)] = dict( 130 lr_t=lr_t, 131 neg_lr_t=-lr_t, 132 epsilon=ops.convert_to_tensor_v2_with_dispatch(self.epsilon, var_dtype), 133 beta_1_t=beta_1_t, 134 beta_2_t=beta_2_t, 135 m_t=m_t, 136 m_t_1=m_t_1, 137 one_minus_beta_1_t=1 - beta_1_t, 138 one_minus_beta_2_t=1 - beta_2_t, 139 one_minus_m_t=1. - m_t, 140 one_minus_m_schedule_new=1. - m_schedule_new, 141 one_minus_m_schedule_next=1. - m_schedule_next, 142 v_t_prime_denominator=1. - math_ops.pow(beta_2_t, local_step), 143 ) 144 145 def _prepare(self, var_list): 146 # Get the value of the momentum cache before starting to apply gradients. 147 self._m_cache_read = array_ops.identity(self._m_cache) 148 return super(Nadam, self)._prepare(var_list) 149 150 def _resource_apply_dense(self, grad, var, apply_state=None): 151 var_device, var_dtype = var.device, var.dtype.base_dtype 152 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 153 or self._fallback_apply_state(var_device, var_dtype)) 154 155 m = self.get_slot(var, 'm') 156 v = self.get_slot(var, 'v') 157 158 g_prime = grad / coefficients['one_minus_m_schedule_new'] 159 m_t = (coefficients['beta_1_t'] * m + 160 coefficients['one_minus_beta_1_t'] * grad) 161 m_t = state_ops.assign(m, m_t, use_locking=self._use_locking) 162 m_t_prime = m_t / coefficients['one_minus_m_schedule_next'] 163 v_t = (coefficients['beta_2_t'] * v + 164 coefficients['one_minus_beta_2_t'] * math_ops.square(grad)) 165 v_t = state_ops.assign(v, v_t, use_locking=self._use_locking) 166 v_t_prime = v_t / coefficients['v_t_prime_denominator'] 167 m_t_bar = (coefficients['one_minus_m_t'] * g_prime + 168 coefficients['m_t_1'] * m_t_prime) 169 var_t = var - coefficients['lr_t'] * m_t_bar / ( 170 math_ops.sqrt(v_t_prime) + coefficients['epsilon']) 171 return state_ops.assign(var, var_t, use_locking=self._use_locking).op 172 173 def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 174 var_device, var_dtype = var.device, var.dtype.base_dtype 175 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 176 or self._fallback_apply_state(var_device, var_dtype)) 177 178 m = self.get_slot(var, 'm') 179 v = self.get_slot(var, 'v') 180 181 g_prime = grad / coefficients['one_minus_m_schedule_new'] 182 183 # m_t = beta1 * m + (1 - beta1) * g_t 184 m_scaled_g_values = grad * coefficients['one_minus_beta_1_t'] 185 m_t = state_ops.assign(m, m * coefficients['beta_1_t'], 186 use_locking=self._use_locking) 187 188 with ops.control_dependencies([m_t]): 189 m_t = self._resource_scatter_add(m, indices, m_scaled_g_values) 190 m_t_slice = array_ops.gather(m_t, indices) 191 192 m_t_prime = m_t_slice / coefficients['one_minus_m_schedule_next'] 193 m_t_bar = (coefficients['one_minus_m_t'] * g_prime + 194 coefficients['m_t_1'] * m_t_prime) 195 196 # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) 197 v_scaled_g_values = (grad * grad) * coefficients['one_minus_beta_2_t'] 198 v_t = state_ops.assign(v, v * coefficients['beta_2_t'], 199 use_locking=self._use_locking) 200 201 with ops.control_dependencies([v_t]): 202 v_t = self._resource_scatter_add(v, indices, v_scaled_g_values) 203 v_t_slice = array_ops.gather(v_t, indices) 204 205 v_t_prime = v_t_slice / coefficients['v_t_prime_denominator'] 206 v_prime_sqrt_plus_eps = math_ops.sqrt(v_t_prime) + coefficients['epsilon'] 207 208 var_update = self._resource_scatter_add( 209 var, indices, 210 coefficients['neg_lr_t'] * m_t_bar / v_prime_sqrt_plus_eps) 211 return control_flow_ops.group(*[var_update, m_t_bar, v_t]) 212 213 def get_config(self): 214 config = super(Nadam, self).get_config() 215 config.update({ 216 'learning_rate': self._serialize_hyperparameter('learning_rate'), 217 'decay': self._initial_decay, 218 'beta_1': self._serialize_hyperparameter('beta_1'), 219 'beta_2': self._serialize_hyperparameter('beta_2'), 220 'epsilon': self.epsilon, 221 }) 222 return config 223