1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes and functions to handle debug-dump data of TensorFlow Debugger."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import glob
23import json
24import os
25import platform
26import re
27
28import numpy as np
29import six
30
31from tensorflow.core.framework import graph_pb2
32from tensorflow.core.framework import types_pb2
33from tensorflow.core.util import event_pb2
34from tensorflow.python.debug.lib import debug_graphs
35from tensorflow.python.framework import tensor_util
36from tensorflow.python.platform import gfile
37from tensorflow.python.platform import tf_logging as logging
38from tensorflow.python.util import compat
39
40
41# TODO(cais): Tie these string constants in with C++?
42METADATA_FILE_PREFIX = "_tfdbg_"
43CORE_METADATA_TAG = "core_metadata_"
44GRAPH_FILE_TAG = "graph_"
45DEVICE_TAG = "device_"
46HASH_TAG = "hash"
47
48FETCHES_INFO_FILE_TAG = "fetches_info_"
49FEED_KEYS_INFO_FILE_TAG = "feed_keys_info_"
50
51
52def _glob(glob_pattern):
53  if platform.system() == "Windows":
54    return glob.glob(glob_pattern)
55  else:
56    return gfile.Glob(glob_pattern)
57
58
59class InconvertibleTensorProto(object):
60  """Represents a TensorProto that cannot be converted to np.ndarray."""
61
62  def __init__(self, tensor_proto, initialized=True):
63    """Constructor.
64
65    Args:
66      tensor_proto: the `TensorProto` object that cannot be represented as a
67        `np.ndarray` object.
68      initialized: (`bool`) whether the Tensor is initialized.
69    """
70    self._tensor_proto = tensor_proto
71    self._initialized = initialized
72
73  def __str__(self):
74    output = "" if self._initialized else "Uninitialized tensor:\n"
75    output += str(self._tensor_proto)
76    return output
77
78  @property
79  def initialized(self):
80    return self._initialized
81
82
83def load_tensor_from_event_file(event_file_path):
84  """Load a tensor from an event file.
85
86  Assumes that the event file contains a `Event` protobuf and the `Event`
87  protobuf contains a `Tensor` value.
88
89  Args:
90    event_file_path: (`str`) path to the event file.
91
92  Returns:
93    The tensor value loaded from the event file, as a `numpy.ndarray`. For
94    uninitialized Tensors, returns `None`. For Tensors of data types that
95    cannot be converted to `numpy.ndarray` (e.g., `tf.resource`), return
96    `None`.
97  """
98
99  event = event_pb2.Event()
100  with gfile.Open(event_file_path, "rb") as f:
101    event.ParseFromString(f.read())
102    return load_tensor_from_event(event)
103
104
105def load_tensor_from_event(event):
106  """Load a tensor from an Event proto.
107
108  Args:
109    event: The Event proto, assumed to hold a tensor value in its
110        summary.value[0] field.
111
112  Returns:
113    The tensor value loaded from the event file, as a `numpy.ndarray`, if
114    representation of the tensor value by a `numpy.ndarray` is possible.
115    For uninitialized Tensors, returns `None`. For Tensors of data types that
116    cannot be represented as `numpy.ndarray` (e.g., `tf.resource`), return
117    the `TensorProto` protobuf object without converting it to a
118    `numpy.ndarray`.
119  """
120
121  tensor_proto = event.summary.value[0].tensor
122  shape = tensor_util.TensorShapeProtoToList(tensor_proto.tensor_shape)
123  num_elements = 1
124  for shape_dim in shape:
125    num_elements *= shape_dim
126
127  if tensor_proto.tensor_content or tensor_proto.string_val or not num_elements:
128    # Initialized tensor or empty tensor.
129    if tensor_proto.dtype == types_pb2.DT_RESOURCE:
130      tensor_value = InconvertibleTensorProto(tensor_proto)
131    else:
132      try:
133        tensor_value = tensor_util.MakeNdarray(tensor_proto)
134      except KeyError:
135        tensor_value = InconvertibleTensorProto(tensor_proto)
136  else:
137    # Uninitialized tensor or tensor of unconvertible data type.
138    tensor_value = InconvertibleTensorProto(tensor_proto, False)
139
140  return tensor_value
141
142
143def _load_graph_def_from_event_file(event_file_path):
144  event = event_pb2.Event()
145  with gfile.Open(event_file_path, "rb") as f:
146    event.ParseFromString(f.read())
147
148  return graph_pb2.GraphDef.FromString(event.graph_def)
149
150
151def _load_log_message_from_event_file(event_file_path):
152  event = event_pb2.Event()
153  with gfile.Open(event_file_path, "rb") as f:
154    event.ParseFromString(f.read())
155
156  return event.log_message.message
157
158
159def _is_graph_file(file_name):
160  return file_name.startswith(METADATA_FILE_PREFIX + GRAPH_FILE_TAG)
161
162
163def _is_run_fetches_info_file(file_name):
164  return file_name == METADATA_FILE_PREFIX + FETCHES_INFO_FILE_TAG
165
166
167def _is_run_feed_keys_info_file(file_name):
168  return file_name == METADATA_FILE_PREFIX + FEED_KEYS_INFO_FILE_TAG
169
170
171def _get_tensor_name(node_name, output_slot):
172  """Get tensor name given node name and output slot index.
173
174  Args:
175    node_name: Name of the node that outputs the tensor, as a string.
176    output_slot: Output slot index of the tensor, as an integer.
177
178  Returns:
179    Name of the tensor, as a string.
180  """
181
182  return "%s:%d" % (node_name, output_slot)
183
184
185def _get_tensor_watch_key(node_name, output_slot, debug_op):
186  """Get the string representation of a debug watch on a tensor.
187
188  Args:
189    node_name: Name of the node by which the watched tensor is produced, as a
190        string.
191    output_slot: Output slot index of the tensor, as an integer.
192    debug_op: Name of the debug op that is used to watch the tensor, as a
193        string.
194
195  Returns:
196    A string representing the debug watch on the tensor (i.e., the "watch
197        key").
198  """
199  return "%s:%s" % (_get_tensor_name(node_name, output_slot), debug_op)
200
201
202def has_inf_or_nan(datum, tensor):
203  """A predicate for whether a tensor consists of any bad numerical values.
204
205  This predicate is common enough to merit definition in this module.
206  Bad numerical values include `nan`s and `inf`s.
207  The signature of this function follows the requirement of the method
208  `DebugDumpDir.find()`.
209
210  Args:
211    datum: (`DebugTensorDatum`) Datum metadata.
212    tensor: (`numpy.ndarray` or None) Value of the tensor. None represents
213      an uninitialized tensor.
214
215  Returns:
216    (`bool`) True if and only if tensor consists of any nan or inf values.
217  """
218
219  _ = datum  # Datum metadata is unused in this predicate.
220
221  if isinstance(tensor, InconvertibleTensorProto):
222    # Uninitialized tensor doesn't have bad numerical values.
223    # Also return False for data types that cannot be represented as numpy
224    # arrays.
225    return False
226  elif (np.issubdtype(tensor.dtype, np.floating) or
227        np.issubdtype(tensor.dtype, np.complex) or
228        np.issubdtype(tensor.dtype, np.integer)):
229    return np.any(np.isnan(tensor)) or np.any(np.isinf(tensor))
230  else:
231    return False
232
233
234_CoreMetadata = collections.namedtuple("CoreMetadata", [
235    "global_step", "session_run_index", "executor_step_index", "input_names",
236    "output_names", "target_nodes"
237])
238
239
240def extract_core_metadata_from_event_proto(event):
241  json_metadata = json.loads(event.log_message.message)
242  return _CoreMetadata(json_metadata["global_step"],
243                       json_metadata["session_run_index"],
244                       json_metadata["executor_step_index"],
245                       json_metadata["input_names"],
246                       json_metadata["output_names"],
247                       json_metadata["target_nodes"])
248
249
250def device_name_to_device_path(device_name):
251  """Convert device name to device path."""
252  device_name_items = compat.as_text(device_name).split("/")
253  device_name_items = [item.replace(":", "_") for item in device_name_items]
254  return METADATA_FILE_PREFIX + DEVICE_TAG + ",".join(device_name_items)
255
256
257def device_path_to_device_name(device_dir):
258  """Parse device name from device path.
259
260  Args:
261    device_dir: (str) a directory name for the device.
262
263  Returns:
264    (str) parsed device name.
265  """
266  path_items = os.path.basename(device_dir)[
267      len(METADATA_FILE_PREFIX) + len(DEVICE_TAG):].split(",")
268  return "/".join([
269      path_item.replace("device_", "device:").replace("_", ":", 1)
270      for path_item in path_items])
271
272
273class DebugTensorDatum(object):
274  """A single tensor dumped by TensorFlow Debugger (tfdbg).
275
276  Contains metadata about the dumped tensor, including `timestamp`,
277  `node_name`, `output_slot`, `debug_op`, and path to the dump file
278  (`file_path`).
279
280  This type does not hold the generally space-expensive tensor value (numpy
281  array). Instead, it points to the file from which the tensor value can be
282  loaded (with the `get_tensor` method) if needed.
283  """
284
285  def __init__(self, dump_root, debug_dump_rel_path):
286    """`DebugTensorDatum` constructor.
287
288    Args:
289      dump_root: (`str`) Debug dump root directory. This path should not include
290        the path component that represents the device name (see also below).
291      debug_dump_rel_path: (`str`) Path to a debug dump file, relative to the
292        `dump_root`. The first item of this relative path is assumed to be
293        a path representing the name of the device that the Tensor belongs to.
294        See `device_path_to_device_name` for more details on the device path.
295        For example, suppose the debug dump root
296        directory is `/tmp/tfdbg_1` and the dump file is at
297        `/tmp/tfdbg_1/<device_path>/>ns_1/node_a_0_DebugIdentity_123456789`,
298        then the value of the debug_dump_rel_path should be
299        `<device_path>/ns_1/node_a_0_DebugIdenity_1234456789`.
300
301    Raises:
302      ValueError: If the base file name of the dump file does not conform to
303        the dump file naming pattern:
304        `node_name`_`output_slot`_`debug_op`_`timestamp`
305    """
306
307    path_components = os.path.normpath(debug_dump_rel_path).split(os.sep)
308    self._device_name = device_path_to_device_name(path_components[0])
309    base = path_components[-1]
310    if base.count("_") < 3:
311      raise ValueError(
312          "Dump file path does not conform to the naming pattern: %s" % base)
313
314    self._extended_timestamp = base.split("_")[-1]
315    # It may include an index suffix at the end if file path collision happened
316    # due to identical timestamps.
317    if "-" in self._extended_timestamp:
318      self._timestamp = int(
319          self._extended_timestamp[:self._extended_timestamp.find("-")])
320    else:
321      self._timestamp = int(self._extended_timestamp)
322
323    self._debug_op = base.split("_")[-2]
324    self._output_slot = int(base.split("_")[-3])
325
326    node_base_name = "_".join(base.split("_")[:-3])
327    self._node_name = "/".join(path_components[1:-1] + [node_base_name])
328
329    self._file_path = os.path.join(dump_root, debug_dump_rel_path)
330    self._dump_size_bytes = (gfile.Stat(self._file_path).length if
331                             gfile.Exists(self._file_path) else None)
332
333  def __str__(self):
334    return "{DebugTensorDatum (%s) %s:%d @ %s @ %d}" % (self.device_name,
335                                                        self.node_name,
336                                                        self.output_slot,
337                                                        self.debug_op,
338                                                        self.timestamp)
339
340  def __repr__(self):
341    return self.__str__()
342
343  def get_tensor(self):
344    """Get tensor from the dump (`Event`) file.
345
346    Returns:
347      The tensor loaded from the dump (`Event`) file.
348    """
349
350    return load_tensor_from_event_file(self.file_path)
351
352  # TODO(cais): Add time unit suffix to timestamp and t0 (us).
353  @property
354  def timestamp(self):
355    """Timestamp of when this tensor value was dumped.
356
357    Returns:
358      (`int`) The timestamp in microseconds.
359    """
360
361    return self._timestamp
362
363  @property
364  def extended_timestamp(self):
365    """Extended timestamp, possibly with an index suffix.
366
367    The index suffix, e.g., "-1", is for disambiguating multiple dumps of the
368    same tensor with the same timestamp, which can occur if the dumping events
369    are spaced by shorter than the temporal resolution of the timestamps.
370
371    Returns:
372      (`str`) The extended timestamp.
373    """
374
375    return self._extended_timestamp
376
377  @property
378  def debug_op(self):
379    """Name of the debug op.
380
381    Returns:
382      (`str`) debug op name (e.g., `DebugIdentity`).
383    """
384
385    return self._debug_op
386
387  @property
388  def device_name(self):
389    """Name of the device that the tensor belongs to.
390
391    Returns:
392      (`str`) device name.
393    """
394
395    return self._device_name
396
397  @property
398  def node_name(self):
399    """Name of the node from which the tensor value was dumped.
400
401    Returns:
402      (`str`) name of the node watched by the debug op.
403    """
404
405    return self._node_name
406
407  @property
408  def output_slot(self):
409    """Output slot index from which the tensor value was dumped.
410
411    Returns:
412      (`int`) output slot index watched by the debug op.
413    """
414
415    return self._output_slot
416
417  @property
418  def tensor_name(self):
419    """Name of the tensor watched by the debug op.
420
421    Returns:
422      (`str`) `Tensor` name, in the form of `node_name`:`output_slot`
423    """
424
425    return _get_tensor_name(self.node_name, self.output_slot)
426
427  @property
428  def watch_key(self):
429    """Watch key identities a debug watch on a tensor.
430
431    Returns:
432      (`str`) A watch key, in the form of `tensor_name`:`debug_op`.
433    """
434
435    return _get_tensor_watch_key(self.node_name, self.output_slot,
436                                 self.debug_op)
437
438  @property
439  def file_path(self):
440    """Path to the file which stores the value of the dumped tensor."""
441
442    return self._file_path
443
444  @property
445  def dump_size_bytes(self):
446    """Size of the dump file.
447
448    Unit: byte.
449
450    Returns:
451      If the dump file exists, size of the dump file, in bytes.
452      If the dump file does not exist, None.
453    """
454
455    return self._dump_size_bytes
456
457
458class WatchKeyDoesNotExistInDebugDumpDirError(ValueError):
459  pass
460
461
462class DebugDumpDir(object):
463  """Data set from a debug-dump directory on filesystem.
464
465  An instance of `DebugDumpDir` contains all `DebugTensorDatum` instances
466  in a tfdbg dump root directory.
467  """
468
469  def __init__(self, dump_root, partition_graphs=None, validate=True):
470    """`DebugDumpDir` constructor.
471
472    Args:
473      dump_root: (`str`) path to the dump root directory.
474      partition_graphs: A repeated field of GraphDefs representing the
475          partition graphs executed by the TensorFlow runtime.
476      validate: (`bool`) whether the dump files are to be validated against the
477          partition graphs.
478
479    Raises:
480      IOError: If dump_root does not exist as a directory.
481      ValueError: If more than one core metadata file is found under the dump
482        root directory.
483    """
484
485    if not gfile.IsDirectory(dump_root):
486      raise IOError("Dump root directory %s does not exist" % dump_root)
487
488    self._core_metadata = []
489
490    # Find the list of devices.
491    self._dump_root = dump_root
492
493    self._load_core_metadata()
494    self._load_fetches_info()
495    self._load_feeds_info()
496    self._load_all_device_dumps(partition_graphs, validate)
497
498    self._python_graph = None
499
500  def _load_all_device_dumps(self, partition_graphs, validate):
501    """Load the dump data for all devices."""
502    device_dirs = _glob(os.path.join(
503        self._dump_root, METADATA_FILE_PREFIX + DEVICE_TAG + "*"))
504
505    self._device_names = []
506    self._t0s = {}
507    self._dump_tensor_data = {}
508    self._dump_graph_file_paths = {}
509    self._debug_watches = {}
510    self._watch_key_to_devices = {}
511    self._watch_key_to_datum = {}
512    self._watch_key_to_rel_time = {}
513    self._watch_key_to_dump_size_bytes = {}
514    for device_dir in device_dirs:
515      device_name = device_path_to_device_name(device_dir)
516      self._device_names.append(device_name)
517      self._load_device_dumps(device_name, device_dir)
518    self._load_partition_graphs(partition_graphs, validate)
519    self._calculate_t0()
520
521    for device_name in self._device_names:
522      self._create_tensor_watch_maps(device_name)
523
524  def _load_device_dumps(self, device_name, device_root):
525    """Load `DebugTensorDatum` instances from the dump root of a given device.
526
527    Populates a map {device_name: a list of `DebugTensorDatum`}, where the list
528    is sorted by ascending timestamp.
529
530    This sorting order reflects the order in which the TensorFlow executor
531    processed the nodes of the graph. It is (one of many possible) topological
532    sort of the nodes. This is useful for displaying tensors in the debugger
533    frontend as well as for the use case in which the user wants to find a
534    "culprit tensor", i.e., the first tensor in the graph that exhibits certain
535    problematic properties, i.e., all zero values, or bad numerical values such
536    as nan and inf.
537
538    In addition, creates a map from node name to debug watches. In this Map,
539    the key is the watched node name; the value is a dictionary.
540    Of this dictionary, the key is the watched_output_slot.
541
542    This method attempts to load the debug watches from the tensor dump files
543    first, before loading the full set of debug watches from the partition
544    graphs as done later. This is necessary because sometimes the partition
545    graphs may not be available, e.g., when the run errors out.
546
547    Args:
548      device_name: (`str`) name of the device.
549      device_root: (`str`) dump root directory of the given device.
550
551    Raises:
552      ValueError: If GraphDef for the device is not available.
553    """
554
555    self._dump_tensor_data[device_name] = []
556    self._debug_watches[device_name] = collections.defaultdict(
557        lambda: collections.defaultdict(set))
558
559    for root, _, files in gfile.Walk(device_root):
560      for f in files:
561        if _is_graph_file(f):
562          self._dump_graph_file_paths[device_name] = os.path.join(root, f)
563        else:
564          datum = self._dump_file_name_to_datum(root, f)
565          self._dump_tensor_data[device_name].append(datum)
566          self._debug_watches[device_name][datum.node_name][
567              datum.output_slot].add(datum.debug_op)
568
569    self._dump_tensor_data[device_name] = sorted(
570        self._dump_tensor_data[device_name],
571        key=lambda x: x.extended_timestamp)
572
573    if self._dump_tensor_data[device_name]:
574      self._t0s[device_name] = self._dump_tensor_data[device_name][0].timestamp
575    else:
576      self._t0s[device_name] = None
577
578  def _calculate_t0(self):
579    """Calculate the first timestamp across all devices."""
580    t0s = [t0 for t0 in six.itervalues(self._t0s) if t0 is not None]
581    self._t0 = min(t0s) if t0s else None
582
583  def _load_core_metadata(self):
584    core_metadata_files = _glob(os.path.join(
585        self._dump_root, METADATA_FILE_PREFIX + CORE_METADATA_TAG + "*"))
586    for core_metadata_file in core_metadata_files:
587      with gfile.Open(core_metadata_file, "rb") as f:
588        event = event_pb2.Event()
589        event.ParseFromString(f.read())
590        self._core_metadata.append(
591            extract_core_metadata_from_event_proto(event))
592
593  def _load_fetches_info(self):
594    fetches_info_files = _glob(os.path.join(
595        self._dump_root, METADATA_FILE_PREFIX + FETCHES_INFO_FILE_TAG + "*"))
596    self._run_fetches_info = []
597    for fetches_info_file in fetches_info_files:
598      self._run_fetches_info.append(
599          _load_log_message_from_event_file(fetches_info_file))
600
601  def _load_feeds_info(self):
602    feeds_info_files = _glob(os.path.join(
603        self._dump_root, METADATA_FILE_PREFIX + FEED_KEYS_INFO_FILE_TAG + "*"))
604    self._run_feed_keys_info = []
605    for feeds_info_file in feeds_info_files:
606      self._run_feed_keys_info.append(
607          _load_log_message_from_event_file(feeds_info_file))
608
609  def _dump_file_name_to_datum(self, dir_name, file_name):
610    """Obtain a DebugTensorDatum from the directory and file name.
611
612    Args:
613      dir_name: (`str`) Name of the directory in which the dump file resides.
614      file_name: (`str`) Base name of the dump file.
615
616    Returns:
617      (`DebugTensorDatum`) The `DebugTensorDatum` loaded from the dump file.
618    """
619
620    # Calculate the relative path of the dump file with respect to the root.
621    debug_dump_rel_path = os.path.join(
622        os.path.relpath(dir_name, self._dump_root), file_name)
623    return DebugTensorDatum(self._dump_root, debug_dump_rel_path)
624
625  def _create_tensor_watch_maps(self, device_name):
626    """Create maps from tensor watch keys to datum and to timestamps.
627
628    Create a map from watch key (tensor name + debug op) to `DebugTensorDatum`
629    item. Also make a map from watch key to relative timestamp.
630    "relative" means (absolute timestamp - t0).
631
632    Args:
633      device_name: (str) name of the device.
634    """
635
636    self._watch_key_to_datum[device_name] = {}
637    self._watch_key_to_rel_time[device_name] = {}
638    self._watch_key_to_dump_size_bytes[device_name] = {}
639    for datum in self._dump_tensor_data[device_name]:
640      if datum.watch_key not in self._watch_key_to_devices:
641        self._watch_key_to_devices[datum.watch_key] = {device_name}
642      else:
643        self._watch_key_to_devices[datum.watch_key].add(device_name)
644
645      if datum.watch_key not in self._watch_key_to_datum[device_name]:
646        self._watch_key_to_datum[device_name][datum.watch_key] = [datum]
647        self._watch_key_to_rel_time[device_name][datum.watch_key] = [
648            datum.timestamp - self._t0]
649        self._watch_key_to_dump_size_bytes[device_name][datum.watch_key] = [
650            datum.dump_size_bytes]
651      else:
652        self._watch_key_to_datum[device_name][datum.watch_key].append(datum)
653        self._watch_key_to_rel_time[device_name][datum.watch_key].append(
654            datum.timestamp - self._t0)
655        self._watch_key_to_dump_size_bytes[device_name][datum.watch_key].append(
656            datum.dump_size_bytes)
657
658  def set_python_graph(self, python_graph):
659    """Provide Python `Graph` object to the wrapper.
660
661    Unlike the partition graphs, which are protobuf `GraphDef` objects, `Graph`
662    is a Python object and carries additional information such as the traceback
663    of the construction of the nodes in the graph.
664
665    Args:
666      python_graph: (ops.Graph) The Python Graph object.
667    """
668
669    self._python_graph = python_graph
670    self._node_traceback = {}
671    if self._python_graph:
672      for op in self._python_graph.get_operations():
673        self._node_traceback[op.name] = op.traceback
674
675  @property
676  def python_graph(self):
677    """Get the Python graph.
678
679    Returns:
680      If the Python graph has been set, returns a `tf.Graph` object. Otherwise,
681      returns None.
682    """
683
684    return self._python_graph
685
686  @property
687  def core_metadata(self):
688    """Metadata about the `Session.run()` call from the core runtime.
689
690    Of the three counters available in the return value, `global_step` is
691    supplied by the caller of the debugged `Session.run()`, while
692    `session_run_index` and `executor_step_index` are determined by the state
693    of the core runtime, automatically. For the same fetch list, feed keys and
694    debug tensor watch options, the same executor will be used and
695    `executor_step_index` should increase by one at a time. However, runs with
696    different fetch lists, feed keys and debug_tensor watch options that all
697    share the same `Session` object can lead to gaps in `session_run_index`.
698
699    Returns:
700      If core metadata are loaded, a `namedtuple` with the fields:
701        `global_step`: A global step count supplied by the caller of
702          `Session.run()`. It is optional to the caller. If the caller did not
703          supply this parameter, its value will be -1.
704        `session_run_index`: A sorted index for Run() calls to the underlying
705          TensorFlow `Session` object.
706        `executor_step_index`: A counter for invocations of a given runtime
707          executor. The same executor is re-used for the same fetched tensors,
708          target nodes, input feed keys and debug tensor watch options.
709        `input_names`: Names of the input (feed) Tensors.
710        `output_names`: Names of the output (fetched) Tensors.
711        `target_nodes`: Names of the target nodes.
712      If the core metadata have not been loaded, `None`.
713      If more than one core metadata files exist, return a list of the
714        `nametuple` described above.
715    """
716
717    output = self._core_metadata
718    return output[0] if len(output) == 1 else output
719
720  @property
721  def dumped_tensor_data(self):
722    """Retrieve dumped tensor data."""
723    if len(self.devices()) == 1:
724      return self._dump_tensor_data[self.devices()[0]]
725    else:
726      all_devices_data = six.itervalues(self._dump_tensor_data)
727      data = []
728      for device_data in all_devices_data:
729        data.extend(device_data)
730      return sorted(data, key=lambda x: x.extended_timestamp)
731
732  @property
733  def t0(self):
734    """Absolute timestamp of the first dumped tensor across all devices.
735
736    Returns:
737      (`int`) absolute timestamp of the first dumped tensor, in microseconds.
738    """
739    return self._t0
740
741  @property
742  def size(self):
743    """Total number of dumped tensors in the dump root directory.
744
745    Returns:
746      (`int`) The total number of dumped tensors in the dump root directory.
747    """
748    return sum(len(self._dump_tensor_data[device_name])
749               for device_name in self._dump_tensor_data)
750
751  def _load_partition_graphs(self, client_partition_graphs, validate):
752    """Load and process partition graphs.
753
754    Load the graphs; parse the input and control input structure; obtain the
755    device and op type of each node; remove the Copy and debug ops inserted
756    by the debugger. The gathered information can be used to validate the
757    tensor dumps.
758
759    Args:
760      client_partition_graphs: A repeated field of GraphDefs representing the
761        partition graphs executed by the TensorFlow runtime, from the Python
762        client. These partition graphs are used only if partition graphs
763        cannot be loaded from the dump directory on the file system.
764      validate: (`bool`) Whether the dump files are to be validated against the
765        partition graphs.
766
767    Raises:
768      ValueError: If the partition GraphDef of one or more devices fail to be
769        loaded.
770    """
771    self._debug_graphs = {}
772    self._node_devices = {}
773
774    partition_graphs_and_device_names = []
775    for device_name in self._device_names:
776      partition_graph = None
777      if device_name in self._dump_graph_file_paths:
778        partition_graph = _load_graph_def_from_event_file(
779            self._dump_graph_file_paths[device_name])
780      else:
781        logging.warn(
782            "Failed to load partition graphs for device %s from disk. "
783            "As a fallback, the client graphs will be used. This "
784            "may cause mismatches in device names." % device_name)
785        partition_graph = self._find_partition_graph(client_partition_graphs,
786                                                     device_name)
787
788      if partition_graph:
789        partition_graphs_and_device_names.append((partition_graph,
790                                                  device_name))
791
792    for partition_graph, maybe_device_name in partition_graphs_and_device_names:
793      debug_graph = debug_graphs.DebugGraph(partition_graph,
794                                            device_name=maybe_device_name)
795      self._debug_graphs[debug_graph.device_name] = debug_graph
796      self._collect_node_devices(debug_graph)
797
798      if validate and debug_graph.device_name in self._dump_tensor_data:
799        self._validate_dump_with_graphs(debug_graph.device_name)
800
801  def _find_partition_graph(self, partition_graphs, device_name):
802    if partition_graphs is None:
803      return None
804    else:
805      for graph_def in partition_graphs:
806        for node_def in graph_def.node:
807          if node_def.device == device_name:
808            return graph_def
809      return None
810
811  def _collect_node_devices(self, debug_graph):
812    for node_name in debug_graph.node_devices:
813      if node_name in self._node_devices:
814        self._node_devices[node_name] = self._node_devices[node_name].union(
815            debug_graph.node_devices[node_name])
816      else:
817        self._node_devices[node_name] = debug_graph.node_devices[node_name]
818
819  def _validate_dump_with_graphs(self, device_name):
820    """Validate the dumped tensor data against the partition graphs.
821
822    Only the watched nodes are validated by this method, because tfdbg allows
823    clients to watch only a subset of the nodes.
824
825    Args:
826      device_name: (`str`) device name.
827
828    Raises:
829      LookupError: If the partition graphs have not been loaded yet.
830      ValueError: If dumps contain node names not found in partition graph.
831        Or if the temporal order of the dump's timestamps violate the
832        input relations on the partition graphs.
833    """
834    if not self._debug_graphs:
835      raise LookupError(
836          "No partition graphs loaded for device %s" % device_name)
837    debug_graph = self._debug_graphs[device_name]
838
839    # Verify that the node names in the dump data are all present in the
840    # partition graphs.
841    for datum in self._dump_tensor_data[device_name]:
842      if datum.node_name not in debug_graph.node_inputs:
843        raise ValueError("Node name '%s' is not found in partition graphs of "
844                         "device %s." % (datum.node_name, device_name))
845
846    pending_inputs = {}
847    for node in debug_graph.node_inputs:
848      pending_inputs[node] = []
849      inputs = debug_graph.node_inputs[node]
850      for inp in inputs:
851        inp_node = debug_graphs.get_node_name(inp)
852        inp_output_slot = debug_graphs.get_output_slot(inp)
853        # Inputs from Enter and NextIteration nodes are not validated because
854        # DebugNodeInserter::InsertNodes() in the debugger core skips creating
855        # control edges from debug ops watching these types of nodes.
856        if (inp_node in self._debug_watches[device_name] and
857            inp_output_slot in self._debug_watches[device_name][inp_node] and
858            debug_graph.node_op_types.get(inp) not in (
859                "Enter", "NextIteration") and
860            (inp_node, inp_output_slot) not in pending_inputs[node]):
861          pending_inputs[node].append((inp_node, inp_output_slot))
862
863    for i, datum in enumerate(self._dump_tensor_data[device_name]):
864      node = datum.node_name
865      slot = datum.output_slot
866      # In some cases (e.g., system clocks with insufficient precision),
867      # the upstream and downstream tensors may have identical timestamps, the
868      # following check examines this possibility and avoids raising an error if
869      # that is the case.
870      if not self._satisfied_at_timestamp(
871          device_name, pending_inputs[node], datum.timestamp, start_i=i + 1):
872        raise ValueError("Causality violated in timing relations of debug "
873                         "dumps: %s (%d): "
874                         "these input(s) are not satisfied: %s" %
875                         (node, datum.timestamp, repr(pending_inputs[node])))
876
877      recipients = debug_graph.node_recipients[node]
878      for recipient in recipients:
879        recipient_pending_inputs = pending_inputs[recipient]
880        if (node, slot) in recipient_pending_inputs:
881          if self.node_op_type(recipient) == "Merge":
882            # If this is a Merge op, we automatically clear the list because
883            # a Merge node only requires one of its two inputs.
884            del recipient_pending_inputs[:]
885          else:
886            del recipient_pending_inputs[
887                recipient_pending_inputs.index((node, slot))]
888
889  def _satisfied_at_timestamp(self, device_name, pending, timestamp, start_i=0):
890    """Determine whether pending inputs are satisfied at given timestamp.
891
892    Note: This method mutates the input argument "pending".
893
894    Args:
895      device_name: (str) device name.
896      pending: A list of 2-tuple (node_name, output_slot): the dependencies to
897        check.
898      timestamp: (int) the timestamp in question.
899      start_i: (int) the index in self._dump_tensor_data to start searching for
900        the timestamp.
901
902    Returns:
903      (bool) Whether all the dependencies in pending are satisfied at the
904        timestamp. If pending is empty to begin with, return True.
905    """
906    if not pending:
907      return True
908
909    for datum in self._dump_tensor_data[device_name][start_i:]:
910      if datum.timestamp > timestamp:
911        break
912      if (datum.timestamp == timestamp and
913          (datum.node_name, datum.output_slot) in pending):
914        pending.remove((datum.node_name, datum.output_slot))
915        if not pending:
916          return True
917
918    return not pending
919
920  def loaded_partition_graphs(self):
921    """Test whether partition graphs have been loaded."""
922    return bool(self._debug_graphs)
923
924  def partition_graphs(self):
925    """Get the partition graphs.
926
927    Returns:
928      Partition graphs as a list of GraphDef.
929
930    Raises:
931      LookupError: If no partition graphs have been loaded.
932    """
933    if not self._debug_graphs:
934      raise LookupError("No partition graphs have been loaded.")
935    return [self._debug_graphs[key].debug_graph_def
936            for key in self._debug_graphs]
937
938  def reconstructed_non_debug_partition_graphs(self):
939    """Reconstruct partition graphs with the debugger-inserted ops stripped.
940
941    The reconstructed partition graphs are identical to the original (i.e.,
942    non-debugger-decorated) partition graphs except in the following respects:
943      1) The exact names of the runtime-inserted internal nodes may differ.
944         These include _Send, _Recv, _HostSend, _HostRecv, _Retval ops.
945      2) As a consequence of 1, the nodes that receive input directly from such
946         send- and recv-type ops will have different input names.
947      3) The parallel_iteration attribute of while-loop Enter ops are set to 1.
948
949    Returns:
950      A dict mapping device names (`str`s) to reconstructed `tf.GraphDef`s.
951    """
952    non_debug_graphs = dict()
953    for key in self._debug_graphs:
954      non_debug_graphs[key] = self._debug_graphs[key].non_debug_graph_def
955    return non_debug_graphs
956
957  @property
958  def run_fetches_info(self):
959    """Get a str representation of the fetches used in the Session.run() call.
960
961    Returns:
962      If the information is available from one `Session.run` call, a `str`
963        obtained from `repr(fetches)`.
964      If the information is available from multiple `Session.run` calls, a
965        `list` of `str` from `repr(fetches)`.
966      If the information is not available, `None`.
967    """
968
969    output = self._run_fetches_info
970    return output[0] if len(output) == 1 else output
971
972  @property
973  def run_feed_keys_info(self):
974    """Get a str representation of the feed_dict used in the Session.run() call.
975
976    Returns:
977      If the information is available from one `Session.run` call, a `str`
978        obtained from `repr(feed_dict)`.
979      If the information is available from multiple `Session.run` calls, a
980        `list` of `str` obtained from `repr(feed_dict)`.
981      If the information is not available, `None`.
982    """
983
984    output = self._run_feed_keys_info
985    return output[0] if len(output) == 1 else output
986
987  def _infer_device_name(self, device_name, node_name):
988    """Infer the device name given node name.
989
990    If device_name is provided (i.e., not None), it'll be simply returned right
991    away.
992
993    Args:
994      device_name: (str or None) name of the device. If None, will try to infer
995        the device name by looking at the available nodes.
996      node_name: (str) name of the node.
997
998    Returns:
999      (str) Inferred name of the device, if available.
1000
1001    Raises:
1002      ValueError: If the node name does not exist on any of the available
1003        devices or if there are multiple devices that contain the node with
1004        the given name.
1005    """
1006    if device_name is None:
1007      if node_name in self._node_devices:
1008        if len(self._node_devices[node_name]) == 1:
1009          return list(self._node_devices[node_name])[0]
1010        else:
1011          raise ValueError(
1012              "There are multiple (%d) devices with nodes named '%s' but "
1013              "device_name is not specified." %
1014              (len(self._node_devices[node_name]), node_name))
1015      else:
1016        raise ValueError("None of the %d device(s) has a node named '%s'." %
1017                         (len(self._device_names), node_name))
1018    else:
1019      return device_name
1020
1021  def nodes(self, device_name=None):
1022    """Get a list of all nodes from the partition graphs.
1023
1024    Args:
1025      device_name: (`str`) name of device. If None, all nodes from all available
1026        devices will be included.
1027
1028    Returns:
1029      All nodes' names, as a list of str.
1030
1031    Raises:
1032      LookupError: If no partition graphs have been loaded.
1033      ValueError: If specified node name does not exist.
1034    """
1035    if not self._debug_graphs:
1036      raise LookupError("No partition graphs have been loaded.")
1037    if device_name is None:
1038      nodes = []
1039      for device_name in self._debug_graphs:
1040        nodes.extend(self._debug_graphs[device_name].node_inputs.keys())
1041      return nodes
1042    else:
1043      if device_name not in self._debug_graphs:
1044        raise ValueError("Invalid device name: %s" % device_name)
1045      return self._debug_graphs[device_name].node_inputs.keys()
1046
1047  def node_attributes(self, node_name, device_name=None):
1048    """Get the attributes of a node.
1049
1050    Args:
1051      node_name: Name of the node in question.
1052      device_name: (`str`) name of the device. If there is only one device or if
1053        node_name exists on only one device, this argument is optional.
1054
1055    Returns:
1056      Attributes of the node.
1057
1058    Raises:
1059      LookupError: If no partition graphs have been loaded.
1060    """
1061    if not self._debug_graphs:
1062      raise LookupError("No partition graphs have been loaded.")
1063
1064    device_name = self._infer_device_name(device_name, node_name)
1065    return self._debug_graphs[device_name].node_attributes[node_name]
1066
1067  def node_inputs(self, node_name, is_control=False, device_name=None):
1068    """Get the inputs of given node according to partition graphs.
1069
1070    Args:
1071      node_name: Name of the node.
1072      is_control: (`bool`) Whether control inputs, rather than non-control
1073        inputs, are to be returned.
1074      device_name: (`str`) name of the device. If there is only one device or if
1075        node_name exists on only one device, this argument is optional.
1076
1077    Returns:
1078      (`list` of `str`) inputs to the node, as a list of node names.
1079
1080    Raises:
1081      LookupError: If node inputs and control inputs have not been loaded
1082         from partition graphs yet.
1083    """
1084    if not self._debug_graphs:
1085      raise LookupError(
1086          "Node inputs are not loaded from partition graphs yet.")
1087
1088    device_name = self._infer_device_name(device_name, node_name)
1089    if is_control:
1090      return self._debug_graphs[device_name].node_ctrl_inputs[node_name]
1091    else:
1092      return self._debug_graphs[device_name].node_inputs[node_name]
1093
1094  def transitive_inputs(self,
1095                        node_name,
1096                        include_control=True,
1097                        include_reversed_ref=False,
1098                        device_name=None,):
1099    """Get the transitive inputs of given node according to partition graphs.
1100
1101    Args:
1102      node_name: Name of the node.
1103      include_control: Include control inputs (True by default).
1104      include_reversed_ref: Whether a ref input, say from A to B, is to be also
1105        considered as an input from B to A. The rationale is that ref inputs
1106        generally let the recipient (e.g., B in this case) mutate the value of
1107        the source (e.g., A in this case). So the reverse direction of the ref
1108        edge reflects the direction of information flow.
1109      device_name: (`str`) name of the device. If there is only one device or if
1110        node_name exists on only one device, this argument is optional.
1111
1112    Returns:
1113      (`list` of `str`) all transitive inputs to the node, as a list of node
1114        names.
1115
1116    Raises:
1117      LookupError: If node inputs and control inputs have not been loaded
1118         from partition graphs yet.
1119    """
1120    if not self._debug_graphs:
1121      raise LookupError(
1122          "Node inputs are not loaded from partition graphs yet.")
1123
1124    device_name = self._infer_device_name(device_name, node_name)
1125
1126    input_lists = [self._debug_graphs[device_name].node_inputs]
1127    if include_control:
1128      input_lists.append(self._debug_graphs[device_name].node_ctrl_inputs)
1129    if include_reversed_ref:
1130      input_lists.append(
1131          self._debug_graphs[device_name].node_reversed_ref_inputs)
1132    tracer = debug_graphs.DFSGraphTracer(
1133        input_lists,
1134        skip_node_names=self._get_merge_node_names(device_name))
1135    tracer.trace(node_name)
1136    return tracer.inputs()
1137
1138  def _get_merge_node_names(self, device_name):
1139    """Lazily get a list of Merge nodes on a given device."""
1140    if device_name not in self._device_names:
1141      raise ValueError("Invalid device name: %s" % device_name)
1142
1143    if not hasattr(self, "_merge_node_names"):
1144      self._merge_node_names = {}
1145    if device_name not in self._merge_node_names:
1146      debug_graph = self._debug_graphs[device_name]
1147      self._merge_node_names[device_name] = [
1148          node for node in debug_graph.node_op_types
1149          if debug_graph.node_op_types[node] == "Merge"]
1150    return self._merge_node_names[device_name]
1151
1152  def find_some_path(self,
1153                     src_node_name,
1154                     dst_node_name,
1155                     include_control=True,
1156                     include_reversed_ref=False,
1157                     device_name=None):
1158    """Find a path between a source node and a destination node.
1159
1160    Limitation: the source and destination are required to be on the same
1161    device, i.e., this method does not yet take into account Send/Recv nodes
1162    across devices.
1163
1164    TODO(cais): Make this method work across device edges by tracing Send/Recv
1165      nodes.
1166
1167    Args:
1168      src_node_name: (`str`) name of the source node or name of an output tensor
1169        of the node.
1170      dst_node_name: (`str`) name of the destination node or name of an output
1171        tensor of the node.
1172      include_control: (`bool`) whrther control edges are considered in the
1173        graph tracing.
1174      include_reversed_ref: Whether a ref input, say from A to B, is to be also
1175        considered as an input from B to A. The rationale is that ref inputs
1176        generally let the recipient (e.g., B in this case) mutate the value of
1177        the source (e.g., A in this case). So the reverse direction of the ref
1178        edge reflects the direction of information flow.
1179      device_name: (`str`) name of the device. If there is only one device or if
1180        node_name exists on only one device, this argument is optional.
1181
1182    Returns:
1183      A path from the src_node_name to dst_node_name, as a `list` of `str`, if
1184      it exists. The list includes src_node_name as the first item and
1185      dst_node_name as the last.
1186      If such a path does not exist, `None`.
1187
1188    Raises:
1189      ValueError: If the source and destination nodes are not on the same
1190        device.
1191    """
1192    src_device_name = self._infer_device_name(device_name, src_node_name)
1193    dst_device_name = self._infer_device_name(device_name, dst_node_name)
1194
1195    if src_device_name != dst_device_name:
1196      raise ValueError(
1197          "Source (%s) and destination (%s) are not on the same device: "
1198          "%s vs. %s" % (src_node_name, dst_node_name, src_device_name,
1199                         dst_device_name))
1200
1201    input_lists = [self._debug_graphs[dst_device_name].node_inputs]
1202    debug_graph = self._debug_graphs[dst_device_name]
1203    if include_control:
1204      input_lists.append(debug_graph.node_ctrl_inputs)
1205    if include_reversed_ref:
1206      input_lists.append(debug_graph.node_reversed_ref_inputs)
1207    tracer = debug_graphs.DFSGraphTracer(
1208        input_lists,
1209        skip_node_names=self._get_merge_node_names(dst_device_name),
1210        destination_node_name=src_node_name)
1211    # Here the value of destination_node_name is src_node_name, because we
1212    # are tracing the graph from output to its inputs (i.e., going backwards
1213    # on the graph).
1214
1215    try:
1216      tracer.trace(dst_node_name)
1217    except debug_graphs.GraphTracingReachedDestination:
1218      # Prune nodes not on the path.
1219      inputs = [dst_node_name] + tracer.inputs()
1220      depth_list = [0] + tracer.depth_list()
1221
1222      path = []
1223      curr_depth = depth_list[-1]
1224      for inp, depth in zip(reversed(inputs), reversed(depth_list)):
1225        if depth == curr_depth:
1226          path.append(inp)
1227          curr_depth -= 1
1228      return path
1229
1230  def node_recipients(self, node_name, is_control=False, device_name=None):
1231    """Get recipient of the given node's output according to partition graphs.
1232
1233    Args:
1234      node_name: (`str`) name of the node.
1235      is_control: (`bool`) whether control outputs, rather than non-control
1236        outputs, are to be returned.
1237      device_name: (`str`) name of the device. If there is only one device or if
1238        node_name exists on only one device, this argument is optional.
1239
1240    Returns:
1241      (`list` of `str`) all inputs to the node, as a list of node names.
1242
1243    Raises:
1244      LookupError: If node inputs and control inputs have not been loaded
1245         from partition graphs yet.
1246    """
1247
1248    if not self._debug_graphs:
1249      raise LookupError(
1250          "Node recipients are not loaded from partition graphs yet.")
1251
1252    device_name = self._infer_device_name(device_name, node_name)
1253    debug_graph = self._debug_graphs[device_name]
1254    if is_control:
1255      return debug_graph.node_ctrl_recipients[node_name]
1256    else:
1257      return debug_graph.node_recipients[node_name]
1258
1259  def devices(self):
1260    """Get the list of device names.
1261
1262    Returns:
1263      (`list` of `str`) names of the devices.
1264    """
1265    return self._device_names
1266
1267  def node_exists(self, node_name, device_name=None):
1268    """Test if a node exists in the partition graphs.
1269
1270    Args:
1271      node_name: (`str`) name of the node to be checked.
1272      device_name: optional device name. If None, will search for the node
1273        on all available devices. Otherwise, search for the node only on
1274        the given device.
1275
1276    Returns:
1277      A boolean indicating whether the node exists.
1278
1279    Raises:
1280      LookupError: If no partition graphs have been loaded yet.
1281      ValueError: If device_name is specified but cannot be found.
1282    """
1283    if not self._debug_graphs:
1284      raise LookupError(
1285          "Nodes have not been loaded from partition graphs yet.")
1286
1287    if (device_name is not None) and device_name not in self._debug_graphs:
1288      raise ValueError(
1289          "The specified device_name '%s' cannot be found." % device_name)
1290
1291    for _, debug_graph in self._debug_graphs.items():
1292      if node_name in debug_graph.node_inputs:
1293        return True
1294    return False
1295
1296  def node_device(self, node_name):
1297    """Get the names of the devices that has nodes of the specified name.
1298
1299    Args:
1300      node_name: (`str`) name of the node.
1301
1302    Returns:
1303      (`str` or `list` of `str`) name of the device(s) on which the node of the
1304        given name is found. Returns a `str` if there is only one such device,
1305        otherwise return a `list` of `str`.
1306
1307    Raises:
1308      LookupError: If node inputs and control inputs have not been loaded
1309         from partition graphs yet.
1310      ValueError: If the node does not exist in partition graphs.
1311    """
1312    if not self._debug_graphs:
1313      raise LookupError(
1314          "Node devices are not loaded from partition graphs yet.")
1315
1316    if node_name not in self._node_devices:
1317      raise ValueError("Node '%s' does not exist in partition graphs." %
1318                       node_name)
1319
1320    output = list(self._node_devices[node_name])
1321    return output[0] if len(output) == 1 else output
1322
1323  def node_op_type(self, node_name, device_name=None):
1324    """Get the op type of given node.
1325
1326    Args:
1327      node_name: (`str`) name of the node.
1328      device_name: (`str`) name of the device. If there is only one device or if
1329        node_name exists on only one device, this argument is optional.
1330
1331    Returns:
1332      (`str`) op type of the node.
1333
1334    Raises:
1335      LookupError: If node op types have not been loaded
1336         from partition graphs yet.
1337    """
1338    if not self._debug_graphs:
1339      raise LookupError(
1340          "Node op types are not loaded from partition graphs yet.")
1341
1342    device_name = self._infer_device_name(device_name, node_name)
1343    return self._debug_graphs[device_name].node_op_types[node_name]
1344
1345  def debug_watch_keys(self, node_name, device_name=None):
1346    """Get all tensor watch keys of given node according to partition graphs.
1347
1348    Args:
1349      node_name: (`str`) name of the node.
1350      device_name: (`str`) name of the device. If there is only one device or if
1351        node_name exists on only one device, this argument is optional.
1352
1353    Returns:
1354      (`list` of `str`) all debug tensor watch keys. Returns an empty list if
1355        the node name does not correspond to any debug watch keys.
1356
1357    Raises:
1358      `LookupError`: If debug watch information has not been loaded from
1359        partition graphs yet.
1360    """
1361
1362    try:
1363      device_name = self._infer_device_name(device_name, node_name)
1364    except ValueError:
1365      return []
1366
1367    if node_name not in self._debug_watches[device_name]:
1368      return []
1369
1370    watch_keys = []
1371    for watched_slot in self._debug_watches[device_name][node_name]:
1372      debug_ops = self._debug_watches[device_name][node_name][watched_slot]
1373      for debug_op in debug_ops:
1374        watch_keys.append(
1375            _get_tensor_watch_key(node_name, watched_slot, debug_op))
1376
1377    return watch_keys
1378
1379  def watch_key_to_data(self, debug_watch_key, device_name=None):
1380    """Get all `DebugTensorDatum` instances corresponding to a debug watch key.
1381
1382    Args:
1383      debug_watch_key: (`str`) debug watch key.
1384      device_name: (`str`) name of the device. If there is only one device or if
1385        the specified debug_watch_key exists on only one device, this argument
1386        is optional.
1387
1388    Returns:
1389      A list of `DebugTensorDatum` instances that correspond to the debug watch
1390      key. If the watch key does not exist, returns an empty list.
1391
1392    Raises:
1393      ValueError: If there are multiple devices that have the debug_watch_key,
1394        but device_name is not specified.
1395    """
1396    if device_name is None:
1397      matching_device_names = [
1398          name for name in self._watch_key_to_datum
1399          if debug_watch_key in self._watch_key_to_datum[name]]
1400      if not matching_device_names:
1401        return []
1402      elif len(matching_device_names) == 1:
1403        device_name = matching_device_names[0]
1404      else:
1405        raise ValueError(
1406            "The debug watch key '%s' exists on multiple (%d) devices, but "
1407            "device name is not specified." %
1408            (debug_watch_key, len(matching_device_names)))
1409    elif device_name not in self._debug_key_to_datum:
1410      raise ValueError(
1411          "There is no device named '%s' consisting of debug watch keys." %
1412          device_name)
1413
1414    return self._watch_key_to_datum[device_name].get(debug_watch_key, [])
1415
1416  def find(self,
1417           predicate,
1418           first_n=0,
1419           device_name=None,
1420           exclude_node_names=None):
1421    """Find dumped tensor data by a certain predicate.
1422
1423    Args:
1424      predicate: A callable that takes two input arguments:
1425
1426        ```python
1427        def predicate(debug_tensor_datum, tensor):
1428          # returns a bool
1429        ```
1430
1431        where `debug_tensor_datum` is an instance of `DebugTensorDatum`, which
1432        carries the metadata, such as the `Tensor`'s node name, output slot
1433        timestamp, debug op name, etc.; and `tensor` is the dumped tensor value
1434        as a `numpy.ndarray`.
1435      first_n: (`int`) return only the first n `DebugTensotDatum` instances (in
1436        time order) for which the predicate returns True. To return all the
1437        `DebugTensotDatum` instances, let first_n be <= 0.
1438      device_name: optional device name.
1439      exclude_node_names: Optional regular expression to exclude nodes with
1440        names matching the regular expression.
1441
1442    Returns:
1443      A list of all `DebugTensorDatum` objects in this `DebugDumpDir` object
1444       for which predicate returns True, sorted in ascending order of the
1445       timestamp.
1446    """
1447    if exclude_node_names:
1448      exclude_node_names = re.compile(exclude_node_names)
1449
1450    matched_data = []
1451    for device in (self._dump_tensor_data if device_name is None
1452                   else (self._dump_tensor_data[device_name],)):
1453      for datum in self._dump_tensor_data[device]:
1454        if exclude_node_names and exclude_node_names.match(datum.node_name):
1455          continue
1456
1457        if predicate(datum, datum.get_tensor()):
1458          matched_data.append(datum)
1459
1460          if first_n > 0 and len(matched_data) >= first_n:
1461            return matched_data
1462
1463    return matched_data
1464
1465  def get_tensor_file_paths(self,
1466                            node_name,
1467                            output_slot,
1468                            debug_op,
1469                            device_name=None):
1470    """Get the file paths from a debug-dumped tensor.
1471
1472    Args:
1473      node_name: (`str`) name of the node that the tensor is produced by.
1474      output_slot: (`int`) output slot index of tensor.
1475      debug_op: (`str`) name of the debug op.
1476      device_name: (`str`) name of the device. If there is only one device or if
1477        the specified debug_watch_key exists on only one device, this argument
1478        is optional.
1479
1480    Returns:
1481      List of file path(s) loaded. This is a list because each debugged tensor
1482        may be dumped multiple times.
1483
1484    Raises:
1485      WatchKeyDoesNotExistInDebugDumpDirError: If the tensor does not exist in
1486        the debug-dump data.
1487    """
1488
1489    device_name = self._infer_device_name(device_name, node_name)
1490    watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op)
1491    if watch_key not in self._watch_key_to_datum[device_name]:
1492      raise WatchKeyDoesNotExistInDebugDumpDirError(
1493          "Watch key \"%s\" does not exist in the debug dump of device %s" %
1494          (watch_key, device_name))
1495
1496    return [datum.file_path for datum in
1497            self._watch_key_to_datum[device_name][watch_key]]
1498
1499  def get_tensors(self, node_name, output_slot, debug_op, device_name=None):
1500    """Get the tensor value from for a debug-dumped tensor.
1501
1502    The tensor may be dumped multiple times in the dump root directory, so a
1503    list of tensors (`numpy.ndarray`) is returned.
1504
1505    Args:
1506      node_name: (`str`) name of the node that the tensor is produced by.
1507      output_slot: (`int`) output slot index of tensor.
1508      debug_op: (`str`) name of the debug op.
1509      device_name: (`str`) name of the device. If there is only one device or if
1510        the specified debug_watch_key exists on only one device, this argument
1511        is optional.
1512
1513    Returns:
1514      List of tensors (`numpy.ndarray`) loaded from the debug-dump file(s).
1515
1516    Raises:
1517      WatchKeyDoesNotExistInDebugDumpDirError: If the tensor does not exist in
1518        the debug-dump data.
1519    """
1520
1521    watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op)
1522    try:
1523      device_name = self._infer_device_name(device_name, node_name)
1524      return [datum.get_tensor() for datum in
1525              self._watch_key_to_datum[device_name][watch_key]]
1526    except (ValueError, KeyError):
1527      raise WatchKeyDoesNotExistInDebugDumpDirError(
1528          "Watch key \"%s\" does not exist in the debug dump of device %s" %
1529          (watch_key, device_name))
1530
1531  def get_rel_timestamps(self,
1532                         node_name,
1533                         output_slot,
1534                         debug_op,
1535                         device_name=None):
1536    """Get the relative timestamp from for a debug-dumped tensor.
1537
1538    Relative timestamp means (absolute timestamp - `t0`), where `t0` is the
1539    absolute timestamp of the first dumped tensor in the dump root. The tensor
1540    may be dumped multiple times in the dump root directory, so a list of
1541    relative timestamps (`numpy.ndarray`) is returned.
1542
1543    Args:
1544      node_name: (`str`) name of the node that the tensor is produced by.
1545      output_slot: (`int`) output slot index of tensor.
1546      debug_op: (`str`) name of the debug op.
1547      device_name: (`str`) name of the device. If there is only one device or if
1548        the specified debug_watch_key exists on only one device, this argument
1549        is optional.
1550
1551    Returns:
1552      (`list` of `int`) list of relative timestamps.
1553
1554    Raises:
1555      WatchKeyDoesNotExistInDebugDumpDirError: If the tensor watch key does not
1556        exist in the debug dump data.
1557    """
1558
1559    device_name = self._infer_device_name(device_name, node_name)
1560    watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op)
1561    if watch_key not in self._watch_key_to_datum[device_name]:
1562      raise WatchKeyDoesNotExistInDebugDumpDirError(
1563          "Watch key \"%s\" does not exist in the debug dump" % watch_key)
1564
1565    # TODO(cais): Figure out whether this should be relative to the global t0.
1566    return self._watch_key_to_rel_time[device_name][watch_key]
1567
1568  def get_dump_sizes_bytes(self,
1569                           node_name,
1570                           output_slot,
1571                           debug_op,
1572                           device_name=None):
1573    """Get the sizes of the dump files for a debug-dumped tensor.
1574
1575    Unit of the file size: byte.
1576
1577    Args:
1578      node_name: (`str`) name of the node that the tensor is produced by.
1579      output_slot: (`int`) output slot index of tensor.
1580      debug_op: (`str`) name of the debug op.
1581      device_name: (`str`) name of the device. If there is only one device or if
1582        the specified debug_watch_key exists on only one device, this argument
1583        is optional.
1584
1585    Returns:
1586      (`list` of `int`): list of dump file sizes in bytes.
1587
1588    Raises:
1589      WatchKeyDoesNotExistInDebugDumpDirError: If the tensor watch key does not
1590        exist in the debug dump data.
1591    """
1592
1593    device_name = self._infer_device_name(device_name, node_name)
1594    watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op)
1595    if watch_key not in self._watch_key_to_datum[device_name]:
1596      raise WatchKeyDoesNotExistInDebugDumpDirError(
1597          "Watch key \"%s\" does not exist in the debug dump of device %s" %
1598          (watch_key, device_name))
1599
1600    return self._watch_key_to_dump_size_bytes[device_name][watch_key]
1601
1602  def node_traceback(self, element_name):
1603    """Try to retrieve the Python traceback of node's construction.
1604
1605    Args:
1606      element_name: (`str`) Name of a graph element (node or tensor).
1607
1608    Returns:
1609      (list) The traceback list object as returned by the `extract_trace`
1610        method of Python's traceback module.
1611
1612    Raises:
1613      LookupError: If Python graph is not available for traceback lookup.
1614      KeyError: If the node cannot be found in the Python graph loaded.
1615    """
1616
1617    if self._python_graph is None:
1618      raise LookupError("Python graph is not available for traceback lookup")
1619
1620    node_name = debug_graphs.get_node_name(element_name)
1621    if node_name not in self._node_traceback:
1622      raise KeyError("Cannot find node \"%s\" in Python graph" % node_name)
1623
1624    return self._node_traceback[node_name]
1625