1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""SGD optimizer implementation.""" 16# pylint: disable=g-classes-have-attributes 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import ops 22from tensorflow.python.keras.optimizer_v2 import optimizer_v2 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import gen_resource_variable_ops 25from tensorflow.python.training import gen_training_ops 26from tensorflow.python.util.tf_export import keras_export 27 28 29@keras_export("keras.optimizers.SGD") 30class SGD(optimizer_v2.OptimizerV2): 31 r"""Gradient descent (with momentum) optimizer. 32 33 Update rule for parameter `w` with gradient `g` when `momentum` is 0: 34 35 ```python 36 w = w - learning_rate * g 37 ``` 38 39 Update rule when `momentum` is larger than 0: 40 41 ```python 42 velocity = momentum * velocity - learning_rate * g 43 w = w + velocity 44 ``` 45 46 When `nesterov=True`, this rule becomes: 47 48 ```python 49 velocity = momentum * velocity - learning_rate * g 50 w = w + momentum * velocity - learning_rate * g 51 ``` 52 53 Args: 54 learning_rate: A `Tensor`, floating point value, or a schedule that is a 55 `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable 56 that takes no arguments and returns the actual value to use. The 57 learning rate. Defaults to 0.01. 58 momentum: float hyperparameter >= 0 that accelerates gradient descent 59 in the relevant 60 direction and dampens oscillations. Defaults to 0, i.e., vanilla gradient 61 descent. 62 nesterov: boolean. Whether to apply Nesterov momentum. 63 Defaults to `False`. 64 name: Optional name prefix for the operations created when applying 65 gradients. Defaults to `"SGD"`. 66 **kwargs: Keyword arguments. Allowed to be one of 67 `"clipnorm"` or `"clipvalue"`. 68 `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips 69 gradients by value. 70 71 Usage: 72 73 >>> opt = tf.keras.optimizers.SGD(learning_rate=0.1) 74 >>> var = tf.Variable(1.0) 75 >>> loss = lambda: (var ** 2)/2.0 # d(loss)/d(var1) = var1 76 >>> step_count = opt.minimize(loss, [var]).numpy() 77 >>> # Step is `- learning_rate * grad` 78 >>> var.numpy() 79 0.9 80 81 >>> opt = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9) 82 >>> var = tf.Variable(1.0) 83 >>> val0 = var.value() 84 >>> loss = lambda: (var ** 2)/2.0 # d(loss)/d(var1) = var1 85 >>> # First step is `- learning_rate * grad` 86 >>> step_count = opt.minimize(loss, [var]).numpy() 87 >>> val1 = var.value() 88 >>> (val0 - val1).numpy() 89 0.1 90 >>> # On later steps, step-size increases because of momentum 91 >>> step_count = opt.minimize(loss, [var]).numpy() 92 >>> val2 = var.value() 93 >>> (val1 - val2).numpy() 94 0.18 95 96 Reference: 97 - For `nesterov=True`, See [Sutskever et al., 2013]( 98 http://jmlr.org/proceedings/papers/v28/sutskever13.pdf). 99 """ 100 101 _HAS_AGGREGATE_GRAD = True 102 103 def __init__(self, 104 learning_rate=0.01, 105 momentum=0.0, 106 nesterov=False, 107 name="SGD", 108 **kwargs): 109 super(SGD, self).__init__(name, **kwargs) 110 self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) 111 self._set_hyper("decay", self._initial_decay) 112 113 self._momentum = False 114 if isinstance(momentum, ops.Tensor) or callable(momentum) or momentum > 0: 115 self._momentum = True 116 if isinstance(momentum, (int, float)) and (momentum < 0 or momentum > 1): 117 raise ValueError("`momentum` must be between [0, 1].") 118 self._set_hyper("momentum", momentum) 119 120 self.nesterov = nesterov 121 122 def _create_slots(self, var_list): 123 if self._momentum: 124 for var in var_list: 125 self.add_slot(var, "momentum") 126 127 def _prepare_local(self, var_device, var_dtype, apply_state): 128 super(SGD, self)._prepare_local(var_device, var_dtype, apply_state) 129 apply_state[(var_device, var_dtype)]["momentum"] = array_ops.identity( 130 self._get_hyper("momentum", var_dtype)) 131 132 def _resource_apply_dense(self, grad, var, apply_state=None): 133 var_device, var_dtype = var.device, var.dtype.base_dtype 134 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 135 or self._fallback_apply_state(var_device, var_dtype)) 136 137 if self._momentum: 138 momentum_var = self.get_slot(var, "momentum") 139 return gen_training_ops.ResourceApplyKerasMomentum( 140 var=var.handle, 141 accum=momentum_var.handle, 142 lr=coefficients["lr_t"], 143 grad=grad, 144 momentum=coefficients["momentum"], 145 use_locking=self._use_locking, 146 use_nesterov=self.nesterov) 147 else: 148 return gen_training_ops.ResourceApplyGradientDescent( 149 var=var.handle, 150 alpha=coefficients["lr_t"], 151 delta=grad, 152 use_locking=self._use_locking) 153 154 def _resource_apply_sparse_duplicate_indices(self, grad, var, indices, 155 **kwargs): 156 if self._momentum: 157 return super(SGD, self)._resource_apply_sparse_duplicate_indices( 158 grad, var, indices, **kwargs) 159 else: 160 var_device, var_dtype = var.device, var.dtype.base_dtype 161 coefficients = (kwargs.get("apply_state", {}).get((var_device, var_dtype)) 162 or self._fallback_apply_state(var_device, var_dtype)) 163 164 return gen_resource_variable_ops.ResourceScatterAdd( 165 resource=var.handle, 166 indices=indices, 167 updates=-grad * coefficients["lr_t"]) 168 169 def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 170 # This method is only needed for momentum optimization. 171 var_device, var_dtype = var.device, var.dtype.base_dtype 172 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 173 or self._fallback_apply_state(var_device, var_dtype)) 174 175 momentum_var = self.get_slot(var, "momentum") 176 return gen_training_ops.ResourceSparseApplyKerasMomentum( 177 var=var.handle, 178 accum=momentum_var.handle, 179 lr=coefficients["lr_t"], 180 grad=grad, 181 indices=indices, 182 momentum=coefficients["momentum"], 183 use_locking=self._use_locking, 184 use_nesterov=self.nesterov) 185 186 def get_config(self): 187 config = super(SGD, self).get_config() 188 config.update({ 189 "learning_rate": self._serialize_hyperparameter("learning_rate"), 190 "decay": self._initial_decay, 191 "momentum": self._serialize_hyperparameter("momentum"), 192 "nesterov": self.nesterov, 193 }) 194 return config 195