1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functions for summarizing and describing TensorFlow graphs. 16 17This contains functions that generate string descriptions from 18TensorFlow graphs, for debugging, testing, and model size 19estimation. 20""" 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import re 26from tensorflow.contrib.specs.python import specs 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.ops import array_ops 30 31# These are short abbreviations for common TensorFlow operations used 32# in test cases with tf_structure to verify that specs_lib generates a 33# graph structure with the right operations. Operations outside the 34# scope of specs (e.g., Const and Placeholder) are just assigned "_" 35# since they are not relevant to testing. 36 37SHORT_NAMES_SRC = """ 38BiasAdd biasadd 39Const _ 40Conv2D conv 41MatMul dot 42Placeholder _ 43Sigmoid sig 44Variable var 45""".split() 46 47SHORT_NAMES = { 48 x: y 49 for x, y in zip(SHORT_NAMES_SRC[::2], SHORT_NAMES_SRC[1::2]) 50} 51 52 53def _truncate_structure(x): 54 """A helper function that disables recursion in tf_structure. 55 56 Some constructs (e.g., HorizontalLstm) are complex unrolled 57 structures and don't need to be represented in the output 58 of tf_structure or tf_print. This helper function defines 59 which tree branches should be pruned. This is a very imperfect 60 way of dealing with unrolled LSTM's (since it truncates 61 useful information as well), but it's not worth doing something 62 better until the new fused and unrolled ops are ready. 63 64 Args: 65 x: a Tensor or Op 66 67 Returns: 68 A bool indicating whether the subtree should be pruned. 69 """ 70 if "/HorizontalLstm/" in x.name: 71 return True 72 return False 73 74 75def tf_structure(x, include_shapes=False, finished=None): 76 """A postfix expression summarizing the TF graph. 77 78 This is intended to be used as part of test cases to 79 check for gross differences in the structure of the graph. 80 The resulting string is not invertible or unabiguous 81 and cannot be used to reconstruct the graph accurately. 82 83 Args: 84 x: a tf.Tensor or tf.Operation 85 include_shapes: include shapes in the output string 86 finished: a set of ops that have already been output 87 88 Returns: 89 A string representing the structure as a string of 90 postfix operations. 91 """ 92 if finished is None: 93 finished = set() 94 if isinstance(x, ops.Tensor): 95 shape = x.get_shape().as_list() 96 x = x.op 97 else: 98 shape = [] 99 if x in finished: 100 return " <>" 101 finished |= {x} 102 result = "" 103 if not _truncate_structure(x): 104 for y in x.inputs: 105 result += tf_structure(y, include_shapes, finished) 106 if include_shapes: 107 result += " %s" % (shape,) 108 if x.type != "Identity": 109 name = SHORT_NAMES.get(x.type, x.type.lower()) 110 result += " " + name 111 return result 112 113 114def tf_print(x, depth=0, finished=None, printer=print): 115 """A simple print function for a TensorFlow graph. 116 117 Args: 118 x: a tf.Tensor or tf.Operation 119 depth: current printing depth 120 finished: set of nodes already output 121 printer: print function to use 122 123 Returns: 124 Total number of parameters found in the 125 subtree. 126 """ 127 128 if finished is None: 129 finished = set() 130 if isinstance(x, ops.Tensor): 131 shape = x.get_shape().as_list() 132 x = x.op 133 else: 134 shape = "" 135 if x.type == "Identity": 136 x = x.inputs[0].op 137 if x in finished: 138 printer("%s<%s> %s %s" % (" " * depth, x.name, x.type, shape)) 139 return 140 finished |= {x} 141 printer("%s%s %s %s" % (" " * depth, x.name, x.type, shape)) 142 if not _truncate_structure(x): 143 for y in x.inputs: 144 tf_print(y, depth + 1, finished, printer=printer) 145 146 147def tf_num_params(x): 148 """Number of parameters in a TensorFlow subgraph. 149 150 Args: 151 x: root of the subgraph (Tensor, Operation) 152 153 Returns: 154 Total number of elements found in all Variables 155 in the subgraph. 156 """ 157 158 if isinstance(x, ops.Tensor): 159 shape = x.get_shape() 160 x = x.op 161 if x.type in ["Variable", "VariableV2"]: 162 return shape.num_elements() 163 totals = [tf_num_params(y) for y in x.inputs] 164 return sum(totals) 165 166 167def tf_left_split(op): 168 """Split the parameters of op for left recursion. 169 170 Args: 171 op: tf.Operation 172 173 Returns: 174 A tuple of the leftmost input tensor and a list of the 175 remaining arguments. 176 """ 177 178 if len(op.inputs) < 1: 179 return None, [] 180 if op.type == "Concat": 181 return op.inputs[1], op.inputs[2:] 182 return op.inputs[0], op.inputs[1:] 183 184 185def tf_parameter_iter(x): 186 """Iterate over the left branches of a graph and yield sizes. 187 188 Args: 189 x: root of the subgraph (Tensor, Operation) 190 191 Yields: 192 A triple of name, number of params, and shape. 193 """ 194 195 while 1: 196 if isinstance(x, ops.Tensor): 197 shape = x.get_shape().as_list() 198 x = x.op 199 else: 200 shape = "" 201 left, right = tf_left_split(x) 202 totals = [tf_num_params(y) for y in right] 203 total = sum(totals) 204 yield x.name, total, shape 205 if left is None: 206 break 207 x = left 208 209 210def _combine_filter(x): 211 """A filter for combining successive layers with similar names.""" 212 last_name = None 213 last_total = 0 214 last_shape = None 215 for name, total, shape in x: 216 name = re.sub("/.*", "", name) 217 if name == last_name: 218 last_total += total 219 continue 220 if last_name is not None: 221 yield last_name, last_total, last_shape 222 last_name = name 223 last_total = total 224 last_shape = shape 225 if last_name is not None: 226 yield last_name, last_total, last_shape 227 228 229def tf_parameter_summary(x, printer=print, combine=True): 230 """Summarize parameters by depth. 231 232 Args: 233 x: root of the subgraph (Tensor, Operation) 234 printer: print function for output 235 combine: combine layers by top-level scope 236 """ 237 seq = tf_parameter_iter(x) 238 if combine: 239 seq = _combine_filter(seq) 240 seq = reversed(list(seq)) 241 for name, total, shape in seq: 242 printer("%10d %-20s %s" % (total, name, shape)) 243 244 245def tf_spec_structure(spec, 246 inputs=None, 247 input_shape=None, 248 input_type=dtypes.float32): 249 """Return a postfix representation of the specification. 250 251 This is intended to be used as part of test cases to 252 check for gross differences in the structure of the graph. 253 The resulting string is not invertible or unabiguous 254 and cannot be used to reconstruct the graph accurately. 255 256 Args: 257 spec: specification 258 inputs: input to the spec construction (usually a Tensor) 259 input_shape: tensor shape (in lieu of inputs) 260 input_type: type of the input tensor 261 262 Returns: 263 A string with a postfix representation of the 264 specification. 265 """ 266 267 if inputs is None: 268 inputs = array_ops.placeholder(input_type, input_shape) 269 outputs = specs.create_net(spec, inputs) 270 return str(tf_structure(outputs).strip()) 271 272 273def tf_spec_summary(spec, 274 inputs=None, 275 input_shape=None, 276 input_type=dtypes.float32): 277 """Output a summary of the specification. 278 279 This prints a list of left-most tensor operations and summarized the 280 variables found in the right branches. This kind of representation 281 is particularly useful for networks that are generally structured 282 like pipelines. 283 284 Args: 285 spec: specification 286 inputs: input to the spec construction (usually a Tensor) 287 input_shape: optional shape of input 288 input_type: type of the input tensor 289 """ 290 291 if inputs is None: 292 inputs = array_ops.placeholder(input_type, input_shape) 293 outputs = specs.create_net(spec, inputs) 294 tf_parameter_summary(outputs) 295 296 297def tf_spec_print(spec, 298 inputs=None, 299 input_shape=None, 300 input_type=dtypes.float32): 301 """Print a tree representing the spec. 302 303 Args: 304 spec: specification 305 inputs: input to the spec construction (usually a Tensor) 306 input_shape: optional shape of input 307 input_type: type of the input tensor 308 """ 309 310 if inputs is None: 311 inputs = array_ops.placeholder(input_type, input_shape) 312 outputs = specs.create_net(spec, inputs) 313 tf_print(outputs) 314