1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Part of the Keras training engine related to plain array data.
16"""
17# pylint: disable=protected-access
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import functools
23
24import numpy as np
25
26from tensorflow.python.data.ops import dataset_ops
27from tensorflow.python.data.ops import iterator_ops
28from tensorflow.python.eager import context
29from tensorflow.python.framework import errors
30from tensorflow.python.keras import backend as K
31from tensorflow.python.keras import callbacks as cbks
32from tensorflow.python.keras.distribute import distributed_training_utils_v1
33from tensorflow.python.keras.engine import training_utils_v1
34from tensorflow.python.keras.utils.generic_utils import make_batches
35from tensorflow.python.keras.utils.generic_utils import slice_arrays
36from tensorflow.python.keras.utils.mode_keys import ModeKeys
37from tensorflow.python.platform import tf_logging as logging
38from tensorflow.python.util import nest
39
40try:
41  from scipy.sparse import issparse  # pylint: disable=g-import-not-at-top
42except ImportError:
43  issparse = None
44
45
46def model_iteration(model,
47                    inputs,
48                    targets=None,
49                    sample_weights=None,
50                    batch_size=None,
51                    epochs=1,
52                    verbose=1,
53                    callbacks=None,
54                    val_inputs=None,
55                    val_targets=None,
56                    val_sample_weights=None,
57                    shuffle=True,
58                    initial_epoch=0,
59                    steps_per_epoch=None,
60                    validation_steps=None,
61                    validation_freq=1,
62                    mode=ModeKeys.TRAIN,
63                    validation_in_fit=False,
64                    prepared_feed_values_from_dataset=False,
65                    steps_name='steps',
66                    **kwargs):
67  """Loop function for arrays of data with modes TRAIN/TEST/PREDICT.
68
69  Args:
70      model: Keras Model instance.
71      inputs: Either a list or dictionary of arrays, or a dataset instance.
72      targets: List/dictionary of input arrays.
73      sample_weights: Optional list of sample weight arrays.
74      batch_size: Integer batch size or None if unknown.
75      epochs: Number of times to iterate over the data
76      verbose: 0, 1, or 2. Verbosity mode.
77        0 = silent, 1 = progress bar, 2 = one line per epoch.
78        Note that the progress bar is not particularly useful when
79        logged to a file, so verbose=2 is recommended when not running
80        interactively (eg, in a production environment).
81      callbacks: List of callbacks to be called during training
82      val_inputs: Either a list or dictionary of arrays, or a dataset instance.
83      val_targets: List/dictionary of target arrays.
84      val_sample_weights: Optional list of sample weight arrays.
85      shuffle: Whether to shuffle the data at the beginning of each epoch
86        concatenation of list the display names of the outputs of `f` and the
87        list of display names of the outputs of `f_val`.
88      initial_epoch: Epoch at which to start training (useful for resuming a
89        previous training run)
90      steps_per_epoch: Total number of steps (batches of samples) before
91        declaring one epoch finished and starting the next epoch. Ignored with
92        the default value of `None`.
93      validation_steps: Number of steps to run validation for (only if doing
94        validation from data tensors). Ignored with the default value of
95        `None`.
96      validation_freq: Only relevant if validation data is provided. Integer or
97        `collections.abc.Container` instance (e.g. list, tuple, etc.). If an
98        integer, specifies how many training epochs to run before a new
99        validation run is performed, e.g. `validation_freq=2` runs
100        validation every 2 epochs. If a Container, specifies the epochs on
101        which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
102        validation at the end of the 1st, 2nd, and 10th epochs.
103      mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
104      validation_in_fit: if true, then this method is invoked from within
105        training iteration (for validation). In the case where `val_inputs` is
106        a dataset, this flag indicates that its iterator and feed values are
107        already created so should properly reuse resources.
108      prepared_feed_values_from_dataset: if True, `inputs` is a list of feed
109        tensors returned from `_prepare_feed_values` call on the validation
110        dataset, so do not call it again on `inputs`. Should only be used for
111        inline validation (i.e., only if `validation_in_fit` is also True).
112      steps_name: The string name of the steps argument, either `steps`,
113        `validation_steps`, or `steps_per_epoch`. Only used for error message
114        formatting.
115      **kwargs: Additional arguments for backwards compatibility.
116
117  Returns:
118      - In TRAIN mode: `History` object.
119      - In TEST mode: Evaluation metrics.
120      - In PREDICT mode: Outputs of the Model called on inputs.
121
122  Raises:
123      ValueError: in case of invalid arguments.
124  """
125  # Backwards compatibility.
126  if 'steps' in kwargs:
127    steps_per_epoch = kwargs.pop('steps')
128  if kwargs:
129    raise TypeError('Unknown arguments: %s' % (kwargs,))
130
131  # In case we were passed a dataset, we extract symbolic tensors from it.
132  reset_dataset_after_each_epoch = False
133  input_iterator = None
134  is_dataset = isinstance(inputs,
135                          (dataset_ops.DatasetV1, dataset_ops.DatasetV2))
136  # TODO(fchollet): consider moving `steps_per_epoch` inference to
137  # _standardize_user_data and set reset_dataset_after_each_epoch as an
138  # attribute on the dataset instance.
139  if is_dataset:
140    if steps_per_epoch is None:
141      reset_dataset_after_each_epoch = True
142      steps_per_epoch = training_utils_v1.infer_steps_for_dataset(
143          model, inputs, steps_per_epoch, epochs=epochs, steps_name=steps_name)
144    input_iterator = _get_iterator(inputs, model._distribution_strategy)
145
146  # Enter tf.distribute.Strategy scope.
147  if model._distribution_strategy:
148    scope = distributed_training_utils_v1.distributed_scope(
149        strategy=model._distribution_strategy,
150        learning_phase=(1 if mode == ModeKeys.TRAIN else 0))
151    scope.__enter__()
152
153  use_steps = is_dataset or steps_per_epoch is not None
154  do_validation = val_inputs is not None
155
156  # Prepare input data.
157  inputs = input_iterator or inputs
158  if validation_in_fit and prepared_feed_values_from_dataset:
159    # When invoking validation in training loop, avoid creating iterator and
160    # list of feed values for the same validation dataset multiple times (which
161    # essentially would call `iterator.get_next()` that slows down execution and
162    # leads to OOM errors eventually.
163    ins = inputs
164  else:
165    ins = _prepare_feed_values(model, inputs, targets, sample_weights, mode)
166    # `ins` is a function when a distribute strategy is used in Eager mode.  In
167    # that case `is_dataset` is True.  The code branches that have requirements
168    # about the type of `ins` do not trigger in the distributed case.
169
170  if not is_dataset:
171    num_samples_or_steps = _get_num_samples_or_steps(ins, batch_size,
172                                                     steps_per_epoch)
173  else:
174    num_samples_or_steps = steps_per_epoch
175
176  # Update sample_weight_mode of the model if sample_weights is specified by the
177  # user. We need to call this function after we have a handle on the inputs
178  # (both numpy arrays and datasets) in order to determine if the user has
179  # specified sample_weights.
180  _update_sample_weight_mode(model, mode, ins)
181
182  # Get step function and loop type. As part of building the execution
183  # function we recompile the metrics based on the updated
184  # sample_weight_mode value.
185  f = _make_execution_function(model, mode)
186
187  # Prepare validation data. Hold references to the iterator and the input list
188  # to properly reinitialize and reuse in multiple validation passes.
189  val_iterator = None
190  if isinstance(val_inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
191    if validation_steps is None:
192      # Because we pass an iterator feed instead of a Dataset to the eval
193      # model_iteration() call, it will not trigger the dataset-input path
194      # that determines the number of steps required. To avoid this issue,
195      # set validation_steps here if validation_steps is None.
196      validation_steps = training_utils_v1.infer_steps_for_dataset(
197          model,
198          val_inputs,
199          validation_steps,
200          epochs=epochs,
201          steps_name='validation_steps')
202    val_iterator = _get_iterator(val_inputs, model._distribution_strategy)
203    val_inputs = _prepare_feed_values(
204        model, val_iterator, val_targets, val_sample_weights, ModeKeys.TEST)
205    # Get num steps for printing.
206    val_samples_or_steps = validation_steps
207  else:
208    # Get num samples for printing.
209    val_samples_or_steps = val_inputs and nest.flatten(
210        val_inputs)[0].shape[0] or None
211
212  if mode == ModeKeys.TRAIN and verbose:
213    _print_train_info(num_samples_or_steps, val_samples_or_steps, is_dataset)
214
215  # Configure callbacks.
216  count_mode = 'steps' if use_steps else 'samples'
217  callbacks = cbks.configure_callbacks(
218      callbacks,
219      model,
220      do_validation=do_validation,
221      batch_size=batch_size,
222      epochs=epochs,
223      steps_per_epoch=steps_per_epoch,
224      samples=num_samples_or_steps,
225      count_mode=count_mode,
226      verbose=verbose,
227      mode=mode)
228
229  # Find beforehand arrays that need sparse-to-dense conversion.
230  if issparse is not None and not use_steps:
231    indices_for_conversion_to_dense = []
232    feed = _get_model_feed(model, mode)
233    for i, (input_data, feed_tensor) in enumerate(zip(ins, feed)):
234      if issparse(input_data) and not K.is_sparse(feed_tensor):
235        indices_for_conversion_to_dense.append(i)
236
237  # Select aggregation method.
238  if mode == ModeKeys.PREDICT:
239    aggregator = training_utils_v1.OutputsAggregator(
240        use_steps,
241        num_samples=None if steps_per_epoch else num_samples_or_steps,
242        steps=steps_per_epoch)
243  else:
244    aggregator = training_utils_v1.MetricsAggregator(
245        use_steps,
246        num_samples=None if steps_per_epoch else num_samples_or_steps,
247        steps=steps_per_epoch)
248
249  if model._compile_distribution:
250    distributed_training_utils_v1._copy_weights_to_distributed_model(
251        model, mode)
252
253  callbacks.model.stop_training = False
254  callbacks._call_begin_hook(mode)
255
256  initial_epoch = model._maybe_load_initial_epoch_from_ckpt(initial_epoch, mode)
257
258  for epoch in range(initial_epoch, epochs):
259    if callbacks.model.stop_training:
260      break
261
262    # Setup work for each epoch
263    epoch_logs = {}
264    if mode != ModeKeys.PREDICT:
265      # Collecting and resetting metrics has non-zero cost and will needlessly
266      # slow down model.predict.
267      model.reset_metrics()
268    if mode == ModeKeys.TRAIN:
269      callbacks.on_epoch_begin(epoch, epoch_logs)
270
271    if use_steps:
272      # Step-wise loop.
273      if steps_per_epoch is None:
274        # Loop over dataset until `OutOfRangeError` is raised.
275        target_steps = np.inf
276      else:
277        # Loop over dataset for the specified number of steps.
278        target_steps = steps_per_epoch
279
280      step = 0
281      while step < target_steps:
282        batch_logs = {'batch': step, 'size': 1}
283        callbacks._call_batch_hook(mode, 'begin', step, batch_logs)
284
285        # Get outputs.
286        try:
287          # `ins` can be callable in tf.distribute.Strategy + eager case.
288          if not callable(ins) or (model._distribution_strategy and
289                                   not distributed_training_utils_v1
290                                   .is_distributing_by_cloning(model)):
291            actual_inputs = ins
292          else:
293            actual_inputs = ins()
294          batch_outs = f(actual_inputs)
295        except errors.OutOfRangeError:
296          if is_dataset:
297            # The dataset passed by the user ran out of batches.
298            # Now we know the cardinality of the dataset.
299            # If steps_per_epoch was specified, then running out of data is
300            # unexpected, so we stop training and inform the user.
301            if steps_per_epoch:
302              callbacks.model.stop_training = True
303              logging.warning(
304                  'Your dataset ran out of data; interrupting training. '
305                  'Make sure that your dataset can generate at least '
306                  '`%s * epochs` batches (in this case, %d batches). '
307                  'You may need to use the repeat() function when '
308                  'building your dataset.'
309                  % (steps_name, steps_per_epoch * epochs))
310            elif step > 0:
311              steps_per_epoch = step
312              aggregator.steps = steps_per_epoch
313          else:
314            # We ran out of batches while the user passed an iterator (legacy).
315            callbacks.model.stop_training = True
316            logging.warning(
317                'Your dataset iterator ran out of data; '
318                'interrupting training. Make sure that your iterator '
319                'can generate at least `%s * epochs` '
320                'batches (in this case, %d batches). You may need to'
321                'use the repeat() function when building your '
322                'dataset.' % (steps_name, steps_per_epoch * epochs))
323          break
324
325        if not isinstance(batch_outs, list):
326          batch_outs = [batch_outs]
327
328        if model._distribution_strategy:
329          batch_outs = (
330              distributed_training_utils_v1._per_replica_aggregate_batch(
331                  model._distribution_strategy, batch_outs, model, mode))
332
333        # Aggregate results.
334        if step == 0:
335          aggregator.create(batch_outs)
336        aggregator.aggregate(batch_outs)
337
338        # Callbacks batch end.
339        batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode)
340        callbacks._call_batch_hook(mode, 'end', step, batch_logs)
341        step += 1
342
343        if callbacks.model.stop_training:
344          break
345    else:
346      # Sample-wise loop.
347      index_array = np.arange(num_samples_or_steps)
348      if shuffle == 'batch':
349        index_array = training_utils_v1.batch_shuffle(index_array, batch_size)
350      elif shuffle:
351        np.random.shuffle(index_array)
352      batches = make_batches(num_samples_or_steps, batch_size)
353      for batch_index, (batch_start, batch_end) in enumerate(batches):
354        batch_ids = index_array[batch_start:batch_end]
355        # Slice into a batch.
356        if len(batches) == 1:
357          # If we only have one batch, do not slice. This takes care of
358          # composite tensors in non-Dataset modes; we currently don't support
359          # slicing them.
360          # TODO(b/133517906): Add slicing support.
361          ins_batch = ins
362        else:
363          try:
364            if ins and isinstance(ins[-1], int):
365              # Do not slice the training phase flag.
366              ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
367            else:
368              ins_batch = slice_arrays(ins, batch_ids)
369          except TypeError:
370            raise TypeError('TypeError while preparing batch. '
371                            'If using HDF5 input data, '
372                            'pass shuffle="batch".')
373
374        # Sparse to dense conversion.
375        if issparse is not None:
376          for i in indices_for_conversion_to_dense:
377            ins_batch[i] = ins_batch[i].toarray()
378
379        # Callbacks batch_begin.
380        batch_logs = {'batch': batch_index, 'size': len(batch_ids)}
381        callbacks._call_batch_hook(mode, 'begin', batch_index, batch_logs)
382
383        # Get outputs.
384        batch_outs = f(ins_batch)
385        if not isinstance(batch_outs, list):
386          batch_outs = [batch_outs]
387
388        # Aggregate results.
389        if batch_index == 0:
390          aggregator.create(batch_outs)
391        aggregator.aggregate(batch_outs, batch_start, batch_end)
392
393        # Callbacks batch end.
394        batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode)
395        callbacks._call_batch_hook(mode, 'end', batch_index, batch_logs)
396
397        if callbacks.model.stop_training:
398          break
399
400    aggregator.finalize()
401    results = aggregator.results
402    epoch_logs = cbks.make_logs(model, epoch_logs, results, mode)
403    if len(results) == 1:
404      results = results[0]
405
406    # Run the test loop every `validation_freq` epochs during training.
407    if (do_validation and
408        training_utils_v1.should_run_validation(validation_freq, epoch) and
409        not callbacks.model.stop_training):
410
411      if model._compile_distribution:
412        # Since we create a new clone from the original model we need to copy
413        # the weights back to the original model before we can run validation.
414        distributed_training_utils_v1._copy_weights_to_original_model(
415            model, ModeKeys.TRAIN)
416
417      val_results = model_iteration(
418          model,
419          val_inputs,
420          targets=val_targets,
421          sample_weights=val_sample_weights,
422          batch_size=batch_size,
423          steps_per_epoch=validation_steps,
424          callbacks=callbacks,
425          verbose=0,
426          mode=ModeKeys.TEST,
427          validation_in_fit=True,
428          prepared_feed_values_from_dataset=(val_iterator is not None),
429          steps_name='validation_steps')
430      if not isinstance(val_results, list):
431        val_results = [val_results]
432      epoch_logs = cbks.make_logs(
433          model, epoch_logs, val_results, mode, prefix='val_')
434      if val_iterator and epoch < epochs - 1:
435        _reinitialize_iterator(val_iterator, model._distribution_strategy)
436
437    if mode == ModeKeys.TRAIN:
438      # Epochs only apply to `fit`.
439      callbacks.on_epoch_end(epoch, epoch_logs)
440
441    # Reinitialize dataset iterator for the next epoch.
442    if reset_dataset_after_each_epoch and epoch < epochs - 1:
443      _reinitialize_iterator(input_iterator, model._distribution_strategy)
444
445  model._successful_loop_finish = True
446  callbacks._call_end_hook(mode)
447
448  if model._distribution_strategy:
449    if model._compile_distribution:
450      # TODO(priyag, psv): Copy back metrics to the original model as well?
451      distributed_training_utils_v1._copy_weights_to_original_model(model, mode)
452    scope.__exit__(None, None, None)
453
454  if mode == ModeKeys.TRAIN:
455    return model.history
456  return results
457
458
459def _get_model_feed(model, mode):
460  if mode == ModeKeys.PREDICT:
461    feed = model._feed_inputs
462  else:
463    feed = (
464        model._feed_inputs + model._feed_targets + model._feed_sample_weights)
465  return feed
466
467
468def _print_train_info(num_samples_or_steps, val_samples_or_steps, is_dataset):
469  increment = 'steps' if is_dataset else 'samples'
470  msg = 'Train on {0} {increment}'.format(
471      num_samples_or_steps, increment=increment)
472  if val_samples_or_steps:
473    msg += ', validate on {0} {increment}'.format(
474        val_samples_or_steps, increment=increment)
475  print(msg)
476
477
478def _get_num_samples_or_steps(ins, batch_size, steps_per_epoch):
479  """Returns total number of samples (when training in batch mode) or steps."""
480  if steps_per_epoch:
481    return steps_per_epoch
482  return training_utils_v1.check_num_samples(ins, batch_size, steps_per_epoch,
483                                             'steps_per_epoch')
484
485
486def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
487  """Prepare feed values to the model execution function.
488
489  Args:
490    model: Model to prepare feed values for.
491    inputs: List or dict of model inputs.
492    targets: Optional list of model targets.
493    sample_weights: Optional list of sample weight arrays.
494    mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
495
496  Returns:
497    Feed values for the model in the given mode.
498  """
499  if model._distribution_strategy:
500    if isinstance(inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
501      inputs = distributed_training_utils_v1.get_iterator(
502          inputs, model._distribution_strategy)
503
504    def get_distributed_inputs():
505      return distributed_training_utils_v1._prepare_feed_values(
506          model, inputs, targets, sample_weights, mode)
507
508    # In the eager case, we want to call the input method per step, so return
509    # a lambda from here that can be called. Note that this is applicable only
510    # in Distribution Strategy case as it follows the same code path for both
511    # eager and graph modes.
512    # TODO(priyag,omalleyt): Either we should move the training DS with
513    # IteratorBase to use training_generator code path, or figure out how to
514    # set a symbolic Iterator out of a Dataset when in eager mode.
515    if context.executing_eagerly():
516      return get_distributed_inputs
517    else:
518      return get_distributed_inputs()
519
520  if isinstance(inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2,
521                         iterator_ops.Iterator)):
522    inputs, targets, sample_weights = model._standardize_user_data(
523        inputs,
524        extract_tensors_from_dataset=True)
525
526  inputs = training_utils_v1.ModelInputs(inputs).as_list()
527  targets = list(targets or [])
528  sample_weights = list(sample_weights or [])
529  ins = inputs + targets + sample_weights
530  if mode == ModeKeys.TRAIN and not isinstance(K.symbolic_learning_phase(),
531                                               int):
532    ins += [True]  # Add learning phase value.
533  return ins
534
535
536def _get_iterator(inputs, distribution_strategy=None):
537  if distribution_strategy:
538    return distributed_training_utils_v1.get_iterator(
539        inputs, distribution_strategy)
540  return training_utils_v1.get_iterator(inputs)
541
542
543def _reinitialize_iterator(iterator, distribution_strategy=None):
544  if distribution_strategy:
545    distributed_training_utils_v1.initialize_iterator(
546        iterator, distribution_strategy)
547  else:
548    training_utils_v1.initialize_iterator(iterator)
549
550
551def _make_execution_function(model, mode):
552  """Makes function to run one step of model execution."""
553  if model._distribution_strategy:
554    return distributed_training_utils_v1._make_execution_function(model, mode)
555  return model._make_execution_function(mode)
556
557
558def _update_sample_weight_mode(model, mode, inputs):
559  """Updates the sample_weight_mode of a given model."""
560  # Add a quick return to prevent us from calling model._feed_targets that
561  # accesses certain model properties that may not be set in the `PREDICT` mode.
562  if mode == ModeKeys.PREDICT:
563    return
564
565  sample_weights = None
566  # `inputs` is the model's inputs + targets + sample_weights +
567  # learning phase placeholder if specified. To update the sample_weight_mode
568  # we need to determine if the user has passed sample weights as part of the
569  # input.
570  if not callable(inputs):
571    sample_weights = inputs[len(model._feed_inputs) + len(model._feed_targets):]
572    has_learning_phase_pl = (mode == ModeKeys.TRAIN and
573                             not isinstance(K.symbolic_learning_phase(), int))
574    if has_learning_phase_pl:
575      sample_weights = sample_weights[:-1]
576    model._update_sample_weight_modes(sample_weights=sample_weights)
577
578  # Call the DistributionStrategy specific function to update the
579  # sample_weight_mode on the model.
580  if model._distribution_strategy:
581    distributed_training_utils_v1._update_sample_weight_modes(model, mode,
582                                                              sample_weights)
583
584# For backwards compatibility for internal users of these loops.
585fit_loop = functools.partial(model_iteration, mode=ModeKeys.TRAIN)
586test_loop = functools.partial(
587    model_iteration, mode=ModeKeys.TEST, shuffle=False)
588predict_loop = functools.partial(
589    model_iteration, mode=ModeKeys.PREDICT, shuffle=False)
590
591
592class ArrayLikeTrainingLoop(training_utils_v1.TrainingLoop):
593  """TrainingLoop that handle inputs like array.
594
595  This is the default handler for most of the input data types, includes
596  symbolic tensors or Numpy array-like, Datasets and iterators in graph mode
597  (since they generate symbolic tensors). This Function is used to handle model
598  with `run_eagerly` = False.
599  """
600
601  def fit(self,
602          model,
603          x=None,
604          y=None,
605          batch_size=None,
606          epochs=1,
607          verbose=1,
608          callbacks=None,
609          validation_split=0.,
610          validation_data=None,
611          shuffle=True,
612          class_weight=None,
613          sample_weight=None,
614          initial_epoch=0,
615          steps_per_epoch=None,
616          validation_steps=None,
617          validation_freq=1,
618          **kwargs):
619    batch_size = model._validate_or_infer_batch_size(batch_size,
620                                                     steps_per_epoch, x)
621
622    x, y, sample_weights = model._standardize_user_data(
623        x,
624        y,
625        sample_weight=sample_weight,
626        class_weight=class_weight,
627        batch_size=batch_size,
628        check_steps=True,
629        steps_name='steps_per_epoch',
630        steps=steps_per_epoch,
631        validation_split=validation_split,
632        shuffle=shuffle)
633
634    if validation_data:
635      val_x, val_y, val_sample_weights = model._prepare_validation_data(
636          validation_data, batch_size, validation_steps)
637    elif validation_split and 0. < validation_split < 1.:
638      (x, y, sample_weights, val_x, val_y, val_sample_weights
639      ) = training_utils_v1.split_training_and_validation_data(
640          x, y, sample_weights, validation_split)
641    else:
642      if validation_steps:
643        raise ValueError('`validation_steps` should not be specified if '
644                         '`validation_data` is None.')
645      val_x, val_y, val_sample_weights = None, None, None
646
647    return fit_loop(
648        model,
649        inputs=x,
650        targets=y,
651        sample_weights=sample_weights,
652        batch_size=batch_size,
653        epochs=epochs,
654        verbose=verbose,
655        callbacks=callbacks,
656        val_inputs=val_x,
657        val_targets=val_y,
658        val_sample_weights=val_sample_weights,
659        shuffle=shuffle,
660        initial_epoch=initial_epoch,
661        steps_per_epoch=steps_per_epoch,
662        validation_steps=validation_steps,
663        validation_freq=validation_freq,
664        steps_name='steps_per_epoch')
665
666  def evaluate(self,
667               model,
668               x=None,
669               y=None,
670               batch_size=None,
671               verbose=1,
672               sample_weight=None,
673               steps=None,
674               callbacks=None,
675               **kwargs):
676    batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
677    x, y, sample_weights = model._standardize_user_data(
678        x,
679        y,
680        sample_weight=sample_weight,
681        batch_size=batch_size,
682        check_steps=True,
683        steps_name='steps',
684        steps=steps)
685    return test_loop(
686        model,
687        inputs=x,
688        targets=y,
689        sample_weights=sample_weights,
690        batch_size=batch_size,
691        verbose=verbose,
692        steps=steps,
693        callbacks=callbacks)
694
695  def predict(self,
696              model,
697              x,
698              batch_size=None,
699              verbose=0,
700              steps=None,
701              callbacks=None,
702              **kwargs):
703    batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
704    x, _, _ = model._standardize_user_data(
705        x, check_steps=True, steps_name='steps', steps=steps)
706    return predict_loop(
707        model,
708        x,
709        batch_size=batch_size,
710        verbose=verbose,
711        steps=steps,
712        callbacks=callbacks)
713