1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""V1 Training-related part of the Keras engine."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21import warnings
22
23import numpy as np
24
25from tensorflow.python import tf2
26from tensorflow.python.data.ops import dataset_ops
27from tensorflow.python.data.ops import iterator_ops
28from tensorflow.python.distribute import distribution_strategy_context
29from tensorflow.python.distribute import parameter_server_strategy
30from tensorflow.python.distribute import parameter_server_strategy_v2
31from tensorflow.python.eager import context
32from tensorflow.python.eager import def_function
33from tensorflow.python.framework import constant_op
34from tensorflow.python.framework import ops
35from tensorflow.python.framework import sparse_tensor
36from tensorflow.python.framework import tensor_shape
37from tensorflow.python.framework import tensor_spec
38from tensorflow.python.framework import tensor_util
39from tensorflow.python.framework import type_spec
40from tensorflow.python.keras import backend as K
41from tensorflow.python.keras import losses
42from tensorflow.python.keras import metrics as metrics_module
43from tensorflow.python.keras import optimizer_v1
44from tensorflow.python.keras import optimizers
45from tensorflow.python.keras.distribute import distributed_training_utils
46from tensorflow.python.keras.distribute import distributed_training_utils_v1
47from tensorflow.python.keras.engine import base_layer
48from tensorflow.python.keras.engine import training as training_lib
49from tensorflow.python.keras.engine import training_arrays_v1
50from tensorflow.python.keras.engine import training_distributed_v1
51from tensorflow.python.keras.engine import training_eager_v1
52from tensorflow.python.keras.engine import training_generator_v1
53from tensorflow.python.keras.engine import training_utils
54from tensorflow.python.keras.engine import training_utils_v1
55from tensorflow.python.keras.mixed_precision import loss_scale_optimizer
56from tensorflow.python.keras.mixed_precision import policy
57from tensorflow.python.keras.optimizer_v2 import optimizer_v2
58from tensorflow.python.keras.saving import saving_utils
59from tensorflow.python.keras.saving.saved_model import model_serialization
60from tensorflow.python.keras.utils import data_utils
61from tensorflow.python.keras.utils import layer_utils
62from tensorflow.python.keras.utils import losses_utils
63from tensorflow.python.keras.utils import tf_inspect
64from tensorflow.python.keras.utils import tf_utils
65from tensorflow.python.keras.utils.mode_keys import ModeKeys
66from tensorflow.python.ops import array_ops
67from tensorflow.python.ops import math_ops
68from tensorflow.python.platform import tf_logging as logging
69from tensorflow.python.training.tracking import base as trackable
70from tensorflow.python.util import nest
71
72try:
73  from scipy.sparse import issparse  # pylint: disable=g-import-not-at-top
74except ImportError:
75  issparse = None
76
77
78class Model(training_lib.Model):
79  """`Model` groups layers into an object with training and inference features.
80
81  There are two ways to instantiate a `Model`:
82
83  1 - With the "functional API", where you start from `Input`,
84  you chain layer calls to specify the model's forward pass,
85  and finally you create your model from inputs and outputs:
86
87  ```python
88  import tensorflow as tf
89
90  inputs = tf.keras.Input(shape=(3,))
91  x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
92  outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
93  model = tf.keras.Model(inputs=inputs, outputs=outputs)
94  ```
95
96  2 - By subclassing the `Model` class: in that case, you should define your
97  layers in `__init__` and you should implement the model's forward pass
98  in `call`.
99
100  ```python
101  import tensorflow as tf
102
103  class MyModel(tf.keras.Model):
104
105    def __init__(self):
106      super(MyModel, self).__init__()
107      self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
108      self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
109
110    def call(self, inputs):
111      x = self.dense1(inputs)
112      return self.dense2(x)
113
114  model = MyModel()
115  ```
116
117  If you subclass `Model`, you can optionally have
118  a `training` argument (boolean) in `call`, which you can use to specify
119  a different behavior in training and inference:
120
121  ```python
122  import tensorflow as tf
123
124  class MyModel(tf.keras.Model):
125
126    def __init__(self):
127      super(MyModel, self).__init__()
128      self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
129      self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
130      self.dropout = tf.keras.layers.Dropout(0.5)
131
132    def call(self, inputs, training=False):
133      x = self.dense1(inputs)
134      if training:
135        x = self.dropout(x, training=training)
136      return self.dense2(x)
137
138  model = MyModel()
139  ```
140  """
141
142  def __init__(self, *args, **kwargs):
143    super(Model, self).__init__(*args, **kwargs)
144    # initializing _distribution_strategy here since it is possible to call
145    # predict on a model without compiling it.
146    self._distribution_strategy = None
147    self._compile_time_distribution_strategy = None
148    if (ops.executing_eagerly_outside_functions() and
149        distribution_strategy_context.has_strategy()):
150      self._set_strategy(
151          distribution_strategy_context.get_strategy())
152
153    # This flag is used to track if the user is using the deprecated path of
154    # passing distribution strategy to compile rather than creating the model
155    # under distribution strategy scope.
156    self._compile_distribution = False
157
158    self._run_eagerly = None
159    self._experimental_run_tf_function = (
160        ops.executing_eagerly_outside_functions())
161
162    self._v1_compile_was_called = False
163
164  def _init_batch_counters(self):
165    pass  # Batch counters should not be created in legacy graph mode.
166
167  @trackable.no_automatic_dependency_tracking
168  def _set_strategy(self, strategy):
169    self._compile_time_distribution_strategy = strategy
170
171  def get_weights(self):
172    """Retrieves the weights of the model.
173
174    Returns:
175        A flat list of Numpy arrays.
176    """
177    strategy = (self._distribution_strategy or
178                self._compile_time_distribution_strategy)
179    if strategy:
180      with strategy.scope():
181        return base_layer.Layer.get_weights(self)
182    return base_layer.Layer.get_weights(self)
183
184  def load_weights(self, filepath, by_name=False, skip_mismatch=False):
185    """Loads all layer weights, either from a TensorFlow or an HDF5 weight file.
186
187    If `by_name` is False weights are loaded based on the network's
188    topology. This means the architecture should be the same as when the weights
189    were saved.  Note that layers that don't have weights are not taken into
190    account in the topological ordering, so adding or removing layers is fine as
191    long as they don't have weights.
192
193    If `by_name` is True, weights are loaded into layers only if they share the
194    same name. This is useful for fine-tuning or transfer-learning models where
195    some of the layers have changed.
196
197    Only topological loading (`by_name=False`) is supported when loading weights
198    from the TensorFlow format. Note that topological loading differs slightly
199    between TensorFlow and HDF5 formats for user-defined classes inheriting from
200    `tf.keras.Model`: HDF5 loads based on a flattened list of weights, while the
201    TensorFlow format loads based on the object-local names of attributes to
202    which layers are assigned in the `Model`'s constructor.
203
204    Args:
205        filepath: String, path to the weights file to load. For weight files in
206            TensorFlow format, this is the file prefix (the same as was passed
207            to `save_weights`).
208        by_name: Boolean, whether to load weights by name or by topological
209            order. Only topological loading is supported for weight files in
210            TensorFlow format.
211        skip_mismatch: Boolean, whether to skip loading of layers where there is
212            a mismatch in the number of weights, or a mismatch in the shape of
213            the weight (only valid when `by_name=True`).
214
215    Returns:
216        When loading a weight file in TensorFlow format, returns the same status
217        object as `tf.train.Checkpoint.restore`. When graph building, restore
218        ops are run automatically as soon as the network is built (on first call
219        for user-defined classes inheriting from `Model`, immediately if it is
220        already built).
221
222        When loading weights in HDF5 format, returns `None`.
223
224    Raises:
225        ImportError: If h5py is not available and the weight file is in HDF5
226            format.
227        ValueError: If `skip_mismatch` is set to `True` when `by_name` is
228          `False`.
229    """
230    if K.is_tpu_strategy(self._distribution_strategy):
231      if (self._distribution_strategy.extended.steps_per_run > 1 and
232          (not saving_utils.is_hdf5_filepath(filepath))):  # pylint: disable=protected-access
233        raise ValueError('Load weights is not yet supported with TPUStrategy '
234                         'with steps_per_run greater than 1.')
235    return super(Model, self).load_weights(filepath, by_name, skip_mismatch)
236
237  @trackable.no_automatic_dependency_tracking
238  def compile(self,
239              optimizer='rmsprop',
240              loss=None,
241              metrics=None,
242              loss_weights=None,
243              sample_weight_mode=None,
244              weighted_metrics=None,
245              target_tensors=None,
246              distribute=None,
247              **kwargs):
248    """Configures the model for training.
249
250    Args:
251        optimizer: String (name of optimizer) or optimizer instance.
252            See `tf.keras.optimizers`.
253        loss: String (name of objective function), objective function or
254            `tf.keras.losses.Loss` instance. See `tf.keras.losses`. An objective
255            function is any callable with the signature
256            `scalar_loss = fn(y_true, y_pred)`. If the model has multiple
257            outputs, you can use a different loss on each output by passing a
258            dictionary or a list of losses. The loss value that will be
259            minimized by the model will then be the sum of all individual
260            losses.
261        metrics: List of metrics to be evaluated by the model during training
262            and testing. Typically you will use `metrics=['accuracy']`.
263            To specify different metrics for different outputs of a
264            multi-output model, you could also pass a dictionary, such as
265            `metrics={'output_a': 'accuracy', 'output_b': ['accuracy', 'mse']}`.
266            You can also pass a list (len = len(outputs)) of lists of metrics
267            such as `metrics=[['accuracy'], ['accuracy', 'mse']]` or
268            `metrics=['accuracy', ['accuracy', 'mse']]`.
269        loss_weights: Optional list or dictionary specifying scalar
270            coefficients (Python floats) to weight the loss contributions
271            of different model outputs.
272            The loss value that will be minimized by the model
273            will then be the *weighted sum* of all individual losses,
274            weighted by the `loss_weights` coefficients.
275            If a list, it is expected to have a 1:1 mapping
276            to the model's outputs. If a tensor, it is expected to map
277            output names (strings) to scalar coefficients.
278        sample_weight_mode: If you need to do timestep-wise
279            sample weighting (2D weights), set this to `"temporal"`.
280            `None` defaults to sample-wise weights (1D).
281            If the model has multiple outputs, you can use a different
282            `sample_weight_mode` on each output by passing a
283            dictionary or a list of modes.
284        weighted_metrics: List of metrics to be evaluated and weighted
285            by sample_weight or class_weight during training and testing.
286        target_tensors: By default, Keras will create placeholders for the
287            model's target, which will be fed with the target data during
288            training. If instead you would like to use your own
289            target tensors (in turn, Keras will not expect external
290            Numpy data for these targets at training time), you
291            can specify them via the `target_tensors` argument. It can be
292            a single tensor (for a single-output model), a list of tensors,
293            or a dict mapping output names to target tensors.
294        distribute: NOT SUPPORTED IN TF 2.0, please create and compile the
295            model under distribution strategy scope instead of passing it to
296            compile.
297        **kwargs: Any additional arguments.
298
299    Raises:
300        ValueError: In case of invalid arguments for
301            `optimizer`, `loss`, `metrics` or `sample_weight_mode`.
302    """
303    self._assert_built_as_v1()
304    self._run_eagerly = kwargs.pop('run_eagerly', None)
305    self._experimental_run_tf_function = kwargs.pop(
306        'experimental_run_tf_function', True)
307    self._v1_compile_was_called = True
308
309    # Prepare Session arguments (legacy).
310    kwargs.pop('cloning', None)  # Legacy DistStrat argument, never used.
311    allowed_kwargs = {'feed_dict', 'fetches', 'options', 'run_metadata'}
312    unknown_kwargs = set(kwargs.keys()) - allowed_kwargs
313    if unknown_kwargs:
314      raise TypeError(
315          'Invalid keyword argument(s) in `compile`: %s' % (unknown_kwargs,))
316    self._function_kwargs = kwargs
317    if self._function_kwargs:
318      self._experimental_run_tf_function = False
319      if self.run_eagerly:
320        raise ValueError(
321            'Session keyword arguments are not supported '
322            'when `run_eagerly=True`. You passed the following '
323            'Session arguments: %s' % (self._function_kwargs,))
324
325    self._set_optimizer(optimizer)
326    is_any_keras_optimizer_v1 = any(
327        (isinstance(opt, optimizer_v1.Optimizer)
328         and not isinstance(opt, optimizer_v1.TFOptimizer)
329        ) for opt in nest.flatten(self.optimizer))
330
331    if is_any_keras_optimizer_v1 and ops.executing_eagerly_outside_functions():
332      raise ValueError('`tf.compat.v1.keras` Optimizer (', optimizer, ') is '
333                       'not supported when eager execution is enabled. Use a '
334                       '`tf.keras` Optimizer instead, or disable eager '
335                       'execution.')
336
337    if ((target_tensors is not None)
338        or not ops.executing_eagerly_outside_functions()):
339      # Fallback out of things that aren't supported with v2 loops
340      self._experimental_run_tf_function = False
341
342    if distribute is not None:
343      if tf2.enabled() or self._experimental_run_tf_function:
344        raise ValueError(
345            'Distribute argument in compile is not available in TF 2.0 please '
346            'create the model under the distribution strategy scope.')
347      logging.warning('Distribute argument in compile is deprecated please '
348                      'create the model under the distribution strategy scope.')
349      self._distribution_strategy = distribute
350      self._compile_distribution = True
351    else:
352      if distribution_strategy_context.has_strategy():
353        # When the user builds the model in the DS scope and cross replica
354        # context we want distribution strategy to be set but when building the
355        # replica copies of the models internally we should not be compiling
356        # with distribution strategy and use the default compilation path.
357        if distribution_strategy_context.in_cross_replica_context():
358          self._distribution_strategy = (
359              distribution_strategy_context.get_strategy())
360
361    if isinstance(self._distribution_strategy,
362                  parameter_server_strategy.ParameterServerStrategyV1):
363      raise NotImplementedError(
364          '`tf.compat.v1.distribute.experimental.ParameterServerStrategy` '
365          'currently only works with the tf.Estimator API')
366
367    if isinstance(self._distribution_strategy,
368                  parameter_server_strategy_v2.ParameterServerStrategyV2):
369      raise NotImplementedError(
370          '`tf.distribute.experimental.ParameterServerStrategy` is only '
371          'supported in TF2.')
372
373    if not self._experimental_run_tf_function:
374      self._validate_compile_param_for_distribution_strategy(self.run_eagerly,
375                                                             sample_weight_mode,
376                                                             target_tensors,
377                                                             weighted_metrics)
378    # We've disabled automatic dependency tracking for this method, but do want
379    # to add a checkpoint dependency on the optimizer if it's trackable.
380    if isinstance(self.optimizer, trackable.Trackable):
381      self._track_trackable(
382          self.optimizer, name='optimizer', overwrite=True)
383    self.loss = loss or {}
384    self.loss_weights = loss_weights
385    self.sample_weight_mode = sample_weight_mode
386    self._compile_metrics = metrics or []
387    self._compile_weighted_metrics = weighted_metrics
388    if self.run_eagerly and target_tensors is not None:
389      raise ValueError(
390          'target_tensors argument is not supported when '
391          'running a model eagerly.')
392
393    # _training_endpoints contains a list of _TrainingEndpoint object, which has
394    # all the model output/target/loss and related metadata.
395    self._training_endpoints = []
396
397    # Used to freeze the behavior of the Model once `compile` has been called.
398    self._compiled_trainable_state = self._get_trainable_state()
399
400    # Set tf.distribute.Strategy specific parameters.
401    self._distributed_model_cache = {}
402    self._distributed_function_cache = {}
403
404    # Clear any `_eager_losses` that was added.
405    self._clear_losses()
406
407    if (not context.executing_eagerly() and
408        self._distribution_strategy is not None):
409      # Ensures a Session is created and configured correctly for Distribution
410      # Strategy.
411      K.configure_and_create_distributed_session(self._distribution_strategy)
412    # Initialize model metric attributes.
413    self._init_metric_attributes()
414    if not self.built or not self.inputs or not self.outputs:
415      # Model is not compilable because it does not know its number of inputs
416      # and outputs, nor their shapes and names. We will compile after the first
417      # time the model gets called on training data.
418      return
419    self._is_compiled = True
420    base_layer.keras_api_gauge.get_cell('compile').set(True)
421
422    # Prepare list of loss functions, same size of model outputs.
423    self.loss_functions = training_utils_v1.prepare_loss_functions(
424        self.loss, self.output_names)
425
426    target_tensors = self._process_target_tensor_for_compile(target_tensors)
427
428    for o, n, l, t in zip(self.outputs, self.output_names,
429                          self.loss_functions, target_tensors):
430      endpoint = _TrainingEndpoint(o, n, l)
431      endpoint.create_training_target(t, run_eagerly=self.run_eagerly)
432      self._training_endpoints.append(endpoint)
433
434    # Prepare list loss weights, same size of model outputs.
435    training_utils_v1.prepare_loss_weights(self._training_endpoints,
436                                           loss_weights)
437
438    # Initialization for Eager mode execution.
439    if self.run_eagerly:
440      self._compile_eagerly(metrics, weighted_metrics, sample_weight_mode)
441      return
442
443    with K.get_graph().as_default():
444      # Save all metric attributes per output of the model.
445      self._cache_output_metric_attributes(metrics, weighted_metrics)
446
447      # Set metric attributes on model.
448      self._set_metric_attributes()
449
450      # Invoke metric functions (unweighted) for all the outputs.
451      self._handle_metrics(
452          self.outputs,
453          targets=self._targets,
454          skip_target_masks=self._prepare_skip_target_masks(),
455          masks=self._prepare_output_masks())
456
457      # Prepare sample weight modes. List with the same length as model outputs.
458      training_utils_v1.prepare_sample_weight_modes(
459          self._training_endpoints, sample_weight_mode)
460
461      # Creates the model loss and weighted metrics sub-graphs.
462      self._compile_weights_loss_and_weighted_metrics()
463
464      # Functions for train, test and predict will
465      # be compiled lazily when required.
466      # This saves time when the user is not using all functions.
467      self.train_function = None
468      self.test_function = None
469      self.predict_function = None
470
471      # Collected trainable weights, sorted in topological order.
472      self._collected_trainable_weights = self.trainable_weights
473
474      # Validate all variables were correctly created in distribution scope.
475      if self._distribution_strategy and not self._compile_distribution:
476        for v in self.variables:
477          strategy = self._distribution_strategy
478          if not strategy.extended.variable_created_in_scope(v):
479            raise ValueError(
480                'Variable (%s) was not created in the distribution strategy '
481                'scope of (%s). It is most likely due to not all layers or '
482                'the model or optimizer being created outside the distribution '
483                'strategy scope. Try to make sure your code looks similar '
484                'to the following.\n'
485                'with strategy.scope():\n'
486                '  model=_create_model()\n'
487                '  model.compile(...)'% (v, strategy))
488
489  @trackable.no_automatic_dependency_tracking
490  def _init_distributed_function_cache_if_not_compiled(self):
491    if not hasattr(self, '_distributed_function_cache'):
492      self._distributed_function_cache = {}
493
494  @property
495  def metrics(self):
496    """Returns the model's metrics added using `compile`, `add_metric` APIs."""
497    metrics = []
498    if self._is_compiled:
499      if not hasattr(self, '_v1_compile_was_called'):
500        # See b/155687393 for more details, the model is created as a v2
501        # instance but converted to v1. Fallback to use base Model to retrieve
502        # the metrics.
503        return super(Model, self).metrics
504      metrics += self._compile_metric_functions
505    metrics.extend(self._metrics)
506    metrics.extend(
507        _get_metrics_from_layers(
508            list(self._flatten_layers(include_self=False, recursive=False))))
509    return metrics
510
511  @property
512  def metrics_names(self):
513    """Returns the model's display labels for all outputs."""
514
515    # This property includes all output names including `loss` and per-output
516    # losses for backward compatibility.
517    metrics_names = ['loss']
518    if self._is_compiled:
519      if not hasattr(self, '_v1_compile_was_called'):
520        # See b/155687393 for more details, the model is created as a v2
521        # instance but converted to v1. Fallback to use base Model to retrieve
522        # the metrics name
523        return super(Model, self).metrics_names
524
525      # Add output loss metric names to the metric names list.
526      if len(self._training_endpoints) > 1:
527        metrics_names.extend([
528            e.loss_name()
529            for e in self._training_endpoints
530            if not e.should_skip_target()
531        ])
532
533    # Add all metric names.
534    metrics_names += [m.name for m in self.metrics]
535    return metrics_names
536
537  @property
538  def run_eagerly(self):
539    """Settable attribute indicating whether the model should run eagerly.
540
541    Running eagerly means that your model will be run step by step,
542    like Python code. Your model might run slower, but it should become easier
543    for you to debug it by stepping into individual layer calls.
544
545    By default, we will attempt to compile your model to a static graph to
546    deliver the best execution performance.
547
548    Returns:
549      Boolean, whether the model should run eagerly.
550    """
551    if self._run_eagerly is True and not context.executing_eagerly():
552      raise ValueError('You can only set `run_eagerly=True` if eager execution '
553                       'is enabled.')
554    if not self.dynamic:
555      if self._run_eagerly is None:
556        # Respect `tf.config.run_functions_eagerly` unless
557        # `run_eagerly` was explicitly passed to `compile`.
558        return def_function.functions_run_eagerly()
559      else:
560        return self._run_eagerly
561    else:
562      if not context.executing_eagerly():
563        raise ValueError('Your model contains layers that can only be '
564                         'successfully run in eager execution (layers '
565                         'constructed with `dynamic=True`). '
566                         'You must enable eager execution with '
567                         '`tf.enable_eager_execution()`.')
568      if self._run_eagerly is False:
569        # TODO(fchollet): consider using py_func to enable this.
570        raise ValueError('Your model contains layers that can only be '
571                         'successfully run in eager execution (layers '
572                         'constructed with `dynamic=True`). '
573                         'You cannot set `run_eagerly=False`.')
574      return context.executing_eagerly()
575
576  @run_eagerly.setter
577  def run_eagerly(self, value):
578    self._run_eagerly = value
579
580  def _select_training_loop(self, inputs):
581    """Select training loop for fit/eval/predict based on the inputs."""
582    # TODO(kaftan) or TODO(scottzhu): This check should eventually be nicely
583    #  integrated into the data adapters in the v2 loop. We can't do this yet
584    #  because we currently have to fall back for unhandled data types.
585    if isinstance(inputs, (iterator_ops.Iterator,
586                           iterator_ops.IteratorBase)):
587      raise ValueError('For performance reasons Keras `fit`, `evaluate` and'
588                       '`predict` accept tf.data `Datasets` as input but not '
589                       'iterators that have been manually generated from '
590                       'Datasets by users. Please directly pass in the '
591                       'original `Dataset` object instead of passing in '
592                       '`iter(dataset)`.')
593
594    # Case 1: distribution strategy.
595    if self._distribution_strategy:
596      if self._in_multi_worker_mode():
597        return training_distributed_v1.DistributionMultiWorkerTrainingLoop(
598            training_distributed_v1.DistributionSingleWorkerTrainingLoop())
599      else:
600        return training_distributed_v1.DistributionSingleWorkerTrainingLoop()
601
602    # Case 2: generator-like. Input is Python generator, or Sequence object,
603    # or a non-distributed Dataset or iterator in eager execution.
604    if data_utils.is_generator_or_sequence(inputs):
605      return training_generator_v1.GeneratorOrSequenceTrainingLoop()
606    if training_utils_v1.is_eager_dataset_or_iterator(inputs):
607      return training_generator_v1.EagerDatasetOrIteratorTrainingLoop()
608
609    # Case 3: Symbolic tensors or Numpy array-like.
610    # This includes Datasets and iterators in graph mode (since they
611    # generate symbolic tensors).
612    if self.run_eagerly:
613      return training_generator_v1.GeneratorLikeTrainingLoop()
614    else:
615      return training_arrays_v1.ArrayLikeTrainingLoop()
616
617  def fit(self,
618          x=None,
619          y=None,
620          batch_size=None,
621          epochs=1,
622          verbose=1,
623          callbacks=None,
624          validation_split=0.,
625          validation_data=None,
626          shuffle=True,
627          class_weight=None,
628          sample_weight=None,
629          initial_epoch=0,
630          steps_per_epoch=None,
631          validation_steps=None,
632          validation_freq=1,
633          max_queue_size=10,
634          workers=1,
635          use_multiprocessing=False,
636          **kwargs):
637    """Trains the model for a fixed number of epochs (iterations on a dataset).
638
639    Args:
640        x: Input data. It could be:
641          - A Numpy array (or array-like), or a list of arrays
642            (in case the model has multiple inputs).
643          - A TensorFlow tensor, or a list of tensors
644            (in case the model has multiple inputs).
645          - A dict mapping input names to the corresponding array/tensors,
646            if the model has named inputs.
647          - A `tf.data` dataset. Should return a tuple
648            of either `(inputs, targets)` or
649            `(inputs, targets, sample_weights)`.
650          - A generator or `keras.utils.Sequence` returning `(inputs, targets)`
651            or `(inputs, targets, sample weights)`.
652        y: Target data. Like the input data `x`,
653          it could be either Numpy array(s) or TensorFlow tensor(s).
654          It should be consistent with `x` (you cannot have Numpy inputs and
655          tensor targets, or inversely). If `x` is a dataset, generator,
656          or `keras.utils.Sequence` instance, `y` should
657          not be specified (since targets will be obtained from `x`).
658        batch_size: Integer or `None`.
659            Number of samples per gradient update.
660            If unspecified, `batch_size` will default to 32.
661            Do not specify the `batch_size` if your data is in the
662            form of symbolic tensors, datasets,
663            generators, or `keras.utils.Sequence` instances (since they generate
664            batches).
665        epochs: Integer. Number of epochs to train the model.
666            An epoch is an iteration over the entire `x` and `y`
667            data provided.
668            Note that in conjunction with `initial_epoch`,
669            `epochs` is to be understood as "final epoch".
670            The model is not trained for a number of iterations
671            given by `epochs`, but merely until the epoch
672            of index `epochs` is reached.
673        verbose: 0, 1, or 2. Verbosity mode.
674            0 = silent, 1 = progress bar, 2 = one line per epoch.
675            Note that the progress bar is not particularly useful when
676            logged to a file, so verbose=2 is recommended when not running
677            interactively (eg, in a production environment).
678        callbacks: List of `keras.callbacks.Callback` instances.
679            List of callbacks to apply during training.
680            See `tf.keras.callbacks`.
681        validation_split: Float between 0 and 1.
682            Fraction of the training data to be used as validation data.
683            The model will set apart this fraction of the training data,
684            will not train on it, and will evaluate
685            the loss and any model metrics
686            on this data at the end of each epoch.
687            The validation data is selected from the last samples
688            in the `x` and `y` data provided, before shuffling. This argument is
689            not supported when `x` is a dataset, generator or
690           `keras.utils.Sequence` instance.
691        validation_data: Data on which to evaluate
692            the loss and any model metrics at the end of each epoch.
693            The model will not be trained on this data.
694            `validation_data` will override `validation_split`.
695            `validation_data` could be:
696              - tuple `(x_val, y_val)` of Numpy arrays or tensors
697              - tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays
698              - dataset
699            For the first two cases, `batch_size` must be provided.
700            For the last case, `validation_steps` could be provided.
701        shuffle: Boolean (whether to shuffle the training data
702            before each epoch) or str (for 'batch').
703            'batch' is a special option for dealing with the
704            limitations of HDF5 data; it shuffles in batch-sized chunks.
705            Has no effect when `steps_per_epoch` is not `None`.
706        class_weight: Optional dictionary mapping class indices (integers)
707            to a weight (float) value, used for weighting the loss function
708            (during training only).
709            This can be useful to tell the model to
710            "pay more attention" to samples from
711            an under-represented class.
712        sample_weight: Optional Numpy array of weights for
713            the training samples, used for weighting the loss function
714            (during training only). You can either pass a flat (1D)
715            Numpy array with the same length as the input samples
716            (1:1 mapping between weights and samples),
717            or in the case of temporal data,
718            you can pass a 2D array with shape
719            `(samples, sequence_length)`,
720            to apply a different weight to every timestep of every sample.
721            In this case you should make sure to specify
722            `sample_weight_mode="temporal"` in `compile()`. This argument is not
723            supported when `x` is a dataset, generator, or
724           `keras.utils.Sequence` instance, instead provide the sample_weights
725            as the third element of `x`.
726        initial_epoch: Integer.
727            Epoch at which to start training
728            (useful for resuming a previous training run).
729        steps_per_epoch: Integer or `None`.
730            Total number of steps (batches of samples)
731            before declaring one epoch finished and starting the
732            next epoch. When training with input tensors such as
733            TensorFlow data tensors, the default `None` is equal to
734            the number of samples in your dataset divided by
735            the batch size, or 1 if that cannot be determined. If x is a
736            `tf.data` dataset, and 'steps_per_epoch'
737            is None, the epoch will run until the input dataset is exhausted.
738            This argument is not supported with array inputs.
739        validation_steps: Only relevant if `validation_data` is provided and
740            is a `tf.data` dataset. Total number of steps (batches of
741            samples) to draw before stopping when performing validation
742            at the end of every epoch. If 'validation_steps' is None, validation
743            will run until the `validation_data` dataset is exhausted. In the
744            case of a infinite dataset, it will run into a infinite loop.
745            If 'validation_steps' is specified and only part of the dataset
746            will be consumed, the evaluation will start from the beginning of
747            the dataset at each epoch. This ensures that the same validation
748            samples are used every time.
749        validation_freq: Only relevant if validation data is provided. Integer
750            or `collections.abc.Container` instance (e.g. list, tuple, etc.).
751            If an integer, specifies how many training epochs to run before a
752            new validation run is performed, e.g. `validation_freq=2` runs
753            validation every 2 epochs. If a Container, specifies the epochs on
754            which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
755            validation at the end of the 1st, 2nd, and 10th epochs.
756        max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
757            input only. Maximum size for the generator queue.
758            If unspecified, `max_queue_size` will default to 10.
759        workers: Integer. Used for generator or `keras.utils.Sequence` input
760            only. Maximum number of processes to spin up
761            when using process-based threading. If unspecified, `workers`
762            will default to 1. If 0, will execute the generator on the main
763            thread.
764        use_multiprocessing: Boolean. Used for generator or
765            `keras.utils.Sequence` input only. If `True`, use process-based
766            threading. If unspecified, `use_multiprocessing` will default to
767            `False`. Note that because this implementation relies on
768            multiprocessing, you should not pass non-picklable arguments to
769            the generator as they can't be passed easily to children processes.
770        **kwargs: Used for backwards compatibility.
771
772    Returns:
773        A `History` object. Its `History.history` attribute is
774        a record of training loss values and metrics values
775        at successive epochs, as well as validation loss values
776        and validation metrics values (if applicable).
777
778    Raises:
779        RuntimeError: If the model was never compiled.
780        ValueError: In case of mismatch between the provided input data
781            and what the model expects.
782    """
783    self._assert_built_as_v1()
784    base_layer.keras_api_gauge.get_cell('fit').set(True)
785    # Legacy support
786    if 'nb_epoch' in kwargs:
787      logging.warning(
788          'The `nb_epoch` argument in `fit` has been renamed `epochs`.')
789      epochs = kwargs.pop('nb_epoch')
790    if kwargs:
791      raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
792    self._assert_compile_was_called()
793    self._check_call_args('fit')
794
795    func = self._select_training_loop(x)
796    return func.fit(
797        self,
798        x=x,
799        y=y,
800        batch_size=batch_size,
801        epochs=epochs,
802        verbose=verbose,
803        callbacks=callbacks,
804        validation_split=validation_split,
805        validation_data=validation_data,
806        shuffle=shuffle,
807        class_weight=class_weight,
808        sample_weight=sample_weight,
809        initial_epoch=initial_epoch,
810        steps_per_epoch=steps_per_epoch,
811        validation_steps=validation_steps,
812        validation_freq=validation_freq,
813        max_queue_size=max_queue_size,
814        workers=workers,
815        use_multiprocessing=use_multiprocessing)
816
817  def evaluate(self,
818               x=None,
819               y=None,
820               batch_size=None,
821               verbose=1,
822               sample_weight=None,
823               steps=None,
824               callbacks=None,
825               max_queue_size=10,
826               workers=1,
827               use_multiprocessing=False):
828    """Returns the loss value & metrics values for the model in test mode.
829
830    Computation is done in batches (see the `batch_size` arg.)
831
832    Args:
833        x: Input data. It could be:
834          - A Numpy array (or array-like), or a list of arrays
835            (in case the model has multiple inputs).
836          - A TensorFlow tensor, or a list of tensors
837            (in case the model has multiple inputs).
838          - A dict mapping input names to the corresponding array/tensors,
839            if the model has named inputs.
840          - A `tf.data` dataset.
841          - A generator or `keras.utils.Sequence` instance.
842        y: Target data. Like the input data `x`,
843          it could be either Numpy array(s) or TensorFlow tensor(s).
844          It should be consistent with `x` (you cannot have Numpy inputs and
845          tensor targets, or inversely).
846          If `x` is a dataset, generator or
847          `keras.utils.Sequence` instance, `y` should not be specified (since
848          targets will be obtained from the iterator/dataset).
849        batch_size: Integer or `None`.
850            Number of samples per batch of computation.
851            If unspecified, `batch_size` will default to 32.
852            Do not specify the `batch_size` if your data is in the
853            form of symbolic tensors, dataset,
854            generators, or `keras.utils.Sequence` instances (since they generate
855            batches).
856        verbose: 0 or 1. Verbosity mode.
857            0 = silent, 1 = progress bar.
858        sample_weight: Optional Numpy array of weights for
859            the test samples, used for weighting the loss function.
860            You can either pass a flat (1D)
861            Numpy array with the same length as the input samples
862            (1:1 mapping between weights and samples),
863            or in the case of temporal data,
864            you can pass a 2D array with shape
865            `(samples, sequence_length)`,
866            to apply a different weight to every timestep of every sample.
867            In this case you should make sure to specify
868            `sample_weight_mode="temporal"` in `compile()`. This argument is not
869            supported when `x` is a dataset, instead pass
870            sample weights as the third element of `x`.
871        steps: Integer or `None`.
872            Total number of steps (batches of samples)
873            before declaring the evaluation round finished.
874            Ignored with the default value of `None`.
875            If x is a `tf.data` dataset and `steps` is
876            None, 'evaluate' will run until the dataset is exhausted.
877            This argument is not supported with array inputs.
878        callbacks: List of `keras.callbacks.Callback` instances.
879            List of callbacks to apply during evaluation.
880            See [callbacks](/api_docs/python/tf/keras/callbacks).
881        max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
882            input only. Maximum size for the generator queue.
883            If unspecified, `max_queue_size` will default to 10.
884        workers: Integer. Used for generator or `keras.utils.Sequence` input
885            only. Maximum number of processes to spin up when using
886            process-based threading. If unspecified, `workers` will default
887            to 1. If 0, will execute the generator on the main thread.
888        use_multiprocessing: Boolean. Used for generator or
889            `keras.utils.Sequence` input only. If `True`, use process-based
890            threading. If unspecified, `use_multiprocessing` will default to
891            `False`. Note that because this implementation relies on
892            multiprocessing, you should not pass non-picklable arguments to
893            the generator as they can't be passed easily to children processes.
894
895    Returns:
896        Scalar test loss (if the model has a single output and no metrics)
897        or list of scalars (if the model has multiple outputs
898        and/or metrics). The attribute `model.metrics_names` will give you
899        the display labels for the scalar outputs.
900
901    Raises:
902        ValueError: in case of invalid arguments.
903    """
904    self._assert_built_as_v1()
905    base_layer.keras_api_gauge.get_cell('evaluate').set(True)
906    self._assert_compile_was_called()
907    self._check_call_args('evaluate')
908
909    func = self._select_training_loop(x)
910    return func.evaluate(
911        self,
912        x=x,
913        y=y,
914        batch_size=batch_size,
915        verbose=verbose,
916        sample_weight=sample_weight,
917        steps=steps,
918        callbacks=callbacks,
919        max_queue_size=max_queue_size,
920        workers=workers,
921        use_multiprocessing=use_multiprocessing)
922
923  def predict(self,
924              x,
925              batch_size=None,
926              verbose=0,
927              steps=None,
928              callbacks=None,
929              max_queue_size=10,
930              workers=1,
931              use_multiprocessing=False):
932    """Generates output predictions for the input samples.
933
934    Computation is done in batches (see the `batch_size` arg.)
935
936    Args:
937        x: Input samples. It could be:
938          - A Numpy array (or array-like), or a list of arrays
939            (in case the model has multiple inputs).
940          - A TensorFlow tensor, or a list of tensors
941            (in case the model has multiple inputs).
942          - A `tf.data` dataset.
943          - A generator or `keras.utils.Sequence` instance.
944        batch_size: Integer or `None`.
945            Number of samples per batch of computation.
946            If unspecified, `batch_size` will default to 32.
947            Do not specify the `batch_size` if your data is in the
948            form of symbolic tensors, dataset,
949            generators, or `keras.utils.Sequence` instances (since they generate
950            batches).
951        verbose: Verbosity mode, 0 or 1.
952        steps: Total number of steps (batches of samples)
953            before declaring the prediction round finished.
954            Ignored with the default value of `None`. If x is a `tf.data`
955            dataset and `steps` is None, `predict` will
956            run until the input dataset is exhausted.
957        callbacks: List of `keras.callbacks.Callback` instances.
958            List of callbacks to apply during prediction.
959            See [callbacks](/api_docs/python/tf/keras/callbacks).
960        max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
961            input only. Maximum size for the generator queue.
962            If unspecified, `max_queue_size` will default to 10.
963        workers: Integer. Used for generator or `keras.utils.Sequence` input
964            only. Maximum number of processes to spin up when using
965            process-based threading. If unspecified, `workers` will default
966            to 1. If 0, will execute the generator on the main thread.
967        use_multiprocessing: Boolean. Used for generator or
968            `keras.utils.Sequence` input only. If `True`, use process-based
969            threading. If unspecified, `use_multiprocessing` will default to
970            `False`. Note that because this implementation relies on
971            multiprocessing, you should not pass non-picklable arguments to
972            the generator as they can't be passed easily to children processes.
973
974
975    Returns:
976        Numpy array(s) of predictions.
977
978    Raises:
979        ValueError: In case of mismatch between the provided
980            input data and the model's expectations,
981            or in case a stateful model receives a number of samples
982            that is not a multiple of the batch size.
983    """
984    self._assert_built_as_v1()
985    base_layer.keras_api_gauge.get_cell('predict').set(True)
986    self._check_call_args('predict')
987
988    func = self._select_training_loop(x)
989    return func.predict(
990        self,
991        x=x,
992        batch_size=batch_size,
993        verbose=verbose,
994        steps=steps,
995        callbacks=callbacks,
996        max_queue_size=max_queue_size,
997        workers=workers,
998        use_multiprocessing=use_multiprocessing)
999
1000  def reset_metrics(self):
1001    """Resets the state of metrics."""
1002    metrics = self._get_training_eval_metrics()
1003    for m in metrics:
1004      m.reset_states()
1005
1006    # Reset metrics on all the distributed (cloned) models.
1007    if self._distribution_strategy:
1008      distributed_training_utils_v1._reset_metrics(self)  # pylint: disable=protected-access
1009
1010  def train_on_batch(self,
1011                     x,
1012                     y=None,
1013                     sample_weight=None,
1014                     class_weight=None,
1015                     reset_metrics=True):
1016    """Runs a single gradient update on a single batch of data.
1017
1018    Args:
1019        x: Input data. It could be:
1020          - A Numpy array (or array-like), or a list of arrays
1021              (in case the model has multiple inputs).
1022          - A TensorFlow tensor, or a list of tensors
1023              (in case the model has multiple inputs).
1024          - A dict mapping input names to the corresponding array/tensors,
1025              if the model has named inputs.
1026          - A `tf.data` dataset.
1027        y: Target data. Like the input data `x`, it could be either Numpy
1028          array(s) or TensorFlow tensor(s). It should be consistent with `x`
1029          (you cannot have Numpy inputs and tensor targets, or inversely). If
1030          `x` is a dataset, `y` should not be specified
1031          (since targets will be obtained from the iterator).
1032        sample_weight: Optional array of the same length as x, containing
1033          weights to apply to the model's loss for each sample. In the case of
1034          temporal data, you can pass a 2D array with shape (samples,
1035          sequence_length), to apply a different weight to every timestep of
1036          every sample. In this case you should make sure to specify
1037          sample_weight_mode="temporal" in compile(). This argument is not
1038          supported when `x` is a dataset.
1039        class_weight: Optional dictionary mapping class indices (integers) to a
1040          weight (float) to apply to the model's loss for the samples from this
1041          class during training. This can be useful to tell the model to "pay
1042          more attention" to samples from an under-represented class.
1043        reset_metrics: If `True`, the metrics returned will be only for this
1044          batch. If `False`, the metrics will be statefully accumulated across
1045          batches.
1046
1047    Returns:
1048        Scalar training loss
1049        (if the model has a single output and no metrics)
1050        or list of scalars (if the model has multiple outputs
1051        and/or metrics). The attribute `model.metrics_names` will give you
1052        the display labels for the scalar outputs.
1053
1054    Raises:
1055      ValueError: In case of invalid user-provided arguments.
1056    """
1057    self._assert_compile_was_called()
1058    self._check_call_args('train_on_batch')
1059
1060    # If at this point we are in the replica context, then it is okay to execute
1061    # the Eager code path.  The expected way to get here is to call `fit` that
1062    # calls `train_on_batch` on each replica.
1063    if (self._distribution_strategy and
1064        distribution_strategy_context.in_cross_replica_context()):
1065      raise NotImplementedError('`train_on_batch` is not supported for models '
1066                                'distributed with tf.distribute.Strategy.')
1067    # Validate and standardize user data.
1068    x, y, sample_weights = self._standardize_user_data(
1069        x, y, sample_weight=sample_weight, class_weight=class_weight,
1070        extract_tensors_from_dataset=True)
1071
1072    # If `self._distribution_strategy` is True, then we are in a replica context
1073    # at this point because of the check above.  `train_on_batch` is being run
1074    # for each replica by `self._distribution_strategy` and the same code path
1075    # as Eager is expected to be taken.
1076    if self.run_eagerly or self._distribution_strategy:
1077      output_dict = training_eager_v1.train_on_batch(
1078          self,
1079          x,
1080          y,
1081          sample_weights=sample_weights,
1082          output_loss_metrics=self._output_loss_metrics)
1083      outputs = (output_dict['total_loss'] + output_dict['output_losses']
1084                 + output_dict['metrics'])
1085      outputs = [_non_none_constant_value(v) for v in outputs]  # pylint: disable=protected-access
1086    else:
1087      x = training_utils_v1.ModelInputs(x).as_list()
1088      ins = x + list(y or []) + list(sample_weights or [])
1089
1090      if not isinstance(K.symbolic_learning_phase(), int):
1091        ins += [True]  # Add learning phase value.
1092
1093      self._update_sample_weight_modes(sample_weights=sample_weights)
1094      self._make_train_function()
1095      outputs = self.train_function(ins)  # pylint: disable=not-callable
1096
1097    if reset_metrics:
1098      self.reset_metrics()
1099
1100    if len(outputs) == 1:
1101      return outputs[0]
1102    return outputs
1103
1104  def test_on_batch(self, x, y=None, sample_weight=None, reset_metrics=True):
1105    """Test the model on a single batch of samples.
1106
1107    Args:
1108        x: Input data. It could be:
1109          - A Numpy array (or array-like), or a list of arrays
1110            (in case the model has multiple inputs).
1111          - A TensorFlow tensor, or a list of tensors
1112            (in case the model has multiple inputs).
1113          - A dict mapping input names to the corresponding array/tensors,
1114            if the model has named inputs.
1115          - A `tf.data` dataset.
1116        y: Target data. Like the input data `x`,
1117          it could be either Numpy array(s) or TensorFlow tensor(s).
1118          It should be consistent with `x` (you cannot have Numpy inputs and
1119          tensor targets, or inversely). If `x` is a dataset `y` should
1120          not be specified (since targets will be obtained from the iterator).
1121        sample_weight: Optional array of the same length as x, containing
1122            weights to apply to the model's loss for each sample.
1123            In the case of temporal data, you can pass a 2D array
1124            with shape (samples, sequence_length),
1125            to apply a different weight to every timestep of every sample.
1126            In this case you should make sure to specify
1127            sample_weight_mode="temporal" in compile(). This argument is not
1128            supported when `x` is a dataset.
1129        reset_metrics: If `True`, the metrics returned will be only for this
1130          batch. If `False`, the metrics will be statefully accumulated across
1131          batches.
1132
1133    Returns:
1134        Scalar test loss (if the model has a single output and no metrics)
1135        or list of scalars (if the model has multiple outputs
1136        and/or metrics). The attribute `model.metrics_names` will give you
1137        the display labels for the scalar outputs.
1138
1139    Raises:
1140        ValueError: In case of invalid user-provided arguments.
1141    """
1142    self._assert_compile_was_called()
1143    self._check_call_args('test_on_batch')
1144
1145    if (self._distribution_strategy and
1146        distribution_strategy_context.in_cross_replica_context()):
1147      raise NotImplementedError('`test_on_batch` is not supported for models '
1148                                'distributed with tf.distribute.Strategy.')
1149    # Validate and standardize user data.
1150    x, y, sample_weights = self._standardize_user_data(
1151        x, y, sample_weight=sample_weight, extract_tensors_from_dataset=True)
1152
1153    # If `self._distribution_strategy` is True, then we are in a replica context
1154    # at this point.
1155    if self.run_eagerly or self._distribution_strategy:
1156      output_dict = training_eager_v1.test_on_batch(
1157          self,
1158          x,
1159          y,
1160          sample_weights=sample_weights,
1161          output_loss_metrics=self._output_loss_metrics)
1162      outputs = (output_dict['total_loss'] + output_dict['output_losses']
1163                 + output_dict['metrics'])
1164      outputs = [_non_none_constant_value(v) for v in outputs]  # pylint: disable=protected-access
1165    else:
1166      x = training_utils_v1.ModelInputs(x).as_list()
1167      inputs = x + list(y or []) + list(sample_weights or [])
1168
1169      self._update_sample_weight_modes(sample_weights=sample_weights)
1170      self._make_test_function()
1171      outputs = self.test_function(inputs)  # pylint: disable=not-callable
1172
1173    if reset_metrics:
1174      self.reset_metrics()
1175
1176    if len(outputs) == 1:
1177      return outputs[0]
1178    return outputs
1179
1180  def predict_on_batch(self, x):
1181    """Returns predictions for a single batch of samples.
1182
1183    Args:
1184        x: Input data. It could be:
1185          - A Numpy array (or array-like), or a list of arrays
1186            (in case the model has multiple inputs).
1187          - A TensorFlow tensor, or a list of tensors
1188            (in case the model has multiple inputs).
1189          - A `tf.data` dataset.
1190
1191    Returns:
1192        Numpy array(s) of predictions.
1193
1194    Raises:
1195        ValueError: In case of mismatch between given number of inputs and
1196          expectations of the model.
1197    """
1198    self._check_call_args('predict_on_batch')
1199
1200    if (self._distribution_strategy and
1201        distribution_strategy_context.in_cross_replica_context()):
1202      raise NotImplementedError(
1203          '`predict_on_batch` is not supported for models distributed with'
1204          ' tf.distribute.Strategy.')
1205    # Validate and standardize user data.
1206    inputs, _, _ = self._standardize_user_data(
1207        x, extract_tensors_from_dataset=True)
1208    # If `self._distribution_strategy` is True, then we are in a replica context
1209    # at this point.
1210    if self.run_eagerly or self._distribution_strategy:
1211      inputs = training_utils_v1.cast_if_floating_dtype(inputs)
1212      if isinstance(inputs, collections.abc.Sequence):
1213        # Unwrap lists with only one input, as we do when training on batch
1214        if len(inputs) == 1:
1215          inputs = inputs[0]
1216
1217      return self(inputs)  # pylint: disable=not-callable
1218
1219    self._make_predict_function()
1220    outputs = self.predict_function(inputs)
1221
1222    if len(outputs) == 1:
1223      return outputs[0]
1224    return outputs
1225
1226  def fit_generator(self,
1227                    generator,
1228                    steps_per_epoch=None,
1229                    epochs=1,
1230                    verbose=1,
1231                    callbacks=None,
1232                    validation_data=None,
1233                    validation_steps=None,
1234                    validation_freq=1,
1235                    class_weight=None,
1236                    max_queue_size=10,
1237                    workers=1,
1238                    use_multiprocessing=False,
1239                    shuffle=True,
1240                    initial_epoch=0):
1241    """Fits the model on data yielded batch-by-batch by a Python generator.
1242
1243    DEPRECATED:
1244      `Model.fit` now supports generators, so there is no longer any need to use
1245      this endpoint.
1246    """
1247    warnings.warn('`model.fit_generator` is deprecated and '
1248                  'will be removed in a future version. '
1249                  'Please use `Model.fit`, which supports generators.')
1250    return self.fit(
1251        generator,
1252        steps_per_epoch=steps_per_epoch,
1253        epochs=epochs,
1254        verbose=verbose,
1255        callbacks=callbacks,
1256        validation_data=validation_data,
1257        validation_steps=validation_steps,
1258        validation_freq=validation_freq,
1259        class_weight=class_weight,
1260        max_queue_size=max_queue_size,
1261        workers=workers,
1262        use_multiprocessing=use_multiprocessing,
1263        shuffle=shuffle,
1264        initial_epoch=initial_epoch)
1265
1266  def evaluate_generator(self,
1267                         generator,
1268                         steps=None,
1269                         callbacks=None,
1270                         max_queue_size=10,
1271                         workers=1,
1272                         use_multiprocessing=False,
1273                         verbose=0):
1274    """Evaluates the model on a data generator.
1275
1276    DEPRECATED:
1277      `Model.evaluate` now supports generators, so there is no longer any need
1278      to use this endpoint.
1279    """
1280    warnings.warn('`Model.evaluate_generator` is deprecated and '
1281                  'will be removed in a future version. '
1282                  'Please use `Model.evaluate`, which supports generators.')
1283    self._check_call_args('evaluate_generator')
1284
1285    return self.evaluate(
1286        generator,
1287        steps=steps,
1288        max_queue_size=max_queue_size,
1289        workers=workers,
1290        use_multiprocessing=use_multiprocessing,
1291        verbose=verbose,
1292        callbacks=callbacks)
1293
1294  def predict_generator(self,
1295                        generator,
1296                        steps=None,
1297                        callbacks=None,
1298                        max_queue_size=10,
1299                        workers=1,
1300                        use_multiprocessing=False,
1301                        verbose=0):
1302    """Generates predictions for the input samples from a data generator.
1303
1304    DEPRECATED:
1305      `Model.predict` now supports generators, so there is no longer any need
1306      to use this endpoint.
1307    """
1308    warnings.warn('`Model.predict_generator` is deprecated and '
1309                  'will be removed in a future version. '
1310                  'Please use `Model.predict`, which supports generators.')
1311    return self.predict(
1312        generator,
1313        steps=steps,
1314        max_queue_size=max_queue_size,
1315        workers=workers,
1316        use_multiprocessing=use_multiprocessing,
1317        verbose=verbose,
1318        callbacks=callbacks)
1319
1320  def _check_call_args(self, method_name):
1321    """Check that `call` has only one positional arg."""
1322    # Always allow first arg, regardless of arg name.
1323    fullargspec = self._call_full_argspec
1324    if fullargspec.defaults:
1325      positional_args = fullargspec.args[:-len(fullargspec.defaults)]
1326    else:
1327      positional_args = fullargspec.args
1328    if 'training' in positional_args:
1329      positional_args.remove('training')
1330
1331    # self and first arg can be positional.
1332    if len(positional_args) > 2:
1333      extra_args = positional_args[2:]
1334      raise ValueError(
1335          'Models passed to `' + method_name + '` can only have `training` '
1336          'and the first argument in `call` as positional arguments, '
1337          'found: ' + str(extra_args) + '.')
1338
1339  def _set_optimizer(self, optimizer):
1340    """Sets self.optimizer.
1341
1342    Sets self.optimizer to `optimizer`, potentially wrapping it with a
1343    LossScaleOptimizer.
1344
1345    Args:
1346      optimizer: The optimizer(s) to assign to self.optimizer.
1347    """
1348    if isinstance(optimizer, (list, tuple)):
1349      self.optimizer = [optimizers.get(opt) for opt in optimizer]
1350    else:
1351      self.optimizer = optimizers.get(optimizer)
1352
1353    if isinstance(self._dtype_policy, policy.PolicyV1):
1354      loss_scale = self._dtype_policy.loss_scale
1355    elif self._dtype_policy.name == 'mixed_float16':
1356      loss_scale = 'dynamic'
1357    else:
1358      loss_scale = None
1359
1360    if (loss_scale is not None and
1361        not isinstance(self.optimizer,
1362                       loss_scale_optimizer.LossScaleOptimizer)):
1363      if isinstance(self.optimizer, list):
1364        raise ValueError('When a dtype policy with a loss scale is used, you '
1365                         'can only pass a single optimizer. Using policy %s '
1366                         'and got optimizers: %s' %
1367                         self._dtype_policy, self.optimizer)
1368      if not isinstance(self.optimizer, optimizer_v2.OptimizerV2):
1369        raise ValueError('"optimizer" must be an instance of '
1370                         'tf.keras.optimizers.Optimizer when a dype policy '
1371                         'with a loss scale  used, but got: %s. Using policy: '
1372                         '%s' %
1373                         (self.optimizer, self._dtype_policy))
1374      if loss_scale == 'dynamic':
1375        self.optimizer = loss_scale_optimizer.LossScaleOptimizer(self.optimizer)
1376      else:
1377        self.optimizer = loss_scale_optimizer.LossScaleOptimizerV1(
1378            self.optimizer, loss_scale)
1379
1380  def _prepare_validation_data(self, validation_data, batch_size,
1381                               validation_steps):
1382    """Unpack and check the validation data."""
1383    val_x, val_y, val_sample_weights = training_utils_v1.unpack_validation_data(
1384        validation_data)
1385    return self._standardize_user_data(
1386        val_x,
1387        val_y,
1388        sample_weight=val_sample_weights,
1389        batch_size=batch_size,
1390        steps=validation_steps,
1391        steps_name='validation_steps')
1392
1393  def _validate_compile_param_for_distribution_strategy(
1394      self, run_eagerly, sample_weight_mode, target_tensors, weighted_metrics):
1395    # Validate that arguments passed by the user to `compile` are supported by
1396    # tf.distribute.Strategy.
1397    if self._distribution_strategy:
1398      if sample_weight_mode:
1399        raise NotImplementedError('sample_weight_mode is not supported with '
1400                                  'tf.distribute.Strategy.')
1401      if weighted_metrics:
1402        raise NotImplementedError('weighted_metrics is not supported with '
1403                                  'tf.distribute.Strategy.')
1404      if target_tensors:
1405        raise ValueError('target_tensors is not supported with '
1406                         'tf.distribute.Strategy.')
1407
1408      if run_eagerly:
1409        raise ValueError(
1410            'We currently do not support enabling `run_eagerly` with '
1411            'distribution strategy.')
1412
1413      if (distributed_training_utils_v1.is_distributing_by_cloning(self) and
1414          (not self.built or not self.inputs or not self.outputs)):
1415        raise ValueError(
1416            'We currently do not support distribution strategy with a '
1417            '`Sequential` model that is created without `input_shape`/'
1418            '`input_dim` set in its first layer or a subclassed model.')
1419
1420  def _process_target_tensor_for_compile(self, target_tensors):
1421    if self.run_eagerly:
1422      # target tensor is not supported with run_eagerly. Create a list with None
1423      # as placeholder for each output.
1424      return [None for _ in self.output_names]
1425
1426    if target_tensors is not None and not (isinstance(target_tensors, list) and
1427                                           target_tensors == []):  # pylint: disable=g-explicit-bool-comparison
1428      if isinstance(target_tensors, list):
1429        if len(target_tensors) != len(self.outputs):
1430          raise ValueError(
1431              'When passing a list as `target_tensors`, '
1432              'it should have one entry per model output. '
1433              'The model has %s outputs, but you passed target_tensors=%s' %
1434              (len(self.outputs), target_tensors))
1435      elif isinstance(target_tensors, dict):
1436        unexpected_target_tensor_names = set(target_tensors.keys()).difference(
1437            self.output_names)
1438        if unexpected_target_tensor_names:
1439          raise ValueError(
1440              'Unknown entry in `target_tensors` dictionary: "{name}". '
1441              'Only expected the following keys: {keys}'.format(
1442                  name=unexpected_target_tensor_names,
1443                  keys=str(self.output_names)))
1444        tmp_target_tensors = []
1445        for name in self.output_names:
1446          tmp_target_tensors.append(target_tensors.get(name, None))
1447        target_tensors = tmp_target_tensors
1448      elif tensor_util.is_tf_type(target_tensors):
1449        target_tensors = [target_tensors]
1450      else:
1451        raise TypeError('Expected `target_tensors` to be a list or tuple or '
1452                        'dict or a single tensor, but got:', target_tensors)
1453    else:
1454      # In case target tensor is empty or None, create a list with Nones
1455      # that has same length as self.output_names. With that, the None check of
1456      # target tensor can be skipped downstream.
1457      target_tensors = [None for _ in self.output_names]
1458    return target_tensors
1459
1460  def _compile_eagerly(self, metrics, weighted_metrics, sample_weight_mode):
1461    # Prepare sample weight modes. List with the same length as model outputs.
1462    training_utils_v1.prepare_sample_weight_modes(
1463        self._training_endpoints, sample_weight_mode)
1464    # Prepare sample weights.
1465    self._prepare_sample_weights()
1466    # Save all metric attributes per output of the model.
1467    self._cache_output_metric_attributes(metrics, weighted_metrics)
1468    self.total_loss = None
1469    # Set metric attributes on model.
1470    self._set_metric_attributes()
1471
1472    self._collected_trainable_weights = self.trainable_weights
1473
1474  def _update_sample_weight_modes(self, sample_weights=None):
1475    """Updates sample weight modes based on training/eval inputs.
1476
1477    Sample weight placeholders will be created for all or no outputs
1478    based on whether sample_weight is provided for any output.
1479
1480    If model contains `_sample_weight_modes` we check if the input
1481    `sample_weights` corresponds to the sample weight modes.
1482      1. Set sample weight mode to be 'temporal' for output i, if `compile`
1483        sample_weight_mode was set to `temporal` and sample weight inputs
1484        are given for one or more outputs.
1485      2. Set sample weight mode to be 'samplewise' for output i, if `compile`
1486        sample_weight_mode was not set and sample weight inputs are given for
1487        one or more outputs.
1488      3. Reset sample weight mode to None for output i if sample weight mode
1489        was set but there is no sample weight input.
1490
1491    Args:
1492      sample_weights: List of sample weights of the same length as model outputs
1493        or None.
1494    """
1495    if not self._is_compiled:
1496      return
1497    if sample_weights and any(s is not None for s in sample_weights):
1498      for endpoint in self._training_endpoints:
1499        endpoint.sample_weight_mode = (
1500            endpoint.sample_weight_mode or 'samplewise')
1501    else:
1502      for endpoint in self._training_endpoints:
1503        endpoint.sample_weight_mode = None
1504
1505  def _recompile_weights_loss_and_weighted_metrics(self):
1506    if not self._is_compiled:
1507      return False
1508    recompile = any(
1509        e.sample_weights_mismatch() for e in self._training_endpoints)
1510
1511    if recompile:
1512      self._compile_weights_loss_and_weighted_metrics()
1513    return recompile
1514
1515  @trackable.no_automatic_dependency_tracking
1516  def _compile_weights_loss_and_weighted_metrics(self, sample_weights=None):
1517    """Compiles the model loss and weighted metric sub-graphs.
1518
1519    This may be used to set graph tensors as sample weights (instead of creating
1520    placeholders). This functionality is necessary for
1521    `tf.keras.estimator.model_to_estimator`, which calls Keras models in a v1
1522    graph, and creates iterator tensors for inputs, targets, and sample weights.
1523
1524    Args:
1525      sample_weights: List of tensors to use as the sample weights. Must be the
1526        same length as the number of outputs. If left as `None`, placeholders
1527        are used instead.
1528    """
1529    with K.get_graph().as_default():
1530      if sample_weights is not None:
1531        self._update_sample_weight_modes(sample_weights)
1532      self._prepare_sample_weights(sample_weights)
1533
1534      masks = self._prepare_output_masks()
1535
1536      # Compute weighted metrics.
1537      self._handle_metrics(
1538          self.outputs,
1539          targets=self._targets,
1540          skip_target_masks=self._prepare_skip_target_masks(),
1541          sample_weights=self.sample_weights,
1542          masks=masks,
1543          return_weighted_metrics=True)
1544
1545      # Compute total loss.
1546      # Used to keep track of the total loss value (stateless).
1547      # eg., total_loss = loss_weight_1 * output_1_loss_fn(...) +
1548      #                   loss_weight_2 * output_2_loss_fn(...) +
1549      #                   layer losses.
1550      self.total_loss = self._prepare_total_loss(masks)
1551
1552  def _prepare_skip_target_masks(self):
1553    """Boolean mask for whether the target in the output list should be skipped.
1554
1555    If the loss function corresponding to a model output is None, then this
1556    output will be skipped during total loss calculation and feed targets
1557    preparation.
1558
1559    Returns:
1560      A boolean list for whether the corresponding target in the output list
1561      should be skipped during loss calculation.
1562    """
1563    return [l is None for l in self.loss_functions]
1564
1565  def _prepare_output_masks(self):
1566    """Returns masks corresponding to model outputs."""
1567    return [getattr(x, '_keras_mask', None) for x in self.outputs]
1568
1569  def _prepare_total_loss(self, masks):
1570    """Computes total loss from loss functions.
1571
1572    Args:
1573        masks: List of mask values corresponding to each model output.
1574
1575    Returns:
1576        A list of loss weights of python floats.
1577
1578    Raises:
1579        TypeError: If model run_eagerly is True.
1580    """
1581    if self.run_eagerly:
1582      raise TypeError('total loss can not be computed when compiled with '
1583                      'run_eagerly = True.')
1584    loss_list = []
1585    with K.name_scope('loss'):
1586      for endpoint, mask in zip(self._training_endpoints, masks):
1587        if endpoint.should_skip_target():
1588          continue
1589        y_true = endpoint.training_target.target
1590        y_pred = endpoint.output
1591        loss_fn = endpoint.loss_fn
1592        loss_weight = endpoint.loss_weight
1593        loss_name = endpoint.loss_name()
1594        sample_weight = endpoint.sample_weight
1595
1596        with K.name_scope(loss_name):
1597          if mask is not None:
1598            mask = math_ops.cast(mask, y_pred.dtype)
1599            # Update weights with mask.
1600            if sample_weight is None:
1601              sample_weight = mask
1602            else:
1603              # Update dimensions of weights to match with mask if possible.
1604              mask, _, sample_weight = (
1605                  losses_utils.squeeze_or_expand_dimensions(
1606                      mask, sample_weight=sample_weight))
1607              sample_weight *= mask
1608
1609          if hasattr(loss_fn, 'reduction'):
1610            per_sample_losses = loss_fn.call(y_true, y_pred)
1611            weighted_losses = losses_utils.compute_weighted_loss(
1612                per_sample_losses,
1613                sample_weight=sample_weight,
1614                reduction=losses_utils.ReductionV2.NONE)
1615            loss_reduction = loss_fn.reduction
1616
1617            # `AUTO` loss reduction defaults to `SUM_OVER_BATCH_SIZE` for all
1618            # compile use cases.
1619            if loss_reduction == losses_utils.ReductionV2.AUTO:
1620              loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE
1621
1622            # Compute the stateless loss value.
1623            output_loss = losses_utils.reduce_weighted_loss(
1624                weighted_losses, reduction=loss_reduction)
1625          else:
1626            # Compute the stateless loss value for a custom loss class.
1627            # Here we assume that the class takes care of loss reduction
1628            # because if this class returns a vector value we cannot
1629            # differentiate between use case where a custom optimizer
1630            # expects a vector loss value vs unreduced per-sample loss value.
1631            output_loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)
1632            loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE
1633
1634        if len(self.outputs) > 1:
1635          # Keep track of stateful result tensor for the loss.
1636          endpoint.output_loss_metric(output_loss)
1637
1638        # Scale output loss for distribution. For custom losses we assume
1639        # reduction was mean.
1640        if loss_reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE:
1641          output_loss = losses_utils.scale_loss_for_distribution(output_loss)
1642
1643        loss_list.append(loss_weight * output_loss)
1644      if not loss_list and not self.losses:
1645        raise ValueError('The model cannot be compiled '
1646                         'because it has no loss to optimize.')
1647
1648      # Add regularization penalties and other layer-specific losses.
1649      custom_losses = self.get_losses_for(None) + self.get_losses_for(
1650          self.inputs)
1651      if custom_losses:
1652        total_custom_loss = math_ops.add_n(
1653            losses_utils.cast_losses_to_common_dtype(custom_losses))
1654        loss_list.append(
1655            losses_utils.scale_loss_for_distribution(total_custom_loss))
1656
1657      loss_list = losses_utils.cast_losses_to_common_dtype(loss_list)
1658      if loss_list:
1659        total_loss = math_ops.add_n(loss_list)
1660      else:
1661        total_loss = 0.
1662    return total_loss
1663
1664  def _get_callback_model(self):
1665    """Returns the Callback Model for this Model."""
1666
1667    if hasattr(self, '_replicated_model') and self._replicated_model:
1668      # When using training_distributed, we set the callback model
1669      # to an instance of the `DistributedModel` that we create in
1670      # the `compile` call. The `DistributedModel` is initialized
1671      # with the first replicated model. We need to set the callback
1672      # model to a DistributedModel to allow us to override saving
1673      # and loading weights when we checkpoint the model during training.
1674      return self._replicated_model
1675    if hasattr(self, 'callback_model') and self.callback_model:
1676      return self.callback_model
1677    return self
1678
1679  @trackable.no_automatic_dependency_tracking
1680  def _make_callback_model(self, grouped_model):
1681    first_replicated_model = self._distribution_strategy.unwrap(
1682        grouped_model)[0]
1683    # We initialize the callback model with the first replicated model.
1684    self._replicated_model = DistributedCallbackModel(first_replicated_model)
1685    self._replicated_model.set_original_model(self)
1686
1687  def _validate_or_infer_batch_size(self, batch_size, steps, x):
1688    """Validates that the `batch_size` provided is consistent with InputLayer.
1689
1690    It's possible that the user specified a static batch size in their
1691    InputLayer. If so, this method checks the provided `batch_size` and `x`
1692    arguments are consistent with this static batch size. Also, if
1693    `batch_size` is `None`, this method will attempt to infer the batch size
1694    from the static batch size of the InputLayer. Lastly, ValueError will be
1695    raised if `x` is a tf.data.Dataset and `batch_size` is specified as we
1696    expect users to provide batched datasets.
1697
1698    Args:
1699      batch_size: The batch_size provided as an argument to
1700        fit/evaluate/predict.
1701      steps: The steps provided as an argument to fit/evaluate/predict.
1702      x: The data passed as `x` to fit/evaluate/predict.
1703
1704    Returns:
1705      The validated batch_size, auto-inferred from the first layer if not
1706      provided.
1707    """
1708    if (isinstance(x, (dataset_ops.DatasetV1,
1709                       dataset_ops.DatasetV2,
1710                       data_utils.Sequence)) or
1711        tf_inspect.isgenerator(x)):
1712      if batch_size is not None:
1713        raise ValueError(
1714            'The `batch_size` argument must not be specified for the given '
1715            'input type. Received input: {}, batch_size: {}'.format(
1716                x, batch_size))
1717      return
1718
1719    # Avoids the override in Sequential.layers which filters Input layers.
1720    # (Which are often the very layers that we're after.)
1721    layers = self._flatten_layers(include_self=False, recursive=False)
1722    first_layer = next(layers, None)
1723    if first_layer:
1724      # The per-replica static batch size.
1725      static_batch_size = training_utils.get_static_batch_size(first_layer)
1726      if static_batch_size is not None:
1727
1728        # Determine number of times the user-supplied batch size will be split.
1729        if (self._distribution_strategy and
1730            distributed_training_utils.global_batch_size_supported(
1731                self._distribution_strategy)):
1732          num_splits_for_ds = self._distribution_strategy.num_replicas_in_sync
1733        else:
1734          num_splits_for_ds = 1
1735
1736        # Check `batch_size` argument is consistent with InputLayer.
1737        if batch_size is not None:
1738          if batch_size % num_splits_for_ds != 0:
1739            raise ValueError('The `batch_size` argument ({}) must be divisible '
1740                             'the by number of replicas ({})'.format(
1741                                 batch_size, num_splits_for_ds))
1742          per_replica_batch_size = batch_size // num_splits_for_ds
1743
1744          if per_replica_batch_size != static_batch_size:
1745            raise ValueError('The `batch_size` argument value {} is '
1746                             'incompatible with the specified batch size of '
1747                             'your Input Layer: {}'.format(
1748                                 per_replica_batch_size, static_batch_size))
1749
1750        # Check Dataset/Iterator batch size is consistent with InputLayer.
1751        if isinstance(x, (dataset_ops.DatasetV2, iterator_ops.Iterator,
1752                          iterator_ops.IteratorBase)):
1753          ds_batch_size = tensor_shape.Dimension(
1754              nest.flatten(dataset_ops.get_legacy_output_shapes(x))[0][0]).value
1755          if ds_batch_size is not None:
1756            if ds_batch_size % num_splits_for_ds != 0:
1757              raise ValueError(
1758                  'The batch output shape of your `Dataset` {} '
1759                  'cannot be divisible by number of replicas {}'.format(
1760                      ds_batch_size, num_splits_for_ds))
1761
1762            ds_per_replica_batch_size = ds_batch_size // num_splits_for_ds
1763            if ds_per_replica_batch_size != static_batch_size:
1764              raise ValueError('The batch output shape of your `Dataset` is '
1765                               '{}, which is incompatible with the specified '
1766                               'batch size of your Input Layer: {}'.format(
1767                                   ds_per_replica_batch_size,
1768                                   static_batch_size))
1769
1770        # Set inferred batch size from the InputLayer.
1771        if steps is None:
1772          batch_size = static_batch_size * num_splits_for_ds
1773
1774    if batch_size is None and steps is None:
1775      # Backwards compatibility
1776      batch_size = 32
1777    return batch_size
1778
1779  def _prepare_sample_weights(self, sample_weights=None):
1780    """Sets sample weight attribute on the model."""
1781    # List with the same length as model outputs.
1782    if sample_weights is not None:
1783      if len(sample_weights) != len(self._training_endpoints):
1784        raise ValueError('Provided sample weights must have same length as the '
1785                         'number of outputs. Expected: {}, got: {}.'.format(
1786                             len(self._training_endpoints),
1787                             len(sample_weights)))
1788    else:
1789      sample_weights = [None] * len(self._training_endpoints)
1790    for endpoint, weight in zip(self._training_endpoints, sample_weights):
1791      endpoint.populate_sample_weight(weight, endpoint.sample_weight_mode)
1792
1793  def _cache_output_metric_attributes(self, metrics, weighted_metrics):
1794    """Caches metric name and function attributes for every model output."""
1795    output_shapes = []
1796    for output in self.outputs:
1797      if output is None or output.shape.rank is None:
1798        output_shapes.append(None)
1799      else:
1800        output_shapes.append(output.shape.as_list())
1801    self._per_output_metrics = training_utils_v1.collect_per_output_metric_info(
1802        metrics, self.output_names, output_shapes, self.loss_functions)
1803    self._per_output_weighted_metrics = (
1804        training_utils_v1.collect_per_output_metric_info(
1805            weighted_metrics,
1806            self.output_names,
1807            output_shapes,
1808            self.loss_functions,
1809            is_weighted=True))
1810
1811  def _add_unique_metric_name(self, metric_name, output_index):
1812    """Makes the metric name unique and adds it to the model's metric name list.
1813
1814      If there are multiple outputs for which the metrics are calculated, the
1815      metric names have to be made unique by appending an integer.
1816
1817    Args:
1818      metric_name: Metric name that corresponds to the metric specified by the
1819          user. For example: 'acc'.
1820      output_index: The index of the model output for which the metric name is
1821        being added.
1822
1823    Returns:
1824      string, name of the model's unique metric name
1825    """
1826    if len(self.output_names) > 1:
1827      metric_name = '%s_%s' % (self.output_names[output_index], metric_name)
1828    j = 1
1829    base_metric_name = metric_name
1830    while metric_name in self.metrics_names:
1831      metric_name = '%s_%d' % (base_metric_name, j)
1832      j += 1
1833
1834    return metric_name
1835
1836  def _init_metric_attributes(self):
1837    """Initialized model metric attributes."""
1838    # List of stateful metric functions. Used for resetting metric state during
1839    # training/eval.
1840    self._compile_metric_functions = []
1841
1842  def _set_per_output_metric_attributes(self, metrics_dict, output_index):
1843    """Sets the metric attributes on the model for the given output.
1844
1845    Args:
1846      metrics_dict: A dict with metric names as keys and metric fns as values.
1847      output_index: The index of the model output for which the metric
1848        attributes are added.
1849
1850    Returns:
1851      Metrics dict updated with unique metric names as keys.
1852    """
1853    updated_metrics_dict = collections.OrderedDict()
1854    for metric_name, metric_fn in metrics_dict.items():
1855      metric_name = self._add_unique_metric_name(metric_name, output_index)
1856
1857      # Update the name on the metric class to be the unique generated name.
1858      metric_fn._name = metric_name  # pylint: disable=protected-access
1859      updated_metrics_dict[metric_name] = metric_fn
1860      # Keep track of metric name and function.
1861      self._compile_metric_functions.append(metric_fn)
1862    return updated_metrics_dict
1863
1864  def _set_metric_attributes(self):
1865    """Sets the metric attributes on the model for all the model outputs."""
1866    updated_per_output_metrics = []
1867    updated_per_output_weighted_metrics = []
1868    for i, endpoint in enumerate(self._training_endpoints):
1869      if endpoint.should_skip_target():
1870        updated_per_output_metrics.append(self._per_output_metrics[i])
1871        updated_per_output_weighted_metrics.append(
1872            self._per_output_weighted_metrics[i])
1873        continue
1874      updated_per_output_metrics.append(
1875          self._set_per_output_metric_attributes(self._per_output_metrics[i],
1876                                                 i))
1877      updated_per_output_weighted_metrics.append(
1878          self._set_per_output_metric_attributes(
1879              self._per_output_weighted_metrics[i], i))
1880
1881    # Create a metric wrapper for each output loss. This computes mean of an
1882    # output loss across mini-batches (irrespective of how we reduce within a
1883    # batch).
1884    if len(self._training_endpoints) > 1:
1885      for endpoint in self._training_endpoints:
1886        if not endpoint.should_skip_target():
1887          endpoint.output_loss_metric = metrics_module.Mean(
1888              name=endpoint.loss_name())
1889
1890    self._per_output_metrics = updated_per_output_metrics
1891    self._per_output_weighted_metrics = updated_per_output_weighted_metrics
1892
1893  def _handle_per_output_metrics(self,
1894                                 metrics_dict,
1895                                 y_true,
1896                                 y_pred,
1897                                 mask,
1898                                 weights=None):
1899    """Calls metric functions for a single output.
1900
1901    Args:
1902      metrics_dict: A dict with metric names as keys and metric fns as values.
1903      y_true: Target output.
1904      y_pred: Predicted output.
1905      mask: Computed mask value for the current output.
1906      weights: Weights to be applied on the current output.
1907
1908    Returns:
1909      A list of metric result tensors.
1910    """
1911    metric_results = []
1912    for metric_name, metric_fn in metrics_dict.items():
1913      with K.name_scope(metric_name):
1914        metric_result = training_utils_v1.call_metric_function(
1915            metric_fn, y_true, y_pred, weights=weights, mask=mask)
1916        metric_results.append(metric_result)
1917    return metric_results
1918
1919  def _handle_metrics(self,
1920                      outputs,
1921                      targets=None,
1922                      skip_target_masks=None,
1923                      sample_weights=None,
1924                      masks=None,
1925                      return_weighted_metrics=False,
1926                      return_weighted_and_unweighted_metrics=False):
1927    """Handles calling metric functions.
1928
1929    Args:
1930      outputs: List of outputs (predictions).
1931      targets: List of targets.
1932      skip_target_masks: Optional. List of boolean for whether the corresponding
1933        target should be ignored or not.
1934      sample_weights: Optional list of sample weight arrays.
1935      masks: List of computed output mask values.
1936      return_weighted_metrics: Flag that indicates whether weighted metrics
1937        should be computed instead of unweighted metrics. This flag is ignored
1938        when `return_weighted_and_unweighted_metrics` is enabled.
1939      return_weighted_and_unweighted_metrics: Flag that is used to indicate
1940        whether both weighted and unweighted metrics should be computed. When
1941        this is not enabled, we use `return_weighted_metrics` param to indicate
1942        whether weighted or unweighted metrics should be returned.
1943
1944    Returns:
1945      A list of metric result tensors.
1946    """
1947    # TODO(scottzhu): Update this to use the new training_endpoints. Currently
1948    # the eager and graph logic is bit different.
1949    skip_target_masks = skip_target_masks or [False] * len(outputs)
1950    metric_results = []
1951    with K.name_scope('metrics'):
1952      # Invoke all metrics added using `compile`.
1953      for i in range(len(outputs)):
1954        if skip_target_masks[i]:
1955          continue
1956        output = outputs[i] if outputs else None
1957        target = targets[i] if targets else None
1958        output_mask = masks[i] if masks else None
1959
1960        if (return_weighted_and_unweighted_metrics or
1961            not return_weighted_metrics):
1962          metric_results.extend(
1963              self._handle_per_output_metrics(self._per_output_metrics[i],
1964                                              target, output, output_mask))
1965        if return_weighted_and_unweighted_metrics or return_weighted_metrics:
1966          metric_results.extend(
1967              self._handle_per_output_metrics(
1968                  self._per_output_weighted_metrics[i],
1969                  target,
1970                  output,
1971                  output_mask,
1972                  weights=sample_weights[i] if sample_weights else None))
1973    return metric_results
1974
1975  def _check_trainable_weights_consistency(self):
1976    """Check trainable weights count consistency.
1977
1978    This will raise a warning if `trainable_weights` and
1979    `_collected_trainable_weights` are inconsistent (i.e. have different
1980    number of parameters).
1981    Inconsistency will typically arise when one modifies `model.trainable`
1982    without calling `model.compile` again.
1983    """
1984    if not hasattr(self, '_collected_trainable_weights'):
1985      return
1986
1987    if len(self.trainable_weights) != len(self._collected_trainable_weights):
1988      logging.log_first_n(
1989          logging.WARN, 'Discrepancy between trainable weights and collected'
1990          ' trainable weights, did you set `model.trainable`'
1991          ' without calling `model.compile` after ?', 1)
1992
1993  def _make_train_function(self):
1994    has_recompiled = self._recompile_weights_loss_and_weighted_metrics()
1995    self._check_trainable_weights_consistency()
1996    if isinstance(self.optimizer, list):
1997      raise ValueError('The `optimizer` in `compile` should be a single '
1998                       'optimizer.')
1999    # If we have re-compiled the loss/weighted metric sub-graphs then create
2000    # train function even if one exists already. This is because
2001    # `_feed_sample_weights` list has been updated on re-compile.
2002    if getattr(self, 'train_function', None) is None or has_recompiled:
2003      # Restore the compiled trainable state.
2004      current_trainable_state = self._get_trainable_state()
2005      self._set_trainable_state(self._compiled_trainable_state)
2006
2007      inputs = (self._feed_inputs +
2008                self._feed_targets +
2009                self._feed_sample_weights)
2010      if not isinstance(K.symbolic_learning_phase(), int):
2011        inputs += [K.symbolic_learning_phase()]
2012
2013      with K.get_graph().as_default():
2014        with K.name_scope('training'):
2015          # Training updates
2016          updates = self.optimizer.get_updates(
2017              params=self._collected_trainable_weights, loss=self.total_loss)
2018          # Unconditional updates
2019          updates += self.get_updates_for(None)
2020          # Conditional updates relevant to this model
2021          updates += self.get_updates_for(self.inputs)
2022
2023        metrics = self._get_training_eval_metrics()
2024        metrics_tensors = [
2025            m._call_result for m in metrics if hasattr(m, '_call_result')  # pylint: disable=protected-access
2026        ]
2027
2028      with K.name_scope('training'):
2029        # Gets loss and metrics. Updates weights at each call.
2030        fn = K.function(
2031            inputs, [self.total_loss] + metrics_tensors,
2032            updates=updates,
2033            name='train_function',
2034            **self._function_kwargs)
2035        setattr(self, 'train_function', fn)
2036
2037      # Restore the current trainable state
2038      self._set_trainable_state(current_trainable_state)
2039
2040  def _make_test_function(self):
2041    has_recompiled = self._recompile_weights_loss_and_weighted_metrics()
2042    # If we have re-compiled the loss/weighted metric sub-graphs then create
2043    # test function even if one exists already. This is because
2044    # `_feed_sample_weights` list has been updated on re-compile.
2045    if getattr(self, 'test_function', None) is None or has_recompiled:
2046      inputs = (self._feed_inputs +
2047                self._feed_targets +
2048                self._feed_sample_weights)
2049
2050      with K.get_graph().as_default():
2051        metrics = self._get_training_eval_metrics()
2052        metrics_tensors = [
2053            m._call_result for m in metrics if hasattr(m, '_call_result')  # pylint: disable=protected-access
2054        ]
2055
2056      with K.name_scope('evaluation'):
2057        updates = self.state_updates
2058        # Return loss and metrics, no gradient updates.
2059        # Does update the network states.
2060        fn = K.function(
2061            inputs, [self.total_loss] + metrics_tensors,
2062            updates=updates,
2063            name='test_function',
2064            **self._function_kwargs)
2065        setattr(self, 'test_function', fn)
2066
2067  def _make_predict_function(self):
2068    if not hasattr(self, 'predict_function'):
2069      self.predict_function = None
2070    if self.predict_function is None:
2071      inputs = self._feed_inputs
2072      # Gets network outputs. Does not update weights.
2073      # Does update the network states.
2074      kwargs = getattr(self, '_function_kwargs', {})
2075      with K.name_scope(ModeKeys.PREDICT):
2076        self.predict_function = K.function(
2077            inputs,
2078            self.outputs,
2079            updates=self.state_updates,
2080            name='predict_function',
2081            **kwargs)
2082
2083  def _make_execution_function(self, mode):
2084    if mode == ModeKeys.TRAIN:
2085      self._make_train_function()
2086      return self.train_function
2087    if mode == ModeKeys.TEST:
2088      self._make_test_function()
2089      return self.test_function
2090    if mode == ModeKeys.PREDICT:
2091      self._make_predict_function()
2092      return self.predict_function
2093
2094  def _distribution_standardize_user_data(self,
2095                                          x,
2096                                          y=None,
2097                                          sample_weight=None,
2098                                          class_weight=None,
2099                                          batch_size=None,
2100                                          validation_split=0,
2101                                          shuffle=False,
2102                                          epochs=1,
2103                                          allow_partial_batch=False):
2104    """Runs validation checks on input and target data passed by the user.
2105
2106    This is called when using tf.distribute.Strategy to train, evaluate or serve
2107    the model.
2108
2109    Args:
2110      x: Input data. A numpy array or `tf.data` dataset.
2111      y: Target data. A numpy array or None if x is a `tf.data` dataset.
2112      sample_weight: An optional sample-weight array passed by the user to
2113        weight the importance of each sample in `x`.
2114      class_weight: An optional class-weight array by the user to
2115        weight the importance of samples in `x` based on the class they belong
2116        to, as conveyed by `y`.
2117      batch_size: Integer batch size. If provided, it is used to run additional
2118        validation checks on stateful models.
2119      validation_split: Float between 0 and 1.
2120        Fraction of the training data to be used as validation data.
2121      shuffle: Boolean whether to shuffle the training data before each epoch.
2122      epochs: Integer epochs. If > 1, repeat the numpy training data epochs
2123        times when converting to training dataset.
2124      allow_partial_batch: Boolean whether to enforce that all batches have the
2125        same size.
2126
2127    Returns:
2128      Dataset instance.
2129
2130    Raises:
2131      ValueError: In case of invalid user-provided data.
2132      RuntimeError: If the model was never compiled.
2133    """
2134    if class_weight:
2135      raise NotImplementedError('`class_weight` is currently not supported '
2136                                'when using tf.distribute.Strategy.')
2137
2138    if (sample_weight is not None and sample_weight.all() and
2139        K.is_tpu_strategy(self._distribution_strategy)):
2140      raise NotImplementedError('`sample_weight` is currently not supported '
2141                                'when using TPUStrategy.')
2142
2143    # Validates `steps` and `shuffle` arguments right at the beginning
2144    # since we use it to construct the dataset object.
2145    # TODO(anjalisridhar): Remove this check once we refactor the
2146    # _standardize_user_data code path. This check is already present elsewhere
2147    # in the codebase.
2148    if isinstance(x, dataset_ops.DatasetV2):
2149      if shuffle:
2150        training_utils_v1.verify_dataset_shuffled(x)
2151
2152    strategy = self._distribution_strategy
2153    with strategy.scope():
2154      # We should be sure to call get_session() inside the strategy.scope()
2155      # so the strategy can affect the session options.
2156      if ops.executing_eagerly_outside_functions():
2157        session = None
2158      else:
2159        session = K.get_session()
2160
2161      first_x_value = nest.flatten(x)[0]
2162      if isinstance(first_x_value, np.ndarray):
2163        x = training_utils.list_to_tuple(x)
2164        if y is not None:
2165          y = training_utils.list_to_tuple(y)
2166          if sample_weight is not None:
2167            sample_weight = training_utils.list_to_tuple(sample_weight)
2168            in_tuple = (x, y, sample_weight)
2169          else:
2170            in_tuple = (x, y)
2171        else:
2172          in_tuple = x
2173
2174        ds = strategy.extended.experimental_make_numpy_dataset(in_tuple,
2175                                                               session=session)
2176        if shuffle:
2177          # We want a buffer size that is larger than the batch size provided by
2178          # the user and provides sufficient randomness. Note that larger
2179          # numbers introduce more memory usage based on the size of each
2180          # sample.
2181          ds = ds.shuffle(max(1024, batch_size * 8))
2182        if epochs > 1:
2183          ds = ds.repeat(epochs)
2184
2185        # We need to use the drop_remainder argument to get a known static
2186        # input shape which is required for TPUs.
2187        drop_remainder = (not allow_partial_batch and
2188                          strategy.extended.experimental_require_static_shapes)
2189
2190        # TODO(b/131720208): We still drop remainder here if number of examples
2191        # is divisible by batch size, as sometimes dynamic padder will time out
2192        # with keras.metrics.CategoricalAccuracy() metric.
2193        if K.is_tpu_strategy(strategy) and not drop_remainder:
2194          dataset_size = first_x_value.shape[0]
2195          if dataset_size % batch_size == 0:
2196            drop_remainder = True
2197
2198        x = ds.batch(batch_size, drop_remainder=drop_remainder)
2199      else:
2200        assert isinstance(x, dataset_ops.DatasetV2)
2201        training_utils_v1.validate_dataset_input(x, y, sample_weight,
2202                                                 validation_split)
2203    return x
2204
2205  def _standardize_user_data(self,
2206                             x,
2207                             y=None,
2208                             sample_weight=None,
2209                             class_weight=None,
2210                             batch_size=None,
2211                             check_steps=False,
2212                             steps_name='steps',
2213                             steps=None,
2214                             validation_split=0,
2215                             shuffle=False,
2216                             extract_tensors_from_dataset=False):
2217    """Runs validation checks on input and target data passed by the user.
2218
2219    Also standardizes the data to lists of arrays, in order.
2220
2221    Also builds and compiles the model on the fly if it is a subclassed model
2222    that has never been called before (and thus has no inputs/outputs).
2223
2224    This is a purely internal method, subject to refactoring at any time.
2225
2226    Args:
2227      x: Input data. It could be:
2228        - A Numpy array (or array-like), or a list of arrays
2229          (in case the model has multiple inputs).
2230        - A TensorFlow tensor, or a list of tensors
2231          (in case the model has multiple inputs).
2232        - A dict mapping input names to the corresponding array/tensors,
2233          if the model has named inputs.
2234        - A `tf.data` dataset.
2235      y: Target data. Like the input data `x`,
2236        it could be either Numpy array(s) or TensorFlow tensor(s).
2237        It should be consistent with `x` (you cannot have Numpy inputs and
2238        tensor targets, or inversely). If `x` is a dataset, `y` should not be
2239        specified (since targets will be obtained from the iterator).
2240      sample_weight: An optional sample-weight array passed by the user to
2241        weight the importance of each sample in `x`.
2242      class_weight: An optional class-weight array by the user to
2243        weight the importance of samples in `x` based on the class they belong
2244        to, as conveyed by `y`. If both `sample_weight` and `class_weight` are
2245        provided, the weights are multiplied.
2246      batch_size: Integer batch size. If provided, it is used to run additional
2247        validation checks on stateful models.
2248      check_steps: boolean, True if we want to check for validity of `steps` and
2249        False, otherwise. For example, when we are standardizing one batch of
2250        data for train_on_batch/predict_on_batch/test_on_batch APIs, `steps`
2251        value is not required and we should not check for its validity in these
2252        cases.
2253      steps_name: The public API's parameter name for `steps`.
2254      steps: Integer or `None`. Total number of steps (batches of samples) to
2255        execute.
2256      validation_split: Float between 0 and 1.
2257        Fraction of the training data to be used as validation data.
2258      shuffle: Boolean whether to shuffle the training data before each epoch.
2259      extract_tensors_from_dataset: Boolean. When `x` is a dataset instance,
2260        this indicates whether to extract actual tensors from the dataset or
2261        instead output the dataset instance itself.
2262        Set to True when calling from `train_on_batch`/etc.
2263
2264    Returns:
2265      A tuple of 3: inputs (arrays or dicts, depending on whether `x` was a dict
2266      or not), target arrays, sample-weight arrays.
2267      If the model's input and targets are symbolic, these lists are empty
2268      (since the model takes no user-provided data, instead the data comes
2269      from the symbolic inputs/targets).
2270
2271    Raises:
2272      ValueError: In case of invalid user-provided data.
2273      RuntimeError: If the model was never compiled.
2274    """
2275    if isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
2276      # Graph mode dataset. We'll pass the dataset as-is (unless
2277      # `extract_tensors_from_dataset` is True, in which case we extract
2278      # the tensors from the dataset and we output them.
2279      training_utils_v1.validate_dataset_input(x, y, sample_weight,
2280                                               validation_split)
2281      if shuffle:
2282        training_utils_v1.verify_dataset_shuffled(x)
2283
2284      is_dataset = True
2285      if extract_tensors_from_dataset:
2286        # We do this for `train_on_batch`/etc.
2287        x, y, sample_weight = training_utils_v1.extract_tensors_from_dataset(x)
2288    elif isinstance(x, iterator_ops.Iterator):
2289      # Graph mode iterator. We extract the symbolic tensors.
2290      training_utils_v1.validate_dataset_input(x, y, sample_weight,
2291                                               validation_split)
2292      iterator = x
2293      x, y, sample_weight = training_utils_v1.unpack_iterator_input(iterator)
2294      is_dataset = True
2295    else:
2296      is_dataset = False
2297
2298    # Validates `steps` argument based on x's type.
2299    if check_steps:
2300      training_utils_v1.check_steps_argument(x, steps, steps_name)
2301
2302    # First, we build the model on the fly if necessary.
2303    if not self.inputs:
2304      all_inputs, y_input, dict_inputs = self._build_model_with_inputs(x, y)
2305      is_build_called = True
2306    else:
2307      all_inputs = []
2308      # Whether this is a subclassed model that expects dictionary inputs
2309      # rather than list inputs (e.g. FeatureColumn-based models).
2310      dict_inputs = isinstance(self.inputs, dict)
2311      is_build_called = False
2312      y_input = y
2313
2314    # Second, we compile the model on the fly if necessary, mostly for subclass
2315    # models.
2316    is_compile_called = False
2317    if not self._is_compiled and self.optimizer:
2318      self._compile_from_inputs(all_inputs, y_input, x, y)
2319      is_compile_called = True
2320
2321    # In graph mode, if we had just set inputs and targets as symbolic tensors
2322    # by invoking build and compile on the model respectively, we do not have to
2323    # feed anything to the model. Model already has input and target data as
2324    # part of the graph.
2325    # Note: in this case, `any` and `all` are equivalent since we disallow
2326    # mixed symbolic/value inputs.
2327
2328    # self.run_eagerly is not free to compute, so we want to reuse the value.
2329    run_eagerly = self.run_eagerly
2330
2331    if (not run_eagerly and is_build_called and is_compile_called and
2332        not is_dataset  and any(_is_symbolic_tensor(v) for v in all_inputs)):
2333      return [], [], None
2334
2335    return self._standardize_tensors(
2336        x, y, sample_weight,
2337        run_eagerly=run_eagerly,
2338        dict_inputs=dict_inputs,
2339        is_dataset=is_dataset,
2340        class_weight=class_weight,
2341        batch_size=batch_size)
2342
2343  def _standardize_tensors(self, x, y, sample_weight, run_eagerly, dict_inputs,
2344                           is_dataset, class_weight=None, batch_size=None):
2345    if run_eagerly:
2346      # In eager mode, do not do shape validation
2347      # since the network has no input nodes (placeholders) to be fed.
2348      feed_input_names = self.input_names
2349      feed_input_shapes = None
2350    elif not self._is_graph_network:
2351      # Case: symbolic-mode subclassed network. Do not do shape validation.
2352      feed_input_names = self._feed_input_names
2353      feed_input_shapes = None
2354    else:
2355      # Case: symbolic-mode graph network.
2356      # In this case, we run extensive shape validation checks.
2357      feed_input_names = self._feed_input_names
2358      feed_input_shapes = self._feed_input_shapes
2359
2360    # Standardize the inputs.
2361    if not isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
2362      # TODO(fchollet): run static checks with dataset output shape(s).
2363      x = training_utils_v1.standardize_input_data(
2364          x,
2365          feed_input_names,
2366          feed_input_shapes,
2367          check_batch_axis=False,  # Don't enforce the batch size.
2368          exception_prefix='input')
2369
2370    # Get typespecs for the input data and sanitize it if necessary.
2371    # TODO(momernick): This should be capable of doing full input validation
2372    # at all times - validate that this is so and refactor the standardization
2373    # code.
2374    if isinstance(x, dataset_ops.DatasetV2):
2375      x_shapes = dataset_ops.get_structure(x)
2376      if isinstance(x_shapes, tuple):
2377        # If the output of a Dataset is a tuple, we assume it's either of the
2378        # form (x_data, y_data) or (x_data, y_data, sample_weights). In either
2379        # case, we only care about x_data here.
2380        x_shapes = x_shapes[0]
2381    else:
2382      flat_inputs = nest.flatten(x, expand_composites=False)
2383      flat_expected_inputs = nest.flatten(self.inputs, expand_composites=False)
2384      converted_x = []
2385      for (a, b) in zip(flat_inputs, flat_expected_inputs):
2386        converted_x.append(_convert_scipy_sparse_tensor(a, b))
2387      x = nest.pack_sequence_as(x, converted_x, expand_composites=False)
2388
2389      def _type_spec_from_value(value):
2390        """Grab type_spec without converting array-likes to tensors."""
2391        if tf_utils.is_extension_type(value):
2392          return value._type_spec  # pylint: disable=protected-access
2393        # Get a TensorSpec for array-like data without
2394        # converting the data to a Tensor
2395        if hasattr(value, 'shape') and hasattr(value, 'dtype'):
2396          return tensor_spec.TensorSpec(value.shape, value.dtype)
2397        else:
2398          return type_spec.type_spec_from_value(value)
2399
2400      x_shapes = nest.map_structure(_type_spec_from_value, x)
2401
2402    flat_inputs = nest.flatten(x_shapes, expand_composites=False)
2403    flat_expected_inputs = nest.flatten(self.inputs, expand_composites=False)
2404    for (a, b) in zip(flat_inputs, flat_expected_inputs):
2405      nest.assert_same_structure(a, b, expand_composites=True)
2406
2407    if y is not None:
2408      # Prepare self._sample_weight_modes. List with the same length as
2409      # model outputs.
2410      training_utils_v1.prepare_sample_weight_modes(self._training_endpoints,
2411                                                    self.sample_weight_mode)
2412      feed_output_names = self._feed_output_names
2413      feed_sample_weight_modes = self._sample_weight_modes
2414      if not self._is_graph_network:
2415        feed_output_shapes = None
2416      else:
2417        feed_output_shapes = self._feed_output_shapes
2418
2419      # Standardize the outputs.
2420      y = training_utils_v1.standardize_input_data(
2421          y,
2422          feed_output_names,
2423          # Don't enforce target shapes to match output shapes.
2424          # Precise checks will be run in `check_loss_and_target_compatibility`.
2425          shapes=None,
2426          check_batch_axis=False,  # Don't enforce the batch size.
2427          exception_prefix='target')
2428
2429      # Generate sample-wise weight values given the `sample_weight` and
2430      # `class_weight` arguments.
2431      sample_weights = training_utils_v1.standardize_sample_weights(
2432          sample_weight, feed_output_names)
2433      class_weights = training_utils_v1.standardize_class_weights(
2434          class_weight, feed_output_names)
2435
2436      sample_weights = [
2437          training_utils_v1.standardize_weights(ref, sw, cw, mode)
2438          for (ref, sw, cw, mode) in zip(y, sample_weights, class_weights,
2439                                         feed_sample_weight_modes)
2440      ]
2441      # Check that all arrays have the same length.
2442      if not self._distribution_strategy:
2443        training_utils_v1.check_array_lengths(x, y, sample_weights)
2444        if self._is_graph_network and not run_eagerly:
2445          # Additional checks to avoid users mistakenly using improper loss fns.
2446          training_utils_v1.check_loss_and_target_compatibility(
2447              y, self._feed_loss_fns, feed_output_shapes)
2448
2449      sample_weights, _, _ = training_utils.handle_partial_sample_weights(
2450          y, sample_weights, feed_sample_weight_modes, check_all_flat=True)
2451    else:
2452      y = []
2453      sample_weights = None
2454
2455    if self.stateful and batch_size and not is_dataset:
2456      # Check that for stateful networks, number of samples is a multiple
2457      # of the static batch size.
2458      if x[0].shape[0] % batch_size != 0:
2459        raise ValueError('In a stateful network, '
2460                         'you should only pass inputs with '
2461                         'a number of samples that can be '
2462                         'divided by the batch size. Found: ' +
2463                         str(x[0].shape[0]) + ' samples')
2464
2465    # If dictionary inputs were provided, we return a dictionary as well.
2466    if dict_inputs and not isinstance(x, (dataset_ops.DatasetV1,
2467                                          dataset_ops.DatasetV2)):
2468      x = dict(zip(feed_input_names, x))
2469    return x, y, sample_weights
2470
2471  def _build_model_with_inputs(self, inputs, targets):
2472    """Build the model (set model inputs/outputs), mainly for subclass model."""
2473    processed_inputs = []
2474    is_dict_inputs = False
2475    orig_inputs = inputs
2476    # We need to use `inputs` to set the model inputs.
2477    # If input data is a dataset iterator in graph mode or if it is an eager
2478    # iterator and only one batch of samples is required, we fetch the data
2479    # tensors from the iterator and then standardize them.
2480    if isinstance(inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
2481      inputs, targets, _ = training_utils_v1.extract_tensors_from_dataset(
2482          inputs)
2483    # We type-check that `inputs` and `targets` are either single arrays
2484    # or lists of arrays, and extract a flat list of inputs from the passed
2485    # structure.
2486    training_utils_v1.validate_input_types(inputs, orig_inputs)
2487
2488    if isinstance(inputs, (list, tuple)):
2489      processed_inputs += list(inputs)
2490    elif isinstance(inputs, dict):
2491      is_dict_inputs = True
2492      keys = sorted(inputs.keys())
2493      processed_inputs = [inputs[k] for k in keys]
2494    else:
2495      processed_inputs.append(inputs)
2496    # Now that we have a flat set of inputs, we make sure that none of them
2497    # are CompositeTensors or CompositeTensorValues of any type (or scipy
2498    # sparse arrays, which we treat as SparseTensor values). We cannot safely
2499    # infer input data from an arbitrary composite tensor, so we don't try -
2500    # users should explicitly add composite tensor inputs to their subclassed
2501    # models.
2502    for input_tensor in processed_inputs:
2503      if training_utils_v1.is_composite_or_composite_value(input_tensor):
2504        # TODO(b/132691975): Document subclass-model CT input handling.
2505        raise ValueError(
2506            'All SparseTensor and RaggedTensor inputs must be explicitly '
2507            'declared using a keras.Input() with sparse=True or ragged=True. '
2508            'We found an undeclared input %s. For Sequential models, please '
2509            'add a keras.Input() as your first Layer. For subclassed models, '
2510            'please call self._set_inputs() on your input set, which you can '
2511            'create using keras.Input() for each input to your model.' %
2512            (input_tensor,))
2513    # Build the model using the retrieved inputs (value or symbolic).
2514    # If values are generated from a dataset, then in symbolic-mode
2515    # placeholders will be created to match the value shapes.
2516    if isinstance(orig_inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2,
2517                                iterator_ops.Iterator)):
2518      if not self.inputs:
2519        # For subclassed models, a robust input spec is not available so we
2520        # must cast to the model dtype.
2521        inputs = training_utils_v1.cast_if_floating_dtype(inputs, self.dtype)
2522
2523      def create_tensor_spec(t):
2524        return tensor_spec.TensorSpec(t.shape, t.dtype)
2525
2526      cast_inputs = nest.map_structure(create_tensor_spec, inputs)
2527    elif training_utils_v1.has_tensors(inputs):
2528      cast_inputs = training_utils_v1.cast_if_floating_dtype(inputs)
2529    else:
2530      cast_inputs = inputs
2531    self._set_inputs(cast_inputs)
2532    return processed_inputs, targets, is_dict_inputs
2533
2534  def _compile_from_inputs(self, all_inputs, target, orig_inputs, orig_target):
2535    if target is not None:
2536      # We need to use `y` to set the model targets.
2537      if training_utils_v1.has_tensors(target):
2538        target = training_utils_v1.cast_if_floating_dtype_and_mismatch(
2539            target, self.outputs)
2540      training_utils_v1.validate_input_types(
2541          target, orig_target, allow_dict=False, field_name='target')
2542      if isinstance(target, (list, tuple)):
2543        all_inputs += list(target)
2544      else:
2545        all_inputs.append(target)
2546    # Type check that all inputs are *either* value *or* symbolic.
2547    # TODO(fchollet): this check could be removed in Eager mode?
2548    if any(tensor_util.is_tf_type(v) for v in all_inputs):
2549      if not all(tensor_util.is_tf_type(v) for v in all_inputs):
2550        raise ValueError('Do not pass inputs that mix Numpy arrays and '
2551                         'TensorFlow tensors. '
2552                         'You passed: x=' + str(orig_inputs) +
2553                         '; y=' + str(orig_target))
2554    is_dataset = isinstance(orig_inputs, (dataset_ops.DatasetV1,
2555                                          dataset_ops.DatasetV2,
2556                                          iterator_ops.Iterator))
2557    if is_dataset or context.executing_eagerly():
2558      target_tensors = None
2559    else:
2560      # Handle target tensors if any passed.
2561      if target is not None:
2562        if not isinstance(target, (list, tuple)):
2563          target = [target]
2564        target_tensors = [v for v in target if _is_symbolic_tensor(v)]
2565      else:
2566        target_tensors = None
2567
2568    self.compile(
2569        optimizer=self.optimizer,
2570        loss=self.loss,
2571        metrics=self._compile_metrics,
2572        weighted_metrics=self._compile_weighted_metrics,
2573        loss_weights=self.loss_weights,
2574        target_tensors=target_tensors,
2575        sample_weight_mode=self.sample_weight_mode,
2576        run_eagerly=self.run_eagerly,
2577        experimental_run_tf_function=self._experimental_run_tf_function)
2578
2579  # TODO(omalleyt): Consider changing to a more descriptive function name.
2580  def _set_inputs(self, inputs, outputs=None, training=None):
2581    """Set model's input and output specs based on the input data received.
2582
2583    This is to be used for Model subclasses, which do not know at instantiation
2584    time what their inputs look like.
2585
2586    Args:
2587      inputs: Single array, or list of arrays. The arrays could be placeholders,
2588        Numpy arrays, data tensors, or TensorSpecs.
2589        - if placeholders: the model is built on top of these placeholders,
2590          and we expect Numpy data to be fed for them when calling `fit`/etc.
2591        - if Numpy data or TensorShapes: we create placeholders matching the
2592          TensorShapes or shapes of the Numpy arrays. We expect Numpy data to be
2593          fed for these placeholders when calling `fit`/etc.
2594        - if data tensors: the model is built on top of these tensors.
2595          We do not expect any Numpy data to be provided when calling `fit`/etc.
2596      outputs: None, a data tensor, or a list of tensors. If None, the
2597        outputs will be determined by invoking `self.call()`, otherwise the
2598        provided value will be used.
2599      training: Boolean or None. Only relevant in symbolic mode. Specifies
2600        whether to build the model's graph in inference mode (False), training
2601        mode (True), or using the Keras learning phase (None).
2602    Raises:
2603      ValueError: If dict inputs are passed to a Sequential Model where the
2604        first layer isn't FeatureLayer.
2605    """
2606    self._set_save_spec(inputs)
2607    inputs = self._set_input_attrs(inputs)
2608
2609    if outputs is None:
2610      kwargs = {}
2611      if self._expects_training_arg:
2612        # In V2 mode, feeding `training=None` is not allowed because any value
2613        # explicitly passed by the user is respected, even `None`.`
2614        if training is None and not ops.executing_eagerly_outside_functions():
2615          training = K.learning_phase()
2616        if training is not None:
2617          kwargs['training'] = training
2618      try:
2619        outputs = self(inputs, **kwargs)
2620      except NotImplementedError:
2621        # This Model or a submodel is dynamic and hasn't overridden
2622        # `compute_output_shape`.
2623        outputs = None
2624
2625    self._set_output_attrs(outputs)
2626
2627  @trackable.no_automatic_dependency_tracking
2628  def _set_input_attrs(self, inputs):
2629    """Sets attributes related to the inputs of the Model."""
2630    if self.inputs:
2631      raise ValueError('Model inputs are already set.')
2632
2633    if self.__class__.__name__ == 'Sequential' and not self.built:
2634      if tensor_util.is_tf_type(inputs):
2635        input_shape = (None,) + tuple(inputs.shape.as_list()[1:])
2636      elif isinstance(inputs, tensor_shape.TensorShape):
2637        input_shape = (None,) + tuple(inputs.as_list()[1:])
2638      elif isinstance(inputs, dict):
2639        # We assert that the first layer is a FeatureLayer.
2640        if not training_utils_v1.is_feature_layer(self.layers[0]):
2641          raise ValueError('Passing a dictionary input to a Sequential Model '
2642                           'which doesn\'t have FeatureLayer as the first layer'
2643                           ' is an error.')
2644        input_shape = (None,)
2645      else:
2646        input_shape = (None,) + tuple(inputs.shape[1:])
2647      self._build_input_shape = input_shape
2648
2649    # Cast inputs to the compute dtype. This is primarily used
2650    # when saving to determine the correct dtype in the input signature.
2651    inputs = self._maybe_cast_inputs(inputs)
2652
2653    # On-the-fly setting of symbolic model inputs (either by using the tensor
2654    # provided, or by creating a placeholder if Numpy data was provided).
2655    model_inputs = training_utils_v1.ModelInputs(inputs)
2656    inputs = model_inputs.get_symbolic_inputs()
2657    self.inputs = model_inputs.get_symbolic_inputs(return_single_as_list=True)
2658    self.input_names = model_inputs.get_input_names()
2659
2660    self._feed_inputs = []
2661    self._feed_input_names = []
2662    self._feed_input_shapes = []
2663
2664    for k, v in model_inputs.as_dict():
2665      if K.is_placeholder(v):
2666        self._feed_input_names.append(k)
2667        self._feed_inputs.append(v)
2668        self._feed_input_shapes.append(K.int_shape(v))
2669
2670    return inputs
2671
2672  @trackable.no_automatic_dependency_tracking
2673  def _set_output_attrs(self, outputs):
2674    """Sets attributes related to the outputs of the Model."""
2675    # NOTE(taylorrobie): This convention cannot be changed without updating the
2676    #                    data adapter since it assumes nest.flatten ordering.
2677    outputs = nest.flatten(outputs)
2678    self.outputs = outputs
2679    self.output_names = training_utils_v1.generic_output_names(outputs)
2680    # TODO(scottzhu): Should we cleanup the self._training_endpoints here?
2681    self.built = True
2682
2683  @property
2684  def _targets(self):
2685    """The output target tensors for the model."""
2686    return [
2687        e.training_target.target
2688        for e in self._training_endpoints
2689        if e.has_training_target()
2690    ]
2691
2692  @property
2693  def _feed_targets(self):
2694    return [
2695        e.training_target.target
2696        for e in self._training_endpoints
2697        if e.has_feedable_training_target()
2698    ]
2699
2700  @property
2701  def _feed_output_names(self):
2702    return [
2703        e.output_name
2704        for e in self._training_endpoints
2705        if e.has_feedable_training_target()
2706    ]
2707
2708  @property
2709  def _feed_output_shapes(self):
2710    return [
2711        e.feed_output_shape
2712        for e in self._training_endpoints
2713        if e.has_feedable_training_target()
2714    ]
2715
2716  @property
2717  def _feed_loss_fns(self):
2718    return [
2719        e.loss_fn
2720        for e in self._training_endpoints
2721        if e.has_feedable_training_target()
2722    ]
2723
2724  @property
2725  def _loss_weights_list(self):
2726    return [e.loss_weight for e in self._training_endpoints]
2727
2728  @property
2729  def _output_loss_metrics(self):
2730    if hasattr(self, '_training_endpoints'):
2731      return [
2732          e.output_loss_metric
2733          for e in self._training_endpoints
2734          if e.output_loss_metric is not None
2735      ]
2736    return None
2737
2738  @property
2739  def sample_weights(self):
2740    return [e.sample_weight for e in self._training_endpoints]
2741
2742  @property
2743  def _sample_weight_modes(self):
2744    return [e.sample_weight_mode for e in self._training_endpoints]
2745
2746  @property
2747  def _feed_sample_weights(self):
2748    return [e.sample_weight for e in self._training_endpoints
2749            if e.sample_weight is not None]
2750
2751  def _maybe_load_initial_epoch_from_ckpt(self, initial_epoch, mode):
2752    """Maybe load initial epoch from ckpt considering possible worker recovery.
2753
2754    Refer to tensorflow/python/keras/distribute/worker_training_state.py
2755    for more information.
2756
2757    Args:
2758      initial_epoch: The original initial_epoch user passes in in `fit()`.
2759      mode: The mode for running `model.fit()`.
2760
2761    Returns:
2762      If the training is recovering from previous failure under multi-worker
2763      training setting, return the epoch the training is supposed to continue
2764      at. Otherwise, return the `initial_epoch` the user passes in.
2765    """
2766    if self._training_state is not None:
2767      return self._training_state.maybe_load_initial_epoch_from_ckpt(
2768          initial_epoch, mode)
2769    return initial_epoch
2770
2771  def _get_training_eval_metrics(self):
2772    """Returns all the metrics that are to be reported.
2773
2774    This includes the output loss metrics, compile metrics/weighted metrics,
2775    add_metric metrics.
2776    """
2777    metrics = []
2778    metrics.extend(getattr(self, '_output_loss_metrics', None) or [])
2779    metrics.extend(getattr(self, 'metrics', None) or [])
2780    return metrics
2781
2782  def _assert_compile_was_called(self):
2783    # Checks whether `compile` has been called. If it has been called,
2784    # then the optimizer is set. This is different from whether the
2785    # model is compiled
2786    # (i.e. whether the model is built and its inputs/outputs are set).
2787    if not self._compile_was_called:
2788      raise RuntimeError('You must compile your model before '
2789                         'training/testing. '
2790                         'Use `model.compile(optimizer, loss)`.')
2791
2792  def _in_multi_worker_mode(self):
2793    """Method to infer if this `Model` is working in multi-worker settings.
2794
2795    Multi-worker training refers to the setup where the training is
2796    distributed across multiple workers, as opposed to the case where
2797    only a local process performs the training. This function is
2798    used to infer for example whether or not a distribute coordinator
2799    should be run, and thus TensorFlow servers should be started for
2800    communication with other servers in the cluster, or whether or not
2801    saving/restoring checkpoints is relevant for preemption fault tolerance.
2802
2803    Experimental. Signature and implementation are subject to change.
2804
2805    Returns:
2806      Whether this model indicates it's working in multi-worker settings.
2807    """
2808    strategy = self._distribution_strategy
2809
2810    # Otherwise, use the strategy whose scope this is in.
2811    if not strategy and distribution_strategy_context.has_strategy():
2812      strategy = distribution_strategy_context.get_strategy()
2813    return strategy and strategy.extended._in_multi_worker_mode()  # pylint: disable=protected-access
2814
2815  @property
2816  def _trackable_saved_model_saver(self):
2817    return model_serialization.ModelSavedModelSaver(self)
2818
2819  def _get_compile_args(self, user_metrics=True):
2820    del user_metrics
2821    self._assert_compile_was_called()
2822    kwargs = {
2823        'loss': self.loss,
2824        'metrics': self._compile_metrics,
2825        'loss_weights': self.loss_weights,
2826        'sample_weight_mode': self.sample_weight_mode,
2827        'weighted_metrics': self._compile_weighted_metrics,
2828    }
2829    return kwargs
2830
2831  @property
2832  def _compile_was_called(self):
2833    return self._v1_compile_was_called
2834
2835
2836class DistributedCallbackModel(Model):
2837  """Model that is used for callbacks with tf.distribute.Strategy."""
2838
2839  def __init__(self, model):
2840    super(DistributedCallbackModel, self).__init__()
2841    self.optimizer = model.optimizer
2842
2843  def set_original_model(self, orig_model):
2844    self._original_model = orig_model
2845
2846  def save_weights(self, filepath, overwrite=True, save_format=None):
2847    self._replicated_model.save_weights(filepath, overwrite=overwrite,
2848                                        save_format=save_format)
2849
2850  def save(self, filepath, overwrite=True, include_optimizer=True):
2851    # save weights from the distributed model to the original model
2852    distributed_model_weights = self.get_weights()
2853    self._original_model.set_weights(distributed_model_weights)
2854    # TODO(anjalisridhar): Do we need to save the original model here?
2855    # Saving the first replicated model works as well.
2856    self._original_model.save(filepath, overwrite=True, include_optimizer=False)
2857
2858  def load_weights(self, filepath, by_name=False):
2859    self._original_model.load_weights(filepath, by_name=False)
2860    # Copy the weights from the original model to each of the replicated models.
2861    orig_model_weights = self._original_model.get_weights()
2862    distributed_training_utils_v1.set_weights(
2863        self._original_model._distribution_strategy, self,  # pylint: disable=protected-access
2864        orig_model_weights)
2865
2866  def __getattr__(self, item):
2867    # Allowed attributes of the model that can be accessed by the user
2868    # during a callback.
2869    if item not in ('_setattr_tracking', '_layers'):
2870      logging.warning('You are accessing attribute ' + item + ' of the '
2871                      'DistributedCallbackModel that may not have been set '
2872                      'correctly.')
2873    return super(DistributedCallbackModel, self).__getattr__(item)
2874
2875
2876class _TrainingEndpoint(object):
2877  """A container for the training output/target and related entities.
2878
2879  In the case of model with multiple outputs, there is a one-to-one mapping
2880  between model output (y_pred), model target (y_true), loss, metrics etc.
2881  By unifying these entities into one class, different entity can access
2882  information between each other, rather than currently access different list of
2883  attributes of the model.
2884  """
2885
2886  def __init__(self,
2887               output,
2888               output_name,
2889               loss_fn,
2890               loss_weight=None,
2891               training_target=None,
2892               output_loss_metric=None,
2893               sample_weight=None,
2894               sample_weight_mode=None):
2895    """Initialize the _TrainingEndpoint.
2896
2897    Note that the output and output_name should be stable as long as the model
2898    structure doesn't change. The training_target suppose to be mutable since
2899    the information is provided via `compile()`
2900
2901    Args:
2902      output: the output tensor of the model.
2903      output_name: the unique name of the output tensor.
2904      loss_fn: the loss function for the output tensor.
2905      loss_weight: float, the weights for the loss.
2906      training_target: the _TrainingTarget for the model.
2907      output_loss_metric: the metric object for the loss function.
2908      sample_weight: the weights for how a sample is weighted during metric and
2909        loss calculation. Could be None.
2910      sample_weight_mode: string, 'temporal', 'samplewise' or None. The mode for
2911        how the sample_weight is populated.
2912    """
2913    self._output = output
2914    self._output_name = output_name
2915    self._loss_fn = loss_fn
2916    self._loss_weight = loss_weight
2917    self._training_target = training_target
2918    self._output_loss_metric = output_loss_metric
2919    self._sample_weight = sample_weight
2920    self._sample_weight_mode = sample_weight_mode
2921
2922  @property
2923  def output(self):
2924    return self._output
2925
2926  @property
2927  def output_name(self):
2928    return self._output_name
2929
2930  @property
2931  def shape(self):
2932    return K.int_shape(self.output)
2933
2934  @property
2935  def loss_fn(self):
2936    return self._loss_fn
2937
2938  @property
2939  def loss_weight(self):
2940    return self._loss_weight
2941
2942  @loss_weight.setter
2943  def loss_weight(self, value):
2944    self._loss_weight = value
2945
2946  @property
2947  def training_target(self):
2948    return self._training_target
2949
2950  @training_target.setter
2951  def training_target(self, value):
2952    self._training_target = value
2953
2954  def create_training_target(self, target, run_eagerly=False):
2955    """Create training_target instance and update the self.training_target.
2956
2957    Note that the input target should just be a tensor or None, and
2958    corresponding training target will be created based on the output and
2959    loss_fn.
2960
2961    Args:
2962      target: the target tensor for the current output. Could be None.
2963      run_eagerly: boolean, whether the model is in run_eagerly mode.
2964
2965    Raises:
2966      ValueError if the training_target field for the current instance has
2967      already been populated.
2968    """
2969    if self.has_training_target():
2970      raise ValueError('The training_target field for the _TrainingEndpoint '
2971                       'instance has already been populated')
2972    if run_eagerly:
2973      # When run_eagerly, the target tensor is ignored, and the None placeholder
2974      # is created instead.
2975      self.training_target = _TrainingTarget(
2976          None, feedable=True, skip_target_weights=False)
2977      return
2978
2979    if self.should_skip_target():
2980      self.training_target = _TrainingTarget(None)
2981    else:
2982      if target is not None and not K.is_placeholder(target):
2983        feedable = False
2984        skip_target_weights = True
2985      else:
2986        feedable = True
2987        skip_target_weights = False
2988
2989      if target is None:
2990        target_dtype = losses.LABEL_DTYPES_FOR_LOSSES.get(
2991            self.loss_fn, K.dtype(self.output))
2992
2993        target = K.placeholder(
2994            ndim=len(self.shape),
2995            name=self.output_name + '_target',
2996            sparse=K.is_sparse(self.output),
2997            dtype=target_dtype)
2998
2999      self.training_target = _TrainingTarget(
3000          target,
3001          feedable=feedable,
3002          skip_target_weights=skip_target_weights)
3003
3004  @property
3005  def output_loss_metric(self):
3006    return self._output_loss_metric
3007
3008  @output_loss_metric.setter
3009  def output_loss_metric(self, value):
3010    self._output_loss_metric = value
3011
3012  @property
3013  def sample_weight(self):
3014    return self._sample_weight
3015
3016  @sample_weight.setter
3017  def sample_weight(self, value):
3018    self._sample_weight = value
3019
3020  @property
3021  def sample_weight_mode(self):
3022    return self._sample_weight_mode
3023
3024  @sample_weight_mode.setter
3025  def sample_weight_mode(self, value):
3026    self._sample_weight_mode = value
3027
3028  def should_skip_target(self):
3029    return self._loss_fn is None
3030
3031  def should_skip_target_weights(self):
3032    return (self.should_skip_target() or self.training_target is None or
3033            self.training_target.skip_target_weights)
3034
3035  def has_training_target(self):
3036    return self.training_target is not None
3037
3038  def has_feedable_training_target(self):
3039    return (not self.should_skip_target() and
3040            self.training_target is not None and self.training_target.feedable)
3041
3042  def loss_name(self):
3043    if self._loss_fn is not None:
3044      return self._output_name + '_loss'
3045    return None
3046
3047  @property
3048  def feed_output_shape(self):
3049    """The output shape for the feedable target."""
3050    if not self.has_feedable_training_target():
3051      return None
3052
3053    if ((isinstance(self.loss_fn, losses.LossFunctionWrapper) and
3054         self.loss_fn.fn == losses.sparse_categorical_crossentropy)) or (
3055             isinstance(self.loss_fn, losses.SparseCategoricalCrossentropy)):
3056      if K.image_data_format() == 'channels_first':
3057        return (self.shape[0], 1) + self.shape[2:]
3058      else:
3059        return self.shape[:-1] + (1,)
3060    elif (not isinstance(self.loss_fn, losses.Loss) or
3061          (isinstance(self.loss_fn, losses.LossFunctionWrapper) and
3062           (getattr(losses, self.loss_fn.fn.__name__, None) is None))):
3063      # If the given loss is not an instance of the `Loss` class (custom
3064      # class) or if the loss function that is wrapped is not in the
3065      # `losses` module, then it is a user-defined loss and we make no
3066      # assumptions about it.
3067      return None
3068    else:
3069      return self.shape
3070
3071  def sample_weights_mismatch(self):
3072    """Check if the sample weight and the mode match or not."""
3073    # If there is a mismatch between sample weight mode and the placeholders
3074    # created, then recompile the sub-graphs that depend on sample weights.
3075    return (
3076        (self.sample_weight_mode is not None and self.sample_weight is None) or
3077        (self.sample_weight_mode is None and self.sample_weight is not None))
3078
3079  def populate_sample_weight(self, sample_weight, sample_weight_mode):
3080    """Populate the sample weight and based on the sample weight mode."""
3081    if (sample_weight is None and
3082        (self.should_skip_target_weights() or sample_weight_mode is None or
3083         context.executing_eagerly())):
3084      self._sample_weight = None
3085      return
3086
3087    assert sample_weight_mode in ['temporal', 'samplewise']
3088    if sample_weight_mode == 'temporal':
3089      default_value = [[1.]]
3090      shape = [None, None]
3091    else:
3092      # sample_weight_mode == 'samplewise'
3093      default_value = [1.]
3094      shape = [None]
3095
3096    if sample_weight is not None:
3097      if not sample_weight.shape.is_compatible_with(shape):
3098        raise ValueError('Received sample weight with shape {}. Expected shape '
3099                         '{}.'.format(sample_weight.shape, shape))
3100      self._sample_weight = sample_weight
3101    else:
3102      self._sample_weight = array_ops.placeholder_with_default(
3103          constant_op.constant(default_value, dtype=K.floatx()),
3104          shape=shape,
3105          name=self.output_name + '_sample_weights')
3106
3107
3108class _TrainingTarget(object):
3109  """Container for a target tensor (y_true) and its metadata (shape, loss...).
3110
3111  Args:
3112    target: A target tensor for the model. It may be `None` if the
3113      output is excluded from loss computation. It is still kept as None
3114      since each output of the model should have a corresponding target. If
3115      the target is None, the rest of the attributes will be None as well.
3116    feedable: Boolean, whether the target is feedable (requires data to be
3117      passed in `fit` or `train_on_batch`), or not (model compiled with
3118      `target_tensors` argument).
3119    skip_target_weights: Boolean, whether the target should be skipped during
3120      weights calculation.
3121  """
3122
3123  def __init__(self, target, feedable=False, skip_target_weights=True):
3124    self._target = target
3125    self._feedable = feedable
3126    self._skip_target_weights = skip_target_weights
3127
3128  @property
3129  def target(self):
3130    return self._target
3131
3132  @property
3133  def feedable(self):
3134    return self._feedable
3135
3136  @property
3137  def skip_target_weights(self):
3138    return self._skip_target_weights
3139
3140
3141def _is_symbolic_tensor(x):
3142  return tensor_util.is_tf_type(x)
3143
3144
3145def _convert_scipy_sparse_tensor(value, expected_input):
3146  """Handle scipy sparse tensor conversions.
3147
3148  This method takes a value 'value' and returns the proper conversion. If
3149  value is a scipy sparse tensor and the expected input is a dense tensor,
3150  we densify 'value'. If value is a scipy sparse tensor and the expected input
3151  is a TF SparseTensor, we convert 'value' to a SparseTensor. If 'value' is
3152  not a scipy sparse tensor, or scipy is not imported, we pass it through
3153  unchanged.
3154
3155  Args:
3156    value: An object that may be a scipy sparse tensor
3157    expected_input: The expected input placeholder.
3158
3159  Returns:
3160    The possibly-converted 'value'.
3161  """
3162  if issparse is not None and issparse(value):
3163    if K.is_sparse(expected_input):
3164      sparse_coo = value.tocoo()
3165      row, col = sparse_coo.row, sparse_coo.col
3166      data, shape = sparse_coo.data, sparse_coo.shape
3167      indices = np.concatenate((np.expand_dims(row, 1), np.expand_dims(col, 1)),
3168                               1)
3169      return sparse_tensor.SparseTensor(indices, data, shape)
3170    else:
3171      if ops.executing_eagerly_outside_functions():
3172        # In TF2 we do not silently densify sparse matrices.
3173        raise ValueError('A SciPy sparse matrix was passed to a model '
3174                         'that expects dense inputs. Please densify your '
3175                         'inputs first, such as by calling `x.toarray().')
3176      return value.toarray()
3177  else:
3178    return value
3179
3180
3181def _get_metrics_from_layers(layers):
3182  """Returns list of metrics from the given layers.
3183
3184  This will not include the `compile` metrics of a model layer.
3185
3186  Args:
3187    layers: List of layers.
3188
3189  Returns:
3190    List of metrics.
3191  """
3192  metrics = []
3193  layers = layer_utils.filter_empty_layer_containers(layers)
3194  for layer in layers:
3195    if isinstance(layer, Model):
3196      # We cannot call 'metrics' on the model because we do not want to
3197      # include the metrics that were added in compile API of a nested model.
3198      metrics.extend(layer._metrics)  # pylint: disable=protected-access
3199      metrics.extend(_get_metrics_from_layers(layer.layers))
3200    else:
3201      metrics.extend(layer.metrics)
3202  return metrics
3203
3204
3205def _non_none_constant_value(v):
3206  constant_value = tensor_util.constant_value(v)
3207  return constant_value if constant_value is not None else v
3208