1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functions for summarizing and describing TensorFlow graphs.
16
17This contains functions that generate string descriptions from
18TensorFlow graphs, for debugging, testing, and model size
19estimation.
20"""
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import re
26from tensorflow.contrib.specs.python import specs
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.ops import array_ops
30
31# These are short abbreviations for common TensorFlow operations used
32# in test cases with tf_structure to verify that specs_lib generates a
33# graph structure with the right operations. Operations outside the
34# scope of specs (e.g., Const and Placeholder) are just assigned "_"
35# since they are not relevant to testing.
36
37SHORT_NAMES_SRC = """
38BiasAdd biasadd
39Const _
40Conv2D conv
41MatMul dot
42Placeholder _
43Sigmoid sig
44Variable var
45""".split()
46
47SHORT_NAMES = {
48    x: y
49    for x, y in zip(SHORT_NAMES_SRC[::2], SHORT_NAMES_SRC[1::2])
50}
51
52
53def _truncate_structure(x):
54  """A helper function that disables recursion in tf_structure.
55
56  Some constructs (e.g., HorizontalLstm) are complex unrolled
57  structures and don't need to be represented in the output
58  of tf_structure or tf_print. This helper function defines
59  which tree branches should be pruned. This is a very imperfect
60  way of dealing with unrolled LSTM's (since it truncates
61  useful information as well), but it's not worth doing something
62  better until the new fused and unrolled ops are ready.
63
64  Args:
65      x: a Tensor or Op
66
67  Returns:
68      A bool indicating whether the subtree should be pruned.
69  """
70  if "/HorizontalLstm/" in x.name:
71    return True
72  return False
73
74
75def tf_structure(x, include_shapes=False, finished=None):
76  """A postfix expression summarizing the TF graph.
77
78  This is intended to be used as part of test cases to
79  check for gross differences in the structure of the graph.
80  The resulting string is not invertible or unabiguous
81  and cannot be used to reconstruct the graph accurately.
82
83  Args:
84      x: a tf.Tensor or tf.Operation
85      include_shapes: include shapes in the output string
86      finished: a set of ops that have already been output
87
88  Returns:
89      A string representing the structure as a string of
90      postfix operations.
91  """
92  if finished is None:
93    finished = set()
94  if isinstance(x, ops.Tensor):
95    shape = x.get_shape().as_list()
96    x = x.op
97  else:
98    shape = []
99  if x in finished:
100    return " <>"
101  finished |= {x}
102  result = ""
103  if not _truncate_structure(x):
104    for y in x.inputs:
105      result += tf_structure(y, include_shapes, finished)
106  if include_shapes:
107    result += " %s" % (shape,)
108  if x.type != "Identity":
109    name = SHORT_NAMES.get(x.type, x.type.lower())
110    result += " " + name
111  return result
112
113
114def tf_print(x, depth=0, finished=None, printer=print):
115  """A simple print function for a TensorFlow graph.
116
117  Args:
118      x: a tf.Tensor or tf.Operation
119      depth: current printing depth
120      finished: set of nodes already output
121      printer: print function to use
122
123  Returns:
124      Total number of parameters found in the
125      subtree.
126  """
127
128  if finished is None:
129    finished = set()
130  if isinstance(x, ops.Tensor):
131    shape = x.get_shape().as_list()
132    x = x.op
133  else:
134    shape = ""
135  if x.type == "Identity":
136    x = x.inputs[0].op
137  if x in finished:
138    printer("%s<%s> %s %s" % ("  " * depth, x.name, x.type, shape))
139    return
140  finished |= {x}
141  printer("%s%s %s %s" % ("  " * depth, x.name, x.type, shape))
142  if not _truncate_structure(x):
143    for y in x.inputs:
144      tf_print(y, depth + 1, finished, printer=printer)
145
146
147def tf_num_params(x):
148  """Number of parameters in a TensorFlow subgraph.
149
150  Args:
151      x: root of the subgraph (Tensor, Operation)
152
153  Returns:
154      Total number of elements found in all Variables
155      in the subgraph.
156  """
157
158  if isinstance(x, ops.Tensor):
159    shape = x.get_shape()
160    x = x.op
161  if x.type in ["Variable", "VariableV2"]:
162    return shape.num_elements()
163  totals = [tf_num_params(y) for y in x.inputs]
164  return sum(totals)
165
166
167def tf_left_split(op):
168  """Split the parameters of op for left recursion.
169
170  Args:
171    op: tf.Operation
172
173  Returns:
174    A tuple of the leftmost input tensor and a list of the
175    remaining arguments.
176  """
177
178  if len(op.inputs) < 1:
179    return None, []
180  if op.type == "Concat":
181    return op.inputs[1], op.inputs[2:]
182  return op.inputs[0], op.inputs[1:]
183
184
185def tf_parameter_iter(x):
186  """Iterate over the left branches of a graph and yield sizes.
187
188  Args:
189      x: root of the subgraph (Tensor, Operation)
190
191  Yields:
192      A triple of name, number of params, and shape.
193  """
194
195  while 1:
196    if isinstance(x, ops.Tensor):
197      shape = x.get_shape().as_list()
198      x = x.op
199    else:
200      shape = ""
201    left, right = tf_left_split(x)
202    totals = [tf_num_params(y) for y in right]
203    total = sum(totals)
204    yield x.name, total, shape
205    if left is None:
206      break
207    x = left
208
209
210def _combine_filter(x):
211  """A filter for combining successive layers with similar names."""
212  last_name = None
213  last_total = 0
214  last_shape = None
215  for name, total, shape in x:
216    name = re.sub("/.*", "", name)
217    if name == last_name:
218      last_total += total
219      continue
220    if last_name is not None:
221      yield last_name, last_total, last_shape
222    last_name = name
223    last_total = total
224    last_shape = shape
225  if last_name is not None:
226    yield last_name, last_total, last_shape
227
228
229def tf_parameter_summary(x, printer=print, combine=True):
230  """Summarize parameters by depth.
231
232  Args:
233      x: root of the subgraph (Tensor, Operation)
234      printer: print function for output
235      combine: combine layers by top-level scope
236  """
237  seq = tf_parameter_iter(x)
238  if combine:
239    seq = _combine_filter(seq)
240  seq = reversed(list(seq))
241  for name, total, shape in seq:
242    printer("%10d %-20s %s" % (total, name, shape))
243
244
245def tf_spec_structure(spec,
246                      inputs=None,
247                      input_shape=None,
248                      input_type=dtypes.float32):
249  """Return a postfix representation of the specification.
250
251  This is intended to be used as part of test cases to
252  check for gross differences in the structure of the graph.
253  The resulting string is not invertible or unabiguous
254  and cannot be used to reconstruct the graph accurately.
255
256  Args:
257      spec: specification
258      inputs: input to the spec construction (usually a Tensor)
259      input_shape: tensor shape (in lieu of inputs)
260      input_type: type of the input tensor
261
262  Returns:
263      A string with a postfix representation of the
264      specification.
265  """
266
267  if inputs is None:
268    inputs = array_ops.placeholder(input_type, input_shape)
269  outputs = specs.create_net(spec, inputs)
270  return str(tf_structure(outputs).strip())
271
272
273def tf_spec_summary(spec,
274                    inputs=None,
275                    input_shape=None,
276                    input_type=dtypes.float32):
277  """Output a summary of the specification.
278
279  This prints a list of left-most tensor operations and summarized the
280  variables found in the right branches. This kind of representation
281  is particularly useful for networks that are generally structured
282  like pipelines.
283
284  Args:
285      spec: specification
286      inputs: input to the spec construction (usually a Tensor)
287      input_shape: optional shape of input
288      input_type: type of the input tensor
289  """
290
291  if inputs is None:
292    inputs = array_ops.placeholder(input_type, input_shape)
293  outputs = specs.create_net(spec, inputs)
294  tf_parameter_summary(outputs)
295
296
297def tf_spec_print(spec,
298                  inputs=None,
299                  input_shape=None,
300                  input_type=dtypes.float32):
301  """Print a tree representing the spec.
302
303  Args:
304      spec: specification
305      inputs: input to the spec construction (usually a Tensor)
306      input_shape: optional shape of input
307      input_type: type of the input tensor
308  """
309
310  if inputs is None:
311    inputs = array_ops.placeholder(input_type, input_shape)
312  outputs = specs.create_net(spec, inputs)
313  tf_print(outputs)
314