1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Functions that save the model's config into different formats.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import json
23
24from tensorflow.python.util.tf_export import keras_export
25
26# pylint: disable=g-import-not-at-top
27try:
28  import yaml
29except ImportError:
30  yaml = None
31# pylint: enable=g-import-not-at-top
32
33
34@keras_export('keras.models.model_from_config')
35def model_from_config(config, custom_objects=None):
36  """Instantiates a Keras model from its config.
37
38  Arguments:
39      config: Configuration dictionary.
40      custom_objects: Optional dictionary mapping names
41          (strings) to custom classes or functions to be
42          considered during deserialization.
43
44  Returns:
45      A Keras model instance (uncompiled).
46
47  Raises:
48      TypeError: if `config` is not a dictionary.
49  """
50  if isinstance(config, list):
51    raise TypeError('`model_from_config` expects a dictionary, not a list. '
52                    'Maybe you meant to use '
53                    '`Sequential.from_config(config)`?')
54  from tensorflow.python.keras.layers import deserialize  # pylint: disable=g-import-not-at-top
55  return deserialize(config, custom_objects=custom_objects)
56
57
58@keras_export('keras.models.model_from_yaml')
59def model_from_yaml(yaml_string, custom_objects=None):
60  """Parses a yaml model configuration file and returns a model instance.
61
62  Arguments:
63      yaml_string: YAML string encoding a model configuration.
64      custom_objects: Optional dictionary mapping names
65          (strings) to custom classes or functions to be
66          considered during deserialization.
67
68  Returns:
69      A Keras model instance (uncompiled).
70
71  Raises:
72      ImportError: if yaml module is not found.
73  """
74  if yaml is None:
75    raise ImportError('Requires yaml module installed (`pip install pyyaml`).')
76  config = yaml.load(yaml_string)
77  from tensorflow.python.keras.layers import deserialize  # pylint: disable=g-import-not-at-top
78  return deserialize(config, custom_objects=custom_objects)
79
80
81@keras_export('keras.models.model_from_json')
82def model_from_json(json_string, custom_objects=None):
83  """Parses a JSON model configuration file and returns a model instance.
84
85  Arguments:
86      json_string: JSON string encoding a model configuration.
87      custom_objects: Optional dictionary mapping names
88          (strings) to custom classes or functions to be
89          considered during deserialization.
90
91  Returns:
92      A Keras model instance (uncompiled).
93  """
94  config = json.loads(json_string)
95  from tensorflow.python.keras.layers import deserialize  # pylint: disable=g-import-not-at-top
96  return deserialize(config, custom_objects=custom_objects)
97