1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Momentum for TensorFlow.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import ops 22from tensorflow.python.ops import math_ops 23from tensorflow.python.training import optimizer 24from tensorflow.python.training import training_ops 25from tensorflow.python.util.tf_export import tf_export 26 27 28@tf_export("train.MomentumOptimizer") 29class MomentumOptimizer(optimizer.Optimizer): 30 """Optimizer that implements the Momentum algorithm. 31 32 Computes (if `use_nesterov = False`): 33 34 ``` 35 accumulation = momentum * accumulation + gradient 36 variable -= learning_rate * accumulation 37 ``` 38 39 Note that in the dense version of this algorithm, `accumulation` is updated 40 and applied regardless of a gradient's value, whereas the sparse version (when 41 the gradient is an `IndexedSlices`, typically because of `tf.gather` or an 42 embedding) only updates variable slices and corresponding `accumulation` terms 43 when that part of the variable was used in the forward pass. 44 """ 45 46 def __init__(self, learning_rate, momentum, 47 use_locking=False, name="Momentum", use_nesterov=False): 48 """Construct a new Momentum optimizer. 49 50 Args: 51 learning_rate: A `Tensor` or a floating point value. The learning rate. 52 momentum: A `Tensor` or a floating point value. The momentum. 53 use_locking: If `True` use locks for update operations. 54 name: Optional name prefix for the operations created when applying 55 gradients. Defaults to "Momentum". 56 use_nesterov: If `True` use Nesterov Momentum. 57 See [Sutskever et al., 2013]( 58 http://jmlr.org/proceedings/papers/v28/sutskever13.pdf). 59 This implementation always computes gradients at the value of the 60 variable(s) passed to the optimizer. Using Nesterov Momentum makes the 61 variable(s) track the values called `theta_t + mu*v_t` in the paper. 62 63 @compatibility(eager) 64 When eager execution is enabled, learning_rate and momentum can each be a 65 callable that takes no arguments and returns the actual value to use. This 66 can be useful for changing these values across different invocations of 67 optimizer functions. 68 @end_compatibility 69 """ 70 super(MomentumOptimizer, self).__init__(use_locking, name) 71 self._learning_rate = learning_rate 72 self._momentum = momentum 73 self._use_nesterov = use_nesterov 74 75 def _create_slots(self, var_list): 76 for v in var_list: 77 self._zeros_slot(v, "momentum", self._name) 78 79 def _prepare(self): 80 learning_rate = self._learning_rate 81 if callable(learning_rate): 82 learning_rate = learning_rate() 83 self._learning_rate_tensor = ops.convert_to_tensor(learning_rate, 84 name="learning_rate") 85 momentum = self._momentum 86 if callable(momentum): 87 momentum = momentum() 88 self._momentum_tensor = ops.convert_to_tensor(momentum, name="momentum") 89 90 def _apply_dense(self, grad, var): 91 mom = self.get_slot(var, "momentum") 92 return training_ops.apply_momentum( 93 var, mom, 94 math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), 95 grad, 96 math_ops.cast(self._momentum_tensor, var.dtype.base_dtype), 97 use_locking=self._use_locking, 98 use_nesterov=self._use_nesterov).op 99 100 def _resource_apply_dense(self, grad, var): 101 mom = self.get_slot(var, "momentum") 102 return training_ops.resource_apply_momentum( 103 var.handle, mom.handle, 104 math_ops.cast(self._learning_rate_tensor, grad.dtype.base_dtype), 105 grad, 106 math_ops.cast(self._momentum_tensor, grad.dtype.base_dtype), 107 use_locking=self._use_locking, 108 use_nesterov=self._use_nesterov) 109 110 def _apply_sparse(self, grad, var): 111 mom = self.get_slot(var, "momentum") 112 return training_ops.sparse_apply_momentum( 113 var, mom, 114 math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), 115 grad.values, grad.indices, 116 math_ops.cast(self._momentum_tensor, var.dtype.base_dtype), 117 use_locking=self._use_locking, 118 use_nesterov=self._use_nesterov).op 119 120 def _resource_apply_sparse(self, grad, var, indices): 121 mom = self.get_slot(var, "momentum") 122 return training_ops.resource_sparse_apply_momentum( 123 var.handle, mom.handle, 124 math_ops.cast(self._learning_rate_tensor, grad.dtype), 125 grad, indices, 126 math_ops.cast(self._momentum_tensor, grad.dtype), 127 use_locking=self._use_locking, 128 use_nesterov=self._use_nesterov) 129