1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Adadelta optimizer implementation.""" 16# pylint: disable=g-classes-have-attributes 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import ops 24from tensorflow.python.keras import backend_config 25from tensorflow.python.keras.optimizer_v2 import optimizer_v2 26from tensorflow.python.ops import array_ops 27from tensorflow.python.training import gen_training_ops 28from tensorflow.python.util.tf_export import keras_export 29 30 31@keras_export('keras.optimizers.Adadelta') 32class Adadelta(optimizer_v2.OptimizerV2): 33 r"""Optimizer that implements the Adadelta algorithm. 34 35 Adadelta optimization is a stochastic gradient descent method that is based on 36 adaptive learning rate per dimension to address two drawbacks: 37 38 - The continual decay of learning rates throughout training 39 - The need for a manually selected global learning rate 40 41 Adadelta is a more robust extension of Adagrad that adapts learning rates 42 based on a moving window of gradient updates, instead of accumulating all 43 past gradients. This way, Adadelta continues learning even when many updates 44 have been done. Compared to Adagrad, in the original version of Adadelta you 45 don't have to set an initial learning rate. In this version, initial 46 learning rate can be set, as in most other Keras optimizers. 47 48 According to section 4.3 ("Effective Learning rates"), near the end of 49 training step sizes converge to 1 which is effectively a high learning 50 rate which would cause divergence. This occurs only near the end of the 51 training as gradients and step sizes are small, and the epsilon constant 52 in the numerator and denominator dominate past gradients and parameter 53 updates which converge the learning rate to 1. 54 55 According to section 4.4("Speech Data"),where a large neural network with 56 4 hidden layers was trained on a corpus of US English data, ADADELTA was 57 used with 100 network replicas.The epsilon used is 1e-6 with rho=0.95 58 which converged faster than ADAGRAD, by the following construction: 59 def __init__(self, lr=1.0, rho=0.95, epsilon=1e-6, decay=0., **kwargs): 60 61 Args: 62 learning_rate: A `Tensor`, floating point value, or a schedule that is a 63 `tf.keras.optimizers.schedules.LearningRateSchedule`. The learning rate. 64 To match the exact form in the original paper use 1.0. 65 rho: A `Tensor` or a floating point value. The decay rate. 66 epsilon: A `Tensor` or a floating point value. A constant epsilon used 67 to better conditioning the grad update. 68 name: Optional name prefix for the operations created when applying 69 gradients. Defaults to `"Adadelta"`. 70 **kwargs: Keyword arguments. Allowed to be one of 71 `"clipnorm"` or `"clipvalue"`. 72 `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips 73 gradients by value. 74 75 Reference: 76 - [Zeiler, 2012](http://arxiv.org/abs/1212.5701) 77 """ 78 79 _HAS_AGGREGATE_GRAD = True 80 81 def __init__(self, 82 learning_rate=0.001, 83 rho=0.95, 84 epsilon=1e-7, 85 name='Adadelta', 86 **kwargs): 87 super(Adadelta, self).__init__(name, **kwargs) 88 self._set_hyper('learning_rate', kwargs.get('lr', learning_rate)) 89 self._set_hyper('decay', self._initial_decay) 90 self._set_hyper('rho', rho) 91 self.epsilon = epsilon or backend_config.epsilon() 92 93 def _create_slots(self, var_list): 94 # Separate for-loops to respect the ordering of slot variables from v1. 95 for v in var_list: 96 self.add_slot(v, 'accum_grad') 97 for v in var_list: 98 self.add_slot(v, 'accum_var') 99 100 def _prepare_local(self, var_device, var_dtype, apply_state): 101 super(Adadelta, self)._prepare_local(var_device, var_dtype, apply_state) 102 apply_state[(var_device, var_dtype)].update( 103 dict( 104 epsilon=ops.convert_to_tensor_v2_with_dispatch( 105 self.epsilon, var_dtype), 106 rho=array_ops.identity(self._get_hyper('rho', var_dtype)))) 107 108 def set_weights(self, weights): 109 params = self.weights 110 # Override set_weights for backward compatibility of Keras V1 optimizer 111 # since it does not include iteration at head of the weight list. Set 112 # iteration to 0. 113 if len(params) == len(weights) + 1: 114 weights = [np.array(0)] + weights 115 super(Adadelta, self).set_weights(weights) 116 117 def _resource_apply_dense(self, grad, var, apply_state=None): 118 var_device, var_dtype = var.device, var.dtype.base_dtype 119 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 120 or self._fallback_apply_state(var_device, var_dtype)) 121 122 accum_grad = self.get_slot(var, 'accum_grad') 123 accum_var = self.get_slot(var, 'accum_var') 124 return gen_training_ops.ResourceApplyAdadelta( 125 var=var.handle, 126 accum=accum_grad.handle, 127 accum_update=accum_var.handle, 128 lr=coefficients['lr_t'], 129 rho=coefficients['rho'], 130 epsilon=coefficients['epsilon'], 131 grad=grad, 132 use_locking=self._use_locking) 133 134 def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 135 var_device, var_dtype = var.device, var.dtype.base_dtype 136 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 137 or self._fallback_apply_state(var_device, var_dtype)) 138 139 accum_grad = self.get_slot(var, 'accum_grad') 140 accum_var = self.get_slot(var, 'accum_var') 141 return gen_training_ops.ResourceSparseApplyAdadelta( 142 var=var.handle, 143 accum=accum_grad.handle, 144 accum_update=accum_var.handle, 145 lr=coefficients['lr_t'], 146 rho=coefficients['rho'], 147 epsilon=coefficients['epsilon'], 148 grad=grad, 149 indices=indices, 150 use_locking=self._use_locking) 151 152 def get_config(self): 153 config = super(Adadelta, self).get_config() 154 config.update({ 155 'learning_rate': self._serialize_hyperparameter('learning_rate'), 156 'decay': self._initial_decay, 157 'rho': self._serialize_hyperparameter('rho'), 158 'epsilon': self.epsilon, 159 }) 160 return config 161