1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""RMSprop optimizer implementation."""
16# pylint: disable=g-classes-have-attributes
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import ops
24from tensorflow.python.keras import backend_config
25from tensorflow.python.keras.optimizer_v2 import optimizer_v2
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import control_flow_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import state_ops
30from tensorflow.python.training import gen_training_ops
31from tensorflow.python.util.tf_export import keras_export
32
33
34@keras_export("keras.optimizers.RMSprop")
35class RMSprop(optimizer_v2.OptimizerV2):
36  r"""Optimizer that implements the RMSprop algorithm.
37
38  The gist of RMSprop is to:
39
40  - Maintain a moving (discounted) average of the square of gradients
41  - Divide the gradient by the root of this average
42
43  This implementation of RMSprop uses plain momentum, not Nesterov momentum.
44
45  The centered version additionally maintains a moving average of the
46  gradients, and uses that average to estimate the variance.
47
48  Args:
49    learning_rate: A `Tensor`, floating point value, or a schedule that is a
50      `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable
51      that takes no arguments and returns the actual value to use. The
52      learning rate. Defaults to 0.001.
53    rho: Discounting factor for the history/coming gradient. Defaults to 0.9.
54    momentum: A scalar or a scalar `Tensor`. Defaults to 0.0.
55    epsilon: A small constant for numerical stability. This epsilon is
56      "epsilon hat" in the Kingma and Ba paper (in the formula just before
57      Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults to
58      1e-7.
59    centered: Boolean. If `True`, gradients are normalized by the estimated
60      variance of the gradient; if False, by the uncentered second moment.
61      Setting this to `True` may help with training, but is slightly more
62      expensive in terms of computation and memory. Defaults to `False`.
63    name: Optional name prefix for the operations created when applying
64      gradients. Defaults to `"RMSprop"`.
65    **kwargs: Keyword arguments. Allowed to be one of
66      `"clipnorm"` or `"clipvalue"`.
67      `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips
68      gradients by value.
69
70  Note that in the dense implementation of this algorithm, variables and their
71  corresponding accumulators (momentum, gradient moving average, square
72  gradient moving average) will be updated even if the gradient is zero
73  (i.e. accumulators will decay, momentum will be applied). The sparse
74  implementation (used when the gradient is an `IndexedSlices` object,
75  typically because of `tf.gather` or an embedding lookup in the forward pass)
76  will not update variable slices or their accumulators unless those slices
77  were used in the forward pass (nor is there an "eventual" correction to
78  account for these omitted updates). This leads to more efficient updates for
79  large embedding lookup tables (where most of the slices are not accessed in
80  a particular graph execution), but differs from the published algorithm.
81
82  Usage:
83
84  >>> opt = tf.keras.optimizers.RMSprop(learning_rate=0.1)
85  >>> var1 = tf.Variable(10.0)
86  >>> loss = lambda: (var1 ** 2) / 2.0    # d(loss) / d(var1) = var1
87  >>> step_count = opt.minimize(loss, [var1]).numpy()
88  >>> var1.numpy()
89  9.683772
90
91  Reference:
92    - [Hinton, 2012](
93      http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
94  """
95
96  _HAS_AGGREGATE_GRAD = True
97
98  def __init__(self,
99               learning_rate=0.001,
100               rho=0.9,
101               momentum=0.0,
102               epsilon=1e-7,
103               centered=False,
104               name="RMSprop",
105               **kwargs):
106    """Construct a new RMSprop optimizer.
107
108    Args:
109      learning_rate: A `Tensor`, floating point value, or a schedule that is a
110        `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable
111        that takes no arguments and returns the actual value to use. The
112        learning rate. Defaults to 0.001.
113      rho: Discounting factor for the history/coming gradient. Defaults to 0.9.
114      momentum: A scalar or a scalar `Tensor`. Defaults to 0.0.
115      epsilon: A small constant for numerical stability. This epsilon is
116        "epsilon hat" in the Kingma and Ba paper (in the formula just before
117        Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults to
118        1e-7.
119      centered: Boolean. If `True`, gradients are normalized by the estimated
120        variance of the gradient; if False, by the uncentered second moment.
121        Setting this to `True` may help with training, but is slightly more
122        expensive in terms of computation and memory. Defaults to `False`.
123      name: Optional name prefix for the operations created when applying
124        gradients. Defaults to "RMSprop".
125      **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`,
126        `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip
127        gradients by value, `decay` is included for backward compatibility to
128        allow time inverse decay of learning rate. `lr` is included for backward
129        compatibility, recommended to use `learning_rate` instead.
130
131    @compatibility(eager)
132    When eager execution is enabled, `learning_rate`, `decay`, `momentum`, and
133    `epsilon` can each be a callable that takes no arguments and returns the
134    actual value to use. This can be useful for changing these values across
135    different invocations of optimizer functions.
136    @end_compatibility
137    """
138    super(RMSprop, self).__init__(name, **kwargs)
139    self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
140    self._set_hyper("decay", self._initial_decay)
141    self._set_hyper("rho", rho)
142
143    self._momentum = False
144    if isinstance(momentum, ops.Tensor) or callable(momentum) or momentum > 0:
145      self._momentum = True
146    if isinstance(momentum, (int, float)) and (momentum < 0 or momentum > 1):
147      raise ValueError("`momentum` must be between [0, 1].")
148    self._set_hyper("momentum", momentum)
149
150    self.epsilon = epsilon or backend_config.epsilon()
151    self.centered = centered
152
153  def _create_slots(self, var_list):
154    for var in var_list:
155      self.add_slot(var, "rms")
156    if self._momentum:
157      for var in var_list:
158        self.add_slot(var, "momentum")
159    if self.centered:
160      for var in var_list:
161        self.add_slot(var, "mg")
162
163  def _prepare_local(self, var_device, var_dtype, apply_state):
164    super(RMSprop, self)._prepare_local(var_device, var_dtype, apply_state)
165
166    rho = array_ops.identity(self._get_hyper("rho", var_dtype))
167    apply_state[(var_device, var_dtype)].update(
168        dict(
169            neg_lr_t=-apply_state[(var_device, var_dtype)]["lr_t"],
170            epsilon=ops.convert_to_tensor_v2_with_dispatch(
171                self.epsilon, var_dtype),
172            rho=rho,
173            momentum=array_ops.identity(self._get_hyper("momentum", var_dtype)),
174            one_minus_rho=1. - rho))
175
176  def _resource_apply_dense(self, grad, var, apply_state=None):
177    var_device, var_dtype = var.device, var.dtype.base_dtype
178    coefficients = ((apply_state or {}).get((var_device, var_dtype))
179                    or self._fallback_apply_state(var_device, var_dtype))
180
181    rms = self.get_slot(var, "rms")
182    if self._momentum:
183      mom = self.get_slot(var, "momentum")
184      if self.centered:
185        mg = self.get_slot(var, "mg")
186        return gen_training_ops.ResourceApplyCenteredRMSProp(
187            var=var.handle,
188            mg=mg.handle,
189            ms=rms.handle,
190            mom=mom.handle,
191            lr=coefficients["lr_t"],
192            rho=coefficients["rho"],
193            momentum=coefficients["momentum"],
194            epsilon=coefficients["epsilon"],
195            grad=grad,
196            use_locking=self._use_locking)
197      else:
198        return gen_training_ops.ResourceApplyRMSProp(
199            var=var.handle,
200            ms=rms.handle,
201            mom=mom.handle,
202            lr=coefficients["lr_t"],
203            rho=coefficients["rho"],
204            momentum=coefficients["momentum"],
205            epsilon=coefficients["epsilon"],
206            grad=grad,
207            use_locking=self._use_locking)
208    else:
209      rms_t = (coefficients["rho"] * rms +
210               coefficients["one_minus_rho"] * math_ops.square(grad))
211      rms_t = state_ops.assign(rms, rms_t, use_locking=self._use_locking)
212      denom_t = rms_t
213      if self.centered:
214        mg = self.get_slot(var, "mg")
215        mg_t = coefficients["rho"] * mg + coefficients["one_minus_rho"] * grad
216        mg_t = state_ops.assign(mg, mg_t, use_locking=self._use_locking)
217        denom_t = rms_t - math_ops.square(mg_t)
218      var_t = var - coefficients["lr_t"] * grad / (
219          math_ops.sqrt(denom_t) + coefficients["epsilon"])
220      return state_ops.assign(var, var_t, use_locking=self._use_locking).op
221
222  def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
223    var_device, var_dtype = var.device, var.dtype.base_dtype
224    coefficients = ((apply_state or {}).get((var_device, var_dtype))
225                    or self._fallback_apply_state(var_device, var_dtype))
226
227    rms = self.get_slot(var, "rms")
228    if self._momentum:
229      mom = self.get_slot(var, "momentum")
230      if self.centered:
231        mg = self.get_slot(var, "mg")
232        return gen_training_ops.ResourceSparseApplyCenteredRMSProp(
233            var=var.handle,
234            mg=mg.handle,
235            ms=rms.handle,
236            mom=mom.handle,
237            lr=coefficients["lr_t"],
238            rho=coefficients["rho"],
239            momentum=coefficients["momentum"],
240            epsilon=coefficients["epsilon"],
241            grad=grad,
242            indices=indices,
243            use_locking=self._use_locking)
244      else:
245        return gen_training_ops.ResourceSparseApplyRMSProp(
246            var=var.handle,
247            ms=rms.handle,
248            mom=mom.handle,
249            lr=coefficients["lr_t"],
250            rho=coefficients["rho"],
251            momentum=coefficients["momentum"],
252            epsilon=coefficients["epsilon"],
253            grad=grad,
254            indices=indices,
255            use_locking=self._use_locking)
256    else:
257      rms_scaled_g_values = (grad * grad) * coefficients["one_minus_rho"]
258      rms_t = state_ops.assign(rms, rms * coefficients["rho"],
259                               use_locking=self._use_locking)
260      with ops.control_dependencies([rms_t]):
261        rms_t = self._resource_scatter_add(rms, indices, rms_scaled_g_values)
262        rms_slice = array_ops.gather(rms_t, indices)
263      denom_slice = rms_slice
264      if self.centered:
265        mg = self.get_slot(var, "mg")
266        mg_scaled_g_values = grad * coefficients["one_minus_rho"]
267        mg_t = state_ops.assign(mg, mg * coefficients["rho"],
268                                use_locking=self._use_locking)
269        with ops.control_dependencies([mg_t]):
270          mg_t = self._resource_scatter_add(mg, indices, mg_scaled_g_values)
271          mg_slice = array_ops.gather(mg_t, indices)
272          denom_slice = rms_slice - math_ops.square(mg_slice)
273      var_update = self._resource_scatter_add(
274          var, indices, coefficients["neg_lr_t"] * grad / (
275              math_ops.sqrt(denom_slice) + coefficients["epsilon"]))
276      if self.centered:
277        return control_flow_ops.group(*[var_update, rms_t, mg_t])
278      return control_flow_ops.group(*[var_update, rms_t])
279
280  def set_weights(self, weights):
281    params = self.weights
282    # Override set_weights for backward compatibility of Keras V1 optimizer
283    # since it does not include iteration at head of the weight list. Set
284    # iteration to 0.
285    if len(params) == len(weights) + 1:
286      weights = [np.array(0)] + weights
287    super(RMSprop, self).set_weights(weights)
288
289  def get_config(self):
290    config = super(RMSprop, self).get_config()
291    config.update({
292        "learning_rate": self._serialize_hyperparameter("learning_rate"),
293        "decay": self._initial_decay,
294        "rho": self._serialize_hyperparameter("rho"),
295        "momentum": self._serialize_hyperparameter("momentum"),
296        "epsilon": self.epsilon,
297        "centered": self.centered,
298    })
299    return config
300
301
302RMSProp = RMSprop
303