1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16# pylint: disable=g-import-not-at-top
17"""Utilities related to model visualization."""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import os
23import sys
24from tensorflow.python.util.tf_export import keras_export
25
26
27try:
28  # pydot-ng is a fork of pydot that is better maintained.
29  import pydot_ng as pydot
30except ImportError:
31  # pydotplus is an improved version of pydot
32  try:
33    import pydotplus as pydot
34  except ImportError:
35    # Fall back on pydot if necessary.
36    try:
37      import pydot
38    except ImportError:
39      pydot = None
40
41
42def _check_pydot():
43  try:
44    # Attempt to create an image of a blank graph
45    # to check the pydot/graphviz installation.
46    pydot.Dot.create(pydot.Dot())
47    return True
48  except Exception:  # pylint: disable=broad-except
49    # pydot raises a generic Exception here,
50    # so no specific class can be caught.
51    return False
52
53
54def model_to_dot(model, show_shapes=False, show_layer_names=True, rankdir='TB'):
55  """Convert a Keras model to dot format.
56
57  Arguments:
58      model: A Keras model instance.
59      show_shapes: whether to display shape information.
60      show_layer_names: whether to display layer names.
61      rankdir: `rankdir` argument passed to PyDot,
62          a string specifying the format of the plot:
63          'TB' creates a vertical plot;
64          'LR' creates a horizontal plot.
65
66  Returns:
67      A `pydot.Dot` instance representing the Keras model (or None if the Dot
68      file could not be generated).
69
70  Raises:
71    ImportError: if graphviz or pydot are not available.
72  """
73  from tensorflow.python.keras.layers.wrappers import Wrapper
74  from tensorflow.python.keras.models import Sequential
75  from tensorflow.python.util import nest
76
77  check = _check_pydot()
78  if not check:
79    if 'IPython.core.magics.namespace' in sys.modules:
80      # We don't raise an exception here in order to avoid crashing notebook
81      # tests where graphviz is not available.
82      print('Failed to import pydot. You must install pydot'
83            ' and graphviz for `pydotprint` to work.')
84      return
85    else:
86      raise ImportError('Failed to import pydot. You must install pydot'
87                        ' and graphviz for `pydotprint` to work.')
88
89  dot = pydot.Dot()
90  dot.set('rankdir', rankdir)
91  dot.set('concentrate', True)
92  dot.set_node_defaults(shape='record')
93
94  if isinstance(model, Sequential):
95    if not model.built:
96      model.build()
97  layers = model._layers
98
99  # Create graph nodes.
100  for layer in layers:
101    layer_id = str(id(layer))
102
103    # Append a wrapped layer's label to node's label, if it exists.
104    layer_name = layer.name
105    class_name = layer.__class__.__name__
106    if isinstance(layer, Wrapper):
107      layer_name = '{}({})'.format(layer_name, layer.layer.name)
108      child_class_name = layer.layer.__class__.__name__
109      class_name = '{}({})'.format(class_name, child_class_name)
110
111    # Create node's label.
112    if show_layer_names:
113      label = '{}: {}'.format(layer_name, class_name)
114    else:
115      label = class_name
116
117    # Rebuild the label as a table including input/output shapes.
118    if show_shapes:
119      try:
120        outputlabels = str(layer.output_shape)
121      except AttributeError:
122        outputlabels = 'multiple'
123      if hasattr(layer, 'input_shape'):
124        inputlabels = str(layer.input_shape)
125      elif hasattr(layer, 'input_shapes'):
126        inputlabels = ', '.join([str(ishape) for ishape in layer.input_shapes])
127      else:
128        inputlabels = 'multiple'
129      label = '%s\n|{input:|output:}|{{%s}|{%s}}' % (label, inputlabels,
130                                                     outputlabels)
131    node = pydot.Node(layer_id, label=label)
132    dot.add_node(node)
133
134  # Connect nodes with edges.
135  for layer in layers:
136    layer_id = str(id(layer))
137    for i, node in enumerate(layer._inbound_nodes):
138      node_key = layer.name + '_ib-' + str(i)
139      if node_key in model._network_nodes:  # pylint: disable=protected-access
140        for inbound_layer in nest.flatten(node.inbound_layers):
141          inbound_layer_id = str(id(inbound_layer))
142          layer_id = str(id(layer))
143          dot.add_edge(pydot.Edge(inbound_layer_id, layer_id))
144  return dot
145
146
147@keras_export('keras.utils.plot_model')
148def plot_model(model,
149               to_file='model.png',
150               show_shapes=False,
151               show_layer_names=True,
152               rankdir='TB'):
153  """Converts a Keras model to dot format and save to a file.
154
155  Arguments:
156      model: A Keras model instance
157      to_file: File name of the plot image.
158      show_shapes: whether to display shape information.
159      show_layer_names: whether to display layer names.
160      rankdir: `rankdir` argument passed to PyDot,
161          a string specifying the format of the plot:
162          'TB' creates a vertical plot;
163          'LR' creates a horizontal plot.
164
165  Returns:
166      A Jupyter notebook Image object if Jupyter is installed.
167      This enables in-line display of the model plots in notebooks.
168  """
169  dot = model_to_dot(model, show_shapes, show_layer_names, rankdir)
170  if dot is None:
171    return
172  _, extension = os.path.splitext(to_file)
173  if not extension:
174    extension = 'png'
175  else:
176    extension = extension[1:]
177  # Save image to disk.
178  dot.write(to_file, format=extension)
179  # Return the image as a Jupyter Image object, to be displayed in-line.
180  # Note that we cannot easily detect whether the code is running in a
181  # notebook, and thus we always return the Image if Jupyter is available.
182  try:
183    from IPython import display
184    return display.Image(filename=to_file)
185  except ImportError:
186    pass
187