1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes and methods for processing debugger-decorated graphs."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from six.moves import xrange  # pylint: disable=redefined-builtin
21
22from tensorflow.core.framework import graph_pb2
23from tensorflow.python.framework import op_def_registry
24from tensorflow.python.platform import tf_logging as logging
25
26
27def parse_node_or_tensor_name(name):
28  """Get the node name from a string that can be node or tensor name.
29
30  Args:
31    name: An input node name (e.g., "node_a") or tensor name (e.g.,
32      "node_a:0"), as a str.
33
34  Returns:
35    1) The node name, as a str. If the input name is a tensor name, i.e.,
36      consists of a colon, the final colon and the following output slot
37      will be stripped.
38    2) If the input name is a tensor name, the output slot, as an int. If
39      the input name is not a tensor name, None.
40  """
41
42  if ":" in name and not name.endswith(":"):
43    node_name = name[:name.rfind(":")]
44    output_slot = int(name[name.rfind(":") + 1:])
45
46    return node_name, output_slot
47  else:
48    return name, None
49
50
51def get_node_name(element_name):
52  node_name, _ = parse_node_or_tensor_name(element_name)
53  return node_name
54
55
56def get_output_slot(element_name):
57  """Get the output slot number from the name of a graph element.
58
59  If element_name is a node name without output slot at the end, 0 will be
60  assumed.
61
62  Args:
63    element_name: (`str`) name of the graph element in question.
64
65  Returns:
66    (`int`) output slot number.
67  """
68  _, output_slot = parse_node_or_tensor_name(element_name)
69  return output_slot if output_slot is not None else 0
70
71
72def is_copy_node(node_name):
73  """Determine whether a node name is that of a debug Copy node.
74
75  Such nodes are inserted by TensorFlow core upon request in
76  RunOptions.debug_options.debug_tensor_watch_opts.
77
78  Args:
79    node_name: Name of the node.
80
81  Returns:
82    A bool indicating whether the input argument is the name of a debug Copy
83    node.
84  """
85  return node_name.startswith("__copy_")
86
87
88def is_debug_node(node_name):
89  """Determine whether a node name is that of a debug node.
90
91  Such nodes are inserted by TensorFlow core upon request in
92  RunOptions.debug_options.debug_tensor_watch_opts.
93
94  Args:
95    node_name: Name of the node.
96
97  Returns:
98    A bool indicating whether the input argument is the name of a debug node.
99  """
100  return node_name.startswith("__dbg_")
101
102
103def parse_debug_node_name(node_name):
104  """Parse the name of a debug node.
105
106  Args:
107    node_name: Name of the debug node.
108
109  Returns:
110    1. Name of the watched node, as a str.
111    2. Output slot index of the watched tensor, as an int.
112    3. Index of the debug node, as an int.
113    4. Name of the debug op, as a str, e.g, "DebugIdentity".
114
115  Raises:
116    ValueError: If the input node name is not a valid debug node name.
117  """
118  prefix = "__dbg_"
119
120  name = node_name
121  if not name.startswith(prefix):
122    raise ValueError("Invalid prefix in debug node name: '%s'" % node_name)
123
124  name = name[len(prefix):]
125
126  if name.count("_") < 2:
127    raise ValueError("Invalid debug node name: '%s'" % node_name)
128
129  debug_op = name[name.rindex("_") + 1:]
130  name = name[:name.rindex("_")]
131
132  debug_op_index = int(name[name.rindex("_") + 1:])
133  name = name[:name.rindex("_")]
134
135  if name.count(":") != 1:
136    raise ValueError("Invalid tensor name in debug node name: '%s'" % node_name)
137
138  watched_node_name = name[:name.index(":")]
139  watched_output_slot = int(name[name.index(":") + 1:])
140
141  return watched_node_name, watched_output_slot, debug_op_index, debug_op
142
143
144class GraphTracingReachedDestination(Exception):
145  pass
146
147
148class DFSGraphTracer(object):
149  """Graph input tracer using depth-first search."""
150
151  def __init__(self,
152               input_lists,
153               skip_node_names=None,
154               destination_node_name=None):
155    """Constructor of _DFSGraphTracer.
156
157    Args:
158      input_lists: A list of dicts. Each dict is an adjacency (input) map from
159        the recipient node name as the key and the list of input node names
160        as the value.
161      skip_node_names: Optional: a list of node names to skip tracing.
162      destination_node_name: Optional: destination node name. If not `None`, it
163        should be the name of a destination not as a str and the graph tracing
164        will raise GraphTracingReachedDestination as soon as the node has been
165        reached.
166
167    Raises:
168      GraphTracingReachedDestination: if stop_at_node_name is not None and
169        the specified node is reached.
170    """
171
172    self._input_lists = input_lists
173    self._skip_node_names = skip_node_names
174
175    self._inputs = []
176    self._visited_nodes = []
177    self._depth_count = 0
178    self._depth_list = []
179
180    self._destination_node_name = destination_node_name
181
182  def trace(self, graph_element_name):
183    """Trace inputs.
184
185    Args:
186      graph_element_name: Name of the node or an output tensor of the node, as a
187        str.
188
189    Raises:
190      GraphTracingReachedDestination: if destination_node_name of this tracer
191        object is not None and the specified node is reached.
192    """
193    self._depth_count += 1
194
195    node_name = get_node_name(graph_element_name)
196    if node_name == self._destination_node_name:
197      raise GraphTracingReachedDestination()
198
199    if node_name in self._skip_node_names:
200      return
201    if node_name in self._visited_nodes:
202      return
203
204    self._visited_nodes.append(node_name)
205
206    for input_list in self._input_lists:
207      if node_name not in input_list:
208        continue
209      for inp in input_list[node_name]:
210        if get_node_name(inp) in self._visited_nodes:
211          continue
212        self._inputs.append(inp)
213        self._depth_list.append(self._depth_count)
214        self.trace(inp)
215
216    self._depth_count -= 1
217
218  def inputs(self):
219    return self._inputs
220
221  def depth_list(self):
222    return self._depth_list
223
224
225def _infer_device_name(graph_def):
226  """Infer device name from a partition GraphDef."""
227  device_name = None
228  for node in graph_def.node:
229    if node.device:
230      device_name = node.device
231      break
232  if device_name is None:
233    logging.warn(
234        "Failed to infer device name from partition GraphDef: none of the "
235        "nodes of the GraphDef has a non-empty device name.")
236  return device_name
237
238
239class DebugGraph(object):
240  """Represents a debugger-decorated graph."""
241
242  def __init__(self, debug_graph_def, device_name=None):
243    self._debug_graph_def = debug_graph_def
244    self._non_debug_graph_def = None
245
246    self._node_attributes = {}
247    self._node_inputs = {}
248    self._node_reversed_ref_inputs = {}
249    self._node_ctrl_inputs = {}
250    self._node_recipients = {}
251    self._node_ctrl_recipients = {}
252    self._node_devices = {}
253    self._node_op_types = {}
254    self._copy_send_nodes = []
255    self._ref_args = {}
256
257    self._device_name = device_name
258    if not self._device_name:
259      self._device_name = _infer_device_name(debug_graph_def)
260
261    for node in debug_graph_def.node:
262      self._process_debug_graph_node(node)
263
264    self._prune_non_control_edges_of_debug_ops()
265    self._prune_control_edges_of_debug_ops()
266    self._prune_nodes_from_input_and_recipient_maps(self._get_copy_nodes())
267
268    self._populate_recipient_maps()
269
270  def _process_debug_graph_node(self, node):
271    """Process a node from the debug GraphDef.
272
273    Args:
274      node: (NodeDef) A partition-graph node to be processed.
275
276    Raises:
277      ValueError: If duplicate node names are encountered.
278    """
279    if is_debug_node(node.name):
280      # This is a debug node. Parse the node name and retrieve the
281      # information about debug watches on tensors. But do not include
282      # the node in the graph.
283      return
284
285    if node.name in self._node_inputs:
286      raise ValueError("Duplicate node name on device %s: '%s'" %
287                       (self._device_name, node.name))
288
289    self._node_attributes[node.name] = node.attr
290
291    self._node_inputs[node.name] = []
292    self._node_ctrl_inputs[node.name] = []
293    self._node_recipients[node.name] = []
294    self._node_ctrl_recipients[node.name] = []
295
296    if node.name not in self._node_devices:
297      self._node_devices[node.name] = set()
298    self._node_devices[node.name].add(
299        node.device if node.device else self._device_name)
300    self._node_op_types[node.name] = node.op
301    self._ref_args[node.name] = self._get_ref_args(node)
302
303    for inp in node.input:
304      if is_copy_node(inp) and (node.op == "_Send" or node.op == "_Retval"):
305        self._copy_send_nodes.append(node.name)
306
307      if inp.startswith("^"):
308        cinp = inp[1:]
309        self._node_ctrl_inputs[node.name].append(cinp)
310      else:
311        self._node_inputs[node.name].append(inp)
312
313  def _get_ref_args(self, node):
314    """Determine whether an input of an op is ref-type.
315
316    Args:
317      node: A `NodeDef`.
318
319    Returns:
320      A list of the arg names (as strs) that are ref-type.
321    """
322    op_def = op_def_registry.get_registered_ops().get(node.op)
323    ref_args = []
324    if op_def:
325      for i, output_arg in enumerate(op_def.output_arg):
326        if output_arg.is_ref:
327          arg_name = node.name if i == 0 else ("%s:%d" % (node.name, i))
328          ref_args.append(arg_name)
329    return ref_args
330
331  def _get_copy_nodes(self):
332    """Find all Copy nodes in the loaded graph."""
333    copy_nodes = []
334    for node in self._node_inputs:
335      if is_copy_node(node):
336        copy_nodes.append(node)
337    return copy_nodes
338
339  def _prune_non_control_edges_of_debug_ops(self):
340    """Prune (non-control) edges related to debug ops.
341
342    Prune the Copy ops and associated _Send ops inserted by the debugger out
343    from the non-control inputs and output recipients map. Replace the inputs
344    and recipients with original ones.
345    """
346    for node in self._node_inputs:
347      inputs = self._node_inputs[node]
348
349      for i in xrange(len(inputs)):
350        inp = inputs[i]
351        if is_copy_node(inp):
352          # Find the input to the Copy node, which should be the original
353          # input to the node.
354          orig_inp = self._node_inputs[inp][0]
355          inputs[i] = orig_inp
356
357  def _prune_control_edges_of_debug_ops(self):
358    """Prune control edges related to the debug ops."""
359    for node in self._node_ctrl_inputs:
360      ctrl_inputs = self._node_ctrl_inputs[node]
361      debug_op_inputs = []
362      for ctrl_inp in ctrl_inputs:
363        if is_debug_node(ctrl_inp):
364          debug_op_inputs.append(ctrl_inp)
365      for debug_op_inp in debug_op_inputs:
366        ctrl_inputs.remove(debug_op_inp)
367
368  def _populate_recipient_maps(self):
369    """Populate the map from node name to recipient(s) of its output(s).
370
371    This method also populates the input map based on reversed ref edges.
372    """
373    for node in self._node_inputs:
374      inputs = self._node_inputs[node]
375      for inp in inputs:
376        inp = get_node_name(inp)
377        if inp not in self._node_recipients:
378          self._node_recipients[inp] = []
379        self._node_recipients[inp].append(node)
380
381        if inp in self._ref_args:
382          if inp not in self._node_reversed_ref_inputs:
383            self._node_reversed_ref_inputs[inp] = []
384          self._node_reversed_ref_inputs[inp].append(node)
385
386    for node in self._node_ctrl_inputs:
387      ctrl_inputs = self._node_ctrl_inputs[node]
388      for ctrl_inp in ctrl_inputs:
389        if ctrl_inp in self._copy_send_nodes:
390          continue
391
392        if ctrl_inp not in self._node_ctrl_recipients:
393          self._node_ctrl_recipients[ctrl_inp] = []
394        self._node_ctrl_recipients[ctrl_inp].append(node)
395
396  def _prune_nodes_from_input_and_recipient_maps(self, nodes_to_prune):
397    """Prune nodes out of input and recipient maps.
398
399    Args:
400      nodes_to_prune: (`list` of `str`) Names of the nodes to be pruned.
401    """
402    for node in nodes_to_prune:
403      del self._node_inputs[node]
404      del self._node_ctrl_inputs[node]
405      del self._node_recipients[node]
406      del self._node_ctrl_recipients[node]
407
408  def _reconstruct_non_debug_graph_def(self):
409    """Reconstruct non-debug GraphDef.
410
411    Non-debug GraphDef means the original GraphDef without the Copy* and Debug
412    nodes inserted by the debugger.
413    """
414    if self._non_debug_graph_def:
415      return
416
417    self._non_debug_graph_def = graph_pb2.GraphDef()
418    for node in self._debug_graph_def.node:
419      if is_copy_node(node.name) or is_debug_node(node.name):
420        continue
421
422      new_node = self._non_debug_graph_def.node.add()
423      new_node.CopyFrom(node)
424
425      # Redo the list of inputs, because in _debug_graph_def, the list can
426      # consist of Copy* and Debug* nodes inserted by the debugger. Those will
427      # be replaced with the original inputs here.
428      del new_node.input[:]
429      for inp in self._node_inputs[node.name]:
430        new_node.input.append(inp)
431      for ctrl_inp in self._node_ctrl_inputs[node.name]:
432        new_node.input.append("^" + ctrl_inp)
433
434  @property
435  def device_name(self):
436    return self._device_name
437
438  @property
439  def debug_graph_def(self):
440    """The debugger-decorated GraphDef."""
441    return self._debug_graph_def
442
443  @property
444  def non_debug_graph_def(self):
445    """The GraphDef without the Copy* and Debug* nodes added by the debugger."""
446    self._reconstruct_non_debug_graph_def()
447    return self._non_debug_graph_def
448
449  @property
450  def node_devices(self):
451    return self._node_devices
452
453  @property
454  def node_op_types(self):
455    return self._node_op_types
456
457  @property
458  def node_attributes(self):
459    return self._node_attributes
460
461  @property
462  def node_inputs(self):
463    return self._node_inputs
464
465  @property
466  def node_ctrl_inputs(self):
467    return self._node_ctrl_inputs
468
469  @property
470  def node_reversed_ref_inputs(self):
471    return self._node_reversed_ref_inputs
472
473  @property
474  def node_recipients(self):
475    return self._node_recipients
476
477  @property
478  def node_ctrl_recipients(self):
479    return self._node_ctrl_recipients
480
481
482def reconstruct_non_debug_graph_def(debug_graph_def):
483  """Reconstruct original (non-debugger-decorated) partition GraphDef.
484
485  This method strips the input `tf.GraphDef` of the Copy* and Debug*-type nodes
486  inserted by the debugger.
487
488  The reconstructed partition graph is identical to the original (i.e.,
489    non-debugger-decorated) partition graph except in the following respects:
490      1) The exact names of the runtime-inserted internal nodes may differ.
491         These include _Send, _Recv, _HostSend, _HostRecv, _Retval ops.
492      2) As a consequence of 1, the nodes that receive input directly from such
493         send- and recv-type ops will have different input names.
494      3) The parallel_iteration attribute of while-loop Enter ops are set to 1.
495
496  Args:
497    debug_graph_def: The debugger-decorated `tf.GraphDef`, with the
498      debugger-inserted Copy* and Debug* nodes.
499
500  Returns:
501    The reconstructed `tf.GraphDef` stripped of the debugger-inserted nodes.
502  """
503  return DebugGraph(debug_graph_def).non_debug_graph_def
504