1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes and methods for processing debugger-decorated graphs."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.core.framework import graph_pb2
21from tensorflow.python.framework import op_def_registry
22from tensorflow.python.platform import tf_logging as logging
23
24
25def parse_node_or_tensor_name(name):
26  """Get the node name from a string that can be node or tensor name.
27
28  Args:
29    name: An input node name (e.g., "node_a") or tensor name (e.g.,
30      "node_a:0"), as a str.
31
32  Returns:
33    1) The node name, as a str. If the input name is a tensor name, i.e.,
34      consists of a colon, the final colon and the following output slot
35      will be stripped.
36    2) If the input name is a tensor name, the output slot, as an int. If
37      the input name is not a tensor name, None.
38  """
39
40  if ":" in name and not name.endswith(":"):
41    node_name = name[:name.rfind(":")]
42    output_slot = int(name[name.rfind(":") + 1:])
43
44    return node_name, output_slot
45  else:
46    return name, None
47
48
49def get_node_name(element_name):
50  node_name, _ = parse_node_or_tensor_name(element_name)
51  return node_name
52
53
54def get_output_slot(element_name):
55  """Get the output slot number from the name of a graph element.
56
57  If element_name is a node name without output slot at the end, 0 will be
58  assumed.
59
60  Args:
61    element_name: (`str`) name of the graph element in question.
62
63  Returns:
64    (`int`) output slot number.
65  """
66  _, output_slot = parse_node_or_tensor_name(element_name)
67  return output_slot if output_slot is not None else 0
68
69
70def is_copy_node(node_name):
71  """Determine whether a node name is that of a debug Copy node.
72
73  Such nodes are inserted by TensorFlow core upon request in
74  RunOptions.debug_options.debug_tensor_watch_opts.
75
76  Args:
77    node_name: Name of the node.
78
79  Returns:
80    A bool indicating whether the input argument is the name of a debug Copy
81    node.
82  """
83  return node_name.startswith("__copy_")
84
85
86def is_debug_node(node_name):
87  """Determine whether a node name is that of a debug node.
88
89  Such nodes are inserted by TensorFlow core upon request in
90  RunOptions.debug_options.debug_tensor_watch_opts.
91
92  Args:
93    node_name: Name of the node.
94
95  Returns:
96    A bool indicating whether the input argument is the name of a debug node.
97  """
98  return node_name.startswith("__dbg_")
99
100
101def parse_debug_node_name(node_name):
102  """Parse the name of a debug node.
103
104  Args:
105    node_name: Name of the debug node.
106
107  Returns:
108    1. Name of the watched node, as a str.
109    2. Output slot index of the watched tensor, as an int.
110    3. Index of the debug node, as an int.
111    4. Name of the debug op, as a str, e.g, "DebugIdentity".
112
113  Raises:
114    ValueError: If the input node name is not a valid debug node name.
115  """
116  prefix = "__dbg_"
117
118  name = node_name
119  if not name.startswith(prefix):
120    raise ValueError("Invalid prefix in debug node name: '%s'" % node_name)
121
122  name = name[len(prefix):]
123
124  if name.count("_") < 2:
125    raise ValueError("Invalid debug node name: '%s'" % node_name)
126
127  debug_op = name[name.rindex("_") + 1:]
128  name = name[:name.rindex("_")]
129
130  debug_op_index = int(name[name.rindex("_") + 1:])
131  name = name[:name.rindex("_")]
132
133  if name.count(":") != 1:
134    raise ValueError("Invalid tensor name in debug node name: '%s'" % node_name)
135
136  watched_node_name = name[:name.index(":")]
137  watched_output_slot = int(name[name.index(":") + 1:])
138
139  return watched_node_name, watched_output_slot, debug_op_index, debug_op
140
141
142class GraphTracingReachedDestination(Exception):
143  pass
144
145
146class DFSGraphTracer(object):
147  """Graph input tracer using depth-first search."""
148
149  def __init__(self,
150               input_lists,
151               skip_node_names=None,
152               destination_node_name=None):
153    """Constructor of _DFSGraphTracer.
154
155    Args:
156      input_lists: A list of dicts. Each dict is an adjacency (input) map from
157        the recipient node name as the key and the list of input node names
158        as the value.
159      skip_node_names: Optional: a list of node names to skip tracing.
160      destination_node_name: Optional: destination node name. If not `None`, it
161        should be the name of a destination not as a str and the graph tracing
162        will raise GraphTracingReachedDestination as soon as the node has been
163        reached.
164
165    Raises:
166      GraphTracingReachedDestination: if stop_at_node_name is not None and
167        the specified node is reached.
168    """
169
170    self._input_lists = input_lists
171    self._skip_node_names = skip_node_names
172
173    self._inputs = []
174    self._visited_nodes = []
175    self._depth_count = 0
176    self._depth_list = []
177
178    self._destination_node_name = destination_node_name
179
180  def trace(self, graph_element_name):
181    """Trace inputs.
182
183    Args:
184      graph_element_name: Name of the node or an output tensor of the node, as a
185        str.
186
187    Raises:
188      GraphTracingReachedDestination: if destination_node_name of this tracer
189        object is not None and the specified node is reached.
190    """
191    self._depth_count += 1
192
193    node_name = get_node_name(graph_element_name)
194    if node_name == self._destination_node_name:
195      raise GraphTracingReachedDestination()
196
197    if node_name in self._skip_node_names:
198      return
199    if node_name in self._visited_nodes:
200      return
201
202    self._visited_nodes.append(node_name)
203
204    for input_list in self._input_lists:
205      if node_name not in input_list:
206        continue
207      for inp in input_list[node_name]:
208        if get_node_name(inp) in self._visited_nodes:
209          continue
210        self._inputs.append(inp)
211        self._depth_list.append(self._depth_count)
212        self.trace(inp)
213
214    self._depth_count -= 1
215
216  def inputs(self):
217    return self._inputs
218
219  def depth_list(self):
220    return self._depth_list
221
222
223def _infer_device_name(graph_def):
224  """Infer device name from a partition GraphDef."""
225  device_name = None
226  for node in graph_def.node:
227    if node.device:
228      device_name = node.device
229      break
230  if device_name is None:
231    logging.warn(
232        "Failed to infer device name from partition GraphDef: none of the "
233        "nodes of the GraphDef has a non-empty device name.")
234  return device_name
235
236
237class DebugGraph(object):
238  """Represents a debugger-decorated graph."""
239
240  def __init__(self, debug_graph_def, device_name=None):
241    self._debug_graph_def = debug_graph_def
242    self._non_debug_graph_def = None
243
244    self._node_attributes = {}
245    self._node_inputs = {}
246    self._node_reversed_ref_inputs = {}
247    self._node_ctrl_inputs = {}
248    self._node_recipients = {}
249    self._node_ctrl_recipients = {}
250    self._node_devices = {}
251    self._node_op_types = {}
252    self._copy_send_nodes = []
253    self._ref_args = {}
254
255    self._device_name = device_name
256    if not self._device_name:
257      self._device_name = _infer_device_name(debug_graph_def)
258
259    for node in debug_graph_def.node:
260      self._process_debug_graph_node(node)
261
262    self._prune_non_control_edges_of_debug_ops()
263    self._prune_control_edges_of_debug_ops()
264    self._prune_nodes_from_input_and_recipient_maps(self._get_copy_nodes())
265
266    self._populate_recipient_maps()
267
268  def _process_debug_graph_node(self, node):
269    """Process a node from the debug GraphDef.
270
271    Args:
272      node: (NodeDef) A partition-graph node to be processed.
273
274    Raises:
275      ValueError: If duplicate node names are encountered.
276    """
277    if is_debug_node(node.name):
278      # This is a debug node. Parse the node name and retrieve the
279      # information about debug watches on tensors. But do not include
280      # the node in the graph.
281      return
282
283    if node.name in self._node_inputs:
284      raise ValueError("Duplicate node name on device %s: '%s'" %
285                       (self._device_name, node.name))
286
287    self._node_attributes[node.name] = node.attr
288
289    self._node_inputs[node.name] = []
290    self._node_ctrl_inputs[node.name] = []
291    self._node_recipients[node.name] = []
292    self._node_ctrl_recipients[node.name] = []
293
294    if node.name not in self._node_devices:
295      self._node_devices[node.name] = set()
296    self._node_devices[node.name].add(
297        node.device if node.device else self._device_name)
298    self._node_op_types[node.name] = node.op
299    self._ref_args[node.name] = self._get_ref_args(node)
300
301    for inp in node.input:
302      if is_copy_node(inp) and (node.op == "_Send" or node.op == "_Retval"):
303        self._copy_send_nodes.append(node.name)
304
305      if inp.startswith("^"):
306        cinp = inp[1:]
307        self._node_ctrl_inputs[node.name].append(cinp)
308      else:
309        self._node_inputs[node.name].append(inp)
310
311  def _get_ref_args(self, node):
312    """Determine whether an input of an op is ref-type.
313
314    Args:
315      node: A `NodeDef`.
316
317    Returns:
318      A list of the arg names (as strs) that are ref-type.
319    """
320    op_def = op_def_registry.get(node.op)
321    if op_def is None:
322      return []
323
324    ref_args = []
325    for i, output_arg in enumerate(op_def.output_arg):
326      if output_arg.is_ref:
327        arg_name = node.name if i == 0 else ("%s:%d" % (node.name, i))
328        ref_args.append(arg_name)
329    return ref_args
330
331  def _get_copy_nodes(self):
332    """Find all Copy nodes in the loaded graph."""
333    copy_nodes = []
334    for node in self._node_inputs:
335      if is_copy_node(node):
336        copy_nodes.append(node)
337    return copy_nodes
338
339  def _prune_non_control_edges_of_debug_ops(self):
340    """Prune (non-control) edges related to debug ops.
341
342    Prune the Copy ops and associated _Send ops inserted by the debugger out
343    from the non-control inputs and output recipients map. Replace the inputs
344    and recipients with original ones.
345    """
346    for node in self._node_inputs:
347      inputs = self._node_inputs[node]
348
349      for i, inp in enumerate(inputs):
350        if is_copy_node(inp):
351          # Find the input to the Copy node, which should be the original
352          # input to the node.
353          orig_inp = self._node_inputs[inp][0]
354          inputs[i] = orig_inp
355
356  def _prune_control_edges_of_debug_ops(self):
357    """Prune control edges related to the debug ops."""
358    for node in self._node_ctrl_inputs:
359      ctrl_inputs = self._node_ctrl_inputs[node]
360      debug_op_inputs = []
361      for ctrl_inp in ctrl_inputs:
362        if is_debug_node(ctrl_inp):
363          debug_op_inputs.append(ctrl_inp)
364      for debug_op_inp in debug_op_inputs:
365        ctrl_inputs.remove(debug_op_inp)
366
367  def _populate_recipient_maps(self):
368    """Populate the map from node name to recipient(s) of its output(s).
369
370    This method also populates the input map based on reversed ref edges.
371    """
372    for node in self._node_inputs:
373      inputs = self._node_inputs[node]
374      for inp in inputs:
375        inp = get_node_name(inp)
376        if inp not in self._node_recipients:
377          self._node_recipients[inp] = []
378        self._node_recipients[inp].append(node)
379
380        if inp in self._ref_args:
381          if inp not in self._node_reversed_ref_inputs:
382            self._node_reversed_ref_inputs[inp] = []
383          self._node_reversed_ref_inputs[inp].append(node)
384
385    for node in self._node_ctrl_inputs:
386      ctrl_inputs = self._node_ctrl_inputs[node]
387      for ctrl_inp in ctrl_inputs:
388        if ctrl_inp in self._copy_send_nodes:
389          continue
390
391        if ctrl_inp not in self._node_ctrl_recipients:
392          self._node_ctrl_recipients[ctrl_inp] = []
393        self._node_ctrl_recipients[ctrl_inp].append(node)
394
395  def _prune_nodes_from_input_and_recipient_maps(self, nodes_to_prune):
396    """Prune nodes out of input and recipient maps.
397
398    Args:
399      nodes_to_prune: (`list` of `str`) Names of the nodes to be pruned.
400    """
401    for node in nodes_to_prune:
402      del self._node_inputs[node]
403      del self._node_ctrl_inputs[node]
404      del self._node_recipients[node]
405      del self._node_ctrl_recipients[node]
406
407  def _reconstruct_non_debug_graph_def(self):
408    """Reconstruct non-debug GraphDef.
409
410    Non-debug GraphDef means the original GraphDef without the Copy* and Debug
411    nodes inserted by the debugger.
412    """
413    if self._non_debug_graph_def:
414      return
415
416    self._non_debug_graph_def = graph_pb2.GraphDef()
417    for node in self._debug_graph_def.node:
418      if is_copy_node(node.name) or is_debug_node(node.name):
419        continue
420
421      new_node = self._non_debug_graph_def.node.add()
422      new_node.CopyFrom(node)
423
424      # Redo the list of inputs, because in _debug_graph_def, the list can
425      # consist of Copy* and Debug* nodes inserted by the debugger. Those will
426      # be replaced with the original inputs here.
427      del new_node.input[:]
428      for inp in self._node_inputs[node.name]:
429        new_node.input.append(inp)
430      for ctrl_inp in self._node_ctrl_inputs[node.name]:
431        new_node.input.append("^" + ctrl_inp)
432
433  @property
434  def device_name(self):
435    return self._device_name
436
437  @property
438  def debug_graph_def(self):
439    """The debugger-decorated GraphDef."""
440    return self._debug_graph_def
441
442  @property
443  def non_debug_graph_def(self):
444    """The GraphDef without the Copy* and Debug* nodes added by the debugger."""
445    self._reconstruct_non_debug_graph_def()
446    return self._non_debug_graph_def
447
448  @property
449  def node_devices(self):
450    return self._node_devices
451
452  @property
453  def node_op_types(self):
454    return self._node_op_types
455
456  @property
457  def node_attributes(self):
458    return self._node_attributes
459
460  @property
461  def node_inputs(self):
462    return self._node_inputs
463
464  @property
465  def node_ctrl_inputs(self):
466    return self._node_ctrl_inputs
467
468  @property
469  def node_reversed_ref_inputs(self):
470    return self._node_reversed_ref_inputs
471
472  @property
473  def node_recipients(self):
474    return self._node_recipients
475
476  @property
477  def node_ctrl_recipients(self):
478    return self._node_ctrl_recipients
479
480
481def reconstruct_non_debug_graph_def(debug_graph_def):
482  """Reconstruct original (non-debugger-decorated) partition GraphDef.
483
484  This method strips the input `tf.compat.v1.GraphDef` of the Copy* and
485  Debug*-type nodes inserted by the debugger.
486
487  The reconstructed partition graph is identical to the original (i.e.,
488    non-debugger-decorated) partition graph except in the following respects:
489      1) The exact names of the runtime-inserted internal nodes may differ.
490         These include _Send, _Recv, _HostSend, _HostRecv, _Retval ops.
491      2) As a consequence of 1, the nodes that receive input directly from such
492         send- and recv-type ops will have different input names.
493      3) The parallel_iteration attribute of while-loop Enter ops are set to 1.
494
495  Args:
496    debug_graph_def: The debugger-decorated `tf.compat.v1.GraphDef`, with the
497      debugger-inserted Copy* and Debug* nodes.
498
499  Returns:
500    The reconstructed `tf.compat.v1.GraphDef` stripped of the debugger-inserted
501    nodes.
502  """
503  return DebugGraph(debug_graph_def).non_debug_graph_def
504