1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Momentum for TensorFlow."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import ops
22from tensorflow.python.ops import math_ops
23from tensorflow.python.training import optimizer
24from tensorflow.python.training import training_ops
25from tensorflow.python.util.tf_export import tf_export
26
27
28@tf_export("train.MomentumOptimizer")
29class MomentumOptimizer(optimizer.Optimizer):
30  """Optimizer that implements the Momentum algorithm.
31
32  Computes (if `use_nesterov = False`):
33
34  ```
35  accumulation = momentum * accumulation + gradient
36  variable -= learning_rate * accumulation
37  ```
38
39  Note that in the dense version of this algorithm, `accumulation` is updated
40  and applied regardless of a gradient's value, whereas the sparse version (when
41  the gradient is an `IndexedSlices`, typically because of `tf.gather` or an
42  embedding) only updates variable slices and corresponding `accumulation` terms
43  when that part of the variable was used in the forward pass.
44  """
45
46  def __init__(self, learning_rate, momentum,
47               use_locking=False, name="Momentum", use_nesterov=False):
48    """Construct a new Momentum optimizer.
49
50    Args:
51      learning_rate: A `Tensor` or a floating point value.  The learning rate.
52      momentum: A `Tensor` or a floating point value.  The momentum.
53      use_locking: If `True` use locks for update operations.
54      name: Optional name prefix for the operations created when applying
55        gradients.  Defaults to "Momentum".
56      use_nesterov: If `True` use Nesterov Momentum.
57        See [Sutskever et al., 2013](
58        http://jmlr.org/proceedings/papers/v28/sutskever13.pdf).
59        This implementation always computes gradients at the value of the
60        variable(s) passed to the optimizer. Using Nesterov Momentum makes the
61        variable(s) track the values called `theta_t + mu*v_t` in the paper.
62
63    @compatibility(eager)
64    When eager execution is enabled, learning_rate and momentum can each be a
65    callable that takes no arguments and returns the actual value to use. This
66    can be useful for changing these values across different invocations of
67    optimizer functions.
68    @end_compatibility
69    """
70    super(MomentumOptimizer, self).__init__(use_locking, name)
71    self._learning_rate = learning_rate
72    self._momentum = momentum
73    self._use_nesterov = use_nesterov
74
75  def _create_slots(self, var_list):
76    for v in var_list:
77      self._zeros_slot(v, "momentum", self._name)
78
79  def _prepare(self):
80    learning_rate = self._learning_rate
81    if callable(learning_rate):
82      learning_rate = learning_rate()
83    self._learning_rate_tensor = ops.convert_to_tensor(learning_rate,
84                                                       name="learning_rate")
85    momentum = self._momentum
86    if callable(momentum):
87      momentum = momentum()
88    self._momentum_tensor = ops.convert_to_tensor(momentum, name="momentum")
89
90  def _apply_dense(self, grad, var):
91    mom = self.get_slot(var, "momentum")
92    return training_ops.apply_momentum(
93        var, mom,
94        math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
95        grad,
96        math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
97        use_locking=self._use_locking,
98        use_nesterov=self._use_nesterov).op
99
100  def _resource_apply_dense(self, grad, var):
101    mom = self.get_slot(var, "momentum")
102    return training_ops.resource_apply_momentum(
103        var.handle, mom.handle,
104        math_ops.cast(self._learning_rate_tensor, grad.dtype.base_dtype),
105        grad,
106        math_ops.cast(self._momentum_tensor, grad.dtype.base_dtype),
107        use_locking=self._use_locking,
108        use_nesterov=self._use_nesterov)
109
110  def _apply_sparse(self, grad, var):
111    mom = self.get_slot(var, "momentum")
112    return training_ops.sparse_apply_momentum(
113        var, mom,
114        math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
115        grad.values, grad.indices,
116        math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
117        use_locking=self._use_locking,
118        use_nesterov=self._use_nesterov).op
119
120  def _resource_apply_sparse(self, grad, var, indices):
121    mom = self.get_slot(var, "momentum")
122    return training_ops.resource_sparse_apply_momentum(
123        var.handle, mom.handle,
124        math_ops.cast(self._learning_rate_tensor, grad.dtype),
125        grad, indices,
126        math_ops.cast(self._momentum_tensor, grad.dtype),
127        use_locking=self._use_locking,
128        use_nesterov=self._use_nesterov)
129