1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16# pylint: disable=g-import-not-at-top 17"""Utilities related to model visualization.""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import os 23import sys 24from tensorflow.python.util.tf_export import keras_export 25 26 27try: 28 # pydot-ng is a fork of pydot that is better maintained. 29 import pydot_ng as pydot 30except ImportError: 31 # pydotplus is an improved version of pydot 32 try: 33 import pydotplus as pydot 34 except ImportError: 35 # Fall back on pydot if necessary. 36 try: 37 import pydot 38 except ImportError: 39 pydot = None 40 41 42def _check_pydot(): 43 try: 44 # Attempt to create an image of a blank graph 45 # to check the pydot/graphviz installation. 46 pydot.Dot.create(pydot.Dot()) 47 return True 48 except Exception: # pylint: disable=broad-except 49 # pydot raises a generic Exception here, 50 # so no specific class can be caught. 51 return False 52 53 54def model_to_dot(model, show_shapes=False, show_layer_names=True, rankdir='TB'): 55 """Convert a Keras model to dot format. 56 57 Arguments: 58 model: A Keras model instance. 59 show_shapes: whether to display shape information. 60 show_layer_names: whether to display layer names. 61 rankdir: `rankdir` argument passed to PyDot, 62 a string specifying the format of the plot: 63 'TB' creates a vertical plot; 64 'LR' creates a horizontal plot. 65 66 Returns: 67 A `pydot.Dot` instance representing the Keras model (or None if the Dot 68 file could not be generated). 69 70 Raises: 71 ImportError: if graphviz or pydot are not available. 72 """ 73 from tensorflow.python.keras.layers.wrappers import Wrapper 74 from tensorflow.python.keras.models import Sequential 75 from tensorflow.python.util import nest 76 77 check = _check_pydot() 78 if not check: 79 if 'IPython.core.magics.namespace' in sys.modules: 80 # We don't raise an exception here in order to avoid crashing notebook 81 # tests where graphviz is not available. 82 print('Failed to import pydot. You must install pydot' 83 ' and graphviz for `pydotprint` to work.') 84 return 85 else: 86 raise ImportError('Failed to import pydot. You must install pydot' 87 ' and graphviz for `pydotprint` to work.') 88 89 dot = pydot.Dot() 90 dot.set('rankdir', rankdir) 91 dot.set('concentrate', True) 92 dot.set_node_defaults(shape='record') 93 94 if isinstance(model, Sequential): 95 if not model.built: 96 model.build() 97 layers = model._layers 98 99 # Create graph nodes. 100 for layer in layers: 101 layer_id = str(id(layer)) 102 103 # Append a wrapped layer's label to node's label, if it exists. 104 layer_name = layer.name 105 class_name = layer.__class__.__name__ 106 if isinstance(layer, Wrapper): 107 layer_name = '{}({})'.format(layer_name, layer.layer.name) 108 child_class_name = layer.layer.__class__.__name__ 109 class_name = '{}({})'.format(class_name, child_class_name) 110 111 # Create node's label. 112 if show_layer_names: 113 label = '{}: {}'.format(layer_name, class_name) 114 else: 115 label = class_name 116 117 # Rebuild the label as a table including input/output shapes. 118 if show_shapes: 119 try: 120 outputlabels = str(layer.output_shape) 121 except AttributeError: 122 outputlabels = 'multiple' 123 if hasattr(layer, 'input_shape'): 124 inputlabels = str(layer.input_shape) 125 elif hasattr(layer, 'input_shapes'): 126 inputlabels = ', '.join([str(ishape) for ishape in layer.input_shapes]) 127 else: 128 inputlabels = 'multiple' 129 label = '%s\n|{input:|output:}|{{%s}|{%s}}' % (label, inputlabels, 130 outputlabels) 131 node = pydot.Node(layer_id, label=label) 132 dot.add_node(node) 133 134 # Connect nodes with edges. 135 for layer in layers: 136 layer_id = str(id(layer)) 137 for i, node in enumerate(layer._inbound_nodes): 138 node_key = layer.name + '_ib-' + str(i) 139 if node_key in model._network_nodes: # pylint: disable=protected-access 140 for inbound_layer in nest.flatten(node.inbound_layers): 141 inbound_layer_id = str(id(inbound_layer)) 142 layer_id = str(id(layer)) 143 dot.add_edge(pydot.Edge(inbound_layer_id, layer_id)) 144 return dot 145 146 147@keras_export('keras.utils.plot_model') 148def plot_model(model, 149 to_file='model.png', 150 show_shapes=False, 151 show_layer_names=True, 152 rankdir='TB'): 153 """Converts a Keras model to dot format and save to a file. 154 155 Arguments: 156 model: A Keras model instance 157 to_file: File name of the plot image. 158 show_shapes: whether to display shape information. 159 show_layer_names: whether to display layer names. 160 rankdir: `rankdir` argument passed to PyDot, 161 a string specifying the format of the plot: 162 'TB' creates a vertical plot; 163 'LR' creates a horizontal plot. 164 165 Returns: 166 A Jupyter notebook Image object if Jupyter is installed. 167 This enables in-line display of the model plots in notebooks. 168 """ 169 dot = model_to_dot(model, show_shapes, show_layer_names, rankdir) 170 if dot is None: 171 return 172 _, extension = os.path.splitext(to_file) 173 if not extension: 174 extension = 'png' 175 else: 176 extension = extension[1:] 177 # Save image to disk. 178 dot.write(to_file, format=extension) 179 # Return the image as a Jupyter Image object, to be displayed in-line. 180 # Note that we cannot easily detect whether the code is running in a 181 # notebook, and thus we always return the Image if Jupyter is available. 182 try: 183 from IPython import display 184 return display.Image(filename=to_file) 185 except ImportError: 186 pass 187