1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Training-related part of the Keras engine. 16""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22import itertools 23import json 24import os 25import warnings 26 27import six 28 29from tensorflow.python.autograph.lang import directives 30from tensorflow.python.data.experimental.ops import distribute_options 31from tensorflow.python.data.ops import dataset_ops 32from tensorflow.python.distribute import collective_all_reduce_strategy 33from tensorflow.python.distribute import distribution_strategy_context as ds_context 34from tensorflow.python.distribute import values as ds_values 35from tensorflow.python.distribute.coordinator import cluster_coordinator 36from tensorflow.python.eager import backprop 37from tensorflow.python.eager import context 38from tensorflow.python.eager import def_function 39from tensorflow.python.framework import errors 40from tensorflow.python.framework import errors_impl 41from tensorflow.python.framework import func_graph 42from tensorflow.python.framework import ops 43from tensorflow.python.framework import sparse_tensor 44from tensorflow.python.framework import tensor_shape 45from tensorflow.python.keras import backend 46from tensorflow.python.keras import callbacks as callbacks_module 47from tensorflow.python.keras import optimizer_v1 48from tensorflow.python.keras import optimizers 49from tensorflow.python.keras.engine import base_layer 50from tensorflow.python.keras.engine import base_layer_utils 51from tensorflow.python.keras.engine import compile_utils 52from tensorflow.python.keras.engine import data_adapter 53from tensorflow.python.keras.engine import training_utils 54from tensorflow.python.keras.mixed_precision import loss_scale_optimizer as lso 55from tensorflow.python.keras.mixed_precision import policy 56from tensorflow.python.keras.saving import hdf5_format 57from tensorflow.python.keras.saving import save 58from tensorflow.python.keras.saving import saving_utils 59from tensorflow.python.keras.saving.saved_model import json_utils 60from tensorflow.python.keras.saving.saved_model import model_serialization 61from tensorflow.python.keras.utils import generic_utils 62from tensorflow.python.keras.utils import layer_utils 63from tensorflow.python.keras.utils import tf_inspect 64from tensorflow.python.keras.utils import tf_utils 65from tensorflow.python.keras.utils import version_utils 66from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite 67from tensorflow.python.keras.utils.io_utils import path_to_string 68from tensorflow.python.keras.utils.mode_keys import ModeKeys 69from tensorflow.python.ops import array_ops 70from tensorflow.python.ops import math_ops 71from tensorflow.python.ops import sparse_ops 72from tensorflow.python.ops import summary_ops_v2 73from tensorflow.python.ops import variables 74from tensorflow.python.platform import tf_logging as logging 75from tensorflow.python.profiler import trace 76from tensorflow.python.saved_model import constants as sm_constants 77from tensorflow.python.saved_model import loader_impl as sm_loader 78from tensorflow.python.training import checkpoint_management 79from tensorflow.python.training import py_checkpoint_reader 80from tensorflow.python.training.tracking import base as trackable 81from tensorflow.python.training.tracking import data_structures 82from tensorflow.python.training.tracking import util as trackable_utils 83from tensorflow.python.util import nest 84from tensorflow.python.util import tf_decorator 85from tensorflow.python.util.tf_export import keras_export 86from tensorflow.tools.docs import doc_controls 87 88 89# pylint: disable=g-import-not-at-top 90try: 91 import h5py 92except ImportError: 93 h5py = None 94 95try: 96 import yaml 97except ImportError: 98 yaml = None 99# pylint: enable=g-import-not-at-top 100 101 102def disable_multi_worker(method): 103 """Decorator that disallows multi-worker use of `method`.""" 104 105 def _method_wrapper(self, *args, **kwargs): 106 if self._in_multi_worker_mode(): # pylint: disable=protected-access 107 raise ValueError('{} is not supported in multi-worker mode.'.format( 108 method.__name__)) 109 return method(self, *args, **kwargs) 110 111 return tf_decorator.make_decorator( 112 target=method, decorator_func=_method_wrapper) 113 114 115def inject_functional_model_class(cls): 116 """Inject `Functional` into the hierarchy of this class if needed.""" 117 from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top 118 from tensorflow.python.keras.engine import training_v1 # pylint: disable=g-import-not-at-top 119 if cls == Model or cls == training_v1.Model: 120 return functional.Functional 121 # In case there is any multiple inheritance, we stop injecting the 122 # class if keras model is not in its class hierarchy. 123 if cls == object: 124 return object 125 126 cls.__bases__ = tuple(inject_functional_model_class(base) 127 for base in cls.__bases__) 128 # Trigger any `__new__` class swapping that needed to happen on `Functional` 129 # but did not because functional was not in the class hierarchy. 130 cls.__new__(cls) 131 132 return cls 133 134 135def is_functional_model_init_params(args, kwargs): 136 return (len(args) == 2 or 137 len(args) == 1 and 'outputs' in kwargs or 138 'inputs' in kwargs and 'outputs' in kwargs) 139 140 141@keras_export('keras.Model', 'keras.models.Model') 142class Model(base_layer.Layer, version_utils.ModelVersionSelector): 143 """`Model` groups layers into an object with training and inference features. 144 145 Args: 146 inputs: The input(s) of the model: a `keras.Input` object or list of 147 `keras.Input` objects. 148 outputs: The output(s) of the model. See Functional API example below. 149 name: String, the name of the model. 150 151 There are two ways to instantiate a `Model`: 152 153 1 - With the "Functional API", where you start from `Input`, 154 you chain layer calls to specify the model's forward pass, 155 and finally you create your model from inputs and outputs: 156 157 ```python 158 import tensorflow as tf 159 160 inputs = tf.keras.Input(shape=(3,)) 161 x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs) 162 outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x) 163 model = tf.keras.Model(inputs=inputs, outputs=outputs) 164 ``` 165 166 2 - By subclassing the `Model` class: in that case, you should define your 167 layers in `__init__` and you should implement the model's forward pass 168 in `call`. 169 170 ```python 171 import tensorflow as tf 172 173 class MyModel(tf.keras.Model): 174 175 def __init__(self): 176 super(MyModel, self).__init__() 177 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) 178 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) 179 180 def call(self, inputs): 181 x = self.dense1(inputs) 182 return self.dense2(x) 183 184 model = MyModel() 185 ``` 186 187 If you subclass `Model`, you can optionally have 188 a `training` argument (boolean) in `call`, which you can use to specify 189 a different behavior in training and inference: 190 191 ```python 192 import tensorflow as tf 193 194 class MyModel(tf.keras.Model): 195 196 def __init__(self): 197 super(MyModel, self).__init__() 198 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) 199 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) 200 self.dropout = tf.keras.layers.Dropout(0.5) 201 202 def call(self, inputs, training=False): 203 x = self.dense1(inputs) 204 if training: 205 x = self.dropout(x, training=training) 206 return self.dense2(x) 207 208 model = MyModel() 209 ``` 210 211 Once the model is created, you can config the model with losses and metrics 212 with `model.compile()`, train the model with `model.fit()`, or use the model 213 to do prediction with `model.predict()`. 214 """ 215 _TF_MODULE_IGNORED_PROPERTIES = frozenset( 216 itertools.chain(('_train_counter', '_test_counter', '_predict_counter', 217 '_steps_per_execution'), 218 base_layer.Layer._TF_MODULE_IGNORED_PROPERTIES)) # pylint: disable=protected-access 219 220 def __new__(cls, *args, **kwargs): 221 # Signature detection 222 if is_functional_model_init_params(args, kwargs) and cls == Model: 223 # Functional model 224 from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top 225 return functional.Functional(skip_init=True, *args, **kwargs) 226 else: 227 return super(Model, cls).__new__(cls, *args, **kwargs) 228 229 @trackable.no_automatic_dependency_tracking 230 def __init__(self, *args, **kwargs): 231 self._is_model_for_instrumentation = True 232 base_layer.keras_api_gauge.get_cell('model').set(True) 233 234 # Special case for Subclassed Functional Model, which we couldn't detect 235 # when __new__ is called. We only realize it is a functional model when it 236 # calls super.__init__ with input and output tensor. 237 from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top 238 if (is_functional_model_init_params(args, kwargs) and 239 not isinstance(self, functional.Functional)): 240 # Filter the kwargs for multiple inheritance. 241 supported_kwargs = ['inputs', 'outputs', 'name', 'trainable', 'skip_init'] 242 model_kwargs = {k: kwargs[k] for k in kwargs if k in supported_kwargs} 243 other_kwargs = {k: kwargs[k] for k in kwargs if k not in supported_kwargs} 244 inject_functional_model_class(self.__class__) 245 functional.Functional.__init__(self, *args, **model_kwargs) 246 247 # In case there is any multiple inheritance here, we need to call the 248 # __init__ for any class that appears after the Functional class. 249 clz_to_init = [] 250 found_functional_class = False 251 for clz in self.__class__.__bases__: 252 if issubclass(clz, functional.Functional): 253 found_functional_class = True 254 continue 255 if found_functional_class: 256 clz_to_init.append(clz) 257 258 if clz_to_init: 259 for clz in clz_to_init: 260 clz.__init__(self, *args, **other_kwargs) 261 elif other_kwargs: 262 # In case there are unused kwargs, we should raise an error to user, in 263 # case they have a typo in the param name. 264 raise TypeError( 265 'The following keyword arguments aren\'t supported: {}'.format( 266 other_kwargs)) 267 return 268 269 base_layer.keras_api_gauge.get_cell('Model subclass').set(True) 270 # The following are implemented as property functions: 271 # self.trainable_weights 272 # self.non_trainable_weights 273 # `inputs` / `outputs` will only appear in kwargs if either are misspelled. 274 generic_utils.validate_kwargs(kwargs, { 275 'trainable', 'dtype', 'dynamic', 'name', 'autocast', 'inputs', 'outputs' 276 }) 277 super(Model, self).__init__(**kwargs) 278 # By default, Model is a subclass model, which is not in graph network. 279 self._is_graph_network = False 280 281 self.inputs = None 282 self.outputs = None 283 self.input_names = None 284 self.output_names = None 285 # stop_training is used by callback to stop training when error happens 286 self.stop_training = False 287 self.history = None 288 # These objects are used in the default `Model.compile`. They are not 289 # guaranteed to be set after `Model.compile` is called, as users can 290 # override compile with custom logic. 291 self.compiled_loss = None 292 self.compiled_metrics = None 293 294 # This is True for Sequential networks and Functional networks. 295 self._compute_output_and_mask_jointly = False 296 297 # Don't reset compilation if already done. This may occur if calling 298 # `__init__` (or `_init_graph_network`) on an already-compiled model 299 # such as a Sequential model. Sequential models may need to rebuild 300 # themselves after compilation. 301 self._maybe_create_attribute('_is_compiled', False) 302 self._maybe_create_attribute('optimizer', None) 303 304 # Model must be created under scope of DistStrat it will be trained with. 305 if ds_context.has_strategy(): 306 self._distribution_strategy = ds_context.get_strategy() 307 else: 308 self._distribution_strategy = None 309 310 self._cluster_coordinator = None 311 312 # Defaults to value of `tf.config.experimental_functions_run_eagerly`. 313 self._run_eagerly = None 314 # Initialize cache attrs. 315 self._reset_compile_cache() 316 317 # Fault-tolerance handler. Set in `ModelCheckpoint`. 318 self._training_state = None 319 self._saved_model_inputs_spec = None 320 self._trackable_saver = ( 321 trackable_utils.saver_with_op_caching(self)) 322 323 self._steps_per_execution = None 324 325 self._init_batch_counters() 326 self._base_model_initialized = True 327 328 @trackable.no_automatic_dependency_tracking 329 def _init_batch_counters(self): 330 # Untracked Variables, used to keep track of mini-batches seen in `fit`, 331 # `evaluate`, and `predict`. 332 agg = variables.VariableAggregationV2.ONLY_FIRST_REPLICA 333 self._train_counter = variables.Variable(0, dtype='int64', aggregation=agg) 334 self._test_counter = variables.Variable(0, dtype='int64', aggregation=agg) 335 self._predict_counter = variables.Variable( 336 0, dtype='int64', aggregation=agg) 337 338 def __setattr__(self, name, value): 339 if not getattr(self, '_self_setattr_tracking', True): 340 super(Model, self).__setattr__(name, value) 341 return 342 343 if all( 344 isinstance(v, (base_layer.Layer, 345 data_structures.TrackableDataStructure)) or 346 base_layer_utils.has_weights(v) for v in nest.flatten(value)): 347 try: 348 self._base_model_initialized 349 except AttributeError: 350 # six.raise_from supresses the original AttributeError from being raised 351 six.raise_from( 352 RuntimeError('It looks like you are subclassing `Model` and you ' 353 'forgot to call `super(YourClass, self).__init__()`.' 354 ' Always start with this line.'), None) 355 356 super(Model, self).__setattr__(name, value) 357 358 @generic_utils.default 359 def build(self, input_shape): 360 """Builds the model based on input shapes received. 361 362 This is to be used for subclassed models, which do not know at instantiation 363 time what their inputs look like. 364 365 This method only exists for users who want to call `model.build()` in a 366 standalone way (as a substitute for calling the model on real data to 367 build it). It will never be called by the framework (and thus it will 368 never throw unexpected errors in an unrelated workflow). 369 370 Args: 371 input_shape: Single tuple, TensorShape, or list/dict of shapes, where 372 shapes are tuples, integers, or TensorShapes. 373 374 Raises: 375 ValueError: 376 1. In case of invalid user-provided data (not of type tuple, 377 list, TensorShape, or dict). 378 2. If the model requires call arguments that are agnostic 379 to the input shapes (positional or kwarg in call signature). 380 3. If not all layers were properly built. 381 4. If float type inputs are not supported within the layers. 382 383 In each of these cases, the user should build their model by calling it 384 on real tensor data. 385 """ 386 if self._is_graph_network: 387 super(Model, self).build(input_shape) 388 return 389 390 if input_shape is None: 391 raise ValueError('Input shape must be defined when calling build on a ' 392 'model subclass network.') 393 valid_types = (tuple, list, tensor_shape.TensorShape, dict) 394 if not isinstance(input_shape, valid_types): 395 raise ValueError('Specified input shape is not one of the valid types. ' 396 'Please specify a batch input shape of type tuple or ' 397 'list of input shapes. User provided ' 398 'input type: {}'.format(type(input_shape))) 399 400 if input_shape and not self.inputs: 401 # We create placeholders for the `None`s in the shape and build the model 402 # in a Graph. Since tf.Variable is compatible with both eager execution 403 # and graph building, the variables created after building the model in 404 # a Graph are still valid when executing eagerly. 405 if context.executing_eagerly(): 406 graph = func_graph.FuncGraph('build_graph') 407 else: 408 graph = backend.get_graph() 409 with graph.as_default(): 410 if (isinstance(input_shape, list) and 411 all(d is None or isinstance(d, int) for d in input_shape)): 412 input_shape = tuple(input_shape) 413 if isinstance(input_shape, list): 414 x = [base_layer_utils.generate_placeholders_from_shape(shape) 415 for shape in input_shape] 416 elif isinstance(input_shape, dict): 417 x = { 418 k: base_layer_utils.generate_placeholders_from_shape(shape) 419 for k, shape in input_shape.items() 420 } 421 else: 422 x = base_layer_utils.generate_placeholders_from_shape(input_shape) 423 424 kwargs = {} 425 call_signature = self._call_full_argspec 426 call_args = call_signature.args 427 # Exclude `self`, `inputs`, and any argument with a default value. 428 if len(call_args) > 2: 429 if call_signature.defaults: 430 call_args = call_args[2:-len(call_signature.defaults)] 431 else: 432 call_args = call_args[2:] 433 for arg in call_args: 434 if arg == 'training': 435 # Case where `training` is a positional arg with no default. 436 kwargs['training'] = False 437 else: 438 # Has invalid call signature with unknown positional arguments. 439 raise ValueError( 440 'Currently, you cannot build your model if it has ' 441 'positional or keyword arguments that are not ' 442 'inputs to the model, but are required for its ' 443 '`call` method. Instead, in order to instantiate ' 444 'and build your model, `call` your model on real ' 445 'tensor data with all expected call arguments.') 446 elif len(call_args) < 2: 447 # Signature without `inputs`. 448 raise ValueError('You can only call `build` on a model if its `call` ' 449 'method accepts an `inputs` argument.') 450 try: 451 self.call(x, **kwargs) 452 except (errors.InvalidArgumentError, TypeError): 453 raise ValueError('You cannot build your model by calling `build` ' 454 'if your layers do not support float type inputs. ' 455 'Instead, in order to instantiate and build your ' 456 'model, `call` your model on real tensor data (of ' 457 'the correct dtype).') 458 super(Model, self).build(input_shape) 459 460 @doc_controls.doc_in_current_and_subclasses 461 def call(self, inputs, training=None, mask=None): 462 """Calls the model on new inputs. 463 464 In this case `call` just reapplies 465 all ops in the graph to the new inputs 466 (e.g. build a new computational graph from the provided inputs). 467 468 Note: This method should not be called directly. It is only meant to be 469 overridden when subclassing `tf.keras.Model`. 470 To call a model on an input, always use the `__call__` method, 471 i.e. `model(inputs)`, which relies on the underlying `call` method. 472 473 Args: 474 inputs: A tensor or list of tensors. 475 training: Boolean or boolean scalar tensor, indicating whether to run 476 the `Network` in training mode or inference mode. 477 mask: A mask or list of masks. A mask can be 478 either a tensor or None (no mask). 479 480 Returns: 481 A tensor if there is a single output, or 482 a list of tensors if there are more than one outputs. 483 """ 484 raise NotImplementedError('When subclassing the `Model` class, you should ' 485 'implement a `call` method.') 486 487 def compile(self, 488 optimizer='rmsprop', 489 loss=None, 490 metrics=None, 491 loss_weights=None, 492 weighted_metrics=None, 493 run_eagerly=None, 494 steps_per_execution=None, 495 **kwargs): 496 """Configures the model for training. 497 498 Args: 499 optimizer: String (name of optimizer) or optimizer instance. See 500 `tf.keras.optimizers`. 501 loss: String (name of objective function), objective function or 502 `tf.keras.losses.Loss` instance. See `tf.keras.losses`. An objective 503 function is any callable with the signature `loss = fn(y_true, 504 y_pred)`, where y_true = ground truth values with shape = 505 `[batch_size, d0, .. dN]`, except sparse loss functions such as sparse 506 categorical crossentropy where shape = `[batch_size, d0, .. dN-1]`. 507 y_pred = predicted values with shape = `[batch_size, d0, .. dN]`. It 508 returns a weighted loss float tensor. If a custom `Loss` instance is 509 used and reduction is set to NONE, return value has the shape 510 [batch_size, d0, .. dN-1] ie. per-sample or per-timestep loss values; 511 otherwise, it is a scalar. If the model has multiple outputs, you can 512 use a different loss on each output by passing a dictionary or a list 513 of losses. The loss value that will be minimized by the model will 514 then be the sum of all individual losses. 515 metrics: List of metrics to be evaluated by the model during training 516 and testing. Each of this can be a string (name of a built-in 517 function), function or a `tf.keras.metrics.Metric` instance. See 518 `tf.keras.metrics`. Typically you will use `metrics=['accuracy']`. A 519 function is any callable with the signature `result = fn(y_true, 520 y_pred)`. To specify different metrics for different outputs of a 521 multi-output model, you could also pass a dictionary, such as 522 `metrics={'output_a': 'accuracy', 'output_b': ['accuracy', 'mse']}`. 523 You can also pass a list (len = len(outputs)) of lists of metrics 524 such as `metrics=[['accuracy'], ['accuracy', 'mse']]` or 525 `metrics=['accuracy', ['accuracy', 'mse']]`. When you pass the 526 strings 'accuracy' or 'acc', we convert this to one of 527 `tf.keras.metrics.BinaryAccuracy`, 528 `tf.keras.metrics.CategoricalAccuracy`, 529 `tf.keras.metrics.SparseCategoricalAccuracy` based on the loss 530 function used and the model output shape. We do a similar 531 conversion for the strings 'crossentropy' and 'ce' as well. 532 loss_weights: Optional list or dictionary specifying scalar coefficients 533 (Python floats) to weight the loss contributions of different model 534 outputs. The loss value that will be minimized by the model will then 535 be the *weighted sum* of all individual losses, weighted by the 536 `loss_weights` coefficients. 537 If a list, it is expected to have a 1:1 mapping to the model's 538 outputs. If a dict, it is expected to map output names (strings) 539 to scalar coefficients. 540 weighted_metrics: List of metrics to be evaluated and weighted by 541 sample_weight or class_weight during training and testing. 542 run_eagerly: Bool. Defaults to `False`. If `True`, this `Model`'s 543 logic will not be wrapped in a `tf.function`. Recommended to leave 544 this as `None` unless your `Model` cannot be run inside a 545 `tf.function`. 546 steps_per_execution: Int. Defaults to 1. The number of batches to 547 run during each `tf.function` call. Running multiple batches 548 inside a single `tf.function` call can greatly improve performance 549 on TPUs or small models with a large Python overhead. 550 At most, one full epoch will be run each 551 execution. If a number larger than the size of the epoch is passed, 552 the execution will be truncated to the size of the epoch. 553 Note that if `steps_per_execution` is set to `N`, 554 `Callback.on_batch_begin` and `Callback.on_batch_end` methods 555 will only be called every `N` batches 556 (i.e. before/after each `tf.function` execution). 557 **kwargs: Arguments supported for backwards compatibility only. 558 559 Raises: 560 ValueError: In case of invalid arguments for 561 `optimizer`, `loss` or `metrics`. 562 """ 563 base_layer.keras_api_gauge.get_cell('compile').set(True) 564 with self.distribute_strategy.scope(): 565 if 'experimental_steps_per_execution' in kwargs: 566 logging.warn('The argument `steps_per_execution` is no longer ' 567 'experimental. Pass `steps_per_execution` instead of ' 568 '`experimental_steps_per_execution`.') 569 if not steps_per_execution: 570 steps_per_execution = kwargs.pop('experimental_steps_per_execution') 571 572 self._validate_compile(optimizer, metrics, **kwargs) 573 self._run_eagerly = run_eagerly 574 575 self.optimizer = self._get_optimizer(optimizer) 576 self.compiled_loss = compile_utils.LossesContainer( 577 loss, loss_weights, output_names=self.output_names) 578 self.compiled_metrics = compile_utils.MetricsContainer( 579 metrics, weighted_metrics, output_names=self.output_names) 580 581 self._configure_steps_per_execution(steps_per_execution or 1) 582 583 # Initializes attrs that are reset each time `compile` is called. 584 self._reset_compile_cache() 585 self._is_compiled = True 586 587 self.loss = loss or {} # Backwards compat. 588 589 def _get_optimizer(self, optimizer): 590 """Wraps `optimizer` in `LossScaleOptimizer` if necessary.""" 591 # The deprecated PolicyV1 has a loss_scale, which we use for backwards 592 # compatibility to match TF 2.3 behavior. The new Policy does not have a 593 # loss_scale, so we use dynamic loss scaling if the mixed_float16 policy is 594 # used. 595 if isinstance(self._dtype_policy, policy.PolicyV1): 596 loss_scale = self._dtype_policy.loss_scale 597 elif self._dtype_policy.name == 'mixed_float16': 598 loss_scale = 'dynamic' 599 else: 600 loss_scale = None 601 602 def _get_single_optimizer(opt): 603 opt = optimizers.get(opt) 604 if (loss_scale is not None and 605 not isinstance(opt, lso.LossScaleOptimizer)): 606 if loss_scale == 'dynamic': 607 opt = lso.LossScaleOptimizer(opt) 608 else: 609 opt = lso.LossScaleOptimizerV1(opt, loss_scale) 610 return opt 611 612 return nest.map_structure(_get_single_optimizer, optimizer) 613 614 @trackable.no_automatic_dependency_tracking 615 def _reset_compile_cache(self): 616 self.train_function = None 617 self.test_function = None 618 self.predict_function = None 619 620 # Used to cache `trainable` attr of `Layer`s for `fit`. 621 self._compiled_trainable_state = self._get_trainable_state() 622 623 @trackable.no_automatic_dependency_tracking 624 def _configure_steps_per_execution(self, steps_per_execution): 625 self._steps_per_execution = variables.Variable( 626 steps_per_execution, 627 dtype='int64', 628 aggregation=variables.VariableAggregationV2.ONLY_FIRST_REPLICA) 629 630 @property 631 def _should_compute_mask(self): 632 return False 633 634 @property 635 def metrics(self): 636 """Returns the model's metrics added using `compile`, `add_metric` APIs. 637 638 Note: Metrics passed to `compile()` are available only after a `keras.Model` 639 has been trained/evaluated on actual data. 640 641 Examples: 642 643 >>> inputs = tf.keras.layers.Input(shape=(3,)) 644 >>> outputs = tf.keras.layers.Dense(2)(inputs) 645 >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs) 646 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"]) 647 >>> [m.name for m in model.metrics] 648 [] 649 650 >>> x = np.random.random((2, 3)) 651 >>> y = np.random.randint(0, 2, (2, 2)) 652 >>> model.fit(x, y) 653 >>> [m.name for m in model.metrics] 654 ['loss', 'mae'] 655 656 >>> inputs = tf.keras.layers.Input(shape=(3,)) 657 >>> d = tf.keras.layers.Dense(2, name='out') 658 >>> output_1 = d(inputs) 659 >>> output_2 = d(inputs) 660 >>> model = tf.keras.models.Model( 661 ... inputs=inputs, outputs=[output_1, output_2]) 662 >>> model.add_metric( 663 ... tf.reduce_sum(output_2), name='mean', aggregation='mean') 664 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"]) 665 >>> model.fit(x, (y, y)) 666 >>> [m.name for m in model.metrics] 667 ['loss', 'out_loss', 'out_1_loss', 'out_mae', 'out_acc', 'out_1_mae', 668 'out_1_acc', 'mean'] 669 670 """ 671 metrics = [] 672 if self._is_compiled: 673 # TODO(omalleyt): Track `LossesContainer` and `MetricsContainer` objects 674 # so that attr names are not load-bearing. 675 if self.compiled_loss is not None: 676 metrics += self.compiled_loss.metrics 677 if self.compiled_metrics is not None: 678 metrics += self.compiled_metrics.metrics 679 680 for l in self._flatten_layers(): 681 metrics.extend(l._metrics) # pylint: disable=protected-access 682 return metrics 683 684 @property 685 def metrics_names(self): 686 """Returns the model's display labels for all outputs. 687 688 Note: `metrics_names` are available only after a `keras.Model` has been 689 trained/evaluated on actual data. 690 691 Examples: 692 693 >>> inputs = tf.keras.layers.Input(shape=(3,)) 694 >>> outputs = tf.keras.layers.Dense(2)(inputs) 695 >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs) 696 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"]) 697 >>> model.metrics_names 698 [] 699 700 >>> x = np.random.random((2, 3)) 701 >>> y = np.random.randint(0, 2, (2, 2)) 702 >>> model.fit(x, y) 703 >>> model.metrics_names 704 ['loss', 'mae'] 705 706 >>> inputs = tf.keras.layers.Input(shape=(3,)) 707 >>> d = tf.keras.layers.Dense(2, name='out') 708 >>> output_1 = d(inputs) 709 >>> output_2 = d(inputs) 710 >>> model = tf.keras.models.Model( 711 ... inputs=inputs, outputs=[output_1, output_2]) 712 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"]) 713 >>> model.fit(x, (y, y)) 714 >>> model.metrics_names 715 ['loss', 'out_loss', 'out_1_loss', 'out_mae', 'out_acc', 'out_1_mae', 716 'out_1_acc'] 717 718 """ 719 720 # This property includes all output names including `loss` and per-output 721 # losses for backward compatibility. 722 return [m.name for m in self.metrics] 723 724 @property 725 def distribute_strategy(self): 726 """The `tf.distribute.Strategy` this model was created under.""" 727 return self._distribution_strategy or ds_context.get_strategy() 728 729 @property 730 def run_eagerly(self): 731 """Settable attribute indicating whether the model should run eagerly. 732 733 Running eagerly means that your model will be run step by step, 734 like Python code. Your model might run slower, but it should become easier 735 for you to debug it by stepping into individual layer calls. 736 737 By default, we will attempt to compile your model to a static graph to 738 deliver the best execution performance. 739 740 Returns: 741 Boolean, whether the model should run eagerly. 742 """ 743 if self.dynamic and self._run_eagerly is False: # pylint:disable=g-bool-id-comparison 744 # TODO(fchollet): consider using py_func to enable this. 745 raise ValueError('Your model contains layers that can only be ' 746 'successfully run in eager execution (layers ' 747 'constructed with `dynamic=True`). ' 748 'You cannot set `run_eagerly=False`.') 749 750 if self._cluster_coordinator and self._run_eagerly: 751 raise ValueError('When using `Model` with `ParameterServerStrategy`, ' 752 '`run_eagerly` is not supported.') 753 754 # Run eagerly logic, by priority: 755 # (1) Dynamic models must be run eagerly. 756 # (2) Explicitly setting run_eagerly causes a Model to be run eagerly. 757 # (3) Not explicitly setting run_eagerly defaults to TF's global setting. 758 return (self.dynamic or self._run_eagerly or 759 (def_function.functions_run_eagerly() and 760 self._run_eagerly is None)) 761 762 @run_eagerly.setter 763 def run_eagerly(self, value): 764 self._run_eagerly = value 765 766 def train_step(self, data): 767 """The logic for one training step. 768 769 This method can be overridden to support custom training logic. 770 This method is called by `Model.make_train_function`. 771 772 This method should contain the mathematical logic for one step of training. 773 This typically includes the forward pass, loss calculation, backpropagation, 774 and metric updates. 775 776 Configuration details for *how* this logic is run (e.g. `tf.function` and 777 `tf.distribute.Strategy` settings), should be left to 778 `Model.make_train_function`, which can also be overridden. 779 780 Args: 781 data: A nested structure of `Tensor`s. 782 783 Returns: 784 A `dict` containing values that will be passed to 785 `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the 786 values of the `Model`'s metrics are returned. Example: 787 `{'loss': 0.2, 'accuracy': 0.7}`. 788 789 """ 790 # These are the only transformations `Model.fit` applies to user-input 791 # data when a `tf.data.Dataset` is provided. 792 data = data_adapter.expand_1d(data) 793 x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) 794 795 with backprop.GradientTape() as tape: 796 y_pred = self(x, training=True) 797 loss = self.compiled_loss( 798 y, y_pred, sample_weight, regularization_losses=self.losses) 799 self.optimizer.minimize(loss, self.trainable_variables, tape=tape) 800 self.compiled_metrics.update_state(y, y_pred, sample_weight) 801 return {m.name: m.result() for m in self.metrics} 802 803 def make_train_function(self): 804 """Creates a function that executes one step of training. 805 806 This method can be overridden to support custom training logic. 807 This method is called by `Model.fit` and `Model.train_on_batch`. 808 809 Typically, this method directly controls `tf.function` and 810 `tf.distribute.Strategy` settings, and delegates the actual training 811 logic to `Model.train_step`. 812 813 This function is cached the first time `Model.fit` or 814 `Model.train_on_batch` is called. The cache is cleared whenever 815 `Model.compile` is called. 816 817 Returns: 818 Function. The function created by this method should accept a 819 `tf.data.Iterator`, and return a `dict` containing values that will 820 be passed to `tf.keras.Callbacks.on_train_batch_end`, such as 821 `{'loss': 0.2, 'accuracy': 0.7}`. 822 """ 823 if self.train_function is not None: 824 return self.train_function 825 826 def step_function(model, iterator): 827 """Runs a single training step.""" 828 829 def run_step(data): 830 outputs = model.train_step(data) 831 # Ensure counter is updated only if `train_step` succeeds. 832 with ops.control_dependencies(_minimum_control_deps(outputs)): 833 model._train_counter.assign_add(1) # pylint: disable=protected-access 834 return outputs 835 836 data = next(iterator) 837 outputs = model.distribute_strategy.run(run_step, args=(data,)) 838 outputs = reduce_per_replica( 839 outputs, self.distribute_strategy, reduction='first') 840 write_scalar_summaries(outputs, step=model._train_counter) # pylint: disable=protected-access 841 return outputs 842 843 if self._steps_per_execution.numpy().item() == 1: 844 845 def train_function(iterator): 846 """Runs a training execution with one step.""" 847 return step_function(self, iterator) 848 849 else: 850 851 def train_function(iterator): 852 """Runs a training execution with multiple steps.""" 853 for _ in math_ops.range(self._steps_per_execution): 854 outputs = step_function(self, iterator) 855 return outputs 856 857 if not self.run_eagerly: 858 train_function = def_function.function( 859 train_function, experimental_relax_shapes=True) 860 861 self.train_function = train_function 862 863 if self._cluster_coordinator: 864 self.train_function = lambda iterator: self._cluster_coordinator.schedule( # pylint: disable=g-long-lambda 865 train_function, args=(iterator,)) 866 867 return self.train_function 868 869 def fit(self, 870 x=None, 871 y=None, 872 batch_size=None, 873 epochs=1, 874 verbose=1, 875 callbacks=None, 876 validation_split=0., 877 validation_data=None, 878 shuffle=True, 879 class_weight=None, 880 sample_weight=None, 881 initial_epoch=0, 882 steps_per_epoch=None, 883 validation_steps=None, 884 validation_batch_size=None, 885 validation_freq=1, 886 max_queue_size=10, 887 workers=1, 888 use_multiprocessing=False): 889 """Trains the model for a fixed number of epochs (iterations on a dataset). 890 891 Args: 892 x: Input data. It could be: 893 - A Numpy array (or array-like), or a list of arrays 894 (in case the model has multiple inputs). 895 - A TensorFlow tensor, or a list of tensors 896 (in case the model has multiple inputs). 897 - A dict mapping input names to the corresponding array/tensors, 898 if the model has named inputs. 899 - A `tf.data` dataset. Should return a tuple 900 of either `(inputs, targets)` or 901 `(inputs, targets, sample_weights)`. 902 - A generator or `keras.utils.Sequence` returning `(inputs, targets)` 903 or `(inputs, targets, sample_weights)`. 904 A more detailed description of unpacking behavior for iterator types 905 (Dataset, generator, Sequence) is given below. 906 y: Target data. Like the input data `x`, 907 it could be either Numpy array(s) or TensorFlow tensor(s). 908 It should be consistent with `x` (you cannot have Numpy inputs and 909 tensor targets, or inversely). If `x` is a dataset, generator, 910 or `keras.utils.Sequence` instance, `y` should 911 not be specified (since targets will be obtained from `x`). 912 batch_size: Integer or `None`. 913 Number of samples per gradient update. 914 If unspecified, `batch_size` will default to 32. 915 Do not specify the `batch_size` if your data is in the 916 form of datasets, generators, or `keras.utils.Sequence` instances 917 (since they generate batches). 918 epochs: Integer. Number of epochs to train the model. 919 An epoch is an iteration over the entire `x` and `y` 920 data provided. 921 Note that in conjunction with `initial_epoch`, 922 `epochs` is to be understood as "final epoch". 923 The model is not trained for a number of iterations 924 given by `epochs`, but merely until the epoch 925 of index `epochs` is reached. 926 verbose: 0, 1, or 2. Verbosity mode. 927 0 = silent, 1 = progress bar, 2 = one line per epoch. 928 Note that the progress bar is not particularly useful when 929 logged to a file, so verbose=2 is recommended when not running 930 interactively (eg, in a production environment). 931 callbacks: List of `keras.callbacks.Callback` instances. 932 List of callbacks to apply during training. 933 See `tf.keras.callbacks`. Note `tf.keras.callbacks.ProgbarLogger` 934 and `tf.keras.callbacks.History` callbacks are created automatically 935 and need not be passed into `model.fit`. 936 `tf.keras.callbacks.ProgbarLogger` is created or not based on 937 `verbose` argument to `model.fit`. 938 validation_split: Float between 0 and 1. 939 Fraction of the training data to be used as validation data. 940 The model will set apart this fraction of the training data, 941 will not train on it, and will evaluate 942 the loss and any model metrics 943 on this data at the end of each epoch. 944 The validation data is selected from the last samples 945 in the `x` and `y` data provided, before shuffling. This argument is 946 not supported when `x` is a dataset, generator or 947 `keras.utils.Sequence` instance. 948 validation_data: Data on which to evaluate 949 the loss and any model metrics at the end of each epoch. 950 The model will not be trained on this data. Thus, note the fact 951 that the validation loss of data provided using `validation_split` 952 or `validation_data` is not affected by regularization layers like 953 noise and dropout. 954 `validation_data` will override `validation_split`. 955 `validation_data` could be: 956 - tuple `(x_val, y_val)` of Numpy arrays or tensors 957 - tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays 958 - dataset 959 For the first two cases, `batch_size` must be provided. 960 For the last case, `validation_steps` could be provided. 961 Note that `validation_data` does not support all the data types that 962 are supported in `x`, eg, dict, generator or `keras.utils.Sequence`. 963 shuffle: Boolean (whether to shuffle the training data 964 before each epoch) or str (for 'batch'). This argument is ignored 965 when `x` is a generator or an object of tf.data.Dataset. 966 'batch' is a special option for dealing 967 with the limitations of HDF5 data; it shuffles in batch-sized 968 chunks. Has no effect when `steps_per_epoch` is not `None`. 969 class_weight: Optional dictionary mapping class indices (integers) 970 to a weight (float) value, used for weighting the loss function 971 (during training only). 972 This can be useful to tell the model to 973 "pay more attention" to samples from 974 an under-represented class. 975 sample_weight: Optional Numpy array of weights for 976 the training samples, used for weighting the loss function 977 (during training only). You can either pass a flat (1D) 978 Numpy array with the same length as the input samples 979 (1:1 mapping between weights and samples), 980 or in the case of temporal data, 981 you can pass a 2D array with shape 982 `(samples, sequence_length)`, 983 to apply a different weight to every timestep of every sample. This 984 argument is not supported when `x` is a dataset, generator, or 985 `keras.utils.Sequence` instance, instead provide the sample_weights 986 as the third element of `x`. 987 initial_epoch: Integer. 988 Epoch at which to start training 989 (useful for resuming a previous training run). 990 steps_per_epoch: Integer or `None`. 991 Total number of steps (batches of samples) 992 before declaring one epoch finished and starting the 993 next epoch. When training with input tensors such as 994 TensorFlow data tensors, the default `None` is equal to 995 the number of samples in your dataset divided by 996 the batch size, or 1 if that cannot be determined. If x is a 997 `tf.data` dataset, and 'steps_per_epoch' 998 is None, the epoch will run until the input dataset is exhausted. 999 When passing an infinitely repeating dataset, you must specify the 1000 `steps_per_epoch` argument. This argument is not supported with 1001 array inputs. 1002 validation_steps: Only relevant if `validation_data` is provided and 1003 is a `tf.data` dataset. Total number of steps (batches of 1004 samples) to draw before stopping when performing validation 1005 at the end of every epoch. If 'validation_steps' is None, validation 1006 will run until the `validation_data` dataset is exhausted. In the 1007 case of an infinitely repeated dataset, it will run into an 1008 infinite loop. If 'validation_steps' is specified and only part of 1009 the dataset will be consumed, the evaluation will start from the 1010 beginning of the dataset at each epoch. This ensures that the same 1011 validation samples are used every time. 1012 validation_batch_size: Integer or `None`. 1013 Number of samples per validation batch. 1014 If unspecified, will default to `batch_size`. 1015 Do not specify the `validation_batch_size` if your data is in the 1016 form of datasets, generators, or `keras.utils.Sequence` instances 1017 (since they generate batches). 1018 validation_freq: Only relevant if validation data is provided. Integer 1019 or `collections.abc.Container` instance (e.g. list, tuple, etc.). 1020 If an integer, specifies how many training epochs to run before a 1021 new validation run is performed, e.g. `validation_freq=2` runs 1022 validation every 2 epochs. If a Container, specifies the epochs on 1023 which to run validation, e.g. `validation_freq=[1, 2, 10]` runs 1024 validation at the end of the 1st, 2nd, and 10th epochs. 1025 max_queue_size: Integer. Used for generator or `keras.utils.Sequence` 1026 input only. Maximum size for the generator queue. 1027 If unspecified, `max_queue_size` will default to 10. 1028 workers: Integer. Used for generator or `keras.utils.Sequence` input 1029 only. Maximum number of processes to spin up 1030 when using process-based threading. If unspecified, `workers` 1031 will default to 1. If 0, will execute the generator on the main 1032 thread. 1033 use_multiprocessing: Boolean. Used for generator or 1034 `keras.utils.Sequence` input only. If `True`, use process-based 1035 threading. If unspecified, `use_multiprocessing` will default to 1036 `False`. Note that because this implementation relies on 1037 multiprocessing, you should not pass non-picklable arguments to 1038 the generator as they can't be passed easily to children processes. 1039 1040 Unpacking behavior for iterator-like inputs: 1041 A common pattern is to pass a tf.data.Dataset, generator, or 1042 tf.keras.utils.Sequence to the `x` argument of fit, which will in fact 1043 yield not only features (x) but optionally targets (y) and sample weights. 1044 Keras requires that the output of such iterator-likes be unambiguous. The 1045 iterator should return a tuple of length 1, 2, or 3, where the optional 1046 second and third elements will be used for y and sample_weight 1047 respectively. Any other type provided will be wrapped in a length one 1048 tuple, effectively treating everything as 'x'. When yielding dicts, they 1049 should still adhere to the top-level tuple structure. 1050 e.g. `({"x0": x0, "x1": x1}, y)`. Keras will not attempt to separate 1051 features, targets, and weights from the keys of a single dict. 1052 A notable unsupported data type is the namedtuple. The reason is that 1053 it behaves like both an ordered datatype (tuple) and a mapping 1054 datatype (dict). So given a namedtuple of the form: 1055 `namedtuple("example_tuple", ["y", "x"])` 1056 it is ambiguous whether to reverse the order of the elements when 1057 interpreting the value. Even worse is a tuple of the form: 1058 `namedtuple("other_tuple", ["x", "y", "z"])` 1059 where it is unclear if the tuple was intended to be unpacked into x, y, 1060 and sample_weight or passed through as a single element to `x`. As a 1061 result the data processing code will simply raise a ValueError if it 1062 encounters a namedtuple. (Along with instructions to remedy the issue.) 1063 1064 Returns: 1065 A `History` object. Its `History.history` attribute is 1066 a record of training loss values and metrics values 1067 at successive epochs, as well as validation loss values 1068 and validation metrics values (if applicable). 1069 1070 Raises: 1071 RuntimeError: 1. If the model was never compiled or, 1072 2. If `model.fit` is wrapped in `tf.function`. 1073 1074 ValueError: In case of mismatch between the provided input data 1075 and what the model expects or when the input data is empty. 1076 """ 1077 base_layer.keras_api_gauge.get_cell('fit').set(True) 1078 # Legacy graph support is contained in `training_v1.Model`. 1079 version_utils.disallow_legacy_graph('Model', 'fit') 1080 self._assert_compile_was_called() 1081 self._check_call_args('fit') 1082 _disallow_inside_tf_function('fit') 1083 1084 if validation_split: 1085 # Create the validation data using the training data. Only supported for 1086 # `Tensor` and `NumPy` input. 1087 (x, y, sample_weight), validation_data = ( 1088 data_adapter.train_validation_split( 1089 (x, y, sample_weight), validation_split=validation_split)) 1090 1091 if validation_data: 1092 val_x, val_y, val_sample_weight = ( 1093 data_adapter.unpack_x_y_sample_weight(validation_data)) 1094 1095 if self.distribute_strategy._should_use_with_coordinator: # pylint: disable=protected-access 1096 self._cluster_coordinator = cluster_coordinator.ClusterCoordinator( 1097 self.distribute_strategy) 1098 1099 with self.distribute_strategy.scope(), \ 1100 training_utils.RespectCompiledTrainableState(self): 1101 # Creates a `tf.data.Dataset` and handles batch and epoch iteration. 1102 data_handler = data_adapter.get_data_handler( 1103 x=x, 1104 y=y, 1105 sample_weight=sample_weight, 1106 batch_size=batch_size, 1107 steps_per_epoch=steps_per_epoch, 1108 initial_epoch=initial_epoch, 1109 epochs=epochs, 1110 shuffle=shuffle, 1111 class_weight=class_weight, 1112 max_queue_size=max_queue_size, 1113 workers=workers, 1114 use_multiprocessing=use_multiprocessing, 1115 model=self, 1116 steps_per_execution=self._steps_per_execution) 1117 1118 # Container that configures and calls `tf.keras.Callback`s. 1119 if not isinstance(callbacks, callbacks_module.CallbackList): 1120 callbacks = callbacks_module.CallbackList( 1121 callbacks, 1122 add_history=True, 1123 add_progbar=verbose != 0, 1124 model=self, 1125 verbose=verbose, 1126 epochs=epochs, 1127 steps=data_handler.inferred_steps) 1128 1129 self.stop_training = False 1130 self.train_function = self.make_train_function() 1131 self._train_counter.assign(0) 1132 callbacks.on_train_begin() 1133 training_logs = None 1134 # Handle fault-tolerance for multi-worker. 1135 # TODO(omalleyt): Fix the ordering issues that mean this has to 1136 # happen after `callbacks.on_train_begin`. 1137 data_handler._initial_epoch = ( # pylint: disable=protected-access 1138 self._maybe_load_initial_epoch_from_ckpt(initial_epoch)) 1139 logs = None 1140 for epoch, iterator in data_handler.enumerate_epochs(): 1141 self.reset_metrics() 1142 callbacks.on_epoch_begin(epoch) 1143 with data_handler.catch_stop_iteration(): 1144 for step in data_handler.steps(): 1145 with trace.Trace( 1146 'train', 1147 epoch_num=epoch, 1148 step_num=step, 1149 batch_size=batch_size, 1150 _r=1): 1151 callbacks.on_train_batch_begin(step) 1152 tmp_logs = self.train_function(iterator) 1153 if data_handler.should_sync: 1154 context.async_wait() 1155 logs = tmp_logs # No error, now safe to assign to logs. 1156 end_step = step + data_handler.step_increment 1157 callbacks.on_train_batch_end(end_step, logs) 1158 if self.stop_training: 1159 break 1160 1161 logs = data_handler.resolve_logs(logs) 1162 if logs is None: 1163 raise ValueError('Expect x to be a non-empty array or dataset.') 1164 epoch_logs = copy.copy(logs) 1165 1166 # Run validation. 1167 if validation_data and self._should_eval(epoch, validation_freq): 1168 # Create data_handler for evaluation and cache it. 1169 if getattr(self, '_eval_data_handler', None) is None: 1170 self._fit_frame = tf_inspect.currentframe() 1171 self._eval_data_handler = data_adapter.get_data_handler( 1172 x=val_x, 1173 y=val_y, 1174 sample_weight=val_sample_weight, 1175 batch_size=validation_batch_size or batch_size, 1176 steps_per_epoch=validation_steps, 1177 initial_epoch=0, 1178 epochs=1, 1179 max_queue_size=max_queue_size, 1180 workers=workers, 1181 use_multiprocessing=use_multiprocessing, 1182 model=self, 1183 steps_per_execution=self._steps_per_execution) 1184 val_logs = self.evaluate( 1185 x=val_x, 1186 y=val_y, 1187 sample_weight=val_sample_weight, 1188 batch_size=validation_batch_size or batch_size, 1189 steps=validation_steps, 1190 callbacks=callbacks, 1191 max_queue_size=max_queue_size, 1192 workers=workers, 1193 use_multiprocessing=use_multiprocessing, 1194 return_dict=True) 1195 val_logs = {'val_' + name: val for name, val in val_logs.items()} 1196 epoch_logs.update(val_logs) 1197 1198 callbacks.on_epoch_end(epoch, epoch_logs) 1199 training_logs = epoch_logs 1200 if self.stop_training: 1201 break 1202 1203 # If eval data_hanlder exists, delete it after all epochs are done. 1204 if getattr(self, '_eval_data_handler', None) is not None: 1205 del self._eval_data_handler 1206 del self._fit_frame 1207 callbacks.on_train_end(logs=training_logs) 1208 return self.history 1209 1210 def test_step(self, data): 1211 """The logic for one evaluation step. 1212 1213 This method can be overridden to support custom evaluation logic. 1214 This method is called by `Model.make_test_function`. 1215 1216 This function should contain the mathematical logic for one step of 1217 evaluation. 1218 This typically includes the forward pass, loss calculation, and metrics 1219 updates. 1220 1221 Configuration details for *how* this logic is run (e.g. `tf.function` and 1222 `tf.distribute.Strategy` settings), should be left to 1223 `Model.make_test_function`, which can also be overridden. 1224 1225 Args: 1226 data: A nested structure of `Tensor`s. 1227 1228 Returns: 1229 A `dict` containing values that will be passed to 1230 `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the 1231 values of the `Model`'s metrics are returned. 1232 """ 1233 data = data_adapter.expand_1d(data) 1234 x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) 1235 1236 y_pred = self(x, training=False) 1237 # Updates stateful loss metrics. 1238 self.compiled_loss( 1239 y, y_pred, sample_weight, regularization_losses=self.losses) 1240 1241 self.compiled_metrics.update_state(y, y_pred, sample_weight) 1242 return {m.name: m.result() for m in self.metrics} 1243 1244 def make_test_function(self): 1245 """Creates a function that executes one step of evaluation. 1246 1247 This method can be overridden to support custom evaluation logic. 1248 This method is called by `Model.evaluate` and `Model.test_on_batch`. 1249 1250 Typically, this method directly controls `tf.function` and 1251 `tf.distribute.Strategy` settings, and delegates the actual evaluation 1252 logic to `Model.test_step`. 1253 1254 This function is cached the first time `Model.evaluate` or 1255 `Model.test_on_batch` is called. The cache is cleared whenever 1256 `Model.compile` is called. 1257 1258 Returns: 1259 Function. The function created by this method should accept a 1260 `tf.data.Iterator`, and return a `dict` containing values that will 1261 be passed to `tf.keras.Callbacks.on_test_batch_end`. 1262 """ 1263 if self.test_function is not None: 1264 return self.test_function 1265 1266 def step_function(model, iterator): 1267 """Runs a single evaluation step.""" 1268 1269 def run_step(data): 1270 outputs = model.test_step(data) 1271 # Ensure counter is updated only if `test_step` succeeds. 1272 with ops.control_dependencies(_minimum_control_deps(outputs)): 1273 model._test_counter.assign_add(1) # pylint: disable=protected-access 1274 return outputs 1275 1276 data = next(iterator) 1277 outputs = model.distribute_strategy.run(run_step, args=(data,)) 1278 outputs = reduce_per_replica( 1279 outputs, self.distribute_strategy, reduction='first') 1280 return outputs 1281 1282 if self._steps_per_execution.numpy().item() == 1: 1283 1284 def test_function(iterator): 1285 """Runs an evaluation execution with one step.""" 1286 return step_function(self, iterator) 1287 1288 else: 1289 1290 def test_function(iterator): 1291 """Runs an evaluation execution with multiple steps.""" 1292 for _ in math_ops.range(self._steps_per_execution): 1293 outputs = step_function(self, iterator) 1294 return outputs 1295 1296 if not self.run_eagerly: 1297 test_function = def_function.function( 1298 test_function, experimental_relax_shapes=True) 1299 1300 self.test_function = test_function 1301 return self.test_function 1302 1303 def evaluate(self, 1304 x=None, 1305 y=None, 1306 batch_size=None, 1307 verbose=1, 1308 sample_weight=None, 1309 steps=None, 1310 callbacks=None, 1311 max_queue_size=10, 1312 workers=1, 1313 use_multiprocessing=False, 1314 return_dict=False): 1315 """Returns the loss value & metrics values for the model in test mode. 1316 1317 Computation is done in batches (see the `batch_size` arg.) 1318 1319 Args: 1320 x: Input data. It could be: 1321 - A Numpy array (or array-like), or a list of arrays 1322 (in case the model has multiple inputs). 1323 - A TensorFlow tensor, or a list of tensors 1324 (in case the model has multiple inputs). 1325 - A dict mapping input names to the corresponding array/tensors, 1326 if the model has named inputs. 1327 - A `tf.data` dataset. Should return a tuple 1328 of either `(inputs, targets)` or 1329 `(inputs, targets, sample_weights)`. 1330 - A generator or `keras.utils.Sequence` returning `(inputs, targets)` 1331 or `(inputs, targets, sample_weights)`. 1332 A more detailed description of unpacking behavior for iterator types 1333 (Dataset, generator, Sequence) is given in the `Unpacking behavior 1334 for iterator-like inputs` section of `Model.fit`. 1335 y: Target data. Like the input data `x`, it could be either Numpy 1336 array(s) or TensorFlow tensor(s). It should be consistent with `x` 1337 (you cannot have Numpy inputs and tensor targets, or inversely). If 1338 `x` is a dataset, generator or `keras.utils.Sequence` instance, `y` 1339 should not be specified (since targets will be obtained from the 1340 iterator/dataset). 1341 batch_size: Integer or `None`. Number of samples per batch of 1342 computation. If unspecified, `batch_size` will default to 32. Do not 1343 specify the `batch_size` if your data is in the form of a dataset, 1344 generators, or `keras.utils.Sequence` instances (since they generate 1345 batches). 1346 verbose: 0 or 1. Verbosity mode. 0 = silent, 1 = progress bar. 1347 sample_weight: Optional Numpy array of weights for the test samples, 1348 used for weighting the loss function. You can either pass a flat (1D) 1349 Numpy array with the same length as the input samples 1350 (1:1 mapping between weights and samples), or in the case of 1351 temporal data, you can pass a 2D array with shape `(samples, 1352 sequence_length)`, to apply a different weight to every timestep 1353 of every sample. This argument is not supported when `x` is a 1354 dataset, instead pass sample weights as the third element of `x`. 1355 steps: Integer or `None`. Total number of steps (batches of samples) 1356 before declaring the evaluation round finished. Ignored with the 1357 default value of `None`. If x is a `tf.data` dataset and `steps` is 1358 None, 'evaluate' will run until the dataset is exhausted. This 1359 argument is not supported with array inputs. 1360 callbacks: List of `keras.callbacks.Callback` instances. List of 1361 callbacks to apply during evaluation. See 1362 [callbacks](/api_docs/python/tf/keras/callbacks). 1363 max_queue_size: Integer. Used for generator or `keras.utils.Sequence` 1364 input only. Maximum size for the generator queue. If unspecified, 1365 `max_queue_size` will default to 10. 1366 workers: Integer. Used for generator or `keras.utils.Sequence` input 1367 only. Maximum number of processes to spin up when using process-based 1368 threading. If unspecified, `workers` will default to 1. If 0, will 1369 execute the generator on the main thread. 1370 use_multiprocessing: Boolean. Used for generator or 1371 `keras.utils.Sequence` input only. If `True`, use process-based 1372 threading. If unspecified, `use_multiprocessing` will default to 1373 `False`. Note that because this implementation relies on 1374 multiprocessing, you should not pass non-picklable arguments to the 1375 generator as they can't be passed easily to children processes. 1376 return_dict: If `True`, loss and metric results are returned as a dict, 1377 with each key being the name of the metric. If `False`, they are 1378 returned as a list. 1379 1380 See the discussion of `Unpacking behavior for iterator-like inputs` for 1381 `Model.fit`. 1382 1383 Returns: 1384 Scalar test loss (if the model has a single output and no metrics) 1385 or list of scalars (if the model has multiple outputs 1386 and/or metrics). The attribute `model.metrics_names` will give you 1387 the display labels for the scalar outputs. 1388 1389 Raises: 1390 RuntimeError: If `model.evaluate` is wrapped in `tf.function`. 1391 ValueError: in case of invalid arguments. 1392 """ 1393 base_layer.keras_api_gauge.get_cell('evaluate').set(True) 1394 version_utils.disallow_legacy_graph('Model', 'evaluate') 1395 self._assert_compile_was_called() 1396 self._check_call_args('evaluate') 1397 _disallow_inside_tf_function('evaluate') 1398 1399 if self.distribute_strategy._should_use_with_coordinator: # pylint: disable=protected-access 1400 raise NotImplementedError('`model.evaluate` is not yet supported with ' 1401 '`ParameterServerStrategy`.') 1402 1403 with self.distribute_strategy.scope(): 1404 # Use cached evaluation data only when it's called in `Model.fit` 1405 if (getattr(self, '_fit_frame', None) is not None 1406 and tf_inspect.currentframe().f_back is self._fit_frame 1407 and getattr(self, '_eval_data_handler', None) is not None): 1408 data_handler = self._eval_data_handler 1409 else: 1410 # Creates a `tf.data.Dataset` and handles batch and epoch iteration. 1411 data_handler = data_adapter.get_data_handler( 1412 x=x, 1413 y=y, 1414 sample_weight=sample_weight, 1415 batch_size=batch_size, 1416 steps_per_epoch=steps, 1417 initial_epoch=0, 1418 epochs=1, 1419 max_queue_size=max_queue_size, 1420 workers=workers, 1421 use_multiprocessing=use_multiprocessing, 1422 model=self, 1423 steps_per_execution=self._steps_per_execution) 1424 1425 # Container that configures and calls `tf.keras.Callback`s. 1426 if not isinstance(callbacks, callbacks_module.CallbackList): 1427 callbacks = callbacks_module.CallbackList( 1428 callbacks, 1429 add_history=True, 1430 add_progbar=verbose != 0, 1431 model=self, 1432 verbose=verbose, 1433 epochs=1, 1434 steps=data_handler.inferred_steps) 1435 1436 logs = {} 1437 self.test_function = self.make_test_function() 1438 self._test_counter.assign(0) 1439 callbacks.on_test_begin() 1440 for _, iterator in data_handler.enumerate_epochs(): # Single epoch. 1441 self.reset_metrics() 1442 with data_handler.catch_stop_iteration(): 1443 for step in data_handler.steps(): 1444 with trace.Trace('test', step_num=step, _r=1): 1445 callbacks.on_test_batch_begin(step) 1446 tmp_logs = self.test_function(iterator) 1447 if data_handler.should_sync: 1448 context.async_wait() 1449 logs = tmp_logs # No error, now safe to assign to logs. 1450 end_step = step + data_handler.step_increment 1451 callbacks.on_test_batch_end(end_step, logs) 1452 logs = tf_utils.to_numpy_or_python_type(logs) 1453 callbacks.on_test_end(logs=logs) 1454 1455 if return_dict: 1456 return logs 1457 else: 1458 results = [] 1459 for name in self.metrics_names: 1460 if name in logs: 1461 results.append(logs[name]) 1462 for key in sorted(logs.keys()): 1463 if key not in self.metrics_names: 1464 results.append(logs[key]) 1465 if len(results) == 1: 1466 return results[0] 1467 return results 1468 1469 def predict_step(self, data): 1470 """The logic for one inference step. 1471 1472 This method can be overridden to support custom inference logic. 1473 This method is called by `Model.make_predict_function`. 1474 1475 This method should contain the mathematical logic for one step of inference. 1476 This typically includes the forward pass. 1477 1478 Configuration details for *how* this logic is run (e.g. `tf.function` and 1479 `tf.distribute.Strategy` settings), should be left to 1480 `Model.make_predict_function`, which can also be overridden. 1481 1482 Args: 1483 data: A nested structure of `Tensor`s. 1484 1485 Returns: 1486 The result of one inference step, typically the output of calling the 1487 `Model` on data. 1488 """ 1489 data = data_adapter.expand_1d(data) 1490 x, _, _ = data_adapter.unpack_x_y_sample_weight(data) 1491 return self(x, training=False) 1492 1493 def make_predict_function(self): 1494 """Creates a function that executes one step of inference. 1495 1496 This method can be overridden to support custom inference logic. 1497 This method is called by `Model.predict` and `Model.predict_on_batch`. 1498 1499 Typically, this method directly controls `tf.function` and 1500 `tf.distribute.Strategy` settings, and delegates the actual evaluation 1501 logic to `Model.predict_step`. 1502 1503 This function is cached the first time `Model.predict` or 1504 `Model.predict_on_batch` is called. The cache is cleared whenever 1505 `Model.compile` is called. 1506 1507 Returns: 1508 Function. The function created by this method should accept a 1509 `tf.data.Iterator`, and return the outputs of the `Model`. 1510 """ 1511 if self.predict_function is not None: 1512 return self.predict_function 1513 1514 def step_function(model, iterator): 1515 """Runs a single evaluation step.""" 1516 1517 def run_step(data): 1518 outputs = model.predict_step(data) 1519 # Ensure counter is updated only if `test_step` succeeds. 1520 with ops.control_dependencies(_minimum_control_deps(outputs)): 1521 model._predict_counter.assign_add(1) # pylint: disable=protected-access 1522 return outputs 1523 1524 data = next(iterator) 1525 outputs = model.distribute_strategy.run(run_step, args=(data,)) 1526 outputs = reduce_per_replica( 1527 outputs, self.distribute_strategy, reduction='concat') 1528 return outputs 1529 1530 if (self._steps_per_execution is None or 1531 self._steps_per_execution.numpy().item() == 1): 1532 1533 def predict_function(iterator): 1534 """Runs an evaluation execution with one step.""" 1535 return step_function(self, iterator) 1536 1537 else: 1538 1539 def predict_function(iterator): 1540 """Runs an evaluation execution with multiple steps.""" 1541 outputs = step_function(self, iterator) 1542 for _ in math_ops.range(self._steps_per_execution - 1): 1543 directives.set_loop_options( 1544 shape_invariants=[( 1545 t, tf_utils.get_tensor_spec(t, dynamic_batch=True).shape) 1546 for t in nest.flatten(outputs)]) 1547 step_outputs = step_function(self, iterator) 1548 outputs = nest.map_structure(lambda t1, t2: concat([t1, t2]), outputs, 1549 step_outputs) 1550 return outputs 1551 1552 if not self.run_eagerly: 1553 predict_function = def_function.function( 1554 predict_function, experimental_relax_shapes=True) 1555 1556 self.predict_function = predict_function 1557 return self.predict_function 1558 1559 def predict(self, 1560 x, 1561 batch_size=None, 1562 verbose=0, 1563 steps=None, 1564 callbacks=None, 1565 max_queue_size=10, 1566 workers=1, 1567 use_multiprocessing=False): 1568 """Generates output predictions for the input samples. 1569 1570 Computation is done in batches. This method is designed for performance in 1571 large scale inputs. For small amount of inputs that fit in one batch, 1572 directly using `__call__` is recommended for faster execution, e.g., 1573 `model(x)`, or `model(x, training=False)` if you have layers such as 1574 `tf.keras.layers.BatchNormalization` that behaves differently during 1575 inference. Also, note the fact that test loss is not affected by 1576 regularization layers like noise and dropout. 1577 1578 Args: 1579 x: Input samples. It could be: 1580 - A Numpy array (or array-like), or a list of arrays 1581 (in case the model has multiple inputs). 1582 - A TensorFlow tensor, or a list of tensors 1583 (in case the model has multiple inputs). 1584 - A `tf.data` dataset. 1585 - A generator or `keras.utils.Sequence` instance. 1586 A more detailed description of unpacking behavior for iterator types 1587 (Dataset, generator, Sequence) is given in the `Unpacking behavior 1588 for iterator-like inputs` section of `Model.fit`. 1589 batch_size: Integer or `None`. 1590 Number of samples per batch. 1591 If unspecified, `batch_size` will default to 32. 1592 Do not specify the `batch_size` if your data is in the 1593 form of dataset, generators, or `keras.utils.Sequence` instances 1594 (since they generate batches). 1595 verbose: Verbosity mode, 0 or 1. 1596 steps: Total number of steps (batches of samples) 1597 before declaring the prediction round finished. 1598 Ignored with the default value of `None`. If x is a `tf.data` 1599 dataset and `steps` is None, `predict` will 1600 run until the input dataset is exhausted. 1601 callbacks: List of `keras.callbacks.Callback` instances. 1602 List of callbacks to apply during prediction. 1603 See [callbacks](/api_docs/python/tf/keras/callbacks). 1604 max_queue_size: Integer. Used for generator or `keras.utils.Sequence` 1605 input only. Maximum size for the generator queue. 1606 If unspecified, `max_queue_size` will default to 10. 1607 workers: Integer. Used for generator or `keras.utils.Sequence` input 1608 only. Maximum number of processes to spin up when using 1609 process-based threading. If unspecified, `workers` will default 1610 to 1. If 0, will execute the generator on the main thread. 1611 use_multiprocessing: Boolean. Used for generator or 1612 `keras.utils.Sequence` input only. If `True`, use process-based 1613 threading. If unspecified, `use_multiprocessing` will default to 1614 `False`. Note that because this implementation relies on 1615 multiprocessing, you should not pass non-picklable arguments to 1616 the generator as they can't be passed easily to children processes. 1617 1618 See the discussion of `Unpacking behavior for iterator-like inputs` for 1619 `Model.fit`. Note that Model.predict uses the same interpretation rules as 1620 `Model.fit` and `Model.evaluate`, so inputs must be unambiguous for all 1621 three methods. 1622 1623 Returns: 1624 Numpy array(s) of predictions. 1625 1626 Raises: 1627 RuntimeError: If `model.predict` is wrapped in `tf.function`. 1628 ValueError: In case of mismatch between the provided 1629 input data and the model's expectations, 1630 or in case a stateful model receives a number of samples 1631 that is not a multiple of the batch size. 1632 """ 1633 base_layer.keras_api_gauge.get_cell('predict').set(True) 1634 version_utils.disallow_legacy_graph('Model', 'predict') 1635 self._check_call_args('predict') 1636 _disallow_inside_tf_function('predict') 1637 1638 if self.distribute_strategy._should_use_with_coordinator: # pylint: disable=protected-access 1639 raise NotImplementedError('`model.predict` is not yet supported with ' 1640 '`ParameterServerStrategy`.') 1641 1642 outputs = None 1643 with self.distribute_strategy.scope(): 1644 # Creates a `tf.data.Dataset` and handles batch and epoch iteration. 1645 dataset_types = (dataset_ops.DatasetV1, dataset_ops.DatasetV2) 1646 if (self._in_multi_worker_mode() or _is_tpu_multi_host( 1647 self.distribute_strategy)) and isinstance(x, dataset_types): 1648 try: 1649 options = dataset_ops.Options() 1650 data_option = distribute_options.AutoShardPolicy.DATA 1651 options.experimental_distribute.auto_shard_policy = data_option 1652 x = x.with_options(options) 1653 except ValueError: 1654 warnings.warn('Using Model.predict with ' 1655 'MultiWorkerDistributionStrategy or TPUStrategy and ' 1656 'AutoShardPolicy.FILE might lead to out-of-order result' 1657 '. Consider setting it to AutoShardPolicy.DATA.') 1658 1659 data_handler = data_adapter.get_data_handler( 1660 x=x, 1661 batch_size=batch_size, 1662 steps_per_epoch=steps, 1663 initial_epoch=0, 1664 epochs=1, 1665 max_queue_size=max_queue_size, 1666 workers=workers, 1667 use_multiprocessing=use_multiprocessing, 1668 model=self, 1669 steps_per_execution=self._steps_per_execution) 1670 1671 # Container that configures and calls `tf.keras.Callback`s. 1672 if not isinstance(callbacks, callbacks_module.CallbackList): 1673 callbacks = callbacks_module.CallbackList( 1674 callbacks, 1675 add_history=True, 1676 add_progbar=verbose != 0, 1677 model=self, 1678 verbose=verbose, 1679 epochs=1, 1680 steps=data_handler.inferred_steps) 1681 1682 self.predict_function = self.make_predict_function() 1683 self._predict_counter.assign(0) 1684 callbacks.on_predict_begin() 1685 batch_outputs = None 1686 for _, iterator in data_handler.enumerate_epochs(): # Single epoch. 1687 with data_handler.catch_stop_iteration(): 1688 for step in data_handler.steps(): 1689 callbacks.on_predict_batch_begin(step) 1690 tmp_batch_outputs = self.predict_function(iterator) 1691 if data_handler.should_sync: 1692 context.async_wait() 1693 batch_outputs = tmp_batch_outputs # No error, now safe to assign. 1694 if outputs is None: 1695 outputs = nest.map_structure(lambda batch_output: [batch_output], 1696 batch_outputs) 1697 else: 1698 nest.map_structure_up_to( 1699 batch_outputs, 1700 lambda output, batch_output: output.append(batch_output), 1701 outputs, batch_outputs) 1702 end_step = step + data_handler.step_increment 1703 callbacks.on_predict_batch_end(end_step, {'outputs': batch_outputs}) 1704 if batch_outputs is None: 1705 raise ValueError('Expect x to be a non-empty array or dataset.') 1706 callbacks.on_predict_end() 1707 all_outputs = nest.map_structure_up_to(batch_outputs, concat, outputs) 1708 return tf_utils.to_numpy_or_python_type(all_outputs) 1709 1710 def reset_metrics(self): 1711 """Resets the state of all the metrics in the model. 1712 1713 Examples: 1714 1715 >>> inputs = tf.keras.layers.Input(shape=(3,)) 1716 >>> outputs = tf.keras.layers.Dense(2)(inputs) 1717 >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs) 1718 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"]) 1719 1720 >>> x = np.random.random((2, 3)) 1721 >>> y = np.random.randint(0, 2, (2, 2)) 1722 >>> _ = model.fit(x, y, verbose=0) 1723 >>> assert all(float(m.result()) for m in model.metrics) 1724 1725 >>> model.reset_metrics() 1726 >>> assert all(float(m.result()) == 0 for m in model.metrics) 1727 1728 """ 1729 for m in self.metrics: 1730 m.reset_states() 1731 1732 def train_on_batch(self, 1733 x, 1734 y=None, 1735 sample_weight=None, 1736 class_weight=None, 1737 reset_metrics=True, 1738 return_dict=False): 1739 """Runs a single gradient update on a single batch of data. 1740 1741 Args: 1742 x: Input data. It could be: 1743 - A Numpy array (or array-like), or a list of arrays 1744 (in case the model has multiple inputs). 1745 - A TensorFlow tensor, or a list of tensors 1746 (in case the model has multiple inputs). 1747 - A dict mapping input names to the corresponding array/tensors, 1748 if the model has named inputs. 1749 y: Target data. Like the input data `x`, it could be either Numpy 1750 array(s) or TensorFlow tensor(s). It should be consistent with `x` 1751 (you cannot have Numpy inputs and tensor targets, or inversely). 1752 sample_weight: Optional array of the same length as x, containing 1753 weights to apply to the model's loss for each sample. In the case of 1754 temporal data, you can pass a 2D array with shape (samples, 1755 sequence_length), to apply a different weight to every timestep of 1756 every sample. 1757 class_weight: Optional dictionary mapping class indices (integers) to a 1758 weight (float) to apply to the model's loss for the samples from this 1759 class during training. This can be useful to tell the model to "pay 1760 more attention" to samples from an under-represented class. 1761 reset_metrics: If `True`, the metrics returned will be only for this 1762 batch. If `False`, the metrics will be statefully accumulated across 1763 batches. 1764 return_dict: If `True`, loss and metric results are returned as a dict, 1765 with each key being the name of the metric. If `False`, they are 1766 returned as a list. 1767 1768 Returns: 1769 Scalar training loss 1770 (if the model has a single output and no metrics) 1771 or list of scalars (if the model has multiple outputs 1772 and/or metrics). The attribute `model.metrics_names` will give you 1773 the display labels for the scalar outputs. 1774 1775 Raises: 1776 RuntimeError: If `model.train_on_batch` is wrapped in `tf.function`. 1777 ValueError: In case of invalid user-provided arguments. 1778 """ 1779 self._assert_compile_was_called() 1780 self._check_call_args('train_on_batch') 1781 _disallow_inside_tf_function('train_on_batch') 1782 with self.distribute_strategy.scope(), \ 1783 training_utils.RespectCompiledTrainableState(self): 1784 iterator = data_adapter.single_batch_iterator(self.distribute_strategy, x, 1785 y, sample_weight, 1786 class_weight) 1787 self.train_function = self.make_train_function() 1788 logs = self.train_function(iterator) 1789 1790 if reset_metrics: 1791 self.reset_metrics() 1792 logs = tf_utils.to_numpy_or_python_type(logs) 1793 if return_dict: 1794 return logs 1795 else: 1796 results = [logs.get(name, None) for name in self.metrics_names] 1797 if len(results) == 1: 1798 return results[0] 1799 return results 1800 1801 def test_on_batch(self, 1802 x, 1803 y=None, 1804 sample_weight=None, 1805 reset_metrics=True, 1806 return_dict=False): 1807 """Test the model on a single batch of samples. 1808 1809 Args: 1810 x: Input data. It could be: - A Numpy array (or array-like), or a list 1811 of arrays (in case the model has multiple inputs). - A TensorFlow 1812 tensor, or a list of tensors (in case the model has multiple inputs). 1813 - A dict mapping input names to the corresponding array/tensors, if 1814 the model has named inputs. 1815 y: Target data. Like the input data `x`, it could be either Numpy 1816 array(s) or TensorFlow tensor(s). It should be consistent with `x` 1817 (you cannot have Numpy inputs and tensor targets, or inversely). 1818 sample_weight: Optional array of the same length as x, containing 1819 weights to apply to the model's loss for each sample. In the case of 1820 temporal data, you can pass a 2D array with shape (samples, 1821 sequence_length), to apply a different weight to every timestep of 1822 every sample. 1823 reset_metrics: If `True`, the metrics returned will be only for this 1824 batch. If `False`, the metrics will be statefully accumulated across 1825 batches. 1826 return_dict: If `True`, loss and metric results are returned as a dict, 1827 with each key being the name of the metric. If `False`, they are 1828 returned as a list. 1829 1830 Returns: 1831 Scalar test loss (if the model has a single output and no metrics) 1832 or list of scalars (if the model has multiple outputs 1833 and/or metrics). The attribute `model.metrics_names` will give you 1834 the display labels for the scalar outputs. 1835 1836 Raises: 1837 RuntimeError: If `model.test_on_batch` is wrapped in `tf.function`. 1838 ValueError: In case of invalid user-provided arguments. 1839 """ 1840 self._assert_compile_was_called() 1841 self._check_call_args('test_on_batch') 1842 _disallow_inside_tf_function('test_on_batch') 1843 with self.distribute_strategy.scope(): 1844 iterator = data_adapter.single_batch_iterator(self.distribute_strategy, x, 1845 y, sample_weight) 1846 self.test_function = self.make_test_function() 1847 logs = self.test_function(iterator) 1848 1849 if reset_metrics: 1850 self.reset_metrics() 1851 logs = tf_utils.to_numpy_or_python_type(logs) 1852 if return_dict: 1853 return logs 1854 else: 1855 results = [logs.get(name, None) for name in self.metrics_names] 1856 if len(results) == 1: 1857 return results[0] 1858 return results 1859 1860 def predict_on_batch(self, x): 1861 """Returns predictions for a single batch of samples. 1862 1863 Args: 1864 x: Input data. It could be: - A Numpy array (or array-like), or a list 1865 of arrays (in case the model has multiple inputs). - A TensorFlow 1866 tensor, or a list of tensors (in case the model has multiple inputs). 1867 1868 Returns: 1869 Numpy array(s) of predictions. 1870 1871 Raises: 1872 RuntimeError: If `model.predict_on_batch` is wrapped in `tf.function`. 1873 ValueError: In case of mismatch between given number of inputs and 1874 expectations of the model. 1875 """ 1876 self._check_call_args('predict_on_batch') 1877 _disallow_inside_tf_function('predict_on_batch') 1878 with self.distribute_strategy.scope(): 1879 iterator = data_adapter.single_batch_iterator(self.distribute_strategy, x) 1880 self.predict_function = self.make_predict_function() 1881 outputs = self.predict_function(iterator) 1882 return tf_utils.to_numpy_or_python_type(outputs) 1883 1884 def fit_generator(self, 1885 generator, 1886 steps_per_epoch=None, 1887 epochs=1, 1888 verbose=1, 1889 callbacks=None, 1890 validation_data=None, 1891 validation_steps=None, 1892 validation_freq=1, 1893 class_weight=None, 1894 max_queue_size=10, 1895 workers=1, 1896 use_multiprocessing=False, 1897 shuffle=True, 1898 initial_epoch=0): 1899 """Fits the model on data yielded batch-by-batch by a Python generator. 1900 1901 DEPRECATED: 1902 `Model.fit` now supports generators, so there is no longer any need to use 1903 this endpoint. 1904 """ 1905 warnings.warn('`Model.fit_generator` is deprecated and ' 1906 'will be removed in a future version. ' 1907 'Please use `Model.fit`, which supports generators.') 1908 return self.fit( 1909 generator, 1910 steps_per_epoch=steps_per_epoch, 1911 epochs=epochs, 1912 verbose=verbose, 1913 callbacks=callbacks, 1914 validation_data=validation_data, 1915 validation_steps=validation_steps, 1916 validation_freq=validation_freq, 1917 class_weight=class_weight, 1918 max_queue_size=max_queue_size, 1919 workers=workers, 1920 use_multiprocessing=use_multiprocessing, 1921 shuffle=shuffle, 1922 initial_epoch=initial_epoch) 1923 1924 def evaluate_generator(self, 1925 generator, 1926 steps=None, 1927 callbacks=None, 1928 max_queue_size=10, 1929 workers=1, 1930 use_multiprocessing=False, 1931 verbose=0): 1932 """Evaluates the model on a data generator. 1933 1934 DEPRECATED: 1935 `Model.evaluate` now supports generators, so there is no longer any need 1936 to use this endpoint. 1937 """ 1938 warnings.warn('`Model.evaluate_generator` is deprecated and ' 1939 'will be removed in a future version. ' 1940 'Please use `Model.evaluate`, which supports generators.') 1941 self._check_call_args('evaluate_generator') 1942 1943 return self.evaluate( 1944 generator, 1945 steps=steps, 1946 max_queue_size=max_queue_size, 1947 workers=workers, 1948 use_multiprocessing=use_multiprocessing, 1949 verbose=verbose, 1950 callbacks=callbacks) 1951 1952 def predict_generator(self, 1953 generator, 1954 steps=None, 1955 callbacks=None, 1956 max_queue_size=10, 1957 workers=1, 1958 use_multiprocessing=False, 1959 verbose=0): 1960 """Generates predictions for the input samples from a data generator. 1961 1962 DEPRECATED: 1963 `Model.predict` now supports generators, so there is no longer any need 1964 to use this endpoint. 1965 """ 1966 warnings.warn('`Model.predict_generator` is deprecated and ' 1967 'will be removed in a future version. ' 1968 'Please use `Model.predict`, which supports generators.') 1969 return self.predict( 1970 generator, 1971 steps=steps, 1972 max_queue_size=max_queue_size, 1973 workers=workers, 1974 use_multiprocessing=use_multiprocessing, 1975 verbose=verbose, 1976 callbacks=callbacks) 1977 1978 ###################################################################### 1979 # Functions below are not training related. They are for model weights 1980 # tracking, save/load, serialization, etc. 1981 ###################################################################### 1982 1983 @property 1984 def trainable_weights(self): 1985 self._assert_weights_created() 1986 if not self._trainable: 1987 return [] 1988 trainable_variables = [] 1989 for trackable_obj in self._self_tracked_trackables: 1990 trainable_variables += trackable_obj.trainable_variables 1991 trainable_variables += self._trainable_weights 1992 return self._dedup_weights(trainable_variables) 1993 1994 @property 1995 def non_trainable_weights(self): 1996 self._assert_weights_created() 1997 non_trainable_variables = [] 1998 for trackable_obj in self._self_tracked_trackables: 1999 non_trainable_variables += trackable_obj.non_trainable_variables 2000 2001 if not self._trainable: 2002 # Return order is all trainable vars, then all non-trainable vars. 2003 trainable_variables = [] 2004 for trackable_obj in self._self_tracked_trackables: 2005 trainable_variables += trackable_obj.trainable_variables 2006 2007 non_trainable_variables = ( 2008 trainable_variables + self._trainable_weights + 2009 non_trainable_variables + self._non_trainable_weights) 2010 else: 2011 non_trainable_variables = ( 2012 non_trainable_variables + self._non_trainable_weights) 2013 2014 return self._dedup_weights(non_trainable_variables) 2015 2016 def get_weights(self): 2017 """Retrieves the weights of the model. 2018 2019 Returns: 2020 A flat list of Numpy arrays. 2021 """ 2022 with self.distribute_strategy.scope(): 2023 return super(Model, self).get_weights() 2024 2025 def save(self, 2026 filepath, 2027 overwrite=True, 2028 include_optimizer=True, 2029 save_format=None, 2030 signatures=None, 2031 options=None, 2032 save_traces=True): 2033 # pylint: disable=line-too-long 2034 """Saves the model to Tensorflow SavedModel or a single HDF5 file. 2035 2036 Please see `tf.keras.models.save_model` or the 2037 [Serialization and Saving guide](https://keras.io/guides/serialization_and_saving/) 2038 for details. 2039 2040 Args: 2041 filepath: String, PathLike, path to SavedModel or H5 file to save the 2042 model. 2043 overwrite: Whether to silently overwrite any existing file at the 2044 target location, or provide the user with a manual prompt. 2045 include_optimizer: If True, save optimizer's state together. 2046 save_format: Either `'tf'` or `'h5'`, indicating whether to save the 2047 model to Tensorflow SavedModel or HDF5. Defaults to 'tf' in TF 2.X, 2048 and 'h5' in TF 1.X. 2049 signatures: Signatures to save with the SavedModel. Applicable to the 2050 'tf' format only. Please see the `signatures` argument in 2051 `tf.saved_model.save` for details. 2052 options: (only applies to SavedModel format) 2053 `tf.saved_model.SaveOptions` object that specifies options for 2054 saving to SavedModel. 2055 save_traces: (only applies to SavedModel format) When enabled, the 2056 SavedModel will store the function traces for each layer. This 2057 can be disabled, so that only the configs of each layer are stored. 2058 Defaults to `True`. Disabling this will decrease serialization time 2059 and reduce file size, but it requires that all custom layers/models 2060 implement a `get_config()` method. 2061 2062 Example: 2063 2064 ```python 2065 from keras.models import load_model 2066 2067 model.save('my_model.h5') # creates a HDF5 file 'my_model.h5' 2068 del model # deletes the existing model 2069 2070 # returns a compiled model 2071 # identical to the previous one 2072 model = load_model('my_model.h5') 2073 ``` 2074 """ 2075 # pylint: enable=line-too-long 2076 save.save_model(self, filepath, overwrite, include_optimizer, save_format, 2077 signatures, options, save_traces) 2078 2079 def save_weights(self, 2080 filepath, 2081 overwrite=True, 2082 save_format=None, 2083 options=None): 2084 """Saves all layer weights. 2085 2086 Either saves in HDF5 or in TensorFlow format based on the `save_format` 2087 argument. 2088 2089 When saving in HDF5 format, the weight file has: 2090 - `layer_names` (attribute), a list of strings 2091 (ordered names of model layers). 2092 - For every layer, a `group` named `layer.name` 2093 - For every such layer group, a group attribute `weight_names`, 2094 a list of strings 2095 (ordered names of weights tensor of the layer). 2096 - For every weight in the layer, a dataset 2097 storing the weight value, named after the weight tensor. 2098 2099 When saving in TensorFlow format, all objects referenced by the network are 2100 saved in the same format as `tf.train.Checkpoint`, including any `Layer` 2101 instances or `Optimizer` instances assigned to object attributes. For 2102 networks constructed from inputs and outputs using `tf.keras.Model(inputs, 2103 outputs)`, `Layer` instances used by the network are tracked/saved 2104 automatically. For user-defined classes which inherit from `tf.keras.Model`, 2105 `Layer` instances must be assigned to object attributes, typically in the 2106 constructor. See the documentation of `tf.train.Checkpoint` and 2107 `tf.keras.Model` for details. 2108 2109 While the formats are the same, do not mix `save_weights` and 2110 `tf.train.Checkpoint`. Checkpoints saved by `Model.save_weights` should be 2111 loaded using `Model.load_weights`. Checkpoints saved using 2112 `tf.train.Checkpoint.save` should be restored using the corresponding 2113 `tf.train.Checkpoint.restore`. Prefer `tf.train.Checkpoint` over 2114 `save_weights` for training checkpoints. 2115 2116 The TensorFlow format matches objects and variables by starting at a root 2117 object, `self` for `save_weights`, and greedily matching attribute 2118 names. For `Model.save` this is the `Model`, and for `Checkpoint.save` this 2119 is the `Checkpoint` even if the `Checkpoint` has a model attached. This 2120 means saving a `tf.keras.Model` using `save_weights` and loading into a 2121 `tf.train.Checkpoint` with a `Model` attached (or vice versa) will not match 2122 the `Model`'s variables. See the [guide to training 2123 checkpoints](https://www.tensorflow.org/guide/checkpoint) for details 2124 on the TensorFlow format. 2125 2126 Args: 2127 filepath: String or PathLike, path to the file to save the weights to. 2128 When saving in TensorFlow format, this is the prefix used for 2129 checkpoint files (multiple files are generated). Note that the '.h5' 2130 suffix causes weights to be saved in HDF5 format. 2131 overwrite: Whether to silently overwrite any existing file at the 2132 target location, or provide the user with a manual prompt. 2133 save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or 2134 '.keras' will default to HDF5 if `save_format` is `None`. Otherwise 2135 `None` defaults to 'tf'. 2136 options: Optional `tf.train.CheckpointOptions` object that specifies 2137 options for saving weights. 2138 2139 Raises: 2140 ImportError: If h5py is not available when attempting to save in HDF5 2141 format. 2142 ValueError: For invalid/unknown format arguments. 2143 """ 2144 self._assert_weights_created() 2145 filepath = path_to_string(filepath) 2146 filepath_is_h5 = saving_utils.is_hdf5_filepath(filepath) 2147 if save_format is None: 2148 if filepath_is_h5: 2149 save_format = 'h5' 2150 else: 2151 save_format = 'tf' 2152 else: 2153 user_format = save_format.lower().strip() 2154 if user_format in ('tensorflow', 'tf'): 2155 save_format = 'tf' 2156 elif user_format in ('hdf5', 'h5', 'keras'): 2157 save_format = 'h5' 2158 else: 2159 raise ValueError( 2160 'Unknown format "%s". Was expecting one of {"tf", "h5"}.' % ( 2161 save_format,)) 2162 if save_format == 'tf' and filepath_is_h5: 2163 raise ValueError( 2164 ('save_weights got save_format="tf"/"tensorflow", but the ' 2165 'filepath ("%s") looks like an HDF5 file. Omit the ".h5"/".keras" ' 2166 'when saving in TensorFlow format.') 2167 % filepath) 2168 2169 if save_format == 'h5' and h5py is None: 2170 raise ImportError( 2171 '`save_weights` requires h5py when saving in hdf5.') 2172 if save_format == 'tf': 2173 check_filepath = filepath + '.index' 2174 else: 2175 check_filepath = filepath 2176 # If file exists and should not be overwritten: 2177 if not overwrite and os.path.isfile(check_filepath): 2178 proceed = ask_to_proceed_with_overwrite(check_filepath) 2179 if not proceed: 2180 return 2181 if save_format == 'h5': 2182 with h5py.File(filepath, 'w') as f: 2183 hdf5_format.save_weights_to_hdf5_group(f, self.layers) 2184 else: 2185 if context.executing_eagerly(): 2186 session = None 2187 else: 2188 session = backend.get_session() 2189 self._trackable_saver.save(filepath, session=session, options=options) 2190 # Record this checkpoint so it's visible from tf.train.latest_checkpoint. 2191 checkpoint_management.update_checkpoint_state_internal( 2192 save_dir=os.path.dirname(filepath), 2193 model_checkpoint_path=filepath, 2194 save_relative_paths=True, 2195 all_model_checkpoint_paths=[filepath]) 2196 2197 def load_weights(self, 2198 filepath, 2199 by_name=False, 2200 skip_mismatch=False, 2201 options=None): 2202 """Loads all layer weights, either from a TensorFlow or an HDF5 weight file. 2203 2204 If `by_name` is False weights are loaded based on the network's 2205 topology. This means the architecture should be the same as when the weights 2206 were saved. Note that layers that don't have weights are not taken into 2207 account in the topological ordering, so adding or removing layers is fine as 2208 long as they don't have weights. 2209 2210 If `by_name` is True, weights are loaded into layers only if they share the 2211 same name. This is useful for fine-tuning or transfer-learning models where 2212 some of the layers have changed. 2213 2214 Only topological loading (`by_name=False`) is supported when loading weights 2215 from the TensorFlow format. Note that topological loading differs slightly 2216 between TensorFlow and HDF5 formats for user-defined classes inheriting from 2217 `tf.keras.Model`: HDF5 loads based on a flattened list of weights, while the 2218 TensorFlow format loads based on the object-local names of attributes to 2219 which layers are assigned in the `Model`'s constructor. 2220 2221 Args: 2222 filepath: String, path to the weights file to load. For weight files in 2223 TensorFlow format, this is the file prefix (the same as was passed 2224 to `save_weights`). This can also be a path to a SavedModel 2225 saved from `model.save`. 2226 by_name: Boolean, whether to load weights by name or by topological 2227 order. Only topological loading is supported for weight files in 2228 TensorFlow format. 2229 skip_mismatch: Boolean, whether to skip loading of layers where there is 2230 a mismatch in the number of weights, or a mismatch in the shape of 2231 the weight (only valid when `by_name=True`). 2232 options: Optional `tf.train.CheckpointOptions` object that specifies 2233 options for loading weights. 2234 2235 Returns: 2236 When loading a weight file in TensorFlow format, returns the same status 2237 object as `tf.train.Checkpoint.restore`. When graph building, restore 2238 ops are run automatically as soon as the network is built (on first call 2239 for user-defined classes inheriting from `Model`, immediately if it is 2240 already built). 2241 2242 When loading weights in HDF5 format, returns `None`. 2243 2244 Raises: 2245 ImportError: If h5py is not available and the weight file is in HDF5 2246 format. 2247 ValueError: If `skip_mismatch` is set to `True` when `by_name` is 2248 `False`. 2249 """ 2250 if backend.is_tpu_strategy(self._distribution_strategy): 2251 if (self._distribution_strategy.extended.steps_per_run > 1 and 2252 (not saving_utils.is_hdf5_filepath(filepath))): 2253 raise ValueError('Load weights is not yet supported with TPUStrategy ' 2254 'with steps_per_run greater than 1.') 2255 if skip_mismatch and not by_name: 2256 raise ValueError( 2257 'When calling model.load_weights, skip_mismatch can only be set to ' 2258 'True when by_name is True.') 2259 2260 filepath, save_format = _detect_save_format(filepath) 2261 if save_format == 'tf': 2262 status = self._trackable_saver.restore(filepath, options) 2263 if by_name: 2264 raise NotImplementedError( 2265 'Weights may only be loaded based on topology into Models when ' 2266 'loading TensorFlow-formatted weights (got by_name=True to ' 2267 'load_weights).') 2268 if not context.executing_eagerly(): 2269 session = backend.get_session() 2270 # Restore existing variables (if any) immediately, and set up a 2271 # streaming restore for any variables created in the future. 2272 trackable_utils.streaming_restore(status=status, session=session) 2273 status.assert_nontrivial_match() 2274 return status 2275 if h5py is None: 2276 raise ImportError( 2277 '`load_weights` requires h5py when loading weights from HDF5.') 2278 if not self._is_graph_network and not self.built: 2279 raise ValueError( 2280 'Unable to load weights saved in HDF5 format into a subclassed ' 2281 'Model which has not created its variables yet. Call the Model ' 2282 'first, then load the weights.') 2283 self._assert_weights_created() 2284 with h5py.File(filepath, 'r') as f: 2285 if 'layer_names' not in f.attrs and 'model_weights' in f: 2286 f = f['model_weights'] 2287 if by_name: 2288 hdf5_format.load_weights_from_hdf5_group_by_name( 2289 f, self.layers, skip_mismatch=skip_mismatch) 2290 else: 2291 hdf5_format.load_weights_from_hdf5_group(f, self.layers) 2292 2293 def _updated_config(self): 2294 """Util shared between different serialization methods. 2295 2296 Returns: 2297 Model config with Keras version information added. 2298 """ 2299 from tensorflow.python.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top 2300 2301 config = self.get_config() 2302 model_config = { 2303 'class_name': self.__class__.__name__, 2304 'config': config, 2305 'keras_version': keras_version, 2306 'backend': backend.backend() 2307 } 2308 return model_config 2309 2310 def get_config(self): 2311 raise NotImplementedError 2312 2313 @classmethod 2314 def from_config(cls, config, custom_objects=None): 2315 # Since only FunctionalModel produces config, the model can only 2316 # be constructed for FunctionalModel 2317 from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top 2318 return functional.Functional.from_config( 2319 config, custom_objects=custom_objects) 2320 2321 def to_json(self, **kwargs): 2322 """Returns a JSON string containing the network configuration. 2323 2324 To load a network from a JSON save file, use 2325 `keras.models.model_from_json(json_string, custom_objects={})`. 2326 2327 Args: 2328 **kwargs: Additional keyword arguments 2329 to be passed to `json.dumps()`. 2330 2331 Returns: 2332 A JSON string. 2333 """ 2334 model_config = self._updated_config() 2335 return json.dumps( 2336 model_config, default=json_utils.get_json_type, **kwargs) 2337 2338 def to_yaml(self, **kwargs): 2339 """Returns a yaml string containing the network configuration. 2340 2341 To load a network from a yaml save file, use 2342 `keras.models.model_from_yaml(yaml_string, custom_objects={})`. 2343 2344 `custom_objects` should be a dictionary mapping 2345 the names of custom losses / layers / etc to the corresponding 2346 functions / classes. 2347 2348 Args: 2349 **kwargs: Additional keyword arguments 2350 to be passed to `yaml.dump()`. 2351 2352 Returns: 2353 A YAML string. 2354 2355 Raises: 2356 ImportError: if yaml module is not found. 2357 """ 2358 if yaml is None: 2359 raise ImportError( 2360 'Requires yaml module installed (`pip install pyyaml`).') 2361 return yaml.dump(self._updated_config(), **kwargs) 2362 2363 def reset_states(self): 2364 for layer in self.layers: 2365 if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False): 2366 layer.reset_states() 2367 2368 @property 2369 @doc_controls.do_not_generate_docs 2370 def state_updates(self): 2371 """Deprecated, do NOT use! 2372 2373 Returns the `updates` from all layers that are stateful. 2374 2375 This is useful for separating training updates and 2376 state updates, e.g. when we need to update a layer's internal state 2377 during prediction. 2378 2379 Returns: 2380 A list of update ops. 2381 """ 2382 warnings.warn('`Model.state_updates` will be removed in a future version. ' 2383 'This property should not be used in TensorFlow 2.0, ' 2384 'as `updates` are applied automatically.') 2385 state_updates = [] 2386 for layer in self.layers: 2387 if getattr(layer, 'stateful', False): 2388 if hasattr(layer, 'updates'): 2389 state_updates += layer.updates 2390 return state_updates 2391 2392 @property 2393 def weights(self): 2394 """Returns the list of all layer variables/weights. 2395 2396 Note: This will not track the weights of nested `tf.Modules` that are not 2397 themselves Keras layers. 2398 2399 Returns: 2400 A list of variables. 2401 """ 2402 return self._dedup_weights(self._undeduplicated_weights) 2403 2404 @property 2405 def _undeduplicated_weights(self): 2406 """Returns the undeduplicated list of all layer variables/weights.""" 2407 self._assert_weights_created() 2408 weights = [] 2409 for layer in self._self_tracked_trackables: 2410 weights += layer.variables 2411 weights += (self._trainable_weights + self._non_trainable_weights) 2412 return weights 2413 2414 def summary(self, line_length=None, positions=None, print_fn=None): 2415 """Prints a string summary of the network. 2416 2417 Args: 2418 line_length: Total length of printed lines 2419 (e.g. set this to adapt the display to different 2420 terminal window sizes). 2421 positions: Relative or absolute positions of log elements 2422 in each line. If not provided, 2423 defaults to `[.33, .55, .67, 1.]`. 2424 print_fn: Print function to use. Defaults to `print`. 2425 It will be called on each line of the summary. 2426 You can set it to a custom function 2427 in order to capture the string summary. 2428 2429 Raises: 2430 ValueError: if `summary()` is called before the model is built. 2431 """ 2432 if not self.built: 2433 raise ValueError('This model has not yet been built. ' 2434 'Build the model first by calling `build()` or calling ' 2435 '`fit()` with some data, or specify ' 2436 'an `input_shape` argument in the first layer(s) for ' 2437 'automatic build.') 2438 layer_utils.print_summary(self, 2439 line_length=line_length, 2440 positions=positions, 2441 print_fn=print_fn) 2442 2443 @property 2444 def layers(self): 2445 return list(self._flatten_layers(include_self=False, recursive=False)) 2446 2447 def get_layer(self, name=None, index=None): 2448 """Retrieves a layer based on either its name (unique) or index. 2449 2450 If `name` and `index` are both provided, `index` will take precedence. 2451 Indices are based on order of horizontal graph traversal (bottom-up). 2452 2453 Args: 2454 name: String, name of layer. 2455 index: Integer, index of layer. 2456 2457 Returns: 2458 A layer instance. 2459 2460 Raises: 2461 ValueError: In case of invalid layer name or index. 2462 """ 2463 # TODO(fchollet): We could build a dictionary based on layer names 2464 # since they are constant, but we have not done that yet. 2465 if index is not None and name is not None: 2466 raise ValueError('Provide only a layer name or a layer index.') 2467 2468 if index is not None: 2469 if len(self.layers) <= index: 2470 raise ValueError('Was asked to retrieve layer at index ' + str(index) + 2471 ' but model only has ' + str(len(self.layers)) + 2472 ' layers.') 2473 else: 2474 return self.layers[index] 2475 2476 if name is not None: 2477 for layer in self.layers: 2478 if layer.name == name: 2479 return layer 2480 raise ValueError('No such layer: ' + name + '.') 2481 raise ValueError('Provide either a layer name or layer index.') 2482 2483 @trackable.no_automatic_dependency_tracking 2484 def _set_save_spec(self, inputs): 2485 if self._saved_model_inputs_spec is not None: 2486 return # Already set. 2487 2488 input_names = self.input_names 2489 if not input_names: 2490 input_names = compile_utils.create_pseudo_input_names(inputs) 2491 2492 flat_inputs = nest.flatten(inputs) 2493 specs = [] 2494 for name, tensor in zip(input_names, flat_inputs): 2495 specs.append( 2496 tf_utils.get_tensor_spec(tensor, dynamic_batch=False, name=name)) 2497 specs = nest.pack_sequence_as(inputs, specs) 2498 2499 self._saved_model_inputs_spec = specs 2500 2501 # Store the input shapes 2502 if (self.__class__.__name__ == 'Sequential' and 2503 self._build_input_shape is None): 2504 self._build_input_shape = nest.map_structure( 2505 lambda x: None if x is None else x.shape, specs) 2506 2507 def _assert_weights_created(self): 2508 """Asserts that all the weights for the model have been created. 2509 2510 For a non-dynamic model, the weights must already be created after the 2511 layer has been called. For a dynamic model, the exact list of weights can 2512 never be known for certain since it may change at any time during execution. 2513 2514 We run this check right before accessing weights or getting the Numpy value 2515 for the current weights. Otherwise, if the layer has never been called, 2516 the user would just get an empty list, which is misleading. 2517 2518 Raises: 2519 ValueError: if the weights of the network has not yet been created. 2520 """ 2521 if self.dynamic: 2522 return 2523 2524 if ('build' in self.__class__.__dict__ and 2525 self.__class__ != Model and 2526 not self.built): 2527 # For any model that has customized build() method but hasn't 2528 # been invoked yet, this will cover both sequential and subclass model. 2529 # Also make sure to exclude Model class itself which has build() defined. 2530 raise ValueError('Weights for model %s have not yet been created. ' 2531 'Weights are created when the Model is first called on ' 2532 'inputs or `build()` is called with an `input_shape`.' % 2533 self.name) 2534 2535 def _check_call_args(self, method_name): 2536 """Check that `call` has only one positional arg.""" 2537 # Always allow first arg, regardless of arg name. 2538 fullargspec = self._call_full_argspec 2539 if fullargspec.defaults: 2540 positional_args = fullargspec.args[:-len(fullargspec.defaults)] 2541 else: 2542 positional_args = fullargspec.args 2543 if 'training' in positional_args: 2544 positional_args.remove('training') 2545 2546 # self and first arg can be positional. 2547 if len(positional_args) > 2: 2548 extra_args = positional_args[2:] 2549 raise ValueError( 2550 'Models passed to `' + method_name + '` can only have `training` ' 2551 'and the first argument in `call` as positional arguments, ' 2552 'found: ' + str(extra_args) + '.') 2553 2554 def _validate_compile(self, optimizer, metrics, **kwargs): 2555 """Performs validation checks for the default `compile`.""" 2556 if any( 2557 isinstance(opt, optimizer_v1.Optimizer) 2558 for opt in nest.flatten(optimizer)): 2559 raise ValueError( 2560 '`tf.compat.v1.keras` Optimizer (', optimizer, ') is ' 2561 'not supported when eager execution is enabled. Use a ' 2562 '`tf.keras` Optimizer instead, or disable eager ' 2563 'execution.') 2564 2565 kwargs.pop('cloning', None) # Legacy DistStrat argument, never used. 2566 kwargs.pop('experimental_run_tf_function', None) # Always `True`. 2567 if kwargs.pop('distribute', None) is not None: 2568 raise ValueError( 2569 'Distribute argument in compile is not available in TF 2.0 please ' 2570 'create the model under the distribution strategy scope.') 2571 if kwargs.pop('target_tensors', None) is not None: 2572 raise ValueError( 2573 'target_tensors argument is not supported when executing eagerly.') 2574 invalid_kwargs = set(kwargs) - {'sample_weight_mode'} 2575 if invalid_kwargs: 2576 raise TypeError('Invalid keyword argument(s) in `compile`: %s' % 2577 (invalid_kwargs,)) 2578 2579 # Model must be created and compiled with the same DistStrat. 2580 if self.built and ds_context.has_strategy(): 2581 strategy = ds_context.get_strategy() 2582 for v in self.variables: 2583 if not strategy.extended.variable_created_in_scope(v): 2584 raise ValueError( 2585 'Variable (%s) was not created in the distribution strategy ' 2586 'scope of (%s). It is most likely due to not all layers or ' 2587 'the model or optimizer being created outside the distribution ' 2588 'strategy scope. Try to make sure your code looks similar ' 2589 'to the following.\n' 2590 'with strategy.scope():\n' 2591 ' model=_create_model()\n' 2592 ' model.compile(...)' % (v, strategy)) 2593 2594 # Model metrics must be created in the same distribution strategy scope 2595 # as the model. 2596 strategy = self.distribute_strategy 2597 for metric in nest.flatten(metrics): 2598 for v in getattr(metric, 'variables', []): 2599 if not strategy.extended.variable_created_in_scope(v): 2600 raise ValueError( 2601 'Metric (%s) passed to model.compile was created inside of a ' 2602 'different distribution strategy scope than the model. All ' 2603 'metrics must be created in the same distribution strategy ' 2604 'scope as the model (in this case %s). If you pass in a string ' 2605 'identifier for a metric to compile the metric will ' 2606 'automatically be created in the correct distribution ' 2607 'strategy scope.' % (metric, strategy) 2608 ) 2609 2610 # Model metrics must be created in the same distribution strategy scope 2611 # as the model. 2612 for opt in nest.flatten(optimizer): 2613 for v in getattr(opt, '_weights', []): 2614 if not strategy.extended.variable_created_in_scope(v): 2615 raise ValueError( 2616 'Optimizer (%s) passed to model.compile was created inside of a ' 2617 'different distribution strategy scope than the model. All ' 2618 'optimizers must be created in the same distribution strategy ' 2619 'scope as the model (in this case %s). If you pass in a string ' 2620 'identifier for an optimizer to compile the optimizer will ' 2621 'automatically be created in the correct distribution ' 2622 'strategy scope.' % (opt, strategy)) 2623 2624 def _maybe_load_initial_epoch_from_ckpt(self, initial_epoch): 2625 """Maybe load initial epoch from ckpt considering possible worker recovery. 2626 2627 Refer to tensorflow/python/keras/distribute/worker_training_state.py 2628 for more information. 2629 2630 Args: 2631 initial_epoch: The original initial_epoch user passes in in `fit()`. 2632 2633 Returns: 2634 If the training is recovering from previous failure under multi-worker 2635 training setting, return the epoch the training is supposed to continue 2636 at. Otherwise, return the `initial_epoch` the user passes in. 2637 """ 2638 if self._training_state is not None: 2639 return self._training_state.maybe_load_initial_epoch_from_ckpt( 2640 initial_epoch, mode=ModeKeys.TRAIN) 2641 return initial_epoch 2642 2643 def _assert_compile_was_called(self): 2644 # Checks whether `compile` has been called. If it has been called, 2645 # then the optimizer is set. This is different from whether the 2646 # model is compiled 2647 # (i.e. whether the model is built and its inputs/outputs are set). 2648 if not self._is_compiled: 2649 raise RuntimeError('You must compile your model before ' 2650 'training/testing. ' 2651 'Use `model.compile(optimizer, loss)`.') 2652 2653 def _set_inputs(self, inputs, outputs=None, training=None): 2654 """This method is for compat with Modelv1. Only inputs are needed here.""" 2655 self._set_save_spec(inputs) 2656 2657 @property 2658 def _trackable_saved_model_saver(self): 2659 return model_serialization.ModelSavedModelSaver(self) 2660 2661 def _list_functions_for_serialization(self, serialization_cache): 2662 # SavedModel needs to ignore the execution functions. 2663 train_function = self.train_function 2664 test_function = self.test_function 2665 predict_function = self.predict_function 2666 self.train_function = None 2667 self.test_function = None 2668 self.predict_function = None 2669 functions = super( 2670 Model, self)._list_functions_for_serialization(serialization_cache) 2671 self.train_function = train_function 2672 self.test_function = test_function 2673 self.predict_function = predict_function 2674 return functions 2675 2676 def _should_eval(self, epoch, validation_freq): 2677 if self._cluster_coordinator: 2678 raise NotImplementedError( 2679 'Evaluation in `model.fit` with ' 2680 '`ParameterServerStrategy` is not yet supported.') 2681 epoch = epoch + 1 # one-index the user-facing epoch. 2682 if isinstance(validation_freq, int): 2683 return epoch % validation_freq == 0 2684 elif isinstance(validation_freq, list): 2685 return epoch in validation_freq 2686 else: 2687 raise ValueError('Expected `validation_freq` to be a list or int.') 2688 2689 ###################################################################### 2690 # Functions below exist only as v1 / v2 compatibility shims. 2691 ###################################################################### 2692 2693 def _get_compile_args(self, user_metrics=True): 2694 """Used for saving or cloning a Model. 2695 2696 Args: 2697 user_metrics: Whether to return user-supplied metrics or `Metric` objects. 2698 Defaults to returning the user-supplied metrics. 2699 2700 Returns: 2701 Dictionary of arguments that were used when compiling the model. 2702 """ 2703 self._assert_compile_was_called() 2704 # pylint: disable=protected-access 2705 2706 saved_metrics = self.compiled_metrics._user_metrics 2707 saved_weighted_metrics = self.compiled_metrics._user_weighted_metrics 2708 2709 if not user_metrics: 2710 if saved_metrics is not None: 2711 saved_metrics = self.compiled_metrics._metrics 2712 if saved_weighted_metrics is not None: 2713 saved_weighted_metrics = self.compiled_metrics._weighted_metrics 2714 2715 compile_args = { 2716 'optimizer': self.optimizer, 2717 'loss': self.compiled_loss._user_losses, 2718 'metrics': saved_metrics, 2719 'weighted_metrics': saved_weighted_metrics, 2720 'loss_weights': self.compiled_loss._user_loss_weights, 2721 } 2722 # pylint: enable=protected-access 2723 return compile_args 2724 2725 def _get_callback_model(self): 2726 return self 2727 2728 def _in_multi_worker_mode(self): 2729 return self.distribute_strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access 2730 2731 @property 2732 def _compile_was_called(self): 2733 return self._is_compiled 2734 2735 2736def reduce_per_replica(values, strategy, reduction='first'): 2737 """Reduce PerReplica objects. 2738 2739 Args: 2740 values: Structure of `PerReplica` objects or `Tensor`s. `Tensor`s are 2741 returned as-is. 2742 strategy: `tf.distribute.Strategy` object. 2743 reduction: One of 'first', 'concat'. 2744 2745 Returns: 2746 Structure of `Tensor`s. 2747 """ 2748 2749 def _reduce(v): 2750 """Reduce a single `PerReplica` object.""" 2751 if reduction == 'concat' and _collective_all_reduce_multi_worker(strategy): 2752 return _multi_worker_concat(v, strategy) 2753 if not isinstance(v, ds_values.PerReplica): 2754 return v 2755 elif reduction == 'first': 2756 return strategy.unwrap(v)[0] 2757 elif reduction == 'concat': 2758 if _is_tpu_multi_host(strategy): 2759 return _tpu_multi_host_concat(v, strategy) 2760 else: 2761 return concat(strategy.unwrap(v)) 2762 else: 2763 raise ValueError('`reduction` must be "first" or "concat".') 2764 2765 return nest.map_structure(_reduce, values) 2766 2767 2768def concat(tensors, axis=0): 2769 """Concats `tensor`s along `axis`.""" 2770 if isinstance(tensors[0], sparse_tensor.SparseTensor): 2771 return sparse_ops.sparse_concat_v2(axis=axis, sp_inputs=tensors) 2772 return array_ops.concat(tensors, axis=axis) 2773 2774 2775def _is_tpu_multi_host(strategy): 2776 return (backend.is_tpu_strategy(strategy) and 2777 strategy.extended.num_hosts > 1) 2778 2779 2780def _tpu_multi_host_concat(v, strategy): 2781 """Correctly order TPU PerReplica objects.""" 2782 replicas = strategy.unwrap(v) 2783 # When distributed datasets are created from Tensors / NumPy, 2784 # TPUStrategy.experimental_distribute_dataset shards data in 2785 # (Replica, Host) order, and TPUStrategy.unwrap returns it in 2786 # (Host, Replica) order. 2787 # TODO(b/150317897): Figure out long-term plan here. 2788 num_replicas_per_host = strategy.extended.num_replicas_per_host 2789 ordered_replicas = [] 2790 for replica_id in range(num_replicas_per_host): 2791 ordered_replicas += replicas[replica_id::num_replicas_per_host] 2792 return concat(ordered_replicas) 2793 2794 2795def _collective_all_reduce_multi_worker(strategy): 2796 return (isinstance(strategy, 2797 collective_all_reduce_strategy.CollectiveAllReduceStrategy) 2798 ) and strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access 2799 2800 2801# TODO(wxinyi): merge this with _tpu_multi_host_concat once we have all_gather 2802# for all strategies 2803def _multi_worker_concat(v, strategy): 2804 """Order PerReplica objects for CollectiveAllReduceStrategy and concat.""" 2805 replicas = strategy.gather(v, axis=0) 2806 # v might not have the same shape on different replicas 2807 if isinstance(v, ds_values.PerReplica): 2808 shapes = array_ops.concat([ 2809 array_ops.expand_dims_v2(array_ops.shape(single_value)[0], axis=0) 2810 for single_value in v.values 2811 ], 2812 axis=0) 2813 all_shapes = strategy.gather(shapes, axis=0) 2814 else: 2815 # v is a tensor. This may happen when, say, we have 2x1 multi-worker. 2816 all_shapes = strategy.gather( 2817 array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0), axis=0) 2818 2819 replicas = array_ops.split( 2820 replicas, 2821 num_or_size_splits=all_shapes, 2822 num=strategy.num_replicas_in_sync) 2823 ordered_replicas = [] 2824 num_replicas_per_worker = len(strategy.extended.worker_devices) 2825 for replica_id in range(num_replicas_per_worker): 2826 ordered_replicas += replicas[replica_id::num_replicas_per_worker] 2827 return concat(ordered_replicas) 2828 2829 2830def _is_scalar(x): 2831 return isinstance(x, (ops.Tensor, variables.Variable)) and x.shape.rank == 0 2832 2833 2834def write_scalar_summaries(logs, step): 2835 for name, value in logs.items(): 2836 if _is_scalar(value): 2837 summary_ops_v2.scalar('batch_' + name, value, step=step) 2838 2839 2840def _minimum_control_deps(outputs): 2841 """Returns the minimum control dependencies to ensure step succeeded.""" 2842 if context.executing_eagerly(): 2843 return [] # Control dependencies not needed. 2844 outputs = nest.flatten(outputs, expand_composites=True) 2845 for out in outputs: 2846 # Variables can't be control dependencies. 2847 if not isinstance(out, variables.Variable): 2848 return [out] # Return first Tensor or Op from outputs. 2849 return [] # No viable Tensor or Op to use for control deps. 2850 2851 2852def _disallow_inside_tf_function(method_name): 2853 if ops.inside_function(): 2854 error_msg = ( 2855 'Detected a call to `Model.{method_name}` inside a `tf.function`. ' 2856 '`Model.{method_name} is a high-level endpoint that manages its own ' 2857 '`tf.function`. Please move the call to `Model.{method_name}` outside ' 2858 'of all enclosing `tf.function`s. Note that you can call a `Model` ' 2859 'directly on `Tensor`s inside a `tf.function` like: `model(x)`.' 2860 ).format(method_name=method_name) 2861 raise RuntimeError(error_msg) 2862 2863 2864def _detect_save_format(filepath): 2865 """Returns path to weights file and save format.""" 2866 2867 filepath = path_to_string(filepath) 2868 if saving_utils.is_hdf5_filepath(filepath): 2869 return filepath, 'h5' 2870 2871 # Filepath could be a TensorFlow checkpoint file prefix or SavedModel 2872 # directory. It's possible for filepath to be both a prefix and directory. 2873 # Prioritize checkpoint over SavedModel. 2874 if _is_readable_tf_checkpoint(filepath): 2875 save_format = 'tf' 2876 elif sm_loader.contains_saved_model(filepath): 2877 ckpt_path = os.path.join(filepath, sm_constants.VARIABLES_DIRECTORY, 2878 sm_constants.VARIABLES_FILENAME) 2879 if _is_readable_tf_checkpoint(ckpt_path): 2880 filepath = ckpt_path 2881 save_format = 'tf' 2882 else: 2883 raise ValueError('Unable to load weights. filepath {} appears to be a ' 2884 'SavedModel directory, but checkpoint either doesn\'t ' 2885 'exist, or is incorrectly formatted.'.format(filepath)) 2886 else: 2887 # Not a TensorFlow checkpoint. This filepath is likely an H5 file that 2888 # doesn't have the hdf5/keras extensions. 2889 save_format = 'h5' 2890 return filepath, save_format 2891 2892 2893def _is_readable_tf_checkpoint(filepath): 2894 try: 2895 py_checkpoint_reader.NewCheckpointReader(filepath) 2896 return True 2897 except errors_impl.DataLossError: 2898 # The checkpoint is not readable in TensorFlow format. 2899 return False 2900