1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Adadelta Optimizer."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.eager import context
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import test_util
27from tensorflow.python.ops import embedding_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import resource_variable_ops
30from tensorflow.python.ops import variables
31from tensorflow.python.platform import test
32from tensorflow.python.training import adadelta
33
34
35class AdadeltaOptimizerTest(test.TestCase):
36
37  def doTestBasic(self, use_resource=False, use_callable_params=False):
38    num_updates = 4  # number of ADADELTA steps to perform
39    for dtype in [dtypes.half, dtypes.float32]:
40      for grad in [0.2, 0.1, 0.01]:
41        for lr in [1.0, 0.5, 0.1]:
42          var0_init = [1.0, 2.0]
43          var1_init = [3.0, 4.0]
44          if use_resource:
45            var0 = resource_variable_ops.ResourceVariable(
46                var0_init, dtype=dtype)
47            var1 = resource_variable_ops.ResourceVariable(
48                var1_init, dtype=dtype)
49          else:
50            var0 = variables.Variable(var0_init, dtype=dtype)
51            var1 = variables.Variable(var1_init, dtype=dtype)
52
53          grads = constant_op.constant([grad, grad], dtype=dtype)
54
55          accum = 0.0
56          accum_update = 0.0
57
58          # ADADELTA gradient optimizer
59          rho = 0.95
60          epsilon = 1e-8
61          if use_callable_params:
62            adadelta_opt = adadelta.AdadeltaOptimizer(
63                learning_rate=lambda: lr,  # pylint: disable=cell-var-from-loop
64                rho=lambda: rho,  # pylint: disable=cell-var-from-loop
65                epsilon=lambda: epsilon)  # pylint: disable=cell-var-from-loop
66          else:
67            adadelta_opt = adadelta.AdadeltaOptimizer(
68                learning_rate=lr, rho=rho, epsilon=epsilon)
69          if not context.executing_eagerly():
70            adadelta_update = adadelta_opt.apply_gradients(
71                zip([grads, grads], [var0, var1]))
72            self.evaluate(variables.global_variables_initializer())
73
74            # TODO(lxuechen): This is hard to test in eager mode,
75            # since the optimizer is not fully initialized until the first
76            # call to `apply_gradients`
77            opt_vars = adadelta_opt.variables()
78            self.assertStartsWith(opt_vars[0].name, var0._shared_name)
79            self.assertStartsWith(opt_vars[1].name, var0._shared_name)
80            self.assertStartsWith(opt_vars[2].name, var1._shared_name)
81            self.assertStartsWith(opt_vars[3].name, var1._shared_name)
82            self.assertEqual(4, len(opt_vars))
83            # Assign slots
84            slot = [None] * 2
85            slot_update = [None] * 2
86            self.assertEqual(["accum", "accum_update"],
87                             adadelta_opt.get_slot_names())
88            slot[0] = adadelta_opt.get_slot(var0, "accum")
89            self.assertEquals(slot[0].get_shape(), var0.get_shape())
90            self.assertFalse(slot[0] in variables.trainable_variables())
91
92            slot_update[0] = adadelta_opt.get_slot(var0, "accum_update")
93            self.assertEquals(slot_update[0].get_shape(), var0.get_shape())
94            self.assertFalse(slot_update[0] in variables.trainable_variables())
95
96            slot[1] = adadelta_opt.get_slot(var1, "accum")
97            self.assertEquals(slot[1].get_shape(), var1.get_shape())
98            self.assertFalse(slot[1] in variables.trainable_variables())
99
100            slot_update[1] = adadelta_opt.get_slot(var1, "accum_update")
101            self.assertEquals(slot_update[1].get_shape(), var1.get_shape())
102            self.assertFalse(slot_update[1] in variables.trainable_variables())
103
104          # Fetch params to validate initial values
105          self.assertAllClose(var0_init, self.evaluate(var0))
106          self.assertAllClose(var1_init, self.evaluate(var1))
107
108          update = [None] * num_updates
109          tot_update = 0
110          for step in range(num_updates):
111            # Run adadelta update for comparison
112            if not context.executing_eagerly():
113              self.evaluate(adadelta_update)
114            else:
115              adadelta_opt.apply_gradients(zip([grads, grads], [var0, var1]))
116
117            # Perform initial update without previous accum values
118            accum = accum * rho + (grad**2) * (1 - rho)
119            update[step] = (
120                np.sqrt(accum_update + epsilon) *
121                (1. / np.sqrt(accum + epsilon)) * grad)
122            accum_update = (
123                accum_update * rho + (update[step]**2) * (1.0 - rho))
124            tot_update += update[step] * lr
125
126            if not context.executing_eagerly():
127              # Check that the accumulators have been updated
128              # TODO(lxuechen): This is hard to test in eager mode
129              for slot_idx in range(2):
130                self.assertAllCloseAccordingToType(
131                    np.array([accum, accum], dtype=dtype.as_numpy_dtype()),
132                    self.evaluate(slot[slot_idx]),
133                    rtol=1e-5)
134
135                self.assertAllCloseAccordingToType(
136                    np.array(
137                        [accum_update, accum_update],
138                        dtype=dtype.as_numpy_dtype()),
139                    self.evaluate(slot_update[slot_idx]),
140                    rtol=1e-5)
141
142              # Check that the parameters have been updated
143              self.assertAllCloseAccordingToType(
144                  np.array(
145                      [var0_init[0] - tot_update, var0_init[1] - tot_update],
146                      dtype=dtype.as_numpy_dtype()),
147                  self.evaluate(var0),
148                  rtol=1e-5)
149
150              self.assertAllCloseAccordingToType(
151                  np.array(
152                      [var1_init[0] - tot_update, var1_init[1] - tot_update],
153                      dtype=dtype.as_numpy_dtype()),
154                  self.evaluate(var1),
155                  rtol=1e-5)
156
157  def testBasic(self):
158    with self.cached_session():
159      self.doTestBasic(use_resource=False)
160
161  @test_util.run_in_graph_and_eager_modes(reset_test=True)
162  def testResourceBasic(self):
163    self.doTestBasic(use_resource=True)
164
165  def testBasicCallableParams(self):
166    with context.eager_mode():
167      self.doTestBasic(use_resource=True, use_callable_params=True)
168
169  @test_util.run_deprecated_v1
170  def testMinimizeSparseResourceVariable(self):
171    for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
172      with self.cached_session():
173        var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
174        x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
175        pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
176        loss = pred * pred
177        sgd_op = adadelta.AdadeltaOptimizer(
178            1.0, 1.0, 1.0).minimize(loss)
179        variables.global_variables_initializer().run()
180        # Fetch params to validate initial values
181        self.assertAllCloseAccordingToType([[1.0, 2.0]], self.evaluate(var0))
182        # Run 1 step of sgd
183        sgd_op.run()
184        # Validate updated params
185        self.assertAllCloseAccordingToType([[-111, -138]], self.evaluate(var0))
186
187
188if __name__ == "__main__":
189  test.main()
190