1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Adamax."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.eager import context
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import test_util
28from tensorflow.python.keras.optimizer_v2 import adamax
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import resource_variable_ops
32from tensorflow.python.ops import variables
33from tensorflow.python.platform import test
34
35
36def adamax_update_numpy(param,
37                        g_t,
38                        t,
39                        m,
40                        v,
41                        alpha=0.001,
42                        beta1=0.9,
43                        beta2=0.999,
44                        epsilon=1e-8):
45  m_t = beta1 * m + (1 - beta1) * g_t
46  v_t = np.maximum(beta2 * v, np.abs(g_t))
47  param_t = param - (alpha / (1 - beta1**(t + 1))) * (m_t / (v_t + epsilon))
48  return param_t, m_t, v_t
49
50
51def adamax_sparse_update_numpy(param,
52                               indices,
53                               g_t,
54                               t,
55                               m,
56                               v,
57                               alpha=0.001,
58                               beta1=0.9,
59                               beta2=0.999,
60                               epsilon=1e-8):
61  m_t, v_t, param_t = np.copy(m), np.copy(v), np.copy(param)
62  m_t_slice = beta1 * m[indices] + (1 - beta1) * g_t
63  v_t_slice = np.maximum(beta2 * v[indices], np.abs(g_t))
64  param_t_slice = param[indices] - (
65      (alpha / (1 - beta1**(t + 1))) * (m_t_slice / (v_t_slice + epsilon)))
66  m_t[indices] = m_t_slice
67  v_t[indices] = v_t_slice
68  param_t[indices] = param_t_slice
69  return param_t, m_t, v_t
70
71
72def get_beta_accumulators(opt, dtype):
73  local_step = math_ops.cast(opt.iterations + 1, dtype)
74  beta_1_t = math_ops.cast(opt._get_hyper("beta_1"), dtype)
75  beta_1_power = math_ops.pow(beta_1_t, local_step)
76  return beta_1_power
77
78
79class AdamaxOptimizerTest(test.TestCase):
80
81  def doTestSparse(self, use_resource=False):
82    for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
83      with self.cached_session():
84        # Initialize variables for numpy implementation.
85        zero_slots = lambda: np.zeros((3), dtype=dtype.as_numpy_dtype)  # pylint: disable=cell-var-from-loop
86        m0, v0, m1, v1 = zero_slots(), zero_slots(), zero_slots(), zero_slots()
87        var0_np = np.array([1.0, 2.0, 3.0], dtype=dtype.as_numpy_dtype)
88        grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
89        var1_np = np.array([4.0, 5.0, 6.0], dtype=dtype.as_numpy_dtype)
90        grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
91
92        var0 = resource_variable_ops.ResourceVariable(var0_np)
93        var1 = resource_variable_ops.ResourceVariable(var1_np)
94
95        grads0_np_indices = np.array([0, 1], dtype=np.int32)
96        grads0 = ops.IndexedSlices(
97            constant_op.constant(grads0_np),
98            constant_op.constant(grads0_np_indices), constant_op.constant([3]))
99        grads1_np_indices = np.array([2, 1], dtype=np.int32)
100        grads1 = ops.IndexedSlices(
101            constant_op.constant(grads1_np),
102            constant_op.constant(grads1_np_indices), constant_op.constant([3]))
103        opt = adamax.Adamax()
104        update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
105        variables.global_variables_initializer().run()
106
107        # Fetch params to validate initial values
108        self.assertAllClose([1.0, 2.0, 3.0], var0.eval())
109        self.assertAllClose([4.0, 5.0, 6.0], var1.eval())
110
111        beta1_power = get_beta_accumulators(opt, dtype)
112
113        # Run 3 steps of Adamax
114        for t in range(3):
115          self.assertAllCloseAccordingToType(0.9**(t + 1), beta1_power.eval())
116          update.run()
117
118          var0_np, m0, v0 = adamax_sparse_update_numpy(
119              var0_np, grads0_np_indices, grads0_np, t, m0, v0)
120          var1_np, m1, v1 = adamax_sparse_update_numpy(
121              var1_np, grads1_np_indices, grads1_np, t, m1, v1)
122
123          # Validate updated params
124          self.assertAllCloseAccordingToType(var0_np, var0.eval())
125          self.assertAllCloseAccordingToType(var1_np, var1.eval())
126
127  @test_util.run_deprecated_v1
128  def testResourceSparse(self):
129    self.doTestSparse(use_resource=True)
130
131  @test_util.run_deprecated_v1
132  def testSparseDevicePlacement(self):
133    for index_dtype in [dtypes.int32, dtypes.int64]:
134      with self.cached_session(force_gpu=test.is_gpu_available()):
135        # If a GPU is available, tests that all optimizer ops can be placed on
136        # it (i.e. they have GPU kernels).
137        var = variables.Variable([[1.0], [2.0]])
138        indices = constant_op.constant([0, 1], dtype=index_dtype)
139        g_sum = lambda: math_ops.reduce_sum(array_ops.gather(var, indices))  # pylint: disable=cell-var-from-loop
140        optimizer = adamax.Adamax(3.0)
141        minimize_op = optimizer.minimize(g_sum, var_list=[var])
142        variables.global_variables_initializer().run()
143        minimize_op.run()
144
145  @test_util.run_deprecated_v1
146  def testSparseRepeatedIndices(self):
147    for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
148      with self.cached_session():
149        repeated_index_update_var = variables.Variable(
150            [[1.0], [2.0]], dtype=dtype)
151        aggregated_update_var = variables.Variable(
152            [[1.0], [2.0]], dtype=dtype)
153        grad_repeated_index = ops.IndexedSlices(
154            constant_op.constant(
155                [0.1, 0.1], shape=[2, 1], dtype=dtype),
156            constant_op.constant([1, 1]),
157            constant_op.constant([2, 1]))
158        grad_aggregated = ops.IndexedSlices(
159            constant_op.constant(
160                [0.2], shape=[1, 1], dtype=dtype),
161            constant_op.constant([1]),
162            constant_op.constant([2, 1]))
163        repeated_update = adamax.Adamax().apply_gradients(
164            [(grad_repeated_index, repeated_index_update_var)])
165        aggregated_update = adamax.Adamax().apply_gradients(
166            [(grad_aggregated, aggregated_update_var)])
167        variables.global_variables_initializer().run()
168        self.assertAllClose(aggregated_update_var.eval(),
169                            repeated_index_update_var.eval())
170        for _ in range(3):
171          repeated_update.run()
172          aggregated_update.run()
173          self.assertAllClose(aggregated_update_var.eval(),
174                              repeated_index_update_var.eval())
175
176  @test_util.run_in_graph_and_eager_modes(reset_test=True)
177  def testBasic(self):
178    for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
179      with self.session(graph=ops.Graph()):
180        # Initialize variables for numpy implementation.
181        m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
182        var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
183        grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
184        var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
185        grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
186
187        var0 = resource_variable_ops.ResourceVariable(
188            var0_np, name="var0_%d" % i)
189        var1 = resource_variable_ops.ResourceVariable(
190            var1_np, name="var1_%d" % i)
191
192        grads0 = constant_op.constant(grads0_np)
193        grads1 = constant_op.constant(grads1_np)
194
195        opt = adamax.Adamax()
196        if not context.executing_eagerly():
197          update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
198
199        if not context.executing_eagerly():
200          self.evaluate(variables.global_variables_initializer())
201          # Fetch params to validate initial values
202          self.assertAllClose([1.0, 2.0], self.evaluate(var0))
203          self.assertAllClose([3.0, 4.0], self.evaluate(var1))
204
205        # Run 3 steps of Adamax
206        for t in range(3):
207          beta_1_power = get_beta_accumulators(opt, dtype)
208          self.assertAllCloseAccordingToType(0.9**(t + 1),
209                                             self.evaluate(beta_1_power))
210          if not context.executing_eagerly():
211            self.evaluate(update)
212          else:
213            opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
214
215          var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0)
216          var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1)
217
218          # Validate updated params
219          self.assertAllCloseAccordingToType(
220              var0_np, self.evaluate(var0), rtol=1e-2)
221          self.assertAllCloseAccordingToType(
222              var1_np, self.evaluate(var1), rtol=1e-2)
223
224  @test_util.run_in_graph_and_eager_modes(reset_test=True)
225  def testBasicWithLearningRateDecay(self):
226    for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
227      with self.session(graph=ops.Graph()):
228        # Initialize variables for numpy implementation.
229        m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
230        var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
231        grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
232        var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
233        grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
234
235        var0 = resource_variable_ops.ResourceVariable(
236            var0_np, name="var0_%d" % i)
237        var1 = resource_variable_ops.ResourceVariable(
238            var1_np, name="var1_%d" % i)
239
240        grads0 = constant_op.constant(grads0_np)
241        grads1 = constant_op.constant(grads1_np)
242
243        learning_rate = 0.001
244        decay = 0.002
245        opt = adamax.Adamax(learning_rate=learning_rate, decay=decay)
246        if not context.executing_eagerly():
247          update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
248
249        if not context.executing_eagerly():
250          self.evaluate(variables.global_variables_initializer())
251          # Fetch params to validate initial values
252          self.assertAllClose([1.0, 2.0], self.evaluate(var0))
253          self.assertAllClose([3.0, 4.0], self.evaluate(var1))
254
255        # Run 3 steps of Adamax
256        for t in range(3):
257          beta_1_power = get_beta_accumulators(opt, dtype)
258          self.assertAllCloseAccordingToType(0.9**(t + 1),
259                                             self.evaluate(beta_1_power))
260          if not context.executing_eagerly():
261            self.evaluate(update)
262          else:
263            opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
264
265          lr = learning_rate / (1 + decay * t)
266
267          var0_np, m0, v0 = adamax_update_numpy(
268              var0_np, grads0_np, t, m0, v0, alpha=lr)
269          var1_np, m1, v1 = adamax_update_numpy(
270              var1_np, grads1_np, t, m1, v1, alpha=lr)
271
272          # Validate updated params
273          self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0),
274                                             rtol=1e-2)
275          self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1),
276                                             rtol=1e-2)
277
278  @test_util.run_deprecated_v1
279  def testTensorLearningRate(self):
280    for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
281      with self.cached_session():
282        # Initialize variables for numpy implementation.
283        m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
284        var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
285        grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
286        var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
287        grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
288
289        var0 = variables.Variable(var0_np)
290        var1 = variables.Variable(var1_np)
291        grads0 = constant_op.constant(grads0_np)
292        grads1 = constant_op.constant(grads1_np)
293        opt = adamax.Adamax(constant_op.constant(0.001))
294        update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
295        variables.global_variables_initializer().run()
296
297        # Fetch params to validate initial values
298        self.assertAllClose([1.0, 2.0], var0.eval())
299        self.assertAllClose([3.0, 4.0], var1.eval())
300
301        beta1_power = get_beta_accumulators(opt, dtype)
302
303        # Run 3 steps of Adamax
304        for t in range(3):
305          self.assertAllCloseAccordingToType(0.9**(t + 1), beta1_power.eval())
306          update.run()
307
308          var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0)
309          var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1)
310
311          # Validate updated params
312          self.assertAllCloseAccordingToType(var0_np, var0.eval())
313          self.assertAllCloseAccordingToType(var1_np, var1.eval())
314
315  @test_util.run_deprecated_v1
316  def testSharing(self):
317    for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
318      with self.cached_session():
319        # Initialize variables for numpy implementation.
320        m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
321        var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
322        grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
323        var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
324        grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
325
326        var0 = variables.Variable(var0_np)
327        var1 = variables.Variable(var1_np)
328        grads0 = constant_op.constant(grads0_np)
329        grads1 = constant_op.constant(grads1_np)
330        opt = adamax.Adamax()
331        update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
332        update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
333        variables.global_variables_initializer().run()
334
335        beta1_power = get_beta_accumulators(opt, dtype)
336
337        # Fetch params to validate initial values
338        self.assertAllClose([1.0, 2.0], var0.eval())
339        self.assertAllClose([3.0, 4.0], var1.eval())
340
341        # Run 3 steps of intertwined Adamax1 and Adamax2.
342        for t in range(3):
343          self.assertAllCloseAccordingToType(0.9**(t + 1), beta1_power.eval())
344          if t % 2 == 0:
345            update1.run()
346          else:
347            update2.run()
348
349          var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0)
350          var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1)
351
352          # Validate updated params
353          self.assertAllCloseAccordingToType(var0_np, var0.eval())
354          self.assertAllCloseAccordingToType(var1_np, var1.eval())
355
356  def testSlotsUniqueEager(self):
357    with context.eager_mode():
358      v1 = resource_variable_ops.ResourceVariable(1.)
359      v2 = resource_variable_ops.ResourceVariable(1.)
360      opt = adamax.Adamax(1.)
361      opt.minimize(lambda: v1 + v2, var_list=[v1, v2])
362      # There should be iteration, and two unique slot variables for v1 and v2.
363      self.assertEqual(5, len(set(opt.variables())))
364
365  def testConstructAdamaxWithLR(self):
366    opt = adamax.Adamax(lr=1.0)
367    opt_2 = adamax.Adamax(learning_rate=0.1, lr=1.0)
368    opt_3 = adamax.Adamax(learning_rate=0.1)
369    self.assertIsInstance(opt.lr, variables.Variable)
370    self.assertIsInstance(opt_2.lr, variables.Variable)
371    self.assertIsInstance(opt_3.lr, variables.Variable)
372
373    self.evaluate(variables.global_variables_initializer())
374    self.assertAllClose(self.evaluate(opt.lr), (1.0))
375    self.assertAllClose(self.evaluate(opt_2.lr), (1.0))
376    self.assertAllClose(self.evaluate(opt_3.lr), (0.1))
377
378  def testConstructAdamaxWithEpsilonValues(self):
379    opt = adamax.Adamax(epsilon=None)
380    config = opt.get_config()
381    self.assertEqual(config["epsilon"], 1e-7)
382
383    opt = adamax.Adamax(epsilon=1e-8)
384    config = opt.get_config()
385    self.assertEqual(config["epsilon"], 1e-8)
386
387
388if __name__ == "__main__":
389  test.main()
390