1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for training routines."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import itertools
22
23from absl.testing import parameterized
24import numpy as np
25
26from tensorflow.python.data.ops import dataset_ops
27from tensorflow.python.data.ops import iterator_ops
28from tensorflow.python.eager import context
29from tensorflow.python.keras import combinations
30from tensorflow.python.keras import keras_parameterized
31from tensorflow.python.keras import layers as layers_module
32from tensorflow.python.keras import losses
33from tensorflow.python.keras import metrics as metrics_module
34from tensorflow.python.keras import testing_utils
35from tensorflow.python.keras.engine import input_layer
36from tensorflow.python.keras.engine import training
37from tensorflow.python.keras.engine import training_generator_v1
38from tensorflow.python.keras.optimizer_v2 import rmsprop
39from tensorflow.python.keras.utils import data_utils
40from tensorflow.python.platform import test
41from tensorflow.python.util import nest
42
43
44def custom_generator(mode=2):
45  batch_size = 10
46  num_samples = 50
47  arr_data = np.random.random((num_samples, 2))
48  arr_labels = np.random.random((num_samples, 4))
49  arr_weights = np.random.random((num_samples,))
50  i = 0
51  while True:
52    batch_index = i * batch_size % num_samples
53    i += 1
54    start = batch_index
55    end = start + batch_size
56    x = arr_data[start: end]
57    y = arr_labels[start: end]
58    w = arr_weights[start: end]
59    if mode == 1:
60      yield x
61    elif mode == 2:
62      yield x, y
63    else:
64      yield x, y, w
65
66
67def custom_generator_changing_batch_size(mode=2):
68  batch_size = 10
69  cur_batch_size = 11
70  num_samples = 50
71  arr_data = np.random.random((num_samples, 2))
72  arr_labels = np.random.random((num_samples, 4))
73  arr_weights = np.random.random((num_samples,))
74  i = 0
75  while True:
76    if cur_batch_size > 1:
77      cur_batch_size -= 1
78    batch_index = i * batch_size % num_samples
79    i += 1
80    start = batch_index
81    end = start + cur_batch_size
82    x = arr_data[start: end]
83    y = arr_labels[start: end]
84    w = arr_weights[start: end]
85    if mode == 1:
86      yield x
87    elif mode == 2:
88      yield x, y
89    else:
90      yield x, y, w
91
92custom_generator_threads = data_utils.threadsafe_generator(custom_generator)
93
94
95class TestGeneratorMethods(keras_parameterized.TestCase):
96
97  @keras_parameterized.run_with_all_model_types
98  @keras_parameterized.run_all_keras_modes
99  @data_utils.dont_use_multiprocessing_pool
100  def test_fit_generator_method(self):
101    model = testing_utils.get_small_mlp(
102        num_hidden=3, num_classes=4, input_dim=2)
103    model.compile(
104        loss='mse',
105        optimizer=rmsprop.RMSprop(1e-3),
106        metrics=['mae', metrics_module.CategoricalAccuracy()])
107
108    model.fit_generator(custom_generator_threads(),
109                        steps_per_epoch=5,
110                        epochs=1,
111                        verbose=1,
112                        max_queue_size=10,
113                        workers=4,
114                        use_multiprocessing=True)
115    model.fit_generator(custom_generator(),
116                        steps_per_epoch=5,
117                        epochs=1,
118                        verbose=1,
119                        max_queue_size=10,
120                        use_multiprocessing=False)
121    model.fit_generator(custom_generator(),
122                        steps_per_epoch=5,
123                        epochs=1,
124                        verbose=1,
125                        max_queue_size=10,
126                        use_multiprocessing=False,
127                        validation_data=custom_generator(),
128                        validation_steps=10)
129    model.fit_generator(custom_generator(),
130                        steps_per_epoch=5,
131                        validation_data=custom_generator(),
132                        validation_steps=1,
133                        workers=0)
134
135  @keras_parameterized.run_with_all_model_types
136  @keras_parameterized.run_all_keras_modes
137  @data_utils.dont_use_multiprocessing_pool
138  def test_evaluate_generator_method(self):
139    model = testing_utils.get_small_mlp(
140        num_hidden=3, num_classes=4, input_dim=2)
141    model.compile(
142        loss='mse',
143        optimizer=rmsprop.RMSprop(1e-3),
144        metrics=['mae', metrics_module.CategoricalAccuracy()],
145        run_eagerly=testing_utils.should_run_eagerly())
146
147    model.evaluate_generator(custom_generator_threads(),
148                             steps=5,
149                             max_queue_size=10,
150                             workers=2,
151                             verbose=1,
152                             use_multiprocessing=True)
153    model.evaluate_generator(custom_generator(),
154                             steps=5,
155                             max_queue_size=10,
156                             use_multiprocessing=False)
157    model.evaluate_generator(custom_generator(),
158                             steps=5,
159                             max_queue_size=10,
160                             use_multiprocessing=False,
161                             workers=0)
162
163  @keras_parameterized.run_with_all_model_types
164  @keras_parameterized.run_all_keras_modes
165  @data_utils.dont_use_multiprocessing_pool
166  def test_predict_generator_method(self):
167    model = testing_utils.get_small_mlp(
168        num_hidden=3, num_classes=4, input_dim=2)
169    model.run_eagerly = testing_utils.should_run_eagerly()
170
171    model.predict_generator(custom_generator_threads(),
172                            steps=5,
173                            max_queue_size=10,
174                            workers=2,
175                            use_multiprocessing=True)
176    model.predict_generator(custom_generator(),
177                            steps=5,
178                            max_queue_size=10,
179                            use_multiprocessing=False)
180    model.predict_generator(custom_generator(),
181                            steps=5,
182                            max_queue_size=10,
183                            workers=0)
184    # Test generator with just inputs (no targets)
185    model.predict_generator(custom_generator_threads(mode=1),
186                            steps=5,
187                            max_queue_size=10,
188                            workers=2,
189                            use_multiprocessing=True)
190    model.predict_generator(custom_generator(mode=1),
191                            steps=5,
192                            max_queue_size=10,
193                            use_multiprocessing=False)
194    model.predict_generator(custom_generator(mode=1),
195                            steps=5,
196                            max_queue_size=10,
197                            workers=0)
198
199  @keras_parameterized.run_with_all_model_types
200  @keras_parameterized.run_all_keras_modes
201  def test_generator_methods_with_sample_weights(self):
202    model = testing_utils.get_small_mlp(
203        num_hidden=3, num_classes=4, input_dim=2)
204    model.compile(
205        loss='mse',
206        optimizer=rmsprop.RMSprop(1e-3),
207        metrics=['mae', metrics_module.CategoricalAccuracy()],
208        run_eagerly=testing_utils.should_run_eagerly())
209
210    model.fit_generator(custom_generator(mode=3),
211                        steps_per_epoch=5,
212                        epochs=1,
213                        verbose=1,
214                        max_queue_size=10,
215                        use_multiprocessing=False)
216    model.fit_generator(custom_generator(mode=3),
217                        steps_per_epoch=5,
218                        epochs=1,
219                        verbose=1,
220                        max_queue_size=10,
221                        use_multiprocessing=False,
222                        validation_data=custom_generator(mode=3),
223                        validation_steps=10)
224    model.predict_generator(custom_generator(mode=3),
225                            steps=5,
226                            max_queue_size=10,
227                            use_multiprocessing=False)
228    model.evaluate_generator(custom_generator(mode=3),
229                             steps=5,
230                             max_queue_size=10,
231                             use_multiprocessing=False)
232
233  @keras_parameterized.run_with_all_model_types
234  @keras_parameterized.run_all_keras_modes
235  def test_generator_methods_invalid_use_case(self):
236    def invalid_generator():
237      while 1:
238        yield (0, 0, 0, 0)
239
240    model = testing_utils.get_small_mlp(
241        num_hidden=3, num_classes=4, input_dim=2)
242    model.compile(
243        loss='mse',
244        optimizer=rmsprop.RMSprop(1e-3),
245        run_eagerly=testing_utils.should_run_eagerly())
246
247    with self.assertRaises(ValueError):
248      model.fit_generator(invalid_generator(),
249                          steps_per_epoch=5,
250                          epochs=1,
251                          verbose=1,
252                          max_queue_size=10,
253                          use_multiprocessing=False)
254    with self.assertRaises(ValueError):
255      model.fit_generator(custom_generator(),
256                          steps_per_epoch=5,
257                          epochs=1,
258                          verbose=1,
259                          max_queue_size=10,
260                          use_multiprocessing=False,
261                          validation_data=invalid_generator(),
262                          validation_steps=10)
263    with self.assertRaises(ValueError):
264      model.predict_generator(invalid_generator(),
265                              steps=5,
266                              max_queue_size=10,
267                              use_multiprocessing=False)
268    with self.assertRaises(ValueError):
269      model.evaluate_generator(invalid_generator(),
270                               steps=5,
271                               max_queue_size=10,
272                               use_multiprocessing=False)
273
274  @keras_parameterized.run_with_all_model_types
275  @keras_parameterized.run_all_keras_modes
276  def test_generator_input_to_fit_eval_predict(self):
277    val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
278
279    def ones_generator():
280      while True:
281        yield np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
282
283    model = testing_utils.get_small_mlp(
284        num_hidden=10, num_classes=1, input_dim=10)
285
286    model.compile(
287        rmsprop.RMSprop(0.001),
288        'binary_crossentropy',
289        run_eagerly=testing_utils.should_run_eagerly())
290    model.fit(
291        ones_generator(),
292        steps_per_epoch=2,
293        validation_data=val_data,
294        epochs=2)
295    model.evaluate(ones_generator(), steps=2)
296    model.predict(ones_generator(), steps=2)
297
298    # Test with a changing batch size
299    model = testing_utils.get_small_mlp(
300        num_hidden=3, num_classes=4, input_dim=2)
301    model.compile(
302        loss='mse',
303        optimizer=rmsprop.RMSprop(1e-3),
304        metrics=['mae', metrics_module.CategoricalAccuracy()])
305    model.fit_generator(custom_generator_changing_batch_size(),
306                        steps_per_epoch=5,
307                        epochs=1,
308                        verbose=1,
309                        max_queue_size=10,
310                        use_multiprocessing=False)
311    model.fit_generator(custom_generator_changing_batch_size(),
312                        steps_per_epoch=5,
313                        epochs=1,
314                        verbose=1,
315                        max_queue_size=10,
316                        use_multiprocessing=False,
317                        validation_data=custom_generator_changing_batch_size(),
318                        validation_steps=10)
319
320    model.fit(
321        custom_generator_changing_batch_size(),
322        steps_per_epoch=5,
323        validation_data=custom_generator_changing_batch_size(),
324        validation_steps=10,
325        epochs=2)
326    model.evaluate(custom_generator_changing_batch_size(), steps=5)
327    model.predict(custom_generator_changing_batch_size(), steps=5)
328
329  @keras_parameterized.run_with_all_model_types
330  @keras_parameterized.run_all_keras_modes
331  @data_utils.dont_use_multiprocessing_pool
332  def test_generator_dynamic_shapes(self):
333
334    x = [
335        'I think juice is great',
336        'unknown is the best language since slicedbread',
337        'a a a a a a a',
338        'matmul'
339        'Yaks are also quite nice',
340    ]
341    y = [1, 0, 0, 1, 1]
342
343    vocab = {
344        word: i + 1 for i, word in
345        enumerate(
346            sorted(set(itertools.chain(*[i.split() for i in x]))))
347    }
348
349    def data_gen(batch_size=2):
350      np.random.seed(0)
351      data = list(zip(x, y)) * 10
352      np.random.shuffle(data)
353
354      def pack_and_pad(queue):
355        x = [[vocab[j] for j in i[0].split()] for i in queue]
356        pad_len = max(len(i) for i in x)
357        x = np.array([i + [0] * (pad_len - len(i)) for i in x])
358        y = np.array([i[1] for i in queue])
359        del queue[:]
360        return x, y[:, np.newaxis]
361
362      queue = []
363      for i, element in enumerate(data):
364        queue.append(element)
365        if not (i + 1) % batch_size:
366          yield pack_and_pad(queue)
367
368      if queue:
369        # Last partial batch
370        yield pack_and_pad(queue)
371
372    model = testing_utils.get_model_from_layers([
373        layers_module.Embedding(input_dim=len(vocab) + 1, output_dim=4),
374        layers_module.SimpleRNN(units=1),
375        layers_module.Activation('sigmoid')
376    ],
377                                                input_shape=(None,))
378
379    model.compile(loss=losses.binary_crossentropy, optimizer='sgd')
380    model.fit(data_gen(), epochs=1, steps_per_epoch=5)
381
382
383class TestGeneratorMethodsWithSequences(keras_parameterized.TestCase):
384
385  @keras_parameterized.run_with_all_model_types
386  @keras_parameterized.run_all_keras_modes
387  @data_utils.dont_use_multiprocessing_pool
388  def test_training_with_sequences(self):
389
390    class DummySequence(data_utils.Sequence):
391
392      def __getitem__(self, idx):
393        return np.zeros([10, 2]), np.ones([10, 4])
394
395      def __len__(self):
396        return 10
397
398    model = testing_utils.get_small_mlp(
399        num_hidden=3, num_classes=4, input_dim=2)
400    model.compile(loss='mse', optimizer=rmsprop.RMSprop(1e-3))
401
402    model.fit_generator(DummySequence(),
403                        steps_per_epoch=10,
404                        validation_data=custom_generator(),
405                        validation_steps=1,
406                        max_queue_size=10,
407                        workers=0,
408                        use_multiprocessing=True)
409    model.fit_generator(DummySequence(),
410                        steps_per_epoch=10,
411                        validation_data=custom_generator(),
412                        validation_steps=1,
413                        max_queue_size=10,
414                        workers=0,
415                        use_multiprocessing=False)
416
417  @keras_parameterized.run_with_all_model_types
418  @keras_parameterized.run_all_keras_modes
419  @data_utils.dont_use_multiprocessing_pool
420  def test_sequence_input_to_fit_eval_predict(self):
421    val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
422
423    class CustomSequence(data_utils.Sequence):
424
425      def __getitem__(self, idx):
426        return np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
427
428      def __len__(self):
429        return 2
430
431    class CustomSequenceChangingBatchSize(data_utils.Sequence):
432
433      def __getitem__(self, idx):
434        batch_size = 10 - idx
435        return (np.ones([batch_size, 10], np.float32),
436                np.ones([batch_size, 1], np.float32))
437
438      def __len__(self):
439        return 2
440
441    model = testing_utils.get_small_mlp(
442        num_hidden=10, num_classes=1, input_dim=10)
443
444    model.compile(rmsprop.RMSprop(0.001), 'binary_crossentropy')
445    model.fit(CustomSequence(), validation_data=val_data, epochs=2)
446    model.evaluate(CustomSequence())
447    model.predict(CustomSequence())
448
449    with self.assertRaisesRegex(ValueError, '`y` argument is not supported'):
450      model.fit(CustomSequence(), y=np.ones([10, 1]))
451
452    with self.assertRaisesRegex(ValueError,
453                                '`sample_weight` argument is not supported'):
454      model.fit(CustomSequence(), sample_weight=np.ones([10, 1]))
455
456    model.compile(rmsprop.RMSprop(0.001), 'binary_crossentropy')
457    model.fit(CustomSequenceChangingBatchSize(),
458              validation_data=val_data, epochs=2)
459    model.evaluate(CustomSequenceChangingBatchSize())
460    model.predict(CustomSequenceChangingBatchSize())
461
462  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
463  def test_sequence_on_epoch_end(self):
464
465    class MySequence(data_utils.Sequence):
466
467      def __init__(self):
468        self.epochs = 0
469
470      def __getitem__(self, idx):
471        return np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
472
473      def __len__(self):
474        return 2
475
476      def on_epoch_end(self):
477        self.epochs += 1
478
479    inputs = input_layer.Input(10)
480    outputs = layers_module.Dense(1)(inputs)
481    model = training.Model(inputs, outputs)
482    model.compile('sgd', 'mse')
483    my_seq = MySequence()
484    model.fit(my_seq, epochs=2)
485    self.assertEqual(my_seq.epochs, 2)
486
487
488@combinations.generate(combinations.combine(mode=['graph', 'eager']))
489class TestConvertToGeneratorLike(test.TestCase, parameterized.TestCase):
490  simple_inputs = (np.ones((10, 10)), np.ones((10, 1)))
491  nested_inputs = ((np.ones((10, 10)), np.ones((10, 20))), (np.ones((10, 1)),
492                                                            np.ones((10, 3))))
493
494  def _make_dataset(self, inputs, batches):
495    return dataset_ops.DatasetV2.from_tensors(inputs).repeat(batches)
496
497  def _make_iterator(self, inputs, batches):
498    return dataset_ops.make_one_shot_iterator(
499        self._make_dataset(inputs, batches))
500
501  def _make_generator(self, inputs, batches):
502
503    def _gen():
504      for _ in range(batches):
505        yield inputs
506
507    return _gen()
508
509  def _make_numpy(self, inputs, _):
510    return inputs
511
512  @parameterized.named_parameters(
513      ('simple_dataset', _make_dataset, simple_inputs),
514      ('simple_iterator', _make_iterator, simple_inputs),
515      ('simple_generator', _make_generator, simple_inputs),
516      ('simple_numpy', _make_numpy, simple_inputs),
517      ('nested_dataset', _make_dataset, nested_inputs),
518      ('nested_iterator', _make_iterator, nested_inputs),
519      ('nested_generator', _make_generator, nested_inputs),
520      ('nested_numpy', _make_numpy, nested_inputs))
521  def test_convert_to_generator_like(self, input_fn, inputs):
522    expected_batches = 5
523    data = input_fn(self, inputs, expected_batches)
524
525    # Dataset and Iterator not supported in Legacy Graph mode.
526    if (not context.executing_eagerly() and
527        isinstance(data, (dataset_ops.DatasetV2, iterator_ops.Iterator))):
528      return
529
530    generator, steps = training_generator_v1.convert_to_generator_like(
531        data, batch_size=2, steps_per_epoch=expected_batches)
532    self.assertEqual(steps, expected_batches)
533
534    for _ in range(expected_batches):
535      outputs = next(generator)
536    nest.assert_same_structure(outputs, inputs)
537
538
539if __name__ == '__main__':
540  test.main()
541