1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Part of the Keras training engine related to Python generators of array data.
16"""
17# pylint: disable=protected-access
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import functools
23import math
24
25import numpy as np
26
27from tensorflow.python.data.ops import dataset_ops
28from tensorflow.python.data.ops import iterator_ops
29from tensorflow.python.eager import context
30from tensorflow.python.framework import errors
31from tensorflow.python.keras import backend
32from tensorflow.python.keras import callbacks as cbks
33from tensorflow.python.keras.engine import training_utils
34from tensorflow.python.keras.engine import training_utils_v1
35from tensorflow.python.keras.utils import data_utils
36from tensorflow.python.keras.utils import generic_utils
37from tensorflow.python.keras.utils.mode_keys import ModeKeys
38from tensorflow.python.platform import tf_logging as logging
39from tensorflow.python.util import nest
40
41
42def model_iteration(model,
43                    data,
44                    steps_per_epoch=None,
45                    epochs=1,
46                    verbose=1,
47                    callbacks=None,
48                    validation_data=None,
49                    validation_steps=None,
50                    validation_freq=1,
51                    class_weight=None,
52                    max_queue_size=10,
53                    workers=1,
54                    use_multiprocessing=False,
55                    shuffle=False,
56                    initial_epoch=0,
57                    mode=ModeKeys.TRAIN,
58                    batch_size=None,
59                    steps_name='steps',
60                    **kwargs):
61  """Loop function for arrays of data with modes TRAIN/TEST/PREDICT.
62
63  Args:
64      model: Keras Model instance.
65      data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or `(x, y)` or
66        `(x, y, sample_weights)`) or a generator or
67        `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset.
68      steps_per_epoch: Total number of steps (batches of samples) before
69        declaring one epoch finished and starting the next epoch. Ignored with
70        the default value of `None`.
71      epochs: Number of times to iterate over the data.
72      verbose: 0, 1, or 2. Verbosity mode.
73        0 = silent, 1 = progress bar, 2 = one line per epoch.
74        Note that the progress bar is not particularly useful when
75        logged to a file, so verbose=2 is recommended when not running
76        interactively (eg, in a production environment).
77      callbacks: List of callbacks to be called during training.
78      validation_data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or
79        `(x, y)` or `(x, y, sample_weights)`) or a generator or
80        `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset.
81      validation_steps: Total number of steps (batches of samples) before
82        declaring validation finished.
83      validation_freq: Only relevant if validation data is provided. Integer or
84        `collections.abc.Container` instance (e.g. list, tuple, etc.). If an
85        integer, specifies how many training epochs to run before a new
86        validation run is performed, e.g. `validation_freq=2` runs
87        validation every 2 epochs. If a Container, specifies the epochs on
88        which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
89        validation at the end of the 1st, 2nd, and 10th epochs.
90      class_weight: Dictionary mapping class indices to a weight for the class.
91      max_queue_size: Integer. Maximum size for the generator queue. If
92        unspecified, `max_queue_size` will default to 10.
93      workers: Integer. Maximum number of processes to spin up when using
94        process-based threading. If unspecified, `workers` will default to 1. If
95        0, will execute the generator on the main thread.
96      use_multiprocessing: Boolean. If `True`, use process-based threading. If
97        unspecified, `use_multiprocessing` will default to `False`. Note that
98        because this implementation relies on multiprocessing, you should not
99        pass non-picklable arguments to the generator as they can't be passed
100        easily to children processes.
101      shuffle: Boolean. Whether to shuffle the order of the batches at the
102        beginning of each epoch. Only used with instances of `Sequence`
103        (`keras.utils.Sequence`). Has no effect when `steps_per_epoch` is not
104        `None`.
105      initial_epoch: Epoch at which to start training (useful for resuming a
106        previous training run).
107      mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
108      batch_size: Integer batch size or None if unknown. Will only be used if
109        `data` is in NumPy/Tensor format.
110      steps_name: The string name of the steps argument, either `steps`,
111        `validation_steps`, or `steps_per_epoch`. Only used for error message
112        formatting.
113      **kwargs: Additional arguments for backwards compatibility. `steps` is
114        accepted as an alias for `steps_per_epoch`.
115
116  Returns:
117      - In TRAIN mode: `History` object.
118      - In TEST mode: Evaluation metrics.
119      - In PREDICT mode: Outputs of the Model called on inputs.
120
121  Raises:
122      ValueError: in case of invalid arguments.
123  """
124  if 'steps' in kwargs:
125    steps_per_epoch = kwargs['steps']
126
127  # Determine the number of steps per epoch and whether we should reset the
128  # dataset at the end of each epoch.
129  reset_dataset_after_each_epoch = False
130  original_dataset = None
131  is_dataset = isinstance(data, (dataset_ops.DatasetV2, dataset_ops.DatasetV1))
132  if is_dataset:
133    original_dataset = data
134    if steps_per_epoch is None:
135      reset_dataset_after_each_epoch = True
136      steps_per_epoch = training_utils_v1.infer_steps_for_dataset(
137          model, data, steps_per_epoch, epochs=epochs, steps_name=steps_name)
138
139  # Convert to a format that supports `next(generator)`.
140  generator, steps_per_epoch = convert_to_generator_like(
141      data,
142      steps_per_epoch=steps_per_epoch,
143      batch_size=batch_size,
144      epochs=epochs - initial_epoch,
145      shuffle=shuffle)
146
147  do_validation = validation_data is not None
148  is_sequence = isinstance(generator, data_utils.Sequence)
149  _validate_arguments(is_sequence, is_dataset, use_multiprocessing, workers,
150                      steps_per_epoch, validation_data, validation_steps, mode,
151                      kwargs)
152
153  batch_function = _make_execution_function(
154      model, mode, class_weight=class_weight)
155
156  # Create the queue for the generator.
157  enqueuer = None
158  if not is_dataset:
159    generator, enqueuer = _make_enqueued_generator(
160        generator,
161        workers=workers,
162        use_multiprocessing=use_multiprocessing,
163        max_queue_size=max_queue_size,
164        shuffle=shuffle)
165
166  num_samples_or_steps, use_steps = _get_num_samples_or_steps(
167      data, steps_per_epoch)
168
169  count_mode = 'steps' if use_steps else 'samples'
170  callbacks = cbks.configure_callbacks(
171      callbacks,
172      model,
173      do_validation=do_validation,
174      epochs=epochs,
175      steps_per_epoch=steps_per_epoch,
176      batch_size=batch_size,
177      samples=num_samples_or_steps,
178      count_mode=count_mode,
179      verbose=verbose,
180      mode=mode)
181
182  if mode == ModeKeys.PREDICT:
183    aggregator = training_utils_v1.OutputsAggregator(
184        True, steps=steps_per_epoch)
185  else:
186    aggregator = training_utils_v1.MetricsAggregator(
187        True, steps=steps_per_epoch)
188
189  should_set_learning_phase = context.executing_eagerly() and model.run_eagerly
190  if should_set_learning_phase:
191    learning_phase_scope = backend.eager_learning_phase_scope(
192        1 if mode == ModeKeys.TRAIN else 0)
193    learning_phase_scope.__enter__()
194
195  callbacks.model.stop_training = False
196  callbacks._call_begin_hook(mode)
197
198  initial_epoch = model._maybe_load_initial_epoch_from_ckpt(initial_epoch, mode)
199
200  for epoch in range(initial_epoch, epochs):
201    if callbacks.model.stop_training:
202      break
203
204    # Setup work for each epoch.
205    model.reset_metrics()
206    epoch_logs = {}
207    if mode == ModeKeys.TRAIN:
208      callbacks.on_epoch_begin(epoch, epoch_logs)
209
210    if steps_per_epoch is None:
211      # Loop over dataset until `OutOfRangeError` is raised.
212      target_steps = np.inf
213    else:
214      # Loop over dataset for the specified number of steps.
215      target_steps = steps_per_epoch
216
217    step = 0
218    while step < target_steps:
219      batch_data = _get_next_batch(generator)
220      if batch_data is None:
221        if is_dataset:
222          # The dataset passed by the user ran out of batches.
223          # Now we know the cardinality of the dataset.
224          # If steps_per_epoch was specified, then running out of data is
225          # unexpected, so we stop training and inform the user.
226          if steps_per_epoch:
227            callbacks.model.stop_training = True
228            logging.warning(
229                'Your dataset ran out of data; interrupting training. '
230                'Make sure that your dataset can generate at least '
231                '`%s * epochs` batches (in this case, %d batches). '
232                'You may need to use the repeat() function when '
233                'building your dataset.'
234                % (steps_name, steps_per_epoch * epochs))
235          elif step > 0:
236            steps_per_epoch = step
237            aggregator.steps = steps_per_epoch
238        else:
239          # We ran out of batches while the user passed an iterator (legacy).
240          callbacks.model.stop_training = True
241          logging.warning(
242              'Your dataset iterator ran out of data; '
243              'interrupting training. Make sure that your iterator '
244              'can generate at least `%s * epochs` '
245              'batches (in this case, %d batches). You may need to'
246              'use the repeat() function when building your '
247              'dataset.' % (steps_name, steps_per_epoch * epochs))
248        break
249
250      # `batch_size` used for validation data if validation
251      # data is NumPy/EagerTensors.
252      batch_size = int(nest.flatten(batch_data)[0].shape[0])
253
254      # Callbacks batch begin.
255      batch_logs = {'batch': step, 'size': batch_size}
256      callbacks._call_batch_hook(mode, 'begin', step, batch_logs)
257
258      is_deferred = not model._is_compiled
259      batch_outs = batch_function(*batch_data)
260      if not isinstance(batch_outs, list):
261        batch_outs = [batch_outs]
262
263      if step == 0:
264        aggregator.create(batch_outs)
265
266        if is_deferred:
267          # Set callbacks params. We do this here when model is compiled only
268          # in the first iteration of this loop (deferred build scenario).
269          cbks.set_callback_parameters(
270              callbacks,
271              model,
272              do_validation=do_validation,
273              batch_size=batch_size,
274              epochs=epochs,
275              steps_per_epoch=steps_per_epoch,
276              samples=num_samples_or_steps,
277              verbose=verbose,
278              mode=mode)
279
280      # Aggregate results.
281      aggregator.aggregate(batch_outs)
282
283      # Callbacks batch end.
284      batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode)
285      callbacks._call_batch_hook(mode, 'end', step, batch_logs)
286      step += 1
287
288      if callbacks.model.stop_training:
289        break
290
291    aggregator.finalize()
292    results = aggregator.results
293    epoch_logs = cbks.make_logs(model, epoch_logs, results, mode)
294    if len(results) == 1:
295      results = results[0]
296
297    # Run the test loop every epoch during training.
298    if (do_validation and
299        training_utils_v1.should_run_validation(validation_freq, epoch) and
300        not callbacks.model.stop_training):
301      val_results = model_iteration(
302          model,
303          validation_data,
304          steps_per_epoch=validation_steps,
305          batch_size=batch_size,
306          class_weight=class_weight,
307          workers=workers,
308          use_multiprocessing=use_multiprocessing,
309          max_queue_size=max_queue_size,
310          callbacks=callbacks,
311          verbose=verbose,
312          mode=ModeKeys.TEST,
313          steps_name='validation_steps')
314
315      if not isinstance(val_results, list):
316        val_results = [val_results]
317      epoch_logs = cbks.make_logs(
318          model, epoch_logs, val_results, mode, prefix='val_')
319
320    if mode == ModeKeys.TRAIN:
321      # Epochs only apply to `fit`.
322      callbacks.on_epoch_end(epoch, epoch_logs)
323
324    # Recreate dataset iterator for the next epoch.
325    if reset_dataset_after_each_epoch and epoch < epochs - 1:
326      generator = dataset_ops.make_one_shot_iterator(original_dataset)
327
328  model._successful_loop_finish = True
329  callbacks._call_end_hook(mode)
330
331  if enqueuer is not None:
332    enqueuer.stop()
333
334  if should_set_learning_phase:
335    learning_phase_scope.__exit__(None, None, None)
336
337  if mode == ModeKeys.TRAIN:
338    return model.history
339  return results
340
341
342# Maintain compatibility with the existing names.
343fit_generator = functools.partial(model_iteration, mode=ModeKeys.TRAIN)
344evaluate_generator = functools.partial(
345    model_iteration, mode=ModeKeys.TEST, shuffle=False)
346predict_generator = functools.partial(
347    model_iteration, mode=ModeKeys.PREDICT, shuffle=False)
348
349
350def _get_next_batch(generator):
351  """Retrieves the next batch of input data."""
352  try:
353    generator_output = next(generator)
354  except (StopIteration, errors.OutOfRangeError):
355    return None
356
357  if not isinstance(generator_output, tuple):
358    # Always wrap in a tuple.
359    generator_output = (generator_output,)
360  if len(generator_output) not in [1, 2, 3]:
361    raise ValueError(
362        'Output of generator should be a tuple of 1 or 2 or 3 '
363        'elements: (input,) or (input, target) or '
364        '(input, target, sample_weights). Received {}'.format(generator_output))
365  return generator_output
366
367
368def _validate_arguments(is_sequence, is_dataset, use_multiprocessing, workers,
369                        steps_per_epoch, validation_data, validation_steps,
370                        mode, kwargs):
371  """Raises errors if arguments are invalid.
372
373  Args:
374    is_sequence: Boolean, whether data is a `keras.utils.data_utils.Sequence`
375      instance.
376    is_dataset: Boolean, whether data is a dataset instance.
377    use_multiprocessing: Boolean. If `True`, use process-based threading. If
378      unspecified, `use_multiprocessing` will default to `False`. Note that
379      because this implementation relies on multiprocessing, you should not pass
380      non-picklable arguments to the generator as they can't be passed easily to
381      children processes.
382    workers: Integer. Maximum number of processes to spin up when using
383      process-based threading. If unspecified, `workers` will default to 1. If
384      0, will execute the generator on the main thread.
385    steps_per_epoch: Total number of steps (batches of samples) before declaring
386      one epoch finished and starting the next epoch. Ignored with the default
387      value of `None`.
388    validation_data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or `(x,
389      y)` or `(x, y, sample_weights)`) or a generator or
390      `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset.
391    validation_steps: Total number of steps (batches of samples) before
392      declaring validation finished.
393    mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
394    kwargs: Additional arguments for backwards compatibility.
395
396  Raises:
397    ValueError: If `steps_per_epoch` or `validation_steps` are not passed
398      for data types that require them, or if unrecognized keyword
399      arguments are passed.
400  """
401  if not is_sequence and use_multiprocessing and workers > 1:
402    logging.warning(
403        UserWarning('Using a generator with `use_multiprocessing=True`'
404                    ' and multiple workers may duplicate your data.'
405                    ' Please consider using the `keras.utils.Sequence`'
406                    ' class.'))
407
408  if steps_per_epoch is None and not is_dataset:
409    arg_name = 'steps_per_epoch' if mode == ModeKeys.TRAIN else 'steps'
410    raise ValueError('Please specify the number of steps via the '
411                     '`{}` argument.'.format(arg_name))
412
413  val_gen = (
414      data_utils.is_generator_or_sequence(validation_data) or
415      isinstance(validation_data, iterator_ops.IteratorBase))
416  if (val_gen and not isinstance(validation_data, data_utils.Sequence) and
417      not validation_steps):
418    raise ValueError('Please specify the `validation_steps` argument.')
419
420  if any(k != 'steps' for k in kwargs):
421    raise ValueError('Invalid arguments passed: {}'.format(
422        [k for k in kwargs if k != 'steps']))
423
424
425def convert_to_generator_like(data,
426                              batch_size=None,
427                              steps_per_epoch=None,
428                              epochs=1,
429                              shuffle=False):
430  """Make a generator out of NumPy or EagerTensor inputs.
431
432  Args:
433    data: Either a generator or `keras.utils.data_utils.Sequence` object or
434      `Dataset`, `Iterator`, or a {1,2,3}-tuple of NumPy arrays or EagerTensors.
435      If a tuple, the elements represent `(x, y, sample_weights)` and may be
436      `None` or `[None]`.
437    batch_size: Used when creating a generator out of tuples of NumPy arrays or
438      EagerTensors.
439    steps_per_epoch: Steps of the generator to run each epoch. If `None` the
440      number of steps will be read from the data (for
441      `keras.utils.data_utils.Sequence` types).
442    epochs: Total number of epochs to run.
443    shuffle: Whether the data should be shuffled.
444
445  Returns:
446    - Generator, `keras.utils.data_utils.Sequence`, or `Iterator`.
447
448  Raises:
449    - ValueError: If `batch_size` is not provided for NumPy or EagerTensor
450      inputs.
451  """
452  if isinstance(data, tuple):
453    # Scrub `Nones` that might have been passed for `targets`, `sample_weights`.
454    data = tuple(
455        ele for ele in data if not all(e is None for e in nest.flatten(ele)))
456
457  if data_utils.is_generator_or_sequence(data) or isinstance(
458      data, iterator_ops.IteratorBase):
459    if isinstance(data, data_utils.Sequence):
460      if steps_per_epoch is None:
461        steps_per_epoch = len(data)
462    return data, steps_per_epoch
463  if isinstance(data, dataset_ops.DatasetV2):
464    return dataset_ops.make_one_shot_iterator(data), steps_per_epoch
465
466  # Create generator from NumPy or EagerTensor Input.
467  num_samples = int(nest.flatten(data)[0].shape[0])
468  if batch_size is None:
469    raise ValueError(
470        'When passing input data as arrays, do not specify '
471        '`steps_per_epoch`/`steps` argument. Please use `batch_size` instead.')
472  steps_per_epoch = int(math.ceil(num_samples / batch_size))
473
474  def _gen(data):
475    """Makes a generator out of a structure of NumPy/EagerTensors."""
476    index_array = np.arange(num_samples)
477    for _ in range(epochs):
478      if shuffle:
479        np.random.shuffle(index_array)
480      batches = generic_utils.make_batches(num_samples, batch_size)
481      for (batch_start, batch_end) in batches:
482        batch_ids = index_array[batch_start:batch_end]
483        flat_batch_data = training_utils.slice_arrays(
484            nest.flatten(data), batch_ids, contiguous=(not shuffle))
485        yield nest.pack_sequence_as(data, flat_batch_data)
486
487  return _gen(data), steps_per_epoch
488
489
490def _make_enqueued_generator(generator,
491                             workers=1,
492                             use_multiprocessing=False,
493                             max_queue_size=10,
494                             shuffle=False):
495  """Create a buffered queue of next elements of the generator."""
496  is_sequence = isinstance(generator, data_utils.Sequence)
497  enqueuer = None
498  if workers > 0:
499    if is_sequence:
500      enqueuer = data_utils.OrderedEnqueuer(
501          generator, use_multiprocessing=use_multiprocessing, shuffle=shuffle)
502    else:
503      enqueuer = data_utils.GeneratorEnqueuer(
504          generator, use_multiprocessing=use_multiprocessing)
505    enqueuer.start(workers=workers, max_queue_size=max_queue_size)
506    output_generator = enqueuer.get()
507  else:
508    if is_sequence:
509      output_generator = data_utils.iter_sequence_infinite(generator)
510    else:
511      output_generator = generator
512  return output_generator, enqueuer
513
514
515def _make_execution_function(model, mode, class_weight=None):
516  """Makes function to run one step of model execution."""
517  if mode == ModeKeys.TRAIN:
518    f = functools.partial(model.train_on_batch, class_weight=class_weight)
519  elif mode == ModeKeys.TEST:
520    f = model.test_on_batch
521  else:
522    # Match signature of other modes to allow
523    # 1, 2, or 3-tuples from generator
524    def predict_on_batch(x, y=None, sample_weights=None):  # pylint: disable=unused-argument
525      return model.predict_on_batch(x)
526
527    f = predict_on_batch
528
529  # Maintain stateful metrics across batch-level calls.
530  if mode != ModeKeys.PREDICT:
531    f = functools.partial(f, reset_metrics=False)
532
533  return f
534
535
536def _get_num_samples_or_steps(data, steps_per_epoch):
537  """Returns number of samples or steps, and whether to use steps count mode."""
538  flat_inputs = nest.flatten(data)
539  if hasattr(flat_inputs[0], 'shape'):
540    return int(flat_inputs[0].shape[0]), False
541  return steps_per_epoch, True
542
543
544class GeneratorOrSequenceTrainingLoop(training_utils_v1.TrainingLoop):
545  """Generator-like.
546
547  Input is Python generator, or Sequence object.
548
549  The difference between this class and `GeneratorLikeTrainingFunction` is that
550  this class only handles inputs that with x, y and sample_weight fused into one
551  param.
552  """
553
554  def fit(self,
555          model,
556          x=None,
557          y=None,
558          batch_size=None,
559          epochs=1,
560          verbose=1,
561          callbacks=None,
562          validation_split=0.,
563          validation_data=None,
564          shuffle=True,
565          class_weight=None,
566          sample_weight=None,
567          initial_epoch=0,
568          steps_per_epoch=None,
569          validation_steps=None,
570          validation_freq=1,
571          max_queue_size=10,
572          workers=1,
573          use_multiprocessing=False):
574    model._validate_or_infer_batch_size(batch_size, steps_per_epoch, x)
575    training_utils_v1.check_generator_arguments(
576        y, sample_weight, validation_split=validation_split)
577    return fit_generator(
578        model,
579        x,
580        steps_per_epoch=steps_per_epoch,
581        epochs=epochs,
582        verbose=verbose,
583        callbacks=callbacks,
584        validation_data=validation_data,
585        validation_steps=validation_steps,
586        validation_freq=validation_freq,
587        class_weight=class_weight,
588        max_queue_size=max_queue_size,
589        workers=workers,
590        use_multiprocessing=use_multiprocessing,
591        shuffle=shuffle,
592        initial_epoch=initial_epoch,
593        steps_name='steps_per_epoch')
594
595  def evaluate(self,
596               model,
597               x=None,
598               y=None,
599               batch_size=None,
600               verbose=1,
601               sample_weight=None,
602               steps=None,
603               callbacks=None,
604               max_queue_size=10,
605               workers=1,
606               use_multiprocessing=False):
607    model._validate_or_infer_batch_size(batch_size, steps, x)
608    training_utils_v1.check_generator_arguments(y, sample_weight)
609    return evaluate_generator(
610        model,
611        x,
612        steps=steps,
613        verbose=verbose,
614        callbacks=callbacks,
615        max_queue_size=max_queue_size,
616        workers=workers,
617        use_multiprocessing=use_multiprocessing)
618
619  def predict(self,
620              model,
621              x,
622              batch_size=None,
623              verbose=0,
624              steps=None,
625              callbacks=None,
626              max_queue_size=10,
627              workers=1,
628              use_multiprocessing=False):
629    model._validate_or_infer_batch_size(batch_size, steps, x)
630    return predict_generator(
631        model,
632        x,
633        steps=steps,
634        verbose=verbose,
635        callbacks=callbacks,
636        max_queue_size=max_queue_size,
637        workers=workers,
638        use_multiprocessing=use_multiprocessing)
639
640
641class EagerDatasetOrIteratorTrainingLoop(training_utils_v1.TrainingLoop):
642  """A non-distributed Dataset or iterator in eager execution."""
643
644  def fit(self,
645          model,
646          x=None,
647          y=None,
648          batch_size=None,
649          epochs=1,
650          verbose=1,
651          callbacks=None,
652          validation_split=0.,
653          validation_data=None,
654          shuffle=True,
655          class_weight=None,
656          sample_weight=None,
657          initial_epoch=0,
658          steps_per_epoch=None,
659          validation_steps=None,
660          validation_freq=1,
661          **kwargs):
662    model._validate_or_infer_batch_size(batch_size, steps_per_epoch, x)
663    # Make sure that y, sample_weights, validation_split are not passed.
664    training_utils_v1.validate_dataset_input(x, y, sample_weight,
665                                             validation_split)
666    if (isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)) and
667        shuffle):
668      training_utils_v1.verify_dataset_shuffled(x)
669
670    return fit_generator(
671        model,
672        x,
673        steps_per_epoch=steps_per_epoch,
674        epochs=epochs,
675        verbose=verbose,
676        callbacks=callbacks,
677        validation_data=validation_data,
678        validation_steps=validation_steps,
679        validation_freq=validation_freq,
680        class_weight=class_weight,
681        workers=0,
682        shuffle=shuffle,
683        initial_epoch=initial_epoch,
684        steps_name='steps_per_epoch')
685
686  def evaluate(self,
687               model,
688               x=None,
689               y=None,
690               batch_size=None,
691               verbose=1,
692               sample_weight=None,
693               steps=None,
694               callbacks=None,
695               **kwargs):
696    model._validate_or_infer_batch_size(batch_size, steps, x)
697    # Make sure that y, sample_weights, validation_split are not passed.
698    training_utils_v1.validate_dataset_input(x, y, sample_weight)
699    return evaluate_generator(
700        model, x, steps=steps, verbose=verbose, workers=0, callbacks=callbacks)
701
702  def predict(self,
703              model,
704              x,
705              batch_size=None,
706              verbose=0,
707              steps=None,
708              callbacks=None,
709              **kwargs):
710    model._validate_or_infer_batch_size(batch_size, steps, x)
711    return predict_generator(
712        model, x, steps=steps, verbose=verbose, workers=0, callbacks=callbacks)
713
714
715class GeneratorLikeTrainingLoop(training_utils_v1.TrainingLoop):
716  """TrainingLoop that handle inputs like python generator.
717
718  This is the default handler for most of the input data types, includes
719  symbolic tensors or Numpy array-like, Datasets and iterators in graph mode
720  (since they generate symbolic tensors). This Function is used to handle model
721  with `run_eagerly` = True.
722  """
723
724  def fit(self,
725          model,
726          x=None,
727          y=None,
728          batch_size=None,
729          epochs=1,
730          verbose=1,
731          callbacks=None,
732          validation_split=0.,
733          validation_data=None,
734          shuffle=True,
735          class_weight=None,
736          sample_weight=None,
737          initial_epoch=0,
738          steps_per_epoch=None,
739          validation_steps=None,
740          validation_freq=1,
741          **kwargs):
742    batch_size = model._validate_or_infer_batch_size(batch_size,
743                                                     steps_per_epoch, x)
744    x, y, sample_weights = model._standardize_user_data(
745        x,
746        y,
747        sample_weight=sample_weight,
748        class_weight=class_weight,
749        batch_size=batch_size,
750        check_steps=True,
751        steps_name='steps_per_epoch',
752        steps=steps_per_epoch,
753        validation_split=validation_split,
754        shuffle=shuffle)
755
756    if validation_data:
757      validation_data = model._prepare_validation_data(validation_data,
758                                                       batch_size,
759                                                       validation_steps)
760    elif validation_split and 0. < validation_split < 1.:
761      (x, y, sample_weights, val_x, val_y,
762       val_sample_weights) = (
763           training_utils_v1.split_training_and_validation_data(
764               x, y, sample_weights, validation_split))
765      validation_data = (val_x, val_y, val_sample_weights)
766    else:
767      if validation_steps:
768        raise ValueError('`validation_steps` should not be specified if '
769                         '`validation_data` is None.')
770
771    return fit_generator(
772        model, (x, y, sample_weights),
773        steps_per_epoch=steps_per_epoch,
774        batch_size=batch_size,
775        epochs=epochs,
776        verbose=verbose,
777        callbacks=callbacks,
778        validation_data=validation_data,
779        validation_steps=validation_steps,
780        validation_freq=validation_freq,
781        workers=0,
782        shuffle=shuffle,
783        initial_epoch=initial_epoch,
784        steps_name='steps_per_epoch')
785
786  def evaluate(self,
787               model,
788               x=None,
789               y=None,
790               batch_size=None,
791               verbose=1,
792               sample_weight=None,
793               steps=None,
794               callbacks=None,
795               **kwargs):
796    batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
797    x, y, sample_weights = model._standardize_user_data(
798        x,
799        y,
800        sample_weight=sample_weight,
801        batch_size=batch_size,
802        check_steps=True,
803        steps_name='steps',
804        steps=steps)
805    return evaluate_generator(
806        model, (x, y, sample_weights),
807        steps=steps,
808        batch_size=batch_size,
809        verbose=verbose,
810        workers=0,
811        callbacks=callbacks)
812
813  def predict(self,
814              model,
815              x,
816              batch_size=None,
817              verbose=0,
818              steps=None,
819              callbacks=None,
820              **kwargs):
821    batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
822    x, _, _ = model._standardize_user_data(
823        x, check_steps=True, steps_name='steps', steps=steps)
824    return predict_generator(
825        model,
826        x,
827        steps=steps,
828        batch_size=batch_size,
829        verbose=verbose,
830        workers=0,
831        callbacks=callbacks)
832