1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Code for model cloning, plus model-related API entries.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.python.keras import backend as K
23from tensorflow.python.keras import metrics as metrics_module
24from tensorflow.python.keras import optimizers
25from tensorflow.python.keras.engine import sequential
26from tensorflow.python.keras.engine import training
27from tensorflow.python.keras.engine.base_layer import Layer
28from tensorflow.python.keras.engine.input_layer import Input
29from tensorflow.python.keras.engine.input_layer import InputLayer
30from tensorflow.python.keras.engine.network import Network
31from tensorflow.python.keras.saving import hdf5_format
32from tensorflow.python.keras.saving import model_config
33from tensorflow.python.keras.utils import generic_utils
34from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
35from tensorflow.python.util import nest
36from tensorflow.python.util.tf_export import keras_export
37
38
39# API entries importable from `keras.models`:
40Model = training.Model  # pylint: disable=invalid-name
41Sequential = sequential.Sequential  # pylint: disable=invalid-name
42save_model = hdf5_format.save_model
43load_model = hdf5_format.load_model
44model_from_config = model_config.model_from_config
45model_from_yaml = model_config.model_from_yaml
46model_from_json = model_config.model_from_json
47
48
49def _clone_layer(layer):
50  return layer.__class__.from_config(layer.get_config())
51
52
53def _clone_functional_model(model, input_tensors=None, share_weights=False):
54  """Clone a functional `Model` instance.
55
56  Model cloning is similar to calling a model on new inputs,
57  except that it creates new layers (and thus new weights) instead
58  of sharing the weights of the existing layers.
59
60  Arguments:
61      model: Instance of `Model`.
62      input_tensors: optional list of input tensors
63          to build the model upon. If not provided,
64          placeholders will be created.
65      share_weights: flag to enable sharing of non-input layers between the
66          cloned and original model. Note this still clones the input layers.
67          This is required when we create a per-replica copy of the model with
68          distribution strategy; we want the weights to be shared but still
69          feed inputs separately so we create new input layers.
70
71  Returns:
72      An instance of `Model` reproducing the behavior
73      of the original model, on top of new inputs tensors,
74      using newly instantiated weights.
75
76  Raises:
77      ValueError: in case of invalid `model` argument value.
78  """
79  if not isinstance(model, Model):
80    raise ValueError('Expected `model` argument '
81                     'to be a `Model` instance, got ', model)
82  if isinstance(model, Sequential):
83    raise ValueError('Expected `model` argument '
84                     'to be a functional `Model` instance, '
85                     'got a `Sequential` instance instead:', model)
86
87  layer_map = {}  # Cache for created layers.
88  tensor_map = {}  # Map {reference_tensor: corresponding_tensor}
89  if input_tensors is None:
90    # Create placeholders to build the model on top of.
91    input_tensors = []
92    for layer in model._input_layers:
93      input_tensor = Input(
94          batch_shape=layer._batch_input_shape,
95          dtype=layer.dtype,
96          sparse=layer.sparse,
97          name=layer.name)
98      input_tensors.append(input_tensor)
99      # Cache newly created input layer.
100      newly_created_input_layer = input_tensor._keras_history[0]
101      layer_map[layer] = newly_created_input_layer
102  else:
103    # Make sure that all input tensors come from a Keras layer.
104    # If tensor comes from an input layer: cache the input layer.
105    input_tensors = nest.flatten(input_tensors)
106    input_tensors_ = []
107    for i in range(len(input_tensors)):
108      input_tensor = input_tensors[i]
109      if not K.is_keras_tensor(input_tensor):
110        original_input_layer = model._input_layers[i]
111        name = original_input_layer.name
112        input_tensor = Input(tensor=input_tensor,
113                             name='input_wrapper_for_' + name)
114
115        input_tensors_.append(input_tensor)
116        # Cache newly created input layer.
117        newly_created_input_layer = input_tensor._keras_history[0]
118        layer_map[original_input_layer] = newly_created_input_layer
119      else:
120        input_tensors_.append(input_tensor)
121    input_tensors = input_tensors_
122
123  for x, y in zip(model.inputs, input_tensors):
124    tensor_map[x] = y
125
126  # Iterated over every node in the reference model, in depth order.
127  depth_keys = list(model._nodes_by_depth.keys())
128  depth_keys.sort(reverse=True)
129  for depth in depth_keys:
130    nodes = model._nodes_by_depth[depth]
131    for node in nodes:
132      # Recover the corresponding layer.
133      layer = node.outbound_layer
134
135      # Get or create layer.
136      if layer not in layer_map:
137        if not share_weights:
138          # Clone layer.
139          new_layer = _clone_layer(layer)
140          layer_map[layer] = new_layer
141          layer = new_layer
142      else:
143        # Reuse previously cloned layer.
144        layer = layer_map[layer]
145        # Don't call InputLayer multiple times.
146        if isinstance(layer, InputLayer):
147          continue
148
149      # If all previous input tensors are available in tensor_map,
150      # then call node.inbound_layer on them.
151      if all(
152          tensor in tensor_map for tensor in nest.flatten(node.input_tensors)):
153        computed_tensors = nest.map_structure(lambda t: tensor_map[t],
154                                              node.input_tensors)
155        # Call layer.
156        kwargs = node.arguments or {}
157        output_tensors = layer(computed_tensors, **kwargs)
158
159        for x, y in zip(
160            nest.flatten(node.output_tensors), nest.flatten(output_tensors)):
161          tensor_map[x] = y
162
163  # Check that we did compute the model outputs,
164  # then instantiate a new model from inputs and outputs.
165  output_tensors = []
166  for x in model.outputs:
167    assert x in tensor_map, 'Could not compute output ' + str(x)
168    output_tensors.append(tensor_map[x])
169
170  input_tensors = nest.pack_sequence_as(model._nested_inputs, input_tensors)
171  output_tensors = nest.pack_sequence_as(model._nested_outputs, output_tensors)
172  return Model(input_tensors, output_tensors, name=model.name)
173
174
175def _clone_sequential_model(model, input_tensors=None, share_weights=False):
176  """Clone a `Sequential` model instance.
177
178  Model cloning is similar to calling a model on new inputs,
179  except that it creates new layers (and thus new weights) instead
180  of sharing the weights of the existing layers.
181
182  Arguments:
183      model: Instance of `Sequential`.
184      input_tensors: optional list of input tensors
185          to build the model upon. If not provided,
186          placeholders will be created.
187      share_weights: flag to enable sharing of non-input layers between the
188          cloned and original model. Note this still clones the input layers.
189          This is required when we create a per-replica copy of the model with
190          distribution strategy; we want the weights to be shared but still
191          feed inputs separately so we create new input layers.
192
193  Returns:
194      An instance of `Sequential` reproducing the behavior
195      of the original model, on top of new inputs tensors,
196      using newly instantiated weights.
197
198  Raises:
199      ValueError: in case of invalid `model` argument value.
200  """
201  if not isinstance(model, Sequential):
202    raise ValueError('Expected `model` argument '
203                     'to be a `Sequential` model instance, '
204                     'but got:', model)
205
206  # Use model._layers to ensure that all layers are cloned. The model's layers
207  # property will exclude the initial InputLayer (if it exists) in the model,
208  # resulting in a different Sequential model structure.
209  if input_tensors is None:
210    if share_weights:
211      # In preserve weights case we still want the input layers to be cloned.
212      layers = []
213      for layer in model._layers:
214        if isinstance(layer, InputLayer):
215          layers.append(_clone_layer(layer))
216        else:
217          layers.append(layer)
218    else:
219      layers = [_clone_layer(layer) for layer in model._layers]
220    return Sequential(layers=layers, name=model.name)
221  else:
222    # If input tensors are provided, the original model's InputLayer is
223    # overwritten with a different InputLayer.
224    layers = [
225        layer for layer in model._layers if not isinstance(layer, InputLayer)]
226    if not share_weights:
227      layers = [_clone_layer(layer) for layer in layers]
228    if len(generic_utils.to_list(input_tensors)) != 1:
229      raise ValueError('To clone a `Sequential` model, we expect '
230                       ' at most one tensor '
231                       'as part of `input_tensors`.')
232
233    if isinstance(input_tensors, tuple):
234      input_tensors = list(input_tensors)
235    x = generic_utils.to_list(input_tensors)[0]
236    if K.is_keras_tensor(x):
237      origin_layer = x._keras_history[0]
238      if isinstance(origin_layer, InputLayer):
239        return Sequential(layers=[origin_layer] + layers, name=model.name)
240      else:
241        raise ValueError('Cannot clone a `Sequential` model on top '
242                         'of a tensor that comes from a Keras layer '
243                         'other than an `InputLayer`. '
244                         'Use the functional API instead.')
245    input_tensor = Input(tensor=x, name='input_wrapper_for_' + str(x.name))
246    input_layer = input_tensor._keras_history[0]
247    return Sequential(layers=[input_layer] + layers, name=model.name)
248
249
250@keras_export('keras.models.clone_model')
251def clone_model(model, input_tensors=None):
252  """Clone any `Model` instance.
253
254  Model cloning is similar to calling a model on new inputs,
255  except that it creates new layers (and thus new weights) instead
256  of sharing the weights of the existing layers.
257
258  Arguments:
259      model: Instance of `Model`
260          (could be a functional model or a Sequential model).
261      input_tensors: optional list of input tensors or InputLayer objects
262          to build the model upon. If not provided,
263          placeholders will be created.
264
265  Returns:
266      An instance of `Model` reproducing the behavior
267      of the original model, on top of new inputs tensors,
268      using newly instantiated weights.
269
270  Raises:
271      ValueError: in case of invalid `model` argument value.
272  """
273  if isinstance(model, Sequential):
274    return _clone_sequential_model(model, input_tensors=input_tensors)
275  else:
276    return _clone_functional_model(model, input_tensors=input_tensors)
277
278
279# "Clone" a subclassed model by reseting all of the attributes.
280def _in_place_subclassed_model_reset(model):
281  """Substitute for model cloning that works for subclassed models.
282
283  Subclassed models cannot be cloned because their topology is not serializable.
284  To "instantiate" an identical model in a new TF graph, we reuse the original
285  model object, but we clear its state.
286
287  After calling this function on a model instance, you can use the model
288  instance as if it were a model clone (in particular you can use it in a new
289  graph).
290
291  This method clears the state of the input model. It is thus destructive.
292  However the original state can be restored fully by calling
293  `_in_place_subclassed_model_state_restoration`.
294
295  Args:
296    model: Instance of a Keras model created via subclassing.
297
298  Raises:
299    ValueError: In case the model uses a subclassed model as inner layer.
300  """
301  assert not model._is_graph_network  # Only makes sense for subclassed networks
302  # Retrieve all layers tracked by the model as well as their attribute names
303  attributes_cache = {}
304  for name in dir(model):
305    try:
306      value = getattr(model, name)
307    except (AttributeError, ValueError, TypeError):
308      continue
309    if isinstance(value, Layer):
310      attributes_cache[name] = value
311      assert value in model.layers
312      if hasattr(value, 'layers') and value.layers:
313        raise ValueError('We do not support the use of nested layers '
314                         'in `model_to_estimator` at this time. Found nested '
315                         'layer: %s' % value)
316    elif isinstance(
317        value, (list, tuple)) and name not in ('layers', '_layers', 'metrics',
318                                               '_compile_metric_functions',
319                                               '_output_loss_metrics'):
320      # Handle case: list/tuple of layers (also tracked by the Network API).
321      if value and all(isinstance(val, Layer) for val in value):
322        raise ValueError('We do not support the use of list-of-layers '
323                         'attributes in subclassed models used with '
324                         '`model_to_estimator` at this time. Found list '
325                         'model: %s' % name)
326
327  # Replace layers on the model with fresh layers
328  layers_to_names = {value: key for key, value in attributes_cache.items()}
329  original_layers = model._layers[:]
330  setattr_tracking = model._setattr_tracking
331  model._setattr_tracking = False
332  model._layers = []
333  for layer in original_layers:  # We preserve layer order.
334    config = layer.get_config()
335    # This will not work for nested subclassed models used as layers.
336    # This would be theoretically possible to support, but would add complexity.
337    # Only do it if users complain.
338    if isinstance(layer, Network) and not layer._is_graph_network:
339      raise ValueError('We do not support the use of nested subclassed models '
340                       'in `model_to_estimator` at this time. Found nested '
341                       'model: %s' % layer)
342    fresh_layer = layer.__class__.from_config(config)
343    name = layers_to_names[layer]
344    setattr(model, name, fresh_layer)
345    model._layers.append(fresh_layer)
346
347  # Cache original model build attributes (in addition to layers)
348  if (not hasattr(model, '_original_attributes_cache') or
349      model._original_attributes_cache is None):
350    if model.built:
351      attributes_to_cache = [
352          'inputs',
353          'outputs',
354          '_feed_outputs',
355          '_feed_output_names',
356          '_feed_output_shapes',
357          '_feed_loss_fns',
358          'loss_weights_list',
359          'targets',
360          '_feed_targets',
361          'sample_weight_modes',
362          'total_loss',
363          'sample_weights',
364          '_feed_sample_weights',
365          'train_function',
366          'test_function',
367          'predict_function',
368          '_collected_trainable_weights',
369          '_feed_inputs',
370          '_feed_input_names',
371          '_feed_input_shapes',
372          'optimizer',
373      ]
374      for name in attributes_to_cache:
375        attributes_cache[name] = getattr(model, name)
376  model._original_attributes_cache = attributes_cache
377  _reset_build_compile_trackers(model)
378  model._setattr_tracking = setattr_tracking
379
380
381def _reset_build_compile_trackers(model):
382  """Reset state trackers for model.
383
384  Note that we do not actually zero out attributes such as optimizer,
385  but instead rely on the expectation that all of the attrs will be
386  over-written on calling build/compile/etc. This is somewhat fragile,
387  insofar as we check elsewhere for the presence of these attributes as
388  evidence of having been built/compiled/etc. Pending a better way to do this,
389  we reset key attributes here to allow building and compiling.
390
391  Args:
392    model: the model that is being reset
393  """
394  # Reset build state
395  model.built = False
396  model.inputs = None
397  model.outputs = None
398  # Reset compile state
399  model._is_compiled = False  # pylint:disable=protected-access
400  model.optimizer = None
401
402
403def in_place_subclassed_model_state_restoration(model):
404  """Restores the original state of a model after it was "reset".
405
406  This undoes this action of `_in_place_subclassed_model_reset`, which is called
407  in `clone_and_build_model` if `in_place_reset` is set to True.
408
409  Args:
410    model: Instance of a Keras model created via subclassing, on which
411      `_in_place_subclassed_model_reset` was previously called.
412  """
413  assert not model._is_graph_network
414  # Restore layers and build attributes
415  if (hasattr(model, '_original_attributes_cache') and
416      model._original_attributes_cache is not None):
417    # Models have sticky attribute assignment, so we want to be careful to add
418    # back the previous attributes and track Layers by their original names
419    # without adding dependencies on "utility" attributes which Models exempt
420    # when they're constructed.
421    setattr_tracking = model._setattr_tracking
422    model._setattr_tracking = False
423    model._layers = []
424    for name, value in model._original_attributes_cache.items():
425      setattr(model, name, value)
426      if isinstance(value, Layer):
427        model._layers.append(value)
428    model._original_attributes_cache = None
429    model._setattr_tracking = setattr_tracking
430  else:
431    # Restore to the state of a never-called model.
432    _reset_build_compile_trackers(model)
433
434
435def clone_and_build_model(
436    model, input_tensors=None, target_tensors=None, custom_objects=None,
437    compile_clone=True, in_place_reset=False, optimizer_iterations=None):
438  """Clone a `Model` and build/compile it with the same settings used before.
439
440  This function can be be run in the same graph or in a separate graph from the
441  model. When using a separate graph, `in_place_reset` must be `False`.
442
443  Note that, currently, the clone produced from this function may not work with
444  TPU DistributionStrategy. Try at your own risk.
445
446  Args:
447    model: `tf.keras.Model` object. Can be Functional, Sequential, or
448      sub-classed.
449    input_tensors: Optional list of input tensors to build the model upon. If
450      not provided, placeholders will be created.
451    target_tensors: Optional list of target tensors for compiling the model. If
452      not provided, placeholders will be created.
453    custom_objects: Optional dictionary mapping string names to custom classes
454      or functions.
455    compile_clone: Boolean, whether to compile model clone (default `True`).
456    in_place_reset: Boolean, whether to reset the model in place. Only used if
457      the model is a subclassed model. In the case of a subclassed model,
458      this argument must be set to `True` (default `False`). To restore the
459      original model, use the function
460      `in_place_subclassed_model_state_restoration(model)`.
461    optimizer_iterations: An iterations variable that will be incremented by the
462      optimizer if the clone is compiled. This argument is used when a Keras
463      model is cloned into an Estimator model function, because Estimators
464      create their own global step variable.
465
466  Returns:
467    Clone of the model.
468
469  Raises:
470    ValueError: Cloning fails in the following cases
471      - cloning a subclassed model with `in_place_reset` set to False.
472      - compiling the clone when the original model has not been compiled.
473  """
474  # Grab optimizer now, as we reset-in-place for subclassed models, but
475  # want to maintain access to the original optimizer.
476  orig_optimizer = model.optimizer
477  if compile_clone and not orig_optimizer:
478    raise ValueError(
479        'Error when cloning model: compile_clone was set to True, but the '
480        'original model has not been compiled.')
481
482  if model._is_graph_network or isinstance(model, Sequential):
483    if custom_objects:
484      with CustomObjectScope(custom_objects):
485        clone = clone_model(model, input_tensors=input_tensors)
486    else:
487      clone = clone_model(model, input_tensors=input_tensors)
488
489    if all([isinstance(clone, Sequential),
490            not clone._is_graph_network,
491            getattr(model, '_build_input_shape', None) is not None]):
492      # Set model inputs to build the model and add input/output properties.
493      # TODO(kathywu): Add multiple placeholders to handle edge case where
494      # sequential model has multiple inputs.
495      clone._set_inputs(
496          K.placeholder(model._build_input_shape, dtype=model.inputs[0].dtype))
497  else:
498    if not in_place_reset:
499      raise ValueError(
500          'This model is a subclassed model. '
501          'Such a model cannot be cloned, but there is a workaround where '
502          'the model is reset in-place. To use this, please set the argument '
503          '`in_place_reset` to `True`. This will reset the attributes in the '
504          'original model. To restore the attributes, call '
505          '`in_place_subclassed_model_state_restoration(model)`.')
506    clone = model
507    _in_place_subclassed_model_reset(clone)
508    if input_tensors is not None:
509      if isinstance(input_tensors, (list, tuple)) and len(input_tensors) == 1:
510        input_tensors = input_tensors[0]
511      clone._set_inputs(input_tensors)
512
513  if compile_clone:
514    if isinstance(orig_optimizer, optimizers.TFOptimizer):
515      optimizer = optimizers.TFOptimizer(
516          orig_optimizer.optimizer, optimizer_iterations)
517      K.track_tf_optimizer(optimizer)
518    else:
519      optimizer_config = orig_optimizer.get_config()
520      optimizer = orig_optimizer.__class__.from_config(optimizer_config)
521      if optimizer_iterations is not None:
522        optimizer.iterations = optimizer_iterations
523
524    clone.compile(
525        optimizer,
526        model.loss,
527        metrics=metrics_module.clone_metrics(model._compile_metrics),
528        loss_weights=model.loss_weights,
529        sample_weight_mode=model.sample_weight_mode,
530        weighted_metrics=metrics_module.clone_metrics(
531            model._compile_weighted_metrics),
532        target_tensors=target_tensors)
533
534  return clone
535