1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""SGD optimizer implementation."""
16# pylint: disable=g-classes-have-attributes
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import ops
22from tensorflow.python.keras.optimizer_v2 import optimizer_v2
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import gen_resource_variable_ops
25from tensorflow.python.training import gen_training_ops
26from tensorflow.python.util.tf_export import keras_export
27
28
29@keras_export("keras.optimizers.SGD")
30class SGD(optimizer_v2.OptimizerV2):
31  r"""Gradient descent (with momentum) optimizer.
32
33  Update rule for parameter `w` with gradient `g` when `momentum` is 0:
34
35  ```python
36  w = w - learning_rate * g
37  ```
38
39  Update rule when `momentum` is larger than 0:
40
41  ```python
42  velocity = momentum * velocity - learning_rate * g
43  w = w + velocity
44  ```
45
46  When `nesterov=True`, this rule becomes:
47
48  ```python
49  velocity = momentum * velocity - learning_rate * g
50  w = w + momentum * velocity - learning_rate * g
51  ```
52
53  Args:
54    learning_rate: A `Tensor`, floating point value, or a schedule that is a
55      `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable
56      that takes no arguments and returns the actual value to use. The
57      learning rate. Defaults to 0.01.
58    momentum: float hyperparameter >= 0 that accelerates gradient descent
59      in the relevant
60      direction and dampens oscillations. Defaults to 0, i.e., vanilla gradient
61      descent.
62    nesterov: boolean. Whether to apply Nesterov momentum.
63      Defaults to `False`.
64    name: Optional name prefix for the operations created when applying
65      gradients.  Defaults to `"SGD"`.
66    **kwargs: Keyword arguments. Allowed to be one of
67      `"clipnorm"` or `"clipvalue"`.
68      `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips
69      gradients by value.
70
71  Usage:
72
73  >>> opt = tf.keras.optimizers.SGD(learning_rate=0.1)
74  >>> var = tf.Variable(1.0)
75  >>> loss = lambda: (var ** 2)/2.0         # d(loss)/d(var1) = var1
76  >>> step_count = opt.minimize(loss, [var]).numpy()
77  >>> # Step is `- learning_rate * grad`
78  >>> var.numpy()
79  0.9
80
81  >>> opt = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9)
82  >>> var = tf.Variable(1.0)
83  >>> val0 = var.value()
84  >>> loss = lambda: (var ** 2)/2.0         # d(loss)/d(var1) = var1
85  >>> # First step is `- learning_rate * grad`
86  >>> step_count = opt.minimize(loss, [var]).numpy()
87  >>> val1 = var.value()
88  >>> (val0 - val1).numpy()
89  0.1
90  >>> # On later steps, step-size increases because of momentum
91  >>> step_count = opt.minimize(loss, [var]).numpy()
92  >>> val2 = var.value()
93  >>> (val1 - val2).numpy()
94  0.18
95
96  Reference:
97      - For `nesterov=True`, See [Sutskever et al., 2013](
98        http://jmlr.org/proceedings/papers/v28/sutskever13.pdf).
99  """
100
101  _HAS_AGGREGATE_GRAD = True
102
103  def __init__(self,
104               learning_rate=0.01,
105               momentum=0.0,
106               nesterov=False,
107               name="SGD",
108               **kwargs):
109    super(SGD, self).__init__(name, **kwargs)
110    self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
111    self._set_hyper("decay", self._initial_decay)
112
113    self._momentum = False
114    if isinstance(momentum, ops.Tensor) or callable(momentum) or momentum > 0:
115      self._momentum = True
116    if isinstance(momentum, (int, float)) and (momentum < 0 or momentum > 1):
117      raise ValueError("`momentum` must be between [0, 1].")
118    self._set_hyper("momentum", momentum)
119
120    self.nesterov = nesterov
121
122  def _create_slots(self, var_list):
123    if self._momentum:
124      for var in var_list:
125        self.add_slot(var, "momentum")
126
127  def _prepare_local(self, var_device, var_dtype, apply_state):
128    super(SGD, self)._prepare_local(var_device, var_dtype, apply_state)
129    apply_state[(var_device, var_dtype)]["momentum"] = array_ops.identity(
130        self._get_hyper("momentum", var_dtype))
131
132  def _resource_apply_dense(self, grad, var, apply_state=None):
133    var_device, var_dtype = var.device, var.dtype.base_dtype
134    coefficients = ((apply_state or {}).get((var_device, var_dtype))
135                    or self._fallback_apply_state(var_device, var_dtype))
136
137    if self._momentum:
138      momentum_var = self.get_slot(var, "momentum")
139      return gen_training_ops.ResourceApplyKerasMomentum(
140          var=var.handle,
141          accum=momentum_var.handle,
142          lr=coefficients["lr_t"],
143          grad=grad,
144          momentum=coefficients["momentum"],
145          use_locking=self._use_locking,
146          use_nesterov=self.nesterov)
147    else:
148      return gen_training_ops.ResourceApplyGradientDescent(
149          var=var.handle,
150          alpha=coefficients["lr_t"],
151          delta=grad,
152          use_locking=self._use_locking)
153
154  def _resource_apply_sparse_duplicate_indices(self, grad, var, indices,
155                                               **kwargs):
156    if self._momentum:
157      return super(SGD, self)._resource_apply_sparse_duplicate_indices(
158          grad, var, indices, **kwargs)
159    else:
160      var_device, var_dtype = var.device, var.dtype.base_dtype
161      coefficients = (kwargs.get("apply_state", {}).get((var_device, var_dtype))
162                      or self._fallback_apply_state(var_device, var_dtype))
163
164      return gen_resource_variable_ops.ResourceScatterAdd(
165          resource=var.handle,
166          indices=indices,
167          updates=-grad * coefficients["lr_t"])
168
169  def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
170    # This method is only needed for momentum optimization.
171    var_device, var_dtype = var.device, var.dtype.base_dtype
172    coefficients = ((apply_state or {}).get((var_device, var_dtype))
173                    or self._fallback_apply_state(var_device, var_dtype))
174
175    momentum_var = self.get_slot(var, "momentum")
176    return gen_training_ops.ResourceSparseApplyKerasMomentum(
177        var=var.handle,
178        accum=momentum_var.handle,
179        lr=coefficients["lr_t"],
180        grad=grad,
181        indices=indices,
182        momentum=coefficients["momentum"],
183        use_locking=self._use_locking,
184        use_nesterov=self.nesterov)
185
186  def get_config(self):
187    config = super(SGD, self).get_config()
188    config.update({
189        "learning_rate": self._serialize_hyperparameter("learning_rate"),
190        "decay": self._initial_decay,
191        "momentum": self._serialize_hyperparameter("momentum"),
192        "nesterov": self.nesterov,
193    })
194    return config
195