1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Training-related part of the Keras engine.
16"""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22import itertools
23import json
24import os
25import warnings
26
27import six
28
29from tensorflow.python.autograph.lang import directives
30from tensorflow.python.data.experimental.ops import distribute_options
31from tensorflow.python.data.ops import dataset_ops
32from tensorflow.python.distribute import collective_all_reduce_strategy
33from tensorflow.python.distribute import distribution_strategy_context as ds_context
34from tensorflow.python.distribute import values as ds_values
35from tensorflow.python.distribute.coordinator import cluster_coordinator
36from tensorflow.python.eager import backprop
37from tensorflow.python.eager import context
38from tensorflow.python.eager import def_function
39from tensorflow.python.framework import errors
40from tensorflow.python.framework import errors_impl
41from tensorflow.python.framework import func_graph
42from tensorflow.python.framework import ops
43from tensorflow.python.framework import sparse_tensor
44from tensorflow.python.framework import tensor_shape
45from tensorflow.python.keras import backend
46from tensorflow.python.keras import callbacks as callbacks_module
47from tensorflow.python.keras import optimizer_v1
48from tensorflow.python.keras import optimizers
49from tensorflow.python.keras.engine import base_layer
50from tensorflow.python.keras.engine import base_layer_utils
51from tensorflow.python.keras.engine import compile_utils
52from tensorflow.python.keras.engine import data_adapter
53from tensorflow.python.keras.engine import training_utils
54from tensorflow.python.keras.mixed_precision import loss_scale_optimizer as lso
55from tensorflow.python.keras.mixed_precision import policy
56from tensorflow.python.keras.saving import hdf5_format
57from tensorflow.python.keras.saving import save
58from tensorflow.python.keras.saving import saving_utils
59from tensorflow.python.keras.saving.saved_model import json_utils
60from tensorflow.python.keras.saving.saved_model import model_serialization
61from tensorflow.python.keras.utils import generic_utils
62from tensorflow.python.keras.utils import layer_utils
63from tensorflow.python.keras.utils import tf_inspect
64from tensorflow.python.keras.utils import tf_utils
65from tensorflow.python.keras.utils import version_utils
66from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
67from tensorflow.python.keras.utils.io_utils import path_to_string
68from tensorflow.python.keras.utils.mode_keys import ModeKeys
69from tensorflow.python.ops import array_ops
70from tensorflow.python.ops import math_ops
71from tensorflow.python.ops import sparse_ops
72from tensorflow.python.ops import summary_ops_v2
73from tensorflow.python.ops import variables
74from tensorflow.python.platform import tf_logging as logging
75from tensorflow.python.profiler import trace
76from tensorflow.python.saved_model import constants as sm_constants
77from tensorflow.python.saved_model import loader_impl as sm_loader
78from tensorflow.python.training import checkpoint_management
79from tensorflow.python.training import py_checkpoint_reader
80from tensorflow.python.training.tracking import base as trackable
81from tensorflow.python.training.tracking import data_structures
82from tensorflow.python.training.tracking import util as trackable_utils
83from tensorflow.python.util import nest
84from tensorflow.python.util import tf_decorator
85from tensorflow.python.util.tf_export import keras_export
86from tensorflow.tools.docs import doc_controls
87
88
89# pylint: disable=g-import-not-at-top
90try:
91  import h5py
92except ImportError:
93  h5py = None
94
95try:
96  import yaml
97except ImportError:
98  yaml = None
99# pylint: enable=g-import-not-at-top
100
101
102def disable_multi_worker(method):
103  """Decorator that disallows multi-worker use of `method`."""
104
105  def _method_wrapper(self, *args, **kwargs):
106    if self._in_multi_worker_mode():  # pylint: disable=protected-access
107      raise ValueError('{} is not supported in multi-worker mode.'.format(
108          method.__name__))
109    return method(self, *args, **kwargs)
110
111  return tf_decorator.make_decorator(
112      target=method, decorator_func=_method_wrapper)
113
114
115def inject_functional_model_class(cls):
116  """Inject `Functional` into the hierarchy of this class if needed."""
117  from tensorflow.python.keras.engine import functional  # pylint: disable=g-import-not-at-top
118  from tensorflow.python.keras.engine import training_v1  # pylint: disable=g-import-not-at-top
119  if cls == Model or cls == training_v1.Model:
120    return functional.Functional
121  # In case there is any multiple inheritance, we stop injecting the
122  # class if keras model is not in its class hierarchy.
123  if cls == object:
124    return object
125
126  cls.__bases__ = tuple(inject_functional_model_class(base)
127                        for base in cls.__bases__)
128  # Trigger any `__new__` class swapping that needed to happen on `Functional`
129  # but did not because functional was not in the class hierarchy.
130  cls.__new__(cls)
131
132  return cls
133
134
135def is_functional_model_init_params(args, kwargs):
136  return (len(args) == 2 or
137          len(args) == 1 and 'outputs' in kwargs or
138          'inputs' in kwargs and 'outputs' in kwargs)
139
140
141@keras_export('keras.Model', 'keras.models.Model')
142class Model(base_layer.Layer, version_utils.ModelVersionSelector):
143  """`Model` groups layers into an object with training and inference features.
144
145  Args:
146      inputs: The input(s) of the model: a `keras.Input` object or list of
147          `keras.Input` objects.
148      outputs: The output(s) of the model. See Functional API example below.
149      name: String, the name of the model.
150
151  There are two ways to instantiate a `Model`:
152
153  1 - With the "Functional API", where you start from `Input`,
154  you chain layer calls to specify the model's forward pass,
155  and finally you create your model from inputs and outputs:
156
157  ```python
158  import tensorflow as tf
159
160  inputs = tf.keras.Input(shape=(3,))
161  x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
162  outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
163  model = tf.keras.Model(inputs=inputs, outputs=outputs)
164  ```
165
166  2 - By subclassing the `Model` class: in that case, you should define your
167  layers in `__init__` and you should implement the model's forward pass
168  in `call`.
169
170  ```python
171  import tensorflow as tf
172
173  class MyModel(tf.keras.Model):
174
175    def __init__(self):
176      super(MyModel, self).__init__()
177      self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
178      self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
179
180    def call(self, inputs):
181      x = self.dense1(inputs)
182      return self.dense2(x)
183
184  model = MyModel()
185  ```
186
187  If you subclass `Model`, you can optionally have
188  a `training` argument (boolean) in `call`, which you can use to specify
189  a different behavior in training and inference:
190
191  ```python
192  import tensorflow as tf
193
194  class MyModel(tf.keras.Model):
195
196    def __init__(self):
197      super(MyModel, self).__init__()
198      self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
199      self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
200      self.dropout = tf.keras.layers.Dropout(0.5)
201
202    def call(self, inputs, training=False):
203      x = self.dense1(inputs)
204      if training:
205        x = self.dropout(x, training=training)
206      return self.dense2(x)
207
208  model = MyModel()
209  ```
210
211  Once the model is created, you can config the model with losses and metrics
212  with `model.compile()`, train the model with `model.fit()`, or use the model
213  to do prediction with `model.predict()`.
214  """
215  _TF_MODULE_IGNORED_PROPERTIES = frozenset(
216      itertools.chain(('_train_counter', '_test_counter', '_predict_counter',
217                       '_steps_per_execution'),
218                      base_layer.Layer._TF_MODULE_IGNORED_PROPERTIES))  # pylint: disable=protected-access
219
220  def __new__(cls, *args, **kwargs):
221    # Signature detection
222    if is_functional_model_init_params(args, kwargs) and cls == Model:
223      # Functional model
224      from tensorflow.python.keras.engine import functional  # pylint: disable=g-import-not-at-top
225      return functional.Functional(skip_init=True, *args, **kwargs)
226    else:
227      return super(Model, cls).__new__(cls, *args, **kwargs)
228
229  @trackable.no_automatic_dependency_tracking
230  def __init__(self, *args, **kwargs):
231    self._is_model_for_instrumentation = True
232    base_layer.keras_api_gauge.get_cell('model').set(True)
233
234    # Special case for Subclassed Functional Model, which we couldn't detect
235    # when __new__ is called. We only realize it is a functional model when it
236    # calls super.__init__ with input and output tensor.
237    from tensorflow.python.keras.engine import functional  # pylint: disable=g-import-not-at-top
238    if (is_functional_model_init_params(args, kwargs) and
239        not isinstance(self, functional.Functional)):
240      # Filter the kwargs for multiple inheritance.
241      supported_kwargs = ['inputs', 'outputs', 'name', 'trainable', 'skip_init']
242      model_kwargs = {k: kwargs[k] for k in kwargs if k in supported_kwargs}
243      other_kwargs = {k: kwargs[k] for k in kwargs if k not in supported_kwargs}
244      inject_functional_model_class(self.__class__)
245      functional.Functional.__init__(self, *args, **model_kwargs)
246
247      # In case there is any multiple inheritance here, we need to call the
248      # __init__ for any class that appears after the Functional class.
249      clz_to_init = []
250      found_functional_class = False
251      for clz in self.__class__.__bases__:
252        if issubclass(clz, functional.Functional):
253          found_functional_class = True
254          continue
255        if found_functional_class:
256          clz_to_init.append(clz)
257
258      if clz_to_init:
259        for clz in clz_to_init:
260          clz.__init__(self, *args, **other_kwargs)
261      elif other_kwargs:
262        # In case there are unused kwargs, we should raise an error to user, in
263        # case they have a typo in the param name.
264        raise TypeError(
265            'The following keyword arguments aren\'t supported: {}'.format(
266                other_kwargs))
267      return
268
269    base_layer.keras_api_gauge.get_cell('Model subclass').set(True)
270    # The following are implemented as property functions:
271    # self.trainable_weights
272    # self.non_trainable_weights
273    # `inputs` / `outputs` will only appear in kwargs if either are misspelled.
274    generic_utils.validate_kwargs(kwargs, {
275        'trainable', 'dtype', 'dynamic', 'name', 'autocast', 'inputs', 'outputs'
276    })
277    super(Model, self).__init__(**kwargs)
278    # By default, Model is a subclass model, which is not in graph network.
279    self._is_graph_network = False
280
281    self.inputs = None
282    self.outputs = None
283    self.input_names = None
284    self.output_names = None
285    # stop_training is used by callback to stop training when error happens
286    self.stop_training = False
287    self.history = None
288    # These objects are used in the default `Model.compile`. They are not
289    # guaranteed to be set after `Model.compile` is called, as users can
290    # override compile with custom logic.
291    self.compiled_loss = None
292    self.compiled_metrics = None
293
294    # This is True for Sequential networks and Functional networks.
295    self._compute_output_and_mask_jointly = False
296
297    # Don't reset compilation if already done. This may occur if calling
298    # `__init__` (or `_init_graph_network`) on an already-compiled model
299    # such as a Sequential model. Sequential models may need to rebuild
300    # themselves after compilation.
301    self._maybe_create_attribute('_is_compiled', False)
302    self._maybe_create_attribute('optimizer', None)
303
304    # Model must be created under scope of DistStrat it will be trained with.
305    if ds_context.has_strategy():
306      self._distribution_strategy = ds_context.get_strategy()
307    else:
308      self._distribution_strategy = None
309
310    self._cluster_coordinator = None
311
312    # Defaults to value of `tf.config.experimental_functions_run_eagerly`.
313    self._run_eagerly = None
314    # Initialize cache attrs.
315    self._reset_compile_cache()
316
317    # Fault-tolerance handler. Set in `ModelCheckpoint`.
318    self._training_state = None
319    self._saved_model_inputs_spec = None
320    self._trackable_saver = (
321        trackable_utils.saver_with_op_caching(self))
322
323    self._steps_per_execution = None
324
325    self._init_batch_counters()
326    self._base_model_initialized = True
327
328  @trackable.no_automatic_dependency_tracking
329  def _init_batch_counters(self):
330    # Untracked Variables, used to keep track of mini-batches seen in `fit`,
331    # `evaluate`, and `predict`.
332    agg = variables.VariableAggregationV2.ONLY_FIRST_REPLICA
333    self._train_counter = variables.Variable(0, dtype='int64', aggregation=agg)
334    self._test_counter = variables.Variable(0, dtype='int64', aggregation=agg)
335    self._predict_counter = variables.Variable(
336        0, dtype='int64', aggregation=agg)
337
338  def __setattr__(self, name, value):
339    if not getattr(self, '_self_setattr_tracking', True):
340      super(Model, self).__setattr__(name, value)
341      return
342
343    if all(
344        isinstance(v, (base_layer.Layer,
345                       data_structures.TrackableDataStructure)) or
346        base_layer_utils.has_weights(v) for v in nest.flatten(value)):
347      try:
348        self._base_model_initialized
349      except AttributeError:
350        # six.raise_from supresses the original AttributeError from being raised
351        six.raise_from(
352            RuntimeError('It looks like you are subclassing `Model` and you '
353                         'forgot to call `super(YourClass, self).__init__()`.'
354                         ' Always start with this line.'), None)
355
356    super(Model, self).__setattr__(name, value)
357
358  @generic_utils.default
359  def build(self, input_shape):
360    """Builds the model based on input shapes received.
361
362    This is to be used for subclassed models, which do not know at instantiation
363    time what their inputs look like.
364
365    This method only exists for users who want to call `model.build()` in a
366    standalone way (as a substitute for calling the model on real data to
367    build it). It will never be called by the framework (and thus it will
368    never throw unexpected errors in an unrelated workflow).
369
370    Args:
371     input_shape: Single tuple, TensorShape, or list/dict of shapes, where
372         shapes are tuples, integers, or TensorShapes.
373
374    Raises:
375      ValueError:
376        1. In case of invalid user-provided data (not of type tuple,
377           list, TensorShape, or dict).
378        2. If the model requires call arguments that are agnostic
379           to the input shapes (positional or kwarg in call signature).
380        3. If not all layers were properly built.
381        4. If float type inputs are not supported within the layers.
382
383      In each of these cases, the user should build their model by calling it
384      on real tensor data.
385    """
386    if self._is_graph_network:
387      super(Model, self).build(input_shape)
388      return
389
390    if input_shape is None:
391      raise ValueError('Input shape must be defined when calling build on a '
392                       'model subclass network.')
393    valid_types = (tuple, list, tensor_shape.TensorShape, dict)
394    if not isinstance(input_shape, valid_types):
395      raise ValueError('Specified input shape is not one of the valid types. '
396                       'Please specify a batch input shape of type tuple or '
397                       'list of input shapes. User provided '
398                       'input type: {}'.format(type(input_shape)))
399
400    if input_shape and not self.inputs:
401      # We create placeholders for the `None`s in the shape and build the model
402      # in a Graph. Since tf.Variable is compatible with both eager execution
403      # and graph building, the variables created after building the model in
404      # a Graph are still valid when executing eagerly.
405      if context.executing_eagerly():
406        graph = func_graph.FuncGraph('build_graph')
407      else:
408        graph = backend.get_graph()
409      with graph.as_default():
410        if (isinstance(input_shape, list) and
411            all(d is None or isinstance(d, int) for d in input_shape)):
412          input_shape = tuple(input_shape)
413        if isinstance(input_shape, list):
414          x = [base_layer_utils.generate_placeholders_from_shape(shape)
415               for shape in input_shape]
416        elif isinstance(input_shape, dict):
417          x = {
418              k: base_layer_utils.generate_placeholders_from_shape(shape)
419              for k, shape in input_shape.items()
420          }
421        else:
422          x = base_layer_utils.generate_placeholders_from_shape(input_shape)
423
424        kwargs = {}
425        call_signature = self._call_full_argspec
426        call_args = call_signature.args
427        # Exclude `self`, `inputs`, and any argument with a default value.
428        if len(call_args) > 2:
429          if call_signature.defaults:
430            call_args = call_args[2:-len(call_signature.defaults)]
431          else:
432            call_args = call_args[2:]
433          for arg in call_args:
434            if arg == 'training':
435              # Case where `training` is a positional arg with no default.
436              kwargs['training'] = False
437            else:
438              # Has invalid call signature with unknown positional arguments.
439              raise ValueError(
440                  'Currently, you cannot build your model if it has '
441                  'positional or keyword arguments that are not '
442                  'inputs to the model, but are required for its '
443                  '`call` method. Instead, in order to instantiate '
444                  'and build your model, `call` your model on real '
445                  'tensor data with all expected call arguments.')
446        elif len(call_args) < 2:
447          # Signature without `inputs`.
448          raise ValueError('You can only call `build` on a model if its `call` '
449                           'method accepts an `inputs` argument.')
450        try:
451          self.call(x, **kwargs)
452        except (errors.InvalidArgumentError, TypeError):
453          raise ValueError('You cannot build your model by calling `build` '
454                           'if your layers do not support float type inputs. '
455                           'Instead, in order to instantiate and build your '
456                           'model, `call` your model on real tensor data (of '
457                           'the correct dtype).')
458    super(Model, self).build(input_shape)
459
460  @doc_controls.doc_in_current_and_subclasses
461  def call(self, inputs, training=None, mask=None):
462    """Calls the model on new inputs.
463
464    In this case `call` just reapplies
465    all ops in the graph to the new inputs
466    (e.g. build a new computational graph from the provided inputs).
467
468    Note: This method should not be called directly. It is only meant to be
469    overridden when subclassing `tf.keras.Model`.
470    To call a model on an input, always use the `__call__` method,
471    i.e. `model(inputs)`, which relies on the underlying `call` method.
472
473    Args:
474        inputs: A tensor or list of tensors.
475        training: Boolean or boolean scalar tensor, indicating whether to run
476          the `Network` in training mode or inference mode.
477        mask: A mask or list of masks. A mask can be
478            either a tensor or None (no mask).
479
480    Returns:
481        A tensor if there is a single output, or
482        a list of tensors if there are more than one outputs.
483    """
484    raise NotImplementedError('When subclassing the `Model` class, you should '
485                              'implement a `call` method.')
486
487  def compile(self,
488              optimizer='rmsprop',
489              loss=None,
490              metrics=None,
491              loss_weights=None,
492              weighted_metrics=None,
493              run_eagerly=None,
494              steps_per_execution=None,
495              **kwargs):
496    """Configures the model for training.
497
498    Args:
499        optimizer: String (name of optimizer) or optimizer instance. See
500          `tf.keras.optimizers`.
501        loss: String (name of objective function), objective function or
502          `tf.keras.losses.Loss` instance. See `tf.keras.losses`. An objective
503          function is any callable with the signature `loss = fn(y_true,
504          y_pred)`, where y_true = ground truth values with shape =
505          `[batch_size, d0, .. dN]`, except sparse loss functions such as sparse
506          categorical crossentropy where shape = `[batch_size, d0, .. dN-1]`.
507          y_pred = predicted values with shape = `[batch_size, d0, .. dN]`. It
508          returns a weighted loss float tensor. If a custom `Loss` instance is
509          used and reduction is set to NONE, return value has the shape
510          [batch_size, d0, .. dN-1] ie. per-sample or per-timestep loss values;
511          otherwise, it is a scalar. If the model has multiple outputs, you can
512          use a different loss on each output by passing a dictionary or a list
513          of losses. The loss value that will be minimized by the model will
514          then be the sum of all individual losses.
515        metrics: List of metrics to be evaluated by the model during training
516          and testing. Each of this can be a string (name of a built-in
517          function), function or a `tf.keras.metrics.Metric` instance. See
518          `tf.keras.metrics`. Typically you will use `metrics=['accuracy']`. A
519          function is any callable with the signature `result = fn(y_true,
520          y_pred)`. To specify different metrics for different outputs of a
521          multi-output model, you could also pass a dictionary, such as
522            `metrics={'output_a': 'accuracy', 'output_b': ['accuracy', 'mse']}`.
523              You can also pass a list (len = len(outputs)) of lists of metrics
524              such as `metrics=[['accuracy'], ['accuracy', 'mse']]` or
525              `metrics=['accuracy', ['accuracy', 'mse']]`. When you pass the
526              strings 'accuracy' or 'acc', we convert this to one of
527              `tf.keras.metrics.BinaryAccuracy`,
528              `tf.keras.metrics.CategoricalAccuracy`,
529              `tf.keras.metrics.SparseCategoricalAccuracy` based on the loss
530              function used and the model output shape. We do a similar
531              conversion for the strings 'crossentropy' and 'ce' as well.
532        loss_weights: Optional list or dictionary specifying scalar coefficients
533          (Python floats) to weight the loss contributions of different model
534          outputs. The loss value that will be minimized by the model will then
535          be the *weighted sum* of all individual losses, weighted by the
536          `loss_weights` coefficients.
537            If a list, it is expected to have a 1:1 mapping to the model's
538              outputs. If a dict, it is expected to map output names (strings)
539              to scalar coefficients.
540        weighted_metrics: List of metrics to be evaluated and weighted by
541          sample_weight or class_weight during training and testing.
542        run_eagerly: Bool. Defaults to `False`. If `True`, this `Model`'s
543          logic will not be wrapped in a `tf.function`. Recommended to leave
544          this as `None` unless your `Model` cannot be run inside a
545          `tf.function`.
546        steps_per_execution: Int. Defaults to 1. The number of batches to
547          run during each `tf.function` call. Running multiple batches
548          inside a single `tf.function` call can greatly improve performance
549          on TPUs or small models with a large Python overhead.
550          At most, one full epoch will be run each
551          execution. If a number larger than the size of the epoch is passed,
552          the execution will be truncated to the size of the epoch.
553          Note that if `steps_per_execution` is set to `N`,
554          `Callback.on_batch_begin` and `Callback.on_batch_end` methods
555          will only be called every `N` batches
556          (i.e. before/after each `tf.function` execution).
557        **kwargs: Arguments supported for backwards compatibility only.
558
559    Raises:
560        ValueError: In case of invalid arguments for
561            `optimizer`, `loss` or `metrics`.
562    """
563    base_layer.keras_api_gauge.get_cell('compile').set(True)
564    with self.distribute_strategy.scope():
565      if 'experimental_steps_per_execution' in kwargs:
566        logging.warn('The argument `steps_per_execution` is no longer '
567                     'experimental. Pass `steps_per_execution` instead of '
568                     '`experimental_steps_per_execution`.')
569        if not steps_per_execution:
570          steps_per_execution = kwargs.pop('experimental_steps_per_execution')
571
572      self._validate_compile(optimizer, metrics, **kwargs)
573      self._run_eagerly = run_eagerly
574
575      self.optimizer = self._get_optimizer(optimizer)
576      self.compiled_loss = compile_utils.LossesContainer(
577          loss, loss_weights, output_names=self.output_names)
578      self.compiled_metrics = compile_utils.MetricsContainer(
579          metrics, weighted_metrics, output_names=self.output_names)
580
581      self._configure_steps_per_execution(steps_per_execution or 1)
582
583      # Initializes attrs that are reset each time `compile` is called.
584      self._reset_compile_cache()
585      self._is_compiled = True
586
587      self.loss = loss or {}  # Backwards compat.
588
589  def _get_optimizer(self, optimizer):
590    """Wraps `optimizer` in `LossScaleOptimizer` if necessary."""
591    # The deprecated PolicyV1 has a loss_scale, which we use for backwards
592    # compatibility to match TF 2.3 behavior. The new Policy does not have a
593    # loss_scale, so we use dynamic loss scaling if the mixed_float16 policy is
594    # used.
595    if isinstance(self._dtype_policy, policy.PolicyV1):
596      loss_scale = self._dtype_policy.loss_scale
597    elif self._dtype_policy.name == 'mixed_float16':
598      loss_scale = 'dynamic'
599    else:
600      loss_scale = None
601
602    def _get_single_optimizer(opt):
603      opt = optimizers.get(opt)
604      if (loss_scale is not None and
605          not isinstance(opt, lso.LossScaleOptimizer)):
606        if loss_scale == 'dynamic':
607          opt = lso.LossScaleOptimizer(opt)
608        else:
609          opt = lso.LossScaleOptimizerV1(opt, loss_scale)
610      return opt
611
612    return nest.map_structure(_get_single_optimizer, optimizer)
613
614  @trackable.no_automatic_dependency_tracking
615  def _reset_compile_cache(self):
616    self.train_function = None
617    self.test_function = None
618    self.predict_function = None
619
620    # Used to cache `trainable` attr of `Layer`s for `fit`.
621    self._compiled_trainable_state = self._get_trainable_state()
622
623  @trackable.no_automatic_dependency_tracking
624  def _configure_steps_per_execution(self, steps_per_execution):
625    self._steps_per_execution = variables.Variable(
626        steps_per_execution,
627        dtype='int64',
628        aggregation=variables.VariableAggregationV2.ONLY_FIRST_REPLICA)
629
630  @property
631  def _should_compute_mask(self):
632    return False
633
634  @property
635  def metrics(self):
636    """Returns the model's metrics added using `compile`, `add_metric` APIs.
637
638    Note: Metrics passed to `compile()` are available only after a `keras.Model`
639    has been trained/evaluated on actual data.
640
641    Examples:
642
643    >>> inputs = tf.keras.layers.Input(shape=(3,))
644    >>> outputs = tf.keras.layers.Dense(2)(inputs)
645    >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
646    >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"])
647    >>> [m.name for m in model.metrics]
648    []
649
650    >>> x = np.random.random((2, 3))
651    >>> y = np.random.randint(0, 2, (2, 2))
652    >>> model.fit(x, y)
653    >>> [m.name for m in model.metrics]
654    ['loss', 'mae']
655
656    >>> inputs = tf.keras.layers.Input(shape=(3,))
657    >>> d = tf.keras.layers.Dense(2, name='out')
658    >>> output_1 = d(inputs)
659    >>> output_2 = d(inputs)
660    >>> model = tf.keras.models.Model(
661    ...    inputs=inputs, outputs=[output_1, output_2])
662    >>> model.add_metric(
663    ...    tf.reduce_sum(output_2), name='mean', aggregation='mean')
664    >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"])
665    >>> model.fit(x, (y, y))
666    >>> [m.name for m in model.metrics]
667    ['loss', 'out_loss', 'out_1_loss', 'out_mae', 'out_acc', 'out_1_mae',
668    'out_1_acc', 'mean']
669
670    """
671    metrics = []
672    if self._is_compiled:
673      # TODO(omalleyt): Track `LossesContainer` and `MetricsContainer` objects
674      # so that attr names are not load-bearing.
675      if self.compiled_loss is not None:
676        metrics += self.compiled_loss.metrics
677      if self.compiled_metrics is not None:
678        metrics += self.compiled_metrics.metrics
679
680    for l in self._flatten_layers():
681      metrics.extend(l._metrics)  # pylint: disable=protected-access
682    return metrics
683
684  @property
685  def metrics_names(self):
686    """Returns the model's display labels for all outputs.
687
688    Note: `metrics_names` are available only after a `keras.Model` has been
689    trained/evaluated on actual data.
690
691    Examples:
692
693    >>> inputs = tf.keras.layers.Input(shape=(3,))
694    >>> outputs = tf.keras.layers.Dense(2)(inputs)
695    >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
696    >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"])
697    >>> model.metrics_names
698    []
699
700    >>> x = np.random.random((2, 3))
701    >>> y = np.random.randint(0, 2, (2, 2))
702    >>> model.fit(x, y)
703    >>> model.metrics_names
704    ['loss', 'mae']
705
706    >>> inputs = tf.keras.layers.Input(shape=(3,))
707    >>> d = tf.keras.layers.Dense(2, name='out')
708    >>> output_1 = d(inputs)
709    >>> output_2 = d(inputs)
710    >>> model = tf.keras.models.Model(
711    ...    inputs=inputs, outputs=[output_1, output_2])
712    >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"])
713    >>> model.fit(x, (y, y))
714    >>> model.metrics_names
715    ['loss', 'out_loss', 'out_1_loss', 'out_mae', 'out_acc', 'out_1_mae',
716    'out_1_acc']
717
718    """
719
720    # This property includes all output names including `loss` and per-output
721    # losses for backward compatibility.
722    return [m.name for m in self.metrics]
723
724  @property
725  def distribute_strategy(self):
726    """The `tf.distribute.Strategy` this model was created under."""
727    return self._distribution_strategy or ds_context.get_strategy()
728
729  @property
730  def run_eagerly(self):
731    """Settable attribute indicating whether the model should run eagerly.
732
733    Running eagerly means that your model will be run step by step,
734    like Python code. Your model might run slower, but it should become easier
735    for you to debug it by stepping into individual layer calls.
736
737    By default, we will attempt to compile your model to a static graph to
738    deliver the best execution performance.
739
740    Returns:
741      Boolean, whether the model should run eagerly.
742    """
743    if self.dynamic and self._run_eagerly is False:  # pylint:disable=g-bool-id-comparison
744      # TODO(fchollet): consider using py_func to enable this.
745      raise ValueError('Your model contains layers that can only be '
746                       'successfully run in eager execution (layers '
747                       'constructed with `dynamic=True`). '
748                       'You cannot set `run_eagerly=False`.')
749
750    if self._cluster_coordinator and self._run_eagerly:
751      raise ValueError('When using `Model` with `ParameterServerStrategy`, '
752                       '`run_eagerly` is not supported.')
753
754    # Run eagerly logic, by priority:
755    # (1) Dynamic models must be run eagerly.
756    # (2) Explicitly setting run_eagerly causes a Model to be run eagerly.
757    # (3) Not explicitly setting run_eagerly defaults to TF's global setting.
758    return (self.dynamic or self._run_eagerly or
759            (def_function.functions_run_eagerly() and
760             self._run_eagerly is None))
761
762  @run_eagerly.setter
763  def run_eagerly(self, value):
764    self._run_eagerly = value
765
766  def train_step(self, data):
767    """The logic for one training step.
768
769    This method can be overridden to support custom training logic.
770    This method is called by `Model.make_train_function`.
771
772    This method should contain the mathematical logic for one step of training.
773    This typically includes the forward pass, loss calculation, backpropagation,
774    and metric updates.
775
776    Configuration details for *how* this logic is run (e.g. `tf.function` and
777    `tf.distribute.Strategy` settings), should be left to
778    `Model.make_train_function`, which can also be overridden.
779
780    Args:
781      data: A nested structure of `Tensor`s.
782
783    Returns:
784      A `dict` containing values that will be passed to
785      `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the
786      values of the `Model`'s metrics are returned. Example:
787      `{'loss': 0.2, 'accuracy': 0.7}`.
788
789    """
790    # These are the only transformations `Model.fit` applies to user-input
791    # data when a `tf.data.Dataset` is provided.
792    data = data_adapter.expand_1d(data)
793    x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
794
795    with backprop.GradientTape() as tape:
796      y_pred = self(x, training=True)
797      loss = self.compiled_loss(
798          y, y_pred, sample_weight, regularization_losses=self.losses)
799    self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
800    self.compiled_metrics.update_state(y, y_pred, sample_weight)
801    return {m.name: m.result() for m in self.metrics}
802
803  def make_train_function(self):
804    """Creates a function that executes one step of training.
805
806    This method can be overridden to support custom training logic.
807    This method is called by `Model.fit` and `Model.train_on_batch`.
808
809    Typically, this method directly controls `tf.function` and
810    `tf.distribute.Strategy` settings, and delegates the actual training
811    logic to `Model.train_step`.
812
813    This function is cached the first time `Model.fit` or
814    `Model.train_on_batch` is called. The cache is cleared whenever
815    `Model.compile` is called.
816
817    Returns:
818      Function. The function created by this method should accept a
819      `tf.data.Iterator`, and return a `dict` containing values that will
820      be passed to `tf.keras.Callbacks.on_train_batch_end`, such as
821      `{'loss': 0.2, 'accuracy': 0.7}`.
822    """
823    if self.train_function is not None:
824      return self.train_function
825
826    def step_function(model, iterator):
827      """Runs a single training step."""
828
829      def run_step(data):
830        outputs = model.train_step(data)
831        # Ensure counter is updated only if `train_step` succeeds.
832        with ops.control_dependencies(_minimum_control_deps(outputs)):
833          model._train_counter.assign_add(1)  # pylint: disable=protected-access
834        return outputs
835
836      data = next(iterator)
837      outputs = model.distribute_strategy.run(run_step, args=(data,))
838      outputs = reduce_per_replica(
839          outputs, self.distribute_strategy, reduction='first')
840      write_scalar_summaries(outputs, step=model._train_counter)  # pylint: disable=protected-access
841      return outputs
842
843    if self._steps_per_execution.numpy().item() == 1:
844
845      def train_function(iterator):
846        """Runs a training execution with one step."""
847        return step_function(self, iterator)
848
849    else:
850
851      def train_function(iterator):
852        """Runs a training execution with multiple steps."""
853        for _ in math_ops.range(self._steps_per_execution):
854          outputs = step_function(self, iterator)
855        return outputs
856
857    if not self.run_eagerly:
858      train_function = def_function.function(
859          train_function, experimental_relax_shapes=True)
860
861    self.train_function = train_function
862
863    if self._cluster_coordinator:
864      self.train_function = lambda iterator: self._cluster_coordinator.schedule(  # pylint: disable=g-long-lambda
865          train_function, args=(iterator,))
866
867    return self.train_function
868
869  def fit(self,
870          x=None,
871          y=None,
872          batch_size=None,
873          epochs=1,
874          verbose=1,
875          callbacks=None,
876          validation_split=0.,
877          validation_data=None,
878          shuffle=True,
879          class_weight=None,
880          sample_weight=None,
881          initial_epoch=0,
882          steps_per_epoch=None,
883          validation_steps=None,
884          validation_batch_size=None,
885          validation_freq=1,
886          max_queue_size=10,
887          workers=1,
888          use_multiprocessing=False):
889    """Trains the model for a fixed number of epochs (iterations on a dataset).
890
891    Args:
892        x: Input data. It could be:
893          - A Numpy array (or array-like), or a list of arrays
894            (in case the model has multiple inputs).
895          - A TensorFlow tensor, or a list of tensors
896            (in case the model has multiple inputs).
897          - A dict mapping input names to the corresponding array/tensors,
898            if the model has named inputs.
899          - A `tf.data` dataset. Should return a tuple
900            of either `(inputs, targets)` or
901            `(inputs, targets, sample_weights)`.
902          - A generator or `keras.utils.Sequence` returning `(inputs, targets)`
903            or `(inputs, targets, sample_weights)`.
904          A more detailed description of unpacking behavior for iterator types
905          (Dataset, generator, Sequence) is given below.
906        y: Target data. Like the input data `x`,
907          it could be either Numpy array(s) or TensorFlow tensor(s).
908          It should be consistent with `x` (you cannot have Numpy inputs and
909          tensor targets, or inversely). If `x` is a dataset, generator,
910          or `keras.utils.Sequence` instance, `y` should
911          not be specified (since targets will be obtained from `x`).
912        batch_size: Integer or `None`.
913            Number of samples per gradient update.
914            If unspecified, `batch_size` will default to 32.
915            Do not specify the `batch_size` if your data is in the
916            form of datasets, generators, or `keras.utils.Sequence` instances
917            (since they generate batches).
918        epochs: Integer. Number of epochs to train the model.
919            An epoch is an iteration over the entire `x` and `y`
920            data provided.
921            Note that in conjunction with `initial_epoch`,
922            `epochs` is to be understood as "final epoch".
923            The model is not trained for a number of iterations
924            given by `epochs`, but merely until the epoch
925            of index `epochs` is reached.
926        verbose: 0, 1, or 2. Verbosity mode.
927            0 = silent, 1 = progress bar, 2 = one line per epoch.
928            Note that the progress bar is not particularly useful when
929            logged to a file, so verbose=2 is recommended when not running
930            interactively (eg, in a production environment).
931        callbacks: List of `keras.callbacks.Callback` instances.
932            List of callbacks to apply during training.
933            See `tf.keras.callbacks`. Note `tf.keras.callbacks.ProgbarLogger`
934            and `tf.keras.callbacks.History` callbacks are created automatically
935            and need not be passed into `model.fit`.
936            `tf.keras.callbacks.ProgbarLogger` is created or not based on
937            `verbose` argument to `model.fit`.
938        validation_split: Float between 0 and 1.
939            Fraction of the training data to be used as validation data.
940            The model will set apart this fraction of the training data,
941            will not train on it, and will evaluate
942            the loss and any model metrics
943            on this data at the end of each epoch.
944            The validation data is selected from the last samples
945            in the `x` and `y` data provided, before shuffling. This argument is
946            not supported when `x` is a dataset, generator or
947           `keras.utils.Sequence` instance.
948        validation_data: Data on which to evaluate
949            the loss and any model metrics at the end of each epoch.
950            The model will not be trained on this data. Thus, note the fact
951            that the validation loss of data provided using `validation_split`
952            or `validation_data` is not affected by regularization layers like
953            noise and dropout.
954            `validation_data` will override `validation_split`.
955            `validation_data` could be:
956              - tuple `(x_val, y_val)` of Numpy arrays or tensors
957              - tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays
958              - dataset
959            For the first two cases, `batch_size` must be provided.
960            For the last case, `validation_steps` could be provided.
961            Note that `validation_data` does not support all the data types that
962            are supported in `x`, eg, dict, generator or `keras.utils.Sequence`.
963        shuffle: Boolean (whether to shuffle the training data
964            before each epoch) or str (for 'batch'). This argument is ignored
965            when `x` is a generator or an object of tf.data.Dataset.
966            'batch' is a special option for dealing
967            with the limitations of HDF5 data; it shuffles in batch-sized
968            chunks. Has no effect when `steps_per_epoch` is not `None`.
969        class_weight: Optional dictionary mapping class indices (integers)
970            to a weight (float) value, used for weighting the loss function
971            (during training only).
972            This can be useful to tell the model to
973            "pay more attention" to samples from
974            an under-represented class.
975        sample_weight: Optional Numpy array of weights for
976            the training samples, used for weighting the loss function
977            (during training only). You can either pass a flat (1D)
978            Numpy array with the same length as the input samples
979            (1:1 mapping between weights and samples),
980            or in the case of temporal data,
981            you can pass a 2D array with shape
982            `(samples, sequence_length)`,
983            to apply a different weight to every timestep of every sample. This
984            argument is not supported when `x` is a dataset, generator, or
985           `keras.utils.Sequence` instance, instead provide the sample_weights
986            as the third element of `x`.
987        initial_epoch: Integer.
988            Epoch at which to start training
989            (useful for resuming a previous training run).
990        steps_per_epoch: Integer or `None`.
991            Total number of steps (batches of samples)
992            before declaring one epoch finished and starting the
993            next epoch. When training with input tensors such as
994            TensorFlow data tensors, the default `None` is equal to
995            the number of samples in your dataset divided by
996            the batch size, or 1 if that cannot be determined. If x is a
997            `tf.data` dataset, and 'steps_per_epoch'
998            is None, the epoch will run until the input dataset is exhausted.
999            When passing an infinitely repeating dataset, you must specify the
1000            `steps_per_epoch` argument. This argument is not supported with
1001            array inputs.
1002        validation_steps: Only relevant if `validation_data` is provided and
1003            is a `tf.data` dataset. Total number of steps (batches of
1004            samples) to draw before stopping when performing validation
1005            at the end of every epoch. If 'validation_steps' is None, validation
1006            will run until the `validation_data` dataset is exhausted. In the
1007            case of an infinitely repeated dataset, it will run into an
1008            infinite loop. If 'validation_steps' is specified and only part of
1009            the dataset will be consumed, the evaluation will start from the
1010            beginning of the dataset at each epoch. This ensures that the same
1011            validation samples are used every time.
1012        validation_batch_size: Integer or `None`.
1013            Number of samples per validation batch.
1014            If unspecified, will default to `batch_size`.
1015            Do not specify the `validation_batch_size` if your data is in the
1016            form of datasets, generators, or `keras.utils.Sequence` instances
1017            (since they generate batches).
1018        validation_freq: Only relevant if validation data is provided. Integer
1019            or `collections.abc.Container` instance (e.g. list, tuple, etc.).
1020            If an integer, specifies how many training epochs to run before a
1021            new validation run is performed, e.g. `validation_freq=2` runs
1022            validation every 2 epochs. If a Container, specifies the epochs on
1023            which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
1024            validation at the end of the 1st, 2nd, and 10th epochs.
1025        max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
1026            input only. Maximum size for the generator queue.
1027            If unspecified, `max_queue_size` will default to 10.
1028        workers: Integer. Used for generator or `keras.utils.Sequence` input
1029            only. Maximum number of processes to spin up
1030            when using process-based threading. If unspecified, `workers`
1031            will default to 1. If 0, will execute the generator on the main
1032            thread.
1033        use_multiprocessing: Boolean. Used for generator or
1034            `keras.utils.Sequence` input only. If `True`, use process-based
1035            threading. If unspecified, `use_multiprocessing` will default to
1036            `False`. Note that because this implementation relies on
1037            multiprocessing, you should not pass non-picklable arguments to
1038            the generator as they can't be passed easily to children processes.
1039
1040    Unpacking behavior for iterator-like inputs:
1041        A common pattern is to pass a tf.data.Dataset, generator, or
1042      tf.keras.utils.Sequence to the `x` argument of fit, which will in fact
1043      yield not only features (x) but optionally targets (y) and sample weights.
1044      Keras requires that the output of such iterator-likes be unambiguous. The
1045      iterator should return a tuple of length 1, 2, or 3, where the optional
1046      second and third elements will be used for y and sample_weight
1047      respectively. Any other type provided will be wrapped in a length one
1048      tuple, effectively treating everything as 'x'. When yielding dicts, they
1049      should still adhere to the top-level tuple structure.
1050      e.g. `({"x0": x0, "x1": x1}, y)`. Keras will not attempt to separate
1051      features, targets, and weights from the keys of a single dict.
1052        A notable unsupported data type is the namedtuple. The reason is that
1053      it behaves like both an ordered datatype (tuple) and a mapping
1054      datatype (dict). So given a namedtuple of the form:
1055          `namedtuple("example_tuple", ["y", "x"])`
1056      it is ambiguous whether to reverse the order of the elements when
1057      interpreting the value. Even worse is a tuple of the form:
1058          `namedtuple("other_tuple", ["x", "y", "z"])`
1059      where it is unclear if the tuple was intended to be unpacked into x, y,
1060      and sample_weight or passed through as a single element to `x`. As a
1061      result the data processing code will simply raise a ValueError if it
1062      encounters a namedtuple. (Along with instructions to remedy the issue.)
1063
1064    Returns:
1065        A `History` object. Its `History.history` attribute is
1066        a record of training loss values and metrics values
1067        at successive epochs, as well as validation loss values
1068        and validation metrics values (if applicable).
1069
1070    Raises:
1071        RuntimeError: 1. If the model was never compiled or,
1072        2. If `model.fit` is  wrapped in `tf.function`.
1073
1074        ValueError: In case of mismatch between the provided input data
1075            and what the model expects or when the input data is empty.
1076    """
1077    base_layer.keras_api_gauge.get_cell('fit').set(True)
1078    # Legacy graph support is contained in `training_v1.Model`.
1079    version_utils.disallow_legacy_graph('Model', 'fit')
1080    self._assert_compile_was_called()
1081    self._check_call_args('fit')
1082    _disallow_inside_tf_function('fit')
1083
1084    if validation_split:
1085      # Create the validation data using the training data. Only supported for
1086      # `Tensor` and `NumPy` input.
1087      (x, y, sample_weight), validation_data = (
1088          data_adapter.train_validation_split(
1089              (x, y, sample_weight), validation_split=validation_split))
1090
1091    if validation_data:
1092      val_x, val_y, val_sample_weight = (
1093          data_adapter.unpack_x_y_sample_weight(validation_data))
1094
1095    if self.distribute_strategy._should_use_with_coordinator:  # pylint: disable=protected-access
1096      self._cluster_coordinator = cluster_coordinator.ClusterCoordinator(
1097          self.distribute_strategy)
1098
1099    with self.distribute_strategy.scope(), \
1100         training_utils.RespectCompiledTrainableState(self):
1101      # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
1102      data_handler = data_adapter.get_data_handler(
1103          x=x,
1104          y=y,
1105          sample_weight=sample_weight,
1106          batch_size=batch_size,
1107          steps_per_epoch=steps_per_epoch,
1108          initial_epoch=initial_epoch,
1109          epochs=epochs,
1110          shuffle=shuffle,
1111          class_weight=class_weight,
1112          max_queue_size=max_queue_size,
1113          workers=workers,
1114          use_multiprocessing=use_multiprocessing,
1115          model=self,
1116          steps_per_execution=self._steps_per_execution)
1117
1118      # Container that configures and calls `tf.keras.Callback`s.
1119      if not isinstance(callbacks, callbacks_module.CallbackList):
1120        callbacks = callbacks_module.CallbackList(
1121            callbacks,
1122            add_history=True,
1123            add_progbar=verbose != 0,
1124            model=self,
1125            verbose=verbose,
1126            epochs=epochs,
1127            steps=data_handler.inferred_steps)
1128
1129      self.stop_training = False
1130      self.train_function = self.make_train_function()
1131      self._train_counter.assign(0)
1132      callbacks.on_train_begin()
1133      training_logs = None
1134      # Handle fault-tolerance for multi-worker.
1135      # TODO(omalleyt): Fix the ordering issues that mean this has to
1136      # happen after `callbacks.on_train_begin`.
1137      data_handler._initial_epoch = (  # pylint: disable=protected-access
1138          self._maybe_load_initial_epoch_from_ckpt(initial_epoch))
1139      logs = None
1140      for epoch, iterator in data_handler.enumerate_epochs():
1141        self.reset_metrics()
1142        callbacks.on_epoch_begin(epoch)
1143        with data_handler.catch_stop_iteration():
1144          for step in data_handler.steps():
1145            with trace.Trace(
1146                'train',
1147                epoch_num=epoch,
1148                step_num=step,
1149                batch_size=batch_size,
1150                _r=1):
1151              callbacks.on_train_batch_begin(step)
1152              tmp_logs = self.train_function(iterator)
1153              if data_handler.should_sync:
1154                context.async_wait()
1155              logs = tmp_logs  # No error, now safe to assign to logs.
1156              end_step = step + data_handler.step_increment
1157              callbacks.on_train_batch_end(end_step, logs)
1158              if self.stop_training:
1159                break
1160
1161        logs = data_handler.resolve_logs(logs)
1162        if logs is None:
1163          raise ValueError('Expect x to be a non-empty array or dataset.')
1164        epoch_logs = copy.copy(logs)
1165
1166        # Run validation.
1167        if validation_data and self._should_eval(epoch, validation_freq):
1168          # Create data_handler for evaluation and cache it.
1169          if getattr(self, '_eval_data_handler', None) is None:
1170            self._fit_frame = tf_inspect.currentframe()
1171            self._eval_data_handler = data_adapter.get_data_handler(
1172                x=val_x,
1173                y=val_y,
1174                sample_weight=val_sample_weight,
1175                batch_size=validation_batch_size or batch_size,
1176                steps_per_epoch=validation_steps,
1177                initial_epoch=0,
1178                epochs=1,
1179                max_queue_size=max_queue_size,
1180                workers=workers,
1181                use_multiprocessing=use_multiprocessing,
1182                model=self,
1183                steps_per_execution=self._steps_per_execution)
1184          val_logs = self.evaluate(
1185              x=val_x,
1186              y=val_y,
1187              sample_weight=val_sample_weight,
1188              batch_size=validation_batch_size or batch_size,
1189              steps=validation_steps,
1190              callbacks=callbacks,
1191              max_queue_size=max_queue_size,
1192              workers=workers,
1193              use_multiprocessing=use_multiprocessing,
1194              return_dict=True)
1195          val_logs = {'val_' + name: val for name, val in val_logs.items()}
1196          epoch_logs.update(val_logs)
1197
1198        callbacks.on_epoch_end(epoch, epoch_logs)
1199        training_logs = epoch_logs
1200        if self.stop_training:
1201          break
1202
1203      # If eval data_hanlder exists, delete it after all epochs are done.
1204      if getattr(self, '_eval_data_handler', None) is not None:
1205        del self._eval_data_handler
1206        del self._fit_frame
1207      callbacks.on_train_end(logs=training_logs)
1208      return self.history
1209
1210  def test_step(self, data):
1211    """The logic for one evaluation step.
1212
1213    This method can be overridden to support custom evaluation logic.
1214    This method is called by `Model.make_test_function`.
1215
1216    This function should contain the mathematical logic for one step of
1217    evaluation.
1218    This typically includes the forward pass, loss calculation, and metrics
1219    updates.
1220
1221    Configuration details for *how* this logic is run (e.g. `tf.function` and
1222    `tf.distribute.Strategy` settings), should be left to
1223    `Model.make_test_function`, which can also be overridden.
1224
1225    Args:
1226      data: A nested structure of `Tensor`s.
1227
1228    Returns:
1229      A `dict` containing values that will be passed to
1230      `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the
1231      values of the `Model`'s metrics are returned.
1232    """
1233    data = data_adapter.expand_1d(data)
1234    x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
1235
1236    y_pred = self(x, training=False)
1237    # Updates stateful loss metrics.
1238    self.compiled_loss(
1239        y, y_pred, sample_weight, regularization_losses=self.losses)
1240
1241    self.compiled_metrics.update_state(y, y_pred, sample_weight)
1242    return {m.name: m.result() for m in self.metrics}
1243
1244  def make_test_function(self):
1245    """Creates a function that executes one step of evaluation.
1246
1247    This method can be overridden to support custom evaluation logic.
1248    This method is called by `Model.evaluate` and `Model.test_on_batch`.
1249
1250    Typically, this method directly controls `tf.function` and
1251    `tf.distribute.Strategy` settings, and delegates the actual evaluation
1252    logic to `Model.test_step`.
1253
1254    This function is cached the first time `Model.evaluate` or
1255    `Model.test_on_batch` is called. The cache is cleared whenever
1256    `Model.compile` is called.
1257
1258    Returns:
1259      Function. The function created by this method should accept a
1260      `tf.data.Iterator`, and return a `dict` containing values that will
1261      be passed to `tf.keras.Callbacks.on_test_batch_end`.
1262    """
1263    if self.test_function is not None:
1264      return self.test_function
1265
1266    def step_function(model, iterator):
1267      """Runs a single evaluation step."""
1268
1269      def run_step(data):
1270        outputs = model.test_step(data)
1271        # Ensure counter is updated only if `test_step` succeeds.
1272        with ops.control_dependencies(_minimum_control_deps(outputs)):
1273          model._test_counter.assign_add(1)  # pylint: disable=protected-access
1274        return outputs
1275
1276      data = next(iterator)
1277      outputs = model.distribute_strategy.run(run_step, args=(data,))
1278      outputs = reduce_per_replica(
1279          outputs, self.distribute_strategy, reduction='first')
1280      return outputs
1281
1282    if self._steps_per_execution.numpy().item() == 1:
1283
1284      def test_function(iterator):
1285        """Runs an evaluation execution with one step."""
1286        return step_function(self, iterator)
1287
1288    else:
1289
1290      def test_function(iterator):
1291        """Runs an evaluation execution with multiple steps."""
1292        for _ in math_ops.range(self._steps_per_execution):
1293          outputs = step_function(self, iterator)
1294        return outputs
1295
1296    if not self.run_eagerly:
1297      test_function = def_function.function(
1298          test_function, experimental_relax_shapes=True)
1299
1300    self.test_function = test_function
1301    return self.test_function
1302
1303  def evaluate(self,
1304               x=None,
1305               y=None,
1306               batch_size=None,
1307               verbose=1,
1308               sample_weight=None,
1309               steps=None,
1310               callbacks=None,
1311               max_queue_size=10,
1312               workers=1,
1313               use_multiprocessing=False,
1314               return_dict=False):
1315    """Returns the loss value & metrics values for the model in test mode.
1316
1317    Computation is done in batches (see the `batch_size` arg.)
1318
1319    Args:
1320        x: Input data. It could be:
1321          - A Numpy array (or array-like), or a list of arrays
1322            (in case the model has multiple inputs).
1323          - A TensorFlow tensor, or a list of tensors
1324            (in case the model has multiple inputs).
1325          - A dict mapping input names to the corresponding array/tensors,
1326            if the model has named inputs.
1327          - A `tf.data` dataset. Should return a tuple
1328            of either `(inputs, targets)` or
1329            `(inputs, targets, sample_weights)`.
1330          - A generator or `keras.utils.Sequence` returning `(inputs, targets)`
1331            or `(inputs, targets, sample_weights)`.
1332          A more detailed description of unpacking behavior for iterator types
1333          (Dataset, generator, Sequence) is given in the `Unpacking behavior
1334          for iterator-like inputs` section of `Model.fit`.
1335        y: Target data. Like the input data `x`, it could be either Numpy
1336          array(s) or TensorFlow tensor(s). It should be consistent with `x`
1337          (you cannot have Numpy inputs and tensor targets, or inversely). If
1338          `x` is a dataset, generator or `keras.utils.Sequence` instance, `y`
1339          should not be specified (since targets will be obtained from the
1340          iterator/dataset).
1341        batch_size: Integer or `None`. Number of samples per batch of
1342          computation. If unspecified, `batch_size` will default to 32. Do not
1343          specify the `batch_size` if your data is in the form of a dataset,
1344          generators, or `keras.utils.Sequence` instances (since they generate
1345          batches).
1346        verbose: 0 or 1. Verbosity mode. 0 = silent, 1 = progress bar.
1347        sample_weight: Optional Numpy array of weights for the test samples,
1348          used for weighting the loss function. You can either pass a flat (1D)
1349          Numpy array with the same length as the input samples
1350            (1:1 mapping between weights and samples), or in the case of
1351              temporal data, you can pass a 2D array with shape `(samples,
1352              sequence_length)`, to apply a different weight to every timestep
1353              of every sample. This argument is not supported when `x` is a
1354              dataset, instead pass sample weights as the third element of `x`.
1355        steps: Integer or `None`. Total number of steps (batches of samples)
1356          before declaring the evaluation round finished. Ignored with the
1357          default value of `None`. If x is a `tf.data` dataset and `steps` is
1358          None, 'evaluate' will run until the dataset is exhausted. This
1359          argument is not supported with array inputs.
1360        callbacks: List of `keras.callbacks.Callback` instances. List of
1361          callbacks to apply during evaluation. See
1362          [callbacks](/api_docs/python/tf/keras/callbacks).
1363        max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
1364          input only. Maximum size for the generator queue. If unspecified,
1365          `max_queue_size` will default to 10.
1366        workers: Integer. Used for generator or `keras.utils.Sequence` input
1367          only. Maximum number of processes to spin up when using process-based
1368          threading. If unspecified, `workers` will default to 1. If 0, will
1369          execute the generator on the main thread.
1370        use_multiprocessing: Boolean. Used for generator or
1371          `keras.utils.Sequence` input only. If `True`, use process-based
1372          threading. If unspecified, `use_multiprocessing` will default to
1373          `False`. Note that because this implementation relies on
1374          multiprocessing, you should not pass non-picklable arguments to the
1375          generator as they can't be passed easily to children processes.
1376        return_dict: If `True`, loss and metric results are returned as a dict,
1377          with each key being the name of the metric. If `False`, they are
1378          returned as a list.
1379
1380    See the discussion of `Unpacking behavior for iterator-like inputs` for
1381    `Model.fit`.
1382
1383    Returns:
1384        Scalar test loss (if the model has a single output and no metrics)
1385        or list of scalars (if the model has multiple outputs
1386        and/or metrics). The attribute `model.metrics_names` will give you
1387        the display labels for the scalar outputs.
1388
1389    Raises:
1390        RuntimeError: If `model.evaluate` is wrapped in `tf.function`.
1391        ValueError: in case of invalid arguments.
1392    """
1393    base_layer.keras_api_gauge.get_cell('evaluate').set(True)
1394    version_utils.disallow_legacy_graph('Model', 'evaluate')
1395    self._assert_compile_was_called()
1396    self._check_call_args('evaluate')
1397    _disallow_inside_tf_function('evaluate')
1398
1399    if self.distribute_strategy._should_use_with_coordinator:  # pylint: disable=protected-access
1400      raise NotImplementedError('`model.evaluate` is not yet supported with '
1401                                '`ParameterServerStrategy`.')
1402
1403    with self.distribute_strategy.scope():
1404      # Use cached evaluation data only when it's called in `Model.fit`
1405      if (getattr(self, '_fit_frame', None) is not None
1406          and tf_inspect.currentframe().f_back is self._fit_frame
1407          and getattr(self, '_eval_data_handler', None) is not None):
1408        data_handler = self._eval_data_handler
1409      else:
1410        # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
1411        data_handler = data_adapter.get_data_handler(
1412            x=x,
1413            y=y,
1414            sample_weight=sample_weight,
1415            batch_size=batch_size,
1416            steps_per_epoch=steps,
1417            initial_epoch=0,
1418            epochs=1,
1419            max_queue_size=max_queue_size,
1420            workers=workers,
1421            use_multiprocessing=use_multiprocessing,
1422            model=self,
1423            steps_per_execution=self._steps_per_execution)
1424
1425      # Container that configures and calls `tf.keras.Callback`s.
1426      if not isinstance(callbacks, callbacks_module.CallbackList):
1427        callbacks = callbacks_module.CallbackList(
1428            callbacks,
1429            add_history=True,
1430            add_progbar=verbose != 0,
1431            model=self,
1432            verbose=verbose,
1433            epochs=1,
1434            steps=data_handler.inferred_steps)
1435
1436      logs = {}
1437      self.test_function = self.make_test_function()
1438      self._test_counter.assign(0)
1439      callbacks.on_test_begin()
1440      for _, iterator in data_handler.enumerate_epochs():  # Single epoch.
1441        self.reset_metrics()
1442        with data_handler.catch_stop_iteration():
1443          for step in data_handler.steps():
1444            with trace.Trace('test', step_num=step, _r=1):
1445              callbacks.on_test_batch_begin(step)
1446              tmp_logs = self.test_function(iterator)
1447              if data_handler.should_sync:
1448                context.async_wait()
1449              logs = tmp_logs  # No error, now safe to assign to logs.
1450              end_step = step + data_handler.step_increment
1451              callbacks.on_test_batch_end(end_step, logs)
1452      logs = tf_utils.to_numpy_or_python_type(logs)
1453      callbacks.on_test_end(logs=logs)
1454
1455      if return_dict:
1456        return logs
1457      else:
1458        results = []
1459        for name in self.metrics_names:
1460          if name in logs:
1461            results.append(logs[name])
1462        for key in sorted(logs.keys()):
1463          if key not in self.metrics_names:
1464            results.append(logs[key])
1465        if len(results) == 1:
1466          return results[0]
1467        return results
1468
1469  def predict_step(self, data):
1470    """The logic for one inference step.
1471
1472    This method can be overridden to support custom inference logic.
1473    This method is called by `Model.make_predict_function`.
1474
1475    This method should contain the mathematical logic for one step of inference.
1476    This typically includes the forward pass.
1477
1478    Configuration details for *how* this logic is run (e.g. `tf.function` and
1479    `tf.distribute.Strategy` settings), should be left to
1480    `Model.make_predict_function`, which can also be overridden.
1481
1482    Args:
1483      data: A nested structure of `Tensor`s.
1484
1485    Returns:
1486      The result of one inference step, typically the output of calling the
1487      `Model` on data.
1488    """
1489    data = data_adapter.expand_1d(data)
1490    x, _, _ = data_adapter.unpack_x_y_sample_weight(data)
1491    return self(x, training=False)
1492
1493  def make_predict_function(self):
1494    """Creates a function that executes one step of inference.
1495
1496    This method can be overridden to support custom inference logic.
1497    This method is called by `Model.predict` and `Model.predict_on_batch`.
1498
1499    Typically, this method directly controls `tf.function` and
1500    `tf.distribute.Strategy` settings, and delegates the actual evaluation
1501    logic to `Model.predict_step`.
1502
1503    This function is cached the first time `Model.predict` or
1504    `Model.predict_on_batch` is called. The cache is cleared whenever
1505    `Model.compile` is called.
1506
1507    Returns:
1508      Function. The function created by this method should accept a
1509      `tf.data.Iterator`, and return the outputs of the `Model`.
1510    """
1511    if self.predict_function is not None:
1512      return self.predict_function
1513
1514    def step_function(model, iterator):
1515      """Runs a single evaluation step."""
1516
1517      def run_step(data):
1518        outputs = model.predict_step(data)
1519        # Ensure counter is updated only if `test_step` succeeds.
1520        with ops.control_dependencies(_minimum_control_deps(outputs)):
1521          model._predict_counter.assign_add(1)  # pylint: disable=protected-access
1522        return outputs
1523
1524      data = next(iterator)
1525      outputs = model.distribute_strategy.run(run_step, args=(data,))
1526      outputs = reduce_per_replica(
1527          outputs, self.distribute_strategy, reduction='concat')
1528      return outputs
1529
1530    if (self._steps_per_execution is None or
1531        self._steps_per_execution.numpy().item() == 1):
1532
1533      def predict_function(iterator):
1534        """Runs an evaluation execution with one step."""
1535        return step_function(self, iterator)
1536
1537    else:
1538
1539      def predict_function(iterator):
1540        """Runs an evaluation execution with multiple steps."""
1541        outputs = step_function(self, iterator)
1542        for _ in math_ops.range(self._steps_per_execution - 1):
1543          directives.set_loop_options(
1544              shape_invariants=[(
1545                  t, tf_utils.get_tensor_spec(t, dynamic_batch=True).shape)
1546                                for t in nest.flatten(outputs)])
1547          step_outputs = step_function(self, iterator)
1548          outputs = nest.map_structure(lambda t1, t2: concat([t1, t2]), outputs,
1549                                       step_outputs)
1550        return outputs
1551
1552    if not self.run_eagerly:
1553      predict_function = def_function.function(
1554          predict_function, experimental_relax_shapes=True)
1555
1556    self.predict_function = predict_function
1557    return self.predict_function
1558
1559  def predict(self,
1560              x,
1561              batch_size=None,
1562              verbose=0,
1563              steps=None,
1564              callbacks=None,
1565              max_queue_size=10,
1566              workers=1,
1567              use_multiprocessing=False):
1568    """Generates output predictions for the input samples.
1569
1570    Computation is done in batches. This method is designed for performance in
1571    large scale inputs. For small amount of inputs that fit in one batch,
1572    directly using `__call__` is recommended for faster execution, e.g.,
1573    `model(x)`, or `model(x, training=False)` if you have layers such as
1574    `tf.keras.layers.BatchNormalization` that behaves differently during
1575    inference. Also, note the fact that test loss is not affected by
1576    regularization layers like noise and dropout.
1577
1578    Args:
1579        x: Input samples. It could be:
1580          - A Numpy array (or array-like), or a list of arrays
1581            (in case the model has multiple inputs).
1582          - A TensorFlow tensor, or a list of tensors
1583            (in case the model has multiple inputs).
1584          - A `tf.data` dataset.
1585          - A generator or `keras.utils.Sequence` instance.
1586          A more detailed description of unpacking behavior for iterator types
1587          (Dataset, generator, Sequence) is given in the `Unpacking behavior
1588          for iterator-like inputs` section of `Model.fit`.
1589        batch_size: Integer or `None`.
1590            Number of samples per batch.
1591            If unspecified, `batch_size` will default to 32.
1592            Do not specify the `batch_size` if your data is in the
1593            form of dataset, generators, or `keras.utils.Sequence` instances
1594            (since they generate batches).
1595        verbose: Verbosity mode, 0 or 1.
1596        steps: Total number of steps (batches of samples)
1597            before declaring the prediction round finished.
1598            Ignored with the default value of `None`. If x is a `tf.data`
1599            dataset and `steps` is None, `predict` will
1600            run until the input dataset is exhausted.
1601        callbacks: List of `keras.callbacks.Callback` instances.
1602            List of callbacks to apply during prediction.
1603            See [callbacks](/api_docs/python/tf/keras/callbacks).
1604        max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
1605            input only. Maximum size for the generator queue.
1606            If unspecified, `max_queue_size` will default to 10.
1607        workers: Integer. Used for generator or `keras.utils.Sequence` input
1608            only. Maximum number of processes to spin up when using
1609            process-based threading. If unspecified, `workers` will default
1610            to 1. If 0, will execute the generator on the main thread.
1611        use_multiprocessing: Boolean. Used for generator or
1612            `keras.utils.Sequence` input only. If `True`, use process-based
1613            threading. If unspecified, `use_multiprocessing` will default to
1614            `False`. Note that because this implementation relies on
1615            multiprocessing, you should not pass non-picklable arguments to
1616            the generator as they can't be passed easily to children processes.
1617
1618    See the discussion of `Unpacking behavior for iterator-like inputs` for
1619    `Model.fit`. Note that Model.predict uses the same interpretation rules as
1620    `Model.fit` and `Model.evaluate`, so inputs must be unambiguous for all
1621    three methods.
1622
1623    Returns:
1624        Numpy array(s) of predictions.
1625
1626    Raises:
1627        RuntimeError: If `model.predict` is wrapped in `tf.function`.
1628        ValueError: In case of mismatch between the provided
1629            input data and the model's expectations,
1630            or in case a stateful model receives a number of samples
1631            that is not a multiple of the batch size.
1632    """
1633    base_layer.keras_api_gauge.get_cell('predict').set(True)
1634    version_utils.disallow_legacy_graph('Model', 'predict')
1635    self._check_call_args('predict')
1636    _disallow_inside_tf_function('predict')
1637
1638    if self.distribute_strategy._should_use_with_coordinator:  # pylint: disable=protected-access
1639      raise NotImplementedError('`model.predict` is not yet supported with '
1640                                '`ParameterServerStrategy`.')
1641
1642    outputs = None
1643    with self.distribute_strategy.scope():
1644      # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
1645      dataset_types = (dataset_ops.DatasetV1, dataset_ops.DatasetV2)
1646      if (self._in_multi_worker_mode() or _is_tpu_multi_host(
1647          self.distribute_strategy)) and isinstance(x, dataset_types):
1648        try:
1649          options = dataset_ops.Options()
1650          data_option = distribute_options.AutoShardPolicy.DATA
1651          options.experimental_distribute.auto_shard_policy = data_option
1652          x = x.with_options(options)
1653        except ValueError:
1654          warnings.warn('Using Model.predict with '
1655                        'MultiWorkerDistributionStrategy or TPUStrategy and '
1656                        'AutoShardPolicy.FILE might lead to out-of-order result'
1657                        '. Consider setting it to AutoShardPolicy.DATA.')
1658
1659      data_handler = data_adapter.get_data_handler(
1660          x=x,
1661          batch_size=batch_size,
1662          steps_per_epoch=steps,
1663          initial_epoch=0,
1664          epochs=1,
1665          max_queue_size=max_queue_size,
1666          workers=workers,
1667          use_multiprocessing=use_multiprocessing,
1668          model=self,
1669          steps_per_execution=self._steps_per_execution)
1670
1671      # Container that configures and calls `tf.keras.Callback`s.
1672      if not isinstance(callbacks, callbacks_module.CallbackList):
1673        callbacks = callbacks_module.CallbackList(
1674            callbacks,
1675            add_history=True,
1676            add_progbar=verbose != 0,
1677            model=self,
1678            verbose=verbose,
1679            epochs=1,
1680            steps=data_handler.inferred_steps)
1681
1682      self.predict_function = self.make_predict_function()
1683      self._predict_counter.assign(0)
1684      callbacks.on_predict_begin()
1685      batch_outputs = None
1686      for _, iterator in data_handler.enumerate_epochs():  # Single epoch.
1687        with data_handler.catch_stop_iteration():
1688          for step in data_handler.steps():
1689            callbacks.on_predict_batch_begin(step)
1690            tmp_batch_outputs = self.predict_function(iterator)
1691            if data_handler.should_sync:
1692              context.async_wait()
1693            batch_outputs = tmp_batch_outputs  # No error, now safe to assign.
1694            if outputs is None:
1695              outputs = nest.map_structure(lambda batch_output: [batch_output],
1696                                           batch_outputs)
1697            else:
1698              nest.map_structure_up_to(
1699                  batch_outputs,
1700                  lambda output, batch_output: output.append(batch_output),
1701                  outputs, batch_outputs)
1702            end_step = step + data_handler.step_increment
1703            callbacks.on_predict_batch_end(end_step, {'outputs': batch_outputs})
1704      if batch_outputs is None:
1705        raise ValueError('Expect x to be a non-empty array or dataset.')
1706      callbacks.on_predict_end()
1707    all_outputs = nest.map_structure_up_to(batch_outputs, concat, outputs)
1708    return tf_utils.to_numpy_or_python_type(all_outputs)
1709
1710  def reset_metrics(self):
1711    """Resets the state of all the metrics in the model.
1712
1713    Examples:
1714
1715    >>> inputs = tf.keras.layers.Input(shape=(3,))
1716    >>> outputs = tf.keras.layers.Dense(2)(inputs)
1717    >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
1718    >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"])
1719
1720    >>> x = np.random.random((2, 3))
1721    >>> y = np.random.randint(0, 2, (2, 2))
1722    >>> _ = model.fit(x, y, verbose=0)
1723    >>> assert all(float(m.result()) for m in model.metrics)
1724
1725    >>> model.reset_metrics()
1726    >>> assert all(float(m.result()) == 0 for m in model.metrics)
1727
1728    """
1729    for m in self.metrics:
1730      m.reset_states()
1731
1732  def train_on_batch(self,
1733                     x,
1734                     y=None,
1735                     sample_weight=None,
1736                     class_weight=None,
1737                     reset_metrics=True,
1738                     return_dict=False):
1739    """Runs a single gradient update on a single batch of data.
1740
1741    Args:
1742        x: Input data. It could be:
1743          - A Numpy array (or array-like), or a list of arrays
1744              (in case the model has multiple inputs).
1745          - A TensorFlow tensor, or a list of tensors
1746              (in case the model has multiple inputs).
1747          - A dict mapping input names to the corresponding array/tensors,
1748              if the model has named inputs.
1749        y: Target data. Like the input data `x`, it could be either Numpy
1750          array(s) or TensorFlow tensor(s). It should be consistent with `x`
1751          (you cannot have Numpy inputs and tensor targets, or inversely).
1752        sample_weight: Optional array of the same length as x, containing
1753          weights to apply to the model's loss for each sample. In the case of
1754          temporal data, you can pass a 2D array with shape (samples,
1755          sequence_length), to apply a different weight to every timestep of
1756          every sample.
1757        class_weight: Optional dictionary mapping class indices (integers) to a
1758          weight (float) to apply to the model's loss for the samples from this
1759          class during training. This can be useful to tell the model to "pay
1760          more attention" to samples from an under-represented class.
1761        reset_metrics: If `True`, the metrics returned will be only for this
1762          batch. If `False`, the metrics will be statefully accumulated across
1763          batches.
1764        return_dict: If `True`, loss and metric results are returned as a dict,
1765          with each key being the name of the metric. If `False`, they are
1766          returned as a list.
1767
1768    Returns:
1769        Scalar training loss
1770        (if the model has a single output and no metrics)
1771        or list of scalars (if the model has multiple outputs
1772        and/or metrics). The attribute `model.metrics_names` will give you
1773        the display labels for the scalar outputs.
1774
1775    Raises:
1776      RuntimeError: If `model.train_on_batch` is wrapped in `tf.function`.
1777      ValueError: In case of invalid user-provided arguments.
1778    """
1779    self._assert_compile_was_called()
1780    self._check_call_args('train_on_batch')
1781    _disallow_inside_tf_function('train_on_batch')
1782    with self.distribute_strategy.scope(), \
1783         training_utils.RespectCompiledTrainableState(self):
1784      iterator = data_adapter.single_batch_iterator(self.distribute_strategy, x,
1785                                                    y, sample_weight,
1786                                                    class_weight)
1787      self.train_function = self.make_train_function()
1788      logs = self.train_function(iterator)
1789
1790    if reset_metrics:
1791      self.reset_metrics()
1792    logs = tf_utils.to_numpy_or_python_type(logs)
1793    if return_dict:
1794      return logs
1795    else:
1796      results = [logs.get(name, None) for name in self.metrics_names]
1797      if len(results) == 1:
1798        return results[0]
1799      return results
1800
1801  def test_on_batch(self,
1802                    x,
1803                    y=None,
1804                    sample_weight=None,
1805                    reset_metrics=True,
1806                    return_dict=False):
1807    """Test the model on a single batch of samples.
1808
1809    Args:
1810        x: Input data. It could be: - A Numpy array (or array-like), or a list
1811          of arrays (in case the model has multiple inputs). - A TensorFlow
1812          tensor, or a list of tensors (in case the model has multiple inputs).
1813          - A dict mapping input names to the corresponding array/tensors, if
1814          the model has named inputs.
1815        y: Target data. Like the input data `x`, it could be either Numpy
1816          array(s) or TensorFlow tensor(s). It should be consistent with `x`
1817          (you cannot have Numpy inputs and tensor targets, or inversely).
1818        sample_weight: Optional array of the same length as x, containing
1819          weights to apply to the model's loss for each sample. In the case of
1820          temporal data, you can pass a 2D array with shape (samples,
1821          sequence_length), to apply a different weight to every timestep of
1822          every sample.
1823        reset_metrics: If `True`, the metrics returned will be only for this
1824          batch. If `False`, the metrics will be statefully accumulated across
1825          batches.
1826        return_dict: If `True`, loss and metric results are returned as a dict,
1827          with each key being the name of the metric. If `False`, they are
1828          returned as a list.
1829
1830    Returns:
1831        Scalar test loss (if the model has a single output and no metrics)
1832        or list of scalars (if the model has multiple outputs
1833        and/or metrics). The attribute `model.metrics_names` will give you
1834        the display labels for the scalar outputs.
1835
1836    Raises:
1837        RuntimeError: If `model.test_on_batch` is wrapped in `tf.function`.
1838        ValueError: In case of invalid user-provided arguments.
1839    """
1840    self._assert_compile_was_called()
1841    self._check_call_args('test_on_batch')
1842    _disallow_inside_tf_function('test_on_batch')
1843    with self.distribute_strategy.scope():
1844      iterator = data_adapter.single_batch_iterator(self.distribute_strategy, x,
1845                                                    y, sample_weight)
1846      self.test_function = self.make_test_function()
1847      logs = self.test_function(iterator)
1848
1849    if reset_metrics:
1850      self.reset_metrics()
1851    logs = tf_utils.to_numpy_or_python_type(logs)
1852    if return_dict:
1853      return logs
1854    else:
1855      results = [logs.get(name, None) for name in self.metrics_names]
1856      if len(results) == 1:
1857        return results[0]
1858      return results
1859
1860  def predict_on_batch(self, x):
1861    """Returns predictions for a single batch of samples.
1862
1863    Args:
1864        x: Input data. It could be: - A Numpy array (or array-like), or a list
1865          of arrays (in case the model has multiple inputs). - A TensorFlow
1866          tensor, or a list of tensors (in case the model has multiple inputs).
1867
1868    Returns:
1869        Numpy array(s) of predictions.
1870
1871    Raises:
1872        RuntimeError: If `model.predict_on_batch` is wrapped in `tf.function`.
1873        ValueError: In case of mismatch between given number of inputs and
1874          expectations of the model.
1875    """
1876    self._check_call_args('predict_on_batch')
1877    _disallow_inside_tf_function('predict_on_batch')
1878    with self.distribute_strategy.scope():
1879      iterator = data_adapter.single_batch_iterator(self.distribute_strategy, x)
1880      self.predict_function = self.make_predict_function()
1881      outputs = self.predict_function(iterator)
1882    return tf_utils.to_numpy_or_python_type(outputs)
1883
1884  def fit_generator(self,
1885                    generator,
1886                    steps_per_epoch=None,
1887                    epochs=1,
1888                    verbose=1,
1889                    callbacks=None,
1890                    validation_data=None,
1891                    validation_steps=None,
1892                    validation_freq=1,
1893                    class_weight=None,
1894                    max_queue_size=10,
1895                    workers=1,
1896                    use_multiprocessing=False,
1897                    shuffle=True,
1898                    initial_epoch=0):
1899    """Fits the model on data yielded batch-by-batch by a Python generator.
1900
1901    DEPRECATED:
1902      `Model.fit` now supports generators, so there is no longer any need to use
1903      this endpoint.
1904    """
1905    warnings.warn('`Model.fit_generator` is deprecated and '
1906                  'will be removed in a future version. '
1907                  'Please use `Model.fit`, which supports generators.')
1908    return self.fit(
1909        generator,
1910        steps_per_epoch=steps_per_epoch,
1911        epochs=epochs,
1912        verbose=verbose,
1913        callbacks=callbacks,
1914        validation_data=validation_data,
1915        validation_steps=validation_steps,
1916        validation_freq=validation_freq,
1917        class_weight=class_weight,
1918        max_queue_size=max_queue_size,
1919        workers=workers,
1920        use_multiprocessing=use_multiprocessing,
1921        shuffle=shuffle,
1922        initial_epoch=initial_epoch)
1923
1924  def evaluate_generator(self,
1925                         generator,
1926                         steps=None,
1927                         callbacks=None,
1928                         max_queue_size=10,
1929                         workers=1,
1930                         use_multiprocessing=False,
1931                         verbose=0):
1932    """Evaluates the model on a data generator.
1933
1934    DEPRECATED:
1935      `Model.evaluate` now supports generators, so there is no longer any need
1936      to use this endpoint.
1937    """
1938    warnings.warn('`Model.evaluate_generator` is deprecated and '
1939                  'will be removed in a future version. '
1940                  'Please use `Model.evaluate`, which supports generators.')
1941    self._check_call_args('evaluate_generator')
1942
1943    return self.evaluate(
1944        generator,
1945        steps=steps,
1946        max_queue_size=max_queue_size,
1947        workers=workers,
1948        use_multiprocessing=use_multiprocessing,
1949        verbose=verbose,
1950        callbacks=callbacks)
1951
1952  def predict_generator(self,
1953                        generator,
1954                        steps=None,
1955                        callbacks=None,
1956                        max_queue_size=10,
1957                        workers=1,
1958                        use_multiprocessing=False,
1959                        verbose=0):
1960    """Generates predictions for the input samples from a data generator.
1961
1962    DEPRECATED:
1963      `Model.predict` now supports generators, so there is no longer any need
1964      to use this endpoint.
1965    """
1966    warnings.warn('`Model.predict_generator` is deprecated and '
1967                  'will be removed in a future version. '
1968                  'Please use `Model.predict`, which supports generators.')
1969    return self.predict(
1970        generator,
1971        steps=steps,
1972        max_queue_size=max_queue_size,
1973        workers=workers,
1974        use_multiprocessing=use_multiprocessing,
1975        verbose=verbose,
1976        callbacks=callbacks)
1977
1978  ######################################################################
1979  # Functions below are not training related. They are for model weights
1980  # tracking, save/load, serialization, etc.
1981  ######################################################################
1982
1983  @property
1984  def trainable_weights(self):
1985    self._assert_weights_created()
1986    if not self._trainable:
1987      return []
1988    trainable_variables = []
1989    for trackable_obj in self._self_tracked_trackables:
1990      trainable_variables += trackable_obj.trainable_variables
1991    trainable_variables += self._trainable_weights
1992    return self._dedup_weights(trainable_variables)
1993
1994  @property
1995  def non_trainable_weights(self):
1996    self._assert_weights_created()
1997    non_trainable_variables = []
1998    for trackable_obj in self._self_tracked_trackables:
1999      non_trainable_variables += trackable_obj.non_trainable_variables
2000
2001    if not self._trainable:
2002      # Return order is all trainable vars, then all non-trainable vars.
2003      trainable_variables = []
2004      for trackable_obj in self._self_tracked_trackables:
2005        trainable_variables += trackable_obj.trainable_variables
2006
2007      non_trainable_variables = (
2008          trainable_variables + self._trainable_weights +
2009          non_trainable_variables + self._non_trainable_weights)
2010    else:
2011      non_trainable_variables = (
2012          non_trainable_variables + self._non_trainable_weights)
2013
2014    return self._dedup_weights(non_trainable_variables)
2015
2016  def get_weights(self):
2017    """Retrieves the weights of the model.
2018
2019    Returns:
2020        A flat list of Numpy arrays.
2021    """
2022    with self.distribute_strategy.scope():
2023      return super(Model, self).get_weights()
2024
2025  def save(self,
2026           filepath,
2027           overwrite=True,
2028           include_optimizer=True,
2029           save_format=None,
2030           signatures=None,
2031           options=None,
2032           save_traces=True):
2033    # pylint: disable=line-too-long
2034    """Saves the model to Tensorflow SavedModel or a single HDF5 file.
2035
2036    Please see `tf.keras.models.save_model` or the
2037    [Serialization and Saving guide](https://keras.io/guides/serialization_and_saving/)
2038    for details.
2039
2040    Args:
2041        filepath: String, PathLike, path to SavedModel or H5 file to save the
2042            model.
2043        overwrite: Whether to silently overwrite any existing file at the
2044            target location, or provide the user with a manual prompt.
2045        include_optimizer: If True, save optimizer's state together.
2046        save_format: Either `'tf'` or `'h5'`, indicating whether to save the
2047            model to Tensorflow SavedModel or HDF5. Defaults to 'tf' in TF 2.X,
2048            and 'h5' in TF 1.X.
2049        signatures: Signatures to save with the SavedModel. Applicable to the
2050            'tf' format only. Please see the `signatures` argument in
2051            `tf.saved_model.save` for details.
2052        options: (only applies to SavedModel format)
2053            `tf.saved_model.SaveOptions` object that specifies options for
2054            saving to SavedModel.
2055        save_traces: (only applies to SavedModel format) When enabled, the
2056            SavedModel will store the function traces for each layer. This
2057            can be disabled, so that only the configs of each layer are stored.
2058            Defaults to `True`. Disabling this will decrease serialization time
2059            and reduce file size, but it requires that all custom layers/models
2060            implement a `get_config()` method.
2061
2062    Example:
2063
2064    ```python
2065    from keras.models import load_model
2066
2067    model.save('my_model.h5')  # creates a HDF5 file 'my_model.h5'
2068    del model  # deletes the existing model
2069
2070    # returns a compiled model
2071    # identical to the previous one
2072    model = load_model('my_model.h5')
2073    ```
2074    """
2075    # pylint: enable=line-too-long
2076    save.save_model(self, filepath, overwrite, include_optimizer, save_format,
2077                    signatures, options, save_traces)
2078
2079  def save_weights(self,
2080                   filepath,
2081                   overwrite=True,
2082                   save_format=None,
2083                   options=None):
2084    """Saves all layer weights.
2085
2086    Either saves in HDF5 or in TensorFlow format based on the `save_format`
2087    argument.
2088
2089    When saving in HDF5 format, the weight file has:
2090      - `layer_names` (attribute), a list of strings
2091          (ordered names of model layers).
2092      - For every layer, a `group` named `layer.name`
2093          - For every such layer group, a group attribute `weight_names`,
2094              a list of strings
2095              (ordered names of weights tensor of the layer).
2096          - For every weight in the layer, a dataset
2097              storing the weight value, named after the weight tensor.
2098
2099    When saving in TensorFlow format, all objects referenced by the network are
2100    saved in the same format as `tf.train.Checkpoint`, including any `Layer`
2101    instances or `Optimizer` instances assigned to object attributes. For
2102    networks constructed from inputs and outputs using `tf.keras.Model(inputs,
2103    outputs)`, `Layer` instances used by the network are tracked/saved
2104    automatically. For user-defined classes which inherit from `tf.keras.Model`,
2105    `Layer` instances must be assigned to object attributes, typically in the
2106    constructor. See the documentation of `tf.train.Checkpoint` and
2107    `tf.keras.Model` for details.
2108
2109    While the formats are the same, do not mix `save_weights` and
2110    `tf.train.Checkpoint`. Checkpoints saved by `Model.save_weights` should be
2111    loaded using `Model.load_weights`. Checkpoints saved using
2112    `tf.train.Checkpoint.save` should be restored using the corresponding
2113    `tf.train.Checkpoint.restore`. Prefer `tf.train.Checkpoint` over
2114    `save_weights` for training checkpoints.
2115
2116    The TensorFlow format matches objects and variables by starting at a root
2117    object, `self` for `save_weights`, and greedily matching attribute
2118    names. For `Model.save` this is the `Model`, and for `Checkpoint.save` this
2119    is the `Checkpoint` even if the `Checkpoint` has a model attached. This
2120    means saving a `tf.keras.Model` using `save_weights` and loading into a
2121    `tf.train.Checkpoint` with a `Model` attached (or vice versa) will not match
2122    the `Model`'s variables. See the [guide to training
2123    checkpoints](https://www.tensorflow.org/guide/checkpoint) for details
2124    on the TensorFlow format.
2125
2126    Args:
2127        filepath: String or PathLike, path to the file to save the weights to.
2128            When saving in TensorFlow format, this is the prefix used for
2129            checkpoint files (multiple files are generated). Note that the '.h5'
2130            suffix causes weights to be saved in HDF5 format.
2131        overwrite: Whether to silently overwrite any existing file at the
2132            target location, or provide the user with a manual prompt.
2133        save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or
2134            '.keras' will default to HDF5 if `save_format` is `None`. Otherwise
2135            `None` defaults to 'tf'.
2136        options: Optional `tf.train.CheckpointOptions` object that specifies
2137            options for saving weights.
2138
2139    Raises:
2140        ImportError: If h5py is not available when attempting to save in HDF5
2141            format.
2142        ValueError: For invalid/unknown format arguments.
2143    """
2144    self._assert_weights_created()
2145    filepath = path_to_string(filepath)
2146    filepath_is_h5 = saving_utils.is_hdf5_filepath(filepath)
2147    if save_format is None:
2148      if filepath_is_h5:
2149        save_format = 'h5'
2150      else:
2151        save_format = 'tf'
2152    else:
2153      user_format = save_format.lower().strip()
2154      if user_format in ('tensorflow', 'tf'):
2155        save_format = 'tf'
2156      elif user_format in ('hdf5', 'h5', 'keras'):
2157        save_format = 'h5'
2158      else:
2159        raise ValueError(
2160            'Unknown format "%s". Was expecting one of {"tf", "h5"}.' % (
2161                save_format,))
2162    if save_format == 'tf' and filepath_is_h5:
2163      raise ValueError(
2164          ('save_weights got save_format="tf"/"tensorflow", but the '
2165           'filepath ("%s") looks like an HDF5 file. Omit the ".h5"/".keras" '
2166           'when saving in TensorFlow format.')
2167          % filepath)
2168
2169    if save_format == 'h5' and h5py is None:
2170      raise ImportError(
2171          '`save_weights` requires h5py when saving in hdf5.')
2172    if save_format == 'tf':
2173      check_filepath = filepath + '.index'
2174    else:
2175      check_filepath = filepath
2176    # If file exists and should not be overwritten:
2177    if not overwrite and os.path.isfile(check_filepath):
2178      proceed = ask_to_proceed_with_overwrite(check_filepath)
2179      if not proceed:
2180        return
2181    if save_format == 'h5':
2182      with h5py.File(filepath, 'w') as f:
2183        hdf5_format.save_weights_to_hdf5_group(f, self.layers)
2184    else:
2185      if context.executing_eagerly():
2186        session = None
2187      else:
2188        session = backend.get_session()
2189      self._trackable_saver.save(filepath, session=session, options=options)
2190      # Record this checkpoint so it's visible from tf.train.latest_checkpoint.
2191      checkpoint_management.update_checkpoint_state_internal(
2192          save_dir=os.path.dirname(filepath),
2193          model_checkpoint_path=filepath,
2194          save_relative_paths=True,
2195          all_model_checkpoint_paths=[filepath])
2196
2197  def load_weights(self,
2198                   filepath,
2199                   by_name=False,
2200                   skip_mismatch=False,
2201                   options=None):
2202    """Loads all layer weights, either from a TensorFlow or an HDF5 weight file.
2203
2204    If `by_name` is False weights are loaded based on the network's
2205    topology. This means the architecture should be the same as when the weights
2206    were saved.  Note that layers that don't have weights are not taken into
2207    account in the topological ordering, so adding or removing layers is fine as
2208    long as they don't have weights.
2209
2210    If `by_name` is True, weights are loaded into layers only if they share the
2211    same name. This is useful for fine-tuning or transfer-learning models where
2212    some of the layers have changed.
2213
2214    Only topological loading (`by_name=False`) is supported when loading weights
2215    from the TensorFlow format. Note that topological loading differs slightly
2216    between TensorFlow and HDF5 formats for user-defined classes inheriting from
2217    `tf.keras.Model`: HDF5 loads based on a flattened list of weights, while the
2218    TensorFlow format loads based on the object-local names of attributes to
2219    which layers are assigned in the `Model`'s constructor.
2220
2221    Args:
2222        filepath: String, path to the weights file to load. For weight files in
2223            TensorFlow format, this is the file prefix (the same as was passed
2224            to `save_weights`). This can also be a path to a SavedModel
2225            saved from `model.save`.
2226        by_name: Boolean, whether to load weights by name or by topological
2227            order. Only topological loading is supported for weight files in
2228            TensorFlow format.
2229        skip_mismatch: Boolean, whether to skip loading of layers where there is
2230            a mismatch in the number of weights, or a mismatch in the shape of
2231            the weight (only valid when `by_name=True`).
2232        options: Optional `tf.train.CheckpointOptions` object that specifies
2233            options for loading weights.
2234
2235    Returns:
2236        When loading a weight file in TensorFlow format, returns the same status
2237        object as `tf.train.Checkpoint.restore`. When graph building, restore
2238        ops are run automatically as soon as the network is built (on first call
2239        for user-defined classes inheriting from `Model`, immediately if it is
2240        already built).
2241
2242        When loading weights in HDF5 format, returns `None`.
2243
2244    Raises:
2245        ImportError: If h5py is not available and the weight file is in HDF5
2246            format.
2247        ValueError: If `skip_mismatch` is set to `True` when `by_name` is
2248          `False`.
2249    """
2250    if backend.is_tpu_strategy(self._distribution_strategy):
2251      if (self._distribution_strategy.extended.steps_per_run > 1 and
2252          (not saving_utils.is_hdf5_filepath(filepath))):
2253        raise ValueError('Load weights is not yet supported with TPUStrategy '
2254                         'with steps_per_run greater than 1.')
2255    if skip_mismatch and not by_name:
2256      raise ValueError(
2257          'When calling model.load_weights, skip_mismatch can only be set to '
2258          'True when by_name is True.')
2259
2260    filepath, save_format = _detect_save_format(filepath)
2261    if save_format == 'tf':
2262      status = self._trackable_saver.restore(filepath, options)
2263      if by_name:
2264        raise NotImplementedError(
2265            'Weights may only be loaded based on topology into Models when '
2266            'loading TensorFlow-formatted weights (got by_name=True to '
2267            'load_weights).')
2268      if not context.executing_eagerly():
2269        session = backend.get_session()
2270        # Restore existing variables (if any) immediately, and set up a
2271        # streaming restore for any variables created in the future.
2272        trackable_utils.streaming_restore(status=status, session=session)
2273      status.assert_nontrivial_match()
2274      return status
2275    if h5py is None:
2276      raise ImportError(
2277          '`load_weights` requires h5py when loading weights from HDF5.')
2278    if not self._is_graph_network and not self.built:
2279      raise ValueError(
2280          'Unable to load weights saved in HDF5 format into a subclassed '
2281          'Model which has not created its variables yet. Call the Model '
2282          'first, then load the weights.')
2283    self._assert_weights_created()
2284    with h5py.File(filepath, 'r') as f:
2285      if 'layer_names' not in f.attrs and 'model_weights' in f:
2286        f = f['model_weights']
2287      if by_name:
2288        hdf5_format.load_weights_from_hdf5_group_by_name(
2289            f, self.layers, skip_mismatch=skip_mismatch)
2290      else:
2291        hdf5_format.load_weights_from_hdf5_group(f, self.layers)
2292
2293  def _updated_config(self):
2294    """Util shared between different serialization methods.
2295
2296    Returns:
2297        Model config with Keras version information added.
2298    """
2299    from tensorflow.python.keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top
2300
2301    config = self.get_config()
2302    model_config = {
2303        'class_name': self.__class__.__name__,
2304        'config': config,
2305        'keras_version': keras_version,
2306        'backend': backend.backend()
2307    }
2308    return model_config
2309
2310  def get_config(self):
2311    raise NotImplementedError
2312
2313  @classmethod
2314  def from_config(cls, config, custom_objects=None):
2315    # Since only FunctionalModel produces config, the model can only
2316    # be constructed for FunctionalModel
2317    from tensorflow.python.keras.engine import functional  # pylint: disable=g-import-not-at-top
2318    return functional.Functional.from_config(
2319        config, custom_objects=custom_objects)
2320
2321  def to_json(self, **kwargs):
2322    """Returns a JSON string containing the network configuration.
2323
2324    To load a network from a JSON save file, use
2325    `keras.models.model_from_json(json_string, custom_objects={})`.
2326
2327    Args:
2328        **kwargs: Additional keyword arguments
2329            to be passed to `json.dumps()`.
2330
2331    Returns:
2332        A JSON string.
2333    """
2334    model_config = self._updated_config()
2335    return json.dumps(
2336        model_config, default=json_utils.get_json_type, **kwargs)
2337
2338  def to_yaml(self, **kwargs):
2339    """Returns a yaml string containing the network configuration.
2340
2341    To load a network from a yaml save file, use
2342    `keras.models.model_from_yaml(yaml_string, custom_objects={})`.
2343
2344    `custom_objects` should be a dictionary mapping
2345    the names of custom losses / layers / etc to the corresponding
2346    functions / classes.
2347
2348    Args:
2349        **kwargs: Additional keyword arguments
2350            to be passed to `yaml.dump()`.
2351
2352    Returns:
2353        A YAML string.
2354
2355    Raises:
2356        ImportError: if yaml module is not found.
2357    """
2358    if yaml is None:
2359      raise ImportError(
2360          'Requires yaml module installed (`pip install pyyaml`).')
2361    return yaml.dump(self._updated_config(), **kwargs)
2362
2363  def reset_states(self):
2364    for layer in self.layers:
2365      if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False):
2366        layer.reset_states()
2367
2368  @property
2369  @doc_controls.do_not_generate_docs
2370  def state_updates(self):
2371    """Deprecated, do NOT use!
2372
2373    Returns the `updates` from all layers that are stateful.
2374
2375    This is useful for separating training updates and
2376    state updates, e.g. when we need to update a layer's internal state
2377    during prediction.
2378
2379    Returns:
2380        A list of update ops.
2381    """
2382    warnings.warn('`Model.state_updates` will be removed in a future version. '
2383                  'This property should not be used in TensorFlow 2.0, '
2384                  'as `updates` are applied automatically.')
2385    state_updates = []
2386    for layer in self.layers:
2387      if getattr(layer, 'stateful', False):
2388        if hasattr(layer, 'updates'):
2389          state_updates += layer.updates
2390    return state_updates
2391
2392  @property
2393  def weights(self):
2394    """Returns the list of all layer variables/weights.
2395
2396    Note: This will not track the weights of nested `tf.Modules` that are not
2397    themselves Keras layers.
2398
2399    Returns:
2400      A list of variables.
2401    """
2402    return self._dedup_weights(self._undeduplicated_weights)
2403
2404  @property
2405  def _undeduplicated_weights(self):
2406    """Returns the undeduplicated list of all layer variables/weights."""
2407    self._assert_weights_created()
2408    weights = []
2409    for layer in self._self_tracked_trackables:
2410      weights += layer.variables
2411    weights += (self._trainable_weights + self._non_trainable_weights)
2412    return weights
2413
2414  def summary(self, line_length=None, positions=None, print_fn=None):
2415    """Prints a string summary of the network.
2416
2417    Args:
2418        line_length: Total length of printed lines
2419            (e.g. set this to adapt the display to different
2420            terminal window sizes).
2421        positions: Relative or absolute positions of log elements
2422            in each line. If not provided,
2423            defaults to `[.33, .55, .67, 1.]`.
2424        print_fn: Print function to use. Defaults to `print`.
2425            It will be called on each line of the summary.
2426            You can set it to a custom function
2427            in order to capture the string summary.
2428
2429    Raises:
2430        ValueError: if `summary()` is called before the model is built.
2431    """
2432    if not self.built:
2433      raise ValueError('This model has not yet been built. '
2434                       'Build the model first by calling `build()` or calling '
2435                       '`fit()` with some data, or specify '
2436                       'an `input_shape` argument in the first layer(s) for '
2437                       'automatic build.')
2438    layer_utils.print_summary(self,
2439                              line_length=line_length,
2440                              positions=positions,
2441                              print_fn=print_fn)
2442
2443  @property
2444  def layers(self):
2445    return list(self._flatten_layers(include_self=False, recursive=False))
2446
2447  def get_layer(self, name=None, index=None):
2448    """Retrieves a layer based on either its name (unique) or index.
2449
2450    If `name` and `index` are both provided, `index` will take precedence.
2451    Indices are based on order of horizontal graph traversal (bottom-up).
2452
2453    Args:
2454        name: String, name of layer.
2455        index: Integer, index of layer.
2456
2457    Returns:
2458        A layer instance.
2459
2460    Raises:
2461        ValueError: In case of invalid layer name or index.
2462    """
2463    # TODO(fchollet): We could build a dictionary based on layer names
2464    # since they are constant, but we have not done that yet.
2465    if index is not None and name is not None:
2466      raise ValueError('Provide only a layer name or a layer index.')
2467
2468    if index is not None:
2469      if len(self.layers) <= index:
2470        raise ValueError('Was asked to retrieve layer at index ' + str(index) +
2471                         ' but model only has ' + str(len(self.layers)) +
2472                         ' layers.')
2473      else:
2474        return self.layers[index]
2475
2476    if name is not None:
2477      for layer in self.layers:
2478        if layer.name == name:
2479          return layer
2480      raise ValueError('No such layer: ' + name + '.')
2481    raise ValueError('Provide either a layer name or layer index.')
2482
2483  @trackable.no_automatic_dependency_tracking
2484  def _set_save_spec(self, inputs):
2485    if self._saved_model_inputs_spec is not None:
2486      return  # Already set.
2487
2488    input_names = self.input_names
2489    if not input_names:
2490      input_names = compile_utils.create_pseudo_input_names(inputs)
2491
2492    flat_inputs = nest.flatten(inputs)
2493    specs = []
2494    for name, tensor in zip(input_names, flat_inputs):
2495      specs.append(
2496          tf_utils.get_tensor_spec(tensor, dynamic_batch=False, name=name))
2497    specs = nest.pack_sequence_as(inputs, specs)
2498
2499    self._saved_model_inputs_spec = specs
2500
2501    # Store the input shapes
2502    if (self.__class__.__name__ == 'Sequential' and
2503        self._build_input_shape is None):
2504      self._build_input_shape = nest.map_structure(
2505          lambda x: None if x is None else x.shape, specs)
2506
2507  def _assert_weights_created(self):
2508    """Asserts that all the weights for the model have been created.
2509
2510    For a non-dynamic model, the weights must already be created after the
2511    layer has been called. For a dynamic model, the exact list of weights can
2512    never be known for certain since it may change at any time during execution.
2513
2514    We run this check right before accessing weights or getting the Numpy value
2515    for the current weights. Otherwise, if the layer has never been called,
2516    the user would just get an empty list, which is misleading.
2517
2518    Raises:
2519      ValueError: if the weights of the network has not yet been created.
2520    """
2521    if self.dynamic:
2522      return
2523
2524    if ('build' in self.__class__.__dict__ and
2525        self.__class__ != Model and
2526        not self.built):
2527      # For any model that has customized build() method but hasn't
2528      # been invoked yet, this will cover both sequential and subclass model.
2529      # Also make sure to exclude Model class itself which has build() defined.
2530      raise ValueError('Weights for model %s have not yet been created. '
2531                       'Weights are created when the Model is first called on '
2532                       'inputs or `build()` is called with an `input_shape`.' %
2533                       self.name)
2534
2535  def _check_call_args(self, method_name):
2536    """Check that `call` has only one positional arg."""
2537    # Always allow first arg, regardless of arg name.
2538    fullargspec = self._call_full_argspec
2539    if fullargspec.defaults:
2540      positional_args = fullargspec.args[:-len(fullargspec.defaults)]
2541    else:
2542      positional_args = fullargspec.args
2543    if 'training' in positional_args:
2544      positional_args.remove('training')
2545
2546    # self and first arg can be positional.
2547    if len(positional_args) > 2:
2548      extra_args = positional_args[2:]
2549      raise ValueError(
2550          'Models passed to `' + method_name + '` can only have `training` '
2551          'and the first argument in `call` as positional arguments, '
2552          'found: ' + str(extra_args) + '.')
2553
2554  def _validate_compile(self, optimizer, metrics, **kwargs):
2555    """Performs validation checks for the default `compile`."""
2556    if any(
2557        isinstance(opt, optimizer_v1.Optimizer)
2558        for opt in nest.flatten(optimizer)):
2559      raise ValueError(
2560          '`tf.compat.v1.keras` Optimizer (', optimizer, ') is '
2561          'not supported when eager execution is enabled. Use a '
2562          '`tf.keras` Optimizer instead, or disable eager '
2563          'execution.')
2564
2565    kwargs.pop('cloning', None)  # Legacy DistStrat argument, never used.
2566    kwargs.pop('experimental_run_tf_function', None)  # Always `True`.
2567    if kwargs.pop('distribute', None) is not None:
2568      raise ValueError(
2569          'Distribute argument in compile is not available in TF 2.0 please '
2570          'create the model under the distribution strategy scope.')
2571    if kwargs.pop('target_tensors', None) is not None:
2572      raise ValueError(
2573          'target_tensors argument is not supported when executing eagerly.')
2574    invalid_kwargs = set(kwargs) - {'sample_weight_mode'}
2575    if invalid_kwargs:
2576      raise TypeError('Invalid keyword argument(s) in `compile`: %s' %
2577                      (invalid_kwargs,))
2578
2579    # Model must be created and compiled with the same DistStrat.
2580    if self.built and ds_context.has_strategy():
2581      strategy = ds_context.get_strategy()
2582      for v in self.variables:
2583        if not strategy.extended.variable_created_in_scope(v):
2584          raise ValueError(
2585              'Variable (%s) was not created in the distribution strategy '
2586              'scope of (%s). It is most likely due to not all layers or '
2587              'the model or optimizer being created outside the distribution '
2588              'strategy scope. Try to make sure your code looks similar '
2589              'to the following.\n'
2590              'with strategy.scope():\n'
2591              '  model=_create_model()\n'
2592              '  model.compile(...)' % (v, strategy))
2593
2594    # Model metrics must be created in the same distribution strategy scope
2595    # as the model.
2596    strategy = self.distribute_strategy
2597    for metric in nest.flatten(metrics):
2598      for v in getattr(metric, 'variables', []):
2599        if not strategy.extended.variable_created_in_scope(v):
2600          raise ValueError(
2601              'Metric (%s) passed to model.compile was created inside of a '
2602              'different distribution strategy scope than the model. All '
2603              'metrics must be created in the same distribution strategy '
2604              'scope as the model (in this case %s). If you pass in a string '
2605              'identifier for a metric to compile the metric will '
2606              'automatically be created in the correct distribution '
2607              'strategy scope.' % (metric, strategy)
2608          )
2609
2610    # Model metrics must be created in the same distribution strategy scope
2611    # as the model.
2612    for opt in nest.flatten(optimizer):
2613      for v in getattr(opt, '_weights', []):
2614        if not strategy.extended.variable_created_in_scope(v):
2615          raise ValueError(
2616              'Optimizer (%s) passed to model.compile was created inside of a '
2617              'different distribution strategy scope than the model. All '
2618              'optimizers must be created in the same distribution strategy '
2619              'scope as the model (in this case %s). If you pass in a string '
2620              'identifier for an optimizer to compile the optimizer will '
2621              'automatically be created in the correct distribution '
2622              'strategy scope.' % (opt, strategy))
2623
2624  def _maybe_load_initial_epoch_from_ckpt(self, initial_epoch):
2625    """Maybe load initial epoch from ckpt considering possible worker recovery.
2626
2627    Refer to tensorflow/python/keras/distribute/worker_training_state.py
2628    for more information.
2629
2630    Args:
2631      initial_epoch: The original initial_epoch user passes in in `fit()`.
2632
2633    Returns:
2634      If the training is recovering from previous failure under multi-worker
2635      training setting, return the epoch the training is supposed to continue
2636      at. Otherwise, return the `initial_epoch` the user passes in.
2637    """
2638    if self._training_state is not None:
2639      return self._training_state.maybe_load_initial_epoch_from_ckpt(
2640          initial_epoch, mode=ModeKeys.TRAIN)
2641    return initial_epoch
2642
2643  def _assert_compile_was_called(self):
2644    # Checks whether `compile` has been called. If it has been called,
2645    # then the optimizer is set. This is different from whether the
2646    # model is compiled
2647    # (i.e. whether the model is built and its inputs/outputs are set).
2648    if not self._is_compiled:
2649      raise RuntimeError('You must compile your model before '
2650                         'training/testing. '
2651                         'Use `model.compile(optimizer, loss)`.')
2652
2653  def _set_inputs(self, inputs, outputs=None, training=None):
2654    """This method is for compat with Modelv1. Only inputs are needed here."""
2655    self._set_save_spec(inputs)
2656
2657  @property
2658  def _trackable_saved_model_saver(self):
2659    return model_serialization.ModelSavedModelSaver(self)
2660
2661  def _list_functions_for_serialization(self, serialization_cache):
2662    # SavedModel needs to ignore the execution functions.
2663    train_function = self.train_function
2664    test_function = self.test_function
2665    predict_function = self.predict_function
2666    self.train_function = None
2667    self.test_function = None
2668    self.predict_function = None
2669    functions = super(
2670        Model, self)._list_functions_for_serialization(serialization_cache)
2671    self.train_function = train_function
2672    self.test_function = test_function
2673    self.predict_function = predict_function
2674    return functions
2675
2676  def _should_eval(self, epoch, validation_freq):
2677    if self._cluster_coordinator:
2678      raise NotImplementedError(
2679          'Evaluation in `model.fit` with '
2680          '`ParameterServerStrategy` is not yet supported.')
2681    epoch = epoch + 1  # one-index the user-facing epoch.
2682    if isinstance(validation_freq, int):
2683      return epoch % validation_freq == 0
2684    elif isinstance(validation_freq, list):
2685      return epoch in validation_freq
2686    else:
2687      raise ValueError('Expected `validation_freq` to be a list or int.')
2688
2689  ######################################################################
2690  # Functions below exist only as v1 / v2 compatibility shims.
2691  ######################################################################
2692
2693  def _get_compile_args(self, user_metrics=True):
2694    """Used for saving or cloning a Model.
2695
2696    Args:
2697      user_metrics: Whether to return user-supplied metrics or `Metric` objects.
2698        Defaults to returning the user-supplied metrics.
2699
2700    Returns:
2701      Dictionary of arguments that were used when compiling the model.
2702    """
2703    self._assert_compile_was_called()
2704    # pylint: disable=protected-access
2705
2706    saved_metrics = self.compiled_metrics._user_metrics
2707    saved_weighted_metrics = self.compiled_metrics._user_weighted_metrics
2708
2709    if not user_metrics:
2710      if saved_metrics is not None:
2711        saved_metrics = self.compiled_metrics._metrics
2712      if saved_weighted_metrics is not None:
2713        saved_weighted_metrics = self.compiled_metrics._weighted_metrics
2714
2715    compile_args = {
2716        'optimizer': self.optimizer,
2717        'loss': self.compiled_loss._user_losses,
2718        'metrics': saved_metrics,
2719        'weighted_metrics': saved_weighted_metrics,
2720        'loss_weights': self.compiled_loss._user_loss_weights,
2721    }
2722    # pylint: enable=protected-access
2723    return compile_args
2724
2725  def _get_callback_model(self):
2726    return self
2727
2728  def _in_multi_worker_mode(self):
2729    return self.distribute_strategy.extended._in_multi_worker_mode()  # pylint: disable=protected-access
2730
2731  @property
2732  def _compile_was_called(self):
2733    return self._is_compiled
2734
2735
2736def reduce_per_replica(values, strategy, reduction='first'):
2737  """Reduce PerReplica objects.
2738
2739  Args:
2740    values: Structure of `PerReplica` objects or `Tensor`s. `Tensor`s are
2741      returned as-is.
2742    strategy: `tf.distribute.Strategy` object.
2743    reduction: One of 'first', 'concat'.
2744
2745  Returns:
2746    Structure of `Tensor`s.
2747  """
2748
2749  def _reduce(v):
2750    """Reduce a single `PerReplica` object."""
2751    if reduction == 'concat' and _collective_all_reduce_multi_worker(strategy):
2752      return _multi_worker_concat(v, strategy)
2753    if not isinstance(v, ds_values.PerReplica):
2754      return v
2755    elif reduction == 'first':
2756      return strategy.unwrap(v)[0]
2757    elif reduction == 'concat':
2758      if _is_tpu_multi_host(strategy):
2759        return _tpu_multi_host_concat(v, strategy)
2760      else:
2761        return concat(strategy.unwrap(v))
2762    else:
2763      raise ValueError('`reduction` must be "first" or "concat".')
2764
2765  return nest.map_structure(_reduce, values)
2766
2767
2768def concat(tensors, axis=0):
2769  """Concats `tensor`s along `axis`."""
2770  if isinstance(tensors[0], sparse_tensor.SparseTensor):
2771    return sparse_ops.sparse_concat_v2(axis=axis, sp_inputs=tensors)
2772  return array_ops.concat(tensors, axis=axis)
2773
2774
2775def _is_tpu_multi_host(strategy):
2776  return (backend.is_tpu_strategy(strategy) and
2777          strategy.extended.num_hosts > 1)
2778
2779
2780def _tpu_multi_host_concat(v, strategy):
2781  """Correctly order TPU PerReplica objects."""
2782  replicas = strategy.unwrap(v)
2783  # When distributed datasets are created from Tensors / NumPy,
2784  # TPUStrategy.experimental_distribute_dataset shards data in
2785  # (Replica, Host) order, and TPUStrategy.unwrap returns it in
2786  # (Host, Replica) order.
2787  # TODO(b/150317897): Figure out long-term plan here.
2788  num_replicas_per_host = strategy.extended.num_replicas_per_host
2789  ordered_replicas = []
2790  for replica_id in range(num_replicas_per_host):
2791    ordered_replicas += replicas[replica_id::num_replicas_per_host]
2792  return concat(ordered_replicas)
2793
2794
2795def _collective_all_reduce_multi_worker(strategy):
2796  return (isinstance(strategy,
2797                     collective_all_reduce_strategy.CollectiveAllReduceStrategy)
2798         ) and strategy.extended._in_multi_worker_mode()  # pylint: disable=protected-access
2799
2800
2801# TODO(wxinyi): merge this with _tpu_multi_host_concat once we have all_gather
2802# for all strategies
2803def _multi_worker_concat(v, strategy):
2804  """Order PerReplica objects for CollectiveAllReduceStrategy and concat."""
2805  replicas = strategy.gather(v, axis=0)
2806  # v might not have the same shape on different replicas
2807  if isinstance(v, ds_values.PerReplica):
2808    shapes = array_ops.concat([
2809        array_ops.expand_dims_v2(array_ops.shape(single_value)[0], axis=0)
2810        for single_value in v.values
2811    ],
2812                              axis=0)
2813    all_shapes = strategy.gather(shapes, axis=0)
2814  else:
2815    # v is a tensor. This may happen when, say, we have 2x1 multi-worker.
2816    all_shapes = strategy.gather(
2817        array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0), axis=0)
2818
2819  replicas = array_ops.split(
2820      replicas,
2821      num_or_size_splits=all_shapes,
2822      num=strategy.num_replicas_in_sync)
2823  ordered_replicas = []
2824  num_replicas_per_worker = len(strategy.extended.worker_devices)
2825  for replica_id in range(num_replicas_per_worker):
2826    ordered_replicas += replicas[replica_id::num_replicas_per_worker]
2827  return concat(ordered_replicas)
2828
2829
2830def _is_scalar(x):
2831  return isinstance(x, (ops.Tensor, variables.Variable)) and x.shape.rank == 0
2832
2833
2834def write_scalar_summaries(logs, step):
2835  for name, value in logs.items():
2836    if _is_scalar(value):
2837      summary_ops_v2.scalar('batch_' + name, value, step=step)
2838
2839
2840def _minimum_control_deps(outputs):
2841  """Returns the minimum control dependencies to ensure step succeeded."""
2842  if context.executing_eagerly():
2843    return []  # Control dependencies not needed.
2844  outputs = nest.flatten(outputs, expand_composites=True)
2845  for out in outputs:
2846    # Variables can't be control dependencies.
2847    if not isinstance(out, variables.Variable):
2848      return [out]  # Return first Tensor or Op from outputs.
2849  return []  # No viable Tensor or Op to use for control deps.
2850
2851
2852def _disallow_inside_tf_function(method_name):
2853  if ops.inside_function():
2854    error_msg = (
2855        'Detected a call to `Model.{method_name}` inside a `tf.function`. '
2856        '`Model.{method_name} is a high-level endpoint that manages its own '
2857        '`tf.function`. Please move the call to `Model.{method_name}` outside '
2858        'of all enclosing `tf.function`s. Note that you can call a `Model` '
2859        'directly on `Tensor`s inside a `tf.function` like: `model(x)`.'
2860    ).format(method_name=method_name)
2861    raise RuntimeError(error_msg)
2862
2863
2864def _detect_save_format(filepath):
2865  """Returns path to weights file and save format."""
2866
2867  filepath = path_to_string(filepath)
2868  if saving_utils.is_hdf5_filepath(filepath):
2869    return filepath, 'h5'
2870
2871  # Filepath could be a TensorFlow checkpoint file prefix or SavedModel
2872  # directory. It's possible for filepath to be both a prefix and directory.
2873  # Prioritize checkpoint over SavedModel.
2874  if _is_readable_tf_checkpoint(filepath):
2875    save_format = 'tf'
2876  elif sm_loader.contains_saved_model(filepath):
2877    ckpt_path = os.path.join(filepath, sm_constants.VARIABLES_DIRECTORY,
2878                             sm_constants.VARIABLES_FILENAME)
2879    if _is_readable_tf_checkpoint(ckpt_path):
2880      filepath = ckpt_path
2881      save_format = 'tf'
2882    else:
2883      raise ValueError('Unable to load weights. filepath {} appears to be a '
2884                       'SavedModel directory, but checkpoint either doesn\'t '
2885                       'exist, or is incorrectly formatted.'.format(filepath))
2886  else:
2887    # Not a TensorFlow checkpoint. This filepath is likely an H5 file that
2888    # doesn't have the hdf5/keras extensions.
2889    save_format = 'h5'
2890  return filepath, save_format
2891
2892
2893def _is_readable_tf_checkpoint(filepath):
2894  try:
2895    py_checkpoint_reader.NewCheckpointReader(filepath)
2896    return True
2897  except errors_impl.DataLossError:
2898    # The checkpoint is not readable in TensorFlow format.
2899    return False
2900