1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Functions that save the model's config into different formats.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.python.keras.saving.saved_model import json_utils
23from tensorflow.python.util.tf_export import keras_export
24
25# pylint: disable=g-import-not-at-top
26try:
27  import yaml
28except ImportError:
29  yaml = None
30# pylint: enable=g-import-not-at-top
31
32
33@keras_export('keras.models.model_from_config')
34def model_from_config(config, custom_objects=None):
35  """Instantiates a Keras model from its config.
36
37  Usage:
38  ```
39  # for a Functional API model
40  tf.keras.Model().from_config(model.get_config())
41
42  # for a Sequential model
43  tf.keras.Sequential().from_config(model.get_config())
44  ```
45
46  Args:
47      config: Configuration dictionary.
48      custom_objects: Optional dictionary mapping names
49          (strings) to custom classes or functions to be
50          considered during deserialization.
51
52  Returns:
53      A Keras model instance (uncompiled).
54
55  Raises:
56      TypeError: if `config` is not a dictionary.
57  """
58  if isinstance(config, list):
59    raise TypeError('`model_from_config` expects a dictionary, not a list. '
60                    'Maybe you meant to use '
61                    '`Sequential.from_config(config)`?')
62  from tensorflow.python.keras.layers import deserialize  # pylint: disable=g-import-not-at-top
63  return deserialize(config, custom_objects=custom_objects)
64
65
66@keras_export('keras.models.model_from_yaml')
67def model_from_yaml(yaml_string, custom_objects=None):
68  """Parses a yaml model configuration file and returns a model instance.
69
70  Usage:
71
72  >>> model = tf.keras.Sequential([
73  ...     tf.keras.layers.Dense(5, input_shape=(3,)),
74  ...     tf.keras.layers.Softmax()])
75  >>> try:
76  ...   import yaml
77  ...   config = model.to_yaml()
78  ...   loaded_model = tf.keras.models.model_from_yaml(config)
79  ... except ImportError:
80  ...   pass
81
82  Args:
83      yaml_string: YAML string or open file encoding a model configuration.
84      custom_objects: Optional dictionary mapping names
85          (strings) to custom classes or functions to be
86          considered during deserialization.
87
88  Returns:
89      A Keras model instance (uncompiled).
90
91  Raises:
92      ImportError: if yaml module is not found.
93  """
94  if yaml is None:
95    raise ImportError('Requires yaml module installed (`pip install pyyaml`).')
96  # The method unsafe_load only exists in PyYAML 5.x+, so which branch of the
97  # try block is covered by tests depends on the installed version of PyYAML.
98  try:
99    # PyYAML 5.x+
100    config = yaml.unsafe_load(yaml_string)
101  except AttributeError:
102    config = yaml.load(yaml_string)
103  from tensorflow.python.keras.layers import deserialize  # pylint: disable=g-import-not-at-top
104  return deserialize(config, custom_objects=custom_objects)
105
106
107@keras_export('keras.models.model_from_json')
108def model_from_json(json_string, custom_objects=None):
109  """Parses a JSON model configuration string and returns a model instance.
110
111  Usage:
112
113  >>> model = tf.keras.Sequential([
114  ...     tf.keras.layers.Dense(5, input_shape=(3,)),
115  ...     tf.keras.layers.Softmax()])
116  >>> config = model.to_json()
117  >>> loaded_model = tf.keras.models.model_from_json(config)
118
119  Args:
120      json_string: JSON string encoding a model configuration.
121      custom_objects: Optional dictionary mapping names
122          (strings) to custom classes or functions to be
123          considered during deserialization.
124
125  Returns:
126      A Keras model instance (uncompiled).
127  """
128  config = json_utils.decode(json_string)
129  from tensorflow.python.keras.layers import deserialize  # pylint: disable=g-import-not-at-top
130  return deserialize(config, custom_objects=custom_objects)
131