1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tf.keras models using DistributionStrategy."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import os
21from absl.testing import parameterized
22import numpy as np
23
24from tensorflow.contrib.distribute.python import combinations
25from tensorflow.contrib.distribute.python import mirrored_strategy
26from tensorflow.contrib.distribute.python import tpu_strategy
27from tensorflow.python import keras
28from tensorflow.python.data.experimental.ops import cardinality
29from tensorflow.python.data.ops import dataset_ops
30from tensorflow.python.eager import test
31from tensorflow.python.estimator import keras as keras_lib
32from tensorflow.python.estimator import run_config as run_config_lib
33from tensorflow.python.framework import test_util
34from tensorflow.python.keras import testing_utils
35from tensorflow.python.keras.engine import distributed_training_utils
36from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops.parsing_ops import gen_parsing_ops
40from tensorflow.python.platform import gfile
41from tensorflow.python.summary.writer import writer_cache
42from tensorflow.python.training import gradient_descent
43from tensorflow.python.training import rmsprop
44
45_RANDOM_SEED = 1337
46_TRAIN_SIZE = 200
47_INPUT_SIZE = (10,)
48_NUM_CLASS = 2
49
50# Note: Please make sure the tests in this file are also covered in
51# keras_backward_compat_test for features that are supported with both APIs.
52
53
54# TODO(anjalisridhar): Add a decorator that will allow us to run these tests as
55# part of the tf.keras unit tests suite.
56def simple_sequential_model():
57  model = keras.models.Sequential()
58  model.add(keras.layers.Dense(16, activation='relu', input_shape=_INPUT_SIZE))
59  model.add(keras.layers.Dropout(0.1))
60  model.add(keras.layers.Dense(_NUM_CLASS, activation='softmax'))
61  return model
62
63
64def simple_functional_model():
65  a = keras.layers.Input(shape=_INPUT_SIZE)
66  b = keras.layers.Dense(16, activation='relu')(a)
67  b = keras.layers.Dropout(0.1)(b)
68  b = keras.layers.Dense(_NUM_CLASS, activation='softmax')(b)
69  model = keras.models.Model(inputs=[a], outputs=[b])
70  return model
71
72
73def simple_subclassed_model(num_labels=_NUM_CLASS):
74
75  class _SimpleMLP(keras.Model):
76
77    def __init__(self, num_labels):
78      super(_SimpleMLP, self).__init__()
79      self.dense = keras.layers.Dense(num_labels)
80
81    def call(self, inputs):
82      return self.dense(inputs)
83
84  return _SimpleMLP(num_labels)
85
86
87def simple_multi_inputs_multi_outputs_model():
88  input_a = keras.layers.Input(shape=(16,), name='input_a')
89  input_b = keras.layers.Input(shape=(16,), name='input_b')
90
91  merged = keras.layers.concatenate([input_a, input_b], name='merge')
92  output_c = keras.layers.Dense(3, activation='softmax', name='dense_2')(merged)
93  output_d = keras.layers.Dense(2, activation='softmax', name='dense_3')(merged)
94  model = keras.models.Model(
95      inputs=[input_a, input_b], outputs=[output_c, output_d])
96  return model
97
98
99def multi_inputs_multi_outputs_model():
100  input_a = keras.layers.Input(shape=(16,), name='input_a')
101  input_b = keras.layers.Input(shape=(16,), name='input_b')
102  input_m = keras.layers.Input(shape=(8,), dtype='string', name='input_m')
103  dense = keras.layers.Dense(8, name='dense_1')
104
105  interm_a = dense(input_a)
106  # Read m
107  interm_m = keras.layers.Lambda(gen_parsing_ops.string_to_number)(input_m)
108  interm_s = keras.layers.Lambda(lambda k: k[0] * k[1])([interm_m, interm_a])
109  interm_b = dense(input_b)
110  merged = keras.layers.concatenate([interm_s, interm_b], name='merge')
111  output_c = keras.layers.Dense(3, activation='softmax', name='dense_2')(merged)
112  output_d = keras.layers.Dense(2, activation='softmax', name='dense_3')(merged)
113  model = keras.models.Model(
114      inputs=[input_a, input_b, input_m], outputs=[output_c, output_d])
115  model.compile(
116      loss='categorical_crossentropy',
117      optimizer=gradient_descent.GradientDescentOptimizer(0.001),
118      metrics={
119          'dense_2': 'categorical_accuracy',
120          'dense_3': 'categorical_accuracy'
121      })
122  return model
123
124
125def get_ds_train_input_fn():
126  np.random.seed(_RANDOM_SEED)
127  (x_train, y_train), _ = testing_utils.get_test_data(
128      train_samples=_TRAIN_SIZE,
129      test_samples=50,
130      input_shape=_INPUT_SIZE,
131      num_classes=_NUM_CLASS)
132  y_train = keras.utils.to_categorical(y_train)
133
134  dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train))
135  dataset = dataset.batch(32)
136  return dataset
137
138
139def get_ds_test_input_fn():
140  np.random.seed(_RANDOM_SEED)
141  _, (x_test, y_test) = testing_utils.get_test_data(
142      train_samples=_TRAIN_SIZE,
143      test_samples=50,
144      input_shape=_INPUT_SIZE,
145      num_classes=_NUM_CLASS)
146  y_test = keras.utils.to_categorical(y_test)
147
148  dataset = dataset_ops.Dataset.from_tensor_slices((x_test, y_test))
149  dataset = dataset.batch(32)
150  return dataset
151
152
153def get_multi_inputs_multi_outputs_data():
154  (a_train, c_train), (a_test, c_test) = testing_utils.get_test_data(
155      train_samples=_TRAIN_SIZE,
156      test_samples=50,
157      input_shape=(16,),
158      num_classes=3,
159      random_seed=_RANDOM_SEED)
160  (b_train, d_train), (b_test, d_test) = testing_utils.get_test_data(
161      train_samples=_TRAIN_SIZE,
162      test_samples=50,
163      input_shape=(16,),
164      num_classes=2,
165      random_seed=_RANDOM_SEED)
166  (m_train, _), (m_test, _) = testing_utils.get_test_data(
167      train_samples=_TRAIN_SIZE,
168      test_samples=50,
169      input_shape=(8,),
170      num_classes=2,
171      random_seed=_RANDOM_SEED)
172
173  c_train = keras.utils.to_categorical(c_train)
174  c_test = keras.utils.to_categorical(c_test)
175  d_train = keras.utils.to_categorical(d_train)
176  d_test = keras.utils.to_categorical(d_test)
177
178  train_data = {
179      'input_a': a_train,
180      'input_b': b_train,
181      'input_m': m_train,
182      'output_c': c_train,
183      'output_d': d_train
184  }
185  test_data = {
186      'input_a': a_test,
187      'input_b': b_test,
188      'input_m': m_test,
189      'output_c': c_test,
190      'output_d': d_test
191  }
192
193  return (train_data, test_data)
194
195
196def batch_wrapper(dataset, batch_size, distribution, repeat=None):
197  if repeat:
198    dataset = dataset.repeat(repeat)
199  # TPUs currently require fully defined input shapes, drop_remainder ensures
200  # the input will have fully defined shapes.
201  if isinstance(distribution, tpu_strategy.TPUStrategy):
202    return dataset.batch(batch_size, drop_remainder=True)
203  else:
204    return dataset.batch(batch_size)
205
206
207def get_model():
208  x = keras.layers.Input(shape=(3,), name='input')
209  y = keras.layers.Dense(4, name='dense')(x)
210  model = keras.Model(x, y)
211  return model
212
213
214def get_dataset(distribution):
215  inputs = np.zeros((10, 3), dtype=np.float32)
216  targets = np.zeros((10, 4), dtype=np.float32)
217  dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
218  dataset = dataset.repeat(100)
219  dataset = batch_wrapper(dataset, 10, distribution)
220  return dataset
221
222
223def get_predict_dataset(distribution):
224  inputs = np.zeros((10, 3), dtype=np.float32)
225  dataset = dataset_ops.Dataset.from_tensor_slices(inputs)
226  dataset = dataset.repeat(100)
227  dataset = batch_wrapper(dataset, 10, distribution)
228  return dataset
229
230
231def convert_numpy_to_dataset_with_unknown_cardinality(inputs,
232                                                      targets=None):
233  if targets is not None:
234    input_slices = (inputs, targets)
235    dummy_op = (lambda inp, target: True)
236  else:
237    input_slices = inputs
238    dummy_op = (lambda inp: True)
239
240  original_dataset = (dataset_ops.Dataset.from_tensor_slices(
241      input_slices))
242  ds_with_unknown_cardinality = (original_dataset.filter(dummy_op).
243                                 batch(10, drop_remainder=True))
244  return ds_with_unknown_cardinality
245
246
247def multi_input_output_model():
248  a = keras.layers.Input(shape=(3,), name='input_a')
249  b = keras.layers.Input(shape=(5,), name='input_b')
250  # TODO(anjalisridhar): Change the output dimension of the second Dense layer
251  # once the iterator output validation issue has been fixed.
252  dense_1 = keras.layers.Dense(7, name='dense_1')
253  dense_2 = keras.layers.Dense(7, name='dense_2')
254  c = dense_1(a)
255  d = dense_2(b)
256  e = keras.layers.Dropout(0.5, name='dropout')(c)
257  model = keras.models.Model([a, b], [d, e])
258  return model
259
260
261# TODO(josh11b): Add combinations.one_device_strategy_gpu once it works with
262# TestDistributionStrategyWithCallbacks.test_callbacks_in_predict.
263strategies_minus_tpu = [
264    combinations.default_strategy,
265    combinations.one_device_strategy,
266    combinations.one_device_strategy_gpu,
267    combinations.mirrored_strategy_with_gpu_and_cpu,
268    combinations.mirrored_strategy_with_two_gpus,
269    combinations.core_mirrored_strategy_with_gpu_and_cpu,
270    combinations.core_mirrored_strategy_with_two_gpus]
271
272tpu_strategies = [
273    combinations.tpu_strategy,  # steps_per_run=2
274    combinations.tpu_strategy_one_step]
275
276
277def strategy_minus_tpu_combinations():
278  return combinations.combine(distribution=strategies_minus_tpu,
279                              mode=['graph', 'eager'])
280
281
282def tpu_strategy_combinations():
283  return combinations.combine(distribution=tpu_strategies,
284                              mode=['graph'])
285
286
287def all_strategy_combinations():
288  return strategy_minus_tpu_combinations() + tpu_strategy_combinations()
289
290
291def all_strategy_minus_default_and_tpu_combinations():
292  return combinations.combine(
293      distribution=[
294          combinations.one_device_strategy,
295          combinations.one_device_strategy_gpu,
296          combinations.mirrored_strategy_with_gpu_and_cpu,
297          combinations.mirrored_strategy_with_two_gpus,
298          combinations.core_mirrored_strategy_with_gpu_and_cpu,
299          combinations.core_mirrored_strategy_with_two_gpus],
300      mode=['graph', 'eager'])
301
302
303def all_strategy_combinations_minus_default():
304  return (all_strategy_minus_default_and_tpu_combinations() +
305          tpu_strategy_combinations())
306
307
308def strategy_and_optimizer_combinations():
309  return combinations.times(
310      all_strategy_combinations(),
311      combinations.combine(optimizer=[
312          combinations.adagrad_optimizer_v1_fn,
313          combinations.adagrad_optimizer_keras_v2_fn,
314          combinations.adam_optimizer_v1_fn,
315          combinations.adam_optimizer_keras_v2_fn,
316          combinations.gradient_descent_optimizer_v1_fn,
317          combinations.gradient_descent_optimizer_keras_v2_fn,
318          combinations.rmsprop_optimizer_v1_fn,
319          combinations.rmsprop_optimizer_keras_v2_fn
320      ]))
321
322
323class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase,
324                                        parameterized.TestCase):
325
326  def setUp(self):
327    super(TestEstimatorDistributionStrategy, self).setUp()
328    self._base_dir = os.path.join(self.get_temp_dir(),
329                                  'keras_mirrored_strategy_test')
330    gfile.MakeDirs(self._base_dir)
331    self._config = run_config_lib.RunConfig(
332        tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir)
333
334  def tearDown(self):
335    super(TestEstimatorDistributionStrategy, self).tearDown()
336    writer_cache.FileWriterCache.clear()
337    if os.path.isdir(self._base_dir):
338      gfile.DeleteRecursively(self._base_dir)
339
340  @combinations.generate(combinations.combine(
341      distribution=[
342          combinations.mirrored_strategy_with_gpu_and_cpu,
343          combinations.mirrored_strategy_with_two_gpus,
344          combinations.core_mirrored_strategy_with_gpu_and_cpu,
345          combinations.core_mirrored_strategy_with_two_gpus],
346      mode=['graph']))
347  def test_train_functional_with_distribution_strategy(self, distribution):
348    keras_model = simple_functional_model()
349    keras_model.compile(
350        loss='categorical_crossentropy',
351        metrics=[keras.metrics.CategoricalAccuracy()],
352        optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01))
353    config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED,
354                                      model_dir=self._base_dir,
355                                      train_distribute=distribution,
356                                      eval_distribute=distribution)
357    with self.cached_session():
358      est_keras = keras_lib.model_to_estimator(
359          keras_model=keras_model, config=config)
360      before_eval_results = est_keras.evaluate(
361          input_fn=get_ds_test_input_fn, steps=1)
362      est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16)
363      after_eval_results = est_keras.evaluate(input_fn=get_ds_test_input_fn,
364                                              steps=1)
365      self.assertLess(after_eval_results['loss'], before_eval_results['loss'])
366
367    writer_cache.FileWriterCache.clear()
368    gfile.DeleteRecursively(self._config.model_dir)
369
370  @combinations.generate(combinations.combine(
371      distribution=[
372          combinations.mirrored_strategy_with_gpu_and_cpu,
373          combinations.mirrored_strategy_with_two_gpus,
374          combinations.core_mirrored_strategy_with_gpu_and_cpu,
375          combinations.core_mirrored_strategy_with_two_gpus],
376      mode=['graph']))
377  def test_train_sequential_with_distribution_strategy(self, distribution):
378    keras_model = simple_sequential_model()
379    keras_model.compile(
380        loss='categorical_crossentropy',
381        metrics=[keras.metrics.CategoricalAccuracy()],
382        optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01))
383    config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED,
384                                      model_dir=self._base_dir,
385                                      train_distribute=distribution)
386    with self.cached_session():
387      est_keras = keras_lib.model_to_estimator(
388          keras_model=keras_model, config=config)
389      before_eval_results = est_keras.evaluate(
390          input_fn=get_ds_test_input_fn, steps=1)
391      est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16)
392      after_eval_results = est_keras.evaluate(input_fn=get_ds_test_input_fn,
393                                              steps=1)
394      self.assertLess(after_eval_results['loss'], before_eval_results['loss'])
395
396    writer_cache.FileWriterCache.clear()
397    gfile.DeleteRecursively(self._config.model_dir)
398
399  @combinations.generate(combinations.combine(
400      distribution=[
401          combinations.mirrored_strategy_with_gpu_and_cpu,
402          combinations.core_mirrored_strategy_with_gpu_and_cpu],
403      mode=['graph']))
404  def test_multi_inputs_multi_outputs_with_input_fn_as_dict(self, distribution):
405    train_data, test_data = get_multi_inputs_multi_outputs_data()
406
407    def train_input_fn():
408      input_dict = {
409          'input_a': train_data['input_a'],
410          'input_b': train_data['input_b'],
411          'input_m': train_data['input_m'].astype(np.str)
412      }
413      output_dict = {
414          'dense_2': train_data['output_c'],
415          'dense_3': train_data['output_d']
416      }
417      return dataset_ops.Dataset.from_tensor_slices((input_dict,
418                                                     output_dict)).batch(16)
419
420    def eval_input_fn():
421      input_dict = {
422          'input_a': test_data['input_a'],
423          'input_b': test_data['input_b'],
424          'input_m': test_data['input_m'].astype(np.str)
425      }
426      output_dict = {
427          'dense_2': test_data['output_c'],
428          'dense_3': test_data['output_d']
429      }
430      return dataset_ops.Dataset.from_tensor_slices((input_dict,
431                                                     output_dict)).batch(16)
432
433    self.do_test_multi_inputs_multi_outputs_with_input_fn(
434        distribution, train_input_fn, eval_input_fn)
435
436  def do_test_multi_inputs_multi_outputs_with_input_fn(
437      self, distribution, train_input_fn, eval_input_fn):
438    config = run_config_lib.RunConfig(
439        tf_random_seed=_RANDOM_SEED,
440        model_dir=self._base_dir,
441        train_distribute=distribution)
442    with self.cached_session():
443      model = multi_inputs_multi_outputs_model()
444      est_keras = keras_lib.model_to_estimator(keras_model=model, config=config)
445      baseline_eval_results = est_keras.evaluate(
446          input_fn=eval_input_fn, steps=1)
447      est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
448      eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
449      self.assertLess(eval_results['loss'], baseline_eval_results['loss'])
450
451  @combinations.generate(combinations.combine(
452      distribution=[
453          combinations.mirrored_strategy_with_gpu_and_cpu,
454          combinations.core_mirrored_strategy_with_gpu_and_cpu],
455      mode=['graph']))
456  def test_keras_optimizer_with_distribution_strategy(self, distribution):
457    keras_model = simple_sequential_model()
458    keras_model.compile(
459        loss='categorical_crossentropy',
460        optimizer=keras.optimizers.rmsprop(lr=0.01))
461
462    config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED,
463                                      model_dir=self._base_dir,
464                                      train_distribute=distribution)
465    with self.cached_session():
466      est_keras = keras_lib.model_to_estimator(keras_model=keras_model,
467                                               config=config)
468      with self.assertRaisesRegexp(ValueError,
469                                   'Only TensorFlow native optimizers are '
470                                   'supported with DistributionStrategy.'):
471        est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16)
472
473    writer_cache.FileWriterCache.clear()
474    gfile.DeleteRecursively(self._config.model_dir)
475
476
477class TestDistributionStrategyWithNumpyArrays(test.TestCase,
478                                              parameterized.TestCase):
479
480  @combinations.generate(all_strategy_combinations())
481  def test_calculating_input_params_no_steps_no_batch_size(self, distribution):
482    # Calculate the per_replica_batch_size scaling factor for strategies
483    # that use per_core_batch_size
484    replica_scale_factor = 1.0
485    if not distributed_training_utils.global_batch_size_supported(distribution):
486      replica_scale_factor = distribution.num_replicas_in_sync
487
488    with self.cached_session():
489      # Input samples of different sizes
490      input_20_samples = np.zeros((20, 3), dtype=np.float32)
491      input_63_samples = np.zeros((63, 3), dtype=np.float32)
492      input_64_samples = np.zeros((64, 3), dtype=np.float32)
493
494      # Default global batch size 32 for input with 64 samples run in 2 steps
495      steps, batch_size = distributed_training_utils.get_input_params(
496          distribution, input_64_samples, steps=None, batch_size=None)
497      self.assertEqual(batch_size, 32 // replica_scale_factor)
498      self.assertEqual(steps, 2)
499
500      # Computed global batch size 20 is lower than 32 if we pass less samples.
501      steps, batch_size = distributed_training_utils.get_input_params(
502          distribution, input_20_samples, steps=None, batch_size=None)
503      self.assertEqual(batch_size, 20 // replica_scale_factor)
504      self.assertEqual(steps, 1)
505
506      #  Default global batch size 32 cannot be used with 63 samples.
507      with self.assertRaisesRegexp(ValueError, 'not divisible by batch size'):
508        distributed_training_utils.get_input_params(
509            distribution, input_63_samples, steps=None, batch_size=None)
510
511  @combinations.generate(all_strategy_combinations())
512  def test_calculating_input_params_with_steps_no_batch_size(self,
513                                                             distribution):
514    # Calculate the per_replica_batch_size scaling factor for strategies
515    # that use per_core_batch_size
516    replica_scale_factor = 1.0
517    if not distributed_training_utils.global_batch_size_supported(distribution):
518      replica_scale_factor = distribution.num_replicas_in_sync
519
520    with self.cached_session():
521      # Input samples of different sizes
522      input_63_samples = np.zeros((63, 3), dtype=np.float32)
523      input_64_samples = np.zeros((64, 3), dtype=np.float32)
524
525      # Computed global batch size is correct for number of specified 1 step
526      steps, batch_size = distributed_training_utils.get_input_params(
527          distribution, input_64_samples, steps=1, batch_size=None)
528      self.assertEqual(batch_size, 64 // replica_scale_factor)
529      self.assertEqual(steps, 1)
530
531      # Computed global batch size is correct for number of specified 2 steps
532      steps, batch_size = distributed_training_utils.get_input_params(
533          distribution, input_64_samples, steps=2, batch_size=None)
534      self.assertEqual(batch_size, 32 // replica_scale_factor)
535      self.assertEqual(steps, 2)
536
537      # All samples can not be consumed in specified number of steps
538      with self.assertRaisesRegexp(ValueError, 'not divisible by steps'):
539        distributed_training_utils.get_input_params(
540            distribution, input_63_samples, steps=2, batch_size=None)
541
542      # This cases is different for different strategies due to the
543      # difference in supported batch size being global or per-replica.
544      if replica_scale_factor == 1:
545        # Computed global batch size is correct even if not sharadable
546        steps, batch_size = distributed_training_utils.get_input_params(
547            distribution, input_63_samples, steps=3, batch_size=None)
548        self.assertEqual(batch_size, 21)
549        self.assertEqual(steps, 3)
550      else:
551        # Computed global batch size can not be sharded across replicas
552        with self.assertRaisesRegexp(ValueError, 'could not be sharded evenly '
553                                     'across the sync replicas'):
554          distributed_training_utils.get_input_params(
555              distribution, input_63_samples, steps=1, batch_size=None)
556
557  @combinations.generate(all_strategy_combinations())
558  def test_calculating_input_params_no_steps_with_batch_size(self,
559                                                             distribution):
560    # Calculate the per_replica_batch_size scaling factor for strategies
561    # that use per_core_batch_size
562    replica_scale_factor = 1.0
563    if not distributed_training_utils.global_batch_size_supported(distribution):
564      replica_scale_factor = distribution.num_replicas_in_sync
565
566    with self.cached_session():
567      input_64_samples = np.zeros((64, 3), dtype=np.float32)
568
569      # Computed steps is correct for specified batch size
570      steps, batch_size = distributed_training_utils.get_input_params(
571          distribution, input_64_samples, steps=None, batch_size=16)
572      self.assertEqual(batch_size, 16)
573      self.assertEqual(steps, 4 // replica_scale_factor)
574
575      # Computed steps is correct for specified batch size
576      steps, batch_size = distributed_training_utils.get_input_params(
577          distribution, input_64_samples, steps=None, batch_size=32)
578      self.assertEqual(batch_size, 32)
579      self.assertEqual(steps, 2 // replica_scale_factor)
580
581      # Number of samples is not divisible by the global batch size
582      with self.assertRaisesRegexp(ValueError, 'not divisible by batch size'):
583        distributed_training_utils.get_input_params(
584            distribution, input_64_samples, steps=None, batch_size=20)
585
586      # Number of samples is not divisible by the global batch size
587      with self.assertRaisesRegexp(ValueError, 'not divisible by batch size'):
588        distributed_training_utils.get_input_params(
589            distribution, input_64_samples, steps=None, batch_size=3)
590
591  @combinations.generate(all_strategy_combinations())
592  def test_calculating_input_params_with_steps_with_batch_size(self,
593                                                               distribution):
594    with self.cached_session():
595      input_64_samples = np.zeros((64, 3), dtype=np.float32)
596
597      # No change to steps and batch size if both specified and feasible
598      steps, batch_size = distributed_training_utils.get_input_params(
599          distribution, input_64_samples, steps=5, batch_size=3)
600      self.assertEqual(batch_size, 3)
601      self.assertEqual(steps, 5)
602
603      # Number of samples is less than global batch size * steps
604      with self.assertRaisesRegexp(ValueError, 'less than samples required'):
605        distributed_training_utils.get_input_params(
606            distribution, input_64_samples, steps=10, batch_size=13)
607
608  @combinations.generate(all_strategy_combinations())
609  def test_calling_model_with_numpy_arrays(self, distribution):
610    with self.cached_session():
611      with distribution.scope():
612        model = get_model()
613        optimizer = gradient_descent.GradientDescentOptimizer(0.001)
614        loss = 'mse'
615        metrics = ['mae']
616        model.compile(optimizer, loss, metrics=metrics)
617
618        inputs = np.zeros((64, 3), dtype=np.float32)
619        targets = np.zeros((64, 4), dtype=np.float32)
620
621        # Call fit with validation data
622        model.fit(inputs, targets, epochs=1, batch_size=2, verbose=0,
623                  validation_data=(inputs, targets))
624
625        # TODO(anjalisridhar): We need tests for when the batch size and steps
626        # are smaller and results in a 0 batch_size and steps value.
627        model.evaluate(inputs, targets)
628        # with steps
629        model.evaluate(inputs, targets, steps=2)
630        # with batch_size
631        model.evaluate(inputs, targets, batch_size=8)
632
633        model.predict(inputs)
634        # with steps
635        model.predict(inputs, steps=2)
636        # with batch_size
637        model.predict(inputs, batch_size=8)
638
639  @combinations.generate(all_strategy_combinations())
640  def test_calling_model_with_nested_numpy_arrays(self, distribution):
641    with self.cached_session():
642      with distribution.scope():
643        model = multi_input_output_model()
644        optimizer = gradient_descent.GradientDescentOptimizer(
645            learning_rate=0.001)
646        loss = 'mse'
647        model.compile(optimizer, loss)
648
649      input_a_np = np.asarray(np.random.random((64, 3)), dtype=np.float32)
650      input_b_np = np.asarray(np.random.random((64, 5)), dtype=np.float32)
651      inputs = [input_a_np, input_b_np]
652
653      output_d_np = np.asarray(np.random.random((64, 7)), dtype=np.float32)
654      output_e_np = np.asarray(np.random.random((64, 7)), dtype=np.float32)
655      targets = [output_d_np, output_e_np]
656
657      # Call fit with validation data
658      model.fit(inputs, targets, epochs=1, batch_size=8, verbose=0)
659
660      # TODO(anjalisridhar): We need tests for when the batch size and steps are
661      # smaller and results in a 0 batch_size and steps value.
662      model.evaluate(inputs, targets)
663      # with steps
664      model.evaluate(inputs, targets, steps=2)
665      # with batch_size
666      model.evaluate(inputs, targets, batch_size=8)
667
668      model.predict(inputs)
669      # with steps
670      model.predict(inputs, steps=2)
671      # with batch_size
672      model.predict(inputs, batch_size=8)
673
674  @combinations.generate(combinations.combine(
675      distribution=strategies_minus_tpu, mode=['graph']))
676  def test_numpy_with_sample_weights(self, distribution):
677    with self.cached_session():
678      with distribution.scope():
679        model = get_model()
680        optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
681        loss = 'mse'
682        model.compile(optimizer, loss)
683
684      inputs = np.zeros((20, 3), np.float32)
685      targets = np.zeros((20, 4), np.float32)
686      sample_weights = np.ones((20), np.float32)
687
688      model.fit(inputs, targets, sample_weight=sample_weights, epochs=1,
689                steps_per_epoch=2, verbose=1)
690
691  @combinations.generate(all_strategy_combinations())
692  def test_flatten_predict_outputs(self, distribution):
693    with self.cached_session():
694      with distribution.scope():
695        model = multi_input_output_model()
696        optimizer = gradient_descent.GradientDescentOptimizer(
697            learning_rate=0.001)
698        loss = 'mse'
699        model.compile(optimizer, loss)
700
701      # We take 6 input samples with each input having a dimension of 3 or 5.
702      input_a_np = np.asarray(np.random.random((6, 3)), dtype=np.float32)
703      input_b_np = np.asarray(np.random.random((6, 5)), dtype=np.float32)
704      inputs = [input_a_np, input_b_np]
705
706      outs = model.predict(inputs, steps=1)
707      # `predict` a list that is equal in length to the number of model outputs.
708      # In this test our model has two outputs and each element of `outs`
709      # corresponds to all the samples of one of the model outputs.
710      self.assertLen(outs, 2)
711      # Each of the output samples have a dimension of 7. We should process all
712      # the available input samples(6).
713      self.assertAllEqual([6, 7], outs[0].shape)
714      self.assertAllEqual([6, 7], outs[1].shape)
715
716  @combinations.generate(tpu_strategy_combinations())
717  def test_predict_with_partial_batch(self, distribution):
718    with self.cached_session():
719      optimizer = gradient_descent.GradientDescentOptimizer(0.001)
720      loss = 'mse'
721
722      with distribution.scope():
723        model_with_ds_strategy = get_model()
724        model_with_ds_strategy.compile(optimizer, loss)
725
726      cpu_model = get_model()
727      cpu_model.compile(optimizer, loss)
728
729      inputs = np.zeros((10, 3), dtype=np.float32)
730
731      # As sample size is 10, we batch by 4 so that the last batch is
732      # a partial batch. Also `predict()` using numpy array as inputs without
733      # distribution strategy uses entire sample as a single batch. As so,
734      # we remove parameters `batch_size` and `steps`.
735      predict_ground_truth = cpu_model.predict(inputs)
736      cpu_model.set_weights(model_with_ds_strategy.get_weights())
737      self.assertAllClose(
738          model_with_ds_strategy.predict(inputs, batch_size=4, steps=3),
739          predict_ground_truth,
740          atol=1e-5,
741          rtol=1e-5)
742      # Test that `steps` is inferred correctly when final partial batch exists.
743      self.assertAllClose(
744          model_with_ds_strategy.predict(inputs, batch_size=4),
745          predict_ground_truth,
746          atol=1e-5,
747          rtol=1e-5)
748
749  @combinations.generate(tpu_strategy_combinations())
750  def test_predict_multi_output_model_with_partial_batch(
751      self, distribution):
752    with self.cached_session():
753      optimizer = gradient_descent.GradientDescentOptimizer(0.001)
754      loss = 'mse'
755
756      with distribution.scope():
757        model_with_ds_strategy = simple_multi_inputs_multi_outputs_model()
758        model_with_ds_strategy.compile(optimizer, loss)
759
760      cpu_model = simple_multi_inputs_multi_outputs_model()
761      cpu_model.compile(optimizer, loss)
762
763      input_data, _ = get_multi_inputs_multi_outputs_data()
764      input_dict = {
765          'input_a': input_data['input_a'],
766          'input_b': input_data['input_b'],
767      }
768
769      # As sample size is 200, we batch by 18 so that the last batch is
770      # a partial batch. Also `fit()` using numpy array as inputs without
771      # distribution strategy uses entire sample as a single batch. As so,
772      # we remove parameters `batch_size` and `steps`.
773      cpu_model.set_weights(model_with_ds_strategy.get_weights())
774      self.assertAllClose(
775          model_with_ds_strategy.predict(input_dict, batch_size=18, steps=12),
776          cpu_model.predict(input_dict),
777          atol=1e-4, rtol=1e-4)
778
779
780class TestDistributionStrategyWithDatasets(test.TestCase,
781                                           parameterized.TestCase):
782
783  @combinations.generate(all_strategy_combinations())
784  def test_calling_model_on_same_dataset(self, distribution):
785    with self.cached_session():
786      with distribution.scope():
787        model = get_model()
788        optimizer = gradient_descent.GradientDescentOptimizer(0.001)
789        loss = 'mse'
790        metrics = ['mae', keras.metrics.CategoricalAccuracy()]
791        model.compile(optimizer, loss, metrics=metrics)
792
793      dataset = get_dataset(distribution)
794
795      # Call fit with validation data
796      model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
797                validation_data=dataset, validation_steps=2)
798      model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
799                validation_data=dataset, validation_steps=2)
800      model.predict(get_predict_dataset(distribution), steps=2)
801
802  @combinations.generate(all_strategy_combinations())
803  def test_model_interleaved_eval_same_as_direct_eval(self, distribution):
804    with self.cached_session():
805      with distribution.scope():
806        user_controlled_model = get_model()
807        user_controlled_model.compile(
808            gradient_descent.GradientDescentOptimizer(0.001),
809            loss='mse',
810            metrics=['mae', keras.metrics.CategoricalAccuracy()])
811
812        interleaved_model = get_model()
813        interleaved_model.set_weights(user_controlled_model.get_weights())
814        interleaved_model.compile(
815            gradient_descent.GradientDescentOptimizer(0.001),
816            loss='mse',
817            metrics=['mae', keras.metrics.CategoricalAccuracy()])
818
819      dataset = get_dataset(distribution)
820
821      # Call fit with validation interleaved
822      interleaved_output = interleaved_model.fit(
823          dataset, epochs=2, steps_per_epoch=2, verbose=1,
824          validation_data=dataset, validation_steps=2, shuffle=False)
825
826      # Manually control the validation running after each epoch.
827      user_controlled_output = []
828      for _ in range(2):
829        user_controlled_model.fit(
830            dataset, epochs=1, steps_per_epoch=2, verbose=1, shuffle=False)
831        user_controlled_output.append(
832            user_controlled_model.evaluate(dataset, steps=2))
833
834      self.assertEqual(interleaved_output.history['val_loss'],
835                       [x[0] for x in user_controlled_output])
836      val_mean_absolute_error = interleaved_output.history.get(
837          'val_mean_absolute_error')
838      if not val_mean_absolute_error:
839        # The name of the metric changed in TF2.0
840        val_mean_absolute_error = interleaved_output.history['val_mae']
841      self.assertEqual(val_mean_absolute_error,
842                       [x[1] for x in user_controlled_output])
843      self.assertEqual(interleaved_output.history['val_categorical_accuracy'],
844                       [x[2] for x in user_controlled_output])
845
846  # TODO(priyag): Enable this test for TPU. Currently tuples/dict don't work
847  # as clone_model's input_tensors argument only seems to accept list and not
848  # tuples or dict.
849
850  @combinations.generate(combinations.combine(
851      distribution=[
852          combinations.mirrored_strategy_with_gpu_and_cpu,
853          combinations.core_mirrored_strategy_with_gpu_and_cpu],
854      mode=['graph', 'eager']))
855  def test_fit_with_tuple_and_dict_dataset_inputs(self, distribution):
856    with self.cached_session():
857      with distribution.scope():
858        model = multi_input_output_model()
859        optimizer = gradient_descent.GradientDescentOptimizer(
860            learning_rate=0.001)
861        loss = 'mse'
862        metrics = ['mae', keras.metrics.CategoricalAccuracy()]
863        model.compile(optimizer, loss, metrics=metrics)
864
865      input_a_np = np.random.random((10, 3))
866      input_b_np = np.random.random((10, 5))
867      output_d_np = np.random.random((10, 7))
868      output_e_np = np.random.random((10, 7))
869
870      # Test with tuples
871      dataset_tuple = dataset_ops.Dataset.from_tensor_slices((
872          (input_a_np, input_b_np), (output_d_np, output_e_np)))
873      dataset_tuple = dataset_tuple.repeat(100)
874      dataset_tuple = dataset_tuple.batch(10)
875
876      model.fit(dataset_tuple, epochs=1, steps_per_epoch=2, verbose=1)
877
878      # Test with dict
879      dataset_dict = dataset_ops.Dataset.from_tensor_slices((
880          {'input_a': input_a_np, 'input_b': input_b_np},
881          (output_d_np, output_e_np)))
882      dataset_dict = dataset_dict.repeat(100)
883      dataset_dict = dataset_dict.batch(10)
884
885      model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1)
886
887  @combinations.generate(all_strategy_combinations())
888  def test_fit_eval_and_predict_methods_on_dataset_without_steps(
889      self, distribution):
890    with self.cached_session():
891      with distribution.scope():
892        model = get_model()
893        optimizer = gradient_descent.GradientDescentOptimizer(0.001)
894        loss = 'mse'
895        metrics = ['mae', keras.metrics.CategoricalAccuracy()]
896        model.compile(optimizer, loss, metrics=metrics)
897
898      inputs = np.zeros((1000, 3), dtype=np.float32)
899      targets = np.zeros((1000, 4), dtype=np.float32)
900      # steps/steps_per_epoch are calculated when using numpy arrays as
901      # input data.
902      fit_with_numpy = model.fit(inputs, targets, epochs=1,
903                                 batch_size=10).history
904      eval_with_numpy = model.evaluate(inputs, targets, batch_size=10)
905      predict_with_numpy = model.predict(inputs, batch_size=10)
906
907      dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
908      dataset = dataset.batch(10, drop_remainder=True)
909      fit_with_ds = model.fit(dataset, epochs=1).history
910      eval_with_ds = model.evaluate(dataset)
911      predict_dataset = dataset_ops.Dataset.from_tensor_slices(inputs)
912      predict_dataset = predict_dataset.batch(10, drop_remainder=True)
913      predict_with_ds = model.predict(predict_dataset)
914      self.assertAllClose(
915          fit_with_numpy, fit_with_ds, atol=1e-4, rtol=1e-4)
916      self.assertAllClose(
917          eval_with_numpy, eval_with_ds, atol=1e-4, rtol=1e-4)
918      self.assertAllClose(
919          predict_with_numpy, predict_with_ds, atol=1e-4, rtol=1e-4)
920
921  @combinations.generate(all_strategy_combinations())
922  def test_on_dataset_with_unknown_cardinality_without_steps(
923      self, distribution):
924    with self.cached_session():
925      with distribution.scope():
926        model = get_model()
927        optimizer = gradient_descent.GradientDescentOptimizer(0.001)
928        loss = 'mse'
929        metrics = ['mae', keras.metrics.CategoricalAccuracy()]
930        model.compile(optimizer, loss, metrics=metrics)
931
932      inputs = np.zeros((1000, 3), dtype=np.float32)
933      targets = np.zeros((1000, 4), dtype=np.float32)
934      # steps/steps_per_epoch are calculated when using numpy arrays as
935      # input data.
936      fit_with_numpy = model.fit(inputs, targets, epochs=1,
937                                 batch_size=10).history
938      fit_with_numpy_multiple_epochs = model.fit(
939          inputs, targets, epochs=2, batch_size=10).history
940      eval_with_numpy = model.evaluate(inputs, targets, batch_size=10)
941      predict_with_numpy = model.predict(inputs, batch_size=10)
942
943      dataset = convert_numpy_to_dataset_with_unknown_cardinality(
944          inputs, targets)
945      predict_dataset = convert_numpy_to_dataset_with_unknown_cardinality(
946          inputs)
947
948      self.assertEqual(keras.backend.get_value(cardinality.cardinality(
949          dataset)), cardinality.UNKNOWN)
950      self.assertEqual(keras.backend.get_value(cardinality.cardinality(
951          predict_dataset)), cardinality.UNKNOWN)
952
953      eval_with_ds = model.evaluate(dataset)
954      predict_with_ds = model.predict(predict_dataset)
955      self.assertAllClose(
956          eval_with_numpy, eval_with_ds, atol=1e-4, rtol=1e-4)
957      self.assertAllClose(
958          predict_with_numpy, predict_with_ds, atol=1e-4, rtol=1e-4)
959
960      if (distributed_training_utils.is_tpu_strategy(distribution) and
961          distribution.extended.steps_per_run != 1):
962        with self.assertRaisesRegexp(ValueError, '`steps_per_epoch` '
963                                     'should be specified'):
964          fit_with_ds = model.fit(dataset, epochs=1)
965      else:
966        fit_with_ds = model.fit(dataset,
967                                epochs=1).history
968        fit_with_ds_multiple_epochs = model.fit(dataset,
969                                                epochs=2).history
970        self.assertAllClose(
971            fit_with_numpy, fit_with_ds, atol=1e-4, rtol=1e-4)
972        self.assertAllClose(
973            fit_with_numpy_multiple_epochs,
974            fit_with_ds_multiple_epochs, atol=1e-4, rtol=1e-4)
975
976  @combinations.generate(all_strategy_combinations())
977  def test_fit_eval_and_predict_methods_on_dataset(self, distribution):
978    with self.cached_session():
979      with distribution.scope():
980        model = get_model()
981        optimizer = gradient_descent.GradientDescentOptimizer(0.001)
982        loss = 'mse'
983        metrics = ['mae', keras.metrics.CategoricalAccuracy()]
984        model.compile(optimizer, loss, metrics=metrics)
985
986      dataset = get_dataset(distribution)
987
988      model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
989      model.evaluate(dataset, steps=2, verbose=1)
990      model.predict(get_predict_dataset(distribution), steps=2)
991
992  @combinations.generate(strategy_and_optimizer_combinations())
993  def test_fit_eval_and_predict_with_optimizer(self, distribution, optimizer):
994    with self.cached_session():
995      with distribution.scope():
996        model = get_model()
997        loss = 'mse'
998        model.compile(optimizer(), loss)
999
1000      dataset = get_dataset(distribution)
1001
1002      model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
1003      model.evaluate(dataset, steps=2, verbose=1)
1004      model.predict(get_predict_dataset(distribution), steps=2)
1005
1006  @combinations.generate(strategy_minus_tpu_combinations())
1007  def test_dataset_with_sample_weights(self, distribution):
1008    with self.cached_session():
1009      with distribution.scope():
1010        model = get_model()
1011        optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
1012        loss = 'mse'
1013        model.compile(optimizer, loss)
1014
1015      inputs = np.zeros((10, 3), np.float32)
1016      targets = np.zeros((10, 4), np.float32)
1017      sample_weights = np.ones((10), np.float32)
1018      dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets,
1019                                                        sample_weights))
1020      dataset = dataset.repeat()
1021      dataset = dataset.batch(10)
1022
1023      model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
1024      model.evaluate(dataset, steps=2, verbose=1)
1025      model.predict(dataset, steps=2)
1026
1027  @combinations.generate(combinations.combine(
1028      distribution=[
1029          combinations.mirrored_strategy_with_gpu_and_cpu,
1030          combinations.core_mirrored_strategy_with_gpu_and_cpu],
1031      mode=['graph', 'eager']))
1032  # TODO(b/120943676, b/120957836): Re-enable once the validation code is
1033  # restored.
1034  def DISABLED_test_dataset_wrong_input_shape(self, distribution):
1035    with self.cached_session():
1036      with distribution.scope():
1037        model = get_model()
1038        optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
1039        loss = 'mse'
1040        model.compile(optimizer, loss)
1041
1042      # Wrong input shape
1043      inputs = np.zeros((10, 5), dtype=np.float32)
1044      targets = np.zeros((10, 4), dtype=np.float32)
1045      dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
1046      dataset = dataset.repeat(100)
1047      dataset = dataset.batch(10)
1048
1049      with self.assertRaisesRegexp(ValueError,
1050                                   'expected input to have shape'):
1051        model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0)
1052
1053  @combinations.generate(combinations.combine(
1054      distribution=[combinations.mirrored_strategy_with_gpu_and_cpu],
1055      mode=['graph', 'eager']))
1056  # TODO(b/120943676, b/120957836): Re-enable once the validation code is
1057  # restored.
1058  def DISABLED_test_dataset_no_batch_input_validation(self, distribution):
1059    with self.cached_session():
1060      with distribution.scope():
1061        model = get_model()
1062        optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
1063        loss = 'mse'
1064        model.compile(optimizer, loss)
1065
1066      # User forgets to batch the dataset
1067      inputs = np.zeros((10, 3), dtype=np.float32)
1068      targets = np.zeros((10, 4), dtype=np.float32)
1069      dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
1070      dataset = dataset.repeat(100)
1071
1072      with self.assertRaisesRegexp(ValueError, 'expected input to have shape'):
1073        model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0)
1074
1075  @combinations.generate(combinations.combine(
1076      distribution=[combinations.tpu_strategy_one_step],
1077      mode=['graph']))
1078  def test_dataset_input_shape_fully_defined(self, distribution):
1079    with self.cached_session():
1080      with distribution.scope():
1081        model = get_model()
1082        optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
1083        loss = 'mse'
1084        model.compile(optimizer, loss)
1085
1086      dataset = get_dataset(distribution)
1087      # Input shapes are not fully known. Batch dimension is unknown as we are
1088      # not using the drop_remainder argument.
1089      dataset = dataset.repeat(100).batch(10)
1090
1091      with self.assertRaisesRegexp(ValueError, 'requires fully defined shapes'):
1092        model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0)
1093
1094  @combinations.generate(combinations.combine(
1095      distribution=[
1096          combinations.mirrored_strategy_with_gpu_and_cpu,
1097          combinations.mirrored_strategy_with_two_gpus,
1098          combinations.core_mirrored_strategy_with_gpu_and_cpu,
1099          combinations.core_mirrored_strategy_with_two_gpus],
1100      mode=['graph', 'eager']))
1101  def test_learning_phase_value(self, distribution):
1102    # TODO(anjalisridhar): Modify this test to use Lambdas since we can compare
1103    # meaningful values. Currently we don't pass the learning phase if the
1104    # Lambda layer uses the learning phase.
1105    with self.cached_session():
1106      with distribution.scope():
1107        x = keras.layers.Input(shape=(1,), name='input')
1108        y = keras.layers.Dense(1, kernel_initializer='ones')(x)
1109        z = keras.layers.Dropout(0.9999)(y)
1110        model = keras.Model(x, z)
1111        initial_weights = model.get_weights()
1112
1113        optimizer = gradient_descent.GradientDescentOptimizer(0.005)
1114        loss = 'mse'
1115        metrics = ['acc']
1116        model.compile(optimizer, loss, metrics=metrics)
1117
1118      batch_size = 8
1119      if isinstance(distribution, mirrored_strategy.CoreMirroredStrategy):
1120        # CoreMirroredStrategy uses global batch size.
1121        batch_size = 8 * distribution.num_replicas_in_sync
1122
1123      inputs = np.ones((10, 1), dtype=np.float32)
1124      targets = np.ones((10, 1), dtype=np.float32)
1125      dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
1126      dataset = dataset.repeat().batch(batch_size)
1127      hist = model.fit(dataset, epochs=1, steps_per_epoch=20, verbose=1)
1128      self.assertAlmostEqual(hist.history['acc'][0], 0, 0)
1129
1130      with distribution.scope():
1131        model.set_weights(initial_weights)
1132      # TODO(psv/anjalisridhar): Enable these lines after we fix b/117431185.
1133      # evaluate_output = model.evaluate(dataset, steps=20)
1134      # self.assertAlmostEqual(evaluate_output[1], 1, 0)
1135
1136      inputs = np.ones((10, 1), dtype=np.float32)
1137      predict_dataset = dataset_ops.Dataset.from_tensor_slices(inputs)
1138
1139      predict_dataset = predict_dataset.repeat().batch(batch_size)
1140      output = model.predict(predict_dataset, steps=10)
1141      # `predict` runs for 10 steps
1142      ref_output = np.ones((160, 1), dtype=np.float32)
1143      self.assertArrayNear(output, ref_output, 1e-1)
1144
1145  @combinations.generate(all_strategy_combinations())
1146  def testOptimizerWithCallbacks(self, distribution):
1147    with self.cached_session():
1148      with distribution.scope():
1149        model = get_model()
1150        optimizer = gradient_descent_keras.SGD(0.01)
1151        loss = 'mse'
1152        model.compile(optimizer, loss)
1153
1154      dataset = get_dataset(distribution)
1155
1156      def schedule(_):
1157        return 0.001
1158
1159      model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
1160                callbacks=[keras.callbacks.LearningRateScheduler(schedule)])
1161      self.assertAllClose(0.001, keras.backend.get_value(model.optimizer.lr))
1162
1163  @combinations.generate(tpu_strategy_combinations())
1164  def test_predict_with_dataset_with_partial_batch(self, distribution):
1165    with self.cached_session():
1166      optimizer = gradient_descent.GradientDescentOptimizer(0.001)
1167      loss = 'mse'
1168
1169      with distribution.scope():
1170        model_with_ds_strategy = get_model()
1171        model_with_ds_strategy.compile(optimizer, loss)
1172
1173      cpu_model = get_model()
1174      cpu_model.compile(optimizer, loss)
1175
1176      inputs = np.zeros((10, 3), dtype=np.float32)
1177      dataset = dataset_ops.Dataset.from_tensor_slices((inputs))
1178
1179      # As sample size is 10, we batch by 4 so that the last batch is
1180      # a partial batch.
1181      dataset_with_partial_batch = dataset.batch(4)
1182      cpu_model.set_weights(model_with_ds_strategy.get_weights())
1183
1184      self.assertAllClose(
1185          model_with_ds_strategy.predict(dataset_with_partial_batch, steps=3),
1186          cpu_model.predict(dataset_with_partial_batch, steps=3),
1187          atol=1e-5, rtol=1e-5)
1188
1189  @combinations.generate(tpu_strategy_combinations())
1190  def test_predict_multi_output_model_with_dataset_with_partial_batch(
1191      self, distribution):
1192    with self.cached_session():
1193      optimizer = gradient_descent.GradientDescentOptimizer(0.001)
1194      loss = 'mse'
1195
1196      with distribution.scope():
1197        model_with_ds_strategy = simple_multi_inputs_multi_outputs_model()
1198        model_with_ds_strategy.compile(optimizer, loss)
1199
1200      cpu_model = simple_multi_inputs_multi_outputs_model()
1201      cpu_model.compile(optimizer, loss)
1202
1203      input_data, _ = get_multi_inputs_multi_outputs_data()
1204      input_dict = {
1205          'input_a': input_data['input_a'],
1206          'input_b': input_data['input_b'],
1207      }
1208
1209      dataset = dataset_ops.Dataset.from_tensor_slices(input_dict)
1210
1211      # As sample size is 200, we batch by 18 using 12 steps per epoch so
1212      # that the last batch is a partial batch.
1213      dataset_with_partial_batch = dataset.batch(18)
1214      cpu_model.set_weights(model_with_ds_strategy.get_weights())
1215
1216      self.assertAllClose(
1217          model_with_ds_strategy.predict(dataset_with_partial_batch, steps=12),
1218          cpu_model.predict(dataset_with_partial_batch, steps=12),
1219          atol=1e-4, rtol=1e-4)
1220
1221
1222class TestRegularizerLoss(test.TestCase, parameterized.TestCase):
1223  class IdentityRegularizer(keras.regularizers.Regularizer):
1224
1225    def __call__(self, x):
1226      return array_ops.identity(x)
1227
1228  class AddLayer(keras.layers.Layer):
1229
1230    def build(self, _):
1231      self.v = self.add_weight(
1232          'v', (), initializer='ones',
1233          regularizer=TestRegularizerLoss.IdentityRegularizer())
1234
1235    def call(self, inputs):
1236      return inputs + self.v
1237
1238  @staticmethod
1239  def loss_fn(_, y_pred):
1240    return math_ops.reduce_mean(y_pred)
1241
1242  @combinations.generate(all_strategy_combinations_minus_default())
1243  def test_regularizer_loss(self, distribution):
1244    batch_size = 2
1245    if not distributed_training_utils.global_batch_size_supported(distribution):
1246      batch_size //= distribution.num_replicas_in_sync
1247
1248      # Given an input x, which is always 1, and variable v, this model computes
1249      # Loss=x+v+regularizer_loss, where regularizer_loss=v and the variable is
1250      # initialized to 1. Therefore, this model computes Loss=1+2v, and so the
1251      # gradient dLoss/dv = 2. This gradient of 2 is averaged over all examples
1252      # in a batch and then multiplied by the learning rate of 1. As a result,
1253      # the model update for one batch should subtract 2 from v, resulting in v
1254      # being -1. If the regularizer loss is not scaled correctly by number of
1255      # replicas, the variable value will be incorrect when number of replicas
1256      # >1. For e.g. it will be -2 if num replicas = 2.
1257    with distribution.scope():
1258      x = keras.layers.Input(shape=(), batch_size=batch_size)
1259      y = TestRegularizerLoss.AddLayer()(x)
1260      model = keras.models.Model(inputs=x, outputs=y)
1261      opt = gradient_descent_keras.SGD(1.)
1262      model.compile(opt, loss=TestRegularizerLoss.loss_fn)
1263      model.fit(
1264          x=np.array([[1.], [1.]], dtype=np.float32),
1265          y=np.array([[1.], [1.]], dtype=np.float32),
1266          batch_size=batch_size)
1267      v = model.get_weights()[0]
1268      self.assertEqual(-1.0, v)
1269
1270
1271class TestDistributionStrategyWithKerasModels(test.TestCase,
1272                                              parameterized.TestCase):
1273
1274  @combinations.generate(all_strategy_combinations())
1275  def test_distribution_strategy_on_sequential_model(self, distribution):
1276    with distribution.scope():
1277      model = simple_sequential_model()
1278      optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
1279      loss = 'mse'
1280      model.compile(optimizer, loss)
1281
1282      inputs = np.zeros((20, 10), np.float32)
1283      targets = np.zeros((20, 2), np.float32)
1284
1285    model.fit(inputs, targets, epochs=1, steps_per_epoch=2)
1286    model.predict(inputs, steps=1)
1287    model.evaluate(inputs, targets, steps=1)
1288
1289  @combinations.generate(all_strategy_combinations())
1290  def test_distribution_strategy_on_functional_model(self, distribution):
1291    with distribution.scope():
1292      model = get_model()
1293      optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
1294      loss = 'mse'
1295      model.compile(optimizer, loss)
1296
1297      inputs = np.zeros((64, 3), dtype=np.float32)
1298      targets = np.zeros((64, 4), dtype=np.float32)
1299
1300    model.fit(inputs, targets, epochs=1, steps_per_epoch=2)
1301    model.predict(inputs, steps=1)
1302    model.evaluate(inputs, targets, steps=1)
1303
1304  @combinations.generate(all_strategy_combinations_minus_default())
1305  def test_distribution_strategy_one_dimensional(self, distribution):
1306    with distribution.scope():
1307      inp = keras.layers.Input(shape=(10,))
1308      out = keras.layers.Dense(3, activation='softmax')(inp)
1309      model = keras.Model(inputs=[inp], outputs=[out])
1310      model.compile(
1311          optimizer='rmsprop',
1312          loss='sparse_categorical_crossentropy',
1313          metrics=['sparse_categorical_accuracy'],
1314      )
1315
1316      x = np.random.random((64, 10)).astype('float32')
1317      y = np.random.randint(3, size=64)
1318
1319      model.fit(x, y, epochs=1, steps_per_epoch=2)
1320
1321
1322if __name__ == '__main__':
1323  test.main()
1324