1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Part of the Keras training engine related to plain array data. 16""" 17# pylint: disable=protected-access 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import functools 23 24import numpy as np 25 26from tensorflow.python.data.ops import dataset_ops 27from tensorflow.python.data.ops import iterator_ops 28from tensorflow.python.eager import context 29from tensorflow.python.framework import errors 30from tensorflow.python.keras import backend as K 31from tensorflow.python.keras import callbacks as cbks 32from tensorflow.python.keras.engine import distributed_training_utils 33from tensorflow.python.keras.engine import training_utils 34from tensorflow.python.keras.utils.generic_utils import make_batches 35from tensorflow.python.keras.utils.generic_utils import slice_arrays 36from tensorflow.python.keras.utils.mode_keys import ModeKeys 37from tensorflow.python.platform import tf_logging as logging 38 39try: 40 from scipy.sparse import issparse # pylint: disable=g-import-not-at-top 41except ImportError: 42 issparse = None 43 44 45def model_iteration(model, 46 inputs, 47 targets=None, 48 sample_weights=None, 49 batch_size=None, 50 epochs=1, 51 verbose=1, 52 callbacks=None, 53 val_inputs=None, 54 val_targets=None, 55 val_sample_weights=None, 56 shuffle=True, 57 initial_epoch=0, 58 steps_per_epoch=None, 59 validation_steps=None, 60 validation_freq=1, 61 mode=ModeKeys.TRAIN, 62 validation_in_fit=False, 63 prepared_feed_values_from_dataset=False, 64 steps_name='steps', 65 **kwargs): 66 """Loop function for arrays of data with modes TRAIN/TEST/PREDICT. 67 68 Arguments: 69 model: Keras Model instance. 70 inputs: Either a list or dictionary of arrays, or a dataset instance. 71 targets: List/dictionary of input arrays. 72 sample_weights: Optional list of sample weight arrays. 73 batch_size: Integer batch size or None if unknown. 74 epochs: Number of times to iterate over the data 75 verbose: Verbosity mode, 0, 1 or 2 76 callbacks: List of callbacks to be called during training 77 val_inputs: Either a list or dictionary of arrays, or a dataset instance. 78 val_targets: List/dictionary of target arrays. 79 val_sample_weights: Optional list of sample weight arrays. 80 shuffle: Whether to shuffle the data at the beginning of each epoch 81 concatenation of list the display names of the outputs of `f` and the 82 list of display names of the outputs of `f_val`. 83 initial_epoch: Epoch at which to start training (useful for resuming a 84 previous training run) 85 steps_per_epoch: Total number of steps (batches of samples) before 86 declaring one epoch finished and starting the next epoch. Ignored with 87 the default value of `None`. 88 validation_steps: Number of steps to run validation for (only if doing 89 validation from data tensors). Ignored with the default value of `None`. 90 validation_freq: Only relevant if validation data is provided. Integer or 91 `collections.Container` instance (e.g. list, tuple, etc.). If an 92 integer, specifies how many training epochs to run before a new 93 validation run is performed, e.g. `validation_freq=2` runs 94 validation every 2 epochs. If a Container, specifies the epochs on 95 which to run validation, e.g. `validation_freq=[1, 2, 10]` runs 96 validation at the end of the 1st, 2nd, and 10th epochs. 97 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. 98 validation_in_fit: if true, then this method is invoked from within 99 training iteration (for validation). In the case where `val_inputs` is a 100 dataset, this flag indicates that its iterator and feed values are 101 already created so should properly reuse resources. 102 prepared_feed_values_from_dataset: if True, `inputs` is a list of feed 103 tensors returned from `_prepare_feed_values` call on the validation 104 dataset, so do not call it again on `inputs`. Should only be used for 105 inline validation (i.e., only if `validation_in_fit` is also True). 106 steps_name: The string name of the steps argument, either `steps`, 107 `validation_steps`, or `steps_per_epoch`. Only used for error message 108 formatting. 109 **kwargs: Additional arguments for backwards compatibility. 110 111 Returns: 112 - In TRAIN mode: `History` object. 113 - In TEST mode: Evaluation metrics. 114 - In PREDICT mode: Outputs of the Model called on inputs. 115 116 Raises: 117 ValueError: in case of invalid arguments. 118 """ 119 # Backwards compatibility. 120 if 'steps' in kwargs: 121 steps_per_epoch = kwargs.pop('steps') 122 if kwargs: 123 raise TypeError('Unknown arguments: %s' % (kwargs,)) 124 125 # In case we were passed a dataset, we extract symbolic tensors from it. 126 reset_dataset_after_each_epoch = False 127 input_iterator = None 128 is_dataset = isinstance(inputs, 129 (dataset_ops.DatasetV1, dataset_ops.DatasetV2)) 130 # TODO(fchollet): consider moving `steps_per_epoch` inference to 131 # _standardize_user_data and set reset_dataset_after_each_epoch as an 132 # attribute on the dataset instance. 133 if is_dataset: 134 if steps_per_epoch is None: 135 reset_dataset_after_each_epoch = True 136 steps_per_epoch = training_utils.infer_steps_for_dataset( 137 inputs, steps_per_epoch, epochs=epochs, steps_name=steps_name) 138 input_iterator = _get_iterator(inputs, model._distribution_strategy) 139 140 if mode == ModeKeys.TRAIN: 141 _print_train_info(inputs, val_inputs, steps_per_epoch, verbose) 142 143 # Enter DistributionStrategy scope. 144 if model._distribution_strategy: 145 scope = distributed_training_utils.distributed_scope( 146 strategy=model._distribution_strategy, 147 learning_phase=(1 if mode == ModeKeys.TRAIN else 0)) 148 scope.__enter__() 149 150 # Get step function and loop type. 151 f = _make_execution_function(model, mode) 152 use_steps = is_dataset or steps_per_epoch is not None 153 do_validation = val_inputs is not None 154 155 # Convert Eager Tensors to NumPy arrays to support batching/shuffling. 156 inputs, targets, sample_weights = training_utils. \ 157 convert_eager_tensors_to_numpy((inputs, targets, sample_weights)) 158 159 # Prepare input data. 160 inputs = input_iterator or inputs 161 if validation_in_fit and prepared_feed_values_from_dataset: 162 # When invoking validation in training loop, avoid creating iterator and 163 # list of feed values for the same validation dataset multiple times (which 164 # essentially would call `iterator.get_next()` that slows down execution and 165 # leads to OOM errors eventually. 166 ins = inputs 167 else: 168 ins = _prepare_feed_values(model, inputs, targets, sample_weights, mode) 169 if not is_dataset: 170 num_samples_or_steps = _get_num_samples_or_steps(ins, batch_size, 171 steps_per_epoch) 172 else: 173 num_samples_or_steps = steps_per_epoch 174 175 # Prepare validation data. Hold references to the iterator and the input list 176 # to properly reinitialize and reuse in multiple validation passes. 177 val_iterator = None 178 if isinstance(val_inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)): 179 if validation_steps is None: 180 # Because we pass an iterator feed instead of a Dataset to the eval 181 # model_iteration() call, it will not trigger the dataset-input path 182 # that determines the number of steps required. To avoid this issue, 183 # set validation_steps here if validation_steps is None. 184 validation_steps = training_utils.infer_steps_for_dataset( 185 val_inputs, 186 validation_steps, 187 epochs=epochs, 188 steps_name='validation_steps') 189 val_iterator = _get_iterator(val_inputs, model._distribution_strategy) 190 val_inputs = _prepare_feed_values( 191 model, val_iterator, val_targets, val_sample_weights, ModeKeys.TEST) 192 193 # Configure callbacks. 194 count_mode = 'steps' if use_steps else 'samples' 195 callbacks = cbks.configure_callbacks( 196 callbacks, 197 model, 198 do_validation=do_validation, 199 batch_size=batch_size, 200 epochs=epochs, 201 steps_per_epoch=steps_per_epoch, 202 samples=num_samples_or_steps, 203 verbose=0, # Handle ProgBarLogger separately in this loop. 204 mode=mode) 205 # TODO(omalleyt): Handle ProgBar as part of Callbacks once hooks are ready. 206 progbar = training_utils.get_progbar(model, count_mode) 207 progbar.params = callbacks.params 208 progbar.params['verbose'] = verbose 209 210 # Find beforehand arrays that need sparse-to-dense conversion. 211 if issparse is not None and not use_steps: 212 indices_for_conversion_to_dense = [] 213 feed = _get_model_feed(model, mode) 214 for i, (input_data, feed_tensor) in enumerate(zip(ins, feed)): 215 if issparse(input_data) and not K.is_sparse(feed_tensor): 216 indices_for_conversion_to_dense.append(i) 217 218 # Select aggregation method. 219 if mode == ModeKeys.PREDICT: 220 aggregator = training_utils.OutputsAggregator(use_steps, 221 num_samples_or_steps) 222 else: 223 aggregator = training_utils.MetricsAggregator(use_steps, 224 num_samples_or_steps) 225 226 if model._compile_distribution: 227 distributed_training_utils._copy_weights_to_distributed_model(model, mode) 228 229 callbacks.model.stop_training = False 230 callbacks._call_begin_hook(mode) 231 progbar.on_train_begin() 232 233 for epoch in range(initial_epoch, epochs): 234 if callbacks.model.stop_training: 235 break 236 237 # Setup work for each epoch 238 epoch_logs = {} 239 model.reset_metrics() 240 if mode == ModeKeys.TRAIN: 241 callbacks.on_epoch_begin(epoch, epoch_logs) 242 progbar.on_epoch_begin(epoch, epoch_logs) 243 244 if use_steps: 245 # Step-wise loop. 246 if steps_per_epoch is None: 247 # Loop over dataset until `OutOfRangeError` is raised. 248 target_steps = np.inf 249 else: 250 # Loop over dataset for the specified number of steps. 251 target_steps = steps_per_epoch 252 253 step = 0 254 while step < target_steps: 255 batch_logs = {'batch': step, 'size': 1} 256 callbacks._call_batch_hook(mode, 'begin', step, batch_logs) 257 progbar.on_batch_begin(step, batch_logs) 258 259 # Get outputs. 260 try: 261 # `ins` can be callable in DistributionStrategy + eager case. 262 actual_inputs = ins() if callable(ins) else ins 263 batch_outs = f(actual_inputs) 264 except errors.OutOfRangeError: 265 if is_dataset: 266 # The dataset passed by the user ran out of batches. 267 # Now we know the cardinality of the dataset. 268 # If steps_per_epoch was specified, then running out of data is 269 # unexpected, so we stop training and inform the user. 270 if steps_per_epoch: 271 callbacks.model.stop_training = True 272 logging.warning( 273 'Your dataset ran out of data; interrupting training. ' 274 'Make sure that your dataset can generate at least ' 275 '`%s * epochs` batches (in this case, %d batches). ' 276 'You may need to use the repeat() function when ' 277 'building your dataset.' 278 % (steps_name, steps_per_epoch * epochs)) 279 elif step > 0: 280 steps_per_epoch = step 281 aggregator.num_samples_or_steps = steps_per_epoch 282 if mode == ModeKeys.TRAIN: 283 progbar.params['steps'] = steps_per_epoch 284 progbar.progbar.target = steps_per_epoch 285 else: 286 # We ran out of batches while the user passed an iterator (legacy). 287 callbacks.model.stop_training = True 288 logging.warning( 289 'Your dataset iterator ran out of data; ' 290 'interrupting training. Make sure that your iterator ' 291 'can generate at least `%s * epochs` ' 292 'batches (in this case, %d batches). You may need to' 293 'use the repeat() function when building your ' 294 'dataset.' % (steps_name, steps_per_epoch * epochs)) 295 break 296 297 if not isinstance(batch_outs, list): 298 batch_outs = [batch_outs] 299 300 if model._distribution_strategy: 301 batch_outs = distributed_training_utils._per_device_aggregate_batch( 302 batch_outs, model, mode) 303 304 # Aggregate results. 305 if step == 0: 306 aggregator.create(batch_outs) 307 aggregator.aggregate(batch_outs) 308 309 # Callbacks batch end. 310 batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode) 311 callbacks._call_batch_hook(mode, 'end', step, batch_logs) 312 progbar.on_batch_end(step, batch_logs) 313 step += 1 314 315 if callbacks.model.stop_training: 316 break 317 else: 318 # Sample-wise loop. 319 index_array = np.arange(num_samples_or_steps) 320 if shuffle == 'batch': 321 index_array = training_utils.batch_shuffle(index_array, batch_size) 322 elif shuffle: 323 np.random.shuffle(index_array) 324 batches = make_batches(num_samples_or_steps, batch_size) 325 326 for batch_index, (batch_start, batch_end) in enumerate(batches): 327 batch_ids = index_array[batch_start:batch_end] 328 329 # Slice into a batch. 330 try: 331 if ins and isinstance(ins[-1], int): 332 # Do not slice the training phase flag. 333 ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] 334 else: 335 ins_batch = slice_arrays(ins, batch_ids) 336 except TypeError: 337 raise TypeError('TypeError while preparing batch. ' 338 'If using HDF5 input data, ' 339 'pass shuffle="batch".') 340 341 # Sparse to dense conversion. 342 if issparse is not None: 343 for i in indices_for_conversion_to_dense: 344 ins_batch[i] = ins_batch[i].toarray() 345 346 # Callbacks batch_begin. 347 batch_logs = {'batch': batch_index, 'size': len(batch_ids)} 348 callbacks._call_batch_hook(mode, 'begin', batch_index, batch_logs) 349 progbar.on_batch_begin(batch_index, batch_logs) 350 351 # Get outputs. 352 batch_outs = f(ins_batch) 353 if not isinstance(batch_outs, list): 354 batch_outs = [batch_outs] 355 356 # Aggregate results. 357 if batch_index == 0: 358 aggregator.create(batch_outs) 359 aggregator.aggregate(batch_outs, batch_start, batch_end) 360 361 # Callbacks batch end. 362 batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode) 363 callbacks._call_batch_hook(mode, 'end', batch_index, batch_logs) 364 progbar.on_batch_end(batch_index, batch_logs) 365 366 if callbacks.model.stop_training: 367 break 368 369 aggregator.finalize() 370 results = aggregator.results 371 epoch_logs = cbks.make_logs(model, epoch_logs, results, mode) 372 if len(results) == 1: 373 results = results[0] 374 375 # Run the test loop every `validation_freq` epochs during training. 376 if (do_validation and 377 training_utils.should_run_validation(validation_freq, epoch) and 378 not callbacks.model.stop_training): 379 380 if model._compile_distribution: 381 # Since we create a new clone from the original model we need to copy 382 # the weights back to the original model before we can run validation. 383 distributed_training_utils._copy_weights_to_original_model( 384 model, ModeKeys.TRAIN) 385 386 val_results = model_iteration( 387 model, 388 val_inputs, 389 targets=val_targets, 390 sample_weights=val_sample_weights, 391 batch_size=batch_size, 392 steps_per_epoch=validation_steps, 393 callbacks=callbacks, 394 verbose=0, 395 mode=ModeKeys.TEST, 396 validation_in_fit=True, 397 prepared_feed_values_from_dataset=(val_iterator is not None), 398 steps_name='validation_steps') 399 if not isinstance(val_results, list): 400 val_results = [val_results] 401 epoch_logs = cbks.make_logs( 402 model, epoch_logs, val_results, mode, prefix='val_') 403 if val_iterator and epoch < epochs - 1: 404 _reinitialize_iterator(val_iterator, model._distribution_strategy) 405 406 if mode == ModeKeys.TRAIN: 407 # Epochs only apply to `fit`. 408 callbacks.on_epoch_end(epoch, epoch_logs) 409 progbar.on_epoch_end(epoch, epoch_logs) 410 411 # Reinitialize dataset iterator for the next epoch. 412 if reset_dataset_after_each_epoch and epoch < epochs - 1: 413 _reinitialize_iterator(input_iterator, model._distribution_strategy) 414 415 callbacks._call_end_hook(mode) 416 417 if model._distribution_strategy: 418 if model._compile_distribution: 419 # TODO(priyag, psv): Copy back metrics to the original model as well? 420 distributed_training_utils._copy_weights_to_original_model(model, mode) 421 scope.__exit__(None, None, None) 422 423 if mode == ModeKeys.TRAIN: 424 return model.history 425 return results 426 427 428def _get_model_feed(model, mode): 429 if mode == ModeKeys.PREDICT: 430 feed = model._feed_inputs 431 else: 432 feed = ( 433 model._feed_inputs + model._feed_targets + model._feed_sample_weights) 434 return feed 435 436 437def _print_train_info(inputs, val_inputs, steps_per_epoch, verbose): 438 if (val_inputs and steps_per_epoch is None and verbose and inputs and 439 hasattr(inputs[0], 'shape') and hasattr(val_inputs[0], 'shape')): 440 print('Train on %d samples, validate on %d samples' % 441 (inputs[0].shape[0], val_inputs[0].shape[0])) 442 443 444def _get_num_samples_or_steps(ins, batch_size, steps_per_epoch): 445 """Returns total number of samples (when training in batch mode) or steps.""" 446 if steps_per_epoch: 447 return steps_per_epoch 448 return training_utils.check_num_samples(ins, batch_size, steps_per_epoch, 449 'steps_per_epoch') 450 451 452def _prepare_feed_values(model, inputs, targets, sample_weights, mode): 453 """Prepare feed values to the model execution function. 454 455 Arguments: 456 model: Model to prepare feed values for. 457 inputs: List or dict of model inputs. 458 targets: Optional list of model targets. 459 sample_weights: Optional list of sample weight arrays. 460 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. 461 462 Returns: 463 Feed values for the model in the given mode. 464 """ 465 if model._distribution_strategy: 466 if isinstance(inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)): 467 inputs = distributed_training_utils.get_iterator( 468 inputs, model._distribution_strategy) 469 470 def get_distributed_inputs(): 471 return distributed_training_utils._prepare_feed_values( 472 model, inputs, targets, sample_weights, mode) 473 474 # In the eager case, we want to call the input method per step, so return 475 # a lambda from here that can be called. Note that this is applicable only 476 # in Distribution Strategy case as it follows the same code path for both 477 # eager and graph modes. 478 # TODO(priyag,omalleyt): Either we should move the training DS with 479 # EagerIterator to use training_generator code path, or figure out how to 480 # set a symbolic Iterator out of a Dataset when in eager mode. 481 if context.executing_eagerly(): 482 return get_distributed_inputs 483 else: 484 return get_distributed_inputs() 485 486 if isinstance(inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2, 487 iterator_ops.Iterator)): 488 inputs, targets, sample_weights = model._standardize_user_data( 489 inputs, 490 extract_tensors_from_dataset=True) 491 492 inputs = training_utils.ModelInputs(inputs).as_list() 493 targets = targets or [] 494 sample_weights = sample_weights or [] 495 ins = inputs + targets + sample_weights 496 if mode == ModeKeys.TRAIN and not isinstance(K.symbolic_learning_phase(), 497 int): 498 ins += [True] # Add learning phase value. 499 return ins 500 501 502def _get_iterator(inputs, distribution_strategy=None): 503 if distribution_strategy: 504 return distributed_training_utils.get_iterator( 505 inputs, distribution_strategy) 506 return training_utils.get_iterator(inputs) 507 508 509def _reinitialize_iterator(iterator, distribution_strategy=None): 510 if distribution_strategy: 511 distributed_training_utils.initialize_iterator( 512 iterator, distribution_strategy) 513 else: 514 training_utils.initialize_iterator(iterator) 515 516 517def _make_execution_function(model, mode): 518 """Makes function to run one step of model execution.""" 519 if model._distribution_strategy: 520 return distributed_training_utils._make_execution_function(model, mode) 521 return model._make_execution_function(mode) 522 523 524# For backwards compatibility for internal users of these loops. 525fit_loop = functools.partial(model_iteration, mode=ModeKeys.TRAIN) 526test_loop = functools.partial( 527 model_iteration, mode=ModeKeys.TEST, shuffle=False) 528predict_loop = functools.partial( 529 model_iteration, mode=ModeKeys.PREDICT, shuffle=False) 530