1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes and functions to handle debug-dump data of TensorFlow Debugger."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import glob
23import json
24import os
25import platform
26import re
27
28import numpy as np
29import six
30
31from tensorflow.core.framework import graph_pb2
32from tensorflow.core.framework import types_pb2
33from tensorflow.core.util import event_pb2
34from tensorflow.python.debug.lib import debug_graphs
35from tensorflow.python.framework import tensor_util
36from tensorflow.python.platform import gfile
37from tensorflow.python.platform import tf_logging as logging
38from tensorflow.python.util import compat
39
40
41# TODO(cais): Tie these string constants in with C++?
42METADATA_FILE_PREFIX = "_tfdbg_"
43CORE_METADATA_TAG = "core_metadata_"
44GRAPH_FILE_TAG = "graph_"
45DEVICE_TAG = "device_"
46HASH_TAG = "hash"
47
48FETCHES_INFO_FILE_TAG = "fetches_info_"
49FEED_KEYS_INFO_FILE_TAG = "feed_keys_info_"
50
51
52def _glob(glob_pattern):
53  if platform.system() == "Windows":
54    return glob.glob(glob_pattern)
55  else:
56    return gfile.Glob(glob_pattern)
57
58
59class InconvertibleTensorProto(object):
60  """Represents a TensorProto that cannot be converted to np.ndarray."""
61
62  def __init__(self, tensor_proto, initialized=True):
63    """Constructor.
64
65    Args:
66      tensor_proto: the `TensorProto` object that cannot be represented as a
67        `np.ndarray` object.
68      initialized: (`bool`) whether the Tensor is initialized.
69    """
70    self._tensor_proto = tensor_proto
71    self._initialized = initialized
72
73  def __str__(self):
74    output = "" if self._initialized else "Uninitialized tensor:\n"
75    output += str(self._tensor_proto)
76    return output
77
78  @property
79  def initialized(self):
80    return self._initialized
81
82
83def load_tensor_from_event_file(event_file_path):
84  """Load a tensor from an event file.
85
86  Assumes that the event file contains a `Event` protobuf and the `Event`
87  protobuf contains a `Tensor` value.
88
89  Args:
90    event_file_path: (`str`) path to the event file.
91
92  Returns:
93    The tensor value loaded from the event file, as a `numpy.ndarray`. For
94    uninitialized Tensors, returns `None`. For Tensors of data types that
95    cannot be converted to `numpy.ndarray` (e.g., `tf.resource`), return
96    `None`.
97  """
98
99  event = event_pb2.Event()
100  with gfile.Open(event_file_path, "rb") as f:
101    event.ParseFromString(f.read())
102    return load_tensor_from_event(event)
103
104
105def load_tensor_from_event(event):
106  """Load a tensor from an Event proto.
107
108  Args:
109    event: The Event proto, assumed to hold a tensor value in its
110        summary.value[0] field.
111
112  Returns:
113    The tensor value loaded from the event file, as a `numpy.ndarray`, if
114    representation of the tensor value by a `numpy.ndarray` is possible.
115    For uninitialized Tensors, returns `None`. For Tensors of data types that
116    cannot be represented as `numpy.ndarray` (e.g., `tf.resource`), return
117    the `TensorProto` protobuf object without converting it to a
118    `numpy.ndarray`.
119  """
120
121  tensor_proto = event.summary.value[0].tensor
122  shape = tensor_util.TensorShapeProtoToList(tensor_proto.tensor_shape)
123  num_elements = 1
124  for shape_dim in shape:
125    num_elements *= shape_dim
126
127  if tensor_proto.tensor_content or tensor_proto.string_val or not num_elements:
128    # Initialized tensor or empty tensor.
129    if tensor_proto.dtype == types_pb2.DT_RESOURCE:
130      tensor_value = InconvertibleTensorProto(tensor_proto)
131    else:
132      try:
133        tensor_value = tensor_util.MakeNdarray(tensor_proto)
134      except KeyError:
135        tensor_value = InconvertibleTensorProto(tensor_proto)
136  else:
137    # Uninitialized tensor or tensor of unconvertible data type.
138    tensor_value = InconvertibleTensorProto(tensor_proto, False)
139
140  return tensor_value
141
142
143def _load_graph_def_from_event_file(event_file_path):
144  event = event_pb2.Event()
145  with gfile.Open(event_file_path, "rb") as f:
146    event.ParseFromString(f.read())
147
148  return graph_pb2.GraphDef.FromString(event.graph_def)
149
150
151def _load_log_message_from_event_file(event_file_path):
152  event = event_pb2.Event()
153  with gfile.Open(event_file_path, "rb") as f:
154    event.ParseFromString(f.read())
155
156  return event.log_message.message
157
158
159def _is_graph_file(file_name):
160  return file_name.startswith(METADATA_FILE_PREFIX + GRAPH_FILE_TAG)
161
162
163def _is_run_fetches_info_file(file_name):
164  return file_name == METADATA_FILE_PREFIX + FETCHES_INFO_FILE_TAG
165
166
167def _is_run_feed_keys_info_file(file_name):
168  return file_name == METADATA_FILE_PREFIX + FEED_KEYS_INFO_FILE_TAG
169
170
171def _get_tensor_name(node_name, output_slot):
172  """Get tensor name given node name and output slot index.
173
174  Args:
175    node_name: Name of the node that outputs the tensor, as a string.
176    output_slot: Output slot index of the tensor, as an integer.
177
178  Returns:
179    Name of the tensor, as a string.
180  """
181
182  return "%s:%d" % (node_name, output_slot)
183
184
185def _get_tensor_watch_key(node_name, output_slot, debug_op):
186  """Get the string representation of a debug watch on a tensor.
187
188  Args:
189    node_name: Name of the node by which the watched tensor is produced, as a
190        string.
191    output_slot: Output slot index of the tensor, as an integer.
192    debug_op: Name of the debug op that is used to watch the tensor, as a
193        string.
194
195  Returns:
196    A string representing the debug watch on the tensor (i.e., the "watch
197        key").
198  """
199  return "%s:%s" % (_get_tensor_name(node_name, output_slot), debug_op)
200
201
202def has_inf_or_nan(datum, tensor):
203  """A predicate for whether a tensor consists of any bad numerical values.
204
205  This predicate is common enough to merit definition in this module.
206  Bad numerical values include `nan`s and `inf`s.
207  The signature of this function follows the requirement of the method
208  `DebugDumpDir.find()`.
209
210  Args:
211    datum: (`DebugTensorDatum`) Datum metadata.
212    tensor: (`numpy.ndarray` or None) Value of the tensor. None represents
213      an uninitialized tensor.
214
215  Returns:
216    (`bool`) True if and only if tensor consists of any nan or inf values.
217  """
218
219  _ = datum  # Datum metadata is unused in this predicate.
220
221  if isinstance(tensor, InconvertibleTensorProto):
222    # Uninitialized tensor doesn't have bad numerical values.
223    # Also return False for data types that cannot be represented as numpy
224    # arrays.
225    return False
226  elif (np.issubdtype(tensor.dtype, np.floating) or
227        np.issubdtype(tensor.dtype, np.complex) or
228        np.issubdtype(tensor.dtype, np.integer)):
229    return np.any(np.isnan(tensor)) or np.any(np.isinf(tensor))
230  else:
231    return False
232
233
234_CoreMetadata = collections.namedtuple("CoreMetadata", [
235    "global_step", "session_run_index", "executor_step_index", "input_names",
236    "output_names", "target_nodes"
237])
238
239
240def extract_core_metadata_from_event_proto(event):
241  json_metadata = json.loads(event.log_message.message)
242  return _CoreMetadata(json_metadata["global_step"],
243                       json_metadata["session_run_index"],
244                       json_metadata["executor_step_index"],
245                       json_metadata["input_names"],
246                       json_metadata["output_names"],
247                       json_metadata["target_nodes"])
248
249
250def device_name_to_device_path(device_name):
251  """Convert device name to device path."""
252  device_name_items = compat.as_text(device_name).split("/")
253  device_name_items = [item.replace(":", "_") for item in device_name_items]
254  return METADATA_FILE_PREFIX + DEVICE_TAG + ",".join(device_name_items)
255
256
257def device_path_to_device_name(device_dir):
258  """Parse device name from device path.
259
260  Args:
261    device_dir: (str) a directory name for the device.
262
263  Returns:
264    (str) parsed device name.
265  """
266  path_items = os.path.basename(device_dir)[
267      len(METADATA_FILE_PREFIX) + len(DEVICE_TAG):].split(",")
268  return "/".join([
269      path_item.replace("device_", "device:").replace("_", ":", 1)
270      for path_item in path_items])
271
272
273class DebugTensorDatum(object):
274  """A single tensor dumped by TensorFlow Debugger (tfdbg).
275
276  Contains metadata about the dumped tensor, including `timestamp`,
277  `node_name`, `output_slot`, `debug_op`, and path to the dump file
278  (`file_path`).
279
280  This type does not hold the generally space-expensive tensor value (numpy
281  array). Instead, it points to the file from which the tensor value can be
282  loaded (with the `get_tensor` method) if needed.
283  """
284
285  def __init__(self, dump_root, debug_dump_rel_path):
286    """`DebugTensorDatum` constructor.
287
288    Args:
289      dump_root: (`str`) Debug dump root directory. This path should not include
290        the path component that represents the device name (see also below).
291      debug_dump_rel_path: (`str`) Path to a debug dump file, relative to the
292        `dump_root`. The first item of this relative path is assumed to be
293        a path representing the name of the device that the Tensor belongs to.
294        See `device_path_to_device_name` for more details on the device path.
295        For example, suppose the debug dump root
296        directory is `/tmp/tfdbg_1` and the dump file is at
297        `/tmp/tfdbg_1/<device_path>/>ns_1/node_a_0_DebugIdentity_123456789`,
298        then the value of the debug_dump_rel_path should be
299        `<device_path>/ns_1/node_a_0_DebugIdentity_1234456789`.
300
301    Raises:
302      ValueError: If the base file name of the dump file does not conform to
303        the dump file naming pattern:
304        `node_name`_`output_slot`_`debug_op`_`timestamp`
305    """
306
307    path_components = os.path.normpath(debug_dump_rel_path).split(os.sep)
308    self._device_name = device_path_to_device_name(path_components[0])
309    base = path_components[-1]
310    if base.count("_") < 3:
311      raise ValueError(
312          "Dump file path does not conform to the naming pattern: %s" % base)
313
314    self._extended_timestamp = base.split("_")[-1]
315    # It may include an index suffix at the end if file path collision happened
316    # due to identical timestamps.
317    if "-" in self._extended_timestamp:
318      self._timestamp = int(
319          self._extended_timestamp[:self._extended_timestamp.find("-")])
320    else:
321      self._timestamp = int(self._extended_timestamp)
322
323    self._debug_op = base.split("_")[-2]
324    self._output_slot = int(base.split("_")[-3])
325
326    node_base_name = "_".join(base.split("_")[:-3])
327    self._node_name = "/".join(path_components[1:-1] + [node_base_name])
328
329    self._file_path = os.path.join(dump_root, debug_dump_rel_path)
330    self._dump_size_bytes = (gfile.Stat(self._file_path).length if
331                             gfile.Exists(self._file_path) else None)
332
333  def __str__(self):
334    return "{DebugTensorDatum (%s) %s:%d @ %s @ %d}" % (self.device_name,
335                                                        self.node_name,
336                                                        self.output_slot,
337                                                        self.debug_op,
338                                                        self.timestamp)
339
340  def __repr__(self):
341    return self.__str__()
342
343  def get_tensor(self):
344    """Get tensor from the dump (`Event`) file.
345
346    Returns:
347      The tensor loaded from the dump (`Event`) file.
348    """
349
350    return load_tensor_from_event_file(self.file_path)
351
352  # TODO(cais): Add time unit suffix to timestamp and t0 (us).
353  @property
354  def timestamp(self):
355    """Timestamp of when this tensor value was dumped.
356
357    Returns:
358      (`int`) The timestamp in microseconds.
359    """
360
361    return self._timestamp
362
363  @property
364  def extended_timestamp(self):
365    """Extended timestamp, possibly with an index suffix.
366
367    The index suffix, e.g., "-1", is for disambiguating multiple dumps of the
368    same tensor with the same timestamp, which can occur if the dumping events
369    are spaced by shorter than the temporal resolution of the timestamps.
370
371    Returns:
372      (`str`) The extended timestamp.
373    """
374
375    return self._extended_timestamp
376
377  @property
378  def debug_op(self):
379    """Name of the debug op.
380
381    Returns:
382      (`str`) debug op name (e.g., `DebugIdentity`).
383    """
384
385    return self._debug_op
386
387  @property
388  def device_name(self):
389    """Name of the device that the tensor belongs to.
390
391    Returns:
392      (`str`) device name.
393    """
394
395    return self._device_name
396
397  @property
398  def node_name(self):
399    """Name of the node from which the tensor value was dumped.
400
401    Returns:
402      (`str`) name of the node watched by the debug op.
403    """
404
405    return self._node_name
406
407  @property
408  def output_slot(self):
409    """Output slot index from which the tensor value was dumped.
410
411    Returns:
412      (`int`) output slot index watched by the debug op.
413    """
414
415    return self._output_slot
416
417  @property
418  def tensor_name(self):
419    """Name of the tensor watched by the debug op.
420
421    Returns:
422      (`str`) `Tensor` name, in the form of `node_name`:`output_slot`
423    """
424
425    return _get_tensor_name(self.node_name, self.output_slot)
426
427  @property
428  def watch_key(self):
429    """Watch key identities a debug watch on a tensor.
430
431    Returns:
432      (`str`) A watch key, in the form of `tensor_name`:`debug_op`.
433    """
434
435    return _get_tensor_watch_key(self.node_name, self.output_slot,
436                                 self.debug_op)
437
438  @property
439  def file_path(self):
440    """Path to the file which stores the value of the dumped tensor."""
441
442    return self._file_path
443
444  @property
445  def dump_size_bytes(self):
446    """Size of the dump file.
447
448    Unit: byte.
449
450    Returns:
451      If the dump file exists, size of the dump file, in bytes.
452      If the dump file does not exist, None.
453    """
454
455    return self._dump_size_bytes
456
457
458class WatchKeyDoesNotExistInDebugDumpDirError(ValueError):
459  pass
460
461
462class DebugDumpDir(object):
463  """Data set from a debug-dump directory on filesystem.
464
465  An instance of `DebugDumpDir` contains all `DebugTensorDatum` instances
466  in a tfdbg dump root directory.
467  """
468
469  def __init__(self, dump_root, partition_graphs=None, validate=True):
470    """`DebugDumpDir` constructor.
471
472    Args:
473      dump_root: (`str`) path to the dump root directory.
474      partition_graphs: A repeated field of GraphDefs representing the
475          partition graphs executed by the TensorFlow runtime.
476      validate: (`bool`) whether the dump files are to be validated against the
477          partition graphs.
478
479    Raises:
480      IOError: If dump_root does not exist as a directory.
481      ValueError: If more than one core metadata file is found under the dump
482        root directory.
483    """
484
485    if not gfile.IsDirectory(dump_root):
486      raise IOError("Dump root directory %s does not exist" % dump_root)
487
488    self._core_metadata = []
489
490    # Find the list of devices.
491    self._dump_root = dump_root
492
493    self._load_core_metadata()
494    self._load_fetches_info()
495    self._load_feeds_info()
496    self._load_all_device_dumps(partition_graphs, validate)
497
498    self._python_graph = None
499
500  def _load_all_device_dumps(self, partition_graphs, validate):
501    """Load the dump data for all devices."""
502    device_dirs = _glob(os.path.join(
503        self._dump_root, METADATA_FILE_PREFIX + DEVICE_TAG + "*"))
504
505    self._device_names = []
506    self._t0s = {}
507    self._dump_tensor_data = {}
508    self._dump_graph_file_paths = {}
509    self._debug_watches = {}
510    self._watch_key_to_devices = {}
511    self._watch_key_to_datum = {}
512    self._watch_key_to_rel_time = {}
513    self._watch_key_to_dump_size_bytes = {}
514    for device_dir in device_dirs:
515      device_name = device_path_to_device_name(device_dir)
516      self._device_names.append(device_name)
517      self._load_device_dumps(device_name, device_dir)
518    self._load_partition_graphs(partition_graphs, validate)
519    self._calculate_t0()
520
521    for device_name in self._device_names:
522      self._create_tensor_watch_maps(device_name)
523
524  def _load_device_dumps(self, device_name, device_root):
525    """Load `DebugTensorDatum` instances from the dump root of a given device.
526
527    Populates a map {device_name: a list of `DebugTensorDatum`}, where the list
528    is sorted by ascending timestamp.
529
530    This sorting order reflects the order in which the TensorFlow executor
531    processed the nodes of the graph. It is (one of many possible) topological
532    sort of the nodes. This is useful for displaying tensors in the debugger
533    frontend as well as for the use case in which the user wants to find a
534    "culprit tensor", i.e., the first tensor in the graph that exhibits certain
535    problematic properties, i.e., all zero values, or bad numerical values such
536    as nan and inf.
537
538    In addition, creates a map from node name to debug watches. In this Map,
539    the key is the watched node name; the value is a dictionary.
540    Of this dictionary, the key is the watched_output_slot.
541
542    This method attempts to load the debug watches from the tensor dump files
543    first, before loading the full set of debug watches from the partition
544    graphs as done later. This is necessary because sometimes the partition
545    graphs may not be available, e.g., when the run errors out.
546
547    Args:
548      device_name: (`str`) name of the device.
549      device_root: (`str`) dump root directory of the given device.
550
551    Raises:
552      ValueError: If GraphDef for the device is not available.
553    """
554
555    self._dump_tensor_data[device_name] = []
556    self._debug_watches[device_name] = collections.defaultdict(
557        lambda: collections.defaultdict(set))
558
559    for root, _, files in gfile.Walk(device_root):
560      for f in files:
561        if _is_graph_file(f):
562          self._dump_graph_file_paths[device_name] = os.path.join(root, f)
563        else:
564          datum = self._dump_file_name_to_datum(root, f)
565          self._dump_tensor_data[device_name].append(datum)
566          self._debug_watches[device_name][datum.node_name][
567              datum.output_slot].add(datum.debug_op)
568
569    self._dump_tensor_data[device_name] = sorted(
570        self._dump_tensor_data[device_name],
571        key=lambda x: x.extended_timestamp)
572
573    if self._dump_tensor_data[device_name]:
574      self._t0s[device_name] = self._dump_tensor_data[device_name][0].timestamp
575    else:
576      self._t0s[device_name] = None
577
578  def _calculate_t0(self):
579    """Calculate the first timestamp across all devices."""
580    t0s = [t0 for t0 in six.itervalues(self._t0s) if t0 is not None]
581    self._t0 = min(t0s) if t0s else None
582
583  def _load_core_metadata(self):
584    core_metadata_files = _glob(os.path.join(
585        self._dump_root, METADATA_FILE_PREFIX + CORE_METADATA_TAG + "*"))
586    for core_metadata_file in core_metadata_files:
587      with gfile.Open(core_metadata_file, "rb") as f:
588        event = event_pb2.Event()
589        event.ParseFromString(f.read())
590        self._core_metadata.append(
591            extract_core_metadata_from_event_proto(event))
592
593  def _load_fetches_info(self):
594    fetches_info_files = _glob(os.path.join(
595        self._dump_root, METADATA_FILE_PREFIX + FETCHES_INFO_FILE_TAG + "*"))
596    self._run_fetches_info = []
597    for fetches_info_file in fetches_info_files:
598      self._run_fetches_info.append(
599          _load_log_message_from_event_file(fetches_info_file))
600
601  def _load_feeds_info(self):
602    feeds_info_files = _glob(os.path.join(
603        self._dump_root, METADATA_FILE_PREFIX + FEED_KEYS_INFO_FILE_TAG + "*"))
604    self._run_feed_keys_info = []
605    for feeds_info_file in feeds_info_files:
606      self._run_feed_keys_info.append(
607          _load_log_message_from_event_file(feeds_info_file))
608
609  def _dump_file_name_to_datum(self, dir_name, file_name):
610    """Obtain a DebugTensorDatum from the directory and file name.
611
612    Args:
613      dir_name: (`str`) Name of the directory in which the dump file resides.
614      file_name: (`str`) Base name of the dump file.
615
616    Returns:
617      (`DebugTensorDatum`) The `DebugTensorDatum` loaded from the dump file.
618    """
619
620    # Calculate the relative path of the dump file with respect to the root.
621    debug_dump_rel_path = os.path.join(
622        os.path.relpath(dir_name, self._dump_root), file_name)
623    return DebugTensorDatum(self._dump_root, debug_dump_rel_path)
624
625  def _create_tensor_watch_maps(self, device_name):
626    """Create maps from tensor watch keys to datum and to timestamps.
627
628    Create a map from watch key (tensor name + debug op) to `DebugTensorDatum`
629    item. Also make a map from watch key to relative timestamp.
630    "relative" means (absolute timestamp - t0).
631
632    Args:
633      device_name: (str) name of the device.
634    """
635
636    self._watch_key_to_datum[device_name] = {}
637    self._watch_key_to_rel_time[device_name] = {}
638    self._watch_key_to_dump_size_bytes[device_name] = {}
639    for datum in self._dump_tensor_data[device_name]:
640      if datum.watch_key not in self._watch_key_to_devices:
641        self._watch_key_to_devices[datum.watch_key] = {device_name}
642      else:
643        self._watch_key_to_devices[datum.watch_key].add(device_name)
644
645      if datum.watch_key not in self._watch_key_to_datum[device_name]:
646        self._watch_key_to_datum[device_name][datum.watch_key] = [datum]
647        self._watch_key_to_rel_time[device_name][datum.watch_key] = [
648            datum.timestamp - self._t0]
649        self._watch_key_to_dump_size_bytes[device_name][datum.watch_key] = [
650            datum.dump_size_bytes]
651      else:
652        self._watch_key_to_datum[device_name][datum.watch_key].append(datum)
653        self._watch_key_to_rel_time[device_name][datum.watch_key].append(
654            datum.timestamp - self._t0)
655        self._watch_key_to_dump_size_bytes[device_name][datum.watch_key].append(
656            datum.dump_size_bytes)
657
658  def set_python_graph(self, python_graph):
659    """Provide Python `Graph` object to the wrapper.
660
661    Unlike the partition graphs, which are protobuf `GraphDef` objects, `Graph`
662    is a Python object and carries additional information such as the traceback
663    of the construction of the nodes in the graph.
664
665    Args:
666      python_graph: (ops.Graph) The Python Graph object.
667    """
668
669    self._python_graph = python_graph
670    self._node_traceback = {}
671    if self._python_graph:
672      for op in self._python_graph.get_operations():
673        self._node_traceback[op.name] = tuple(map(tuple, op.traceback))
674
675  @property
676  def python_graph(self):
677    """Get the Python graph.
678
679    Returns:
680      If the Python graph has been set, returns a `tf.Graph` object. Otherwise,
681      returns None.
682    """
683
684    return self._python_graph
685
686  @property
687  def core_metadata(self):
688    """Metadata about the `Session.run()` call from the core runtime.
689
690    Of the three counters available in the return value, `global_step` is
691    supplied by the caller of the debugged `Session.run()`, while
692    `session_run_index` and `executor_step_index` are determined by the state
693    of the core runtime, automatically. For the same fetch list, feed keys and
694    debug tensor watch options, the same executor will be used and
695    `executor_step_index` should increase by one at a time. However, runs with
696    different fetch lists, feed keys and debug_tensor watch options that all
697    share the same `Session` object can lead to gaps in `session_run_index`.
698
699    Returns:
700      If core metadata are loaded, a `namedtuple` with the fields:
701        `global_step`: A global step count supplied by the caller of
702          `Session.run()`. It is optional to the caller. If the caller did not
703          supply this parameter, its value will be -1.
704        `session_run_index`: A sorted index for Run() calls to the underlying
705          TensorFlow `Session` object.
706        `executor_step_index`: A counter for invocations of a given runtime
707          executor. The same executor is re-used for the same fetched tensors,
708          target nodes, input feed keys and debug tensor watch options.
709        `input_names`: Names of the input (feed) Tensors.
710        `output_names`: Names of the output (fetched) Tensors.
711        `target_nodes`: Names of the target nodes.
712      If the core metadata have not been loaded, `None`.
713      If more than one core metadata files exist, return a list of the
714        `nametuple` described above.
715    """
716
717    output = self._core_metadata
718    return output[0] if len(output) == 1 else output
719
720  @property
721  def dumped_tensor_data(self):
722    """Retrieve dumped tensor data."""
723    if len(self.devices()) == 1:
724      return self._dump_tensor_data[self.devices()[0]]
725    else:
726      all_devices_data = six.itervalues(self._dump_tensor_data)
727      data = []
728      for device_data in all_devices_data:
729        data.extend(device_data)
730      return sorted(data, key=lambda x: x.extended_timestamp)
731
732  @property
733  def t0(self):
734    """Absolute timestamp of the first dumped tensor across all devices.
735
736    Returns:
737      (`int`) absolute timestamp of the first dumped tensor, in microseconds.
738    """
739    return self._t0
740
741  @property
742  def size(self):
743    """Total number of dumped tensors in the dump root directory.
744
745    Returns:
746      (`int`) The total number of dumped tensors in the dump root directory.
747    """
748    return sum(len(self._dump_tensor_data[device_name])
749               for device_name in self._dump_tensor_data)
750
751  def _load_partition_graphs(self, client_partition_graphs, validate):
752    """Load and process partition graphs.
753
754    Load the graphs; parse the input and control input structure; obtain the
755    device and op type of each node; remove the Copy and debug ops inserted
756    by the debugger. The gathered information can be used to validate the
757    tensor dumps.
758
759    Args:
760      client_partition_graphs: A repeated field of GraphDefs representing the
761        partition graphs executed by the TensorFlow runtime, from the Python
762        client. These partition graphs are used only if partition graphs
763        cannot be loaded from the dump directory on the file system.
764      validate: (`bool`) Whether the dump files are to be validated against the
765        partition graphs.
766
767    Raises:
768      ValueError: If the partition GraphDef of one or more devices fail to be
769        loaded.
770    """
771    self._debug_graphs = {}
772    self._node_devices = {}
773
774    partition_graphs_and_device_names = []
775    for device_name in self._device_names:
776      partition_graph = None
777      if device_name in self._dump_graph_file_paths:
778        partition_graph = _load_graph_def_from_event_file(
779            self._dump_graph_file_paths[device_name])
780      else:
781        logging.warn(
782            "Failed to load partition graphs for device %s from disk. "
783            "As a fallback, the client graphs will be used. This "
784            "may cause mismatches in device names." % device_name)
785        partition_graph = self._find_partition_graph(client_partition_graphs,
786                                                     device_name)
787
788      if partition_graph:
789        partition_graphs_and_device_names.append((partition_graph,
790                                                  device_name))
791
792    for partition_graph, maybe_device_name in partition_graphs_and_device_names:
793      debug_graph = debug_graphs.DebugGraph(partition_graph,
794                                            device_name=maybe_device_name)
795      self._debug_graphs[debug_graph.device_name] = debug_graph
796      self._collect_node_devices(debug_graph)
797
798      if validate and debug_graph.device_name in self._dump_tensor_data:
799        self._validate_dump_with_graphs(debug_graph.device_name)
800
801  def _find_partition_graph(self, partition_graphs, device_name):
802    if partition_graphs is None:
803      return None
804    else:
805      for graph_def in partition_graphs:
806        for node_def in graph_def.node:
807          if node_def.device == device_name:
808            return graph_def
809      return None
810
811  def _collect_node_devices(self, debug_graph):
812    for node_name in debug_graph.node_devices:
813      if node_name in self._node_devices:
814        self._node_devices[node_name] = self._node_devices[node_name].union(
815            debug_graph.node_devices[node_name])
816      else:
817        self._node_devices[node_name] = debug_graph.node_devices[node_name]
818
819  def _validate_dump_with_graphs(self, device_name):
820    """Validate the dumped tensor data against the partition graphs.
821
822    Only the watched nodes are validated by this method, because tfdbg allows
823    clients to watch only a subset of the nodes.
824
825    Args:
826      device_name: (`str`) device name.
827
828    Raises:
829      LookupError: If the partition graphs have not been loaded yet.
830      ValueError: If dumps contain node names not found in partition graph.
831        Or if the temporal order of the dump's timestamps violate the
832        input relations on the partition graphs.
833    """
834    if not self._debug_graphs:
835      raise LookupError(
836          "No partition graphs loaded for device %s" % device_name)
837    debug_graph = self._debug_graphs[device_name]
838
839    # Verify that the node names in the dump data are all present in the
840    # partition graphs.
841    for datum in self._dump_tensor_data[device_name]:
842      if datum.node_name not in debug_graph.node_inputs:
843        raise ValueError("Node name '%s' is not found in partition graphs of "
844                         "device %s." % (datum.node_name, device_name))
845
846    pending_inputs = {}
847    for node in debug_graph.node_inputs:
848      pending_inputs[node] = []
849      inputs = debug_graph.node_inputs[node]
850      for inp in inputs:
851        inp_node = debug_graphs.get_node_name(inp)
852        inp_output_slot = debug_graphs.get_output_slot(inp)
853        # Inputs from Enter and NextIteration nodes are not validated because
854        # DebugNodeInserter::InsertNodes() in the debugger core skips creating
855        # control edges from debug ops watching these types of nodes.
856        if (inp_node in self._debug_watches[device_name] and
857            inp_output_slot in self._debug_watches[device_name][inp_node] and
858            debug_graph.node_op_types.get(inp) not in (
859                "Enter", "NextIteration") and
860            (inp_node, inp_output_slot) not in pending_inputs[node]):
861          pending_inputs[node].append((inp_node, inp_output_slot))
862
863    for i, datum in enumerate(self._dump_tensor_data[device_name]):
864      node = datum.node_name
865      slot = datum.output_slot
866      # In some cases (e.g., system clocks with insufficient precision),
867      # the upstream and downstream tensors may have identical timestamps, the
868      # following check examines this possibility and avoids raising an error if
869      # that is the case.
870      if not self._satisfied_at_timestamp(
871          device_name, pending_inputs[node], datum.timestamp, start_i=i + 1):
872        raise ValueError("Causality violated in timing relations of debug "
873                         "dumps: %s (%d): "
874                         "these input(s) are not satisfied: %s" %
875                         (node, datum.timestamp, repr(pending_inputs[node])))
876
877      recipients = debug_graph.node_recipients[node]
878      for recipient in recipients:
879        recipient_pending_inputs = pending_inputs[recipient]
880        if (node, slot) in recipient_pending_inputs:
881          if self.node_op_type(recipient) == "Merge":
882            # If this is a Merge op, we automatically clear the list because
883            # a Merge node only requires one of its two inputs.
884            del recipient_pending_inputs[:]
885          else:
886            del recipient_pending_inputs[
887                recipient_pending_inputs.index((node, slot))]
888
889  def _satisfied_at_timestamp(self, device_name, pending, timestamp, start_i=0):
890    """Determine whether pending inputs are satisfied at given timestamp.
891
892    Note: This method mutates the input argument "pending".
893
894    Args:
895      device_name: (str) device name.
896      pending: A list of 2-tuple (node_name, output_slot): the dependencies to
897        check.
898      timestamp: (int) the timestamp in question.
899      start_i: (int) the index in self._dump_tensor_data to start searching for
900        the timestamp.
901
902    Returns:
903      (bool) Whether all the dependencies in pending are satisfied at the
904        timestamp. If pending is empty to begin with, return True.
905    """
906    if not pending:
907      return True
908
909    for datum in self._dump_tensor_data[device_name][start_i:]:
910      if datum.timestamp > timestamp:
911        break
912      if (datum.timestamp == timestamp and
913          (datum.node_name, datum.output_slot) in pending):
914        pending.remove((datum.node_name, datum.output_slot))
915        if not pending:
916          return True
917
918    return not pending
919
920  def loaded_partition_graphs(self):
921    """Test whether partition graphs have been loaded."""
922    return bool(self._debug_graphs)
923
924  def partition_graphs(self):
925    """Get the partition graphs.
926
927    Returns:
928      Partition graphs as a list of GraphDef.
929
930    Raises:
931      LookupError: If no partition graphs have been loaded.
932    """
933    if not self._debug_graphs:
934      raise LookupError("No partition graphs have been loaded.")
935    return [self._debug_graphs[key].debug_graph_def
936            for key in self._debug_graphs]
937
938  def reconstructed_non_debug_partition_graphs(self):
939    """Reconstruct partition graphs with the debugger-inserted ops stripped.
940
941    The reconstructed partition graphs are identical to the original (i.e.,
942    non-debugger-decorated) partition graphs except in the following respects:
943      1) The exact names of the runtime-inserted internal nodes may differ.
944         These include _Send, _Recv, _HostSend, _HostRecv, _Retval ops.
945      2) As a consequence of 1, the nodes that receive input directly from such
946         send- and recv-type ops will have different input names.
947      3) The parallel_iteration attribute of while-loop Enter ops are set to 1.
948
949    Returns:
950      A dict mapping device names (`str`s) to reconstructed
951      `tf.compat.v1.GraphDef`s.
952    """
953    non_debug_graphs = {}
954    for key in self._debug_graphs:
955      non_debug_graphs[key] = self._debug_graphs[key].non_debug_graph_def
956    return non_debug_graphs
957
958  @property
959  def run_fetches_info(self):
960    """Get a str representation of the fetches used in the Session.run() call.
961
962    Returns:
963      If the information is available from one `Session.run` call, a `str`
964        obtained from `repr(fetches)`.
965      If the information is available from multiple `Session.run` calls, a
966        `list` of `str` from `repr(fetches)`.
967      If the information is not available, `None`.
968    """
969
970    output = self._run_fetches_info
971    return output[0] if len(output) == 1 else output
972
973  @property
974  def run_feed_keys_info(self):
975    """Get a str representation of the feed_dict used in the Session.run() call.
976
977    Returns:
978      If the information is available from one `Session.run` call, a `str`
979        obtained from `repr(feed_dict)`.
980      If the information is available from multiple `Session.run` calls, a
981        `list` of `str` obtained from `repr(feed_dict)`.
982      If the information is not available, `None`.
983    """
984
985    output = self._run_feed_keys_info
986    return output[0] if len(output) == 1 else output
987
988  def _infer_device_name(self, device_name, node_name):
989    """Infer the device name given node name.
990
991    If device_name is provided (i.e., not None), it'll be simply returned right
992    away.
993
994    Args:
995      device_name: (str or None) name of the device. If None, will try to infer
996        the device name by looking at the available nodes.
997      node_name: (str) name of the node.
998
999    Returns:
1000      (str) Inferred name of the device, if available.
1001
1002    Raises:
1003      ValueError: If the node name does not exist on any of the available
1004        devices or if there are multiple devices that contain the node with
1005        the given name.
1006    """
1007    if device_name is None:
1008      if node_name in self._node_devices:
1009        if len(self._node_devices[node_name]) == 1:
1010          return list(self._node_devices[node_name])[0]
1011        else:
1012          raise ValueError(
1013              "There are multiple (%d) devices with nodes named '%s' but "
1014              "device_name is not specified." %
1015              (len(self._node_devices[node_name]), node_name))
1016      else:
1017        raise ValueError("None of the %d device(s) has a node named '%s'." %
1018                         (len(self._device_names), node_name))
1019    else:
1020      return device_name
1021
1022  def nodes(self, device_name=None):
1023    """Get a list of all nodes from the partition graphs.
1024
1025    Args:
1026      device_name: (`str`) name of device. If None, all nodes from all available
1027        devices will be included.
1028
1029    Returns:
1030      All nodes' names, as a list of str.
1031
1032    Raises:
1033      LookupError: If no partition graphs have been loaded.
1034      ValueError: If specified node name does not exist.
1035    """
1036    if not self._debug_graphs:
1037      raise LookupError("No partition graphs have been loaded.")
1038    if device_name is None:
1039      nodes = []
1040      for device_name in self._debug_graphs:
1041        nodes.extend(self._debug_graphs[device_name].node_inputs.keys())
1042      return nodes
1043    else:
1044      if device_name not in self._debug_graphs:
1045        raise ValueError("Invalid device name: %s" % device_name)
1046      return self._debug_graphs[device_name].node_inputs.keys()
1047
1048  def node_attributes(self, node_name, device_name=None):
1049    """Get the attributes of a node.
1050
1051    Args:
1052      node_name: Name of the node in question.
1053      device_name: (`str`) name of the device. If there is only one device or if
1054        node_name exists on only one device, this argument is optional.
1055
1056    Returns:
1057      Attributes of the node.
1058
1059    Raises:
1060      LookupError: If no partition graphs have been loaded.
1061    """
1062    if not self._debug_graphs:
1063      raise LookupError("No partition graphs have been loaded.")
1064
1065    device_name = self._infer_device_name(device_name, node_name)
1066    return self._debug_graphs[device_name].node_attributes[node_name]
1067
1068  def node_inputs(self, node_name, is_control=False, device_name=None):
1069    """Get the inputs of given node according to partition graphs.
1070
1071    Args:
1072      node_name: Name of the node.
1073      is_control: (`bool`) Whether control inputs, rather than non-control
1074        inputs, are to be returned.
1075      device_name: (`str`) name of the device. If there is only one device or if
1076        node_name exists on only one device, this argument is optional.
1077
1078    Returns:
1079      (`list` of `str`) inputs to the node, as a list of node names.
1080
1081    Raises:
1082      LookupError: If node inputs and control inputs have not been loaded
1083         from partition graphs yet.
1084    """
1085    if not self._debug_graphs:
1086      raise LookupError(
1087          "Node inputs are not loaded from partition graphs yet.")
1088
1089    device_name = self._infer_device_name(device_name, node_name)
1090    if is_control:
1091      return self._debug_graphs[device_name].node_ctrl_inputs[node_name]
1092    else:
1093      return self._debug_graphs[device_name].node_inputs[node_name]
1094
1095  def transitive_inputs(self,
1096                        node_name,
1097                        include_control=True,
1098                        include_reversed_ref=False,
1099                        device_name=None,):
1100    """Get the transitive inputs of given node according to partition graphs.
1101
1102    Args:
1103      node_name: Name of the node.
1104      include_control: Include control inputs (True by default).
1105      include_reversed_ref: Whether a ref input, say from A to B, is to be also
1106        considered as an input from B to A. The rationale is that ref inputs
1107        generally let the recipient (e.g., B in this case) mutate the value of
1108        the source (e.g., A in this case). So the reverse direction of the ref
1109        edge reflects the direction of information flow.
1110      device_name: (`str`) name of the device. If there is only one device or if
1111        node_name exists on only one device, this argument is optional.
1112
1113    Returns:
1114      (`list` of `str`) all transitive inputs to the node, as a list of node
1115        names.
1116
1117    Raises:
1118      LookupError: If node inputs and control inputs have not been loaded
1119         from partition graphs yet.
1120    """
1121    if not self._debug_graphs:
1122      raise LookupError(
1123          "Node inputs are not loaded from partition graphs yet.")
1124
1125    device_name = self._infer_device_name(device_name, node_name)
1126
1127    input_lists = [self._debug_graphs[device_name].node_inputs]
1128    if include_control:
1129      input_lists.append(self._debug_graphs[device_name].node_ctrl_inputs)
1130    if include_reversed_ref:
1131      input_lists.append(
1132          self._debug_graphs[device_name].node_reversed_ref_inputs)
1133    tracer = debug_graphs.DFSGraphTracer(
1134        input_lists,
1135        skip_node_names=self._get_merge_node_names(device_name))
1136    tracer.trace(node_name)
1137    return tracer.inputs()
1138
1139  def _get_merge_node_names(self, device_name):
1140    """Lazily get a list of Merge nodes on a given device."""
1141    if device_name not in self._device_names:
1142      raise ValueError("Invalid device name: %s" % device_name)
1143
1144    if not hasattr(self, "_merge_node_names"):
1145      self._merge_node_names = {}
1146    if device_name not in self._merge_node_names:
1147      debug_graph = self._debug_graphs[device_name]
1148      self._merge_node_names[device_name] = [
1149          node for node in debug_graph.node_op_types
1150          if debug_graph.node_op_types[node] == "Merge"]
1151    return self._merge_node_names[device_name]
1152
1153  def find_some_path(self,
1154                     src_node_name,
1155                     dst_node_name,
1156                     include_control=True,
1157                     include_reversed_ref=False,
1158                     device_name=None):
1159    """Find a path between a source node and a destination node.
1160
1161    Limitation: the source and destination are required to be on the same
1162    device, i.e., this method does not yet take into account Send/Recv nodes
1163    across devices.
1164
1165    TODO(cais): Make this method work across device edges by tracing Send/Recv
1166      nodes.
1167
1168    Args:
1169      src_node_name: (`str`) name of the source node or name of an output tensor
1170        of the node.
1171      dst_node_name: (`str`) name of the destination node or name of an output
1172        tensor of the node.
1173      include_control: (`bool`) whrther control edges are considered in the
1174        graph tracing.
1175      include_reversed_ref: Whether a ref input, say from A to B, is to be also
1176        considered as an input from B to A. The rationale is that ref inputs
1177        generally let the recipient (e.g., B in this case) mutate the value of
1178        the source (e.g., A in this case). So the reverse direction of the ref
1179        edge reflects the direction of information flow.
1180      device_name: (`str`) name of the device. If there is only one device or if
1181        node_name exists on only one device, this argument is optional.
1182
1183    Returns:
1184      A path from the src_node_name to dst_node_name, as a `list` of `str`, if
1185      it exists. The list includes src_node_name as the first item and
1186      dst_node_name as the last.
1187      If such a path does not exist, `None`.
1188
1189    Raises:
1190      ValueError: If the source and destination nodes are not on the same
1191        device.
1192    """
1193    src_device_name = self._infer_device_name(device_name, src_node_name)
1194    dst_device_name = self._infer_device_name(device_name, dst_node_name)
1195
1196    if src_device_name != dst_device_name:
1197      raise ValueError(
1198          "Source (%s) and destination (%s) are not on the same device: "
1199          "%s vs. %s" % (src_node_name, dst_node_name, src_device_name,
1200                         dst_device_name))
1201
1202    input_lists = [self._debug_graphs[dst_device_name].node_inputs]
1203    debug_graph = self._debug_graphs[dst_device_name]
1204    if include_control:
1205      input_lists.append(debug_graph.node_ctrl_inputs)
1206    if include_reversed_ref:
1207      input_lists.append(debug_graph.node_reversed_ref_inputs)
1208    tracer = debug_graphs.DFSGraphTracer(
1209        input_lists,
1210        skip_node_names=self._get_merge_node_names(dst_device_name),
1211        destination_node_name=src_node_name)
1212    # Here the value of destination_node_name is src_node_name, because we
1213    # are tracing the graph from output to its inputs (i.e., going backwards
1214    # on the graph).
1215
1216    try:
1217      tracer.trace(dst_node_name)
1218    except debug_graphs.GraphTracingReachedDestination:
1219      # Prune nodes not on the path.
1220      inputs = [dst_node_name] + tracer.inputs()
1221      depth_list = [0] + tracer.depth_list()
1222
1223      path = []
1224      curr_depth = depth_list[-1]
1225      for inp, depth in zip(reversed(inputs), reversed(depth_list)):
1226        if depth == curr_depth:
1227          path.append(inp)
1228          curr_depth -= 1
1229      return path
1230
1231  def node_recipients(self, node_name, is_control=False, device_name=None):
1232    """Get recipient of the given node's output according to partition graphs.
1233
1234    Args:
1235      node_name: (`str`) name of the node.
1236      is_control: (`bool`) whether control outputs, rather than non-control
1237        outputs, are to be returned.
1238      device_name: (`str`) name of the device. If there is only one device or if
1239        node_name exists on only one device, this argument is optional.
1240
1241    Returns:
1242      (`list` of `str`) all inputs to the node, as a list of node names.
1243
1244    Raises:
1245      LookupError: If node inputs and control inputs have not been loaded
1246         from partition graphs yet.
1247    """
1248
1249    if not self._debug_graphs:
1250      raise LookupError(
1251          "Node recipients are not loaded from partition graphs yet.")
1252
1253    device_name = self._infer_device_name(device_name, node_name)
1254    debug_graph = self._debug_graphs[device_name]
1255    if is_control:
1256      return debug_graph.node_ctrl_recipients[node_name]
1257    else:
1258      return debug_graph.node_recipients[node_name]
1259
1260  def devices(self):
1261    """Get the list of device names.
1262
1263    Returns:
1264      (`list` of `str`) names of the devices.
1265    """
1266    return self._device_names
1267
1268  def node_exists(self, node_name, device_name=None):
1269    """Test if a node exists in the partition graphs.
1270
1271    Args:
1272      node_name: (`str`) name of the node to be checked.
1273      device_name: optional device name. If None, will search for the node
1274        on all available devices. Otherwise, search for the node only on
1275        the given device.
1276
1277    Returns:
1278      A boolean indicating whether the node exists.
1279
1280    Raises:
1281      LookupError: If no partition graphs have been loaded yet.
1282      ValueError: If device_name is specified but cannot be found.
1283    """
1284    if not self._debug_graphs:
1285      raise LookupError(
1286          "Nodes have not been loaded from partition graphs yet.")
1287
1288    if (device_name is not None) and device_name not in self._debug_graphs:
1289      raise ValueError(
1290          "The specified device_name '%s' cannot be found." % device_name)
1291
1292    for _, debug_graph in self._debug_graphs.items():
1293      if node_name in debug_graph.node_inputs:
1294        return True
1295    return False
1296
1297  def node_device(self, node_name):
1298    """Get the names of the devices that has nodes of the specified name.
1299
1300    Args:
1301      node_name: (`str`) name of the node.
1302
1303    Returns:
1304      (`str` or `list` of `str`) name of the device(s) on which the node of the
1305        given name is found. Returns a `str` if there is only one such device,
1306        otherwise return a `list` of `str`.
1307
1308    Raises:
1309      LookupError: If node inputs and control inputs have not been loaded
1310         from partition graphs yet.
1311      ValueError: If the node does not exist in partition graphs.
1312    """
1313    if not self._debug_graphs:
1314      raise LookupError(
1315          "Node devices are not loaded from partition graphs yet.")
1316
1317    if node_name not in self._node_devices:
1318      raise ValueError("Node '%s' does not exist in partition graphs." %
1319                       node_name)
1320
1321    output = list(self._node_devices[node_name])
1322    return output[0] if len(output) == 1 else output
1323
1324  def node_op_type(self, node_name, device_name=None):
1325    """Get the op type of given node.
1326
1327    Args:
1328      node_name: (`str`) name of the node.
1329      device_name: (`str`) name of the device. If there is only one device or if
1330        node_name exists on only one device, this argument is optional.
1331
1332    Returns:
1333      (`str`) op type of the node.
1334
1335    Raises:
1336      LookupError: If node op types have not been loaded
1337         from partition graphs yet.
1338    """
1339    if not self._debug_graphs:
1340      raise LookupError(
1341          "Node op types are not loaded from partition graphs yet.")
1342
1343    device_name = self._infer_device_name(device_name, node_name)
1344    return self._debug_graphs[device_name].node_op_types[node_name]
1345
1346  def debug_watch_keys(self, node_name, device_name=None):
1347    """Get all tensor watch keys of given node according to partition graphs.
1348
1349    Args:
1350      node_name: (`str`) name of the node.
1351      device_name: (`str`) name of the device. If there is only one device or if
1352        node_name exists on only one device, this argument is optional.
1353
1354    Returns:
1355      (`list` of `str`) all debug tensor watch keys. Returns an empty list if
1356        the node name does not correspond to any debug watch keys.
1357
1358    Raises:
1359      `LookupError`: If debug watch information has not been loaded from
1360        partition graphs yet.
1361    """
1362
1363    try:
1364      device_name = self._infer_device_name(device_name, node_name)
1365    except ValueError:
1366      return []
1367
1368    if node_name not in self._debug_watches[device_name]:
1369      return []
1370
1371    watch_keys = []
1372    for watched_slot in self._debug_watches[device_name][node_name]:
1373      debug_ops = self._debug_watches[device_name][node_name][watched_slot]
1374      for debug_op in debug_ops:
1375        watch_keys.append(
1376            _get_tensor_watch_key(node_name, watched_slot, debug_op))
1377
1378    return watch_keys
1379
1380  def watch_key_to_data(self, debug_watch_key, device_name=None):
1381    """Get all `DebugTensorDatum` instances corresponding to a debug watch key.
1382
1383    Args:
1384      debug_watch_key: (`str`) debug watch key.
1385      device_name: (`str`) name of the device. If there is only one device or if
1386        the specified debug_watch_key exists on only one device, this argument
1387        is optional.
1388
1389    Returns:
1390      A list of `DebugTensorDatum` instances that correspond to the debug watch
1391      key. If the watch key does not exist, returns an empty list.
1392
1393    Raises:
1394      ValueError: If there are multiple devices that have the debug_watch_key,
1395        but device_name is not specified.
1396    """
1397    if device_name is None:
1398      matching_device_names = [
1399          name for name in self._watch_key_to_datum
1400          if debug_watch_key in self._watch_key_to_datum[name]]
1401      if not matching_device_names:
1402        return []
1403      elif len(matching_device_names) == 1:
1404        device_name = matching_device_names[0]
1405      else:
1406        raise ValueError(
1407            "The debug watch key '%s' exists on multiple (%d) devices, but "
1408            "device name is not specified." %
1409            (debug_watch_key, len(matching_device_names)))
1410    elif device_name not in self._debug_key_to_datum:
1411      raise ValueError(
1412          "There is no device named '%s' consisting of debug watch keys." %
1413          device_name)
1414
1415    return self._watch_key_to_datum[device_name].get(debug_watch_key, [])
1416
1417  def find(self,
1418           predicate,
1419           first_n=0,
1420           device_name=None,
1421           exclude_node_names=None):
1422    """Find dumped tensor data by a certain predicate.
1423
1424    Args:
1425      predicate: A callable that takes two input arguments:
1426
1427        ```python
1428        def predicate(debug_tensor_datum, tensor):
1429          # returns a bool
1430        ```
1431
1432        where `debug_tensor_datum` is an instance of `DebugTensorDatum`, which
1433        carries the metadata, such as the `Tensor`'s node name, output slot
1434        timestamp, debug op name, etc.; and `tensor` is the dumped tensor value
1435        as a `numpy.ndarray`.
1436      first_n: (`int`) return only the first n `DebugTensotDatum` instances (in
1437        time order) for which the predicate returns True. To return all the
1438        `DebugTensotDatum` instances, let first_n be <= 0.
1439      device_name: optional device name.
1440      exclude_node_names: Optional regular expression to exclude nodes with
1441        names matching the regular expression.
1442
1443    Returns:
1444      A list of all `DebugTensorDatum` objects in this `DebugDumpDir` object
1445       for which predicate returns True, sorted in ascending order of the
1446       timestamp.
1447    """
1448    if exclude_node_names:
1449      exclude_node_names = re.compile(exclude_node_names)
1450
1451    matched_data = []
1452    for device in (self._dump_tensor_data if device_name is None
1453                   else (self._dump_tensor_data[device_name],)):
1454      for datum in self._dump_tensor_data[device]:
1455        if exclude_node_names and exclude_node_names.match(datum.node_name):
1456          continue
1457
1458        if predicate(datum, datum.get_tensor()):
1459          matched_data.append(datum)
1460
1461          if first_n > 0 and len(matched_data) >= first_n:
1462            return matched_data
1463
1464    return matched_data
1465
1466  def get_tensor_file_paths(self,
1467                            node_name,
1468                            output_slot,
1469                            debug_op,
1470                            device_name=None):
1471    """Get the file paths from a debug-dumped tensor.
1472
1473    Args:
1474      node_name: (`str`) name of the node that the tensor is produced by.
1475      output_slot: (`int`) output slot index of tensor.
1476      debug_op: (`str`) name of the debug op.
1477      device_name: (`str`) name of the device. If there is only one device or if
1478        the specified debug_watch_key exists on only one device, this argument
1479        is optional.
1480
1481    Returns:
1482      List of file path(s) loaded. This is a list because each debugged tensor
1483        may be dumped multiple times.
1484
1485    Raises:
1486      WatchKeyDoesNotExistInDebugDumpDirError: If the tensor does not exist in
1487        the debug-dump data.
1488    """
1489
1490    device_name = self._infer_device_name(device_name, node_name)
1491    watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op)
1492    if watch_key not in self._watch_key_to_datum[device_name]:
1493      raise WatchKeyDoesNotExistInDebugDumpDirError(
1494          "Watch key \"%s\" does not exist in the debug dump of device %s" %
1495          (watch_key, device_name))
1496
1497    return [datum.file_path for datum in
1498            self._watch_key_to_datum[device_name][watch_key]]
1499
1500  def get_tensors(self, node_name, output_slot, debug_op, device_name=None):
1501    """Get the tensor value from for a debug-dumped tensor.
1502
1503    The tensor may be dumped multiple times in the dump root directory, so a
1504    list of tensors (`numpy.ndarray`) is returned.
1505
1506    Args:
1507      node_name: (`str`) name of the node that the tensor is produced by.
1508      output_slot: (`int`) output slot index of tensor.
1509      debug_op: (`str`) name of the debug op.
1510      device_name: (`str`) name of the device. If there is only one device or if
1511        the specified debug_watch_key exists on only one device, this argument
1512        is optional.
1513
1514    Returns:
1515      List of tensors (`numpy.ndarray`) loaded from the debug-dump file(s).
1516
1517    Raises:
1518      WatchKeyDoesNotExistInDebugDumpDirError: If the tensor does not exist in
1519        the debug-dump data.
1520    """
1521
1522    watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op)
1523    try:
1524      device_name = self._infer_device_name(device_name, node_name)
1525      return [datum.get_tensor() for datum in
1526              self._watch_key_to_datum[device_name][watch_key]]
1527    except (ValueError, KeyError):
1528      raise WatchKeyDoesNotExistInDebugDumpDirError(
1529          "Watch key \"%s\" does not exist in the debug dump of device %s" %
1530          (watch_key, device_name))
1531
1532  def get_rel_timestamps(self,
1533                         node_name,
1534                         output_slot,
1535                         debug_op,
1536                         device_name=None):
1537    """Get the relative timestamp from for a debug-dumped tensor.
1538
1539    Relative timestamp means (absolute timestamp - `t0`), where `t0` is the
1540    absolute timestamp of the first dumped tensor in the dump root. The tensor
1541    may be dumped multiple times in the dump root directory, so a list of
1542    relative timestamps (`numpy.ndarray`) is returned.
1543
1544    Args:
1545      node_name: (`str`) name of the node that the tensor is produced by.
1546      output_slot: (`int`) output slot index of tensor.
1547      debug_op: (`str`) name of the debug op.
1548      device_name: (`str`) name of the device. If there is only one device or if
1549        the specified debug_watch_key exists on only one device, this argument
1550        is optional.
1551
1552    Returns:
1553      (`list` of `int`) list of relative timestamps.
1554
1555    Raises:
1556      WatchKeyDoesNotExistInDebugDumpDirError: If the tensor watch key does not
1557        exist in the debug dump data.
1558    """
1559
1560    device_name = self._infer_device_name(device_name, node_name)
1561    watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op)
1562    if watch_key not in self._watch_key_to_datum[device_name]:
1563      raise WatchKeyDoesNotExistInDebugDumpDirError(
1564          "Watch key \"%s\" does not exist in the debug dump" % watch_key)
1565
1566    # TODO(cais): Figure out whether this should be relative to the global t0.
1567    return self._watch_key_to_rel_time[device_name][watch_key]
1568
1569  def get_dump_sizes_bytes(self,
1570                           node_name,
1571                           output_slot,
1572                           debug_op,
1573                           device_name=None):
1574    """Get the sizes of the dump files for a debug-dumped tensor.
1575
1576    Unit of the file size: byte.
1577
1578    Args:
1579      node_name: (`str`) name of the node that the tensor is produced by.
1580      output_slot: (`int`) output slot index of tensor.
1581      debug_op: (`str`) name of the debug op.
1582      device_name: (`str`) name of the device. If there is only one device or if
1583        the specified debug_watch_key exists on only one device, this argument
1584        is optional.
1585
1586    Returns:
1587      (`list` of `int`): list of dump file sizes in bytes.
1588
1589    Raises:
1590      WatchKeyDoesNotExistInDebugDumpDirError: If the tensor watch key does not
1591        exist in the debug dump data.
1592    """
1593
1594    device_name = self._infer_device_name(device_name, node_name)
1595    watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op)
1596    if watch_key not in self._watch_key_to_datum[device_name]:
1597      raise WatchKeyDoesNotExistInDebugDumpDirError(
1598          "Watch key \"%s\" does not exist in the debug dump of device %s" %
1599          (watch_key, device_name))
1600
1601    return self._watch_key_to_dump_size_bytes[device_name][watch_key]
1602
1603  def node_traceback(self, element_name):
1604    """Try to retrieve the Python traceback of node's construction.
1605
1606    Args:
1607      element_name: (`str`) Name of a graph element (node or tensor).
1608
1609    Returns:
1610      (list) The traceback list object as returned by the `extract_trace`
1611        method of Python's traceback module.
1612
1613    Raises:
1614      LookupError: If Python graph is not available for traceback lookup.
1615      KeyError: If the node cannot be found in the Python graph loaded.
1616    """
1617
1618    if self._python_graph is None:
1619      raise LookupError("Python graph is not available for traceback lookup")
1620
1621    node_name = debug_graphs.get_node_name(element_name)
1622    if node_name not in self._node_traceback:
1623      raise KeyError("Cannot find node \"%s\" in Python graph" % node_name)
1624
1625    return self._node_traceback[node_name]
1626