1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16# pylint: disable=g-import-not-at-top 17"""Utilities related to model visualization.""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import os 23import sys 24from tensorflow.python.keras.utils.io_utils import path_to_string 25from tensorflow.python.util import nest 26from tensorflow.python.util.tf_export import keras_export 27 28 29try: 30 # pydot-ng is a fork of pydot that is better maintained. 31 import pydot_ng as pydot 32except ImportError: 33 # pydotplus is an improved version of pydot 34 try: 35 import pydotplus as pydot 36 except ImportError: 37 # Fall back on pydot if necessary. 38 try: 39 import pydot 40 except ImportError: 41 pydot = None 42 43 44def check_pydot(): 45 """Returns True if PyDot and Graphviz are available.""" 46 if pydot is None: 47 return False 48 try: 49 # Attempt to create an image of a blank graph 50 # to check the pydot/graphviz installation. 51 pydot.Dot.create(pydot.Dot()) 52 return True 53 except (OSError, pydot.InvocationException): 54 return False 55 56 57def is_wrapped_model(layer): 58 from tensorflow.python.keras.engine import functional 59 from tensorflow.python.keras.layers import wrappers 60 return (isinstance(layer, wrappers.Wrapper) and 61 isinstance(layer.layer, functional.Functional)) 62 63 64def add_edge(dot, src, dst): 65 if not dot.get_edge(src, dst): 66 dot.add_edge(pydot.Edge(src, dst)) 67 68 69@keras_export('keras.utils.model_to_dot') 70def model_to_dot(model, 71 show_shapes=False, 72 show_dtype=False, 73 show_layer_names=True, 74 rankdir='TB', 75 expand_nested=False, 76 dpi=96, 77 subgraph=False): 78 """Convert a Keras model to dot format. 79 80 Args: 81 model: A Keras model instance. 82 show_shapes: whether to display shape information. 83 show_dtype: whether to display layer dtypes. 84 show_layer_names: whether to display layer names. 85 rankdir: `rankdir` argument passed to PyDot, 86 a string specifying the format of the plot: 87 'TB' creates a vertical plot; 88 'LR' creates a horizontal plot. 89 expand_nested: whether to expand nested models into clusters. 90 dpi: Dots per inch. 91 subgraph: whether to return a `pydot.Cluster` instance. 92 93 Returns: 94 A `pydot.Dot` instance representing the Keras model or 95 a `pydot.Cluster` instance representing nested model if 96 `subgraph=True`. 97 98 Raises: 99 ImportError: if graphviz or pydot are not available. 100 """ 101 from tensorflow.python.keras.layers import wrappers 102 from tensorflow.python.keras.engine import sequential 103 from tensorflow.python.keras.engine import functional 104 105 if not check_pydot(): 106 message = ( 107 'You must install pydot (`pip install pydot`) ' 108 'and install graphviz ' 109 '(see instructions at https://graphviz.gitlab.io/download/) ', 110 'for plot_model/model_to_dot to work.') 111 if 'IPython.core.magics.namespace' in sys.modules: 112 # We don't raise an exception here in order to avoid crashing notebook 113 # tests where graphviz is not available. 114 print(message) 115 return 116 else: 117 raise ImportError(message) 118 119 if subgraph: 120 dot = pydot.Cluster(style='dashed', graph_name=model.name) 121 dot.set('label', model.name) 122 dot.set('labeljust', 'l') 123 else: 124 dot = pydot.Dot() 125 dot.set('rankdir', rankdir) 126 dot.set('concentrate', True) 127 dot.set('dpi', dpi) 128 dot.set_node_defaults(shape='record') 129 130 sub_n_first_node = {} 131 sub_n_last_node = {} 132 sub_w_first_node = {} 133 sub_w_last_node = {} 134 135 layers = model.layers 136 if not model._is_graph_network: 137 node = pydot.Node(str(id(model)), label=model.name) 138 dot.add_node(node) 139 return dot 140 elif isinstance(model, sequential.Sequential): 141 if not model.built: 142 model.build() 143 layers = super(sequential.Sequential, model).layers 144 145 # Create graph nodes. 146 for i, layer in enumerate(layers): 147 layer_id = str(id(layer)) 148 149 # Append a wrapped layer's label to node's label, if it exists. 150 layer_name = layer.name 151 class_name = layer.__class__.__name__ 152 153 if isinstance(layer, wrappers.Wrapper): 154 if expand_nested and isinstance(layer.layer, 155 functional.Functional): 156 submodel_wrapper = model_to_dot( 157 layer.layer, 158 show_shapes, 159 show_dtype, 160 show_layer_names, 161 rankdir, 162 expand_nested, 163 subgraph=True) 164 # sub_w : submodel_wrapper 165 sub_w_nodes = submodel_wrapper.get_nodes() 166 sub_w_first_node[layer.layer.name] = sub_w_nodes[0] 167 sub_w_last_node[layer.layer.name] = sub_w_nodes[-1] 168 dot.add_subgraph(submodel_wrapper) 169 else: 170 layer_name = '{}({})'.format(layer_name, layer.layer.name) 171 child_class_name = layer.layer.__class__.__name__ 172 class_name = '{}({})'.format(class_name, child_class_name) 173 174 if expand_nested and isinstance(layer, functional.Functional): 175 submodel_not_wrapper = model_to_dot( 176 layer, 177 show_shapes, 178 show_dtype, 179 show_layer_names, 180 rankdir, 181 expand_nested, 182 subgraph=True) 183 # sub_n : submodel_not_wrapper 184 sub_n_nodes = submodel_not_wrapper.get_nodes() 185 sub_n_first_node[layer.name] = sub_n_nodes[0] 186 sub_n_last_node[layer.name] = sub_n_nodes[-1] 187 dot.add_subgraph(submodel_not_wrapper) 188 189 # Create node's label. 190 if show_layer_names: 191 label = '{}: {}'.format(layer_name, class_name) 192 else: 193 label = class_name 194 195 # Rebuild the label as a table including the layer's dtype. 196 if show_dtype: 197 198 def format_dtype(dtype): 199 if dtype is None: 200 return '?' 201 else: 202 return str(dtype) 203 204 label = '%s|%s' % (label, format_dtype(layer.dtype)) 205 206 # Rebuild the label as a table including input/output shapes. 207 if show_shapes: 208 209 def format_shape(shape): 210 return str(shape).replace(str(None), 'None') 211 212 try: 213 outputlabels = format_shape(layer.output_shape) 214 except AttributeError: 215 outputlabels = '?' 216 if hasattr(layer, 'input_shape'): 217 inputlabels = format_shape(layer.input_shape) 218 elif hasattr(layer, 'input_shapes'): 219 inputlabels = ', '.join( 220 [format_shape(ishape) for ishape in layer.input_shapes]) 221 else: 222 inputlabels = '?' 223 label = '%s\n|{input:|output:}|{{%s}|{%s}}' % (label, 224 inputlabels, 225 outputlabels) 226 227 if not expand_nested or not isinstance( 228 layer, functional.Functional): 229 node = pydot.Node(layer_id, label=label) 230 dot.add_node(node) 231 232 # Connect nodes with edges. 233 for layer in layers: 234 layer_id = str(id(layer)) 235 for i, node in enumerate(layer._inbound_nodes): 236 node_key = layer.name + '_ib-' + str(i) 237 if node_key in model._network_nodes: 238 for inbound_layer in nest.flatten(node.inbound_layers): 239 inbound_layer_id = str(id(inbound_layer)) 240 if not expand_nested: 241 assert dot.get_node(inbound_layer_id) 242 assert dot.get_node(layer_id) 243 add_edge(dot, inbound_layer_id, layer_id) 244 else: 245 # if inbound_layer is not Model or wrapped Model 246 if (not isinstance(inbound_layer, 247 functional.Functional) and 248 not is_wrapped_model(inbound_layer)): 249 # if current layer is not Model or wrapped Model 250 if (not isinstance(layer, functional.Functional) and 251 not is_wrapped_model(layer)): 252 assert dot.get_node(inbound_layer_id) 253 assert dot.get_node(layer_id) 254 add_edge(dot, inbound_layer_id, layer_id) 255 # if current layer is Model 256 elif isinstance(layer, functional.Functional): 257 add_edge(dot, inbound_layer_id, 258 sub_n_first_node[layer.name].get_name()) 259 # if current layer is wrapped Model 260 elif is_wrapped_model(layer): 261 add_edge(dot, inbound_layer_id, layer_id) 262 name = sub_w_first_node[layer.layer.name].get_name() 263 add_edge(dot, layer_id, name) 264 # if inbound_layer is Model 265 elif isinstance(inbound_layer, functional.Functional): 266 name = sub_n_last_node[inbound_layer.name].get_name() 267 if isinstance(layer, functional.Functional): 268 output_name = sub_n_first_node[layer.name].get_name() 269 add_edge(dot, name, output_name) 270 else: 271 add_edge(dot, name, layer_id) 272 # if inbound_layer is wrapped Model 273 elif is_wrapped_model(inbound_layer): 274 inbound_layer_name = inbound_layer.layer.name 275 add_edge(dot, 276 sub_w_last_node[inbound_layer_name].get_name(), 277 layer_id) 278 return dot 279 280 281@keras_export('keras.utils.plot_model') 282def plot_model(model, 283 to_file='model.png', 284 show_shapes=False, 285 show_dtype=False, 286 show_layer_names=True, 287 rankdir='TB', 288 expand_nested=False, 289 dpi=96): 290 """Converts a Keras model to dot format and save to a file. 291 292 Example: 293 294 ```python 295 input = tf.keras.Input(shape=(100,), dtype='int32', name='input') 296 x = tf.keras.layers.Embedding( 297 output_dim=512, input_dim=10000, input_length=100)(input) 298 x = tf.keras.layers.LSTM(32)(x) 299 x = tf.keras.layers.Dense(64, activation='relu')(x) 300 x = tf.keras.layers.Dense(64, activation='relu')(x) 301 x = tf.keras.layers.Dense(64, activation='relu')(x) 302 output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x) 303 model = tf.keras.Model(inputs=[input], outputs=[output]) 304 dot_img_file = '/tmp/model_1.png' 305 tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True) 306 ``` 307 308 Args: 309 model: A Keras model instance 310 to_file: File name of the plot image. 311 show_shapes: whether to display shape information. 312 show_dtype: whether to display layer dtypes. 313 show_layer_names: whether to display layer names. 314 rankdir: `rankdir` argument passed to PyDot, 315 a string specifying the format of the plot: 316 'TB' creates a vertical plot; 317 'LR' creates a horizontal plot. 318 expand_nested: Whether to expand nested models into clusters. 319 dpi: Dots per inch. 320 321 Returns: 322 A Jupyter notebook Image object if Jupyter is installed. 323 This enables in-line display of the model plots in notebooks. 324 """ 325 dot = model_to_dot( 326 model, 327 show_shapes=show_shapes, 328 show_dtype=show_dtype, 329 show_layer_names=show_layer_names, 330 rankdir=rankdir, 331 expand_nested=expand_nested, 332 dpi=dpi) 333 to_file = path_to_string(to_file) 334 if dot is None: 335 return 336 _, extension = os.path.splitext(to_file) 337 if not extension: 338 extension = 'png' 339 else: 340 extension = extension[1:] 341 # Save image to disk. 342 dot.write(to_file, format=extension) 343 # Return the image as a Jupyter Image object, to be displayed in-line. 344 # Note that we cannot easily detect whether the code is running in a 345 # notebook, and thus we always return the Image if Jupyter is available. 346 if extension != 'pdf': 347 try: 348 from IPython import display 349 return display.Image(filename=to_file) 350 except ImportError: 351 pass 352