1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""RMSprop optimizer implementation.""" 16# pylint: disable=g-classes-have-attributes 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import ops 24from tensorflow.python.keras import backend_config 25from tensorflow.python.keras.optimizer_v2 import optimizer_v2 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import control_flow_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops import state_ops 30from tensorflow.python.training import gen_training_ops 31from tensorflow.python.util.tf_export import keras_export 32 33 34@keras_export("keras.optimizers.RMSprop") 35class RMSprop(optimizer_v2.OptimizerV2): 36 r"""Optimizer that implements the RMSprop algorithm. 37 38 The gist of RMSprop is to: 39 40 - Maintain a moving (discounted) average of the square of gradients 41 - Divide the gradient by the root of this average 42 43 This implementation of RMSprop uses plain momentum, not Nesterov momentum. 44 45 The centered version additionally maintains a moving average of the 46 gradients, and uses that average to estimate the variance. 47 48 Args: 49 learning_rate: A `Tensor`, floating point value, or a schedule that is a 50 `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable 51 that takes no arguments and returns the actual value to use. The 52 learning rate. Defaults to 0.001. 53 rho: Discounting factor for the history/coming gradient. Defaults to 0.9. 54 momentum: A scalar or a scalar `Tensor`. Defaults to 0.0. 55 epsilon: A small constant for numerical stability. This epsilon is 56 "epsilon hat" in the Kingma and Ba paper (in the formula just before 57 Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults to 58 1e-7. 59 centered: Boolean. If `True`, gradients are normalized by the estimated 60 variance of the gradient; if False, by the uncentered second moment. 61 Setting this to `True` may help with training, but is slightly more 62 expensive in terms of computation and memory. Defaults to `False`. 63 name: Optional name prefix for the operations created when applying 64 gradients. Defaults to `"RMSprop"`. 65 **kwargs: Keyword arguments. Allowed to be one of 66 `"clipnorm"` or `"clipvalue"`. 67 `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips 68 gradients by value. 69 70 Note that in the dense implementation of this algorithm, variables and their 71 corresponding accumulators (momentum, gradient moving average, square 72 gradient moving average) will be updated even if the gradient is zero 73 (i.e. accumulators will decay, momentum will be applied). The sparse 74 implementation (used when the gradient is an `IndexedSlices` object, 75 typically because of `tf.gather` or an embedding lookup in the forward pass) 76 will not update variable slices or their accumulators unless those slices 77 were used in the forward pass (nor is there an "eventual" correction to 78 account for these omitted updates). This leads to more efficient updates for 79 large embedding lookup tables (where most of the slices are not accessed in 80 a particular graph execution), but differs from the published algorithm. 81 82 Usage: 83 84 >>> opt = tf.keras.optimizers.RMSprop(learning_rate=0.1) 85 >>> var1 = tf.Variable(10.0) 86 >>> loss = lambda: (var1 ** 2) / 2.0 # d(loss) / d(var1) = var1 87 >>> step_count = opt.minimize(loss, [var1]).numpy() 88 >>> var1.numpy() 89 9.683772 90 91 Reference: 92 - [Hinton, 2012]( 93 http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) 94 """ 95 96 _HAS_AGGREGATE_GRAD = True 97 98 def __init__(self, 99 learning_rate=0.001, 100 rho=0.9, 101 momentum=0.0, 102 epsilon=1e-7, 103 centered=False, 104 name="RMSprop", 105 **kwargs): 106 """Construct a new RMSprop optimizer. 107 108 Args: 109 learning_rate: A `Tensor`, floating point value, or a schedule that is a 110 `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable 111 that takes no arguments and returns the actual value to use. The 112 learning rate. Defaults to 0.001. 113 rho: Discounting factor for the history/coming gradient. Defaults to 0.9. 114 momentum: A scalar or a scalar `Tensor`. Defaults to 0.0. 115 epsilon: A small constant for numerical stability. This epsilon is 116 "epsilon hat" in the Kingma and Ba paper (in the formula just before 117 Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults to 118 1e-7. 119 centered: Boolean. If `True`, gradients are normalized by the estimated 120 variance of the gradient; if False, by the uncentered second moment. 121 Setting this to `True` may help with training, but is slightly more 122 expensive in terms of computation and memory. Defaults to `False`. 123 name: Optional name prefix for the operations created when applying 124 gradients. Defaults to "RMSprop". 125 **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, 126 `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip 127 gradients by value, `decay` is included for backward compatibility to 128 allow time inverse decay of learning rate. `lr` is included for backward 129 compatibility, recommended to use `learning_rate` instead. 130 131 @compatibility(eager) 132 When eager execution is enabled, `learning_rate`, `decay`, `momentum`, and 133 `epsilon` can each be a callable that takes no arguments and returns the 134 actual value to use. This can be useful for changing these values across 135 different invocations of optimizer functions. 136 @end_compatibility 137 """ 138 super(RMSprop, self).__init__(name, **kwargs) 139 self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) 140 self._set_hyper("decay", self._initial_decay) 141 self._set_hyper("rho", rho) 142 143 self._momentum = False 144 if isinstance(momentum, ops.Tensor) or callable(momentum) or momentum > 0: 145 self._momentum = True 146 if isinstance(momentum, (int, float)) and (momentum < 0 or momentum > 1): 147 raise ValueError("`momentum` must be between [0, 1].") 148 self._set_hyper("momentum", momentum) 149 150 self.epsilon = epsilon or backend_config.epsilon() 151 self.centered = centered 152 153 def _create_slots(self, var_list): 154 for var in var_list: 155 self.add_slot(var, "rms") 156 if self._momentum: 157 for var in var_list: 158 self.add_slot(var, "momentum") 159 if self.centered: 160 for var in var_list: 161 self.add_slot(var, "mg") 162 163 def _prepare_local(self, var_device, var_dtype, apply_state): 164 super(RMSprop, self)._prepare_local(var_device, var_dtype, apply_state) 165 166 rho = array_ops.identity(self._get_hyper("rho", var_dtype)) 167 apply_state[(var_device, var_dtype)].update( 168 dict( 169 neg_lr_t=-apply_state[(var_device, var_dtype)]["lr_t"], 170 epsilon=ops.convert_to_tensor_v2_with_dispatch( 171 self.epsilon, var_dtype), 172 rho=rho, 173 momentum=array_ops.identity(self._get_hyper("momentum", var_dtype)), 174 one_minus_rho=1. - rho)) 175 176 def _resource_apply_dense(self, grad, var, apply_state=None): 177 var_device, var_dtype = var.device, var.dtype.base_dtype 178 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 179 or self._fallback_apply_state(var_device, var_dtype)) 180 181 rms = self.get_slot(var, "rms") 182 if self._momentum: 183 mom = self.get_slot(var, "momentum") 184 if self.centered: 185 mg = self.get_slot(var, "mg") 186 return gen_training_ops.ResourceApplyCenteredRMSProp( 187 var=var.handle, 188 mg=mg.handle, 189 ms=rms.handle, 190 mom=mom.handle, 191 lr=coefficients["lr_t"], 192 rho=coefficients["rho"], 193 momentum=coefficients["momentum"], 194 epsilon=coefficients["epsilon"], 195 grad=grad, 196 use_locking=self._use_locking) 197 else: 198 return gen_training_ops.ResourceApplyRMSProp( 199 var=var.handle, 200 ms=rms.handle, 201 mom=mom.handle, 202 lr=coefficients["lr_t"], 203 rho=coefficients["rho"], 204 momentum=coefficients["momentum"], 205 epsilon=coefficients["epsilon"], 206 grad=grad, 207 use_locking=self._use_locking) 208 else: 209 rms_t = (coefficients["rho"] * rms + 210 coefficients["one_minus_rho"] * math_ops.square(grad)) 211 rms_t = state_ops.assign(rms, rms_t, use_locking=self._use_locking) 212 denom_t = rms_t 213 if self.centered: 214 mg = self.get_slot(var, "mg") 215 mg_t = coefficients["rho"] * mg + coefficients["one_minus_rho"] * grad 216 mg_t = state_ops.assign(mg, mg_t, use_locking=self._use_locking) 217 denom_t = rms_t - math_ops.square(mg_t) 218 var_t = var - coefficients["lr_t"] * grad / ( 219 math_ops.sqrt(denom_t) + coefficients["epsilon"]) 220 return state_ops.assign(var, var_t, use_locking=self._use_locking).op 221 222 def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 223 var_device, var_dtype = var.device, var.dtype.base_dtype 224 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 225 or self._fallback_apply_state(var_device, var_dtype)) 226 227 rms = self.get_slot(var, "rms") 228 if self._momentum: 229 mom = self.get_slot(var, "momentum") 230 if self.centered: 231 mg = self.get_slot(var, "mg") 232 return gen_training_ops.ResourceSparseApplyCenteredRMSProp( 233 var=var.handle, 234 mg=mg.handle, 235 ms=rms.handle, 236 mom=mom.handle, 237 lr=coefficients["lr_t"], 238 rho=coefficients["rho"], 239 momentum=coefficients["momentum"], 240 epsilon=coefficients["epsilon"], 241 grad=grad, 242 indices=indices, 243 use_locking=self._use_locking) 244 else: 245 return gen_training_ops.ResourceSparseApplyRMSProp( 246 var=var.handle, 247 ms=rms.handle, 248 mom=mom.handle, 249 lr=coefficients["lr_t"], 250 rho=coefficients["rho"], 251 momentum=coefficients["momentum"], 252 epsilon=coefficients["epsilon"], 253 grad=grad, 254 indices=indices, 255 use_locking=self._use_locking) 256 else: 257 rms_scaled_g_values = (grad * grad) * coefficients["one_minus_rho"] 258 rms_t = state_ops.assign(rms, rms * coefficients["rho"], 259 use_locking=self._use_locking) 260 with ops.control_dependencies([rms_t]): 261 rms_t = self._resource_scatter_add(rms, indices, rms_scaled_g_values) 262 rms_slice = array_ops.gather(rms_t, indices) 263 denom_slice = rms_slice 264 if self.centered: 265 mg = self.get_slot(var, "mg") 266 mg_scaled_g_values = grad * coefficients["one_minus_rho"] 267 mg_t = state_ops.assign(mg, mg * coefficients["rho"], 268 use_locking=self._use_locking) 269 with ops.control_dependencies([mg_t]): 270 mg_t = self._resource_scatter_add(mg, indices, mg_scaled_g_values) 271 mg_slice = array_ops.gather(mg_t, indices) 272 denom_slice = rms_slice - math_ops.square(mg_slice) 273 var_update = self._resource_scatter_add( 274 var, indices, coefficients["neg_lr_t"] * grad / ( 275 math_ops.sqrt(denom_slice) + coefficients["epsilon"])) 276 if self.centered: 277 return control_flow_ops.group(*[var_update, rms_t, mg_t]) 278 return control_flow_ops.group(*[var_update, rms_t]) 279 280 def set_weights(self, weights): 281 params = self.weights 282 # Override set_weights for backward compatibility of Keras V1 optimizer 283 # since it does not include iteration at head of the weight list. Set 284 # iteration to 0. 285 if len(params) == len(weights) + 1: 286 weights = [np.array(0)] + weights 287 super(RMSprop, self).set_weights(weights) 288 289 def get_config(self): 290 config = super(RMSprop, self).get_config() 291 config.update({ 292 "learning_rate": self._serialize_hyperparameter("learning_rate"), 293 "decay": self._initial_decay, 294 "rho": self._serialize_hyperparameter("rho"), 295 "momentum": self._serialize_hyperparameter("momentum"), 296 "epsilon": self.epsilon, 297 "centered": self.centered, 298 }) 299 return config 300 301 302RMSProp = RMSprop 303