1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Functions that save the model's config into different formats. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.python.keras.saving.saved_model import json_utils 23from tensorflow.python.util.tf_export import keras_export 24 25# pylint: disable=g-import-not-at-top 26try: 27 import yaml 28except ImportError: 29 yaml = None 30# pylint: enable=g-import-not-at-top 31 32 33@keras_export('keras.models.model_from_config') 34def model_from_config(config, custom_objects=None): 35 """Instantiates a Keras model from its config. 36 37 Usage: 38 ``` 39 # for a Functional API model 40 tf.keras.Model().from_config(model.get_config()) 41 42 # for a Sequential model 43 tf.keras.Sequential().from_config(model.get_config()) 44 ``` 45 46 Args: 47 config: Configuration dictionary. 48 custom_objects: Optional dictionary mapping names 49 (strings) to custom classes or functions to be 50 considered during deserialization. 51 52 Returns: 53 A Keras model instance (uncompiled). 54 55 Raises: 56 TypeError: if `config` is not a dictionary. 57 """ 58 if isinstance(config, list): 59 raise TypeError('`model_from_config` expects a dictionary, not a list. ' 60 'Maybe you meant to use ' 61 '`Sequential.from_config(config)`?') 62 from tensorflow.python.keras.layers import deserialize # pylint: disable=g-import-not-at-top 63 return deserialize(config, custom_objects=custom_objects) 64 65 66@keras_export('keras.models.model_from_yaml') 67def model_from_yaml(yaml_string, custom_objects=None): 68 """Parses a yaml model configuration file and returns a model instance. 69 70 Usage: 71 72 >>> model = tf.keras.Sequential([ 73 ... tf.keras.layers.Dense(5, input_shape=(3,)), 74 ... tf.keras.layers.Softmax()]) 75 >>> try: 76 ... import yaml 77 ... config = model.to_yaml() 78 ... loaded_model = tf.keras.models.model_from_yaml(config) 79 ... except ImportError: 80 ... pass 81 82 Args: 83 yaml_string: YAML string or open file encoding a model configuration. 84 custom_objects: Optional dictionary mapping names 85 (strings) to custom classes or functions to be 86 considered during deserialization. 87 88 Returns: 89 A Keras model instance (uncompiled). 90 91 Raises: 92 ImportError: if yaml module is not found. 93 """ 94 if yaml is None: 95 raise ImportError('Requires yaml module installed (`pip install pyyaml`).') 96 # The method unsafe_load only exists in PyYAML 5.x+, so which branch of the 97 # try block is covered by tests depends on the installed version of PyYAML. 98 try: 99 # PyYAML 5.x+ 100 config = yaml.unsafe_load(yaml_string) 101 except AttributeError: 102 config = yaml.load(yaml_string) 103 from tensorflow.python.keras.layers import deserialize # pylint: disable=g-import-not-at-top 104 return deserialize(config, custom_objects=custom_objects) 105 106 107@keras_export('keras.models.model_from_json') 108def model_from_json(json_string, custom_objects=None): 109 """Parses a JSON model configuration string and returns a model instance. 110 111 Usage: 112 113 >>> model = tf.keras.Sequential([ 114 ... tf.keras.layers.Dense(5, input_shape=(3,)), 115 ... tf.keras.layers.Softmax()]) 116 >>> config = model.to_json() 117 >>> loaded_model = tf.keras.models.model_from_json(config) 118 119 Args: 120 json_string: JSON string encoding a model configuration. 121 custom_objects: Optional dictionary mapping names 122 (strings) to custom classes or functions to be 123 considered during deserialization. 124 125 Returns: 126 A Keras model instance (uncompiled). 127 """ 128 config = json_utils.decode(json_string) 129 from tensorflow.python.keras.layers import deserialize # pylint: disable=g-import-not-at-top 130 return deserialize(config, custom_objects=custom_objects) 131