1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Part of the Keras training engine related to Python generators of array data. 16""" 17# pylint: disable=protected-access 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import functools 23import math 24 25import numpy as np 26 27from tensorflow.python.data.ops import dataset_ops 28from tensorflow.python.data.ops import iterator_ops 29from tensorflow.python.eager import context 30from tensorflow.python.framework import errors 31from tensorflow.python.keras import backend 32from tensorflow.python.keras import callbacks as cbks 33from tensorflow.python.keras.engine import training_utils 34from tensorflow.python.keras.engine import training_utils_v1 35from tensorflow.python.keras.utils import data_utils 36from tensorflow.python.keras.utils import generic_utils 37from tensorflow.python.keras.utils.mode_keys import ModeKeys 38from tensorflow.python.platform import tf_logging as logging 39from tensorflow.python.util import nest 40 41 42def model_iteration(model, 43 data, 44 steps_per_epoch=None, 45 epochs=1, 46 verbose=1, 47 callbacks=None, 48 validation_data=None, 49 validation_steps=None, 50 validation_freq=1, 51 class_weight=None, 52 max_queue_size=10, 53 workers=1, 54 use_multiprocessing=False, 55 shuffle=False, 56 initial_epoch=0, 57 mode=ModeKeys.TRAIN, 58 batch_size=None, 59 steps_name='steps', 60 **kwargs): 61 """Loop function for arrays of data with modes TRAIN/TEST/PREDICT. 62 63 Args: 64 model: Keras Model instance. 65 data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or `(x, y)` or 66 `(x, y, sample_weights)`) or a generator or 67 `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset. 68 steps_per_epoch: Total number of steps (batches of samples) before 69 declaring one epoch finished and starting the next epoch. Ignored with 70 the default value of `None`. 71 epochs: Number of times to iterate over the data. 72 verbose: 0, 1, or 2. Verbosity mode. 73 0 = silent, 1 = progress bar, 2 = one line per epoch. 74 Note that the progress bar is not particularly useful when 75 logged to a file, so verbose=2 is recommended when not running 76 interactively (eg, in a production environment). 77 callbacks: List of callbacks to be called during training. 78 validation_data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or 79 `(x, y)` or `(x, y, sample_weights)`) or a generator or 80 `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset. 81 validation_steps: Total number of steps (batches of samples) before 82 declaring validation finished. 83 validation_freq: Only relevant if validation data is provided. Integer or 84 `collections.abc.Container` instance (e.g. list, tuple, etc.). If an 85 integer, specifies how many training epochs to run before a new 86 validation run is performed, e.g. `validation_freq=2` runs 87 validation every 2 epochs. If a Container, specifies the epochs on 88 which to run validation, e.g. `validation_freq=[1, 2, 10]` runs 89 validation at the end of the 1st, 2nd, and 10th epochs. 90 class_weight: Dictionary mapping class indices to a weight for the class. 91 max_queue_size: Integer. Maximum size for the generator queue. If 92 unspecified, `max_queue_size` will default to 10. 93 workers: Integer. Maximum number of processes to spin up when using 94 process-based threading. If unspecified, `workers` will default to 1. If 95 0, will execute the generator on the main thread. 96 use_multiprocessing: Boolean. If `True`, use process-based threading. If 97 unspecified, `use_multiprocessing` will default to `False`. Note that 98 because this implementation relies on multiprocessing, you should not 99 pass non-picklable arguments to the generator as they can't be passed 100 easily to children processes. 101 shuffle: Boolean. Whether to shuffle the order of the batches at the 102 beginning of each epoch. Only used with instances of `Sequence` 103 (`keras.utils.Sequence`). Has no effect when `steps_per_epoch` is not 104 `None`. 105 initial_epoch: Epoch at which to start training (useful for resuming a 106 previous training run). 107 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. 108 batch_size: Integer batch size or None if unknown. Will only be used if 109 `data` is in NumPy/Tensor format. 110 steps_name: The string name of the steps argument, either `steps`, 111 `validation_steps`, or `steps_per_epoch`. Only used for error message 112 formatting. 113 **kwargs: Additional arguments for backwards compatibility. `steps` is 114 accepted as an alias for `steps_per_epoch`. 115 116 Returns: 117 - In TRAIN mode: `History` object. 118 - In TEST mode: Evaluation metrics. 119 - In PREDICT mode: Outputs of the Model called on inputs. 120 121 Raises: 122 ValueError: in case of invalid arguments. 123 """ 124 if 'steps' in kwargs: 125 steps_per_epoch = kwargs['steps'] 126 127 # Determine the number of steps per epoch and whether we should reset the 128 # dataset at the end of each epoch. 129 reset_dataset_after_each_epoch = False 130 original_dataset = None 131 is_dataset = isinstance(data, (dataset_ops.DatasetV2, dataset_ops.DatasetV1)) 132 if is_dataset: 133 original_dataset = data 134 if steps_per_epoch is None: 135 reset_dataset_after_each_epoch = True 136 steps_per_epoch = training_utils_v1.infer_steps_for_dataset( 137 model, data, steps_per_epoch, epochs=epochs, steps_name=steps_name) 138 139 # Convert to a format that supports `next(generator)`. 140 generator, steps_per_epoch = convert_to_generator_like( 141 data, 142 steps_per_epoch=steps_per_epoch, 143 batch_size=batch_size, 144 epochs=epochs - initial_epoch, 145 shuffle=shuffle) 146 147 do_validation = validation_data is not None 148 is_sequence = isinstance(generator, data_utils.Sequence) 149 _validate_arguments(is_sequence, is_dataset, use_multiprocessing, workers, 150 steps_per_epoch, validation_data, validation_steps, mode, 151 kwargs) 152 153 batch_function = _make_execution_function( 154 model, mode, class_weight=class_weight) 155 156 # Create the queue for the generator. 157 enqueuer = None 158 if not is_dataset: 159 generator, enqueuer = _make_enqueued_generator( 160 generator, 161 workers=workers, 162 use_multiprocessing=use_multiprocessing, 163 max_queue_size=max_queue_size, 164 shuffle=shuffle) 165 166 num_samples_or_steps, use_steps = _get_num_samples_or_steps( 167 data, steps_per_epoch) 168 169 count_mode = 'steps' if use_steps else 'samples' 170 callbacks = cbks.configure_callbacks( 171 callbacks, 172 model, 173 do_validation=do_validation, 174 epochs=epochs, 175 steps_per_epoch=steps_per_epoch, 176 batch_size=batch_size, 177 samples=num_samples_or_steps, 178 count_mode=count_mode, 179 verbose=verbose, 180 mode=mode) 181 182 if mode == ModeKeys.PREDICT: 183 aggregator = training_utils_v1.OutputsAggregator( 184 True, steps=steps_per_epoch) 185 else: 186 aggregator = training_utils_v1.MetricsAggregator( 187 True, steps=steps_per_epoch) 188 189 should_set_learning_phase = context.executing_eagerly() and model.run_eagerly 190 if should_set_learning_phase: 191 learning_phase_scope = backend.eager_learning_phase_scope( 192 1 if mode == ModeKeys.TRAIN else 0) 193 learning_phase_scope.__enter__() 194 195 callbacks.model.stop_training = False 196 callbacks._call_begin_hook(mode) 197 198 initial_epoch = model._maybe_load_initial_epoch_from_ckpt(initial_epoch, mode) 199 200 for epoch in range(initial_epoch, epochs): 201 if callbacks.model.stop_training: 202 break 203 204 # Setup work for each epoch. 205 model.reset_metrics() 206 epoch_logs = {} 207 if mode == ModeKeys.TRAIN: 208 callbacks.on_epoch_begin(epoch, epoch_logs) 209 210 if steps_per_epoch is None: 211 # Loop over dataset until `OutOfRangeError` is raised. 212 target_steps = np.inf 213 else: 214 # Loop over dataset for the specified number of steps. 215 target_steps = steps_per_epoch 216 217 step = 0 218 while step < target_steps: 219 batch_data = _get_next_batch(generator) 220 if batch_data is None: 221 if is_dataset: 222 # The dataset passed by the user ran out of batches. 223 # Now we know the cardinality of the dataset. 224 # If steps_per_epoch was specified, then running out of data is 225 # unexpected, so we stop training and inform the user. 226 if steps_per_epoch: 227 callbacks.model.stop_training = True 228 logging.warning( 229 'Your dataset ran out of data; interrupting training. ' 230 'Make sure that your dataset can generate at least ' 231 '`%s * epochs` batches (in this case, %d batches). ' 232 'You may need to use the repeat() function when ' 233 'building your dataset.' 234 % (steps_name, steps_per_epoch * epochs)) 235 elif step > 0: 236 steps_per_epoch = step 237 aggregator.steps = steps_per_epoch 238 else: 239 # We ran out of batches while the user passed an iterator (legacy). 240 callbacks.model.stop_training = True 241 logging.warning( 242 'Your dataset iterator ran out of data; ' 243 'interrupting training. Make sure that your iterator ' 244 'can generate at least `%s * epochs` ' 245 'batches (in this case, %d batches). You may need to' 246 'use the repeat() function when building your ' 247 'dataset.' % (steps_name, steps_per_epoch * epochs)) 248 break 249 250 # `batch_size` used for validation data if validation 251 # data is NumPy/EagerTensors. 252 batch_size = int(nest.flatten(batch_data)[0].shape[0]) 253 254 # Callbacks batch begin. 255 batch_logs = {'batch': step, 'size': batch_size} 256 callbacks._call_batch_hook(mode, 'begin', step, batch_logs) 257 258 is_deferred = not model._is_compiled 259 batch_outs = batch_function(*batch_data) 260 if not isinstance(batch_outs, list): 261 batch_outs = [batch_outs] 262 263 if step == 0: 264 aggregator.create(batch_outs) 265 266 if is_deferred: 267 # Set callbacks params. We do this here when model is compiled only 268 # in the first iteration of this loop (deferred build scenario). 269 cbks.set_callback_parameters( 270 callbacks, 271 model, 272 do_validation=do_validation, 273 batch_size=batch_size, 274 epochs=epochs, 275 steps_per_epoch=steps_per_epoch, 276 samples=num_samples_or_steps, 277 verbose=verbose, 278 mode=mode) 279 280 # Aggregate results. 281 aggregator.aggregate(batch_outs) 282 283 # Callbacks batch end. 284 batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode) 285 callbacks._call_batch_hook(mode, 'end', step, batch_logs) 286 step += 1 287 288 if callbacks.model.stop_training: 289 break 290 291 aggregator.finalize() 292 results = aggregator.results 293 epoch_logs = cbks.make_logs(model, epoch_logs, results, mode) 294 if len(results) == 1: 295 results = results[0] 296 297 # Run the test loop every epoch during training. 298 if (do_validation and 299 training_utils_v1.should_run_validation(validation_freq, epoch) and 300 not callbacks.model.stop_training): 301 val_results = model_iteration( 302 model, 303 validation_data, 304 steps_per_epoch=validation_steps, 305 batch_size=batch_size, 306 class_weight=class_weight, 307 workers=workers, 308 use_multiprocessing=use_multiprocessing, 309 max_queue_size=max_queue_size, 310 callbacks=callbacks, 311 verbose=verbose, 312 mode=ModeKeys.TEST, 313 steps_name='validation_steps') 314 315 if not isinstance(val_results, list): 316 val_results = [val_results] 317 epoch_logs = cbks.make_logs( 318 model, epoch_logs, val_results, mode, prefix='val_') 319 320 if mode == ModeKeys.TRAIN: 321 # Epochs only apply to `fit`. 322 callbacks.on_epoch_end(epoch, epoch_logs) 323 324 # Recreate dataset iterator for the next epoch. 325 if reset_dataset_after_each_epoch and epoch < epochs - 1: 326 generator = dataset_ops.make_one_shot_iterator(original_dataset) 327 328 model._successful_loop_finish = True 329 callbacks._call_end_hook(mode) 330 331 if enqueuer is not None: 332 enqueuer.stop() 333 334 if should_set_learning_phase: 335 learning_phase_scope.__exit__(None, None, None) 336 337 if mode == ModeKeys.TRAIN: 338 return model.history 339 return results 340 341 342# Maintain compatibility with the existing names. 343fit_generator = functools.partial(model_iteration, mode=ModeKeys.TRAIN) 344evaluate_generator = functools.partial( 345 model_iteration, mode=ModeKeys.TEST, shuffle=False) 346predict_generator = functools.partial( 347 model_iteration, mode=ModeKeys.PREDICT, shuffle=False) 348 349 350def _get_next_batch(generator): 351 """Retrieves the next batch of input data.""" 352 try: 353 generator_output = next(generator) 354 except (StopIteration, errors.OutOfRangeError): 355 return None 356 357 if not isinstance(generator_output, tuple): 358 # Always wrap in a tuple. 359 generator_output = (generator_output,) 360 if len(generator_output) not in [1, 2, 3]: 361 raise ValueError( 362 'Output of generator should be a tuple of 1 or 2 or 3 ' 363 'elements: (input,) or (input, target) or ' 364 '(input, target, sample_weights). Received {}'.format(generator_output)) 365 return generator_output 366 367 368def _validate_arguments(is_sequence, is_dataset, use_multiprocessing, workers, 369 steps_per_epoch, validation_data, validation_steps, 370 mode, kwargs): 371 """Raises errors if arguments are invalid. 372 373 Args: 374 is_sequence: Boolean, whether data is a `keras.utils.data_utils.Sequence` 375 instance. 376 is_dataset: Boolean, whether data is a dataset instance. 377 use_multiprocessing: Boolean. If `True`, use process-based threading. If 378 unspecified, `use_multiprocessing` will default to `False`. Note that 379 because this implementation relies on multiprocessing, you should not pass 380 non-picklable arguments to the generator as they can't be passed easily to 381 children processes. 382 workers: Integer. Maximum number of processes to spin up when using 383 process-based threading. If unspecified, `workers` will default to 1. If 384 0, will execute the generator on the main thread. 385 steps_per_epoch: Total number of steps (batches of samples) before declaring 386 one epoch finished and starting the next epoch. Ignored with the default 387 value of `None`. 388 validation_data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or `(x, 389 y)` or `(x, y, sample_weights)`) or a generator or 390 `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset. 391 validation_steps: Total number of steps (batches of samples) before 392 declaring validation finished. 393 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. 394 kwargs: Additional arguments for backwards compatibility. 395 396 Raises: 397 ValueError: If `steps_per_epoch` or `validation_steps` are not passed 398 for data types that require them, or if unrecognized keyword 399 arguments are passed. 400 """ 401 if not is_sequence and use_multiprocessing and workers > 1: 402 logging.warning( 403 UserWarning('Using a generator with `use_multiprocessing=True`' 404 ' and multiple workers may duplicate your data.' 405 ' Please consider using the `keras.utils.Sequence`' 406 ' class.')) 407 408 if steps_per_epoch is None and not is_dataset: 409 arg_name = 'steps_per_epoch' if mode == ModeKeys.TRAIN else 'steps' 410 raise ValueError('Please specify the number of steps via the ' 411 '`{}` argument.'.format(arg_name)) 412 413 val_gen = ( 414 data_utils.is_generator_or_sequence(validation_data) or 415 isinstance(validation_data, iterator_ops.IteratorBase)) 416 if (val_gen and not isinstance(validation_data, data_utils.Sequence) and 417 not validation_steps): 418 raise ValueError('Please specify the `validation_steps` argument.') 419 420 if any(k != 'steps' for k in kwargs): 421 raise ValueError('Invalid arguments passed: {}'.format( 422 [k for k in kwargs if k != 'steps'])) 423 424 425def convert_to_generator_like(data, 426 batch_size=None, 427 steps_per_epoch=None, 428 epochs=1, 429 shuffle=False): 430 """Make a generator out of NumPy or EagerTensor inputs. 431 432 Args: 433 data: Either a generator or `keras.utils.data_utils.Sequence` object or 434 `Dataset`, `Iterator`, or a {1,2,3}-tuple of NumPy arrays or EagerTensors. 435 If a tuple, the elements represent `(x, y, sample_weights)` and may be 436 `None` or `[None]`. 437 batch_size: Used when creating a generator out of tuples of NumPy arrays or 438 EagerTensors. 439 steps_per_epoch: Steps of the generator to run each epoch. If `None` the 440 number of steps will be read from the data (for 441 `keras.utils.data_utils.Sequence` types). 442 epochs: Total number of epochs to run. 443 shuffle: Whether the data should be shuffled. 444 445 Returns: 446 - Generator, `keras.utils.data_utils.Sequence`, or `Iterator`. 447 448 Raises: 449 - ValueError: If `batch_size` is not provided for NumPy or EagerTensor 450 inputs. 451 """ 452 if isinstance(data, tuple): 453 # Scrub `Nones` that might have been passed for `targets`, `sample_weights`. 454 data = tuple( 455 ele for ele in data if not all(e is None for e in nest.flatten(ele))) 456 457 if data_utils.is_generator_or_sequence(data) or isinstance( 458 data, iterator_ops.IteratorBase): 459 if isinstance(data, data_utils.Sequence): 460 if steps_per_epoch is None: 461 steps_per_epoch = len(data) 462 return data, steps_per_epoch 463 if isinstance(data, dataset_ops.DatasetV2): 464 return dataset_ops.make_one_shot_iterator(data), steps_per_epoch 465 466 # Create generator from NumPy or EagerTensor Input. 467 num_samples = int(nest.flatten(data)[0].shape[0]) 468 if batch_size is None: 469 raise ValueError( 470 'When passing input data as arrays, do not specify ' 471 '`steps_per_epoch`/`steps` argument. Please use `batch_size` instead.') 472 steps_per_epoch = int(math.ceil(num_samples / batch_size)) 473 474 def _gen(data): 475 """Makes a generator out of a structure of NumPy/EagerTensors.""" 476 index_array = np.arange(num_samples) 477 for _ in range(epochs): 478 if shuffle: 479 np.random.shuffle(index_array) 480 batches = generic_utils.make_batches(num_samples, batch_size) 481 for (batch_start, batch_end) in batches: 482 batch_ids = index_array[batch_start:batch_end] 483 flat_batch_data = training_utils.slice_arrays( 484 nest.flatten(data), batch_ids, contiguous=(not shuffle)) 485 yield nest.pack_sequence_as(data, flat_batch_data) 486 487 return _gen(data), steps_per_epoch 488 489 490def _make_enqueued_generator(generator, 491 workers=1, 492 use_multiprocessing=False, 493 max_queue_size=10, 494 shuffle=False): 495 """Create a buffered queue of next elements of the generator.""" 496 is_sequence = isinstance(generator, data_utils.Sequence) 497 enqueuer = None 498 if workers > 0: 499 if is_sequence: 500 enqueuer = data_utils.OrderedEnqueuer( 501 generator, use_multiprocessing=use_multiprocessing, shuffle=shuffle) 502 else: 503 enqueuer = data_utils.GeneratorEnqueuer( 504 generator, use_multiprocessing=use_multiprocessing) 505 enqueuer.start(workers=workers, max_queue_size=max_queue_size) 506 output_generator = enqueuer.get() 507 else: 508 if is_sequence: 509 output_generator = data_utils.iter_sequence_infinite(generator) 510 else: 511 output_generator = generator 512 return output_generator, enqueuer 513 514 515def _make_execution_function(model, mode, class_weight=None): 516 """Makes function to run one step of model execution.""" 517 if mode == ModeKeys.TRAIN: 518 f = functools.partial(model.train_on_batch, class_weight=class_weight) 519 elif mode == ModeKeys.TEST: 520 f = model.test_on_batch 521 else: 522 # Match signature of other modes to allow 523 # 1, 2, or 3-tuples from generator 524 def predict_on_batch(x, y=None, sample_weights=None): # pylint: disable=unused-argument 525 return model.predict_on_batch(x) 526 527 f = predict_on_batch 528 529 # Maintain stateful metrics across batch-level calls. 530 if mode != ModeKeys.PREDICT: 531 f = functools.partial(f, reset_metrics=False) 532 533 return f 534 535 536def _get_num_samples_or_steps(data, steps_per_epoch): 537 """Returns number of samples or steps, and whether to use steps count mode.""" 538 flat_inputs = nest.flatten(data) 539 if hasattr(flat_inputs[0], 'shape'): 540 return int(flat_inputs[0].shape[0]), False 541 return steps_per_epoch, True 542 543 544class GeneratorOrSequenceTrainingLoop(training_utils_v1.TrainingLoop): 545 """Generator-like. 546 547 Input is Python generator, or Sequence object. 548 549 The difference between this class and `GeneratorLikeTrainingFunction` is that 550 this class only handles inputs that with x, y and sample_weight fused into one 551 param. 552 """ 553 554 def fit(self, 555 model, 556 x=None, 557 y=None, 558 batch_size=None, 559 epochs=1, 560 verbose=1, 561 callbacks=None, 562 validation_split=0., 563 validation_data=None, 564 shuffle=True, 565 class_weight=None, 566 sample_weight=None, 567 initial_epoch=0, 568 steps_per_epoch=None, 569 validation_steps=None, 570 validation_freq=1, 571 max_queue_size=10, 572 workers=1, 573 use_multiprocessing=False): 574 model._validate_or_infer_batch_size(batch_size, steps_per_epoch, x) 575 training_utils_v1.check_generator_arguments( 576 y, sample_weight, validation_split=validation_split) 577 return fit_generator( 578 model, 579 x, 580 steps_per_epoch=steps_per_epoch, 581 epochs=epochs, 582 verbose=verbose, 583 callbacks=callbacks, 584 validation_data=validation_data, 585 validation_steps=validation_steps, 586 validation_freq=validation_freq, 587 class_weight=class_weight, 588 max_queue_size=max_queue_size, 589 workers=workers, 590 use_multiprocessing=use_multiprocessing, 591 shuffle=shuffle, 592 initial_epoch=initial_epoch, 593 steps_name='steps_per_epoch') 594 595 def evaluate(self, 596 model, 597 x=None, 598 y=None, 599 batch_size=None, 600 verbose=1, 601 sample_weight=None, 602 steps=None, 603 callbacks=None, 604 max_queue_size=10, 605 workers=1, 606 use_multiprocessing=False): 607 model._validate_or_infer_batch_size(batch_size, steps, x) 608 training_utils_v1.check_generator_arguments(y, sample_weight) 609 return evaluate_generator( 610 model, 611 x, 612 steps=steps, 613 verbose=verbose, 614 callbacks=callbacks, 615 max_queue_size=max_queue_size, 616 workers=workers, 617 use_multiprocessing=use_multiprocessing) 618 619 def predict(self, 620 model, 621 x, 622 batch_size=None, 623 verbose=0, 624 steps=None, 625 callbacks=None, 626 max_queue_size=10, 627 workers=1, 628 use_multiprocessing=False): 629 model._validate_or_infer_batch_size(batch_size, steps, x) 630 return predict_generator( 631 model, 632 x, 633 steps=steps, 634 verbose=verbose, 635 callbacks=callbacks, 636 max_queue_size=max_queue_size, 637 workers=workers, 638 use_multiprocessing=use_multiprocessing) 639 640 641class EagerDatasetOrIteratorTrainingLoop(training_utils_v1.TrainingLoop): 642 """A non-distributed Dataset or iterator in eager execution.""" 643 644 def fit(self, 645 model, 646 x=None, 647 y=None, 648 batch_size=None, 649 epochs=1, 650 verbose=1, 651 callbacks=None, 652 validation_split=0., 653 validation_data=None, 654 shuffle=True, 655 class_weight=None, 656 sample_weight=None, 657 initial_epoch=0, 658 steps_per_epoch=None, 659 validation_steps=None, 660 validation_freq=1, 661 **kwargs): 662 model._validate_or_infer_batch_size(batch_size, steps_per_epoch, x) 663 # Make sure that y, sample_weights, validation_split are not passed. 664 training_utils_v1.validate_dataset_input(x, y, sample_weight, 665 validation_split) 666 if (isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)) and 667 shuffle): 668 training_utils_v1.verify_dataset_shuffled(x) 669 670 return fit_generator( 671 model, 672 x, 673 steps_per_epoch=steps_per_epoch, 674 epochs=epochs, 675 verbose=verbose, 676 callbacks=callbacks, 677 validation_data=validation_data, 678 validation_steps=validation_steps, 679 validation_freq=validation_freq, 680 class_weight=class_weight, 681 workers=0, 682 shuffle=shuffle, 683 initial_epoch=initial_epoch, 684 steps_name='steps_per_epoch') 685 686 def evaluate(self, 687 model, 688 x=None, 689 y=None, 690 batch_size=None, 691 verbose=1, 692 sample_weight=None, 693 steps=None, 694 callbacks=None, 695 **kwargs): 696 model._validate_or_infer_batch_size(batch_size, steps, x) 697 # Make sure that y, sample_weights, validation_split are not passed. 698 training_utils_v1.validate_dataset_input(x, y, sample_weight) 699 return evaluate_generator( 700 model, x, steps=steps, verbose=verbose, workers=0, callbacks=callbacks) 701 702 def predict(self, 703 model, 704 x, 705 batch_size=None, 706 verbose=0, 707 steps=None, 708 callbacks=None, 709 **kwargs): 710 model._validate_or_infer_batch_size(batch_size, steps, x) 711 return predict_generator( 712 model, x, steps=steps, verbose=verbose, workers=0, callbacks=callbacks) 713 714 715class GeneratorLikeTrainingLoop(training_utils_v1.TrainingLoop): 716 """TrainingLoop that handle inputs like python generator. 717 718 This is the default handler for most of the input data types, includes 719 symbolic tensors or Numpy array-like, Datasets and iterators in graph mode 720 (since they generate symbolic tensors). This Function is used to handle model 721 with `run_eagerly` = True. 722 """ 723 724 def fit(self, 725 model, 726 x=None, 727 y=None, 728 batch_size=None, 729 epochs=1, 730 verbose=1, 731 callbacks=None, 732 validation_split=0., 733 validation_data=None, 734 shuffle=True, 735 class_weight=None, 736 sample_weight=None, 737 initial_epoch=0, 738 steps_per_epoch=None, 739 validation_steps=None, 740 validation_freq=1, 741 **kwargs): 742 batch_size = model._validate_or_infer_batch_size(batch_size, 743 steps_per_epoch, x) 744 x, y, sample_weights = model._standardize_user_data( 745 x, 746 y, 747 sample_weight=sample_weight, 748 class_weight=class_weight, 749 batch_size=batch_size, 750 check_steps=True, 751 steps_name='steps_per_epoch', 752 steps=steps_per_epoch, 753 validation_split=validation_split, 754 shuffle=shuffle) 755 756 if validation_data: 757 validation_data = model._prepare_validation_data(validation_data, 758 batch_size, 759 validation_steps) 760 elif validation_split and 0. < validation_split < 1.: 761 (x, y, sample_weights, val_x, val_y, 762 val_sample_weights) = ( 763 training_utils_v1.split_training_and_validation_data( 764 x, y, sample_weights, validation_split)) 765 validation_data = (val_x, val_y, val_sample_weights) 766 else: 767 if validation_steps: 768 raise ValueError('`validation_steps` should not be specified if ' 769 '`validation_data` is None.') 770 771 return fit_generator( 772 model, (x, y, sample_weights), 773 steps_per_epoch=steps_per_epoch, 774 batch_size=batch_size, 775 epochs=epochs, 776 verbose=verbose, 777 callbacks=callbacks, 778 validation_data=validation_data, 779 validation_steps=validation_steps, 780 validation_freq=validation_freq, 781 workers=0, 782 shuffle=shuffle, 783 initial_epoch=initial_epoch, 784 steps_name='steps_per_epoch') 785 786 def evaluate(self, 787 model, 788 x=None, 789 y=None, 790 batch_size=None, 791 verbose=1, 792 sample_weight=None, 793 steps=None, 794 callbacks=None, 795 **kwargs): 796 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x) 797 x, y, sample_weights = model._standardize_user_data( 798 x, 799 y, 800 sample_weight=sample_weight, 801 batch_size=batch_size, 802 check_steps=True, 803 steps_name='steps', 804 steps=steps) 805 return evaluate_generator( 806 model, (x, y, sample_weights), 807 steps=steps, 808 batch_size=batch_size, 809 verbose=verbose, 810 workers=0, 811 callbacks=callbacks) 812 813 def predict(self, 814 model, 815 x, 816 batch_size=None, 817 verbose=0, 818 steps=None, 819 callbacks=None, 820 **kwargs): 821 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x) 822 x, _, _ = model._standardize_user_data( 823 x, check_steps=True, steps_name='steps', steps=steps) 824 return predict_generator( 825 model, 826 x, 827 steps=steps, 828 batch_size=batch_size, 829 verbose=verbose, 830 workers=0, 831 callbacks=callbacks) 832