1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Training-related part of the Keras engine.
16"""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import numpy as np
23
24from tensorflow.python import tf2
25from tensorflow.python.data.ops import dataset_ops
26from tensorflow.python.data.ops import iterator_ops
27from tensorflow.python.distribute import distribute_coordinator as dc
28from tensorflow.python.distribute import distribution_strategy_context
29from tensorflow.python.eager import context
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.framework import tensor_util
33from tensorflow.python.keras import backend as K
34from tensorflow.python.keras import losses
35from tensorflow.python.keras import metrics as metrics_module
36from tensorflow.python.keras import optimizers
37from tensorflow.python.keras.engine import distributed_training_utils
38from tensorflow.python.keras.engine import network
39from tensorflow.python.keras.engine import training_arrays
40from tensorflow.python.keras.engine import training_distributed
41from tensorflow.python.keras.engine import training_eager
42from tensorflow.python.keras.engine import training_generator
43from tensorflow.python.keras.engine import training_utils
44from tensorflow.python.keras.saving import saving_utils
45from tensorflow.python.keras.utils import data_utils
46from tensorflow.python.keras.utils import losses_utils
47from tensorflow.python.keras.utils.generic_utils import slice_arrays
48from tensorflow.python.keras.utils.mode_keys import ModeKeys
49from tensorflow.python.ops import math_ops
50from tensorflow.python.platform import tf_logging as logging
51from tensorflow.python.training.tracking import base as trackable
52from tensorflow.python.util import nest
53from tensorflow.python.util.tf_export import keras_export
54
55
56@keras_export('keras.models.Model', 'keras.Model')
57class Model(network.Network):
58  """`Model` groups layers into an object with training and inference features.
59
60  There are two ways to instantiate a `Model`:
61
62  1 - With the "functional API", where you start from `Input`,
63  you chain layer calls to specify the model's forward pass,
64  and finally you create your model from inputs and outputs:
65
66  ```python
67  import tensorflow as tf
68
69  inputs = tf.keras.Input(shape=(3,))
70  x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
71  outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
72  model = tf.keras.Model(inputs=inputs, outputs=outputs)
73  ```
74
75  2 - By subclassing the `Model` class: in that case, you should define your
76  layers in `__init__` and you should implement the model's forward pass
77  in `call`.
78
79  ```python
80  import tensorflow as tf
81
82  class MyModel(tf.keras.Model):
83
84    def __init__(self):
85      super(MyModel, self).__init__()
86      self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
87      self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
88
89    def call(self, inputs):
90      x = self.dense1(inputs)
91      return self.dense2(x)
92
93  model = MyModel()
94  ```
95
96  If you subclass `Model`, you can optionally have
97  a `training` argument (boolean) in `call`, which you can use to specify
98  a different behavior in training and inference:
99
100  ```python
101  import tensorflow as tf
102
103  class MyModel(tf.keras.Model):
104
105    def __init__(self):
106      super(MyModel, self).__init__()
107      self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
108      self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
109      self.dropout = tf.keras.layers.Dropout(0.5)
110
111    def call(self, inputs, training=False):
112      x = self.dense1(inputs)
113      if training:
114        x = self.dropout(x, training=training)
115      return self.dense2(x)
116
117  model = MyModel()
118  ```
119  """
120
121  def __init__(self, *args, **kwargs):
122    super(Model, self).__init__(*args, **kwargs)
123    # initializing _distribution_strategy here since it is possible to call
124    # predict on a model without compiling it.
125    self._distribution_strategy = None
126    # This flag is used to track if the user is using the deprecated path of
127    # passing distribution strategy to compile rather than creating the model
128    # under distribution strategy scope.
129    self._compile_distribution = False
130
131    self.run_eagerly = None
132
133  def get_weights(self):
134    """Retrieves the weights of the model.
135
136    Returns:
137        A flat list of Numpy arrays.
138    """
139    if self._distribution_strategy:
140      with self._distribution_strategy.scope():
141        return super(Model, self).get_weights()
142    return super(Model, self).get_weights()
143
144  def load_weights(self, filepath, by_name=False):
145    """Loads all layer weights, either from a TensorFlow or an HDF5 file."""
146    if distributed_training_utils.is_tpu_strategy(self._distribution_strategy):
147      if (self._distribution_strategy.extended.steps_per_run > 1 and
148          (not network._is_hdf5_filepath(filepath))):  # pylint: disable=protected-access
149        raise ValueError('Load weights is not yet supported with TPUStrategy '
150                         'with steps_per_run greater than 1.')
151    return super(Model, self).load_weights(filepath, by_name)
152
153  @trackable.no_automatic_dependency_tracking
154  def compile(self,
155              optimizer,
156              loss=None,
157              metrics=None,
158              loss_weights=None,
159              sample_weight_mode=None,
160              weighted_metrics=None,
161              target_tensors=None,
162              distribute=None,
163              **kwargs):
164    """Configures the model for training.
165
166    Arguments:
167        optimizer: String (name of optimizer) or optimizer instance.
168            See `tf.keras.optimizers`.
169        loss: String (name of objective function), objective function or
170            `tf.losses.Loss` instance. See `tf.losses`. If the model has
171            multiple outputs, you can use a different loss on each output by
172            passing a dictionary or a list of losses. The loss value that will
173            be minimized by the model will then be the sum of all individual
174            losses.
175        metrics: List of metrics to be evaluated by the model during training
176            and testing. Typically you will use `metrics=['accuracy']`.
177            To specify different metrics for different outputs of a
178            multi-output model, you could also pass a dictionary, such as
179            `metrics={'output_a': 'accuracy', 'output_b': ['accuracy', 'mse']}`.
180            You can also pass a list (len = len(outputs)) of lists of metrics
181            such as `metrics=[['accuracy'], ['accuracy', 'mse']]` or
182            `metrics=['accuracy', ['accuracy', 'mse']]`.
183        loss_weights: Optional list or dictionary specifying scalar
184            coefficients (Python floats) to weight the loss contributions
185            of different model outputs.
186            The loss value that will be minimized by the model
187            will then be the *weighted sum* of all individual losses,
188            weighted by the `loss_weights` coefficients.
189            If a list, it is expected to have a 1:1 mapping
190            to the model's outputs. If a tensor, it is expected to map
191            output names (strings) to scalar coefficients.
192        sample_weight_mode: If you need to do timestep-wise
193            sample weighting (2D weights), set this to `"temporal"`.
194            `None` defaults to sample-wise weights (1D).
195            If the model has multiple outputs, you can use a different
196            `sample_weight_mode` on each output by passing a
197            dictionary or a list of modes.
198        weighted_metrics: List of metrics to be evaluated and weighted
199            by sample_weight or class_weight during training and testing.
200        target_tensors: By default, Keras will create placeholders for the
201            model's target, which will be fed with the target data during
202            training. If instead you would like to use your own
203            target tensors (in turn, Keras will not expect external
204            Numpy data for these targets at training time), you
205            can specify them via the `target_tensors` argument. It can be
206            a single tensor (for a single-output model), a list of tensors,
207            or a dict mapping output names to target tensors.
208        distribute: NOT SUPPORTED IN TF 2.0, please create and compile the
209            model under distribution strategy scope instead of passing it to
210            compile.
211        **kwargs: Any additional arguments.
212
213    Raises:
214        ValueError: In case of invalid arguments for
215            `optimizer`, `loss`, `metrics` or `sample_weight_mode`.
216    """
217    run_eagerly = kwargs.pop('run_eagerly', None)
218    if run_eagerly and getattr(self, '_contains_symbolic_tensors', False):
219      raise ValueError(
220          'We currently do not support enabling `run_eagerly` on compile if '
221          '`model.add_loss(tensor)` or `model.add_metric(tensor)` '
222          'has been called.')
223
224    self._run_eagerly = run_eagerly
225    optimizer = optimizers.get(optimizer)
226
227    if distribute is not None:
228      if tf2.enabled():
229        raise ValueError(
230            'Distribute argument in compile is not available in TF 2.0 please '
231            'create the model under the distribution strategy scope.')
232      logging.warning('Distribute argument in compile is deprecated please '
233                      'create the model under the distribution strategy scope.')
234      self._distribution_strategy = distribute
235      self._compile_distribution = True
236    else:
237      if distribution_strategy_context.has_strategy():
238        # When the user builds the model in the DS scope and cross replica
239        # context we want distribution strategy to be set but when building the
240        # replica copies of the models internally we should not be compiling
241        # with distribution strategy and use the default compilation path.
242        if distribution_strategy_context.in_cross_replica_context():
243          self._distribution_strategy = (
244              distribution_strategy_context.get_strategy())
245
246    # Validate that arguments passed by the user to `compile` are supported by
247    # DistributionStrategy.
248    if self._distribution_strategy:
249      if sample_weight_mode:
250        raise NotImplementedError('sample_weight_mode is not supported with '
251                                  'DistributionStrategy.')
252      if weighted_metrics:
253        raise NotImplementedError('weighted_metrics is not supported with '
254                                  'DistributionStrategy.')
255      if target_tensors:
256        raise ValueError('target_tensors is not supported with '
257                         'DistributionStrategy.')
258
259      if run_eagerly:
260        raise ValueError(
261            'We currently do not support enabling `run_eagerly` with '
262            'distribution strategy.')
263
264      if getattr(self, '_contains_symbolic_tensors', False):
265        raise ValueError(
266            'We currently do not support compiling the model with distribution '
267            'strategy if `model.add_loss(tensor)` or `model.add_metric(tensor)`'
268            ' has been called.')
269
270      if not self.built or not self.inputs or not self.outputs:
271        raise ValueError(
272            'We currently do not support distribution strategy with a '
273            '`Sequential` model that is created without `input_shape`/'
274            '`input_dim` set in its first layer or a subclassed model.')
275
276    loss = loss or {}
277
278    self.optimizer = optimizer
279    # We've disabled automatic dependency tracking for this method, but do want
280    # to add a checkpoint dependency on the optimizer if it's trackable.
281    if isinstance(self.optimizer, trackable.Trackable):
282      self._track_trackable(
283          self.optimizer, name='optimizer', overwrite=True)
284    self.loss = loss
285    self._compile_metrics = metrics or []
286    self.loss_weights = loss_weights
287    self.sample_weight_mode = sample_weight_mode
288    self._compile_weighted_metrics = weighted_metrics
289    if self.run_eagerly and target_tensors is not None:
290      raise ValueError(
291          'target_tensors argument is not supported when '
292          'running a model eagerly.')
293    self.target_tensors = target_tensors
294
295    # Set DistributionStrategy specific parameters.
296    self._distributed_model_cache = {}
297
298    if self._distribution_strategy is not None:
299      # Ensures a Session is created and configured correctly for Distribution
300      # Strategy.
301      K.configure_and_create_distributed_session(self._distribution_strategy)
302    # Initialize model metric attributes.
303    self._init_metric_attributes()
304    if not self.built or not self.inputs or not self.outputs:
305      # Model is not compilable because it does not know its number of inputs
306      # and outputs, nor their shapes and names. We will compile after the first
307      # time the model gets called on training data.
308      return
309    self._is_compiled = True
310
311    # Prepare list of loss functions, same size of model outputs.
312    self.loss_functions = training_utils.prepare_loss_functions(
313        loss, self.output_names)
314
315    self._feed_outputs = []
316    self._feed_output_names = []
317    self._feed_output_shapes = []
318    self._feed_loss_fns = []
319    # if loss function is None, then this output will be skipped during total
320    # loss calculation and feed targets preparation.
321    skip_target_indices = []
322    skip_target_weighing_indices = []
323    for i, loss_function in enumerate(self.loss_functions):
324      if loss_function is None:
325        skip_target_indices.append(i)
326        skip_target_weighing_indices.append(i)
327
328    # Prepare output masks.
329    if not self.run_eagerly:
330      masks = [getattr(x, '_keras_mask', None) for x in self.outputs]
331
332    # Prepare list loss weights, same size of model outputs.
333    self.loss_weights_list = training_utils.prepare_loss_weights(
334        self.output_names, loss_weights)
335
336    # Initialization for Eager mode execution.
337    if self.run_eagerly:
338      # Prepare sample weights.
339      self._set_sample_weight_attributes(sample_weight_mode,
340                                         skip_target_weighing_indices)
341      # Save all metric attributes per output of the model.
342      self._cache_output_metric_attributes(metrics, weighted_metrics)
343
344      if target_tensors is not None:
345        raise ValueError('target_tensors are not currently supported in Eager '
346                         'mode.')
347      self.total_loss = None
348
349      # Set metric attributes on model.
350      self._set_metric_attributes(skip_target_indices=skip_target_indices)
351
352      self.targets = []
353      for i in range(len(self.outputs)):
354        self._feed_output_names.append(self.output_names[i])
355      self._collected_trainable_weights = self.trainable_weights
356      return
357
358    with K.get_graph().as_default():
359      # Prepare targets of model.
360      self.targets = []
361      self._feed_targets = []
362      if target_tensors not in (None, []):
363        if isinstance(target_tensors, list):
364          if len(target_tensors) != len(self.outputs):
365            raise ValueError(
366                'When passing a list as `target_tensors`, '
367                'it should have one entry per model output. '
368                'The model has %s outputs, but you passed target_tensors=%s' %
369                (len(self.outputs), target_tensors))
370        elif isinstance(target_tensors, dict):
371          for name in target_tensors:
372            if name not in self.output_names:
373              raise ValueError(
374                  'Unknown entry in `target_tensors` '
375                  'dictionary: "' + name + '". '
376                  'Only expected the following keys: ' + str(self.output_names))
377          tmp_target_tensors = []
378          for name in self.output_names:
379            tmp_target_tensors.append(target_tensors.get(name, None))
380          target_tensors = tmp_target_tensors
381        elif tensor_util.is_tensor(target_tensors):
382          target_tensors = [target_tensors]
383        else:
384          raise TypeError('Expected `target_tensors` to be a list or tuple or '
385                          'dict or a single tensor, but got:', target_tensors)
386
387      for i in range(len(self.outputs)):
388        if i in skip_target_indices:
389          self.targets.append(None)
390        else:
391          shape = K.int_shape(self.outputs[i])
392          name = self.output_names[i]
393          if target_tensors not in (None, []):
394            target = target_tensors[i]
395          else:
396            target = None
397          if target is None or K.is_placeholder(target):
398            if target is None:
399              target_dtype = losses.LABEL_DTYPES_FOR_LOSSES.get(
400                  self.loss_functions[i],
401                  K.dtype(self.outputs[i]))
402
403              target = K.placeholder(
404                  ndim=len(shape),
405                  name=name + '_target',
406                  sparse=K.is_sparse(self.outputs[i]),
407                  dtype=target_dtype)
408            self._feed_targets.append(target)
409            self._feed_outputs.append(self.outputs[i])
410            self._feed_output_names.append(name)
411            self._feed_output_shapes.append(shape)
412            self._feed_loss_fns.append(self.loss_functions[i])
413          else:
414            skip_target_weighing_indices.append(i)
415          self.targets.append(target)
416
417      # Prepare sample weights.
418      self._set_sample_weight_attributes(sample_weight_mode,
419                                         skip_target_weighing_indices)
420      # Save all metric attributes per output of the model.
421      self._cache_output_metric_attributes(metrics, weighted_metrics)
422
423      # Set metric attributes on model.
424      self._set_metric_attributes(skip_target_indices=skip_target_indices)
425
426      # Invoke metric functions for all the outputs.
427      self._handle_metrics(
428          self.outputs,
429          masks=masks,
430          targets=self.targets,
431          skip_target_indices=skip_target_indices,
432          sample_weights=self.sample_weights)
433
434      # Compute total loss.
435      # Used to keep track of the total loss value (stateless).
436      # eg., total_loss = loss_weight_1 * output_1_loss_fn(...) +
437      #                   loss_weight_2 * output_2_loss_fn(...) +
438      #                   layer losses.
439      self.total_loss = self._prepare_total_loss(skip_target_indices, masks)
440
441      # Functions for train, test and predict will
442      # be compiled lazily when required.
443      # This saves time when the user is not using all functions.
444      self._function_kwargs = kwargs
445
446      self.train_function = None
447      self.test_function = None
448      self.predict_function = None
449
450      # Collected trainable weights, sorted in topological order.
451      trainable_weights = self.trainable_weights
452      self._collected_trainable_weights = trainable_weights
453
454      # Validate all variables were correctly created in distribution scope.
455      if self._distribution_strategy and not self._compile_distribution:
456        for v in self.variables:
457          strategy = self._distribution_strategy
458          if not strategy.extended.variable_created_in_scope(v):
459            raise ValueError(
460                'Variable (%s) was not created in the distribution strategy '
461                'scope of (%s). It is most likely due to not all layers or '
462                'the model or optimizer being created outside the distribution '
463                'strategy scope. Try to make sure your code looks similar '
464                'to the following.\n'
465                'with strategy.scope():\n'
466                '  model=_create_model()\n'
467                '  model.compile(...)'% (v, strategy))
468
469  @property
470  def metrics(self):
471    """Returns the model's metrics added using `compile`, `add_metric` APIs."""
472    metrics = []
473    if self._is_compiled:
474      metrics += self._compile_metric_functions
475    return metrics + super(Model, self).metrics
476
477  @property
478  def metrics_names(self):
479    """Returns the model's display labels for all outputs."""
480    metrics_names = []
481    if self._is_compiled:
482      metrics_names += self._compile_metrics_names  # Includes names of losses.
483
484    # Add metric names from layers.
485    for layer in self.layers:
486      metrics_names += [m.name for m in layer._metrics]  # pylint: disable=protected-access
487    metrics_names += [m.name for m in self._metrics]
488    return metrics_names
489
490  @property
491  def run_eagerly(self):
492    """Settable attribute indicating whether the model should run eagerly.
493
494    Running eagerly means that your model will be run step by step,
495    like Python code. Your model might run slower, but it should become easier
496    for you to debug it by stepping into individual layer calls.
497
498    By default, we will attempt to compile your model to a static graph to
499    deliver the best execution performance.
500
501    Returns:
502      Boolean, whether the model should run eagerly.
503    """
504    if self._run_eagerly is True and not context.executing_eagerly():
505      raise ValueError('You can only set `run_eagerly=True` if eager execution '
506                       'is enabled.')
507    if not self.dynamic:
508      if self._run_eagerly is None:
509        return False
510      else:
511        return self._run_eagerly
512    else:
513      if not context.executing_eagerly():
514        raise ValueError('Your model contains layers that can only be '
515                         'successfully run in eager execution (layers '
516                         'constructed with `dynamic=True`). '
517                         'You must enable eager execution with '
518                         '`tf.enable_eager_execution()`.')
519      if self._run_eagerly is False:
520        # TODO(fchollet): consider using py_func to enable this.
521        raise ValueError('Your model contains layers that can only be '
522                         'successfully run in eager execution (layers '
523                         'constructed with `dynamic=True`). '
524                         'You cannot set `run_eagerly=False`.')
525      return context.executing_eagerly()
526
527  @run_eagerly.setter
528  def run_eagerly(self, value):
529    self._run_eagerly = value
530
531  def fit(self,
532          x=None,
533          y=None,
534          batch_size=None,
535          epochs=1,
536          verbose=1,
537          callbacks=None,
538          validation_split=0.,
539          validation_data=None,
540          shuffle=True,
541          class_weight=None,
542          sample_weight=None,
543          initial_epoch=0,
544          steps_per_epoch=None,
545          validation_steps=None,
546          validation_freq=1,
547          max_queue_size=10,
548          workers=1,
549          use_multiprocessing=False,
550          **kwargs):
551    """Trains the model for a fixed number of epochs (iterations on a dataset).
552
553    Arguments:
554        x: Input data. It could be:
555          - A Numpy array (or array-like), or a list of arrays
556            (in case the model has multiple inputs).
557          - A TensorFlow tensor, or a list of tensors
558            (in case the model has multiple inputs).
559          - A dict mapping input names to the corresponding array/tensors,
560            if the model has named inputs.
561          - A `tf.data` dataset or a dataset iterator. Should return a tuple
562            of either `(inputs, targets)` or
563            `(inputs, targets, sample_weights)`.
564          - A generator or `keras.utils.Sequence` returning `(inputs, targets)`
565            or `(inputs, targets, sample weights)`.
566        y: Target data. Like the input data `x`,
567          it could be either Numpy array(s) or TensorFlow tensor(s).
568          It should be consistent with `x` (you cannot have Numpy inputs and
569          tensor targets, or inversely). If `x` is a dataset, dataset
570          iterator, generator, or `keras.utils.Sequence` instance, `y` should
571          not be specified (since targets will be obtained from `x`).
572        batch_size: Integer or `None`.
573            Number of samples per gradient update.
574            If unspecified, `batch_size` will default to 32.
575            Do not specify the `batch_size` if your data is in the
576            form of symbolic tensors, dataset, dataset iterators,
577            generators, or `keras.utils.Sequence` instances (since they generate
578            batches).
579        epochs: Integer. Number of epochs to train the model.
580            An epoch is an iteration over the entire `x` and `y`
581            data provided.
582            Note that in conjunction with `initial_epoch`,
583            `epochs` is to be understood as "final epoch".
584            The model is not trained for a number of iterations
585            given by `epochs`, but merely until the epoch
586            of index `epochs` is reached.
587        verbose: Integer. 0, 1, or 2. Verbosity mode.
588            0 = silent, 1 = progress bar, 2 = one line per epoch.
589        callbacks: List of `keras.callbacks.Callback` instances.
590            List of callbacks to apply during training.
591            See `tf.keras.callbacks`.
592        validation_split: Float between 0 and 1.
593            Fraction of the training data to be used as validation data.
594            The model will set apart this fraction of the training data,
595            will not train on it, and will evaluate
596            the loss and any model metrics
597            on this data at the end of each epoch.
598            The validation data is selected from the last samples
599            in the `x` and `y` data provided, before shuffling. This argument is
600            not supported when `x` is a dataset, dataset iterator, generator or
601           `keras.utils.Sequence` instance.
602        validation_data: Data on which to evaluate
603            the loss and any model metrics at the end of each epoch.
604            The model will not be trained on this data.
605            `validation_data` will override `validation_split`.
606            `validation_data` could be:
607              - tuple `(x_val, y_val)` of Numpy arrays or tensors
608              - tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays
609              - dataset or a dataset iterator
610            For the first two cases, `batch_size` must be provided.
611            For the last case, `validation_steps` must be provided.
612        shuffle: Boolean (whether to shuffle the training data
613            before each epoch) or str (for 'batch').
614            'batch' is a special option for dealing with the
615            limitations of HDF5 data; it shuffles in batch-sized chunks.
616            Has no effect when `steps_per_epoch` is not `None`.
617        class_weight: Optional dictionary mapping class indices (integers)
618            to a weight (float) value, used for weighting the loss function
619            (during training only).
620            This can be useful to tell the model to
621            "pay more attention" to samples from
622            an under-represented class.
623        sample_weight: Optional Numpy array of weights for
624            the training samples, used for weighting the loss function
625            (during training only). You can either pass a flat (1D)
626            Numpy array with the same length as the input samples
627            (1:1 mapping between weights and samples),
628            or in the case of temporal data,
629            you can pass a 2D array with shape
630            `(samples, sequence_length)`,
631            to apply a different weight to every timestep of every sample.
632            In this case you should make sure to specify
633            `sample_weight_mode="temporal"` in `compile()`. This argument is not
634            supported when `x` is a dataset, dataset iterator, generator, or
635           `keras.utils.Sequence` instance, instead provide the sample_weights
636            as the third element of `x`.
637        initial_epoch: Integer.
638            Epoch at which to start training
639            (useful for resuming a previous training run).
640        steps_per_epoch: Integer or `None`.
641            Total number of steps (batches of samples)
642            before declaring one epoch finished and starting the
643            next epoch. When training with input tensors such as
644            TensorFlow data tensors, the default `None` is equal to
645            the number of samples in your dataset divided by
646            the batch size, or 1 if that cannot be determined. If x is a
647            `tf.data` dataset or a dataset iterator, and 'steps_per_epoch'
648            is None, the epoch will run until the input dataset is exhausted.
649        validation_steps: Only relevant if `validation_data` is provided and
650            is a dataset or dataset iterator. Total number of steps (batches of
651            samples) to draw before stopping when performing validation
652            at the end of every epoch. If validation_data is a `tf.data` dataset
653            or a dataset iterator, and 'validation_steps' is None, validation
654            will run until the `validation_data` dataset is exhausted.
655        validation_freq: Only relevant if validation data is provided. Integer
656            or `collections.Container` instance (e.g. list, tuple, etc.). If an
657            integer, specifies how many training epochs to run before a new
658            validation run is performed, e.g. `validation_freq=2` runs
659            validation every 2 epochs. If a Container, specifies the epochs on
660            which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
661            validation at the end of the 1st, 2nd, and 10th epochs.
662        max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
663            input only. Maximum size for the generator queue.
664            If unspecified, `max_queue_size` will default to 10.
665        workers: Integer. Used for generator or `keras.utils.Sequence` input
666            only. Maximum number of processes to spin up
667            when using process-based threading. If unspecified, `workers`
668            will default to 1. If 0, will execute the generator on the main
669            thread.
670        use_multiprocessing: Boolean. Used for generator or
671            `keras.utils.Sequence` input only. If `True`, use process-based
672            threading. If unspecified, `use_multiprocessing` will default to
673            `False`. Note that because this implementation relies on
674            multiprocessing, you should not pass non-picklable arguments to
675            the generator as they can't be passed easily to children processes.
676        **kwargs: Used for backwards compatibility.
677
678    Returns:
679        A `History` object. Its `History.history` attribute is
680        a record of training loss values and metrics values
681        at successive epochs, as well as validation loss values
682        and validation metrics values (if applicable).
683
684    Raises:
685        RuntimeError: If the model was never compiled.
686        ValueError: In case of mismatch between the provided input data
687            and what the model expects.
688    """
689    # Legacy support
690    if 'nb_epoch' in kwargs:
691      logging.warning(
692          'The `nb_epoch` argument in `fit` '
693          'has been renamed `epochs`.')
694      epochs = kwargs.pop('nb_epoch')
695    if kwargs:
696      raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
697
698    # Case 1: distribution strategy.
699    if self._distribution_strategy:
700      if K.in_multi_worker_mode():
701        # Multi-Worker mode runs the Keras training loop on multiple
702        # servers via the Distribute Coordinator.
703        def _worker_fn(_):
704          """Run training inside the distributed coordinator."""
705          filtered_callbacks = distributed_training_utils \
706              .filter_distributed_callbacks(callbacks)
707          return training_distributed.fit_distributed(
708              self,
709              x=x,
710              y=y,
711              batch_size=batch_size,
712              epochs=epochs,
713              verbose=verbose,
714              callbacks=filtered_callbacks,
715              validation_split=validation_split,
716              validation_data=validation_data,
717              shuffle=shuffle,
718              class_weight=class_weight,
719              sample_weight=sample_weight,
720              initial_epoch=initial_epoch,
721              steps_per_epoch=steps_per_epoch,
722              validation_steps=validation_steps,
723              validation_freq=validation_freq)
724
725        # Independent worker only for now.
726        return dc.run_distribute_coordinator(
727            _worker_fn,
728            self._distribution_strategy,
729            mode=dc.CoordinatorMode.INDEPENDENT_WORKER)
730      else:
731        return training_distributed.fit_distributed(
732            self,
733            x=x,
734            y=y,
735            batch_size=batch_size,
736            epochs=epochs,
737            verbose=verbose,
738            callbacks=callbacks,
739            validation_split=validation_split,
740            validation_data=validation_data,
741            shuffle=shuffle,
742            class_weight=class_weight,
743            sample_weight=sample_weight,
744            initial_epoch=initial_epoch,
745            steps_per_epoch=steps_per_epoch,
746            validation_steps=validation_steps,
747            validation_freq=validation_freq)
748
749    batch_size = self._validate_or_infer_batch_size(
750        batch_size, steps_per_epoch, x)
751
752    # Case 2: generator-like. Input is Python generator, or Sequence object,
753    # or a non-distributed Dataset or iterator in eager execution.
754    if data_utils.is_generator_or_sequence(x):
755      training_utils.check_generator_arguments(
756          y, sample_weight, validation_split=validation_split)
757      return self.fit_generator(
758          x,
759          steps_per_epoch=steps_per_epoch,
760          epochs=epochs,
761          verbose=verbose,
762          callbacks=callbacks,
763          validation_data=validation_data,
764          validation_steps=validation_steps,
765          validation_freq=validation_freq,
766          class_weight=class_weight,
767          max_queue_size=max_queue_size,
768          workers=workers,
769          use_multiprocessing=use_multiprocessing,
770          shuffle=shuffle,
771          initial_epoch=initial_epoch)
772    if training_utils.is_eager_dataset_or_iterator(x):
773      # Make sure that y, sample_weights, validation_split are not passed.
774      training_utils.validate_dataset_input(x, y, sample_weight,
775                                            validation_split)
776      if (isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2))
777          and shuffle):
778        training_utils.verify_dataset_shuffled(x)
779
780      return self.fit_generator(
781          x,
782          steps_per_epoch=steps_per_epoch,
783          epochs=epochs,
784          verbose=verbose,
785          callbacks=callbacks,
786          validation_data=validation_data,
787          validation_steps=validation_steps,
788          validation_freq=validation_freq,
789          class_weight=class_weight,
790          workers=0,
791          shuffle=shuffle,
792          initial_epoch=initial_epoch)
793
794    # Case 3: Symbolic tensors or Numpy array-like.
795    # This includes Datasets and iterators in graph mode (since they
796    # generate symbolic tensors).
797    x, y, sample_weights = self._standardize_user_data(
798        x,
799        y,
800        sample_weight=sample_weight,
801        class_weight=class_weight,
802        batch_size=batch_size,
803        check_steps=True,
804        steps_name='steps_per_epoch',
805        steps=steps_per_epoch,
806        validation_split=validation_split,
807        shuffle=shuffle)
808
809    # Prepare validation data.
810    if validation_data:
811      val_x, val_y, val_sample_weights = self._unpack_validation_data(
812          validation_data)
813      val_x, val_y, val_sample_weights = self._standardize_user_data(
814          val_x,
815          val_y,
816          sample_weight=val_sample_weights,
817          batch_size=batch_size,
818          steps=validation_steps,
819          steps_name='validation_steps')
820    elif validation_split and 0. < validation_split < 1.:
821      if training_utils.has_symbolic_tensors(x):
822        raise ValueError('If your data is in the form of symbolic tensors, '
823                         'you cannot use `validation_split`.')
824      if hasattr(x[0], 'shape'):
825        split_at = int(x[0].shape[0] * (1. - validation_split))
826      else:
827        split_at = int(len(x[0]) * (1. - validation_split))
828      x, val_x = (slice_arrays(x, 0, split_at), slice_arrays(x, split_at))
829      y, val_y = (slice_arrays(y, 0, split_at), slice_arrays(y, split_at))
830      sample_weights, val_sample_weights = (slice_arrays(
831          sample_weights, 0, split_at), slice_arrays(sample_weights, split_at))
832    elif validation_steps:
833      val_x = []
834      val_y = []
835      val_sample_weights = []
836    else:
837      val_x = None
838      val_y = None
839      val_sample_weights = None
840
841    if self.run_eagerly:
842      return training_generator.fit_generator(
843          self, (x, y, sample_weights),
844          steps_per_epoch=steps_per_epoch,
845          batch_size=batch_size,
846          epochs=epochs,
847          verbose=verbose,
848          callbacks=callbacks,
849          validation_data=validation_data,
850          validation_steps=validation_steps,
851          validation_freq=validation_freq,
852          workers=0,
853          shuffle=shuffle,
854          initial_epoch=initial_epoch,
855          steps_name='steps_per_epoch')
856    else:
857      return training_arrays.fit_loop(
858          self,
859          x,
860          y,
861          sample_weights=sample_weights,
862          batch_size=batch_size,
863          epochs=epochs,
864          verbose=verbose,
865          callbacks=callbacks,
866          val_inputs=val_x,
867          val_targets=val_y,
868          val_sample_weights=val_sample_weights,
869          shuffle=shuffle,
870          initial_epoch=initial_epoch,
871          steps_per_epoch=steps_per_epoch,
872          validation_steps=validation_steps,
873          validation_freq=validation_freq,
874          steps_name='steps_per_epoch')
875
876  def evaluate(self,
877               x=None,
878               y=None,
879               batch_size=None,
880               verbose=1,
881               sample_weight=None,
882               steps=None,
883               callbacks=None,
884               max_queue_size=10,
885               workers=1,
886               use_multiprocessing=False):
887    """Returns the loss value & metrics values for the model in test mode.
888
889    Computation is done in batches.
890
891    Arguments:
892        x: Input data. It could be:
893          - A Numpy array (or array-like), or a list of arrays
894            (in case the model has multiple inputs).
895          - A TensorFlow tensor, or a list of tensors
896            (in case the model has multiple inputs).
897          - A dict mapping input names to the corresponding array/tensors,
898            if the model has named inputs.
899          - A `tf.data` dataset or a dataset iterator.
900          - A generator or `keras.utils.Sequence` instance.
901        y: Target data. Like the input data `x`,
902          it could be either Numpy array(s) or TensorFlow tensor(s).
903          It should be consistent with `x` (you cannot have Numpy inputs and
904          tensor targets, or inversely).
905          If `x` is a dataset, dataset iterator, generator or
906          `keras.utils.Sequence` instance, `y` should not be specified (since
907          targets will be obtained from the iterator/dataset).
908        batch_size: Integer or `None`.
909            Number of samples per gradient update.
910            If unspecified, `batch_size` will default to 32.
911            Do not specify the `batch_size` is your data is in the
912            form of symbolic tensors, dataset, dataset iterators,
913            generators, or `keras.utils.Sequence` instances (since they generate
914            batches).
915        verbose: 0 or 1. Verbosity mode.
916            0 = silent, 1 = progress bar.
917        sample_weight: Optional Numpy array of weights for
918            the test samples, used for weighting the loss function.
919            You can either pass a flat (1D)
920            Numpy array with the same length as the input samples
921            (1:1 mapping between weights and samples),
922            or in the case of temporal data,
923            you can pass a 2D array with shape
924            `(samples, sequence_length)`,
925            to apply a different weight to every timestep of every sample.
926            In this case you should make sure to specify
927            `sample_weight_mode="temporal"` in `compile()`. This argument is not
928            supported when `x` is a dataset or a dataset iterator, instead pass
929            sample weights as the third element of `x`.
930        steps: Integer or `None`.
931            Total number of steps (batches of samples)
932            before declaring the evaluation round finished.
933            Ignored with the default value of `None`.
934            If x is a `tf.data` dataset or a dataset iterator, and `steps` is
935            None, 'evaluate' will run until the dataset is exhausted.
936        callbacks: List of `keras.callbacks.Callback` instances.
937            List of callbacks to apply during evaluation.
938            See [callbacks](/api_docs/python/tf/keras/callbacks).
939        max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
940            input only. Maximum size for the generator queue.
941            If unspecified, `max_queue_size` will default to 10.
942        workers: Integer. Used for generator or `keras.utils.Sequence` input
943            only. Maximum number of processes to spin up when using
944            process-based threading. If unspecified, `workers` will default
945            to 1. If 0, will execute the generator on the main thread.
946        use_multiprocessing: Boolean. Used for generator or
947            `keras.utils.Sequence` input only. If `True`, use process-based
948            threading. If unspecified, `use_multiprocessing` will default to
949            `False`. Note that because this implementation relies on
950            multiprocessing, you should not pass non-picklable arguments to
951            the generator as they can't be passed easily to children processes.
952
953    Returns:
954        Scalar test loss (if the model has a single output and no metrics)
955        or list of scalars (if the model has multiple outputs
956        and/or metrics). The attribute `model.metrics_names` will give you
957        the display labels for the scalar outputs.
958
959    Raises:
960        ValueError: in case of invalid arguments.
961    """
962    # Case 1: distribution strategy.
963    if self._distribution_strategy:
964      if K.in_multi_worker_mode():
965        # Multi-Worker mode runs the Keras evaluation loop on multiple
966        # servers via the Distribute Coordinator.
967        def _worker_fn(_):
968          """Run evaluation inside the distributed coordinator."""
969          filtered_callbacks = distributed_training_utils \
970              .filter_distributed_callbacks(callbacks)
971          return training_distributed.evaluate_distributed(
972              self,
973              x=x,
974              y=y,
975              batch_size=batch_size,
976              verbose=verbose,
977              sample_weight=sample_weight,
978              steps=steps,
979              callbacks=filtered_callbacks)
980
981        # Independent worker only for now.
982        return dc.run_distribute_coordinator(
983            _worker_fn,
984            self._distribution_strategy,
985            mode=dc.CoordinatorMode.INDEPENDENT_WORKER)
986      else:
987        return training_distributed.evaluate_distributed(
988            self,
989            x=x,
990            y=y,
991            batch_size=batch_size,
992            verbose=verbose,
993            sample_weight=sample_weight,
994            steps=steps,
995            callbacks=callbacks)
996
997    batch_size = self._validate_or_infer_batch_size(batch_size, steps, x)
998
999    # Case 2: generator-like. Input is Python generator, or Sequence object,
1000    # or a non-distributed Dataset or iterator in eager execution.
1001    if data_utils.is_generator_or_sequence(x):
1002      training_utils.check_generator_arguments(y, sample_weight)
1003      return self.evaluate_generator(
1004          x,
1005          steps=steps,
1006          verbose=verbose,
1007          callbacks=callbacks,
1008          max_queue_size=max_queue_size,
1009          workers=workers,
1010          use_multiprocessing=use_multiprocessing)
1011    if training_utils.is_eager_dataset_or_iterator(x):
1012      # Make sure that y, sample_weights are not passed.
1013      training_utils.validate_dataset_input(x, y, sample_weight)
1014      return training_generator.evaluate_generator(
1015          self, x,
1016          steps=steps,
1017          batch_size=batch_size,
1018          verbose=verbose,
1019          workers=0,
1020          callbacks=callbacks)
1021
1022    # Case 3: Symbolic tensors or Numpy array-like.
1023    # This includes Datasets and iterators in graph mode (since they
1024    # generate symbolic tensors).
1025    x, y, sample_weights = self._standardize_user_data(
1026        x,
1027        y,
1028        sample_weight=sample_weight,
1029        batch_size=batch_size,
1030        check_steps=True,
1031        steps_name='steps',
1032        steps=steps)
1033
1034    if self.run_eagerly:
1035      return training_generator.evaluate_generator(
1036          self, (x, y, sample_weights),
1037          steps=steps,
1038          batch_size=batch_size,
1039          verbose=verbose,
1040          workers=0,
1041          callbacks=callbacks)
1042    else:
1043      return training_arrays.test_loop(
1044          self,
1045          inputs=x,
1046          targets=y,
1047          sample_weights=sample_weights,
1048          batch_size=batch_size,
1049          verbose=verbose,
1050          steps=steps,
1051          callbacks=callbacks)
1052
1053  def predict(self,
1054              x,
1055              batch_size=None,
1056              verbose=0,
1057              steps=None,
1058              callbacks=None,
1059              max_queue_size=10,
1060              workers=1,
1061              use_multiprocessing=False):
1062    """Generates output predictions for the input samples.
1063
1064    Computation is done in batches.
1065
1066    Arguments:
1067         x: Input samples. It could be:
1068          - A Numpy array (or array-like), or a list of arrays
1069            (in case the model has multiple inputs).
1070          - A TensorFlow tensor, or a list of tensors
1071            (in case the model has multiple inputs).
1072          - A `tf.data` dataset or a dataset iterator.
1073          - A generator or `keras.utils.Sequence` instance.
1074        batch_size: Integer or `None`.
1075            Number of samples per gradient update.
1076            If unspecified, `batch_size` will default to 32.
1077            Do not specify the `batch_size` is your data is in the
1078            form of symbolic tensors, dataset, dataset iterators,
1079            generators, or `keras.utils.Sequence` instances (since they generate
1080            batches).
1081        verbose: Verbosity mode, 0 or 1.
1082        steps: Total number of steps (batches of samples)
1083            before declaring the prediction round finished.
1084            Ignored with the default value of `None`. If x is a `tf.data`
1085            dataset or a dataset iterator, and `steps` is None, `predict` will
1086            run until the input dataset is exhausted.
1087        callbacks: List of `keras.callbacks.Callback` instances.
1088            List of callbacks to apply during prediction.
1089            See [callbacks](/api_docs/python/tf/keras/callbacks).
1090        max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
1091            input only. Maximum size for the generator queue.
1092            If unspecified, `max_queue_size` will default to 10.
1093        workers: Integer. Used for generator or `keras.utils.Sequence` input
1094            only. Maximum number of processes to spin up when using
1095            process-based threading. If unspecified, `workers` will default
1096            to 1. If 0, will execute the generator on the main thread.
1097        use_multiprocessing: Boolean. Used for generator or
1098            `keras.utils.Sequence` input only. If `True`, use process-based
1099            threading. If unspecified, `use_multiprocessing` will default to
1100            `False`. Note that because this implementation relies on
1101            multiprocessing, you should not pass non-picklable arguments to
1102            the generator as they can't be passed easily to children processes.
1103
1104
1105    Returns:
1106        Numpy array(s) of predictions.
1107
1108    Raises:
1109        ValueError: In case of mismatch between the provided
1110            input data and the model's expectations,
1111            or in case a stateful model receives a number of samples
1112            that is not a multiple of the batch size.
1113    """
1114    # Case 1: distribution strategy.
1115    if self._distribution_strategy:
1116      return training_distributed.predict_distributed(self,
1117                                                      x=x,
1118                                                      batch_size=batch_size,
1119                                                      verbose=verbose,
1120                                                      steps=steps,
1121                                                      callbacks=callbacks)
1122
1123    batch_size = self._validate_or_infer_batch_size(batch_size, steps, x)
1124
1125    # Case 2: generator-like. Input is Python generator, or Sequence object,
1126    # or a non-distributed Dataset or iterator in eager execution.
1127    if data_utils.is_generator_or_sequence(x):
1128      return self.predict_generator(
1129          x,
1130          steps=steps,
1131          verbose=verbose,
1132          callbacks=callbacks,
1133          max_queue_size=max_queue_size,
1134          workers=workers,
1135          use_multiprocessing=use_multiprocessing)
1136    if training_utils.is_eager_dataset_or_iterator(x):
1137      return training_generator.predict_generator(
1138          self,
1139          x,
1140          steps=steps,
1141          batch_size=batch_size,
1142          verbose=verbose,
1143          workers=0,
1144          callbacks=callbacks)
1145
1146    # Case 3: Symbolic tensors or Numpy array-like.
1147    # This includes Datasets and iterators in graph mode (since they
1148    # generate symbolic tensors).
1149    x, _, _ = self._standardize_user_data(
1150        x, check_steps=True, steps_name='steps', steps=steps)
1151
1152    if self.run_eagerly:
1153      return training_generator.predict_generator(
1154          self,
1155          x,
1156          steps=steps,
1157          batch_size=batch_size,
1158          verbose=verbose,
1159          workers=0,
1160          callbacks=callbacks)
1161    else:
1162      return training_arrays.predict_loop(
1163          self,
1164          x,
1165          batch_size=batch_size,
1166          verbose=verbose,
1167          steps=steps,
1168          callbacks=callbacks)
1169
1170  def reset_metrics(self):
1171    """Resets the state of metrics."""
1172    if hasattr(self, 'metrics'):
1173      for m in self.metrics:
1174        m.reset_states()
1175
1176    # Reset the state of loss metric wrappers.
1177    if getattr(self, '_output_loss_metrics', None) is not None:
1178      for m in self._output_loss_metrics:
1179        m.reset_states()
1180
1181    # Reset metrics on all the distributed (cloned) models.
1182    if self._distribution_strategy:
1183      distributed_training_utils._reset_metrics(self)  # pylint: disable=protected-access
1184
1185  def train_on_batch(self,
1186                     x,
1187                     y=None,
1188                     sample_weight=None,
1189                     class_weight=None,
1190                     reset_metrics=True):
1191    """Runs a single gradient update on a single batch of data.
1192
1193    Arguments:
1194        x: Input data. It could be:
1195          - A Numpy array (or array-like), or a list of arrays
1196              (in case the model has multiple inputs).
1197          - A TensorFlow tensor, or a list of tensors
1198              (in case the model has multiple inputs).
1199          - A dict mapping input names to the corresponding array/tensors,
1200              if the model has named inputs.
1201          - A `tf.data` dataset or a dataset iterator.
1202        y: Target data. Like the input data `x`, it could be either Numpy
1203          array(s) or TensorFlow tensor(s). It should be consistent with `x`
1204          (you cannot have Numpy inputs and tensor targets, or inversely). If
1205          `x` is a dataset or a dataset iterator, `y` should not be specified
1206          (since targets will be obtained from the iterator).
1207        sample_weight: Optional array of the same length as x, containing
1208          weights to apply to the model's loss for each sample. In the case of
1209          temporal data, you can pass a 2D array with shape (samples,
1210          sequence_length), to apply a different weight to every timestep of
1211          every sample. In this case you should make sure to specify
1212          sample_weight_mode="temporal" in compile(). This argument is not
1213          supported when `x` is a dataset or a dataset iterator.
1214        class_weight: Optional dictionary mapping class indices (integers) to a
1215          weight (float) to apply to the model's loss for the samples from this
1216          class during training. This can be useful to tell the model to "pay
1217          more attention" to samples from an under-represented class.
1218        reset_metrics: If `True`, the metrics returned will be only for this
1219          batch. If `False`, the metrics will be statefully accumulated across
1220          batches.
1221
1222    Returns:
1223        Scalar training loss
1224        (if the model has a single output and no metrics)
1225        or list of scalars (if the model has multiple outputs
1226        and/or metrics). The attribute `model.metrics_names` will give you
1227        the display labels for the scalar outputs.
1228
1229    Raises:
1230      ValueError: In case of invalid user-provided arguments.
1231    """
1232    if self._distribution_strategy:
1233      raise NotImplementedError('`train_on_batch` is not supported for models '
1234                                'compiled with DistributionStrategy.')
1235    # Validate and standardize user data.
1236    x, y, sample_weights = self._standardize_user_data(
1237        x, y, sample_weight=sample_weight, class_weight=class_weight,
1238        extract_tensors_from_dataset=True)
1239
1240    if self.run_eagerly:
1241      outputs = training_eager.train_on_batch(
1242          self,
1243          x,
1244          y,
1245          sample_weights=sample_weights,
1246          output_loss_metrics=self._output_loss_metrics)
1247    else:
1248      x = training_utils.ModelInputs(x).as_list()
1249      ins = x + (y or []) + (sample_weights or [])
1250
1251      if not isinstance(K.symbolic_learning_phase(), int):
1252        ins += [True]  # Add learning phase value.
1253
1254      self._make_train_function()
1255      outputs = self.train_function(ins)  # pylint: disable=not-callable
1256
1257    if reset_metrics:
1258      self.reset_metrics()
1259
1260    if len(outputs) == 1:
1261      return outputs[0]
1262    return outputs
1263
1264  def test_on_batch(self, x, y=None, sample_weight=None, reset_metrics=True):
1265    """Test the model on a single batch of samples.
1266
1267    Arguments:
1268        x: Input data. It could be:
1269          - A Numpy array (or array-like), or a list of arrays
1270            (in case the model has multiple inputs).
1271          - A TensorFlow tensor, or a list of tensors
1272            (in case the model has multiple inputs).
1273          - A dict mapping input names to the corresponding array/tensors,
1274            if the model has named inputs.
1275          - A `tf.data` dataset or a dataset iterator.
1276        y: Target data. Like the input data `x`,
1277          it could be either Numpy array(s) or TensorFlow tensor(s).
1278          It should be consistent with `x` (you cannot have Numpy inputs and
1279          tensor targets, or inversely). If `x` is a dataset or a
1280          dataset iterator, `y` should not be specified
1281          (since targets will be obtained from the iterator).
1282        sample_weight: Optional array of the same length as x, containing
1283            weights to apply to the model's loss for each sample.
1284            In the case of temporal data, you can pass a 2D array
1285            with shape (samples, sequence_length),
1286            to apply a different weight to every timestep of every sample.
1287            In this case you should make sure to specify
1288            sample_weight_mode="temporal" in compile(). This argument is not
1289            supported when `x` is a dataset or a dataset iterator.
1290        reset_metrics: If `True`, the metrics returned will be only for this
1291          batch. If `False`, the metrics will be statefully accumulated across
1292          batches.
1293
1294    Returns:
1295        Scalar test loss (if the model has a single output and no metrics)
1296        or list of scalars (if the model has multiple outputs
1297        and/or metrics). The attribute `model.metrics_names` will give you
1298        the display labels for the scalar outputs.
1299
1300    Raises:
1301        ValueError: In case of invalid user-provided arguments.
1302    """
1303    if self._distribution_strategy:
1304      raise NotImplementedError('`test_on_batch` is not supported for models '
1305                                'compiled with DistributionStrategy.')
1306    # Validate and standardize user data.
1307    x, y, sample_weights = self._standardize_user_data(
1308        x, y, sample_weight=sample_weight, extract_tensors_from_dataset=True)
1309
1310    if self.run_eagerly:
1311      outputs = training_eager.test_on_batch(
1312          self,
1313          x,
1314          y,
1315          sample_weights=sample_weights,
1316          output_loss_metrics=self._output_loss_metrics)
1317    else:
1318      x = training_utils.ModelInputs(x).as_list()
1319      inputs = x + (y or []) + (sample_weights or [])
1320
1321      self._make_test_function()
1322      outputs = self.test_function(inputs)  # pylint: disable=not-callable
1323
1324    if reset_metrics:
1325      self.reset_metrics()
1326
1327    if len(outputs) == 1:
1328      return outputs[0]
1329    return outputs
1330
1331  def predict_on_batch(self, x):
1332    """Returns predictions for a single batch of samples.
1333
1334    Arguments:
1335        x: Input data. It could be:
1336          - A Numpy array (or array-like), or a list of arrays
1337            (in case the model has multiple inputs).
1338          - A TensorFlow tensor, or a list of tensors
1339            (in case the model has multiple inputs).
1340          - A `tf.data` dataset or a dataset iterator.
1341
1342    Returns:
1343        Numpy array(s) of predictions.
1344
1345    Raises:
1346        ValueError: In case of mismatch between given number of inputs and
1347          expectations of the model.
1348    """
1349    if self._distribution_strategy:
1350      raise NotImplementedError('`predict_on_batch` is not supported for '
1351                                'models compiled with DistributionStrategy.')
1352    # Validate and standardize user data.
1353    inputs, _, _ = self._standardize_user_data(
1354        x, extract_tensors_from_dataset=True)
1355    if self.run_eagerly:
1356      if (isinstance(inputs, iterator_ops.EagerIterator) or
1357          (isinstance(inputs, dataset_ops.DatasetV2))):
1358        inputs = training_utils.cast_if_floating_dtype(inputs)
1359      elif isinstance(inputs, collections.Sequence):
1360        inputs = [
1361            ops.convert_to_tensor(val, dtype=K.floatx()) for val in inputs]
1362
1363        # Unwrap lists with only one input, as we do when training on batch
1364        if len(inputs) == 1:
1365          inputs = inputs[0]
1366
1367      return self(inputs)  # pylint: disable=not-callable
1368
1369    self._make_predict_function()
1370    outputs = self.predict_function(inputs)
1371
1372    if len(outputs) == 1:
1373      return outputs[0]
1374    return outputs
1375
1376  def fit_generator(self,
1377                    generator,
1378                    steps_per_epoch=None,
1379                    epochs=1,
1380                    verbose=1,
1381                    callbacks=None,
1382                    validation_data=None,
1383                    validation_steps=None,
1384                    validation_freq=1,
1385                    class_weight=None,
1386                    max_queue_size=10,
1387                    workers=1,
1388                    use_multiprocessing=False,
1389                    shuffle=True,
1390                    initial_epoch=0):
1391    """Fits the model on data yielded batch-by-batch by a Python generator.
1392
1393    The generator is run in parallel to the model, for efficiency.
1394    For instance, this allows you to do real-time data augmentation
1395    on images on CPU in parallel to training your model on GPU.
1396
1397    The use of `keras.utils.Sequence` guarantees the ordering
1398    and guarantees the single use of every input per epoch when
1399    using `use_multiprocessing=True`.
1400
1401    Arguments:
1402        generator: A generator or an instance of `Sequence`
1403          (`keras.utils.Sequence`)
1404            object in order to avoid duplicate data
1405            when using multiprocessing.
1406            The output of the generator must be either
1407            - a tuple `(inputs, targets)`
1408            - a tuple `(inputs, targets, sample_weights)`.
1409            This tuple (a single output of the generator) makes a single batch.
1410            Therefore, all arrays in this tuple must have the same length (equal
1411            to the size of this batch). Different batches may have different
1412              sizes.
1413            For example, the last batch of the epoch is commonly smaller than
1414              the
1415            others, if the size of the dataset is not divisible by the batch
1416              size.
1417            The generator is expected to loop over its data
1418            indefinitely. An epoch finishes when `steps_per_epoch`
1419            batches have been seen by the model.
1420        steps_per_epoch: Total number of steps (batches of samples)
1421            to yield from `generator` before declaring one epoch
1422            finished and starting the next epoch. It should typically
1423            be equal to the number of samples of your dataset
1424            divided by the batch size.
1425            Optional for `Sequence`: if unspecified, will use
1426            the `len(generator)` as a number of steps.
1427        epochs: Integer, total number of iterations on the data.
1428        verbose: Verbosity mode, 0, 1, or 2.
1429        callbacks: List of callbacks to be called during training.
1430        validation_data: This can be either
1431            - a generator for the validation data
1432            - a tuple (inputs, targets)
1433            - a tuple (inputs, targets, sample_weights).
1434        validation_steps: Only relevant if `validation_data`
1435            is a generator. Total number of steps (batches of samples)
1436            to yield from `generator` before stopping.
1437            Optional for `Sequence`: if unspecified, will use
1438            the `len(validation_data)` as a number of steps.
1439        validation_freq: Only relevant if validation data is provided. Integer
1440            or `collections.Container` instance (e.g. list, tuple, etc.). If an
1441            integer, specifies how many training epochs to run before a new
1442            validation run is performed, e.g. `validation_freq=2` runs
1443            validation every 2 epochs. If a Container, specifies the epochs on
1444            which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
1445            validation at the end of the 1st, 2nd, and 10th epochs.
1446        class_weight: Dictionary mapping class indices to a weight
1447            for the class.
1448        max_queue_size: Integer. Maximum size for the generator queue.
1449            If unspecified, `max_queue_size` will default to 10.
1450        workers: Integer. Maximum number of processes to spin up
1451            when using process-based threading.
1452            If unspecified, `workers` will default to 1. If 0, will
1453            execute the generator on the main thread.
1454        use_multiprocessing: Boolean.
1455            If `True`, use process-based threading.
1456            If unspecified, `use_multiprocessing` will default to `False`.
1457            Note that because this implementation relies on multiprocessing,
1458            you should not pass non-picklable arguments to the generator
1459            as they can't be passed easily to children processes.
1460        shuffle: Boolean. Whether to shuffle the order of the batches at
1461            the beginning of each epoch. Only used with instances
1462            of `Sequence` (`keras.utils.Sequence`).
1463            Has no effect when `steps_per_epoch` is not `None`.
1464        initial_epoch: Epoch at which to start training
1465            (useful for resuming a previous training run)
1466
1467    Returns:
1468        A `History` object.
1469
1470    Example:
1471
1472    ```python
1473        def generate_arrays_from_file(path):
1474            while 1:
1475                f = open(path)
1476                for line in f:
1477                    # create numpy arrays of input data
1478                    # and labels, from each line in the file
1479                    x1, x2, y = process_line(line)
1480                    yield ({'input_1': x1, 'input_2': x2}, {'output': y})
1481                f.close()
1482
1483        model.fit_generator(generate_arrays_from_file('/my_file.txt'),
1484                            steps_per_epoch=10000, epochs=10)
1485    ```
1486    Raises:
1487        ValueError: In case the generator yields data in an invalid format.
1488    """
1489    if self._distribution_strategy:
1490      raise NotImplementedError('`fit_generator` is not supported for '
1491                                'models compiled with DistributionStrategy.')
1492    return training_generator.fit_generator(
1493        self,
1494        generator,
1495        steps_per_epoch=steps_per_epoch,
1496        epochs=epochs,
1497        verbose=verbose,
1498        callbacks=callbacks,
1499        validation_data=validation_data,
1500        validation_steps=validation_steps,
1501        validation_freq=validation_freq,
1502        class_weight=class_weight,
1503        max_queue_size=max_queue_size,
1504        workers=workers,
1505        use_multiprocessing=use_multiprocessing,
1506        shuffle=shuffle,
1507        initial_epoch=initial_epoch,
1508        steps_name='steps_per_epoch')
1509
1510  def evaluate_generator(self,
1511                         generator,
1512                         steps=None,
1513                         callbacks=None,
1514                         max_queue_size=10,
1515                         workers=1,
1516                         use_multiprocessing=False,
1517                         verbose=0):
1518    """Evaluates the model on a data generator.
1519
1520    The generator should return the same kind of data
1521    as accepted by `test_on_batch`.
1522
1523    Arguments:
1524        generator: Generator yielding tuples (inputs, targets)
1525            or (inputs, targets, sample_weights)
1526            or an instance of `keras.utils.Sequence`
1527            object in order to avoid duplicate data
1528            when using multiprocessing.
1529        steps: Total number of steps (batches of samples)
1530            to yield from `generator` before stopping.
1531            Optional for `Sequence`: if unspecified, will use
1532            the `len(generator)` as a number of steps.
1533        callbacks: List of `keras.callbacks.Callback` instances.
1534            List of callbacks to apply during evaluation.
1535            See [callbacks](/api_docs/python/tf/keras/callbacks).
1536        max_queue_size: maximum size for the generator queue
1537        workers: Integer. Maximum number of processes to spin up
1538            when using process-based threading.
1539            If unspecified, `workers` will default to 1. If 0, will
1540            execute the generator on the main thread.
1541        use_multiprocessing: Boolean.
1542            If `True`, use process-based threading.
1543            If unspecified, `use_multiprocessing` will default to `False`.
1544            Note that because this implementation relies on multiprocessing,
1545            you should not pass non-picklable arguments to the generator
1546            as they can't be passed easily to children processes.
1547        verbose: Verbosity mode, 0 or 1.
1548
1549    Returns:
1550        Scalar test loss (if the model has a single output and no metrics)
1551        or list of scalars (if the model has multiple outputs
1552        and/or metrics). The attribute `model.metrics_names` will give you
1553        the display labels for the scalar outputs.
1554
1555    Raises:
1556        ValueError: in case of invalid arguments.
1557
1558    Raises:
1559        ValueError: In case the generator yields data in an invalid format.
1560    """
1561    if self._distribution_strategy:
1562      raise NotImplementedError('`evaluate_generator` is not supported for '
1563                                'models compiled with DistributionStrategy.')
1564    return training_generator.evaluate_generator(
1565        self,
1566        generator,
1567        steps=steps,
1568        max_queue_size=max_queue_size,
1569        workers=workers,
1570        use_multiprocessing=use_multiprocessing,
1571        verbose=verbose,
1572        callbacks=callbacks)
1573
1574  def predict_generator(self,
1575                        generator,
1576                        steps=None,
1577                        callbacks=None,
1578                        max_queue_size=10,
1579                        workers=1,
1580                        use_multiprocessing=False,
1581                        verbose=0):
1582    """Generates predictions for the input samples from a data generator.
1583
1584    The generator should return the same kind of data as accepted by
1585    `predict_on_batch`.
1586
1587    Arguments:
1588        generator: Generator yielding batches of input samples
1589            or an instance of `keras.utils.Sequence` object in order to
1590            avoid duplicate data when using multiprocessing.
1591        steps: Total number of steps (batches of samples)
1592            to yield from `generator` before stopping.
1593            Optional for `Sequence`: if unspecified, will use
1594            the `len(generator)` as a number of steps.
1595        callbacks: List of `keras.callbacks.Callback` instances.
1596            List of callbacks to apply during prediction.
1597            See [callbacks](/api_docs/python/tf/keras/callbacks).
1598        max_queue_size: Maximum size for the generator queue.
1599        workers: Integer. Maximum number of processes to spin up
1600            when using process-based threading.
1601            If unspecified, `workers` will default to 1. If 0, will
1602            execute the generator on the main thread.
1603        use_multiprocessing: Boolean.
1604            If `True`, use process-based threading.
1605            If unspecified, `use_multiprocessing` will default to `False`.
1606            Note that because this implementation relies on multiprocessing,
1607            you should not pass non-picklable arguments to the generator
1608            as they can't be passed easily to children processes.
1609        verbose: verbosity mode, 0 or 1.
1610
1611    Returns:
1612        Numpy array(s) of predictions.
1613
1614    Raises:
1615        ValueError: In case the generator yields data in an invalid format.
1616    """
1617    if self._distribution_strategy:
1618      raise NotImplementedError('`predict_generator` is not supported for '
1619                                'models compiled with DistributionStrategy.')
1620    return training_generator.predict_generator(
1621        self,
1622        generator,
1623        steps=steps,
1624        max_queue_size=max_queue_size,
1625        workers=workers,
1626        use_multiprocessing=use_multiprocessing,
1627        verbose=verbose,
1628        callbacks=callbacks)
1629
1630  def _prepare_total_loss(self, skip_target_indices=None, masks=None):
1631    """Computes total loss from loss functions.
1632
1633    Arguments:
1634        skip_target_indices: A list of indices of model outputs where loss
1635          function is None.
1636        masks: List of mask values corresponding to each model output.
1637
1638    Returns:
1639        A list of loss weights of python floats.
1640
1641    Raises:
1642        TypeError: If model run_eagerly is True.
1643    """
1644    if self.run_eagerly:
1645      raise TypeError('total loss can not be computed when compiled with '
1646                      'run_eagerly = True.')
1647    skip_target_indices = skip_target_indices or []
1648    total_loss = None
1649    with K.name_scope('loss'):
1650      zipped_inputs = zip(self.targets, self.outputs, self.loss_functions,
1651                          self.sample_weights, masks, self.loss_weights_list)
1652      for i, (y_true, y_pred, loss_fn, sample_weight, mask,
1653              loss_weight) in enumerate(zipped_inputs):
1654        if i in skip_target_indices:
1655          continue
1656        loss_name = self.output_names[i] + '_loss'
1657        with K.name_scope(loss_name):
1658          if mask is not None:
1659            mask = math_ops.cast(mask, y_pred.dtype)
1660            # Update weights with mask.
1661            if sample_weight is None:
1662              sample_weight = mask
1663            else:
1664              # Update dimensions of weights to match with mask if possible.
1665              mask, _, sample_weight = (
1666                  losses_utils.squeeze_or_expand_dimensions(
1667                      mask, None, sample_weight))
1668              sample_weight *= mask
1669
1670          # Reset reduction on the loss so that we can get the per sample loss
1671          # value. We use this to get both the stateless and stateful loss
1672          # values without having to compute the underlying loss function
1673          # twice.
1674          weighted_losses = None
1675          if hasattr(loss_fn, 'reduction'):
1676            current_loss_reduction = loss_fn.reduction
1677            loss_fn.reduction = losses_utils.ReductionV2.NONE
1678            weighted_losses = loss_fn(
1679                y_true, y_pred, sample_weight=sample_weight)
1680            loss_fn.reduction = current_loss_reduction
1681
1682            # Compute the stateless loss value.
1683            output_loss = losses_utils.reduce_weighted_loss(
1684                weighted_losses, reduction=current_loss_reduction)
1685          else:
1686            # Compute the stateless loss value for a custom loss class.
1687            # Here we assume that the class takes care of loss reduction
1688            # because if this class returns a vector value we cannot
1689            # differentiate between use case where a custom optimizer
1690            # expects a vector loss value vs unreduced per-sample loss value.
1691            output_loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)
1692
1693        if len(self.outputs) > 1:
1694          # Keep track of stateful result tensor and function for the loss.
1695          # Compute the stateful loss value.
1696          if weighted_losses is not None:
1697            # TODO(b/120571621): Directly call metric when the bug is fixed.
1698            aggregated_output_loss = self._call_fn_for_each_replica(
1699                self._output_loss_metrics[i], weighted_losses)
1700          else:
1701            # Custom loss class.
1702            aggregated_output_loss = self._call_metric_fn(
1703                self._output_loss_metrics[i], y_true, y_pred, sample_weight)
1704          self._compile_metrics_tensors[loss_name] = aggregated_output_loss
1705
1706        if total_loss is None:
1707          total_loss = loss_weight * output_loss
1708        else:
1709          total_loss += loss_weight * output_loss
1710      if total_loss is None:
1711        if not self.losses:
1712          raise ValueError('The model cannot be compiled '
1713                           'because it has no loss to optimize.')
1714        else:
1715          total_loss = 0.
1716
1717      # Add regularization penalties and other layer-specific losses.
1718      if self.losses:
1719        total_loss += losses_utils.scale_loss_for_distribution(
1720            math_ops.add_n(self.losses))
1721    return total_loss
1722
1723  def _get_callback_model(self):
1724    """Returns the Callback Model for this Model."""
1725
1726    if hasattr(self, '_replicated_model') and self._replicated_model:
1727      # When using training_distributed, we set the callback model
1728      # to an instance of the `DistributedModel` that we create in
1729      # the `compile` call. The `DistributedModel` is initialized
1730      # with the first replicated model. We need to set the callback
1731      # model to a DistributedModel to allow us to override saving
1732      # and loading weights when we checkpoint the model during training.
1733      return self._replicated_model
1734    if hasattr(self, 'callback_model') and self.callback_model:
1735      return self.callback_model
1736    return self
1737
1738  def _make_callback_model(self, grouped_model):
1739    first_replicated_model = self._distribution_strategy.unwrap(
1740        grouped_model)[0]
1741    # We initialize the callback model with the first replicated model.
1742    self._replicated_model = DistributedCallbackModel(first_replicated_model)
1743    self._replicated_model.set_original_model(self)
1744
1745  def _validate_or_infer_batch_size(self, batch_size, steps, x):
1746    """Validates that the `batch_size` provided is consistent with InputLayer.
1747
1748    It's possible that the user specified a static batch size in their
1749    InputLayer. If so, this method checks the provided `batch_size` and `x`
1750    arguments are consistent with this static batch size. Also, if
1751    `batch_size` is `None`, this method will attempt to infer the batch size
1752    from the static batch size of the InputLayer. Lastly, ValueError will be
1753    raised if `x` is a tf.data.Dataset and `batch_size` is specified as we
1754    expect users to provide batched datasets.
1755
1756    Arguments:
1757      batch_size: The batch_size provided as an argument to
1758        fit/evaluate/predict.
1759      steps: The steps provided as an argument to fit/evaluate/predict.
1760      x: The data passed as `x` to fit/evaluate/predict.
1761
1762    Returns:
1763      The validated batch_size, auto-inferred from the first layer if not
1764      provided.
1765    """
1766    if batch_size is not None and isinstance(x, dataset_ops.DatasetV2):
1767      raise ValueError('The `batch_size` argument must not be specified when'
1768                       ' using dataset as an input.')
1769
1770    layers = super(Model, self).layers  # Avoids the override in Sequential.
1771    if layers:
1772      first_layer = layers[0]
1773      static_batch_size = training_utils.get_static_batch_size(first_layer)
1774      if static_batch_size is not None:
1775
1776        # Check `batch_size` argument is consistent with InputLayer.
1777        if batch_size is not None and batch_size != static_batch_size:
1778          raise ValueError('The `batch_size` argument value {} is incompatible '
1779                           'with the specified batch size of your Input Layer: '
1780                           '{}'.format(batch_size, static_batch_size))
1781
1782        # Check Dataset/Iterator batch size is consistent with InputLayer.
1783        if isinstance(x, (dataset_ops.DatasetV2, iterator_ops.Iterator,
1784                          iterator_ops.EagerIterator)):
1785          ds_batch_size = tensor_shape.as_dimension(
1786              nest.flatten(dataset_ops.get_legacy_output_shapes(x))[0][0]).value
1787          if ds_batch_size is not None and ds_batch_size != static_batch_size:
1788            raise ValueError('The batch output shape of your `Dataset` is {}, '
1789                             'which is incompatible with the specified batch '
1790                             'size of your Input Layer: {}'.format(
1791                                 ds_batch_size, static_batch_size))
1792
1793        # Set inferred batch size from the InputLayer.
1794        if steps is None:
1795          batch_size = static_batch_size
1796
1797    if batch_size is None and steps is None:
1798      # Backwards compatibility
1799      batch_size = 32
1800    return batch_size
1801
1802  def _list_functions_for_serialization(self):
1803    return {
1804        '_default_save_signature': saving_utils.trace_model_call(self)
1805    }
1806
1807  def _set_sample_weight_attributes(self, sample_weight_mode,
1808                                    skip_target_weighing_indices):
1809    """Sets sample weight related attributes on the model."""
1810    sample_weights, sample_weight_modes = training_utils.prepare_sample_weights(
1811        self.output_names, sample_weight_mode, skip_target_weighing_indices)
1812    self.sample_weights = sample_weights
1813    self.sample_weight_modes = sample_weight_modes
1814    self._feed_sample_weight_modes = [
1815        sample_weight_modes[i]
1816        for i in range(len(self.outputs))
1817        if i not in skip_target_weighing_indices
1818    ]
1819    self._feed_sample_weights = [
1820        sample_weights[i]
1821        for i in range(len(sample_weights))
1822        if i not in skip_target_weighing_indices
1823    ]
1824
1825  def _cache_output_metric_attributes(self, metrics, weighted_metrics):
1826    """Caches metric name and function attributes for every model output."""
1827    output_shapes = []
1828    for output in self.outputs:
1829      if output is None or output.shape.rank is None:
1830        output_shapes.append(None)
1831      else:
1832        output_shapes.append(output.shape.as_list())
1833    self._per_output_metrics = training_utils.collect_per_output_metric_info(
1834        metrics, self.output_names, output_shapes, self.loss_functions)
1835    self._per_output_weighted_metrics = (
1836        training_utils.collect_per_output_metric_info(
1837            weighted_metrics,
1838            self.output_names,
1839            output_shapes,
1840            self.loss_functions,
1841            is_weighted=True))
1842
1843  def _add_unique_metric_name(self, metric_name, output_index):
1844    """Makes the metric name unique and adds it to the model's metric name list.
1845
1846      If there are multiple outputs for which the metrics are calculated, the
1847      metric names have to be made unique by appending an integer.
1848
1849    Arguments:
1850      metric_name: Metric name that corresponds to the metric specified by the
1851          user. For example: 'acc'.
1852      output_index: The index of the model output for which the metric name is
1853        being added.
1854
1855    Returns:
1856      string, name of the model's unique metric name
1857    """
1858    if len(self.output_names) > 1:
1859      metric_name = '%s_%s' % (self.output_names[output_index], metric_name)
1860    j = 1
1861    base_metric_name = metric_name
1862    while metric_name in self._compile_metrics_names:
1863      metric_name = '%s_%d' % (base_metric_name, j)
1864      j += 1
1865
1866    return metric_name
1867
1868  @property
1869  def _all_metrics_tensors(self):
1870    """Returns a dictionary that maps metric names to metric result tensors.
1871
1872    This maps metric names from `model.metric_names` to result tensors.
1873    Just like model.metric_names, this includes loss names and tensors.
1874    """
1875    metrics_tensors = {}
1876    if self._is_compiled:
1877      metrics_tensors.update(self._compile_metrics_tensors)
1878    metrics_tensors.update(super(Model, self)._all_metrics_tensors)
1879    return metrics_tensors
1880
1881  def _init_metric_attributes(self):
1882    """Initialized model metric attributes."""
1883    # List of all metric names in the model. This includes loss metrics.
1884    self._compile_metrics_names = ['loss']
1885    # List of stateful metric functions. Used for resetting metric state during
1886    # training/eval. This includes loss metric functions.
1887    self._compile_metric_functions = []
1888    # Dict of all aggregated metric result tensors. This includes aggregated
1889    # loss result tensors.
1890    self._compile_metrics_tensors = {}
1891    # List of metric wrappers on output losses.
1892    self._output_loss_metrics = None
1893
1894  def _set_per_output_metric_attributes(self, metrics_dict, output_index):
1895    """Sets the metric attributes on the model for the given output.
1896
1897    Arguments:
1898      metrics_dict: A dict with metric names as keys and metric fns as values.
1899      output_index: The index of the model output for which the metric
1900        attributes are added.
1901
1902    Returns:
1903      Metrics dict updated with unique metric names as keys.
1904    """
1905    updated_metrics_dict = collections.OrderedDict()
1906    for metric_name, metric_fn in metrics_dict.items():
1907      metric_name = self._add_unique_metric_name(metric_name, output_index)
1908
1909      # Update the name on the metric class to be the unique generated name.
1910      metric_fn._name = metric_name  # pylint: disable=protected-access
1911      updated_metrics_dict[metric_name] = metric_fn
1912      # Keep track of metric name and function.
1913      self._compile_metrics_names.append(metric_name)
1914      self._compile_metric_functions.append(metric_fn)
1915    return updated_metrics_dict
1916
1917  def _set_metric_attributes(self, skip_target_indices=None):
1918    """Sets the metric attributes on the model for all the model outputs."""
1919    # Add loss metric names to the model metric names list.
1920    if len(self.outputs) > 1:
1921      output_names = [
1922          self.output_names[i] + '_loss'
1923          for i in range(len(self.outputs))
1924          if i not in skip_target_indices
1925      ]
1926      self._compile_metrics_names.extend(output_names)
1927
1928    skip_target_indices = skip_target_indices or []
1929    updated_per_output_metrics = []
1930    updated_per_output_weighted_metrics = []
1931    for i in range(len(self.outputs)):
1932      if i in skip_target_indices:
1933        updated_per_output_metrics.append(self._per_output_metrics[i])
1934        updated_per_output_weighted_metrics.append(
1935            self._per_output_weighted_metrics[i])
1936        continue
1937      updated_per_output_metrics.append(
1938          self._set_per_output_metric_attributes(self._per_output_metrics[i],
1939                                                 i))
1940      updated_per_output_weighted_metrics.append(
1941          self._set_per_output_metric_attributes(
1942              self._per_output_weighted_metrics[i], i))
1943
1944    # Create a metric wrapper for each output loss.
1945    if len(self.outputs) > 1:
1946      self._output_loss_metrics = [
1947          metrics_module.SumOverBatchSize() if hasattr(loss_fn, 'reduction')
1948          else metrics_module.SumOverBatchSizeMetricWrapper(loss_fn)
1949          for loss_fn in self.loss_functions
1950      ]
1951
1952    self._per_output_metrics = updated_per_output_metrics
1953    self._per_output_weighted_metrics = updated_per_output_weighted_metrics
1954
1955  def _call_metric_fn(self, metric_fn, y_true, y_pred, weights, mask=None):
1956    # TODO(b/120571621): Remove this function when the bug is fixed.
1957    """Helper function to call metric function with distribution strategy."""
1958    return self._call_fn_for_each_replica(
1959        training_utils.call_metric_function,
1960        metric_fn,
1961        y_true,
1962        y_pred,
1963        weights=weights,
1964        mask=mask)
1965
1966  def _call_fn_for_each_replica(self, fn, *args, **kwargs):
1967    # TODO(b/120571621): We want to avoid metric reductions here since
1968    # since TPUStrategy does not implement replica local variables.
1969    # Remove this hack once we support TPUReplicaLocalVariables.
1970    is_tpu = distributed_training_utils.is_tpu_strategy(
1971        self._distribution_strategy)
1972    if ((not is_tpu) and self._distribution_strategy and
1973        distribution_strategy_context.in_cross_replica_context()):
1974      with self._distribution_strategy.scope():
1975        return self._distribution_strategy.extended.call_for_each_replica(
1976            fn, args, kwargs)
1977    return fn(*args, **kwargs)
1978
1979  def _handle_per_output_metrics(self,
1980                                 metrics_dict,
1981                                 y_true,
1982                                 y_pred,
1983                                 mask,
1984                                 weights=None):
1985    """Calls metric functions for a single output.
1986
1987    Arguments:
1988      metrics_dict: A dict with metric names as keys and metric fns as values.
1989      y_true: Target output.
1990      y_pred: Predicted output.
1991      mask: Computed mask value for the current output.
1992      weights: Weights to be applied on the current output.
1993
1994    Returns:
1995      A list of metric result tensors.
1996    """
1997    metric_results = []
1998    for metric_name, metric_fn in metrics_dict.items():
1999      with K.name_scope(metric_name):
2000        metric_result = self._call_metric_fn(metric_fn, y_true, y_pred, weights,
2001                                             mask)
2002        metric_results.append(metric_result)
2003        if not self.run_eagerly:
2004          self._compile_metrics_tensors[metric_name] = metric_result
2005
2006    return metric_results
2007
2008  def _handle_metrics(self,
2009                      outputs,
2010                      skip_target_indices=None,
2011                      targets=None,
2012                      sample_weights=None,
2013                      masks=None):
2014    """Handles calling metric functions.
2015
2016    Arguments:
2017      outputs: List of outputs (predictions).
2018      skip_target_indices: Optional. List of target ids to skip.
2019      targets: List of targets.
2020      sample_weights: Optional list of sample weight arrays.
2021      masks: List of computed output mask values.
2022
2023    Returns:
2024      A list of metric result tensors.
2025    """
2026    skip_target_indices = skip_target_indices or []
2027    metric_results = []
2028    with K.name_scope('metrics'):
2029      # Invoke all metrics added using `compile`.
2030      for i in range(len(outputs)):
2031        if i in skip_target_indices:
2032          continue
2033        output = outputs[i] if outputs else None
2034        target = targets[i] if targets else None
2035        output_mask = masks[i] if masks else None
2036        metric_results.extend(
2037            self._handle_per_output_metrics(self._per_output_metrics[i], target,
2038                                            output, output_mask))
2039        metric_results.extend(
2040            self._handle_per_output_metrics(
2041                self._per_output_weighted_metrics[i],
2042                target,
2043                output,
2044                output_mask,
2045                weights=sample_weights[i]))
2046
2047    # Add metric results from the `add_metric` metrics in eager mode.
2048    if context.executing_eagerly():
2049      for m in self.metrics:
2050        if m not in self._compile_metric_functions:
2051          metric_results.append(m.result())
2052    return metric_results
2053
2054  def _check_trainable_weights_consistency(self):
2055    """Check trainable weights count consistency.
2056
2057    This will raise a warning if `trainable_weights` and
2058    `_collected_trainable_weights` are inconsistent (i.e. have different
2059    number of parameters).
2060    Inconsistency will typically arise when one modifies `model.trainable`
2061    without calling `model.compile` again.
2062    """
2063    if not hasattr(self, '_collected_trainable_weights'):
2064      return
2065
2066    if len(self.trainable_weights) != len(self._collected_trainable_weights):
2067      logging.log_first_n(
2068          logging.WARN, 'Discrepancy between trainable weights and collected'
2069          ' trainable weights, did you set `model.trainable`'
2070          ' without calling `model.compile` after ?', 1)
2071
2072  def _make_train_function(self):
2073    metrics_tensors = [
2074        self._all_metrics_tensors[m] for m in self.metrics_names[1:]
2075    ]
2076    if not self._is_compiled:
2077      raise RuntimeError('You must compile your model before using it.')
2078    self._check_trainable_weights_consistency()
2079    if getattr(self, 'train_function') is None:
2080      inputs = (self._feed_inputs +
2081                self._feed_targets +
2082                self._feed_sample_weights)
2083      if not isinstance(K.symbolic_learning_phase(), int):
2084        inputs += [K.symbolic_learning_phase()]
2085
2086      with K.get_graph().as_default():
2087        with K.name_scope('training'):
2088          with K.name_scope(self.optimizer.__class__.__name__):
2089            # Training updates
2090            updates = self.optimizer.get_updates(
2091                params=self._collected_trainable_weights, loss=self.total_loss)
2092      # Unconditional updates
2093      updates += self.get_updates_for(None)
2094      # Conditional updates relevant to this model
2095      updates += self.get_updates_for(self.inputs)
2096
2097      with K.name_scope('training'):
2098        # Gets loss and metrics. Updates weights at each call.
2099        fn = K.function(
2100            inputs, [self.total_loss] + metrics_tensors,
2101            updates=updates,
2102            name='train_function',
2103            **self._function_kwargs)
2104        setattr(self, 'train_function', fn)
2105
2106  def _make_test_function(self):
2107    metrics_tensors = [
2108        self._all_metrics_tensors[m] for m in self.metrics_names[1:]
2109    ]
2110    if not self._is_compiled:
2111      raise RuntimeError('You must compile your model before using it.')
2112    if getattr(self, 'test_function') is None:
2113      inputs = (self._feed_inputs +
2114                self._feed_targets +
2115                self._feed_sample_weights)
2116
2117      with K.name_scope('evaluation'):
2118        updates = self.state_updates
2119        # Return loss and metrics, no gradient updates.
2120        # Does update the network states.
2121        fn = K.function(
2122            inputs, [self.total_loss] + metrics_tensors,
2123            updates=updates,
2124            name='test_function',
2125            **self._function_kwargs)
2126        setattr(self, 'test_function', fn)
2127
2128  def _make_predict_function(self):
2129    if not hasattr(self, 'predict_function'):
2130      self.predict_function = None
2131    if self.predict_function is None:
2132      inputs = self._feed_inputs
2133      # Gets network outputs. Does not update weights.
2134      # Does update the network states.
2135      kwargs = getattr(self, '_function_kwargs', {})
2136      with K.name_scope(ModeKeys.PREDICT):
2137        self.predict_function = K.function(
2138            inputs,
2139            self.outputs,
2140            updates=self.state_updates,
2141            name='predict_function',
2142            **kwargs)
2143
2144  def _make_execution_function(self, mode):
2145    if mode == ModeKeys.TRAIN:
2146      self._make_train_function()
2147      return self.train_function
2148    if mode == ModeKeys.TEST:
2149      self._make_test_function()
2150      return self.test_function
2151    if mode == ModeKeys.PREDICT:
2152      self._make_predict_function()
2153      return self.predict_function
2154
2155  def _distribution_standardize_user_data(self,
2156                                          x,
2157                                          y=None,
2158                                          sample_weight=None,
2159                                          class_weight=None,
2160                                          batch_size=None,
2161                                          validation_split=0,
2162                                          shuffle=False,
2163                                          repeat=False,
2164                                          allow_partial_batch=False):
2165    """Runs validation checks on input and target data passed by the user.
2166
2167    This is called when using DistributionStrategy to train, evaluate or serve
2168    the model.
2169
2170    Args:
2171      x: Input data. A numpy array or `tf.data` dataset.
2172      y: Target data. A numpy array or None if x is a `tf.data` dataset.
2173      sample_weight: An optional sample-weight array passed by the user to
2174        weight the importance of each sample in `x`.
2175      class_weight: An optional class-weight array by the user to
2176        weight the importance of samples in `x` based on the class they belong
2177        to, as conveyed by `y`.
2178      batch_size: Integer batch size. If provided, it is used to run additional
2179        validation checks on stateful models.
2180      validation_split: Float between 0 and 1.
2181        Fraction of the training data to be used as validation data.
2182      shuffle: Boolean whether to shuffle the training data before each epoch.
2183      repeat: Boolean whether to repeat the numpy training data when converting
2184        to training dataset.
2185      allow_partial_batch: Boolean whether to enforce that all batches have the
2186        same size.
2187
2188    Returns:
2189      Dataset instance.
2190
2191    Raises:
2192      ValueError: In case of invalid user-provided data.
2193      RuntimeError: If the model was never compiled.
2194    """
2195    if class_weight:
2196      raise NotImplementedError('`class_weight` is currently not supported '
2197                                'when using DistributionStrategy.')
2198
2199    if (sample_weight is not None and sample_weight.all() and
2200        distributed_training_utils.is_tpu_strategy(
2201            self._distribution_strategy)):
2202      raise NotImplementedError('`sample_weight` is currently not supported '
2203                                'when using TPUStrategy.')
2204
2205    if (self.stateful and distributed_training_utils.is_tpu_strategy(
2206        self._distribution_strategy) and self._distribution_strategy.
2207        num_replicas_in_sync != 1):
2208      raise ValueError('Single core must be used for computation on '
2209                       'stateful models. Consider adding `device_assignment` '
2210                       'parameter to TPUStrategy using\n'
2211                       'topology = tf.contrib.distribute.'
2212                       'initialize_tpu_system()\n'
2213                       'device_assignment = tf.contrib.tpu.DeviceAssignment('
2214                       'topology, core_assignment=tf.contrib.tpu.'
2215                       'SINGLE_CORE_ASSIGNMENT)\n'
2216                       'tpu_strategy = tf.contrib.distribute.TPUStrategy('
2217                       'device_assignment=device_assignment)')
2218
2219    # Validates `steps` and `shuffle` arguments right at the beginning
2220    # since we use it to construct the dataset object.
2221    # TODO(anjalisridhar): Remove this check once we refactor the
2222    # _standardize_user_data code path. This check is already present elsewhere
2223    # in the codebase.
2224    if isinstance(x, dataset_ops.DatasetV2):
2225      if shuffle:
2226        training_utils.verify_dataset_shuffled(x)
2227
2228    strategy = self._distribution_strategy
2229    with strategy.scope():
2230      # We should be sure to call get_session() inside the strategy.scope()
2231      # so the strategy can affect the session options.
2232      if ops.executing_eagerly_outside_functions():
2233        session = None
2234      else:
2235        session = K.get_session()
2236
2237      first_x_value = nest.flatten(x)[0]
2238      if isinstance(first_x_value, np.ndarray):
2239        x = distributed_training_utils.list_to_tuple(x)
2240        if y is not None:
2241          y = distributed_training_utils.list_to_tuple(y)
2242          if sample_weight is not None:
2243            sample_weight = distributed_training_utils.list_to_tuple(
2244                sample_weight)
2245            in_tuple = (x, y, sample_weight)
2246          else:
2247            in_tuple = (x, y)
2248        else:
2249          in_tuple = x
2250
2251        ds = strategy.extended.experimental_make_numpy_dataset(in_tuple,
2252                                                               session=session)
2253        if shuffle:
2254          # We want a buffer size that is larger than the batch size provided by
2255          # the user and provides sufficient randomness. Note that larger
2256          # numbers introduce more memory usage based on the size of each
2257          # sample.
2258          ds = ds.shuffle(max(1024, batch_size * 8))
2259        if repeat:
2260          ds = ds.repeat()
2261
2262        # We need to use the drop_remainder argument to get a known static
2263        # input shape which is required for TPUs.
2264        drop_remainder = (not allow_partial_batch and
2265                          strategy.extended.experimental_require_static_shapes)
2266        x = ds.batch(batch_size, drop_remainder=drop_remainder)
2267      else:
2268        assert isinstance(x, dataset_ops.DatasetV2)
2269        training_utils.validate_dataset_input(x, y, sample_weight,
2270                                              validation_split)
2271    return x
2272
2273  def _standardize_user_data(self,
2274                             x,
2275                             y=None,
2276                             sample_weight=None,
2277                             class_weight=None,
2278                             batch_size=None,
2279                             check_steps=False,
2280                             steps_name='steps',
2281                             steps=None,
2282                             validation_split=0,
2283                             shuffle=False,
2284                             extract_tensors_from_dataset=False):
2285    """Runs validation checks on input and target data passed by the user.
2286
2287    Also standardizes the data to lists of arrays, in order.
2288
2289    Also builds and compiles the model on the fly if it is a subclassed model
2290    that has never been called before (and thus has no inputs/outputs).
2291
2292    This is a purely internal method, subject to refactoring at any time.
2293
2294    Args:
2295      x: Input data. It could be:
2296        - A Numpy array (or array-like), or a list of arrays
2297          (in case the model has multiple inputs).
2298        - A TensorFlow tensor, or a list of tensors
2299          (in case the model has multiple inputs).
2300        - A dict mapping input names to the corresponding array/tensors,
2301          if the model has named inputs.
2302        - A `tf.data` dataset or a dataset iterator.
2303      y: Target data. Like the input data `x`,
2304        it could be either Numpy array(s) or TensorFlow tensor(s).
2305        It should be consistent with `x` (you cannot have Numpy inputs and
2306        tensor targets, or inversely). If `x` is a dataset or a
2307        dataset iterator, `y` should not be specified
2308        (since targets will be obtained from the iterator).
2309      sample_weight: An optional sample-weight array passed by the user to
2310        weight the importance of each sample in `x`.
2311      class_weight: An optional class-weight array by the user to
2312        weight the importance of samples in `x` based on the class they belong
2313        to, as conveyed by `y`. If both `sample_weight` and `class_weight` are
2314        provided, the weights are multiplied.
2315      batch_size: Integer batch size. If provided, it is used to run additional
2316        validation checks on stateful models.
2317      check_steps: boolean, True if we want to check for validity of `steps` and
2318        False, otherwise. For example, when we are standardizing one batch of
2319        data for train_on_batch/predict_on_batch/test_on_batch APIs, `steps`
2320        value is not required and we should not check for its validity in these
2321        cases.
2322      steps_name: The public API's parameter name for `steps`.
2323      steps: Integer or `None`. Total number of steps (batches of samples) to
2324        execute.
2325      validation_split: Float between 0 and 1.
2326        Fraction of the training data to be used as validation data.
2327      shuffle: Boolean whether to shuffle the training data before each epoch.
2328      extract_tensors_from_dataset: Boolean. When `x` is a dataset instance,
2329        this indicates whether to extract actual tensors from the dataset or
2330        instead output the dataset instance itself.
2331        Set to True when calling from `train_on_batch`/etc.
2332
2333    Returns:
2334      A tuple of 3: inputs (arrays or dicts, depending on whether `x` was a dict
2335      or not), target arrays, sample-weight arrays.
2336      If the model's input and targets are symbolic, these lists are empty
2337      (since the model takes no user-provided data, instead the data comes
2338      from the symbolic inputs/targets).
2339
2340    Raises:
2341      ValueError: In case of invalid user-provided data.
2342      RuntimeError: If the model was never compiled.
2343    """
2344    if isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
2345      # Graph mode dataset. We'll pass the dataset as-is (unless
2346      # `extract_tensors_from_dataset` is True, in which case we extract
2347      # the tensors from the dataset and we output them.
2348      training_utils.validate_dataset_input(x, y, sample_weight,
2349                                            validation_split)
2350      if shuffle:
2351        training_utils.verify_dataset_shuffled(x)
2352
2353      is_dataset = True
2354      if extract_tensors_from_dataset:
2355        # We do this for `train_on_batch`/etc.
2356        x, y, sample_weight = training_utils.extract_tensors_from_dataset(x)
2357    elif isinstance(x, iterator_ops.Iterator):
2358      # Graph mode iterator. We extract the symbolic tensors.
2359      training_utils.validate_dataset_input(x, y, sample_weight,
2360                                            validation_split)
2361      iterator = x
2362      x, y, sample_weight = training_utils.unpack_iterator_input(iterator)
2363      is_dataset = True
2364    else:
2365      is_dataset = False
2366
2367    # Validates `steps` argument based on x's type.
2368    if check_steps:
2369      training_utils.check_steps_argument(x, steps, steps_name)
2370
2371    # First, we build/compile the model on the fly if necessary.
2372    all_inputs = []
2373    is_build_called = False
2374    is_compile_called = False
2375    # Whether this is a subclassed model that expects dictionary inputs
2376    # rather than list inputs (e.g. FeatureColumn-based models).
2377    dict_inputs = False
2378    if not self.inputs:
2379      # We need to use `x_input` to set the model inputs.
2380
2381      # If input data is a dataset iterator in graph mode or if it is an eager
2382      # iterator and only one batch of samples is required, we fetch the data
2383      # tensors from the iterator and then standardize them.
2384      if isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
2385        x_input, y_input, _ = training_utils.extract_tensors_from_dataset(x)
2386      else:
2387        x_input = x
2388        y_input = y
2389      # We type-check that `x_input` and `y_input` are either single arrays
2390      # or lists of arrays.
2391      if isinstance(x_input, (list, tuple)):
2392        if not all(isinstance(v, np.ndarray) or
2393                   tensor_util.is_tensor(v) for v in x_input):
2394          raise ValueError('Please provide as model inputs either a single '
2395                           'array or a list of arrays. You passed: x=' + str(x))
2396        all_inputs += list(x_input)
2397      elif isinstance(x_input, dict):
2398        dict_inputs = True
2399        keys = sorted(x_input.keys())
2400        all_inputs = [x_input[k] for k in keys]
2401      else:
2402        if (not isinstance(x_input, np.ndarray) and
2403            not tensor_util.is_tensor(x_input)):
2404          raise ValueError('Please provide as model inputs either a single '
2405                           'array or a list of arrays. You passed: x=' + str(x))
2406        all_inputs.append(x_input)
2407
2408      # Build the model using the retrieved inputs (value or symbolic).
2409      # If values or generated from a dataset, then in symbolic-mode
2410      # placeholders will be created to match the value shapes.
2411      is_build_called = True
2412      if is_dataset:
2413        cast_inputs = nest.map_structure(lambda v: v.shape, x_input)
2414      elif training_utils.has_tensors(x_input):
2415        cast_inputs = training_utils.cast_if_floating_dtype(x_input)
2416      else:
2417        cast_inputs = x_input
2418      self._set_inputs(cast_inputs)
2419    else:
2420      y_input = y
2421      dict_inputs = isinstance(self.inputs, dict)
2422
2423    if y_input is not None:
2424      if not self.optimizer:
2425        raise RuntimeError('You must compile a model before '
2426                           'training/testing. '
2427                           'Use `model.compile(optimizer, loss)`.')
2428      if not self._is_compiled:
2429        # On-the-fly compilation of the model.
2430        # We need to use `y` to set the model targets.
2431        if training_utils.has_tensors(y_input):
2432          y_input = training_utils.cast_if_floating_dtype(y_input)
2433        if isinstance(y_input, (list, tuple)):
2434          if not all(isinstance(v, np.ndarray) or
2435                     tensor_util.is_tensor(v) for v in y_input):
2436            raise ValueError('Please provide as model targets either a single '
2437                             'array or a list of arrays. '
2438                             'You passed: y=' + str(y))
2439          all_inputs += list(y_input)
2440        elif isinstance(y_input, dict):
2441          raise ValueError('You cannot pass a dictionary as model targets.')
2442        else:
2443          if (not isinstance(y_input, np.ndarray) and
2444              not tensor_util.is_tensor(y_input)):
2445            raise ValueError('Please provide as model targets either a single '
2446                             'array or a list of arrays. '
2447                             'You passed: y=' + str(y))
2448          all_inputs.append(y_input)
2449
2450        # Typecheck that all inputs are *either* value *or* symbolic.
2451        # TODO(fchollet): this check could be removed in Eager mode?
2452        if any(tensor_util.is_tensor(v) for v in all_inputs):
2453          if not all(tensor_util.is_tensor(v) for v in all_inputs):
2454            raise ValueError('Do not pass inputs that mix Numpy arrays and '
2455                             'TensorFlow tensors. '
2456                             'You passed: x=' + str(x) + '; y=' + str(y))
2457
2458        if is_dataset or context.executing_eagerly():
2459          target_tensors = None
2460        else:
2461          # Handle target tensors if any passed.
2462          if not isinstance(y_input, (list, tuple)):
2463            y_input = [y_input]
2464          target_tensors = [v for v in y_input if _is_symbolic_tensor(v)]
2465        is_compile_called = True
2466        self.compile(
2467            optimizer=self.optimizer,
2468            loss=self.loss,
2469            metrics=self._compile_metrics,
2470            weighted_metrics=self._compile_weighted_metrics,
2471            loss_weights=self.loss_weights,
2472            target_tensors=target_tensors,
2473            run_eagerly=self.run_eagerly)
2474
2475    # In graph mode, if we had just set inputs and targets as symbolic tensors
2476    # by invoking build and compile on the model respectively, we do not have to
2477    # feed anything to the model. Model already has input and target data as
2478    # part of the graph.
2479    # Note: in this case, `any` and `all` are equivalent since we disallow
2480    # mixed symbolic/value inputs.
2481    if (not self.run_eagerly and is_build_called and is_compile_called and
2482        not is_dataset  and any(_is_symbolic_tensor(v) for v in all_inputs)):
2483      return [], [], []
2484
2485    # What follows is input validation and standardization to list format,
2486    # in the case where all inputs are value arrays.
2487
2488    if self.run_eagerly:
2489      # In eager mode, do not do shape validation
2490      # since the network has no input nodes (placeholders) to be fed.
2491      feed_input_names = self.input_names
2492      feed_input_shapes = None
2493    elif not self._is_graph_network:
2494      # Case: symbolic-mode subclassed network. Do not do shape validation.
2495      feed_input_names = self._feed_input_names
2496      feed_input_shapes = None
2497    else:
2498      # Case: symbolic-mode graph network.
2499      # In this case, we run extensive shape validation checks.
2500      feed_input_names = self._feed_input_names
2501      feed_input_shapes = self._feed_input_shapes
2502
2503    # Standardize the inputs.
2504    if not isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
2505      # TODO(fchollet): run static checks with dataset output shape(s).
2506      x = training_utils.standardize_input_data(
2507          x,
2508          feed_input_names,
2509          feed_input_shapes,
2510          check_batch_axis=False,  # Don't enforce the batch size.
2511          exception_prefix='input')
2512
2513    if y is not None:
2514      if not self._is_graph_network:
2515        feed_output_names = self._feed_output_names
2516        feed_output_shapes = None
2517        # Sample weighting not supported in this case.
2518        # TODO(fchollet): consider supporting it.
2519        feed_sample_weight_modes = [None for _ in self.outputs]
2520      else:
2521        feed_output_names = self._feed_output_names
2522        feed_sample_weight_modes = self._feed_sample_weight_modes
2523        feed_output_shapes = []
2524        for output_shape, loss_fn in zip(self._feed_output_shapes,
2525                                         self._feed_loss_fns):
2526          if ((isinstance(loss_fn, losses.LossFunctionWrapper) and
2527               loss_fn.fn == losses.sparse_categorical_crossentropy)) or (
2528                   isinstance(loss_fn, losses.SparseCategoricalCrossentropy)):
2529            if K.image_data_format() == 'channels_first':
2530              feed_output_shapes.append(
2531                  (output_shape[0], 1) + output_shape[2:])
2532            else:
2533              feed_output_shapes.append(output_shape[:-1] + (1,))
2534          elif (not isinstance(loss_fn, losses.Loss) or
2535                (isinstance(loss_fn, losses.LossFunctionWrapper) and
2536                 (getattr(losses, loss_fn.fn.__name__, None) is None))):
2537            # If the given loss is not an instance of the `Loss` class (custom
2538            # class) or if the loss function that is wrapped is not in the
2539            # `losses` module, then it is a user-defined loss and we make no
2540            # assumptions about it.
2541            feed_output_shapes.append(None)
2542          else:
2543            feed_output_shapes.append(output_shape)
2544
2545      # Standardize the outputs.
2546      y = training_utils.standardize_input_data(
2547          y,
2548          feed_output_names,
2549          # Don't enforce target shapes to match output shapes.
2550          # Precise checks will be run in `check_loss_and_target_compatibility`.
2551          shapes=None,
2552          check_batch_axis=False,  # Don't enforce the batch size.
2553          exception_prefix='target')
2554
2555      # Generate sample-wise weight values given the `sample_weight` and
2556      # `class_weight` arguments.
2557      sample_weights = training_utils.standardize_sample_weights(
2558          sample_weight, feed_output_names)
2559      class_weights = training_utils.standardize_class_weights(
2560          class_weight, feed_output_names)
2561      sample_weights = [
2562          training_utils.standardize_weights(ref, sw, cw, mode)
2563          for (ref, sw, cw, mode) in zip(y, sample_weights, class_weights,
2564                                         feed_sample_weight_modes)
2565      ]
2566      # Check that all arrays have the same length.
2567      if not self._distribution_strategy:
2568        training_utils.check_array_lengths(x, y, sample_weights)
2569        if self._is_graph_network and not self.run_eagerly:
2570          # Additional checks to avoid users mistakenly using improper loss fns.
2571          training_utils.check_loss_and_target_compatibility(
2572              y, self._feed_loss_fns, feed_output_shapes)
2573    else:
2574      y = []
2575      sample_weights = []
2576
2577    if self.stateful and batch_size:
2578      # Check that for stateful networks, number of samples is a multiple
2579      # of the static batch size.
2580      if x[0].shape[0] % batch_size != 0:
2581        raise ValueError('In a stateful network, '
2582                         'you should only pass inputs with '
2583                         'a number of samples that can be '
2584                         'divided by the batch size. Found: ' +
2585                         str(x[0].shape[0]) + ' samples')
2586
2587    # If dictionary inputs were provided, we return a dictionary as well.
2588    if dict_inputs and not isinstance(x, (dataset_ops.DatasetV1,
2589                                          dataset_ops.DatasetV2)):
2590      x = dict(zip(feed_input_names, x))
2591    return x, y, sample_weights
2592
2593  def _unpack_validation_data(self, validation_data):
2594    if (isinstance(validation_data, (iterator_ops.Iterator,
2595                                     iterator_ops.EagerIterator,
2596                                     dataset_ops.DatasetV2))):
2597      val_x = validation_data
2598      val_y = None
2599      val_sample_weight = None
2600    elif len(validation_data) == 2:
2601      val_x, val_y = validation_data  # pylint: disable=unpacking-non-sequence
2602      val_sample_weight = None
2603    elif len(validation_data) == 3:
2604      val_x, val_y, val_sample_weight = validation_data  # pylint: disable=unpacking-non-sequence
2605    else:
2606      raise ValueError(
2607          'When passing a `validation_data` argument, '
2608          'it must contain either 2 items (x_val, y_val), '
2609          'or 3 items (x_val, y_val, val_sample_weights), '
2610          'or alternatively it could be a dataset or a '
2611          'dataset or a dataset iterator. '
2612          'However we received `validation_data=%s`' % validation_data)
2613    return val_x, val_y, val_sample_weight
2614
2615  # TODO(omalleyt): Consider changing to a more descriptive function name.
2616  def _set_inputs(self, inputs, outputs=None, training=None):
2617    """Set model's input and output specs based on the input data received.
2618
2619    This is to be used for Model subclasses, which do not know at instantiation
2620    time what their inputs look like.
2621
2622    Args:
2623      inputs: Single array, or list of arrays. The arrays could be placeholders,
2624        Numpy arrays, data tensors, or TensorShapes.
2625        - if placeholders: the model is built on top of these placeholders,
2626          and we expect Numpy data to be fed for them when calling `fit`/etc.
2627        - if Numpy data or TensorShapes: we create placeholders matching the
2628          TensorShapes or shapes of the Numpy arrays. We expect Numpy data to be
2629          fed for these placeholders when calling `fit`/etc.
2630        - if data tensors: the model is built on top of these tensors.
2631          We do not expect any Numpy data to be provided when calling `fit`/etc.
2632      outputs: None, a data tensor, or a list of tensors. If None, the
2633        outputs will be determined by invoking `self.call()`, otherwise the
2634        provided value will be used.
2635      training: Boolean or None. Only relevant in symbolic mode. Specifies
2636        whether to build the model's graph in inference mode (False), training
2637        mode (True), or using the Keras learning phase (None).
2638    Raises:
2639      ValueError: If dict inputs are passed to a Sequential Model where the
2640        first layer isn't FeatureLayer.
2641    """
2642    inputs = self._set_input_attrs(inputs)
2643
2644    if outputs is None:
2645      kwargs = {'training': training} if self._expects_training_arg else {}
2646      try:
2647        outputs = self(inputs, **kwargs)
2648      except NotImplementedError:
2649        # This Model or a submodel is dynamic and hasn't overridden
2650        # `compute_output_shape`.
2651        outputs = None
2652
2653    self._set_output_attrs(outputs)
2654
2655  @trackable.no_automatic_dependency_tracking
2656  def _set_input_attrs(self, inputs):
2657    """Sets attributes related to the inputs of the Model."""
2658    if self.inputs:
2659      raise ValueError('Model inputs are already set.')
2660
2661    if self.__class__.__name__ == 'Sequential' and not self.built:
2662      if tensor_util.is_tensor(inputs):
2663        input_shape = (None,) + tuple(inputs.shape.as_list()[1:])
2664      elif isinstance(inputs, tensor_shape.TensorShape):
2665        input_shape = (None,) + tuple(inputs.as_list()[1:])
2666      elif isinstance(inputs, dict):
2667        # We assert that the first layer is a FeatureLayer.
2668        if not training_utils.is_feature_layer(self.layers[0]):
2669          raise ValueError('Passing a dictionary input to a Sequential Model '
2670                           'which doesn\'t have FeatureLayer as the first layer'
2671                           ' is an error.')
2672        input_shape = (None,)
2673      else:
2674        input_shape = (None,) + tuple(inputs.shape[1:])
2675      self._build_input_shape = input_shape
2676
2677    # On-the-fly setting of symbolic model inputs (either by using the tensor
2678    # provided, or by creating a placeholder if Numpy data was provided).
2679    model_inputs = training_utils.ModelInputs(inputs)
2680    inputs = model_inputs.get_symbolic_inputs()
2681    self.inputs = model_inputs.get_symbolic_inputs(return_single_as_list=True)
2682    self.input_names = model_inputs.get_input_names()
2683
2684    self._feed_inputs = []
2685    self._feed_input_names = []
2686    self._feed_input_shapes = []
2687
2688    for k, v in model_inputs.as_dict():
2689      if K.is_placeholder(v):
2690        self._feed_input_names.append(k)
2691        self._feed_inputs.append(v)
2692        self._feed_input_shapes.append(K.int_shape(v))
2693
2694    return inputs
2695
2696  @trackable.no_automatic_dependency_tracking
2697  def _set_output_attrs(self, outputs):
2698    """Sets attributes related to the outputs of the Model."""
2699    outputs = nest.flatten(outputs)
2700    self.outputs = outputs
2701    self.output_names = training_utils.generic_output_names(outputs)
2702    self.built = True
2703
2704
2705class DistributedCallbackModel(Model):
2706  """Model that is used for callbacks with DistributionStrategy."""
2707
2708  def __init__(self, model):
2709    super(DistributedCallbackModel, self).__init__()
2710    self.optimizer = model.optimizer
2711
2712  def set_original_model(self, orig_model):
2713    self._original_model = orig_model
2714
2715  def save_weights(self, filepath, overwrite=True, save_format=None):
2716    self._replicated_model.save_weights(filepath, overwrite=overwrite,
2717                                        save_format=save_format)
2718
2719  def save(self, filepath, overwrite=True, include_optimizer=True):
2720    # save weights from the distributed model to the original model
2721    distributed_model_weights = self.get_weights()
2722    self._original_model.set_weights(distributed_model_weights)
2723    # TODO(anjalisridhar): Do we need to save the original model here?
2724    # Saving the first replicated model works as well.
2725    self._original_model.save(filepath, overwrite=True, include_optimizer=False)
2726
2727  def load_weights(self, filepath, by_name=False):
2728    self._original_model.load_weights(filepath, by_name=False)
2729    # Copy the weights from the original model to each of the replicated models.
2730    orig_model_weights = self._original_model.get_weights()
2731    distributed_training_utils.set_weights(
2732        self._original_model._distribution_strategy, self,  # pylint: disable=protected-access
2733        orig_model_weights)
2734
2735  def __getattr__(self, item):
2736    # Whitelisted atttributes of the model that can be accessed by the user
2737    # during a callback.
2738    if item not in ['_setattr_tracking']:
2739      logging.warning('You are accessing attribute ' + item + ' of the '
2740                      'DistributedCallbackModel that may not have been set '
2741                      'correctly.')
2742
2743
2744def _is_symbolic_tensor(x):
2745  return tensor_util.is_tensor(x) and not isinstance(x, ops.EagerTensor)
2746