1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functional test for OptimizerV2."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python import keras
24from tensorflow.python.eager import context
25from tensorflow.python.eager import def_function
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import test_util
30from tensorflow.python.keras import backend
31from tensorflow.python.keras import callbacks
32from tensorflow.python.keras import keras_parameterized
33from tensorflow.python.keras import optimizers
34from tensorflow.python.keras import testing_utils
35from tensorflow.python.keras.engine import input_layer
36from tensorflow.python.keras.engine import sequential
37from tensorflow.python.keras.engine import training
38from tensorflow.python.keras.layers import core
39from tensorflow.python.keras.optimizer_v2 import adadelta
40from tensorflow.python.keras.optimizer_v2 import adagrad
41from tensorflow.python.keras.optimizer_v2 import adam
42from tensorflow.python.keras.optimizer_v2 import adamax
43from tensorflow.python.keras.optimizer_v2 import gradient_descent
44from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
45from tensorflow.python.keras.optimizer_v2 import nadam
46from tensorflow.python.keras.optimizer_v2 import optimizer_v2
47from tensorflow.python.keras.optimizer_v2 import rmsprop
48from tensorflow.python.ops import array_ops
49from tensorflow.python.ops import clip_ops
50from tensorflow.python.ops import resource_variable_ops
51from tensorflow.python.ops import state_ops
52from tensorflow.python.ops import variables
53from tensorflow.python.platform import test
54from tensorflow.python.training import momentum
55from tensorflow.python.training import training_util
56
57
58class OptimizerTest(test.TestCase):
59
60  @test_util.run_in_graph_and_eager_modes
61  def testBasic(self):
62    for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
63      with self.cached_session():
64        var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
65        var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
66        loss = lambda: 5 * var0 + 3 * var1  # pylint: disable=cell-var-from-loop
67        sgd = gradient_descent.SGD(3.0)
68
69        self.evaluate(variables.global_variables_initializer())
70        # Fetch params to validate initial values
71        self.assertAllClose([1.0, 2.0], self.evaluate(var0))
72        self.assertAllClose([3.0, 4.0], self.evaluate(var1))
73        # Run 1 step of sgd through optimizer
74        opt_op = sgd.minimize(loss, var_list=[var0, var1])
75        self.evaluate(variables.global_variables_initializer())
76        self.evaluate(opt_op)
77        # Validate updated params
78        self.assertAllClose([-14., -13.], self.evaluate(var0))
79        self.assertAllClose([-6., -5.], self.evaluate(var1))
80
81  @test_util.run_in_graph_and_eager_modes
82  def testAdaptiveLearningRate(self):
83    for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
84      var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
85      var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
86
87      def loss():
88        return 5 * var0 + 3 * var1  # pylint: disable=cell-var-from-loop
89
90      sgd = gradient_descent.SGD(1.0)
91
92      self.evaluate(variables.global_variables_initializer())
93      # Fetch params to validate initial values
94      self.assertAllClose([1.0, 2.0], self.evaluate(var0))
95      self.assertAllClose([3.0, 4.0], self.evaluate(var1))
96      # Run 1 step of sgd through optimizer
97      opt_op = sgd.minimize(loss, [var0, var1])
98      self.evaluate(variables.global_variables_initializer())
99      self.evaluate(opt_op)
100      # Validate updated params
101      # var0 = [1., 2.] - 1.0 * [5, 5]
102      self.assertAllClose([-4., -3.], self.evaluate(var0))
103      # var1 = [3., 4.] - 1.0 * [3, 3]
104      self.assertAllClose([0., 1.], self.evaluate(var1))
105
106      sgd.learning_rate = 0.5
107      if context.executing_eagerly():
108        sgd.minimize(loss, [var0, var1])
109      else:
110        self.evaluate(opt_op)
111      # Validate updated params
112      # var0 = [-4., -3.] - 0.5 * [5, 5]
113      self.assertAllClose([-6.5, -5.5], self.evaluate(var0))
114      # var1 = [0., 1.] - 0.5 * [3, 3]
115      self.assertAllClose([-1.5, -0.5], self.evaluate(var1))
116
117      sgd.learning_rate = learning_rate_schedule.InverseTimeDecay(
118          0.5, decay_steps=1.0, decay_rate=0.5)
119      if context.executing_eagerly():
120        sgd.minimize(loss, [var0, var1])
121      else:
122        self.evaluate(opt_op)
123
124  @test_util.run_in_graph_and_eager_modes
125  def testPrecomputedGradient(self):
126    for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
127      with self.cached_session():
128        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
129        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
130        loss = lambda: 5 * var0 + 3 * var1  # pylint: disable=cell-var-from-loop
131        grad_loss = constant_op.constant([42, -42], dtype=dtype)
132        sgd = gradient_descent.SGD(3.0)
133
134        self.evaluate(variables.global_variables_initializer())
135        # Fetch params to validate initial values
136        self.assertAllClose([1.0, 2.0], self.evaluate(var0))
137        self.assertAllClose([3.0, 4.0], self.evaluate(var1))
138        # Run 1 step of sgd through optimizer
139        opt_op = sgd.minimize(loss, var_list=[var0, var1], grad_loss=grad_loss)
140        self.evaluate(variables.global_variables_initializer())
141        self.evaluate(opt_op)
142        # Validate updated params
143        self.assertAllClose([1.0 - 3 * 5 * 42.0, 2.0 - 3 * 5 * (-42.0)],
144                            self.evaluate(var0))
145        self.assertAllClose([3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)],
146                            self.evaluate(var1))
147
148  @test_util.run_in_graph_and_eager_modes
149  def testNoGradients(self):
150    for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
151      with self.cached_session():
152        var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
153        var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
154        loss = lambda: 5 * var0  # pylint: disable=cell-var-from-loop
155        sgd_op = gradient_descent.SGD(3.0)
156        with self.assertRaisesRegexp(ValueError, 'No gradients'):
157          # var1 has no gradient
158          sgd_op.minimize(loss, var_list=[var1])
159
160  @test_util.run_in_graph_and_eager_modes
161  def testNoGradientsForAnyVariables_Minimize(self):
162    for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
163      with self.cached_session():
164        var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
165        var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
166        loss = lambda: constant_op.constant(5.0)
167
168        sgd_op = gradient_descent.SGD(3.0)
169        with self.assertRaisesRegexp(ValueError,
170                                     'No gradients provided for any variable'):
171          sgd_op.minimize(loss, var_list=[var0, var1])
172
173  @test_util.run_in_graph_and_eager_modes
174  def testNoGradientsForAnyVariables_ApplyGradients(self):
175    for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
176      with self.cached_session():
177        var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
178        var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
179        sgd_op = gradient_descent.SGD(3.0)
180        with self.assertRaisesRegexp(ValueError,
181                                     'No gradients provided for any variable'):
182          sgd_op.apply_gradients([(None, var0), (None, var1)])
183
184  @test_util.run_in_graph_and_eager_modes
185  def testGradientsAsVariables(self):
186    for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
187      with self.cached_session():
188        var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
189        var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
190        loss = lambda: 5 * var0 + 3 * var1  # pylint: disable=cell-var-from-loop
191
192        sgd = gradient_descent.SGD(3.0)
193        grads_and_vars = sgd._compute_gradients(loss, [var0, var1])
194        # Convert gradients to tf.Variables
195        converted_grads = [
196            resource_variable_ops.ResourceVariable(
197                array_ops.zeros([2], dtype), name='c_%d_%d' % (i, j))
198            for j, gv in enumerate(grads_and_vars)
199        ]
200        convert_ops = [
201            state_ops.assign(converted_grads[j], gv[0])
202            for j, gv in enumerate(grads_and_vars)
203        ]
204
205        # Run convert_ops to achieve the gradients converting
206        self.evaluate(variables.global_variables_initializer())
207        self.evaluate(convert_ops)
208        # Fetch params to validate initial values
209        self.assertAllClose([1.0, 2.0], self.evaluate(var0))
210        self.assertAllClose([3.0, 4.0], self.evaluate(var1))
211
212        # Run 1 step of sgd through optimizer
213        converted_grads_and_vars = list(zip(converted_grads, [var0, var1]))
214        opt_op = sgd.apply_gradients(converted_grads_and_vars)
215        self.evaluate(variables.global_variables_initializer())
216        self.evaluate(convert_ops)
217        self.evaluate(opt_op)
218
219        # Validate updated params
220        self.assertAllClose([-14., -13.], self.evaluate(var0))
221        self.assertAllClose([-6., -5.], self.evaluate(var1))
222
223  @test_util.run_in_graph_and_eager_modes
224  def testComputeGradientsWithTensors(self):
225    with self.cached_session():
226      x = ops.convert_to_tensor(1.0)
227
228      def f():
229        return x * x
230
231      sgd = gradient_descent.SGD(3.0)
232      grads_and_vars = sgd._compute_gradients(f, [x])
233      self.assertEqual(1, len(grads_and_vars))
234      grad, x_as_var = grads_and_vars[0]
235      self.assertIs(x, x_as_var)
236      self.assertEqual(2.0, self.evaluate(grad))
237
238      with self.assertRaises(NotImplementedError):
239        sgd.apply_gradients(grads_and_vars)
240
241  @test_util.run_in_graph_and_eager_modes
242  def testConstraint(self):
243    constraint_01 = lambda x: clip_ops.clip_by_value(x, -0.1, 0.)
244    constraint_0 = lambda x: clip_ops.clip_by_value(x, 0., 1.)
245    with self.cached_session():
246      var0 = variables.Variable([1.0, 2.0],
247                                constraint=constraint_01)
248      var1 = variables.Variable([3.0, 4.0],
249                                constraint=constraint_0)
250      loss = lambda: 5 * var0 + 3 * var1
251      sgd = gradient_descent.SGD(3.0)
252
253      self.evaluate(variables.global_variables_initializer())
254      # Fetch params to validate initial values
255      self.assertAllClose([1.0, 2.0], self.evaluate(var0))
256      self.assertAllClose([3.0, 4.0], self.evaluate(var1))
257      # Run 1 step of sgd through optimizer
258      opt_op = sgd.minimize(loss, var_list=[var0, var1])
259      self.evaluate(variables.global_variables_initializer())
260      self.evaluate(opt_op)
261      # Validate updated params
262      self.assertAllClose([-0.1, -0.1], self.evaluate(var0))
263      self.assertAllClose([0., 0.], self.evaluate(var1))
264
265  @test_util.run_in_graph_and_eager_modes
266  def testIterationWithoutMinimize(self):
267    with self.cached_session():
268      sgd = gradient_descent.SGD(3.0)
269      self.evaluate(sgd.iterations.initializer)
270      self.assertEqual(0, self.evaluate(sgd.iterations))
271
272  @test_util.run_in_graph_and_eager_modes
273  def testConfig(self):
274    with self.cached_session():
275      opt = gradient_descent.SGD(learning_rate=1.0)
276      config = opt.get_config()
277      opt2 = gradient_descent.SGD.from_config(config)
278      lr = opt._get_hyper('learning_rate')
279      lr2 = opt2._get_hyper('learning_rate')
280      self.evaluate(variables.global_variables_initializer())
281      # assert both are equal float values.
282      self.assertEqual(self.evaluate(lr), self.evaluate(lr2))
283      var0 = variables.Variable([[1.0], [2.0]], dtype=dtypes.float32)
284      loss = lambda: 3 * var0
285      # learning rate variable created when calling minimize.
286      opt.minimize(loss, [var0])
287      opt3 = gradient_descent.SGD.from_config(config)
288      lr3 = opt3._get_hyper('learning_rate')
289      self.evaluate(variables.global_variables_initializer())
290      self.assertEqual(self.evaluate(lr), self.evaluate(lr3))
291
292  @test_util.run_in_graph_and_eager_modes
293  def testConfigWithLearningRateDecay(self):
294    with self.cached_session():
295      decay_schedule = learning_rate_schedule.InverseTimeDecay(
296          0.5, decay_steps=1.0, decay_rate=0.1)
297      step = 10
298      opt = gradient_descent.SGD(decay_schedule)
299      config = opt.get_config()
300      opt2 = gradient_descent.SGD.from_config(config)
301      # assert both are equal float values.
302      self.assertAllEqual(
303          decay_schedule(step),
304          opt._get_hyper('learning_rate')(step))
305      self.assertAllEqual(
306          decay_schedule(step),
307          opt2._get_hyper('learning_rate')(step))
308      var0 = variables.Variable([[1.0], [2.0]], dtype=dtypes.float32)
309      loss = lambda: 3 * var0
310      # learning rate variable created when calling minimize.
311      opt.minimize(loss, [var0])
312      self.evaluate(variables.global_variables_initializer())
313      config = opt.get_config()
314      opt3 = gradient_descent.SGD.from_config(config)
315      self.assertAllEqual(
316          self.evaluate(opt._get_hyper('learning_rate')(step)),
317          opt3._get_hyper('learning_rate')(step))
318
319  @test_util.run_in_graph_and_eager_modes
320  def testGradClipValue(self):
321    with self.cached_session():
322      var = resource_variable_ops.ResourceVariable([1.0, 2.0])
323      loss = lambda: 3 * var
324      opt = gradient_descent.SGD(learning_rate=1.0, clipvalue=1.0)
325      opt_op = opt.minimize(loss, [var])
326      self.evaluate(variables.global_variables_initializer())
327      self.evaluate(opt_op)
328      self.assertAllClose([0., 1.], self.evaluate(var))
329
330  @test_util.run_in_graph_and_eager_modes
331  def testGradClipNorm(self):
332    with self.cached_session():
333      var = resource_variable_ops.ResourceVariable([1.0])
334      loss = lambda: 3 * var
335      opt = gradient_descent.SGD(learning_rate=1.0, clipnorm=1.0)
336      opt_op = opt.minimize(loss, [var])
337      self.evaluate(variables.global_variables_initializer())
338      self.evaluate(opt_op)
339      self.assertAllClose([0.], self.evaluate(var))
340
341  @test_util.run_in_graph_and_eager_modes
342  def testInvalidClipNorm(self):
343    with self.assertRaisesRegexp(ValueError, '>= 0'):
344      gradient_descent.SGD(learning_rate=1.0, clipnorm=-1.0)
345
346  @test_util.run_in_graph_and_eager_modes
347  def testInvalidKwargs(self):
348    with self.assertRaisesRegexp(TypeError, 'Unexpected keyword argument'):
349      gradient_descent.SGD(learning_rate=1.0, invalidkwargs=1.0)
350
351  @test_util.run_in_graph_and_eager_modes
352  def testWeights(self):
353    with self.cached_session():
354      opt1 = adam.Adam(learning_rate=1.0)
355      var1 = resource_variable_ops.ResourceVariable([1.0, 2.0],
356                                                    dtype=dtypes.float32)
357      loss1 = lambda: 3 * var1
358      opt_op_1 = opt1.minimize(loss1, [var1])
359      self.evaluate(variables.global_variables_initializer())
360      config = opt1.get_config()
361      opt2 = adam.Adam.from_config(config)
362      var2 = resource_variable_ops.ResourceVariable([1.0, 2.0],
363                                                    dtype=dtypes.float32)
364      loss2 = lambda: 3 * var2
365      opt_op_2 = opt2.minimize(loss2, [var2])
366      weights = opt1.get_weights()
367
368      # Assert set_weights and both variables get updated to same value.
369      self.evaluate(variables.global_variables_initializer())
370      opt2.set_weights(weights)
371      self.evaluate([opt_op_1, opt_op_2])
372      self.assertAllClose(self.evaluate(var1), self.evaluate(var2))
373      self.assertEqual(1, self.evaluate(opt1.iterations))
374      self.assertEqual(1, self.evaluate(opt2.iterations))
375
376      var3 = resource_variable_ops.ResourceVariable([1.0, 2.0, 3.0],
377                                                    dtype=dtypes.float32)
378      var4 = resource_variable_ops.ResourceVariable([4.0, 5.0, 6.0],
379                                                    dtype=dtypes.float32)
380      loss3 = lambda: 3 * var3 + 5 * var4
381      opt_op_3 = opt1.minimize(loss3, [var3, var4])
382
383      # Assert set_weights with ValueError since weight list does not match.
384      self.evaluate(variables.global_variables_initializer())
385      weights = opt1.get_weights()
386      with self.assertRaisesRegexp(ValueError, 'but the optimizer was'):
387        opt2.set_weights(weights)
388
389      # Assert set_weights and variables get updated to same value.
390      var5 = resource_variable_ops.ResourceVariable([1.0, 2.0, 3.0],
391                                                    dtype=dtypes.float32)
392      var6 = resource_variable_ops.ResourceVariable([4.0, 5.0, 6.0],
393                                                    dtype=dtypes.float32)
394      loss4 = lambda: 3 * var5 + 5 * var6
395      opt_op_4 = opt2.minimize(loss4, [var5, var6])
396      self.evaluate(variables.global_variables_initializer())
397      opt2.set_weights(weights)
398      self.evaluate([opt_op_3, opt_op_4])
399      self.assertAllClose(
400          self.evaluate([var3, var4]), self.evaluate([var5, var6]))
401
402  @test_util.run_in_graph_and_eager_modes
403  def testGettingHyperParameters(self):
404    opt = adam.Adam(learning_rate=1.0)
405    var = resource_variable_ops.ResourceVariable([1.0, 2.0],
406                                                 dtype=dtypes.float32)
407    loss = lambda: 3 * var
408    opt_op = opt.minimize(loss, [var])
409    self.evaluate(variables.global_variables_initializer())
410    self.evaluate(opt_op)
411
412    lr = self.evaluate(opt.lr)
413    self.assertEqual(1.0, lr)
414
415    opt.lr = 2.0
416    lr = self.evaluate(opt.lr)
417    self.assertEqual(2.0, lr)
418
419    self.evaluate(opt.lr.assign(3.0))
420    lr = self.evaluate(opt.lr)
421    self.assertEqual(3.0, lr)
422
423    with self.assertRaises(AttributeError):
424      opt.not_an_attr += 3
425
426  @test_util.run_in_graph_and_eager_modes
427  def testGettingHyperParametersWithLrInConstructor(self):
428    opt = gradient_descent.SGD(lr=3.0)
429    var = resource_variable_ops.ResourceVariable([1.0, 2.0],
430                                                 dtype=dtypes.float32)
431    loss = lambda: 3 * var
432    opt_op = opt.minimize(loss, [var])
433    self.evaluate(variables.global_variables_initializer())
434    self.evaluate(opt_op)
435
436    self.assertTrue(isinstance(opt.lr, resource_variable_ops.ResourceVariable))
437    self.assertTrue(
438        isinstance(opt.learning_rate, resource_variable_ops.ResourceVariable))
439
440    lr = self.evaluate(opt.lr)
441    self.assertEqual(3.0, lr)
442
443    opt.lr = 2.0
444    lr = self.evaluate(opt.lr)
445    self.assertEqual(2.0, lr)
446
447    self.evaluate(opt.lr.assign(4.0))
448    lr = self.evaluate(opt.lr)
449    self.assertEqual(4.0, lr)
450
451  @test_util.run_in_graph_and_eager_modes
452  def testOptimizerWithKerasModel(self):
453    a = input_layer.Input(shape=(3,), name='input_a')
454    b = input_layer.Input(shape=(3,), name='input_b')
455
456    dense = core.Dense(4, name='dense')
457    c = dense(a)
458    d = dense(b)
459    e = core.Dropout(0.5, name='dropout')(c)
460
461    model = training.Model([a, b], [d, e])
462
463    optimizer = gradient_descent.SGD(learning_rate=0.001)
464    loss = 'mse'
465    model.compile(optimizer, loss, metrics=['mae'])
466
467    input_a_np = np.random.random((10, 3))
468    input_b_np = np.random.random((10, 3))
469
470    output_d_np = np.random.random((10, 4))
471    output_e_np = np.random.random((10, 4))
472
473    model.fit([input_a_np, input_b_np], [output_d_np, output_e_np],
474              epochs=1,
475              batch_size=5)
476
477  @test_util.run_in_graph_and_eager_modes
478  def testOptimizerWithCallbacks(self):
479    np.random.seed(1331)
480    input_np = np.random.random((10, 3))
481    output_np = np.random.random((10, 4))
482    a = input_layer.Input(shape=(3,), name='input_a')
483    model = sequential.Sequential()
484    model.add(core.Dense(4, name='dense'))
485    model.add(core.Dropout(0.5, name='dropout'))
486    model(a)
487    optimizer = gradient_descent.SGD(learning_rate=0.1)
488    model.compile(optimizer, loss='mse', metrics=['mae'])
489    # This does not reduce the LR after the first epoch (due to low delta).
490    cbks = [
491        callbacks.ReduceLROnPlateau(
492            monitor='val_loss', factor=0.1, min_delta=0, patience=1, cooldown=5)
493    ]
494    model.fit(
495        input_np,
496        output_np,
497        batch_size=10,
498        validation_data=(input_np, output_np),
499        callbacks=cbks,
500        epochs=2,
501        verbose=0)
502    self.assertAllClose(
503        float(backend.get_value(model.optimizer.lr)), 0.1, atol=1e-4)
504
505    # This should reduce the LR after the first epoch (due to high delta).
506    cbks = [
507        callbacks.ReduceLROnPlateau(
508            monitor='val_loss',
509            factor=0.1,
510            min_delta=10,
511            patience=1,
512            cooldown=5)
513    ]
514    model.fit(
515        input_np,
516        output_np,
517        batch_size=10,
518        validation_data=(input_np, output_np),
519        callbacks=cbks,
520        epochs=2,
521        verbose=2)
522    self.assertAllClose(
523        float(backend.get_value(model.optimizer.lr)), 0.01, atol=1e-4)
524
525  def testOptimizerSetIterations(self):
526    global_step = training_util.get_or_create_global_step()
527    opt = adam.Adam(learning_rate=1.0)
528    opt.iterations = global_step
529    var = resource_variable_ops.ResourceVariable([1.0, 2.0],
530                                                 dtype=dtypes.float32)
531    self.evaluate(variables.global_variables_initializer())
532    init_step_value = self.evaluate(global_step)
533    loss = lambda: 3 * var
534    opt_op = opt.minimize(loss, [var])
535    self.evaluate(variables.global_variables_initializer())
536    self.evaluate(opt_op)
537    new_step_value = self.evaluate(global_step)
538    self.assertEqual(new_step_value, init_step_value + 1)
539
540  def testVarKey(self):
541    with context.graph_mode():
542      a = variables.Variable([1., 2.], name='var')
543      b = variables.Variable([1.], name='var')
544      self.assertTrue(a._in_graph_mode)
545      self.assertTrue(b._in_graph_mode)
546      var_key = optimizer_v2._var_key(a)
547      self.assertEqual('var', var_key)
548      var_key = optimizer_v2._var_key(b)
549      self.assertEqual('var_1', var_key)
550
551
552@keras_parameterized.run_with_all_model_types
553class OptimizersCompatibilityTest(keras_parameterized.TestCase):
554
555  def _testOptimizersCompatibility(self, opt_v1, opt_v2, test_weights=True):
556    np.random.seed(1331)
557    with self.cached_session():
558      train_samples = 20
559      input_dim = 3
560      num_classes = 2
561      (x, y), _ = testing_utils.get_test_data(
562          train_samples=train_samples,
563          test_samples=10,
564          input_shape=(input_dim,),
565          num_classes=num_classes)
566      y = keras.utils.to_categorical(y)
567
568      num_hidden = 5
569      model_v1 = testing_utils.get_small_sequential_mlp(
570          num_hidden=num_hidden, num_classes=num_classes, input_dim=input_dim)
571      model_v1.compile(opt_v1, loss='categorical_crossentropy', metrics=[])
572      model_v1.fit(x, y, batch_size=5, epochs=1)
573
574      model_v2 = testing_utils.get_small_sequential_mlp(
575          num_hidden=num_hidden, num_classes=num_classes, input_dim=input_dim)
576      model_v2.set_weights(model_v1.get_weights())
577      model_v2.compile(opt_v2, loss='categorical_crossentropy', metrics=[])
578      model_v2._make_train_function()
579      if test_weights:
580        opt_v2.set_weights(opt_v1.get_weights())
581
582      hist_1 = model_v1.fit(x, y, batch_size=5, epochs=1, shuffle=False)
583      hist_2 = model_v2.fit(x, y, batch_size=5, epochs=1, shuffle=False)
584      self.assertAllClose(model_v1.get_weights(), model_v2.get_weights(),
585                          rtol=1e-5, atol=1e-5)
586      self.assertAllClose(hist_1.history['loss'], hist_2.history['loss'],
587                          rtol=1e-5, atol=1e-5)
588
589  def testAdadeltaCompatibility(self):
590    opt_v1 = optimizers.Adadelta(lr=0.01)
591    opt_v2 = adadelta.Adadelta(learning_rate=0.01)
592    self._testOptimizersCompatibility(opt_v1, opt_v2)
593
594  def testAdagradCompatibility(self):
595    opt_v1 = optimizers.Adagrad(lr=0.01)
596    opt_v2 = adagrad.Adagrad(learning_rate=0.01)
597    self._testOptimizersCompatibility(opt_v1, opt_v2)
598
599  def testAdamCompatibility(self):
600    opt_v1 = optimizers.Adam()
601    opt_v2 = adam.Adam()
602    self._testOptimizersCompatibility(opt_v1, opt_v2)
603
604  def testAdamaxCompatibility(self):
605    opt_v1 = optimizers.Adamax(lr=0.01)
606    opt_v2 = adamax.Adamax(learning_rate=0.01)
607    self._testOptimizersCompatibility(opt_v1, opt_v2)
608
609  def testNadamCompatibility(self):
610    opt_v1 = optimizers.Nadam(lr=0.001)
611    opt_v2 = nadam.Nadam(learning_rate=0.001)
612    self._testOptimizersCompatibility(opt_v1, opt_v2)
613
614  def testMomentumCompatibility(self):
615    opt_v1 = optimizers.SGD(lr=0.01, momentum=0.9)
616    opt_v2 = gradient_descent.SGD(learning_rate=0.01, momentum=0.9)
617    self._testOptimizersCompatibility(opt_v1, opt_v2)
618
619  def testRMSpropCompatibility(self):
620    opt_v1 = optimizers.RMSprop()
621    opt_v2 = rmsprop.RMSprop()
622    self._testOptimizersCompatibility(opt_v1, opt_v2)
623
624  def testSGDCompatibility(self):
625    opt_v1 = optimizers.SGD(lr=0.01)
626    opt_v2 = gradient_descent.SGD(learning_rate=0.01)
627    self._testOptimizersCompatibility(opt_v1, opt_v2, False)
628
629  def testNumericEquivalenceForNesterovMomentum(self):
630    np.random.seed(1331)
631    with self.cached_session():
632      train_samples = 20
633      input_dim = 3
634      num_classes = 2
635      (x, y), _ = testing_utils.get_test_data(
636          train_samples=train_samples,
637          test_samples=10,
638          input_shape=(input_dim,),
639          num_classes=num_classes)
640      y = keras.utils.to_categorical(y)
641
642      num_hidden = 5
643      model_k_v1 = testing_utils.get_small_sequential_mlp(
644          num_hidden=num_hidden, num_classes=num_classes, input_dim=input_dim)
645      model_k_v2 = testing_utils.get_small_sequential_mlp(
646          num_hidden=num_hidden, num_classes=num_classes, input_dim=input_dim)
647      model_k_v2.set_weights(model_k_v1.get_weights())
648      model_tf = testing_utils.get_small_sequential_mlp(
649          num_hidden=num_hidden, num_classes=num_classes, input_dim=input_dim)
650      model_tf.set_weights(model_k_v2.get_weights())
651
652      opt_k_v1 = optimizers.SGD(lr=0.001, momentum=0.9, nesterov=True)
653      opt_k_v2 = gradient_descent.SGD(momentum=0.9, nesterov=True)
654      opt_tf = momentum.MomentumOptimizer(
655          learning_rate=0.001, momentum=0.9, use_nesterov=True)
656
657      model_k_v1.compile(opt_k_v1, loss='categorical_crossentropy', metrics=[])
658      model_k_v2.compile(opt_k_v2, loss='categorical_crossentropy', metrics=[])
659      model_tf.compile(opt_tf, loss='categorical_crossentropy', metrics=[])
660
661      hist_k_v1 = model_k_v1.fit(x, y, batch_size=5, epochs=10, shuffle=False)
662      hist_k_v2 = model_k_v2.fit(x, y, batch_size=5, epochs=10, shuffle=False)
663      hist_tf = model_tf.fit(x, y, batch_size=5, epochs=10, shuffle=False)
664
665      self.assertAllClose(model_k_v1.get_weights(), model_tf.get_weights())
666      self.assertAllClose(model_k_v1.get_weights(), model_k_v2.get_weights())
667      self.assertAllClose(opt_k_v1.get_weights(), opt_k_v2.get_weights())
668      self.assertAllClose(hist_k_v1.history['loss'], hist_tf.history['loss'])
669      self.assertAllClose(hist_k_v1.history['loss'], hist_k_v2.history['loss'])
670
671  def testNumericEquivalenceForAmsgrad(self):
672    np.random.seed(1331)
673    with self.cached_session():
674      train_samples = 20
675      input_dim = 3
676      num_classes = 2
677      (x, y), _ = testing_utils.get_test_data(
678          train_samples=train_samples,
679          test_samples=10,
680          input_shape=(input_dim,),
681          num_classes=num_classes)
682      y = keras.utils.to_categorical(y)
683
684      num_hidden = 5
685      model_k_v1 = testing_utils.get_small_sequential_mlp(
686          num_hidden=num_hidden, num_classes=num_classes, input_dim=input_dim)
687      model_k_v2 = testing_utils.get_small_sequential_mlp(
688          num_hidden=num_hidden, num_classes=num_classes, input_dim=input_dim)
689      model_k_v2.set_weights(model_k_v1.get_weights())
690
691      opt_k_v1 = optimizers.Adam(amsgrad=True)
692      opt_k_v2 = adam.Adam(amsgrad=True)
693
694      model_k_v1.compile(opt_k_v1, loss='categorical_crossentropy', metrics=[])
695      model_k_v2.compile(opt_k_v2, loss='categorical_crossentropy', metrics=[])
696
697      hist_k_v1 = model_k_v1.fit(x, y, batch_size=5, epochs=10, shuffle=False)
698      hist_k_v2 = model_k_v2.fit(x, y, batch_size=5, epochs=10, shuffle=False)
699
700      self.assertAllClose(model_k_v1.get_weights(), model_k_v2.get_weights())
701      self.assertAllClose(opt_k_v1.get_weights(), opt_k_v2.get_weights())
702      self.assertAllClose(hist_k_v1.history['loss'], hist_k_v2.history['loss'])
703
704
705# Note: These tests are kept in a separate class to avoid bugs in some
706# distributions of Python that break AutoGraph which is used by tf.function.
707class OptimizerWithFunctionTest(test.TestCase):
708
709  def testBasic(self):
710    with context.eager_mode():
711      var = resource_variable_ops.ResourceVariable([1.0, 2.0],
712                                                   dtype=dtypes.float32)
713      loss = lambda: 3 * var
714      opt = adam.Adam(learning_rate=1.0)
715
716      @def_function.function
717      def fn():
718        opt.minimize(loss, [var])
719        return var
720
721      self.assertAllClose([0., 1.], fn(), atol=1e-4)
722      self.assertAllClose([-1, 0.], fn(), atol=1e-4)
723
724  def testVarKeyWithVarCreatedInEager(self):
725    with context.eager_mode():
726      a = variables.Variable([1., 2.], name='var')
727      b = variables.Variable([1.], name='var')
728
729      @test_util.also_run_as_tf_function
730      def var_key_test():
731        self.assertFalse(a._in_graph_mode)
732        self.assertFalse(b._in_graph_mode)
733        var_key_a = optimizer_v2._var_key(a)
734        self.assertStartsWith(var_key_a, 'var_')
735        var_key_b = optimizer_v2._var_key(b)
736        self.assertStartsWith(var_key_b, 'var_')
737        self.assertNotEquals(var_key_a, var_key_b)
738
739      var_key_test()
740
741
742if __name__ == '__main__':
743  test.main()
744