1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras SavedModel deserialization."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import os
21import re
22import types
23
24from google.protobuf import message
25
26from tensorflow.core.framework import versions_pb2
27from tensorflow.python.eager import context
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import sparse_tensor
30from tensorflow.python.framework import tensor_shape
31from tensorflow.python.framework import tensor_spec
32from tensorflow.python.keras import backend
33from tensorflow.python.keras import regularizers
34from tensorflow.python.keras.engine import input_spec
35from tensorflow.python.keras.protobuf import saved_metadata_pb2
36from tensorflow.python.keras.saving import saving_utils
37from tensorflow.python.keras.saving.saved_model import constants
38from tensorflow.python.keras.saving.saved_model import json_utils
39from tensorflow.python.keras.saving.saved_model import utils
40from tensorflow.python.keras.saving.saved_model.serialized_attributes import CommonEndpoints
41from tensorflow.python.keras.utils import generic_utils
42from tensorflow.python.keras.utils import metrics_utils
43from tensorflow.python.keras.utils.generic_utils import LazyLoader
44from tensorflow.python.ops.ragged import ragged_tensor
45from tensorflow.python.platform import gfile
46from tensorflow.python.platform import tf_logging as logging
47from tensorflow.python.saved_model import load as tf_load
48from tensorflow.python.saved_model import loader_impl
49from tensorflow.python.saved_model import nested_structure_coder
50from tensorflow.python.saved_model import revived_types
51from tensorflow.python.training.tracking import base as trackable
52from tensorflow.python.training.tracking import data_structures
53from tensorflow.python.training.tracking.tracking import delete_tracking
54from tensorflow.python.util import compat
55from tensorflow.python.util import nest
56
57# To avoid circular dependencies between keras/engine and keras/saving,
58# code in keras/saving must delay imports.
59
60# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
61# once the issue with copybara is fixed.
62# pylint:disable=g-inconsistent-quotes
63models_lib = LazyLoader("models_lib", globals(),
64                        "tensorflow.python.keras.models")
65base_layer = LazyLoader(
66    "base_layer", globals(),
67    "tensorflow.python.keras.engine.base_layer")
68layers_module = LazyLoader(
69    "layers_module", globals(),
70    "tensorflow.python.keras.layers")
71input_layer = LazyLoader(
72    "input_layer", globals(),
73    "tensorflow.python.keras.engine.input_layer")
74functional_lib = LazyLoader(
75    "functional_lib", globals(),
76    "tensorflow.python.keras.engine.functional")
77training_lib = LazyLoader(
78    "training_lib", globals(),
79    "tensorflow.python.keras.engine.training")
80training_lib_v1 = LazyLoader(
81    "training_lib_v1", globals(),
82    "tensorflow.python.keras.engine.training_v1")
83metrics = LazyLoader("metrics", globals(),
84                     "tensorflow.python.keras.metrics")
85recurrent = LazyLoader(
86    "recurrent", globals(),
87    "tensorflow.python.keras.layers.recurrent")
88# pylint:enable=g-inconsistent-quotes
89
90
91PUBLIC_ATTRIBUTES = CommonEndpoints.all_functions.union(
92    CommonEndpoints.all_checkpointable_objects)
93PUBLIC_ATTRIBUTES.add(constants.KERAS_ATTR)
94
95
96def load(path, compile=True, options=None):  # pylint: disable=redefined-builtin
97  """Loads Keras objects from a SavedModel.
98
99  Any Keras layer or model saved to the SavedModel will be loaded back
100  as Keras objects. Other objects are loaded as regular trackable objects (same
101  as `tf.saved_model.load`).
102
103  Currently, Keras saving/loading only retains the Keras object's weights,
104  losses, and call function.
105
106  The loaded model can be re-compiled, but the original optimizer, compiled loss
107  functions, and metrics are not retained. This is temporary, and `model.save`
108  will soon be able to serialize compiled models.
109
110  Args:
111    path: Path to SavedModel.
112    compile: If true, compile the model after loading it.
113    options: Optional `tf.saved_model.LoadOptions` object that specifies
114      options for loading from SavedModel.
115
116
117  Returns:
118    Object loaded from SavedModel.
119  """
120  # TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics.
121  # TODO(kathywu): Add code to load from objects that contain all endpoints
122
123  # Look for metadata file or parse the SavedModel
124  metadata = saved_metadata_pb2.SavedMetadata()
125  meta_graph_def = loader_impl.parse_saved_model(path).meta_graphs[0]
126  object_graph_def = meta_graph_def.object_graph_def
127  path_to_metadata_pb = os.path.join(path, constants.SAVED_METADATA_PATH)
128  if gfile.Exists(path_to_metadata_pb):
129    try:
130      with gfile.GFile(path_to_metadata_pb, 'rb') as f:
131        file_content = f.read()
132      metadata.ParseFromString(file_content)
133    except message.DecodeError as e:
134      raise IOError('Cannot parse keras metadata {}: {}.'
135                    .format(path_to_metadata_pb, str(e)))
136  else:
137    logging.warning('SavedModel saved prior to TF 2.5 detected when loading '
138                    'Keras model. Please ensure that you are saving the model '
139                    'with model.save() or tf.keras.models.save_model(), *NOT* '
140                    'tf.saved_model.save(). To confirm, there should be a file '
141                    'named "keras_metadata.pb" in the SavedModel directory.')
142    _read_legacy_metadata(object_graph_def, metadata)
143
144  if not metadata.nodes:
145    # When there are no Keras objects, return the results from the core loader
146    return tf_load.load(path, options=options)
147
148  # Recreate layers and metrics using the info stored in the metadata.
149  keras_loader = KerasObjectLoader(metadata, object_graph_def)
150  keras_loader.load_layers(compile=compile)
151
152  # Generate a dictionary of all loaded nodes.
153  nodes_to_load = {'root': None}
154  for node_id, loaded_node in keras_loader.loaded_nodes.items():
155    nodes_to_load[keras_loader.get_path(node_id)] = loaded_node
156  loaded = tf_load.load_partial(path, nodes_to_load, options=options)
157
158  # Finalize the loaded layers and remove the extra tracked dependencies.
159  keras_loader.finalize_objects()
160  keras_loader.del_tracking()
161
162  model = loaded['root']
163
164  # pylint: disable=protected-access
165  if isinstance(model, training_lib.Model) and compile:
166    # TODO(kathywu): Use compiled objects from SavedModel, instead of
167    # creating new objects from the training config.
168    training_config = model._serialized_attributes['metadata'].get(
169        'training_config', None)
170    if training_config is not None:
171      model.compile(**saving_utils.compile_args_from_training_config(
172          training_config))
173      saving_utils.try_build_compiled_arguments(model)
174    else:
175      logging.warning('No training configuration found in save file, so the '
176                      'model was *not* compiled. Compile it manually.')
177  # pylint: enable=protected-access
178
179  # Force variables and resources to initialize.
180  if not context.executing_eagerly():
181    sess = backend.get_session()  # Variables are initialized by this call.
182    sess.run(ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS))
183
184  return model
185
186
187def _read_legacy_metadata(object_graph_def, metadata):
188  """Builds a KerasMetadata proto from the SavedModel ObjectGraphDef."""
189  # Older SavedModels store the metadata directly in the proto instead of the
190  # separate pb file.
191  node_paths = _generate_object_paths(object_graph_def)
192  for node_id, proto in enumerate(object_graph_def.nodes):
193    if (proto.WhichOneof('kind') == 'user_object' and
194        proto.user_object.identifier in constants.KERAS_OBJECT_IDENTIFIERS):
195      metadata.nodes.add(
196          node_id=node_id,
197          node_path=node_paths[node_id],
198          version=versions_pb2.VersionDef(
199              producer=1, min_consumer=1, bad_consumers=[]),
200          identifier=proto.user_object.identifier,
201          metadata=proto.user_object.metadata)
202
203
204def _generate_object_paths(object_graph_def):
205  """Traverses through an ObjectGraphDef and builds a map of all node paths."""
206  paths = {0: 'root'}
207  nodes_to_visit = [0]
208
209  while nodes_to_visit:
210    current_node = nodes_to_visit.pop()
211    current_path = paths[current_node]
212    for reference in object_graph_def.nodes[current_node].children:
213      if reference.node_id in paths:
214        continue
215      paths[reference.node_id] = '{}.{}'.format(current_path,
216                                                reference.local_name)
217      nodes_to_visit.append(reference.node_id)
218
219  return paths
220
221
222def _is_graph_network(layer):
223  """Determines whether the layer is a graph network."""
224  # pylint: disable=protected-access
225  if isinstance(layer, RevivedNetwork):
226    return False
227  elif isinstance(layer, functional_lib.Functional):
228    return (layer._is_graph_network or
229            isinstance(layer, models_lib.Sequential))
230  return False
231
232
233class KerasObjectLoader(object):
234  """Loader that recreates Keras objects (e.g. layers, models).
235
236  Layers and models are revived from either the config or SavedModel following
237  these rules:
238  1. If object is a graph network (i.e. Sequential or Functional) then it will
239     be initialized using the structure from the config only after the children
240     layers have been created. Graph networks must be initialized with inputs
241     and outputs, so all child layers must be created beforehand.
242  2. If object's config exists and the class can be found, then revive from
243     config.
244  3. Object may have already been created if its parent was revived from config.
245     In this case, do nothing.
246  4. If nothing of the above applies, compose the various artifacts from the
247     SavedModel to create a subclassed layer or model. At this time, custom
248     metrics are not supported.
249
250  """
251
252  def __init__(self, metadata, object_graph_def):
253    self._metadata = metadata
254    self._proto = object_graph_def
255
256    self._node_paths = {node_data.node_id: node_data.node_path
257                        for node_data in metadata.nodes}
258    self.loaded_nodes = {}  # Maps node path -> loaded node
259
260    # Store all node ids that have already been traversed when tracking nodes
261    # that were recreated from the config.
262    self._traversed_nodes_from_config = set()
263
264    # Maps model id -> (blank model obj, list of child layer or their node ids)
265    # This tracks all layers in functional and sequential models. These models
266    # are only reconstructed after all of their child layers have been created.
267    self.model_layer_dependencies = {}
268    self._models_to_reconstruct = []
269
270  def del_tracking(self):
271    """Removes tracked references that are only used when loading the model."""
272    # Now that the node object has been fully loaded, and the checkpoint has
273    # been restored, the object no longer needs to track objects added from
274    # SerializedAttributes. (Note that saving a training checkpoint still
275    # functions correctly, because layers and variables are tracked separately
276    # by the Layer object.)
277    # TODO(kathywu): Instead of outright deleting these nodes (which would
278    # make restoring from a different checkpoint tricky), mark them as extra
279    # dependencies that are OK to overwrite.
280    for node in self.loaded_nodes.values():
281      node = node[0]
282      if not isinstance(node, base_layer.Layer):
283        # Loaded nodes can contain other trackable objects created when
284        # loading layers from the config, such as variables.
285        continue
286      for name in PUBLIC_ATTRIBUTES:
287        delete_tracking(node, name)
288
289      if isinstance(node, functional_lib.Functional):
290        # Delete the temporary layer dependencies, which were used to restore
291        # the checkpointed values. When the model is live, the user can delete
292        # or add layers to the model at any time, so these layer dependencies
293        # may be obsolete.
294        dependencies = list(node._self_unconditional_dependency_names)  # pylint: disable=protected-access
295        for name in dependencies:
296          if re.match(r'^layer(_with_weights)?-[\d+]', name) is not None:
297            delete_tracking(node, name)
298
299  def _add_children_recreated_from_config(self, obj, proto, node_id):
300    """Recursively records objects recreated from config."""
301    # pylint: disable=protected-access
302    if node_id in self._traversed_nodes_from_config:
303      return
304
305    parent_path = self._node_paths[node_id]
306    self._traversed_nodes_from_config.add(node_id)
307    obj._maybe_initialize_trackable()
308    if isinstance(obj, base_layer.Layer) and not obj.built:
309      metadata = json_utils.decode(proto.user_object.metadata)
310      self._try_build_layer(obj, node_id, metadata.get('build_input_shape'))
311
312    # Create list of all possible children
313    children = []
314    # Look for direct children
315    for reference in proto.children:
316      obj_child = obj._lookup_dependency(reference.local_name)
317      children.append((obj_child, reference.node_id, reference.local_name))
318
319    # Add metrics that may have been added to the layer._metrics list.
320    # This is stored in the SavedModel as layer.keras_api.layer_metrics in
321    # SavedModels created after Tf 2.2.
322    metric_list_node_id = self._search_for_child_node(
323        node_id, [constants.KERAS_ATTR, 'layer_metrics'])
324    if metric_list_node_id is not None and hasattr(obj, '_metrics'):
325      obj_metrics = {m.name: m for m in obj._metrics}
326      for reference in self._proto.nodes[metric_list_node_id].children:
327        metric = obj_metrics.get(reference.local_name)
328        if metric is not None:
329          metric_path = '{}.layer_metrics.{}'.format(constants.KERAS_ATTR,
330                                                     reference.local_name)
331          children.append((metric, reference.node_id, metric_path))
332
333    for (obj_child, child_id, child_name) in children:
334      child_proto = self._proto.nodes[child_id]
335
336      if not isinstance(obj_child, trackable.Trackable):
337        continue
338      if (child_proto.user_object.identifier in
339          revived_types.registered_identifiers()):
340        setter = revived_types.get_setter(child_proto.user_object)
341      elif obj_child._object_identifier in constants.KERAS_OBJECT_IDENTIFIERS:
342        setter = _revive_setter
343      else:
344        setter = setattr
345        # pylint: enable=protected-access
346
347      if child_id in self.loaded_nodes:
348        if self.loaded_nodes[child_id][0] is not obj_child:
349          # This means that the same trackable object is referenced by two
350          # different objects that were recreated from the config.
351          logging.warn('Looks like there is an object (perhaps variable or '
352                       'layer) that is shared between different layers/models. '
353                       'This may cause issues when restoring the variable '
354                       'values. Object: {}'.format(obj_child))
355        continue
356
357      # Overwrite variable names with the ones saved in the SavedModel.
358      if (child_proto.WhichOneof('kind') == 'variable' and
359          child_proto.variable.name):
360        obj_child._handle_name = child_proto.variable.name + ':0'  # pylint: disable=protected-access
361
362      if isinstance(obj_child, data_structures.TrackableDataStructure):
363        setter = lambda *args: None
364
365      child_path = '{}.{}'.format(parent_path, child_name)
366      self._node_paths[child_id] = child_path
367      self._add_children_recreated_from_config(
368          obj_child, child_proto, child_id)
369      self.loaded_nodes[child_id] = obj_child, setter
370
371  def load_layers(self, compile=True):  # pylint: disable=redefined-builtin
372    """Load all layer nodes from the metadata."""
373    # Load metrics after models and layers, since it's likely that models
374    # and layers will create the metric when initialized (this avoids wasting
375    # time by creating objects multiple times).
376    metric_list = []
377    for node_metadata in self._metadata.nodes:
378      if node_metadata.identifier == constants.METRIC_IDENTIFIER:
379        metric_list.append(node_metadata)
380        continue
381
382      self.loaded_nodes[node_metadata.node_id] = self._load_layer(
383          node_metadata.node_id, node_metadata.identifier,
384          node_metadata.metadata)
385
386    for node_metadata in metric_list:
387      try:
388        self.loaded_nodes[node_metadata.node_id] = self._load_layer(
389            node_metadata.node_id, node_metadata.identifier,
390            node_metadata.metadata)
391      except ValueError:
392        # Metrics are only needed when the model is compiled later. We ignore
393        # errors when trying to load custom metrics when `compile=False` until
394        # custom metrics are serialized properly (b/135550038).
395        if compile:
396          raise
397        logging.warning('Unable to restore custom metric. Please ensure that '
398                        'the layer implements `get_config` and `from_config` '
399                        'when saving. In addition, please use the '
400                        '`custom_objects` arg when calling `load_model()`.')
401
402  def _load_layer(self, node_id, identifier, metadata):
403    """Load a single layer from a SavedUserObject proto."""
404    metadata = json_utils.decode(metadata)
405
406    # If node was already created
407    if node_id in self.loaded_nodes:
408      node, setter = self.loaded_nodes[node_id]
409
410      # Revive setter requires the object to have a `_serialized_attributes`
411      # property. Add it here.
412      _maybe_add_serialized_attributes(node, metadata)
413
414      config = metadata.get('config')
415      if _is_graph_network(node) and generic_utils.validate_config(config):
416        child_nodes = self._get_child_layer_node_ids(node_id)
417        self.model_layer_dependencies[node_id] = (node, child_nodes)
418        if not child_nodes:
419          self._models_to_reconstruct.append(node_id)
420      return node, setter
421
422    # Detect whether this object can be revived from the config. If not, then
423    # revive from the SavedModel instead.
424    obj, setter = self._revive_from_config(identifier, metadata, node_id)
425    if obj is None:
426      obj, setter = revive_custom_object(identifier, metadata)
427
428    # Add an attribute that stores the extra functions/objects saved in the
429    # SavedModel. Most of these functions/objects are ignored, but some are
430    # used later in the loading process (e.g. the list of regularization
431    # losses, or the training config of compiled models).
432    _maybe_add_serialized_attributes(obj, metadata)
433    return obj, setter
434
435  def _revive_from_config(self, identifier, metadata, node_id):
436    """Revives a layer/model from config, or returns None."""
437    if identifier == constants.METRIC_IDENTIFIER:
438      obj = self._revive_metric_from_config(metadata)
439    else:
440      obj = (
441          self._revive_graph_network(identifier, metadata, node_id) or
442          self._revive_layer_or_model_from_config(metadata, node_id))
443
444    if obj is None:
445      return None, None
446
447    setter = self._config_node_setter(_revive_setter)
448    self._add_children_recreated_from_config(
449        obj, self._proto.nodes[node_id], node_id)
450    return obj, setter
451
452  def _revive_graph_network(self, identifier, metadata, node_id):
453    """Revives a graph network from config."""
454    # Determine whether the metadata contains information for reviving a
455    # functional or Sequential model.
456    config = metadata.get('config')
457    if not generic_utils.validate_config(config):
458      return None
459
460    class_name = compat.as_str(metadata['class_name'])
461    if generic_utils.get_registered_object(class_name) is not None:
462      return None
463    model_is_functional_or_sequential = (
464        metadata.get('is_graph_network', False) or
465        class_name == 'Sequential' or
466        class_name == 'Functional')
467    if not model_is_functional_or_sequential:
468      return None
469
470    # Revive functional and sequential models as blank model objects for now (
471    # must be initialized to enable setattr tracking and attribute caching).
472    # Reconstruction of the network is deferred until all of the model's layers
473    # have been revived.
474    if class_name == 'Sequential':
475      model = models_lib.Sequential(name=config['name'])
476    # The model is a custom Sequential model.
477    elif identifier == constants.SEQUENTIAL_IDENTIFIER:
478      # Uses the custom class name, since the config does not have one.
479      model = models_lib.Sequential(name=class_name)
480    else:
481      model = models_lib.Functional(
482          inputs=[], outputs=[], name=config['name'])
483
484    # Record this model and its layers. This will later be used to reconstruct
485    # the model.
486    layers = self._get_child_layer_node_ids(node_id)
487    self.model_layer_dependencies[node_id] = (model, layers)
488    if not layers:
489      self._models_to_reconstruct.append(node_id)
490    return model
491
492  def _revive_layer_or_model_from_config(self, metadata, node_id):
493    """Revives a layer/custom model from config; returns None if infeasible."""
494    # Check that the following requirements are met for reviving from config:
495    #    1. Object can be deserialized from config.
496    #    2. If the object needs to be built, then the build input shape can be
497    #       found.
498    class_name = metadata.get('class_name')
499    config = metadata.get('config')
500    shared_object_id = metadata.get('shared_object_id')
501    must_restore_from_config = metadata.get('must_restore_from_config')
502    if not generic_utils.validate_config(config):
503      return None
504
505    try:
506      obj = layers_module.deserialize(
507          generic_utils.serialize_keras_class_and_config(
508              class_name, config, shared_object_id=shared_object_id))
509    except ValueError:
510      if must_restore_from_config:
511        raise RuntimeError(
512            'Unable to restore a layer of class {cls}. Layers of '
513            'class {cls} require that the class be provided to '
514            'the model loading code, either by registering the '
515            'class using @keras.utils.register_keras_serializable '
516            'on the class def and including that file in your '
517            'program, or by passing the class in a '
518            'keras.utils.CustomObjectScope that wraps this load '
519            'call.'.format(cls=class_name))
520      else:
521        return None
522
523    # Use the dtype, name, and trainable status. Often times these are not
524    # specified in custom configs, so retrieve their values from the metadata.
525    # pylint: disable=protected-access
526    obj._name = metadata['name']
527    if metadata.get('trainable') is not None:
528      obj.trainable = metadata['trainable']
529    if metadata.get('dtype') is not None:
530      obj._set_dtype_policy(metadata['dtype'])
531    if metadata.get('stateful') is not None:
532      obj.stateful = metadata['stateful']
533    # Restore model save spec for subclassed models. (layers do not store a
534    # SaveSpec)
535    if isinstance(obj, training_lib.Model):
536      save_spec = metadata.get('save_spec')
537      if save_spec is not None:
538        obj._set_save_spec(save_spec)
539    # pylint: enable=protected-access
540
541    build_input_shape = metadata.get('build_input_shape')
542    built = self._try_build_layer(obj, node_id, build_input_shape)
543
544    if not built:
545      # If the layer cannot be built, revive a custom layer instead.
546      return None
547    return obj
548
549  def _revive_metric_from_config(self, metadata):
550    """Revives a metric object using the config saved in the metadata."""
551    class_name = compat.as_str(metadata['class_name'])
552    config = metadata.get('config')
553
554    if not generic_utils.validate_config(config):
555      return None
556
557    try:
558      obj = metrics.deserialize(
559          generic_utils.serialize_keras_class_and_config(class_name, config))
560    except ValueError:
561      return None
562
563    build_input_shape = metadata.get('build_input_shape')
564    if build_input_shape is not None and hasattr(obj, '_build'):
565      obj._build(build_input_shape)  # pylint: disable=protected-access
566
567    return obj
568
569  def _try_build_layer(self, obj, node_id, build_input_shape):
570    """Attempts to build the layer."""
571    if obj.built or hasattr(obj.build, '_is_default'):
572      obj.built = True
573      return True
574
575    if build_input_shape is None:
576      build_input_shape = self._infer_inputs(node_id, convert_to_shapes=True)
577
578    if build_input_shape is not None:
579      obj.build(build_input_shape)
580      base_layer.Layer.build(obj, build_input_shape)
581      return True
582
583    return False
584
585  def _load_edges(self):
586    """Add edges for all nodes that are not waiting on initialization."""
587    for node_id, proto in enumerate(self._proto.nodes):
588      if node_id not in self.model_layer_dependencies:
589        self._add_object_graph_edges(proto, node_id)
590
591  def get_path(self, node_id):
592    return self._node_paths[node_id]
593
594  def finalize_objects(self):
595    """Finish setting up Keras objects.
596
597    This function is executed after all objects and functions have been created.
598    Call functions and losses are attached to each layer, and once all layers
599    have been fully set up, graph networks are initialized.
600
601    Subclassed models that are revived from the SavedModel are treated like
602    layers, and have their call/loss functions attached here.
603    """
604    # Finish setting up layers and subclassed models. This step attaches call
605    # functions and losses to each object, and sets model inputs/outputs.
606    layers_revived_from_config = []
607    layers_revived_from_saved_model = []
608    for node_id, (node, _) in self.loaded_nodes.items():
609      if (not isinstance(node, base_layer.Layer) or
610          # Don't finalize models until all layers have finished loading.
611          node_id in self.model_layer_dependencies):
612        continue
613
614      self._unblock_model_reconstruction(node_id, node)
615
616      if isinstance(node, input_layer.InputLayer):
617        continue
618      elif isinstance(node, metrics.Metric):
619        continue
620
621      if isinstance(node, (RevivedLayer, RevivedInputLayer)):
622        layers_revived_from_saved_model.append(node)
623      else:
624        layers_revived_from_config.append(node)
625
626    _finalize_saved_model_layers(layers_revived_from_saved_model)
627    _finalize_config_layers(layers_revived_from_config)
628
629    # Initialize graph networks, now that layer dependencies have been resolved.
630    self._reconstruct_all_models()
631
632  def _unblock_model_reconstruction(self, layer_id, layer):
633    """Removes layer from blocking model reconstruction."""
634    for model_id, v in self.model_layer_dependencies.items():
635      _, layers = v
636      if layer_id not in layers:
637        continue
638      layers[layers.index(layer_id)] = layer
639      if all(isinstance(x, base_layer.Layer) for x in layers):
640        self._models_to_reconstruct.append(model_id)
641
642  def _reconstruct_all_models(self):
643    """Reconstructs the network structure of all models."""
644    all_initialized_models = set()
645    while self._models_to_reconstruct:
646      model_id = self._models_to_reconstruct.pop(0)
647      all_initialized_models.add(model_id)
648      model, layers = self.model_layer_dependencies[model_id]
649      self._reconstruct_model(model_id, model, layers)
650      _finalize_config_layers([model])
651
652    if all_initialized_models != set(self.model_layer_dependencies.keys()):
653      # This should not happen.
654      uninitialized_model_ids = (
655          set(self.model_layer_dependencies.keys()) - all_initialized_models)
656      uninitialized_model_names = [
657          self.model_layer_dependencies[model_id][0].name
658          for model_id in uninitialized_model_ids]
659      raise ValueError('Error when loading from SavedModel -- the following '
660                       'models could not be initialized: {}'
661                       .format(uninitialized_model_names))
662
663  def _reconstruct_model(self, model_id, model, layers):
664    """Reconstructs the network structure."""
665    config = json_utils.decode(
666        self._proto.nodes[model_id].user_object.metadata)['config']
667
668    # Set up model inputs
669    if model.inputs:
670      # Inputs may already be created if the model is instantiated in another
671      # object's __init__.
672      pass
673    elif isinstance(model, models_lib.Sequential):
674      if not layers or not isinstance(layers[0], input_layer.InputLayer):
675        if config['layers'][0]['class_name'] == 'InputLayer':
676          layers.insert(0, input_layer.InputLayer.from_config(
677              config['layers'][0]['config']))
678        elif 'batch_input_shape' in config['layers'][0]['config']:
679          batch_input_shape = config['layers'][0]['config']['batch_input_shape']
680          layers.insert(0, input_layer.InputLayer(
681              input_shape=batch_input_shape[1:],
682              batch_size=batch_input_shape[0],
683              dtype=layers[0].dtype,
684              name=layers[0].name + '_input'))
685      model.__init__(layers, name=config['name'])
686      if not model.inputs:
687        first_layer = self._get_child_layer_node_ids(model_id)[0]
688        input_specs = self._infer_inputs(first_layer)
689        input_shapes = self._infer_inputs(first_layer, convert_to_shapes=True)
690        model._set_inputs(input_specs)  # pylint: disable=protected-access
691        if not model.built and not isinstance(input_specs, dict):
692          model.build(input_shapes)
693    else:  # Reconstruct functional model
694      (inputs, outputs,
695       created_layers) = functional_lib.reconstruct_from_config(
696           config, created_layers={layer.name: layer for layer in layers})
697      model.__init__(inputs, outputs, name=config['name'])
698      functional_lib.connect_ancillary_layers(model, created_layers)
699
700    # Set model dtype and trainable status.
701    _set_network_attributes_from_metadata(model)
702
703    # Unblock models that are dependent on this model.
704    self._unblock_model_reconstruction(model_id, model)
705
706  def _get_child_layer_node_ids(self, node_id):
707    """Returns the node ids of each layer in a Sequential/Functional model."""
708    # Sequential and Functional track layers with names following the format
709    # "layer-N". Use this to generate the list of layers.
710    num_layers = 0
711    child_layers = {}
712    pattern = re.compile('layer-(\\d+)')
713
714    for child in self._proto.nodes[node_id].children:
715      m = pattern.match(child.local_name)
716      if m is None:
717        continue
718      layer_n = int(m.group(1))
719      num_layers = max(layer_n + 1, num_layers)
720      child_layers[layer_n] = child.node_id
721
722    ordered = []
723    for n in range(num_layers):
724      child = child_layers.get(n)
725      if child is None:
726        break
727      ordered.append(child)
728    return ordered
729
730  def _search_for_child_node(self, parent_id, path_to_child):
731    """Returns node id of child node.
732
733    A helper method for traversing the object graph proto.
734
735    As an example, say that the object graph proto in the SavedModel contains an
736    object with the following child and grandchild attributes:
737
738    `parent.child_a.child_b`
739
740    This method can be used to retrieve the node id of `child_b` using the
741    parent's node id by calling:
742
743    `_search_for_child_node(parent_id, ['child_a', 'child_b'])`.
744
745    Args:
746      parent_id: node id of parent node
747      path_to_child: list of children names.
748
749    Returns:
750      node_id of child, or None if child isn't found.
751    """
752    if not path_to_child:
753      return parent_id
754
755    for child in self._proto.nodes[parent_id].children:
756      if child.local_name == path_to_child[0]:
757        return self._search_for_child_node(child.node_id, path_to_child[1:])
758    return None
759
760  def _infer_inputs(self, layer_node_id, convert_to_shapes=False):
761    """Infers input shape of layer from SavedModel functions."""
762    coder = nested_structure_coder.StructureCoder()
763    call_fn_id = self._search_for_child_node(
764        layer_node_id, ['call_and_return_all_conditional_losses'])
765    if call_fn_id is None:
766      return None
767
768    concrete_functions = (
769        self._proto.nodes[call_fn_id].function.concrete_functions)
770    if not concrete_functions:
771      return None
772    call_fn_name = concrete_functions[0]
773    call_fn_proto = self._proto.concrete_functions[call_fn_name]
774    structured_input_signature = coder.decode_proto(
775        call_fn_proto.canonicalized_input_signature)
776    inputs = structured_input_signature[0][0]
777    if convert_to_shapes:
778      return nest.map_structure(lambda spec: spec.shape, inputs)
779    else:
780      return inputs
781
782  def _config_node_setter(self, setter):
783    """Creates edges for nodes that are recreated from config."""
784    def setattr_wrapper(obj, name, value):
785      # Avoid overwriting attributes of objects recreated from the config.
786      if obj._lookup_dependency(name) is None:  # pylint: disable=protected-access
787        setter(obj, name, value)
788    return setattr_wrapper
789
790
791def _finalize_saved_model_layers(layers):
792  """Runs the final steps of loading Keras Layers from SavedModel."""
793  # pylint: disable=protected-access
794  # 1. Set up call functions for all layers (skip this step for Sequential and
795  # Functional models).
796  for layer in layers:
797    layer.built = True
798    if hasattr(_get_keras_attr(layer), 'call_and_return_conditional_losses'):
799      layer.call = utils.use_wrapped_call(
800          layer, _get_keras_attr(layer).call_and_return_conditional_losses,
801          return_method=True)
802      layer._init_call_fn_args()
803    else:
804      layer.call = types.MethodType(
805          _unable_to_call_layer_due_to_serialization_issue, layer)
806
807  for layer in layers:
808    # 2. Set model inputs and outputs.
809    if isinstance(layer, RevivedNetwork):
810      _set_network_attributes_from_metadata(layer)
811
812      if hasattr(_get_keras_attr(layer), 'call_and_return_conditional_losses'):
813        call_fn = _get_keras_attr(layer).call_and_return_conditional_losses
814        if call_fn.input_signature is None:
815          inputs = infer_inputs_from_restored_call_function(call_fn)
816        else:
817          inputs = call_fn.input_signature[0]
818        layer._set_inputs(inputs)  # pylint: disable=protected-access
819
820    # 3. Add losses that aren't generated by the layer.call function.
821    _restore_layer_unconditional_losses(layer)
822    _restore_layer_activation_loss(layer)
823
824    # 4. Restore metrics list
825    _restore_layer_metrics(layer)
826
827  # pylint: enable=protected-access
828
829
830def _unable_to_call_layer_due_to_serialization_issue(
831    layer, *unused_args, **unused_kwargs):
832  """Replaces the `layer.call` if the layer was not fully serialized.
833
834  Keras Model/Layer serialization is relatively relaxed because SavedModels
835  are not always loaded back as keras models. Thus, when there is an issue
836  tracing a non-signature function, a warning is logged instead of raising an
837  error. This results in a SavedModel where the model's call function is saved,
838  but the internal layer call functions are not.
839
840  When deserialized with `tf.keras.models.load_model`, the internal layers
841  which do not have serialized call functions should raise an error when called.
842
843  Args:
844    layer: Layer without the serialized call function.
845
846  Raises:
847    ValueError
848  """
849
850  raise ValueError(
851      'Cannot call custom layer {} of type {}, because the call function was '
852      'not serialized to the SavedModel.'
853      'Please try one of the following methods to fix this issue:'
854      '\n\n(1) Implement `get_config` and `from_config` in the layer/model '
855      'class, and pass the object to the `custom_objects` argument when '
856      'loading the model. For more details, see: '
857      'https://www.tensorflow.org/guide/keras/save_and_serialize'
858      '\n\n(2) Ensure that the subclassed model or layer overwrites `call` '
859      'and not `__call__`. The input shape and dtype will be automatically '
860      'recorded when the object is called, and used when saving. To manually '
861      'specify the input shape/dtype, decorate the call function with '
862      '`@tf.function(input_signature=...)`.'.format(layer.name, type(layer)))
863
864
865def _finalize_config_layers(layers):
866  """Runs the final steps of loading Keras Layers from config."""
867  for layer in layers:
868    # It is assumed that layers define their unconditional losses after being
869    # recreated from the config and built. The exceptions to this
870    # are Functional and Sequential models, which only store conditional losses
871    # (losses dependent on the inputs) in the config. Unconditional losses like
872    # weight regularization must be revived from the SavedModel.
873    if _is_graph_network(layer):
874      _restore_layer_unconditional_losses(layer)
875
876    # Some layers, like Dense, record their activation loss function in the
877    # config. However, not all layers do this, so the activation loss may be
878    # missing when restored from the config/hdf5.
879    # TODO(kathywu): Investigate ways to improve the config to ensure consistent
880    # loading behavior between HDF5 and SavedModel.
881    _restore_layer_activation_loss(layer)
882
883    # Restore metrics list.
884    _restore_layer_metrics(layer)
885
886    # Restore RNN layer states
887    if (isinstance(layer, recurrent.RNN) and
888        layer.stateful and
889        hasattr(_get_keras_attr(layer), 'states')):
890      layer.states = getattr(_get_keras_attr(layer), 'states', None)
891      for variable in nest.flatten(layer.states):
892        backend.track_variable(variable)
893
894
895def _finalize_metric(metric):
896  metric.update_state = types.MethodType(metrics_utils.update_state_wrapper(
897      metric.keras_api.update_state), metric)
898  metric.result = metric.keras_api.result
899
900
901def _restore_layer_unconditional_losses(layer):
902  """Restore unconditional losses from SavedModel."""
903  if hasattr(_get_keras_attr(layer), 'layer_regularization_losses'):
904    losses = getattr(_get_keras_attr(layer), 'layer_regularization_losses', [])
905  else:
906    # Some earlier SavedModels may not have layer_regularization_losses
907    # serialized separately. Fall back to using the regularization_losses
908    # list if it does not exist.
909    losses = layer._serialized_attributes.get('regularization_losses', [])  # pylint: disable=protected-access
910  for loss in losses:
911    layer.add_loss(loss)
912
913
914def _restore_layer_activation_loss(layer):
915  """Restore actiation loss from SavedModel."""
916  # Use wrapped activity regularizer function if the layer's activity
917  # regularizer wasn't created during initialization.
918  activity_regularizer = getattr(_get_keras_attr(layer),
919                                 'activity_regularizer_fn', None)
920  if activity_regularizer and not layer.activity_regularizer:
921    try:
922      layer.activity_regularizer = activity_regularizer
923    except AttributeError:
924      # This may happen if a layer wrapper is saved with an activity
925      # regularizer. The wrapper object's activity regularizer is unsettable.
926      pass
927
928
929def revive_custom_object(identifier, metadata):
930  """Revives object from SavedModel."""
931  if ops.executing_eagerly_outside_functions():
932    model_class = training_lib.Model
933  else:
934    model_class = training_lib_v1.Model
935
936  revived_classes = {
937      constants.INPUT_LAYER_IDENTIFIER: (
938          RevivedInputLayer, input_layer.InputLayer),
939      constants.LAYER_IDENTIFIER: (RevivedLayer, base_layer.Layer),
940      constants.MODEL_IDENTIFIER: (RevivedNetwork, model_class),
941      constants.NETWORK_IDENTIFIER: (RevivedNetwork, functional_lib.Functional),
942      constants.SEQUENTIAL_IDENTIFIER: (RevivedNetwork, models_lib.Sequential),
943  }
944  parent_classes = revived_classes.get(identifier, None)
945
946  if parent_classes is not None:
947    parent_classes = revived_classes[identifier]
948    revived_cls = type(
949        compat.as_str(metadata['class_name']), parent_classes, {})
950    return revived_cls._init_from_metadata(metadata)  # pylint: disable=protected-access
951  else:
952    raise ValueError('Unable to restore custom object of type {} currently. '
953                     'Please make sure that the layer implements `get_config`'
954                     'and `from_config` when saving. In addition, please use '
955                     'the `custom_objects` arg when calling `load_model()`.'
956                     .format(identifier))
957
958
959def _restore_layer_metrics(layer):
960  metrics_list = getattr(_get_keras_attr(layer), 'layer_metrics', {})
961  layer_metrics = {m.name: m for m in layer._metrics}  # pylint: disable=protected-access
962  for name, metric in metrics_list.items():
963    if name not in layer_metrics:
964      # Metrics may be added during initialization/building of custom layers.
965      layer._metrics.append(metric)  # pylint: disable=protected-access
966
967
968# TODO(kathywu): Centrally define keys and functions for both  serialization and
969# deserialization.
970class RevivedLayer(object):
971  """Keras layer loaded from a SavedModel."""
972
973  @classmethod
974  def _init_from_metadata(cls, metadata):
975    """Create revived layer from metadata stored in the SavedModel proto."""
976    init_args = dict(
977        name=metadata['name'],
978        trainable=metadata['trainable'])
979    if metadata.get('dtype') is not None:
980      init_args['dtype'] = metadata['dtype']
981    if metadata.get('batch_input_shape') is not None:
982      init_args['batch_input_shape'] = metadata['batch_input_shape']
983
984    revived_obj = cls(**init_args)
985
986    with trackable.no_automatic_dependency_tracking_scope(revived_obj):
987      # pylint:disable=protected-access
988      revived_obj._expects_training_arg = metadata['expects_training_arg']
989      config = metadata.get('config')
990      if generic_utils.validate_config(config):
991        revived_obj._config = config
992      if metadata.get('input_spec') is not None:
993        revived_obj.input_spec = recursively_deserialize_keras_object(
994            metadata['input_spec'],
995            module_objects={'InputSpec': input_spec.InputSpec})
996      if metadata.get('activity_regularizer') is not None:
997        revived_obj.activity_regularizer = regularizers.deserialize(
998            metadata['activity_regularizer'])
999      if metadata.get('_is_feature_layer') is not None:
1000        revived_obj._is_feature_layer = metadata['_is_feature_layer']
1001      if metadata.get('stateful') is not None:
1002        revived_obj.stateful = metadata['stateful']
1003      # pylint:enable=protected-access
1004
1005    return revived_obj, _revive_setter
1006
1007  @property
1008  def keras_api(self):
1009    return self._serialized_attributes.get(constants.KERAS_ATTR, None)
1010
1011  def get_config(self):
1012    if hasattr(self, '_config'):
1013      return self._config
1014    else:
1015      raise NotImplementedError
1016
1017
1018def _revive_setter(layer, name, value):
1019  """Setter function that saves some attributes to separate dictionary."""
1020  # Many attributes in the SavedModel conflict with properties defined in
1021  # Layer and Model. Save these attributes to a separate dictionary.
1022  if name in PUBLIC_ATTRIBUTES:
1023    # pylint: disable=protected-access
1024    if isinstance(value, trackable.Trackable):
1025      layer._track_trackable(value, name=name)
1026    layer._serialized_attributes[name] = value
1027    # pylint: enable=protected-access
1028  elif (isinstance(layer, functional_lib.Functional) and
1029        re.match(r'^layer(_with_weights)?-[\d+]', name) is not None):
1030    # Edges named "layer-n" or "layer_with_weights-n", which are tracked in
1031    # network._track_layers, should not be added as an attribute. They should
1032    # be temporarily added as a dependency so that checkpointed values can be
1033    # restored. These dependencies are manually deleted in
1034    # KerasObjectLoader.del_tracking.
1035    layer._track_trackable(value, name)  # pylint: disable=protected-access
1036  elif getattr(layer, name, None) is not None:
1037    # Don't overwrite already defined attributes.
1038    pass
1039  else:
1040    setattr(layer, name, value)
1041
1042
1043class RevivedInputLayer(object):
1044  """InputLayer loaded from a SavedModel."""
1045
1046  @classmethod
1047  def _init_from_metadata(cls, metadata):
1048    """Revives the saved InputLayer from the Metadata."""
1049    init_args = dict(
1050        name=metadata['name'],
1051        dtype=metadata['dtype'],
1052        sparse=metadata['sparse'],
1053        ragged=metadata['ragged'],
1054        batch_input_shape=metadata['batch_input_shape'])
1055    revived_obj = cls(**init_args)
1056    with trackable.no_automatic_dependency_tracking_scope(revived_obj):
1057      revived_obj._config = metadata['config']  # pylint:disable=protected-access
1058
1059    return revived_obj, setattr
1060
1061  def get_config(self):
1062    return self._config
1063
1064
1065def recursively_deserialize_keras_object(config, module_objects=None):
1066  """Deserialize Keras object from a nested structure."""
1067  if isinstance(config, dict):
1068    if 'class_name' in config:
1069      return generic_utils.deserialize_keras_object(
1070          config, module_objects=module_objects)
1071    else:
1072      return {key: recursively_deserialize_keras_object(config[key],
1073                                                        module_objects)
1074              for key in config}
1075  if isinstance(config, (tuple, list)):
1076    return [recursively_deserialize_keras_object(x, module_objects)
1077            for x in config]
1078  else:
1079    raise ValueError('Unable to decode config: {}'.format(config))
1080
1081
1082def get_common_shape(x, y):
1083  """Find a `TensorShape` that is compatible with both `x` and `y`."""
1084  if x is None != y is None:
1085    raise RuntimeError(
1086        'Cannot find a common shape when LHS shape is None but RHS shape '
1087        'is not (or vice versa): %s vs. %s' % (x, y))
1088  if x is None:
1089    return None  # The associated input was not a Tensor, no shape generated.
1090  if not isinstance(x, tensor_shape.TensorShape):
1091    raise TypeError('Expected x to be a TensorShape but saw %s' % (x,))
1092  if not isinstance(y, tensor_shape.TensorShape):
1093    raise TypeError('Expected y to be a TensorShape but saw %s' % (y,))
1094  if x.rank != y.rank or x.rank is None:
1095    return tensor_shape.TensorShape(None)
1096  dims = []
1097  for dim_x, dim_y in zip(x.dims, y.dims):
1098    if (dim_x != dim_y
1099        or tensor_shape.dimension_value(dim_x) is None
1100        or tensor_shape.dimension_value(dim_y) is None):
1101      dims.append(None)
1102    else:
1103      dims.append(tensor_shape.dimension_value(dim_x))
1104  return tensor_shape.TensorShape(dims)
1105
1106
1107def infer_inputs_from_restored_call_function(fn):
1108  """Returns TensorSpec of inputs from a restored call function.
1109
1110  Args:
1111    fn: Restored layer call function. It is assumed that the inputs are entirely
1112      in the first argument.
1113
1114  Returns:
1115    TensorSpec of call function inputs.
1116  """
1117  def common_spec(x, y):
1118    common_shape = get_common_shape(x.shape, y.shape)
1119    if isinstance(x, sparse_tensor.SparseTensorSpec):
1120      return sparse_tensor.SparseTensorSpec(common_shape, x.dtype)
1121    elif isinstance(x, ragged_tensor.RaggedTensorSpec):
1122      return ragged_tensor.RaggedTensorSpec(common_shape, x.dtype)
1123    return tensor_spec.TensorSpec(common_shape, x.dtype, x.name)
1124
1125  spec = fn.concrete_functions[0].structured_input_signature[0][0]
1126  for concrete in fn.concrete_functions[1:]:
1127    spec2 = concrete.structured_input_signature[0][0]
1128    spec = nest.map_structure(common_spec, spec, spec2)
1129  return spec
1130
1131
1132class RevivedNetwork(RevivedLayer):
1133  """Keras network of layers loaded from a SavedModel."""
1134
1135  @classmethod
1136  def _init_from_metadata(cls, metadata):
1137    """Create revived network from metadata stored in the SavedModel proto."""
1138    revived_obj = cls(name=metadata['name'])
1139
1140    # Store attributes revived from SerializedAttributes in a un-tracked
1141    # dictionary. The attributes are the ones listed in CommonEndpoints or
1142    # "keras_api" for keras-specific attributes.
1143    with trackable.no_automatic_dependency_tracking_scope(revived_obj):
1144      # pylint:disable=protected-access
1145      revived_obj._expects_training_arg = metadata['expects_training_arg']
1146      config = metadata.get('config')
1147      if generic_utils.validate_config(config):
1148        revived_obj._config = config
1149
1150      if metadata.get('activity_regularizer') is not None:
1151        revived_obj.activity_regularizer = regularizers.deserialize(
1152            metadata['activity_regularizer'])
1153      # pylint:enable=protected-access
1154
1155    return revived_obj, _revive_setter  # pylint:disable=protected-access
1156
1157
1158def _set_network_attributes_from_metadata(revived_obj):
1159  """Sets attributes recorded in the metadata."""
1160  with trackable.no_automatic_dependency_tracking_scope(revived_obj):
1161    # pylint:disable=protected-access
1162    metadata = revived_obj._serialized_attributes['metadata']
1163    if metadata.get('dtype') is not None:
1164      revived_obj._set_dtype_policy(metadata['dtype'])
1165    revived_obj.trainable = metadata['trainable']
1166    # pylint:enable=protected-access
1167
1168
1169def _maybe_add_serialized_attributes(layer, metadata):
1170  # Store attributes revived from SerializedAttributes in a un-tracked
1171  # dictionary. The attributes are the ones listed in CommonEndpoints or
1172  # "keras_api" for keras-specific attributes.
1173  if not hasattr(layer, '_serialized_attributes'):
1174    with trackable.no_automatic_dependency_tracking_scope(layer):
1175      layer._serialized_attributes = {'metadata': metadata}  # pylint: disable=protected-access
1176
1177
1178def _get_keras_attr(layer):
1179  return getattr(layer, '_serialized_attributes', {}).get(constants.KERAS_ATTR,
1180                                                          None)
1181