1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""A `Network` is way to compose layers: the topological form of a `Model`.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import copy
23import json
24import os
25
26from six.moves import zip  # pylint: disable=redefined-builtin
27
28from tensorflow.python import pywrap_tensorflow
29from tensorflow.python.eager import context
30from tensorflow.python.framework import errors
31from tensorflow.python.framework import errors_impl
32from tensorflow.python.framework import func_graph
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import tensor_shape
35from tensorflow.python.keras import backend
36from tensorflow.python.keras.engine import base_layer
37from tensorflow.python.keras.engine import base_layer_utils
38from tensorflow.python.keras.engine import training_utils
39from tensorflow.python.keras.mixed_precision.experimental import policy
40from tensorflow.python.keras.saving import hdf5_format
41from tensorflow.python.keras.utils import generic_utils
42from tensorflow.python.keras.utils import layer_utils
43from tensorflow.python.keras.utils import tf_utils
44from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
45from tensorflow.python.platform import tf_logging as logging
46from tensorflow.python.training import checkpoint_management
47from tensorflow.python.training.tracking import base as trackable
48from tensorflow.python.training.tracking import data_structures
49from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
50from tensorflow.python.training.tracking import util as trackable_utils
51from tensorflow.python.util import nest
52from tensorflow.python.util import serialization
53from tensorflow.python.util import tf_inspect
54
55
56# pylint: disable=g-import-not-at-top
57try:
58  import h5py
59except ImportError:
60  h5py = None
61
62try:
63  import yaml
64except ImportError:
65  yaml = None
66# pylint: enable=g-import-not-at-top
67
68
69class Network(base_layer.Layer):
70  """A `Network` is a composition of layers.
71
72  `Network` is the topological form of a "model". A `Model`
73  is simply a `Network` with added training routines.
74
75  Two types of `Networks` exist: Graph Networks and Subclass Networks. Graph
76  networks are used in the Keras Functional and Sequential APIs. Subclassed
77  networks are used when a user subclasses the `Model` class. In general,
78  more Keras features are supported with Graph Networks than with Subclassed
79  Networks, specifically:
80
81  - Model cloning (`keras.models.clone`)
82  - Serialization (`model.get_config()/from_config`, `model.to_json()/to_yaml()`
83  - Whole-model saving (`model.save()`)
84
85  A Graph Network can be instantiated by passing two arguments to `__init__`.
86  The first argument is the `keras.Input` Tensors that represent the inputs
87  to the Network. The second argument specifies the output Tensors that
88  represent the outputs of this Network. Both arguments can be a nested
89  structure of Tensors.
90
91  Example:
92
93  ```
94  inputs = {'x1': keras.Input(shape=(10,)), 'x2': keras.Input(shape=(1,))}
95  t = keras.layers.Dense(1, activation='relu')(inputs['x1'])
96  outputs = keras.layers.Add()([t, inputs['x2'])
97  network = Network(inputs, outputs)
98  ```
99
100  A Graph Network constructed using the Functional API can also include raw
101  TensorFlow functions, with the exception of functions that create Variables
102  or assign ops.
103
104  Example:
105
106  ```
107  inputs = keras.Input(shape=(10,))
108  x = keras.layers.Dense(1)(inputs)
109  outputs = tf.nn.relu(x)
110  network = Network(inputs, outputs)
111  ```
112
113  Subclassed Networks can be instantiated via `name` and (optional) `dynamic`
114  keyword arguments. Subclassed Networks keep track of their Layers, and their
115  `call` method can be overridden. Subclassed Networks are typically created
116  indirectly, by subclassing the `Model` class.
117
118  Example:
119
120  ```
121  class MyModel(keras.Model):
122    def __init__(self):
123      super(MyModel, self).__init__(name='my_model', dynamic=False)
124
125      self.layer1 = keras.layers.Dense(10, activation='relu')
126
127    def call(self, inputs):
128      return self.layer1(inputs)
129  ```
130  """
131
132  def __init__(self, *args, **kwargs):  # pylint: disable=super-init-not-called
133    # Signature detection
134    if (len(args) == 2 or
135        len(args) == 1 and 'outputs' in kwargs or
136        'inputs' in kwargs and 'outputs' in kwargs):
137      # Graph network
138      self._init_graph_network(*args, **kwargs)
139    else:
140      # Subclassed network
141      self._init_subclassed_network(**kwargs)
142
143  # Several Network methods have "no_automatic_dependency_tracking"
144  # annotations. Since Network does automatic dependency tracking on attribute
145  # assignment, including for common data structures such as lists, by default
146  # we'd have quite a few empty dependencies which users don't care about (or
147  # would need some way to ignore dependencies automatically, which is confusing
148  # when applied to user code). Some attributes, such as _layers, would cause
149  # structural issues (_layers being the place where Layers assigned to tracked
150  # attributes are stored).
151  #
152  # Aside from these aesthetic and structural issues, useless dependencies on
153  # empty lists shouldn't cause issues; adding or removing them will not break
154  # checkpoints, but may cause "all Python objects matched" assertions to fail
155  # (in which case less strict assertions may be substituted if necessary).
156  @trackable.no_automatic_dependency_tracking
157  def _base_init(self, name=None):
158    # The following are implemented as property functions:
159    # self.trainable_weights
160    # self.non_trainable_weights
161    # self.input_spec
162    # self.losses
163    # self.updates
164
165    self._init_set_name(name, zero_based=True)
166    self._activity_regularizer = None
167    # This acts just like the `trainable` attribute of any layer instance.
168    # It does not affect users of the underlying layers, only users of the
169    # Network instance.
170    self.trainable = True
171    self._is_compiled = False
172    self._expects_training_arg = False
173
174    # This is True for Sequential networks and Functional networks.
175    self._compute_output_and_mask_jointly = False
176
177    self.supports_masking = False
178    if not hasattr(self, 'optimizer'):
179      # Don't reset optimizer if already set.
180      self.optimizer = None
181
182    # Private attributes to implement compatibility with Layer.
183    self._trainable_weights = []
184    self._non_trainable_weights = []
185    self._updates = []  # Used in symbolic mode only.
186    self._losses = []
187    self._eager_losses = []
188    # A list of metric instances corresponding to the symbolic metric tensors
189    # added using the `add_metric` API.
190    self._metrics = []
191    # A dictionary that maps metric names to metric result tensors.
192    self._metrics_tensors = {}
193    self._scope = None  # Never used.
194    self._reuse = None  # Never used.
195    if context.executing_eagerly():
196      self._graph = None
197    else:
198      self._graph = ops.get_default_graph()  # Used in symbolic mode only.
199      # A Network does not create weights of its own, thus has no dtype.
200    self._dtype = None
201
202    # All layers in order of horizontal graph traversal.
203    # Entries are unique. Includes input and output layers.
204    self._layers = []
205
206    # Used in symbolic mode only, only in conjunction with graph-networks
207    self._outbound_nodes = []
208    self._inbound_nodes = []
209
210    self._trackable_saver = (
211        trackable_utils.saver_with_op_caching(self))
212
213    # Networks do not need to do any casting of inputs or variables, because
214    # each of its layers will handle casting through the layer's own
215    # implementation. Therefore networks use the 'infer' policy, which does no
216    # casting.
217    self._mixed_precision_policy = policy.Policy('infer')
218
219  @trackable.no_automatic_dependency_tracking
220  def _init_graph_network(self, inputs, outputs, name=None):
221    self._call_convention = (base_layer_utils
222                             .CallConvention.EXPLICIT_INPUTS_ARGUMENT)
223    # Normalize and set self.inputs, self.outputs.
224    if isinstance(inputs, list) and len(nest.flatten(inputs)) == 1:
225      inputs = inputs[0]
226    if isinstance(outputs, list) and len(nest.flatten(outputs)) == 1:
227      outputs = outputs[0]
228    self._nested_outputs = outputs
229    self._nested_inputs = inputs
230    self.inputs = nest.flatten(inputs)
231    self.outputs = nest.flatten(outputs)
232
233    if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs):
234      base_layer_utils.create_keras_history(self._nested_outputs)
235
236    self._base_init(name=name)
237    self._validate_graph_inputs_and_outputs()
238
239    self._compute_previous_mask = (
240        'mask' in tf_inspect.getfullargspec(self.call).args or
241        hasattr(self, 'compute_mask'))
242    # A Network does not create weights of its own, thus it is already
243    # built.
244    self.built = True
245    self._compute_output_and_mask_jointly = True
246    self._is_graph_network = True
247    self._dynamic = False
248    # `_expects_training_arg` is True since the `training` argument is always
249    # present in the signature of the `call` method of a graph network.
250    self._expects_training_arg = True
251
252    self._input_layers = []
253    self._output_layers = []
254    self._input_coordinates = []
255    self._output_coordinates = []
256
257    # This is for performance optimization when calling the Network on new
258    # inputs. Every time the Network is called on a set on input tensors,
259    # we compute the output tensors, output masks and output shapes in one pass,
260    # then cache them here. When any of these outputs is queried later, we
261    # retrieve it from there instead of recomputing it.
262    self._output_mask_cache = {}
263    self._output_tensor_cache = {}
264    self._output_shape_cache = {}
265
266    # Build self._output_layers:
267    for x in self.outputs:
268      layer, node_index, tensor_index = x._keras_history  # pylint: disable=protected-access
269      self._output_layers.append(layer)
270      self._output_coordinates.append((layer, node_index, tensor_index))
271
272    # Build self._input_layers:
273    for x in self.inputs:
274      layer, node_index, tensor_index = x._keras_history  # pylint: disable=protected-access
275      # It's supposed to be an input layer, so only one node
276      # and one tensor output.
277      assert node_index == 0
278      assert tensor_index == 0
279      self._input_layers.append(layer)
280      self._input_coordinates.append((layer, node_index, tensor_index))
281
282    # Keep track of the network's nodes and layers.
283    nodes, nodes_by_depth, layers, layers_by_depth = _map_graph_network(
284        self.inputs, self.outputs)
285    self._network_nodes = nodes
286    self._nodes_by_depth = nodes_by_depth
287    self._layers = layers
288    self._layers_by_depth = layers_by_depth
289    self._layer_call_argspecs = {}
290    for layer in self._layers:
291      self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
292
293    self._track_layers(layers)
294
295    # Create the node linking internal inputs to internal outputs.
296    base_layer.Node(
297        outbound_layer=self,
298        inbound_layers=[],
299        node_indices=[],
300        tensor_indices=[],
301        input_tensors=self._nested_inputs,
302        output_tensors=self._nested_outputs)
303
304    # Build self.input_names and self.output_names.
305    self.input_names = []
306    self.output_names = []
307    self._feed_input_names = []
308    self._feed_inputs = []
309    self._feed_input_shapes = []
310    for i, layer in enumerate(self._input_layers):
311      self.input_names.append(layer.name)
312      if layer.is_placeholder:
313        self._feed_input_names.append(layer.name)
314        self._feed_input_shapes.append(backend.int_shape(self.inputs[i]))
315        self._feed_inputs.append(layer.input)
316    for layer in self._output_layers:
317      self.output_names.append(layer.name)
318
319  @trackable.no_automatic_dependency_tracking
320  def _init_subclassed_network(self, name=None, dynamic=False):
321    self._base_init(name=name)
322    self._is_graph_network = False
323    self._dynamic = dynamic
324    call_argspec = tf_inspect.getfullargspec(self.call)
325    if 'training' in call_argspec.args:
326      self._expects_training_arg = True
327    else:
328      self._expects_training_arg = False
329    self._call_convention = self._determine_call_convention(call_argspec)
330    self.outputs = []
331    self.inputs = []
332    self.built = False
333
334  @property
335  def dynamic(self):
336    if self._is_graph_network:
337      return any(layer.dynamic for layer in self.layers)
338    return self._dynamic or any(layer.dynamic for layer in self.layers)
339
340  def _determine_call_convention(self, call_argspec):
341    """Decides how `self.call()` is invoked. See `CallConvention`."""
342    if call_argspec.varargs:
343      may_take_single_argument = False
344    else:
345      try:
346        # Note: tf_inspect doesn't raise a TypeError when regular inspect would,
347        # so we need to keep in mind that "getcallargs" may have returned
348        # something even though we under-specified positional arguments.
349        all_args = tf_inspect.getcallargs(self.call, None)
350        self_args = set()
351        for arg_name, obj in all_args.items():
352          if obj is self:
353            self_args.add(arg_name)
354        may_take_single_argument = True
355      except TypeError:
356        may_take_single_argument = False
357    if may_take_single_argument:
358      # A single positional argument (plus "self") is considered equivalent to
359      # an "inputs" argument.
360      all_positional_args = len(call_argspec.args)
361      if call_argspec.defaults is not None:
362        all_positional_args -= len(call_argspec.defaults)
363      non_self_positional_args = all_positional_args
364      for positional_arg_name in call_argspec.args[:all_positional_args]:
365        if positional_arg_name in self_args:
366          non_self_positional_args -= 1
367      if non_self_positional_args == 1:
368        if 'inputs' in call_argspec.args[all_positional_args:]:
369          raise TypeError(
370              "Model.call() takes a single positional argument (to which "
371              "inputs are passed by convention) and a separate 'inputs' "
372              "argument. Unable to determine which arguments are inputs.")
373        return base_layer_utils.CallConvention.SINGLE_POSITIONAL_ARGUMENT
374    if 'inputs' in call_argspec.args:
375      return base_layer_utils.CallConvention.EXPLICIT_INPUTS_ARGUMENT
376    else:
377      return base_layer_utils.CallConvention.POSITIONAL_ARGUMENTS_ARE_INPUTS
378
379  def _track_layers(self, layers):
380    """Add Trackable dependencies on a list of Layers."""
381    weight_layer_index = 0
382    for layer_index, layer in enumerate(layers):
383      if layer.weights:
384        # Keep a separate index for layers which have weights. This allows users
385        # to insert Layers without weights anywhere in the network without
386        # breaking checkpoints.
387        self._track_trackable(
388            layer, name='layer_with_weights-%d' % weight_layer_index,
389            overwrite=True)
390        weight_layer_index += 1
391      # Even if it doesn't have weights, we should still track everything in
392      # case it has/will have Trackable dependencies.
393      self._track_trackable(
394          layer, name='layer-%d' % layer_index, overwrite=True)
395
396  def __setattr__(self, name, value):
397    if not getattr(self, '_setattr_tracking', True):
398      super(Network, self).__setattr__(name, value)
399      return
400
401    if all(
402        isinstance(v, (base_layer.Layer,
403                       data_structures.TrackableDataStructure)) or
404        trackable_layer_utils.has_weights(v) for v in nest.flatten(value)):
405      try:
406        self._is_graph_network
407      except AttributeError:
408        raise RuntimeError('It looks like you are subclassing `Model` and you '
409                           'forgot to call `super(YourClass, self).__init__()`.'
410                           ' Always start with this line.')
411
412    super(Network, self).__setattr__(name, value)
413
414    # Keep track of metric instance created in subclassed model/layer.
415    # We do this so that we can maintain the correct order of metrics by adding
416    # the instance to the `metrics` list as soon as it is created.
417    from tensorflow.python.keras import metrics as metrics_module  # pylint: disable=g-import-not-at-top
418    if isinstance(value, metrics_module.Metric):
419      self._metrics.append(value)
420
421  @property
422  def stateful(self):
423    return any((hasattr(layer, 'stateful') and layer.stateful)
424               for layer in self.layers)
425
426  def reset_states(self):
427    for layer in self.layers:
428      if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False):
429        layer.reset_states()
430
431  @property
432  def state_updates(self):
433    """Returns the `updates` from all layers that are stateful.
434
435    This is useful for separating training updates and
436    state updates, e.g. when we need to update a layer's internal state
437    during prediction.
438
439    Returns:
440        A list of update ops.
441    """
442    state_updates = []
443    for layer in self.layers:
444      if getattr(layer, 'stateful', False):
445        if hasattr(layer, 'updates'):
446          state_updates += layer.updates
447    return state_updates
448
449  def get_weights(self):
450    """Retrieves the weights of the model.
451
452    Returns:
453        A flat list of Numpy arrays.
454    """
455    weights = []
456    for layer in self.layers:
457      weights += layer.weights
458    return backend.batch_get_value(weights)
459
460  def set_weights(self, weights):
461    """Sets the weights of the model.
462
463    Arguments:
464        weights: A list of Numpy arrays with shapes and types matching
465            the output of `model.get_weights()`.
466    """
467    tuples = []
468    for layer in self.layers:
469      num_param = len(layer.weights)
470      layer_weights = weights[:num_param]
471      for sw, w in zip(layer.weights, layer_weights):
472        tuples.append((sw, w))
473      weights = weights[num_param:]
474    backend.batch_set_value(tuples)
475
476  def compute_mask(self, inputs, mask):
477    if not self._is_graph_network:
478      return None
479
480    # TODO(omalleyt): b/123540974 This function is not really safe to call
481    # by itself because it will duplicate any updates and losses in graph
482    # mode by `call`ing the Layers again.
483    output_tensors = self._run_internal_graph(inputs, mask=mask)
484    return nest.map_structure(lambda t: t._keras_mask, output_tensors)
485
486  @property
487  def layers(self):
488    return trackable_layer_utils.filter_empty_layer_containers(
489        self._layers)
490
491  def get_layer(self, name=None, index=None):
492    """Retrieves a layer based on either its name (unique) or index.
493
494    If `name` and `index` are both provided, `index` will take precedence.
495    Indices are based on order of horizontal graph traversal (bottom-up).
496
497    Arguments:
498        name: String, name of layer.
499        index: Integer, index of layer.
500
501    Returns:
502        A layer instance.
503
504    Raises:
505        ValueError: In case of invalid layer name or index.
506    """
507    # TODO(fchollet): We could build a dictionary based on layer names
508    # since they are constant, but we have not done that yet.
509    if index is not None:
510      if len(self.layers) <= index:
511        raise ValueError('Was asked to retrieve layer at index ' + str(index) +
512                         ' but model only has ' + str(len(self.layers)) +
513                         ' layers.')
514      else:
515        return self.layers[index]
516    else:
517      if not name:
518        raise ValueError('Provide either a layer name or layer index.')
519    for layer in self.layers:
520      if layer.name == name:
521        return layer
522    raise ValueError('No such layer: ' + name)
523
524  def _get_unfiltered_updates(self, check_trainable=True):
525    if check_trainable and not self.trainable and not self.stateful:
526      return []
527    updates = []
528    for layer in self.layers:
529      updates += layer._get_unfiltered_updates(check_trainable=check_trainable)
530    updates += list(self._updates)
531    return updates
532
533  @property
534  def _unfiltered_losses(self):
535    losses = []
536
537    # If any eager losses are present, we assume the model to be part of an
538    # eager training loop (either a custom one or the one used when
539    # `run_eagerly=True`), and so we always return just the eager losses in that
540    # case.
541    if self._eager_losses:
542      losses.extend(self._eager_losses)
543    else:
544      losses.extend(self._losses)
545    for layer in self.layers:
546      if isinstance(layer, Network):
547        losses += layer._unfiltered_losses
548      else:
549        losses += layer.losses
550    return losses
551
552  @trackable.no_automatic_dependency_tracking
553  def _clear_losses(self):
554    """Used every step in eager to reset losses."""
555    self._eager_losses = []
556    for layer in self.layers:
557      layer._clear_losses()
558
559  @property
560  def updates(self):
561    """Retrieves the network's updates.
562
563    Will only include updates that are either
564    unconditional, or conditional on inputs to this model
565    (e.g. will not include updates that were created by layers of this model
566    outside of the model).
567
568    When the network has no registered inputs, all updates are returned.
569
570    Effectively, `network.updates` behaves like `layer.updates`.
571
572    Concrete example:
573
574    ```python
575      bn = keras.layers.BatchNormalization()
576      x1 = keras.layers.Input(shape=(10,))
577      _ = bn(x1)  # This creates 2 updates.
578
579      x2 = keras.layers.Input(shape=(10,))
580      y2 = bn(x2)  # This creates 2 more updates.
581
582      # The BN layer has now 4 updates.
583      self.assertEqual(len(bn.updates), 4)
584
585      # Let's create a model from x2 to y2.
586      model = keras.models.Model(x2, y2)
587
588      # The model does not list all updates from its underlying layers,
589      # but only the updates that are relevant to it. Updates created by layers
590      # outside of the model are discarded.
591      self.assertEqual(len(model.updates), 2)
592
593      # If you keep calling the model, you append to its updates, just like
594      # what happens for a layer.
595      x3 = keras.layers.Input(shape=(10,))
596      y3 = model(x3)
597      self.assertEqual(len(model.updates), 4)
598
599      # But if you call the inner BN layer independently, you don't affect
600      # the model's updates.
601      x4 = keras.layers.Input(shape=(10,))
602      _ = bn(x4)
603      self.assertEqual(len(model.updates), 4)
604    ```
605
606    Returns:
607        A list of update ops.
608    """
609
610    updates = self._get_unfiltered_updates(check_trainable=True)
611
612    # `updates` might contain irrelevant updates, so it needs to be filtered
613    # with respect to inputs the model has been called on.
614    relevant_inputs = []
615    for i in range(0, len(self._inbound_nodes)):
616      inputs = self.get_input_at(i)
617      if isinstance(inputs, list):
618        relevant_inputs += inputs
619      else:
620        relevant_inputs.append(inputs)
621    if not relevant_inputs:
622      return list(set(updates))
623
624    reachable = tf_utils.get_reachable_from_inputs(relevant_inputs, updates)
625    relevant_conditional_updates = [x for x in updates if x in reachable]
626    unconditional_updates = [
627        x for x in updates if x._unconditional_update]  # pylint: disable=protected-access
628    # A layer could be used multiple times in a nested structure,
629    # so the updates list must be de-duped.
630    return list(set(relevant_conditional_updates + unconditional_updates))
631
632  @property
633  def losses(self):
634    """Retrieves the network's losses.
635
636    Will only include losses that are either
637    unconditional, or conditional on inputs to this model
638    (e.g. will not include losses that depend on tensors
639    that aren't inputs to this model).
640
641    When the network has no registered inputs, all losses are returned.
642
643    Returns:
644        A list of loss tensors.
645    """
646    losses = self._unfiltered_losses
647
648    if context.executing_eagerly():
649      return losses
650
651    # TODO(kaftan/fchollet): Clean this up / make it obsolete.
652    # This is a super ugly, confusing check necessary to
653    # handle the case where we are executing in a function graph in eager mode
654    # but the model was constructed symbolically in a separate graph scope.
655    # We need to capture the losses created in the current graph function,
656    # and filter out the incorrect loss tensors created when symbolically
657    # building the graph.
658    # We have to use this check because the code after it that checks
659    # for reachable inputs only captures the part of the model that was
660    # built symbolically, and captures the wrong tensors from a different
661    # func graph (causing a crash later on when trying to execute the
662    # graph function)
663    with ops.init_scope():
664      if context.executing_eagerly():
665        return [loss for loss in losses
666                if loss.graph == ops.get_default_graph()]
667
668    relevant_inputs = []
669    for i in range(0, len(self._inbound_nodes)):
670      inputs = self.get_input_at(i)
671      if isinstance(inputs, list):
672        relevant_inputs += inputs
673      else:
674        relevant_inputs.append(inputs)
675    if not relevant_inputs:
676      return losses
677
678    reachable = tf_utils.get_reachable_from_inputs(relevant_inputs, losses)
679    relevant_conditional_losses = [x for x in losses if x in reachable]
680    unconditional_losses = [
681        x for x in losses if x._unconditional_loss]  # pylint: disable=protected-access
682    return list(set(
683        relevant_conditional_losses + unconditional_losses + self._losses))
684
685  @property
686  def trainable_weights(self):
687    return trackable_layer_utils.gather_trainable_weights(
688        trainable=self.trainable,
689        sub_layers=self._layers,
690        extra_variables=self._trainable_weights)
691
692  @property
693  def non_trainable_weights(self):
694    return trackable_layer_utils.gather_non_trainable_weights(
695        trainable=self.trainable,
696        sub_layers=self._layers,
697        extra_variables=self._non_trainable_weights + self._trainable_weights)
698
699  @property
700  def metrics(self):
701    """Returns the network's symbolic metrics.
702
703    Model overrides this function to include the metrics from `compile` API.
704    """
705    metrics = []
706    for layer in self.layers:
707      metrics += layer._metrics  # pylint: disable=protected-access
708    return metrics + self._metrics
709
710  @property
711  def _all_metrics_tensors(self):
712    """Returns the network's symbolic metric tensors."""
713    # TODO(psv): Remove this property.
714    metrics_tensors = {}
715    for layer in self.layers:
716      if isinstance(layer, Network):
717        metrics_tensors.update(layer._all_metrics_tensors)
718      else:
719        metrics_tensors.update(layer._metrics_tensors)
720    metrics_tensors.update(self._metrics_tensors)
721    return metrics_tensors
722
723  @property
724  def input_spec(self):
725    """Gets the network's input specs.
726
727    Returns:
728        A list of `InputSpec` instances (one per input to the model)
729            or a single instance if the model has only one input.
730    """
731    # If subclassed model, can't assume anything.
732    if not self._is_graph_network:
733      return None
734
735    specs = []
736    for layer in self._input_layers:
737      if layer.input_spec is None:
738        specs.append(None)
739      else:
740        if not isinstance(layer.input_spec, list):
741          raise TypeError('Layer ' + layer.name +
742                          ' has an input_spec attribute that '
743                          'is not a list. We expect a list. '
744                          'Found input_spec = ' + str(layer.input_spec))
745        specs += layer.input_spec
746    if len(specs) == 1:
747      return specs[0]
748    return specs
749
750  @base_layer.default
751  def build(self, input_shape):
752    """Builds the model based on input shapes received.
753
754    This is to be used for subclassed models, which do not know at instantiation
755    time what their inputs look like.
756
757    This method only exists for users who want to call `model.build()` in a
758    standalone way (as a substitute for calling the model on real data to
759    build it). It will never be called by the framework (and thus it will
760    never throw unexpected errors in an unrelated workflow).
761
762    Args:
763     input_shape: Single tuple, TensorShape, or list of shapes, where shapes
764         are tuples, integers, or TensorShapes.
765
766    Raises:
767      ValueError:
768        1. In case of invalid user-provided data (not of type tuple,
769           list, or TensorShape).
770        2. If the model requires call arguments that are agnostic
771           to the input shapes (positional or kwarg in call signature).
772        3. If not all layers were properly built.
773        4. If float type inputs are not supported within the layers.
774
775      In each of these cases, the user should build their model by calling it
776      on real tensor data.
777    """
778    if self._is_graph_network:
779      self.built = True
780      return
781
782    # If subclass network
783    if input_shape is None:
784      raise ValueError('Input shape must be defined when calling build on a '
785                       'model subclass network.')
786    valid_types = (tuple, list, tensor_shape.TensorShape)
787    if not isinstance(input_shape, valid_types):
788      raise ValueError('Specified input shape is not one of the valid types. '
789                       'Please specify a batch input shape of type tuple or '
790                       'list of input shapes. User provided '
791                       'input type: {}'.format(type(input_shape)))
792
793    if input_shape and not self.inputs:
794      # We create placeholders for the `None`s in the shape and build the model
795      # in a Graph. Since tf.Variable is compatible with both eager execution
796      # and graph building, the variables created after building the model in
797      # a Graph are still valid when executing eagerly.
798      if context.executing_eagerly():
799        graph = func_graph.FuncGraph('build_graph')
800      else:
801        graph = backend.get_graph()
802      with graph.as_default():
803        if isinstance(input_shape, list):
804          x = [base_layer_utils.generate_placeholders_from_shape(shape)
805               for shape in input_shape]
806        else:
807          x = base_layer_utils.generate_placeholders_from_shape(input_shape)
808
809        kwargs = {}
810        call_signature = tf_inspect.getfullargspec(self.call)
811        call_args = call_signature.args
812        # Exclude `self`, `inputs`, and any argument with a default value.
813        if len(call_args) > 2:
814          if call_signature.defaults:
815            call_args = call_args[2:-len(call_signature.defaults)]
816          else:
817            call_args = call_args[2:]
818          for arg in call_args:
819            if arg == 'training':
820              # Case where `training` is a positional arg with no default.
821              kwargs['training'] = False
822            else:
823              # Has invalid call signature with unknown positional arguments.
824              raise ValueError(
825                  'Currently, you cannot build your model if it has '
826                  'positional or keyword arguments that are not '
827                  'inputs to the model, but are required for its '
828                  '`call` method. Instead, in order to instantiate '
829                  'and build your model, `call` your model on real '
830                  'tensor data with all expected call arguments.')
831        elif len(call_args) < 2:
832          # Signature without `inputs`.
833          raise ValueError('You can only call `build` on a model if its `call` '
834                           'method accepts an `inputs` argument.')
835        try:
836          self.call(x, **kwargs)
837        except (errors.InvalidArgumentError, TypeError):
838          raise ValueError('You cannot build your model by calling `build` '
839                           'if your layers do not support float type inputs. '
840                           'Instead, in order to instantiate and build your '
841                           'model, `call` your model on real tensor data (of '
842                           'the correct dtype).')
843    if self._layers:
844      self._track_layers(self._layers)
845    self.built = True
846
847  def call(self, inputs, training=None, mask=None):
848    """Calls the model on new inputs.
849
850    In this case `call` just reapplies
851    all ops in the graph to the new inputs
852    (e.g. build a new computational graph from the provided inputs).
853
854    Arguments:
855        inputs: A tensor or list of tensors.
856        training: Boolean or boolean scalar tensor, indicating whether to run
857          the `Network` in training mode or inference mode.
858        mask: A mask or list of masks. A mask can be
859            either a tensor or None (no mask).
860
861    Returns:
862        A tensor if there is a single output, or
863        a list of tensors if there are more than one outputs.
864    """
865    if not self._is_graph_network:
866      raise NotImplementedError('When subclassing the `Model` class, you should'
867                                ' implement a `call` method.')
868
869    return self._run_internal_graph(inputs, training=training, mask=mask)
870
871  def compute_output_shape(self, input_shape):
872    if not self._is_graph_network:
873      return super(Network, self).compute_output_shape(input_shape)
874
875    # Convert any shapes in tuple format to TensorShapes.
876    input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
877
878    if len(nest.flatten(input_shape)) != len(nest.flatten(self._input_layers)):
879      raise ValueError('Invalid input_shape argument ' + str(input_shape) +
880                       ': model has ' + str(len(self._input_layers)) +
881                       ' tensor inputs.')
882
883    cache_key = generic_utils.object_list_uid(input_shape)
884    if cache_key in self._output_shape_cache:
885      # Cache hit. Return shapes as TensorShapes.
886      return self._output_shape_cache[cache_key]
887
888    layers_to_output_shapes = {}
889    for layer, shape in zip(self._input_layers, nest.flatten(input_shape)):
890      # It's an input layer: then `compute_output_shape` is identity,
891      # and there is only one node and one tensor..
892      shape_key = layer.name + '_0_0'
893      layers_to_output_shapes[shape_key] = shape
894
895    depth_keys = list(self._nodes_by_depth.keys())
896    depth_keys.sort(reverse=True)
897    # Iterate over nodes, by depth level.
898    if len(depth_keys) > 1:
899      for depth in depth_keys:
900        nodes = self._nodes_by_depth[depth]
901        for node in nodes:
902          # This is always a single layer, never a list.
903          layer = node.outbound_layer
904          if layer in self._input_layers:
905            # We've already covered the input layers
906            # a few lines above.
907            continue
908          # Potentially redundant list,
909          # same size as node.input_tensors.
910          layer_input_shapes = []
911          for inbound_layer, node_id, tensor_id, _ in node.iterate_inbound():
912            input_layer_key = inbound_layer.name + '_%s_%s' % (node_id,
913                                                               tensor_id)
914            layer_input_shapes.append(layers_to_output_shapes[input_layer_key])
915          layer_input_shapes = nest.pack_sequence_as(node.inbound_layers,
916                                                     layer_input_shapes)
917          # Layers expect shapes to be tuples for `compute_output_shape`.
918          layer_input_shapes = tf_utils.convert_shapes(
919              layer_input_shapes, to_tuples=True)
920          layer_output_shapes = layer.compute_output_shape(layer_input_shapes)
921          # Convert back to TensorShapes.
922          layer_output_shapes = tf_utils.convert_shapes(
923              layer_output_shapes, to_tuples=False)
924
925          node_index = layer._inbound_nodes.index(node)  # pylint: disable=protected-access
926          for j, shape in enumerate(nest.flatten(layer_output_shapes)):
927            shape_key = layer.name + '_%s_%s' % (node_index, j)
928            layers_to_output_shapes[shape_key] = shape
929
930      # Read final output shapes from layers_to_output_shapes.
931      output_shapes = []
932      for i in range(len(self._output_layers)):
933        layer, node_index, tensor_index = self._output_coordinates[i]
934        shape_key = layer.name + '_%s_%s' % (node_index, tensor_index)
935        output_shapes.append(layers_to_output_shapes[shape_key])
936      output_shapes = nest.pack_sequence_as(self._nested_outputs, output_shapes)
937      # Store in cache.
938      self._output_shape_cache[cache_key] = output_shapes
939
940    # Return shapes as TensorShapes.
941    return output_shapes
942
943  def _run_internal_graph(self, inputs, training=None, mask=None):
944    """Computes output tensors for new inputs.
945
946    # Note:
947        - Expects `inputs` to be a list (potentially with 1 element).
948        - Can be run on non-Keras tensors.
949
950    Arguments:
951        inputs: Tensor or nested structure of Tensors.
952        training: Boolean learning phase.
953        mask: (Optional) Tensor or nested structure of Tensors.
954
955    Returns:
956        Two lists: output_tensors, output_masks
957    """
958    # Note: masking support is relevant mainly for Keras.
959    # It cannot be factored out without having the fully reimplement the network
960    # calling logic on the Keras side. We choose to incorporate it in
961    # Network because 1) it may be useful to fully support in tf.layers in
962    # the future and 2) Keras is a major user of Network.  If you don't
963    # use masking, it does not interfere with regular behavior at all and you
964    # can ignore it.
965    inputs = nest.flatten(inputs)
966    if mask is None:
967      masks = [None for _ in range(len(inputs))]
968    else:
969      masks = nest.flatten(mask)
970
971    for input_t, mask in zip(inputs, masks):
972      input_t._keras_mask = mask
973
974    # Dictionary mapping reference tensors to computed tensors.
975    tensor_dict = {}
976
977    for x, y, mask in zip(self.inputs, inputs, masks):
978      tensor_dict[str(id(x))] = y
979
980    depth_keys = list(self._nodes_by_depth.keys())
981    depth_keys.sort(reverse=True)
982    # Ignore the InputLayers when computing the graph.
983    depth_keys = depth_keys[1:]
984
985    for depth in depth_keys:
986      nodes = self._nodes_by_depth[depth]
987      for node in nodes:
988        # This is always a single layer, never a list.
989        layer = node.outbound_layer
990
991        if all(
992            str(id(tensor)) in tensor_dict
993            for tensor in nest.flatten(node.input_tensors)):
994
995          # Call layer (reapplying ops to new inputs).
996          computed_tensors = nest.map_structure(
997              lambda t: tensor_dict[str(id(t))], node.input_tensors)
998
999          # Ensure `training` and `mask` arg propagation if applicable.
1000          kwargs = node.arguments or {}
1001          argspec = self._layer_call_argspecs[layer].args
1002          if 'training' in argspec:
1003            kwargs.setdefault('training', training)
1004          if 'mask' in argspec:
1005            computed_masks = nest.map_structure(lambda t: t._keras_mask,
1006                                                computed_tensors)
1007            kwargs.setdefault('mask', computed_masks)
1008
1009          # Compute outputs.
1010          output_tensors = layer(computed_tensors, **kwargs)
1011
1012          # Update tensor_dict.
1013          for x, y in zip(
1014              nest.flatten(node.output_tensors), nest.flatten(output_tensors)):
1015            tensor_dict[str(id(x))] = y
1016
1017    output_tensors = []
1018    output_shapes = []
1019    for x in self.outputs:
1020      assert str(id(x)) in tensor_dict, 'Could not compute output ' + str(x)
1021      tensor = tensor_dict[str(id(x))]
1022      output_shapes.append(x.shape)
1023      output_tensors.append(tensor)
1024
1025    if output_shapes is not None:
1026      input_shapes = [x.shape for x in inputs]
1027      cache_key = generic_utils.object_list_uid(input_shapes)
1028      self._output_shape_cache[cache_key] = nest.pack_sequence_as(
1029          self._nested_outputs, output_shapes)
1030
1031    output_tensors = nest.pack_sequence_as(self._nested_outputs, output_tensors)
1032    return output_tensors
1033
1034  def get_config(self):
1035    if not self._is_graph_network:
1036      raise NotImplementedError
1037
1038    config = {
1039        'name': self.name,
1040    }
1041    node_conversion_map = {}
1042    for layer in self.layers:
1043      if issubclass(layer.__class__, Network):
1044        # Networks start with a pre-existing node
1045        # linking their input to output.
1046        kept_nodes = 1
1047      else:
1048        kept_nodes = 0
1049      for original_node_index, node in enumerate(layer._inbound_nodes):
1050        node_key = _make_node_key(layer.name, original_node_index)
1051        if node_key in self._network_nodes:
1052          node_conversion_map[node_key] = kept_nodes
1053          kept_nodes += 1
1054    layer_configs = []
1055    for layer in self.layers:  # From the earliest layers on.
1056      layer_class_name = layer.__class__.__name__
1057      layer_config = layer.get_config()
1058      filtered_inbound_nodes = []
1059      for original_node_index, node in enumerate(layer._inbound_nodes):
1060        node_key = _make_node_key(layer.name, original_node_index)
1061        if node_key in self._network_nodes:
1062          # The node is relevant to the model:
1063          # add to filtered_inbound_nodes.
1064          if node.arguments:
1065            try:
1066              json.dumps(node.arguments)
1067              kwargs = node.arguments
1068            except TypeError:
1069              logging.warning(
1070                  'Layer ' + layer.name +
1071                  ' was passed non-serializable keyword arguments: ' +
1072                  str(node.arguments) + '. They will not be included '
1073                  'in the serialized model (and thus will be missing '
1074                  'at deserialization time).')
1075              kwargs = {}
1076          else:
1077            kwargs = {}
1078          if node.inbound_layers:
1079            node_data = []
1080            for inbound_layer, node_id, tensor_id, _ in node.iterate_inbound():
1081              node_key = _make_node_key(inbound_layer.name, node_id)
1082              new_node_index = node_conversion_map.get(node_key, 0)
1083              node_data.append(
1084                  tf_utils.ListWrapper(
1085                      [inbound_layer.name, new_node_index, tensor_id, kwargs]))
1086            node_data = nest.pack_sequence_as(node.input_tensors, node_data)
1087            # Convert ListWrapper to list for backwards compatible configs.
1088            node_data = tf_utils.convert_inner_node_data(node_data)
1089            filtered_inbound_nodes.append(node_data)
1090      layer_configs.append({
1091          'name': layer.name,
1092          'class_name': layer_class_name,
1093          'config': layer_config,
1094          'inbound_nodes': filtered_inbound_nodes,
1095      })
1096    config['layers'] = layer_configs
1097
1098    # Gather info about inputs and outputs.
1099    model_inputs = []
1100    for i in range(len(self._input_layers)):
1101      layer, node_index, tensor_index = self._input_coordinates[i]
1102      node_key = _make_node_key(layer.name, node_index)
1103      if node_key not in self._network_nodes:
1104        continue
1105      new_node_index = node_conversion_map[node_key]
1106      model_inputs.append(
1107          tf_utils.ListWrapper([layer.name, new_node_index, tensor_index]))
1108    model_inputs = nest.pack_sequence_as(self._nested_inputs, model_inputs)
1109    # Preserve external Keras compat for Models with single input.
1110    if not nest.is_sequence(model_inputs):
1111      model_inputs = [model_inputs]
1112    model_inputs = tf_utils.convert_inner_node_data(model_inputs)
1113    config['input_layers'] = model_inputs
1114
1115    model_outputs = []
1116    for i in range(len(self._output_layers)):
1117      layer, node_index, tensor_index = self._output_coordinates[i]
1118      node_key = _make_node_key(layer.name, node_index)
1119      if node_key not in self._network_nodes:
1120        continue
1121      new_node_index = node_conversion_map[node_key]
1122      model_outputs.append(
1123          tf_utils.ListWrapper([layer.name, new_node_index, tensor_index]))
1124    model_outputs = nest.pack_sequence_as(self._nested_outputs, model_outputs)
1125    # Preserve external Keras compat for Models with single output.
1126    if not nest.is_sequence(model_outputs):
1127      model_outputs = [model_outputs]
1128    model_outputs = tf_utils.convert_inner_node_data(model_outputs)
1129    config['output_layers'] = model_outputs
1130    return copy.deepcopy(config)
1131
1132  @classmethod
1133  def from_config(cls, config, custom_objects=None):
1134    """Instantiates a Model from its config (output of `get_config()`).
1135
1136    Arguments:
1137        config: Model config dictionary.
1138        custom_objects: Optional dictionary mapping names
1139            (strings) to custom classes or functions to be
1140            considered during deserialization.
1141
1142    Returns:
1143        A model instance.
1144
1145    Raises:
1146        ValueError: In case of improperly formatted config dict.
1147    """
1148    # Layer instances created during
1149    # the graph reconstruction process
1150    created_layers = {}
1151
1152    # Dictionary mapping layer instances to
1153    # node data that specifies a layer call.
1154    # It acts as a queue that maintains any unprocessed
1155    # layer call until it becomes possible to process it
1156    # (i.e. until the input tensors to the call all exist).
1157    unprocessed_nodes = {}
1158
1159    def add_unprocessed_node(layer, node_data):
1160      if layer not in unprocessed_nodes:
1161        unprocessed_nodes[layer] = [node_data]
1162      else:
1163        unprocessed_nodes[layer].append(node_data)
1164
1165    def process_node(layer, node_data):
1166      """Deserialize a node.
1167
1168      Arguments:
1169          layer: layer instance.
1170          node_data: Nested structure of `ListWrapper`.
1171
1172      Raises:
1173          ValueError: In case of improperly formatted `node_data`.
1174      """
1175      input_tensors = []
1176      for input_data in nest.flatten(node_data):
1177        input_data = input_data.as_list()
1178        inbound_layer_name = input_data[0]
1179        inbound_node_index = input_data[1]
1180        inbound_tensor_index = input_data[2]
1181        if len(input_data) == 3:
1182          kwargs = {}
1183        elif len(input_data) == 4:
1184          kwargs = input_data[3]
1185        else:
1186          raise ValueError('Improperly formatted model config.')
1187
1188        inbound_layer = created_layers[inbound_layer_name]
1189        if len(inbound_layer._inbound_nodes) <= inbound_node_index:
1190          add_unprocessed_node(layer, node_data)
1191          return
1192        inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
1193        input_tensors.append(
1194            nest.flatten(inbound_node.output_tensors)[inbound_tensor_index])
1195      input_tensors = nest.pack_sequence_as(node_data, input_tensors)
1196      # Call layer on its inputs, thus creating the node
1197      # and building the layer if needed.
1198      if input_tensors is not None:
1199        # Preserve compatibility with older configs.
1200        flat_input_tensors = nest.flatten(input_tensors)
1201        if len(flat_input_tensors) == 1:
1202          layer(flat_input_tensors[0], **kwargs)
1203        else:
1204          layer(input_tensors, **kwargs)
1205
1206    def process_layer(layer_data):
1207      """Deserializes a layer, then call it on appropriate inputs.
1208
1209      Arguments:
1210          layer_data: layer config dict.
1211
1212      Raises:
1213          ValueError: In case of improperly formatted `layer_data` dict.
1214      """
1215      layer_name = layer_data['name']
1216
1217      # Instantiate layer.
1218      from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
1219
1220      layer = deserialize_layer(layer_data, custom_objects=custom_objects)
1221      created_layers[layer_name] = layer
1222
1223      # Gather layer inputs and convert to `ListWrapper` objects.
1224      inbound_nodes_data = layer_data['inbound_nodes']
1225      inbound_nodes_data = tf_utils.convert_inner_node_data(
1226          inbound_nodes_data, wrap=True)
1227      for node_data in inbound_nodes_data:
1228        # We don't process nodes (i.e. make layer calls)
1229        # on the fly because the inbound node may not yet exist,
1230        # in case of layer shared at different topological depths
1231        # (e.g. a model such as A(B(A(B(x)))))
1232        add_unprocessed_node(layer, node_data)
1233
1234    # First, we create all layers and enqueue nodes to be processed
1235    for layer_data in config['layers']:
1236      process_layer(layer_data)
1237    # Then we process nodes in order of layer depth.
1238    # Nodes that cannot yet be processed (if the inbound node
1239    # does not yet exist) are re-enqueued, and the process
1240    # is repeated until all nodes are processed.
1241    while unprocessed_nodes:
1242      for layer_data in config['layers']:
1243        layer = created_layers[layer_data['name']]
1244        if layer in unprocessed_nodes:
1245          for node_data in unprocessed_nodes.pop(layer):
1246            process_node(layer, node_data)
1247
1248    name = config.get('name')
1249    input_tensors = []
1250    output_tensors = []
1251
1252    input_layers = tf_utils.convert_inner_node_data(
1253        config['input_layers'], wrap=True)
1254    for layer_data in nest.flatten(input_layers):
1255      layer_name, node_index, tensor_index = layer_data.as_list()
1256      assert layer_name in created_layers
1257      layer = created_layers[layer_name]
1258      layer_output_tensors = layer._inbound_nodes[node_index].output_tensors
1259      input_tensors.append(nest.flatten(layer_output_tensors)[tensor_index])
1260
1261    output_layers = tf_utils.convert_inner_node_data(
1262        config['output_layers'], wrap=True)
1263    for layer_data in nest.flatten(output_layers):
1264      layer_name, node_index, tensor_index = layer_data.as_list()
1265      assert layer_name in created_layers
1266      layer = created_layers[layer_name]
1267      layer_output_tensors = layer._inbound_nodes[node_index].output_tensors
1268      output_tensors.append(nest.flatten(layer_output_tensors)[tensor_index])
1269
1270    input_tensors = nest.pack_sequence_as(input_layers, input_tensors)
1271    output_tensors = nest.pack_sequence_as(output_layers, output_tensors)
1272    return cls(inputs=input_tensors, outputs=output_tensors, name=name)
1273
1274  def save(self, filepath, overwrite=True, include_optimizer=True):
1275    """Saves the model to a single HDF5 file.
1276
1277    The savefile includes:
1278        - The model architecture, allowing to re-instantiate the model.
1279        - The model weights.
1280        - The state of the optimizer, allowing to resume training
1281            exactly where you left off.
1282
1283    This allows you to save the entirety of the state of a model
1284    in a single file.
1285
1286    Saved models can be reinstantiated via `keras.models.load_model`.
1287    The model returned by `load_model`
1288    is a compiled model ready to be used (unless the saved model
1289    was never compiled in the first place).
1290
1291    Arguments:
1292        filepath: String, path to the file to save the weights to.
1293        overwrite: Whether to silently overwrite any existing file at the
1294            target location, or provide the user with a manual prompt.
1295        include_optimizer: If True, save optimizer's state together.
1296
1297    Example:
1298
1299    ```python
1300    from keras.models import load_model
1301
1302    model.save('my_model.h5')  # creates a HDF5 file 'my_model.h5'
1303    del model  # deletes the existing model
1304
1305    # returns a compiled model
1306    # identical to the previous one
1307    model = load_model('my_model.h5')
1308    ```
1309    """
1310    if not self._is_graph_network:
1311      raise NotImplementedError(
1312          'The `save` method requires the model to be a Functional model or a '
1313          'Sequential model. It does not work for subclassed models, '
1314          'because such models are defined via the body of a Python method, '
1315          'which isn\'t safely serializable. Consider '
1316          'using `save_weights`, in order to save the weights of the model.')
1317
1318    from tensorflow.python.keras.models import save_model  # pylint: disable=g-import-not-at-top
1319    save_model(self, filepath, overwrite, include_optimizer)
1320
1321  def save_weights(self, filepath, overwrite=True, save_format=None):
1322    """Saves all layer weights.
1323
1324    Either saves in HDF5 or in TensorFlow format based on the `save_format`
1325    argument.
1326
1327    When saving in HDF5 format, the weight file has:
1328      - `layer_names` (attribute), a list of strings
1329          (ordered names of model layers).
1330      - For every layer, a `group` named `layer.name`
1331          - For every such layer group, a group attribute `weight_names`,
1332              a list of strings
1333              (ordered names of weights tensor of the layer).
1334          - For every weight in the layer, a dataset
1335              storing the weight value, named after the weight tensor.
1336
1337    When saving in TensorFlow format, all objects referenced by the network are
1338    saved in the same format as `tf.train.Checkpoint`, including any `Layer`
1339    instances or `Optimizer` instances assigned to object attributes. For
1340    networks constructed from inputs and outputs using `tf.keras.Model(inputs,
1341    outputs)`, `Layer` instances used by the network are tracked/saved
1342    automatically. For user-defined classes which inherit from `tf.keras.Model`,
1343    `Layer` instances must be assigned to object attributes, typically in the
1344    constructor. See the documentation of `tf.train.Checkpoint` and
1345    `tf.keras.Model` for details.
1346
1347    Arguments:
1348        filepath: String, path to the file to save the weights to. When saving
1349            in TensorFlow format, this is the prefix used for checkpoint files
1350            (multiple files are generated). Note that the '.h5' suffix causes
1351            weights to be saved in HDF5 format.
1352        overwrite: Whether to silently overwrite any existing file at the
1353            target location, or provide the user with a manual prompt.
1354        save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or
1355            '.keras' will default to HDF5 if `save_format` is `None`. Otherwise
1356            `None` defaults to 'tf'.
1357
1358    Raises:
1359        ImportError: If h5py is not available when attempting to save in HDF5
1360            format.
1361        ValueError: For invalid/unknown format arguments.
1362    """
1363    filepath_is_h5 = _is_hdf5_filepath(filepath)
1364    if save_format is None:
1365      if filepath_is_h5:
1366        save_format = 'h5'
1367      else:
1368        save_format = 'tf'
1369    else:
1370      user_format = save_format.lower().strip()
1371      if user_format in ('tensorflow', 'tf'):
1372        save_format = 'tf'
1373      elif user_format in ('hdf5', 'h5', 'keras'):
1374        save_format = 'h5'
1375      else:
1376        raise ValueError(
1377            'Unknown format "%s". Was expecting one of {"tf", "h5"}.' % (
1378                save_format,))
1379    if save_format == 'tf' and filepath_is_h5:
1380      raise ValueError(
1381          ('save_weights got save_format="tf"/"tensorflow", but the '
1382           'filepath ("%s") looks like an HDF5 file. Omit the ".h5"/".keras" '
1383           'when saving in TensorFlow format.')
1384          % filepath)
1385
1386    if save_format == 'h5' and h5py is None:
1387      raise ImportError(
1388          '`save_weights` requires h5py when saving in hdf5.')
1389    if save_format == 'tf':
1390      check_filepath = filepath + '.index'
1391    else:
1392      check_filepath = filepath
1393    # If file exists and should not be overwritten:
1394    if not overwrite and os.path.isfile(check_filepath):
1395      proceed = ask_to_proceed_with_overwrite(check_filepath)
1396      if not proceed:
1397        return
1398    if save_format == 'h5':
1399      with h5py.File(filepath, 'w') as f:
1400        hdf5_format.save_weights_to_hdf5_group(f, self.layers)
1401    else:
1402      if context.executing_eagerly():
1403        session = None
1404      else:
1405        session = backend.get_session()
1406      optimizer = getattr(self, 'optimizer', None)
1407      if (optimizer
1408          and not isinstance(optimizer, trackable.Trackable)):
1409        logging.warning(
1410            ('This model was compiled with a Keras optimizer (%s) but is being '
1411             'saved in TensorFlow format with `save_weights`. The model\'s '
1412             'weights will be saved, but unlike with TensorFlow optimizers in '
1413             'the TensorFlow format the optimizer\'s state will not be '
1414             'saved.\n\nConsider using a TensorFlow optimizer from `tf.train`.')
1415            % (optimizer,))
1416      self._trackable_saver.save(filepath, session=session)
1417      # Record this checkpoint so it's visible from tf.train.latest_checkpoint.
1418      checkpoint_management.update_checkpoint_state_internal(
1419          save_dir=os.path.dirname(filepath),
1420          model_checkpoint_path=filepath,
1421          save_relative_paths=True,
1422          all_model_checkpoint_paths=[filepath])
1423
1424  def load_weights(self, filepath, by_name=False):
1425    """Loads all layer weights, either from a TensorFlow or an HDF5 weight file.
1426
1427    If `by_name` is False weights are loaded based on the network's
1428    topology. This means the architecture should be the same as when the weights
1429    were saved.  Note that layers that don't have weights are not taken into
1430    account in the topological ordering, so adding or removing layers is fine as
1431    long as they don't have weights.
1432
1433    If `by_name` is True, weights are loaded into layers only if they share the
1434    same name. This is useful for fine-tuning or transfer-learning models where
1435    some of the layers have changed.
1436
1437    Only topological loading (`by_name=False`) is supported when loading weights
1438    from the TensorFlow format. Note that topological loading differs slightly
1439    between TensorFlow and HDF5 formats for user-defined classes inheriting from
1440    `tf.keras.Model`: HDF5 loads based on a flattened list of weights, while the
1441    TensorFlow format loads based on the object-local names of attributes to
1442    which layers are assigned in the `Model`'s constructor.
1443
1444    Arguments:
1445        filepath: String, path to the weights file to load. For weight files in
1446            TensorFlow format, this is the file prefix (the same as was passed
1447            to `save_weights`).
1448        by_name: Boolean, whether to load weights by name or by topological
1449            order. Only topological loading is supported for weight files in
1450            TensorFlow format.
1451
1452    Returns:
1453        When loading a weight file in TensorFlow format, returns the same status
1454        object as `tf.train.Checkpoint.restore`. When graph building, restore
1455        ops are run automatically as soon as the network is built (on first call
1456        for user-defined classes inheriting from `Model`, immediately if it is
1457        already built).
1458
1459        When loading weights in HDF5 format, returns `None`.
1460
1461    Raises:
1462        ImportError: If h5py is not available and the weight file is in HDF5
1463            format.
1464    """
1465    if _is_hdf5_filepath(filepath):
1466      save_format = 'h5'
1467    else:
1468      try:
1469        pywrap_tensorflow.NewCheckpointReader(filepath)
1470        save_format = 'tf'
1471      except errors_impl.DataLossError:
1472        # The checkpoint is not readable in TensorFlow format. Try HDF5.
1473        save_format = 'h5'
1474    if save_format == 'tf':
1475      status = self._trackable_saver.restore(filepath)
1476      if by_name:
1477        raise NotImplementedError(
1478            'Weights may only be loaded based on topology into Models when '
1479            'loading TensorFlow-formatted weights (got by_name=True to '
1480            'load_weights).')
1481      if not context.executing_eagerly():
1482        session = backend.get_session()
1483        # Restore existing variables (if any) immediately, and set up a
1484        # streaming restore for any variables created in the future.
1485        trackable_utils.streaming_restore(status=status, session=session)
1486      status.assert_nontrivial_match()
1487      return status
1488    if h5py is None:
1489      raise ImportError(
1490          '`load_weights` requires h5py when loading weights from HDF5.')
1491    if self._is_graph_network and not self.built:
1492      raise NotImplementedError(
1493          'Unable to load weights saved in HDF5 format into a subclassed '
1494          'Model which has not created its variables yet. Call the Model '
1495          'first, then load the weights.')
1496    with h5py.File(filepath, 'r') as f:
1497      if 'layer_names' not in f.attrs and 'model_weights' in f:
1498        f = f['model_weights']
1499      if by_name:
1500        hdf5_format.load_weights_from_hdf5_group_by_name(f, self.layers)
1501      else:
1502        hdf5_format.load_weights_from_hdf5_group(f, self.layers)
1503
1504  def _updated_config(self):
1505    """Util shared between different serialization methods.
1506
1507    Returns:
1508        Model config with Keras version information added.
1509    """
1510    from tensorflow.python.keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top
1511
1512    config = self.get_config()
1513    model_config = {
1514        'class_name': self.__class__.__name__,
1515        'config': config,
1516        'keras_version': keras_version,
1517        'backend': backend.backend()
1518    }
1519    return model_config
1520
1521  def to_json(self, **kwargs):
1522    """Returns a JSON string containing the network configuration.
1523
1524    To load a network from a JSON save file, use
1525    `keras.models.model_from_json(json_string, custom_objects={})`.
1526
1527    Arguments:
1528        **kwargs: Additional keyword arguments
1529            to be passed to `json.dumps()`.
1530
1531    Returns:
1532        A JSON string.
1533    """
1534    model_config = self._updated_config()
1535    return json.dumps(
1536        model_config, default=serialization.get_json_type, **kwargs)
1537
1538  def to_yaml(self, **kwargs):
1539    """Returns a yaml string containing the network configuration.
1540
1541    To load a network from a yaml save file, use
1542    `keras.models.model_from_yaml(yaml_string, custom_objects={})`.
1543
1544    `custom_objects` should be a dictionary mapping
1545    the names of custom losses / layers / etc to the corresponding
1546    functions / classes.
1547
1548    Arguments:
1549        **kwargs: Additional keyword arguments
1550            to be passed to `yaml.dump()`.
1551
1552    Returns:
1553        A YAML string.
1554
1555    Raises:
1556        ImportError: if yaml module is not found.
1557    """
1558    if yaml is None:
1559      raise ImportError(
1560          'Requires yaml module installed (`pip install pyyaml`).')
1561    return yaml.dump(self._updated_config(), **kwargs)
1562
1563  def summary(self, line_length=None, positions=None, print_fn=None):
1564    """Prints a string summary of the network.
1565
1566    Arguments:
1567        line_length: Total length of printed lines
1568            (e.g. set this to adapt the display to different
1569            terminal window sizes).
1570        positions: Relative or absolute positions of log elements
1571            in each line. If not provided,
1572            defaults to `[.33, .55, .67, 1.]`.
1573        print_fn: Print function to use. Defaults to `print`.
1574            It will be called on each line of the summary.
1575            You can set it to a custom function
1576            in order to capture the string summary.
1577
1578    Raises:
1579        ValueError: if `summary()` is called before the model is built.
1580    """
1581    if not self.built:
1582      raise ValueError('This model has not yet been built. '
1583                       'Build the model first by calling `build()` or calling '
1584                       '`fit()` with some data, or specify '
1585                       'an `input_shape` argument in the first layer(s) for '
1586                       'automatic build.')
1587    layer_utils.print_summary(self,
1588                              line_length=line_length,
1589                              positions=positions,
1590                              print_fn=print_fn)
1591
1592  def _validate_graph_inputs_and_outputs(self):
1593    """Validates the inputs and outputs of a Graph Network."""
1594    # Check for redundancy in inputs.
1595    if len(set(self.inputs)) != len(self.inputs):
1596      raise ValueError('The list of inputs passed to the model '
1597                       'is redundant. '
1598                       'All inputs should only appear once.'
1599                       ' Found: ' + str(self.inputs))
1600
1601    for x in self.inputs:
1602      # Check that x has appropriate `_keras_history` metadata.
1603      if not hasattr(x, '_keras_history'):
1604        cls_name = self.__class__.__name__
1605        raise ValueError('Input tensors to a ' + cls_name + ' ' +
1606                         'must come from `tf.keras.Input`. '
1607                         'Received: ' + str(x) +
1608                         ' (missing previous layer metadata).')
1609      # Check that x is an input tensor.
1610      # pylint: disable=protected-access
1611      layer, _, _ = x._keras_history
1612      if len(layer._inbound_nodes) > 1 or (
1613          layer._inbound_nodes and layer._inbound_nodes[0].inbound_layers):
1614        cls_name = self.__class__.__name__
1615        logging.warning(cls_name + ' inputs must come from '
1616                        '`tf.keras.Input` (thus holding past layer metadata), '
1617                        'they cannot be the output of '
1618                        'a previous non-Input layer. '
1619                        'Here, a tensor specified as '
1620                        'input to "' + self.name + '" was not an Input tensor, '
1621                        'it was generated by layer ' + layer.name + '.\n'
1622                        'Note that input tensors are '
1623                        'instantiated via `tensor = tf.keras.Input(shape)`.\n'
1624                        'The tensor that caused the issue was: ' + str(x.name))
1625
1626    # Check compatibility of batch sizes of Input Layers.
1627    input_batch_sizes = [
1628        training_utils.get_static_batch_size(x._keras_history[0])
1629        for x in self.inputs
1630    ]
1631    consistent_batch_size = None
1632    for batch_size in input_batch_sizes:
1633      if batch_size is not None:
1634        if (consistent_batch_size is not None and
1635            batch_size != consistent_batch_size):
1636          raise ValueError('The specified batch sizes of the Input Layers'
1637                           ' are incompatible. Found batch sizes: {}'.format(
1638                               input_batch_sizes))
1639        consistent_batch_size = batch_size
1640
1641    for x in self.outputs:
1642      if not hasattr(x, '_keras_history'):
1643        cls_name = self.__class__.__name__
1644        raise ValueError('Output tensors to a ' + cls_name + ' must be '
1645                         'the output of a TensorFlow `Layer` '
1646                         '(thus holding past layer metadata). Found: ' + str(x))
1647
1648
1649def _is_hdf5_filepath(filepath):
1650  return (filepath.endswith('.h5') or filepath.endswith('.keras') or
1651          filepath.endswith('.hdf5'))
1652
1653
1654def _make_node_key(layer_name, node_index):
1655  return layer_name + '_ib-' + str(node_index)
1656
1657
1658def _map_graph_network(inputs, outputs):
1659  """Validates a network's topology and gather its layers and nodes.
1660
1661  Arguments:
1662    inputs: List of input tensors.
1663    outputs: List of outputs tensors.
1664
1665  Returns:
1666    A tuple `(nodes, nodes_by_depth, layers, layers_by_depth)`.
1667    - nodes: list of Node instances.
1668    - nodes_by_depth: dict mapping ints (depth) to lists of node instances.
1669    - layers: list of Layer instances.
1670    - layers_by_depth: dict mapping ints (depth) to lists of layer instances.
1671
1672  Raises:
1673    ValueError: In case the network is not valid (e.g. disconnected graph).
1674  """
1675  # Network_nodes: set of nodes included in the graph of layers
1676  # (not all nodes included in the layers are relevant to the current graph).
1677  network_nodes = set()  # ids of all nodes relevant to the Network
1678  nodes_depths = {}  # dict {node: depth value}
1679  layers_depths = {}  # dict {layer: depth value}
1680  layer_indices = {}  # dict {layer: index in traversal}
1681  nodes_in_decreasing_depth = []
1682
1683  def build_map(tensor,
1684                finished_nodes,
1685                nodes_in_progress,
1686                layer,
1687                node_index,
1688                tensor_index):
1689    """Builds a map of the graph of layers.
1690
1691    This recursively updates the map `layer_indices`,
1692    the list `nodes_in_decreasing_depth` and the set `network_nodes`.
1693
1694    Arguments:
1695        tensor: Some tensor in a graph.
1696        finished_nodes: Set of nodes whose subgraphs have been traversed
1697            completely. Useful to prevent duplicated work.
1698        nodes_in_progress: Set of nodes that are currently active on the
1699            recursion stack. Useful to detect cycles.
1700        layer: Layer from which `tensor` comes from. If not provided,
1701            will be obtained from `tensor._keras_history`.
1702        node_index: Node index from which `tensor` comes from.
1703        tensor_index: Tensor_index from which `tensor` comes from.
1704
1705    Raises:
1706        ValueError: if a cycle is detected.
1707    """
1708    node = layer._inbound_nodes[node_index]  # pylint: disable=protected-access
1709
1710    # Prevent cycles.
1711    if node in nodes_in_progress:
1712      raise ValueError('The tensor ' + str(tensor) + ' at layer "' +
1713                       layer.name + '" is part of a cycle.')
1714
1715    # Don't repeat work for shared subgraphs
1716    if node in finished_nodes:
1717      return
1718
1719    node_key = _make_node_key(layer.name, node_index)
1720    # Update network_nodes.
1721    network_nodes.add(node_key)
1722
1723    # Store the traversal order for layer sorting.
1724    if layer not in layer_indices:
1725      layer_indices[layer] = len(layer_indices)
1726
1727    nodes_in_progress.add(node)
1728
1729    # Propagate to all previous tensors connected to this node.
1730    for layer, node_index, tensor_index, tensor in node.iterate_inbound():
1731      build_map(tensor, finished_nodes, nodes_in_progress, layer, node_index,
1732                tensor_index)
1733
1734    finished_nodes.add(node)
1735    nodes_in_progress.remove(node)
1736    nodes_in_decreasing_depth.append(node)
1737
1738  finished_nodes = set()
1739  nodes_in_progress = set()
1740  for x in outputs:
1741    layer, node_index, tensor_index = x._keras_history  # pylint: disable=protected-access
1742    build_map(x, finished_nodes, nodes_in_progress,
1743              layer=layer,
1744              node_index=node_index,
1745              tensor_index=tensor_index)
1746
1747  for node in reversed(nodes_in_decreasing_depth):
1748    # If the depth is not set, the node has no outbound nodes (depth 0).
1749    depth = nodes_depths.setdefault(node, 0)
1750
1751    # Update the depth of the corresponding layer
1752    previous_depth = layers_depths.get(node.outbound_layer, 0)
1753    # If we've seen this layer before at a higher depth,
1754    # we should use that depth instead of the node depth.
1755    # This is necessary for shared layers that have inputs at different
1756    # depth levels in the graph.
1757    depth = max(depth, previous_depth)
1758    layers_depths[node.outbound_layer] = depth
1759    nodes_depths[node] = depth
1760
1761    # Update the depth of inbound nodes.
1762    # The "depth" of a node is the max of the depths
1763    # of all layers it is connected to.
1764    for inbound_layer, node_index, _, _ in node.iterate_inbound():
1765      inbound_node = inbound_layer._inbound_nodes[node_index]  # pylint: disable=protected-access
1766      previous_depth = nodes_depths.get(inbound_node, 0)
1767      nodes_depths[inbound_node] = max(depth + 1, previous_depth)
1768
1769  # Build a dict {depth: list of nodes with this depth}
1770  nodes_by_depth = {}
1771  for node, depth in nodes_depths.items():
1772    if depth not in nodes_by_depth:
1773      nodes_by_depth[depth] = []
1774    nodes_by_depth[depth].append(node)
1775
1776  # Build a dict {depth: list of layers with this depth}
1777  layers_by_depth = {}
1778  for layer, depth in layers_depths.items():
1779    if depth not in layers_by_depth:
1780      layers_by_depth[depth] = []
1781    layers_by_depth[depth].append(layer)
1782
1783  # Get sorted list of layer depths.
1784  depth_keys = list(layers_by_depth.keys())
1785  depth_keys.sort(reverse=True)
1786
1787  # Set self.layers and self._layers_by_depth.
1788  layers = []
1789  for depth in depth_keys:
1790    layers_for_depth = layers_by_depth[depth]
1791    # Network.layers needs to have a deterministic order:
1792    # here we order them by traversal order.
1793    layers_for_depth.sort(key=lambda x: layer_indices[x])
1794    layers.extend(layers_for_depth)
1795
1796  # Get sorted list of node depths.
1797  depth_keys = list(nodes_by_depth.keys())
1798  depth_keys.sort(reverse=True)
1799
1800  # Check that all tensors required are computable.
1801  # computable_tensors: all tensors in the graph
1802  # that can be computed from the inputs provided.
1803  computable_tensors = []
1804  for x in inputs:
1805    computable_tensors.append(x)
1806
1807  layers_with_complete_input = []  # To provide a better error msg.
1808  for depth in depth_keys:
1809    for node in nodes_by_depth[depth]:
1810      layer = node.outbound_layer
1811      if layer:
1812        for x in nest.flatten(node.input_tensors):
1813          if x not in computable_tensors:
1814            raise ValueError('Graph disconnected: '
1815                             'cannot obtain value for tensor ' + str(x) +
1816                             ' at layer "' + layer.name + '". '
1817                             'The following previous layers '
1818                             'were accessed without issue: ' +
1819                             str(layers_with_complete_input))
1820        for x in nest.flatten(node.output_tensors):
1821          computable_tensors.append(x)
1822        layers_with_complete_input.append(layer.name)
1823
1824  # Ensure name unicity, which will be crucial for serialization
1825  # (since serialized nodes refer to layers by their name).
1826  all_names = [layer.name for layer in layers]
1827  for name in all_names:
1828    if all_names.count(name) != 1:
1829      raise ValueError('The name "' + name + '" is used ' +
1830                       str(all_names.count(name)) + ' times in the model. '
1831                       'All layer names should be unique.')
1832  return network_nodes, nodes_by_depth, layers, layers_by_depth
1833