1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utils related to keras model saving."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections.abc as collections_abc
21import copy
22import os
23import six
24
25from tensorflow.python.eager import def_function
26from tensorflow.python.keras import backend as K
27from tensorflow.python.keras import losses
28from tensorflow.python.keras import optimizer_v1
29from tensorflow.python.keras import optimizers
30from tensorflow.python.keras.engine import base_layer_utils
31from tensorflow.python.keras.utils import generic_utils
32from tensorflow.python.keras.utils import version_utils
33from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
34from tensorflow.python.platform import tf_logging as logging
35from tensorflow.python.util import nest
36
37
38def extract_model_metrics(model):
39  """Convert metrics from a Keras model `compile` API to dictionary.
40
41  This is used for converting Keras models to Estimators and SavedModels.
42
43  Args:
44    model: A `tf.keras.Model` object.
45
46  Returns:
47    Dictionary mapping metric names to metric instances. May return `None` if
48    the model does not contain any metrics.
49  """
50  if getattr(model, '_compile_metrics', None):
51    # TODO(psv/kathywu): use this implementation in model to estimator flow.
52    # We are not using model.metrics here because we want to exclude the metrics
53    # added using `add_metric` API.
54    return {m.name: m for m in model._compile_metric_functions}  # pylint: disable=protected-access
55  return None
56
57
58def model_input_signature(model, keep_original_batch_size=False):
59  """Inspect model to get its input signature.
60
61  The model's input signature is a list with a single (possibly-nested) object.
62  This is due to the Keras-enforced restriction that tensor inputs must be
63  passed in as the first argument.
64
65  For example, a model with input {'feature1': <Tensor>, 'feature2': <Tensor>}
66  will have input signature: [{'feature1': TensorSpec, 'feature2': TensorSpec}]
67
68  Args:
69    model: Keras Model object.
70    keep_original_batch_size: A boolean indicating whether we want to keep using
71      the original batch size or set it to None. Default is `False`, which means
72      that the batch dim of the returned input signature will always be set to
73      `None`.
74
75  Returns:
76    A list containing either a single TensorSpec or an object with nested
77    TensorSpecs. This list does not contain the `training` argument.
78  """
79  input_specs = model._get_save_spec(dynamic_batch=not keep_original_batch_size)  # pylint: disable=protected-access
80  if input_specs is None:
81    return None
82  input_specs = _enforce_names_consistency(input_specs)
83  # Return a list with a single element as the model's input signature.
84  if isinstance(input_specs,
85                collections_abc.Sequence) and len(input_specs) == 1:
86    # Note that the isinstance check filters out single-element dictionaries,
87    # which should also be wrapped as a single-element list.
88    return input_specs
89  else:
90    return [input_specs]
91
92
93def raise_model_input_error(model):
94  raise ValueError(
95      'Model {} cannot be saved because the input shapes have not been '
96      'set. Usually, input shapes are automatically determined from calling'
97      ' `.fit()` or `.predict()`. To manually set the shapes, call '
98      '`model.build(input_shape)`.'.format(model))
99
100
101def trace_model_call(model, input_signature=None):
102  """Trace the model call to create a tf.function for exporting a Keras model.
103
104  Args:
105    model: A Keras model.
106    input_signature: optional, a list of tf.TensorSpec objects specifying the
107      inputs to the model.
108
109  Returns:
110    A tf.function wrapping the model's call function with input signatures set.
111
112  Raises:
113    ValueError: if input signature cannot be inferred from the model.
114  """
115  if input_signature is None:
116    if isinstance(model.call, def_function.Function):
117      input_signature = model.call.input_signature
118
119  if input_signature is None:
120    input_signature = model_input_signature(model)
121
122  if input_signature is None:
123    raise_model_input_error(model)
124
125  @def_function.function(input_signature=input_signature)
126  def _wrapped_model(*args):
127    """A concrete tf.function that wraps the model's call function."""
128    # When given a single input, Keras models will call the model on the tensor
129    # rather than a list consisting of the single tensor.
130    inputs = args[0] if len(input_signature) == 1 else list(args)
131
132    with base_layer_utils.call_context().enter(
133        model, inputs=inputs, build_graph=False, training=False, saving=True):
134      outputs = model(inputs, training=False)
135
136    # Outputs always has to be a flat dict.
137    output_names = model.output_names  # Functional Model.
138    if output_names is None:  # Subclassed Model.
139      from tensorflow.python.keras.engine import compile_utils  # pylint: disable=g-import-not-at-top
140      output_names = compile_utils.create_pseudo_output_names(outputs)
141    outputs = nest.flatten(outputs)
142    return {name: output for name, output in zip(output_names, outputs)}
143
144  return _wrapped_model
145
146
147def model_metadata(model, include_optimizer=True, require_config=True):
148  """Returns a dictionary containing the model metadata."""
149  from tensorflow.python.keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top
150  from tensorflow.python.keras.optimizer_v2 import optimizer_v2  # pylint: disable=g-import-not-at-top
151
152  model_config = {'class_name': model.__class__.__name__}
153  try:
154    model_config['config'] = model.get_config()
155  except NotImplementedError as e:
156    if require_config:
157      raise e
158
159  metadata = dict(
160      keras_version=str(keras_version),
161      backend=K.backend(),
162      model_config=model_config)
163  if model.optimizer and include_optimizer:
164    if isinstance(model.optimizer, optimizer_v1.TFOptimizer):
165      logging.warning(
166          'TensorFlow optimizers do not '
167          'make it possible to access '
168          'optimizer attributes or optimizer state '
169          'after instantiation. '
170          'As a result, we cannot save the optimizer '
171          'as part of the model save file. '
172          'You will have to compile your model again after loading it. '
173          'Prefer using a Keras optimizer instead '
174          '(see keras.io/optimizers).')
175    elif model._compile_was_called:  # pylint: disable=protected-access
176      training_config = model._get_compile_args(user_metrics=False)  # pylint: disable=protected-access
177      training_config.pop('optimizer', None)  # Handled separately.
178      metadata['training_config'] = _serialize_nested_config(training_config)
179      if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer):
180        raise NotImplementedError(
181            'As of now, Optimizers loaded from SavedModel cannot be saved. '
182            'If you\'re calling `model.save` or `tf.keras.models.save_model`,'
183            ' please set the `include_optimizer` option to `False`. For '
184            '`tf.saved_model.save`, delete the optimizer from the model.')
185      else:
186        optimizer_config = {
187            'class_name':
188                generic_utils.get_registered_name(model.optimizer.__class__),
189            'config':
190                model.optimizer.get_config()
191        }
192      metadata['training_config']['optimizer_config'] = optimizer_config
193  return metadata
194
195
196def should_overwrite(filepath, overwrite):
197  """Returns whether the filepath should be overwritten."""
198  # If file exists and should not be overwritten.
199  if not overwrite and os.path.isfile(filepath):
200    return ask_to_proceed_with_overwrite(filepath)
201  return True
202
203
204def compile_args_from_training_config(training_config, custom_objects=None):
205  """Return model.compile arguments from training config."""
206  if custom_objects is None:
207    custom_objects = {}
208
209  with generic_utils.CustomObjectScope(custom_objects):
210    optimizer_config = training_config['optimizer_config']
211    optimizer = optimizers.deserialize(optimizer_config)
212
213    # Recover losses.
214    loss = None
215    loss_config = training_config.get('loss', None)
216    if loss_config is not None:
217      loss = _deserialize_nested_config(losses.deserialize, loss_config)
218
219    # Recover metrics.
220    metrics = None
221    metrics_config = training_config.get('metrics', None)
222    if metrics_config is not None:
223      metrics = _deserialize_nested_config(_deserialize_metric, metrics_config)
224
225    # Recover weighted metrics.
226    weighted_metrics = None
227    weighted_metrics_config = training_config.get('weighted_metrics', None)
228    if weighted_metrics_config is not None:
229      weighted_metrics = _deserialize_nested_config(_deserialize_metric,
230                                                    weighted_metrics_config)
231
232    sample_weight_mode = training_config['sample_weight_mode'] if hasattr(
233        training_config, 'sample_weight_mode') else None
234    loss_weights = training_config['loss_weights']
235
236  return dict(
237      optimizer=optimizer,
238      loss=loss,
239      metrics=metrics,
240      weighted_metrics=weighted_metrics,
241      loss_weights=loss_weights,
242      sample_weight_mode=sample_weight_mode)
243
244
245def _deserialize_nested_config(deserialize_fn, config):
246  """Deserializes arbitrary Keras `config` using `deserialize_fn`."""
247
248  def _is_single_object(obj):
249    if isinstance(obj, dict) and 'class_name' in obj:
250      return True  # Serialized Keras object.
251    if isinstance(obj, six.string_types):
252      return True  # Serialized function or string.
253    return False
254
255  if config is None:
256    return None
257  if _is_single_object(config):
258    return deserialize_fn(config)
259  elif isinstance(config, dict):
260    return {
261        k: _deserialize_nested_config(deserialize_fn, v)
262        for k, v in config.items()
263    }
264  elif isinstance(config, (tuple, list)):
265    return [_deserialize_nested_config(deserialize_fn, obj) for obj in config]
266
267  raise ValueError('Saved configuration not understood.')
268
269
270def _serialize_nested_config(config):
271  """Serialized a nested structure of Keras objects."""
272
273  def _serialize_fn(obj):
274    if callable(obj):
275      return generic_utils.serialize_keras_object(obj)
276    return obj
277
278  return nest.map_structure(_serialize_fn, config)
279
280
281def _deserialize_metric(metric_config):
282  """Deserialize metrics, leaving special strings untouched."""
283  from tensorflow.python.keras import metrics as metrics_module  # pylint:disable=g-import-not-at-top
284  if metric_config in ['accuracy', 'acc', 'crossentropy', 'ce']:
285    # Do not deserialize accuracy and cross-entropy strings as we have special
286    # case handling for these in compile, based on model output shape.
287    return metric_config
288  return metrics_module.deserialize(metric_config)
289
290
291def _enforce_names_consistency(specs):
292  """Enforces that either all specs have names or none do."""
293
294  def _has_name(spec):
295    return hasattr(spec, 'name') and spec.name is not None
296
297  def _clear_name(spec):
298    spec = copy.deepcopy(spec)
299    if hasattr(spec, 'name'):
300      spec._name = None  # pylint:disable=protected-access
301    return spec
302
303  flat_specs = nest.flatten(specs)
304  name_inconsistency = (
305      any(_has_name(s) for s in flat_specs) and
306      not all(_has_name(s) for s in flat_specs))
307
308  if name_inconsistency:
309    specs = nest.map_structure(_clear_name, specs)
310  return specs
311
312
313def try_build_compiled_arguments(model):
314  if (not version_utils.is_v1_layer_or_model(model) and
315      model.outputs is not None):
316    try:
317      model.compiled_loss.build(model.outputs)
318      model.compiled_metrics.build(model.outputs, model.outputs)
319    except:  # pylint: disable=bare-except
320      logging.warning(
321          'Compiled the loaded model, but the compiled metrics have yet to '
322          'be built. `model.compile_metrics` will be empty until you train '
323          'or evaluate the model.')
324
325
326def is_hdf5_filepath(filepath):
327  return (filepath.endswith('.h5') or filepath.endswith('.keras') or
328          filepath.endswith('.hdf5'))
329