1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16# pylint: disable=g-import-not-at-top
17"""Utilities related to model visualization."""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import os
23import sys
24from tensorflow.python.keras.utils.io_utils import path_to_string
25from tensorflow.python.util import nest
26from tensorflow.python.util.tf_export import keras_export
27
28
29try:
30  # pydot-ng is a fork of pydot that is better maintained.
31  import pydot_ng as pydot
32except ImportError:
33  # pydotplus is an improved version of pydot
34  try:
35    import pydotplus as pydot
36  except ImportError:
37    # Fall back on pydot if necessary.
38    try:
39      import pydot
40    except ImportError:
41      pydot = None
42
43
44def check_pydot():
45  """Returns True if PyDot and Graphviz are available."""
46  if pydot is None:
47    return False
48  try:
49    # Attempt to create an image of a blank graph
50    # to check the pydot/graphviz installation.
51    pydot.Dot.create(pydot.Dot())
52    return True
53  except (OSError, pydot.InvocationException):
54    return False
55
56
57def is_wrapped_model(layer):
58  from tensorflow.python.keras.engine import functional
59  from tensorflow.python.keras.layers import wrappers
60  return (isinstance(layer, wrappers.Wrapper) and
61          isinstance(layer.layer, functional.Functional))
62
63
64def add_edge(dot, src, dst):
65  if not dot.get_edge(src, dst):
66    dot.add_edge(pydot.Edge(src, dst))
67
68
69@keras_export('keras.utils.model_to_dot')
70def model_to_dot(model,
71                 show_shapes=False,
72                 show_dtype=False,
73                 show_layer_names=True,
74                 rankdir='TB',
75                 expand_nested=False,
76                 dpi=96,
77                 subgraph=False):
78  """Convert a Keras model to dot format.
79
80  Args:
81    model: A Keras model instance.
82    show_shapes: whether to display shape information.
83    show_dtype: whether to display layer dtypes.
84    show_layer_names: whether to display layer names.
85    rankdir: `rankdir` argument passed to PyDot,
86        a string specifying the format of the plot:
87        'TB' creates a vertical plot;
88        'LR' creates a horizontal plot.
89    expand_nested: whether to expand nested models into clusters.
90    dpi: Dots per inch.
91    subgraph: whether to return a `pydot.Cluster` instance.
92
93  Returns:
94    A `pydot.Dot` instance representing the Keras model or
95    a `pydot.Cluster` instance representing nested model if
96    `subgraph=True`.
97
98  Raises:
99    ImportError: if graphviz or pydot are not available.
100  """
101  from tensorflow.python.keras.layers import wrappers
102  from tensorflow.python.keras.engine import sequential
103  from tensorflow.python.keras.engine import functional
104
105  if not check_pydot():
106    message = (
107        'You must install pydot (`pip install pydot`) '
108        'and install graphviz '
109        '(see instructions at https://graphviz.gitlab.io/download/) ',
110        'for plot_model/model_to_dot to work.')
111    if 'IPython.core.magics.namespace' in sys.modules:
112      # We don't raise an exception here in order to avoid crashing notebook
113      # tests where graphviz is not available.
114      print(message)
115      return
116    else:
117      raise ImportError(message)
118
119  if subgraph:
120    dot = pydot.Cluster(style='dashed', graph_name=model.name)
121    dot.set('label', model.name)
122    dot.set('labeljust', 'l')
123  else:
124    dot = pydot.Dot()
125    dot.set('rankdir', rankdir)
126    dot.set('concentrate', True)
127    dot.set('dpi', dpi)
128    dot.set_node_defaults(shape='record')
129
130  sub_n_first_node = {}
131  sub_n_last_node = {}
132  sub_w_first_node = {}
133  sub_w_last_node = {}
134
135  layers = model.layers
136  if not model._is_graph_network:
137    node = pydot.Node(str(id(model)), label=model.name)
138    dot.add_node(node)
139    return dot
140  elif isinstance(model, sequential.Sequential):
141    if not model.built:
142      model.build()
143    layers = super(sequential.Sequential, model).layers
144
145  # Create graph nodes.
146  for i, layer in enumerate(layers):
147    layer_id = str(id(layer))
148
149    # Append a wrapped layer's label to node's label, if it exists.
150    layer_name = layer.name
151    class_name = layer.__class__.__name__
152
153    if isinstance(layer, wrappers.Wrapper):
154      if expand_nested and isinstance(layer.layer,
155                                      functional.Functional):
156        submodel_wrapper = model_to_dot(
157            layer.layer,
158            show_shapes,
159            show_dtype,
160            show_layer_names,
161            rankdir,
162            expand_nested,
163            subgraph=True)
164        # sub_w : submodel_wrapper
165        sub_w_nodes = submodel_wrapper.get_nodes()
166        sub_w_first_node[layer.layer.name] = sub_w_nodes[0]
167        sub_w_last_node[layer.layer.name] = sub_w_nodes[-1]
168        dot.add_subgraph(submodel_wrapper)
169      else:
170        layer_name = '{}({})'.format(layer_name, layer.layer.name)
171        child_class_name = layer.layer.__class__.__name__
172        class_name = '{}({})'.format(class_name, child_class_name)
173
174    if expand_nested and isinstance(layer, functional.Functional):
175      submodel_not_wrapper = model_to_dot(
176          layer,
177          show_shapes,
178          show_dtype,
179          show_layer_names,
180          rankdir,
181          expand_nested,
182          subgraph=True)
183      # sub_n : submodel_not_wrapper
184      sub_n_nodes = submodel_not_wrapper.get_nodes()
185      sub_n_first_node[layer.name] = sub_n_nodes[0]
186      sub_n_last_node[layer.name] = sub_n_nodes[-1]
187      dot.add_subgraph(submodel_not_wrapper)
188
189    # Create node's label.
190    if show_layer_names:
191      label = '{}: {}'.format(layer_name, class_name)
192    else:
193      label = class_name
194
195    # Rebuild the label as a table including the layer's dtype.
196    if show_dtype:
197
198      def format_dtype(dtype):
199        if dtype is None:
200          return '?'
201        else:
202          return str(dtype)
203
204      label = '%s|%s' % (label, format_dtype(layer.dtype))
205
206    # Rebuild the label as a table including input/output shapes.
207    if show_shapes:
208
209      def format_shape(shape):
210        return str(shape).replace(str(None), 'None')
211
212      try:
213        outputlabels = format_shape(layer.output_shape)
214      except AttributeError:
215        outputlabels = '?'
216      if hasattr(layer, 'input_shape'):
217        inputlabels = format_shape(layer.input_shape)
218      elif hasattr(layer, 'input_shapes'):
219        inputlabels = ', '.join(
220            [format_shape(ishape) for ishape in layer.input_shapes])
221      else:
222        inputlabels = '?'
223      label = '%s\n|{input:|output:}|{{%s}|{%s}}' % (label,
224                                                     inputlabels,
225                                                     outputlabels)
226
227    if not expand_nested or not isinstance(
228        layer, functional.Functional):
229      node = pydot.Node(layer_id, label=label)
230      dot.add_node(node)
231
232  # Connect nodes with edges.
233  for layer in layers:
234    layer_id = str(id(layer))
235    for i, node in enumerate(layer._inbound_nodes):
236      node_key = layer.name + '_ib-' + str(i)
237      if node_key in model._network_nodes:
238        for inbound_layer in nest.flatten(node.inbound_layers):
239          inbound_layer_id = str(id(inbound_layer))
240          if not expand_nested:
241            assert dot.get_node(inbound_layer_id)
242            assert dot.get_node(layer_id)
243            add_edge(dot, inbound_layer_id, layer_id)
244          else:
245            # if inbound_layer is not Model or wrapped Model
246            if (not isinstance(inbound_layer,
247                               functional.Functional) and
248                not is_wrapped_model(inbound_layer)):
249              # if current layer is not Model or wrapped Model
250              if (not isinstance(layer, functional.Functional) and
251                  not is_wrapped_model(layer)):
252                assert dot.get_node(inbound_layer_id)
253                assert dot.get_node(layer_id)
254                add_edge(dot, inbound_layer_id, layer_id)
255              # if current layer is Model
256              elif isinstance(layer, functional.Functional):
257                add_edge(dot, inbound_layer_id,
258                         sub_n_first_node[layer.name].get_name())
259              # if current layer is wrapped Model
260              elif is_wrapped_model(layer):
261                add_edge(dot, inbound_layer_id, layer_id)
262                name = sub_w_first_node[layer.layer.name].get_name()
263                add_edge(dot, layer_id, name)
264            # if inbound_layer is Model
265            elif isinstance(inbound_layer, functional.Functional):
266              name = sub_n_last_node[inbound_layer.name].get_name()
267              if isinstance(layer, functional.Functional):
268                output_name = sub_n_first_node[layer.name].get_name()
269                add_edge(dot, name, output_name)
270              else:
271                add_edge(dot, name, layer_id)
272            # if inbound_layer is wrapped Model
273            elif is_wrapped_model(inbound_layer):
274              inbound_layer_name = inbound_layer.layer.name
275              add_edge(dot,
276                       sub_w_last_node[inbound_layer_name].get_name(),
277                       layer_id)
278  return dot
279
280
281@keras_export('keras.utils.plot_model')
282def plot_model(model,
283               to_file='model.png',
284               show_shapes=False,
285               show_dtype=False,
286               show_layer_names=True,
287               rankdir='TB',
288               expand_nested=False,
289               dpi=96):
290  """Converts a Keras model to dot format and save to a file.
291
292  Example:
293
294  ```python
295  input = tf.keras.Input(shape=(100,), dtype='int32', name='input')
296  x = tf.keras.layers.Embedding(
297      output_dim=512, input_dim=10000, input_length=100)(input)
298  x = tf.keras.layers.LSTM(32)(x)
299  x = tf.keras.layers.Dense(64, activation='relu')(x)
300  x = tf.keras.layers.Dense(64, activation='relu')(x)
301  x = tf.keras.layers.Dense(64, activation='relu')(x)
302  output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x)
303  model = tf.keras.Model(inputs=[input], outputs=[output])
304  dot_img_file = '/tmp/model_1.png'
305  tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True)
306  ```
307
308  Args:
309    model: A Keras model instance
310    to_file: File name of the plot image.
311    show_shapes: whether to display shape information.
312    show_dtype: whether to display layer dtypes.
313    show_layer_names: whether to display layer names.
314    rankdir: `rankdir` argument passed to PyDot,
315        a string specifying the format of the plot:
316        'TB' creates a vertical plot;
317        'LR' creates a horizontal plot.
318    expand_nested: Whether to expand nested models into clusters.
319    dpi: Dots per inch.
320
321  Returns:
322    A Jupyter notebook Image object if Jupyter is installed.
323    This enables in-line display of the model plots in notebooks.
324  """
325  dot = model_to_dot(
326      model,
327      show_shapes=show_shapes,
328      show_dtype=show_dtype,
329      show_layer_names=show_layer_names,
330      rankdir=rankdir,
331      expand_nested=expand_nested,
332      dpi=dpi)
333  to_file = path_to_string(to_file)
334  if dot is None:
335    return
336  _, extension = os.path.splitext(to_file)
337  if not extension:
338    extension = 'png'
339  else:
340    extension = extension[1:]
341  # Save image to disk.
342  dot.write(to_file, format=extension)
343  # Return the image as a Jupyter Image object, to be displayed in-line.
344  # Note that we cannot easily detect whether the code is running in a
345  # notebook, and thus we always return the Image if Jupyter is available.
346  if extension != 'pdf':
347    try:
348      from IPython import display
349      return display.Image(filename=to_file)
350    except ImportError:
351      pass
352