1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorFlow Debugger (tfdbg) Stepper Module."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import copy
21import os
22import shutil
23import tempfile
24import time
25
26import six
27
28from tensorflow.core.protobuf import config_pb2
29from tensorflow.python.debug.lib import debug_data
30from tensorflow.python.debug.lib import debug_graphs
31from tensorflow.python.debug.lib import debug_utils
32from tensorflow.python.framework import ops
33from tensorflow.python.ops import session_ops
34
35
36# TODO(cais): Use nest.flatten once it handles nest Dicts correctly.
37def _flatten_fetches(fetches):
38  """Flatten list, tuple of fetches, or a single fetch into a list of fetches.
39
40  Args:
41    fetches: The fetches to flatten: Can be a single Tensor, Op, or a
42      potentially nested list, tuple or dict of such individual fetches.
43
44  Returns:
45    The fetches flattened to a list.
46  """
47
48  flattened = []
49  if isinstance(fetches, (list, tuple)):
50    for fetch in fetches:
51      flattened.extend(_flatten_fetches(fetch))
52  elif isinstance(fetches, dict):
53    for key in fetches:
54      flattened.extend(_flatten_fetches(fetches[key]))
55  else:
56    flattened.append(fetches)
57
58  return flattened
59
60
61class NodeStepper(object):
62  """TensorFlow Debugger (tfdbg) stepper.
63
64  The stepper provides ability to perform "continue to" actions on a graph,
65  given fetch and feeds. The stepper calculates the transitive closure of the
66  fetch. cont() (continue to) calls can only be performed on members of the
67  transitive closure.
68
69  On a cont() call, the stepper performs depth-first tracing of the input
70  tree of the target. When it reaches an input where one of the following is
71  available, it will supply the available value to the feed_dict of the cont()
72  call:
73    (1) Overriding (injected) values from the client.
74    (2) TensorHandles from previous cont() calls.
75    (3) Dumped intermediate Tensors from previous cont() calls.
76    (4) Feeds supplied during the construction of the stepper instance.
77
78  During the cont() call, intermediate Tensors are dumped to temporary
79  directories. The dumped Tensor values will be used in subsequent cont() calls
80  when they are required as data dependencies.
81
82  The temporary directories are automatically clean when the NodeStepper
83  instance exits as a context manager.
84
85  Once the tracing is complete, it will issue a run() call on the
86  underlying session, using the aforementioned feed_dict prepared by the input
87  tracing, to achieve the "continue-to" action. The above process takes into
88  account whether the transitive closure of an input contains Variables that
89  are updated during previous cont() calls on this stepper instance. If such
90  updates exist, we say the transitive closure is "dirty" and the stepper
91  can restore the "clean" state of the Variable and avoid using the
92  TensorHandle.
93
94  Example of basic usage:
95    a = tf.Variable(1.0, name="a")
96    b = tf.Variable(2.0, anme="b")
97    c = tf.add(a, b, name="c")
98    d = tf.multiply(a, c, name="d")
99
100    sess = tf.Session()
101    sess.run(tf.initialize_all_varialbes())
102    stepper = NodeStepper(sess, d)
103
104    stepper.cont(c)  # Caches the handle to Tensor c:0.
105    stepper.cont(d)  # Uses handle to Tensor c:0, avoiding recomputing c.
106  """
107
108  # Possible types of feed used during cont() calls.
109  FEED_TYPE_CLIENT = "client"
110  FEED_TYPE_HANDLE = "handle"
111  FEED_TYPE_OVERRIDE = "override"
112  FEED_TYPE_DUMPED_INTERMEDIATE = "dumped_intermediate"
113
114  def __init__(self, sess, fetches, feed_dict=None):
115    """Constructor for Debugger.
116
117    Args:
118      sess: (Session) the TensorFlow Session to step in.
119      fetches: Same as the fetches input argument to `Session.run()`.
120      feed_dict: Same as the feed_dict input argument to `Session.run()`.
121    """
122
123    self._sess = sess
124
125    self._fetches = fetches
126    flattened_fetches = _flatten_fetches(fetches)
127
128    self._fetch_names, self._fetch_list = self._get_fetch_and_name_lists(
129        flattened_fetches)
130
131    # A map from Variable name to initializer op.
132    self._variable_initializers = {}
133
134    # A map from Variable name to initial value, used when overriding or
135    # restoring Variable values.
136    self._variable_initial_values = {}
137
138    # Initialize the map for output recipients (targets).
139    self._output_targets = {}
140
141    # Sorted transitive closure of the fetched node.
142    # We also collect the list of the names of the reference-type Tensors,
143    # because we later need to avoid using intermediate dumps for such Tensors.
144    (self._sorted_nodes,
145     self._closure_elements,
146     self._ref_tensor_names) = self._dfs_visit(self._sess.graph,
147                                               self._fetch_list)
148
149    self._transitive_closure_set = set(self._sorted_nodes)
150
151    # A map from Variable name to the old values (before any cont() calls).
152    self._cached_variable_values = {}
153
154    # A cache map from tensor name to what variables may invalidate the tensor
155    self._cached_invalidation_path = {}
156
157    # Keep track of which variables are in a dirty state.
158    self._dirty_variables = set()
159
160    # Variables updated in the last cont() call.
161    self._last_updated = None
162
163    # Cached tensor handles: a dict with keys as tensor names and values as
164    # tensor handles.
165    self._tensor_handles = {}
166
167    # Cached intermediate tensor values: a dict mapping tensor names to
168    # DebugTensorDatum.
169    self._dumped_intermediate_tensors = {}
170    self._dump_session_root = tempfile.mkdtemp(prefix="tfdbg_stepper_")
171
172    # Feed dict from the client.
173    self._client_feed_dict = {}
174    if feed_dict:
175      for key in feed_dict:
176        if isinstance(key, ops.Tensor):
177          self._client_feed_dict[key.name] = feed_dict[key]
178        else:
179          self._client_feed_dict[key] = feed_dict[key]
180
181    # Overriding tensor values.
182    self._override_tensors = {}
183
184    # What the feed types were used by the last cont() call.
185    self._last_feed_types = {}
186
187  def __enter__(self):
188    return self
189
190  def __exit__(self, exc_type, exc_value, exc_traceback):
191    if os.path.isdir(self._dump_session_root):
192      shutil.rmtree(self._dump_session_root)
193
194  def _get_fetch_and_name_lists(self, flattened_fetches):
195    """Get the lists of fetches and their names.
196
197    Args:
198      flattened_fetches: A list of fetches or their names. Can mix fetches and
199        names.
200
201    Returns:
202      (list of str): A list of the names of the fetches.
203      (list): A list of the fetches.
204    """
205
206    fetch_names = []
207    fetch_list = []
208    for fetch in flattened_fetches:
209      if isinstance(fetch, six.string_types):
210        fetch_names.append(fetch)
211        fetch_list.append(self._sess.graph.as_graph_element(fetch))
212      else:
213        fetch_names.append(fetch.name)
214        fetch_list.append(fetch)
215
216    return fetch_names, fetch_list
217
218  def _dfs_visit(self, graph, elem_list):
219    """Trace back the input of a graph element, using depth-first search.
220
221    Uses non-recursive implementation to prevent stack overflow for deep
222    graphs.
223
224    Also performs the following action(s):
225      1) When encountering a Variable, obtain its initializer op, to
226         facilitate possible subsequent restoration / overriding of variable
227         value.
228
229    Args:
230      graph: A TF graph instance.
231      elem_list: list of graph elements: a Tensor or an Operation.
232
233    Returns:
234      (list of str) A topologically-sorted list of all nodes (not tensors)
235        in the transitive closure of elem_list. Obviously, the topological sort
236         is not unique in general. The return value here is just an arbitrary
237         one of potentially many possible topological sorts.
238      (list of str) A list of all graph elements (nodes and/or tensors) in the
239        transitive closure.
240    """
241
242    # These set should hold only strings, i.e, names of the nodes.
243    done = set()  # Keep track of visited graph elements.
244
245    # A list of str: Names of the topologically-sorted graph elements.
246    node_inputs = dict()  # New: Input map of nodes in the transitive closure.
247
248    elem_stack = copy.copy(elem_list)
249
250    # Graph elements in the transitive closure, including the nodes and tensors.
251    closure_elements = [elem.name for elem in elem_list]
252
253    ref_tensor_names = set()
254    for element in elem_list:
255      if isinstance(element, ops.Tensor) and element.dtype._is_ref_dtype:  # pylint: disable=protected-access
256        ref_tensor_names.add(element.name)
257
258    while elem_stack:
259      curr_elem = elem_stack.pop()
260      curr_node = self._get_node(curr_elem)
261
262      done.add(curr_node.name)
263
264      non_control_inputs = [inp for inp in curr_node.inputs]
265      control_inputs = [inp for inp in curr_node.control_inputs]
266      all_inputs = set(non_control_inputs + control_inputs)
267
268      if curr_node.name not in node_inputs:
269        all_input_nodes = set()
270        for inp in all_inputs:
271          all_input_nodes.add(self._get_node(inp).name)
272        node_inputs[curr_node.name] = all_input_nodes
273
274      # Iterate through the (non-control) inputs.
275      for inp in all_inputs:
276        # Set up the non-control output map.
277        # if is_non_control_input:
278        if inp.name not in self._output_targets:
279          self._output_targets[inp.name] = set([curr_elem.name])
280        else:
281          self._output_targets[inp.name].add(curr_elem.name)
282
283        if (isinstance(inp, ops.Tensor) and
284            inp.op.type in ["Variable", "VariableV2"] and
285            inp.name not in self._variable_initializers):
286          # Obtain the initializer op of the variable, in case the Variable's
287          # value needs to be restored later.
288          initializer = graph.as_graph_element(inp.op.name + "/Assign")
289          self._variable_initializers[inp.name] = initializer
290          self._variable_initial_values[inp.name] = initializer.inputs[1]
291
292        inp_node = self._get_node(inp)
293        if inp_node.name in done:
294          # Already visited.
295          continue
296
297        elem_stack.append(inp)
298        closure_elements.append(inp.name)
299        if isinstance(inp, ops.Tensor) and inp.dtype._is_ref_dtype:  # pylint: disable=protected-access
300          ref_tensor_names.add(inp.name)
301
302    # Now that we have traversed the transitive closure and obtained the
303    # node-input map, we can topologically sort them.
304    sorted_nodes = []
305    stack = []
306    for node in node_inputs:
307      if not node_inputs[node]:
308        stack.append(node)
309    for node in stack:
310      del node_inputs[node]
311
312    while stack:
313      curr_node = stack.pop()
314      sorted_nodes.append(curr_node)
315
316      # Iterate through the node-input map and remove the child.
317      pushes = []
318      for node in node_inputs:
319        if curr_node in node_inputs[node]:
320          node_inputs[node].remove(curr_node)
321          if not node_inputs[node]:
322            pushes.append(node)
323
324      # Delete new pushes from node-input map.
325      for node in pushes:
326        del node_inputs[node]
327
328      stack.extend(pushes)
329
330    return sorted_nodes, closure_elements, ref_tensor_names
331
332  def sorted_nodes(self):
333    """Get a topologically-sorted list of node names of the stepper.
334
335    These are the names of the nodes (i.e., not Tensors) in the transitive
336    closure of the stepper, in a topologically-sorted order.
337
338    Returns:
339      (list of str): Sorted transitive inputs to the fetch of the stepper
340        instance. The fetch itself is included in the list.
341    """
342
343    return self._sorted_nodes
344
345  def closure_elements(self):
346    """Get a name list of the graph elements of the stepper.
347
348    Returns:
349      (list of str): names of the graph elements (i.e., nodes and tensors) in
350    the transitive closure of the stepper, in a random order.
351    """
352
353    return self._closure_elements
354
355  def output_slots_in_closure(self, node_name):
356    """Get the output tensors in the transitive closure from node.
357
358    Args:
359      node_name: (str) Name of the node in question.
360
361    Returns:
362      (list of int) Output slots of the output tensors of the node that are in
363        the transitive closure of the stepper.
364    """
365
366    node = self._sess.graph.as_graph_element(node_name)
367
368    tensor_slots = []
369    for i, _ in enumerate(node.outputs):
370      tensor_name = node_name + ":%d" % i
371      if tensor_name in self._closure_elements:
372        tensor_slots.append(i)
373
374    return tensor_slots
375
376  def is_feedable(self, name):
377    """Determine if a graph element if feedable.
378
379    Args:
380      name: (str) name of the graph element (Tensor or Operation)
381
382    Returns:
383      (bool) whether the graph element is feedable.
384    """
385
386    if not isinstance(name, six.string_types):
387      raise TypeError("Expected type str; got type %s" % type(name))
388
389    elem = self._sess.graph.as_graph_element(name)
390    return self._sess.graph.is_feedable(elem)
391
392  def override_tensor(self, tensor_name, overriding_val):
393    """Override the value of a tensor.
394
395    Args:
396      tensor_name: (str) Name of the tensor to override.
397      overriding_val: (numpy.ndarray) Overriding tensor value.
398
399    Raises:
400      ValueError: If tensor_name does not correspond to a tensor in the input
401        tree to the fetched graph element of this stepper instance.
402    """
403
404    if not isinstance(tensor_name, six.string_types):
405      raise TypeError("Expected type str; got type %s" % type(tensor_name))
406
407    node_name = self._get_node_name(tensor_name)
408    if node_name not in self._transitive_closure_set:
409      raise ValueError(
410          "Cannot override tensor \"%s\" because it does not exist in the "
411          "input tree to the fetch \"%s\"" %
412          (tensor_name, repr(self._fetch_names)))
413
414    self._override_tensors[tensor_name] = overriding_val
415
416    # Invalidate cache by tracing outputs.
417    self._invalidate_transitively_outgoing_cache(tensor_name)
418
419  def remove_override(self, tensor_name):
420    """Remove the overriding value on a tensor.
421
422    Args:
423      tensor_name: (str) name of the tensor to remove the overriding value
424        from.
425
426    Raises:
427      ValueError: If no overriding value exists for tensor_name.
428    """
429
430    if tensor_name not in self._override_tensors:
431      raise ValueError("No overriding value exists for tensor \"%s\"." %
432                       tensor_name)
433
434    del self._override_tensors[tensor_name]
435
436    # Invalidate cache by tracing outputs.
437    self._invalidate_transitively_outgoing_cache(tensor_name)
438
439  def last_feed_types(self):
440    """Obtain information about the feed in the last cont() call.
441
442    Returns:
443      (dict) A dict mapping tensor names to feed types.
444    """
445
446    return self._last_feed_types
447
448  def cont(self,
449           target,
450           use_tensor_handles=True,
451           use_dumped_intermediates=True,
452           use_overrides=True,
453           invalidate_from_updated_variables=False,
454           restore_variable_values=False):
455    """Continue till the completion of the specified target tensor.
456
457    Args:
458      target: A single fetched Tensor or Op, or a name (str) representing the
459        Tensor or Op. In the case of a name str, the graph will be searched
460        to find the corresponding Tensor or Op.
461        # TODO(cais): Support multiple fetches as in Session.run() interface.
462      use_tensor_handles: (bool) Whether this cont() run will use cached tensor
463        handles to avoid recomputation. Default: True.
464      use_dumped_intermediates: (bool) Whether this cont() call will use dumped
465        intermediate tensors to avoid recomputation.
466      use_overrides: (bool) Whether the overriding tensor values supplied by
467        the client are to be used in this cont() call. Default: True.
468      invalidate_from_updated_variables: (bool) Whether to invalidate the
469        tensor handles and intermediate tensor handles affected by the
470        Variable updates that happen in this cont() call.
471      restore_variable_values: (bool) Whether the old values of the variables
472        (before any cont() calls in this object) are to be restored.
473
474    Returns:
475      Value from Session.run() of the target.
476
477    Raises:
478      ValueError: If the target is specified as a string and the string does
479        not correspond to any tensors in the Session graph.
480        Or if the target of this cont() is not in the input list of the Stepper
481        object's target.
482        Or if target is a Placeholder.
483    """
484
485    self._last_feed_types = {}
486
487    if isinstance(target, six.string_types):
488      # Fetch target is a string. Assume it is the name of the Tensor or Op and
489      # will attempt to find it in the Session's graph.
490      target_name = target
491    else:
492      target_name = target.name
493
494    graph_element = self._sess.graph.as_graph_element(target_name)
495    # Any additional tensor handles to obtain in this cont() action.
496    additional_handle_requests = []
497
498    if (isinstance(graph_element, ops.Tensor) and
499        graph_element.op.type == "Placeholder"):
500      self._last_feed_types[graph_element.name] = self.FEED_TYPE_CLIENT
501      return self._client_feed_dict[graph_element.name]
502    elif (isinstance(graph_element, ops.Operation) and
503          graph_element.type == "Placeholder"):
504      tensor_name = graph_element.name + ":0"
505      self._last_feed_types[tensor_name] = self.FEED_TYPE_CLIENT
506      return self._client_feed_dict[tensor_name]
507
508    if isinstance(graph_element, ops.Operation) and graph_element.outputs:
509      # Check if this op has any output tensors that also fall into this
510      # stepper's transitive closure.
511      node_outputs = [
512          output.name for output in graph_element.outputs
513          if output.name in self._closure_elements
514      ]
515      if node_outputs:
516        # The target is an op with at least one output within the transitive
517        # closure. The cont() action will amount to using the 0-th
518        # output Tensor as the target, as well as obtaining handles to it
519        # and to the rest of the outputs tensors in the transitive closure
520        # (if any).
521        target_name = node_outputs[0]
522        additional_handle_requests = node_outputs[1:]
523
524    # Verify that the target is in the transitive closure of the stepper's
525    # fetch.
526    target_node_name = self._get_node_name(target_name)
527    if target_node_name not in self._transitive_closure_set:
528      raise ValueError(
529          "Target \"%s\" is not in the transitive closure for the fetch of the "
530          "stepper: \"%s\"." % (target_name, repr(self._fetch_names)))
531
532    # Check if a cached tensor handle can be used on the fetch directly.
533    if use_tensor_handles and target_name in self._tensor_handles:
534      self._last_feed_types[target_name] = self.FEED_TYPE_HANDLE
535      return self._tensor_handles[target_name].eval()
536
537    # Check if a dumped intermediate tensor can be used on the fetch directly.
538    if (use_dumped_intermediates and
539        target_name in self._dumped_intermediate_tensors):
540      self._last_feed_types[target_name] = self.FEED_TYPE_DUMPED_INTERMEDIATE
541      return self._dumped_intermediate_tensors[target_name].get_tensor()
542
543    # Check if an overriding tensor value can be used directly.
544    if use_overrides and target_name in self._override_tensors:
545      # Override is available. Return the value right away.
546      self._last_feed_types[target_name] = self.FEED_TYPE_OVERRIDE
547      return self._override_tensors[target_name]
548
549    # Keep track of which variables are restored in this cont() call.
550    restored_variables = set()
551
552    # Keep track of which variables are "touched" (i.e., possibly updated) in
553    # this cont() call.
554    self._last_updated = set()
555
556    # =========================================================================
557    # Use a non-recursive method to trace the inputs from the node and set up
558    # the feeds.
559    feeds = {}  # The feeds to be used in the Session.run() call.
560    fetched = self._sess.graph.as_graph_element(target_name)
561    elem_stack = [fetched]
562    done = set()
563
564    while elem_stack:
565      curr_elem = elem_stack.pop()
566      curr_node = self._get_node(curr_elem)
567
568      done.add(curr_node.name)
569
570      non_control_inputs = [inp for inp in curr_node.inputs]
571      control_inputs = [inp for inp in curr_node.control_inputs]
572      all_inputs = set(non_control_inputs + control_inputs)
573
574      # Iterate through the (non-control) inputs.
575      for inp in all_inputs:
576        # Determine whether the input is feedable. Reference-type tensors,
577        # e.g., Variables, should not be fed, because they can change.
578        if isinstance(inp, ops.Tensor):
579          is_inp_ref = inp.dtype._is_ref_dtype  # pylint: disable=protected-access
580          can_feed = self._sess.graph.is_feedable(inp) and not is_inp_ref
581        else:
582          is_inp_ref = False
583          can_feed = False
584
585        if (restore_variable_values and inp.name in self._dirty_variables and
586            inp.name not in restored_variables and
587            inp.name not in self._last_updated):
588          # Do not restore Variables touched or restored previously in this
589          # cont() call.
590          initializer_op = self._variable_initializers[inp.name]
591          initial_value_tensor = self._variable_initial_values[inp.name]
592          self._sess.run(initializer_op,
593                         feed_dict={
594                             initial_value_tensor:
595                                 self._cached_variable_values[inp.name]
596                         })
597
598          # Mark the variable as restored.
599          restored_variables.add(inp.name)
600
601        # Determine if this is a reference-type input from a variable, and
602        # the recipient node is not Identity. In that case, the Variable
603        # needs to be marked as dirty and its current value recorded, due to
604        # the fact that the receiving op may mutate the value of the Variable.
605        if (is_inp_ref and inp.op.type in ["Variable", "VariableV2"] and
606            curr_node.type != "Identity"):
607          # Mark the variable as dirty.
608          self._last_updated.add(inp.name)
609
610          # Obtain the old value of the variable and cache it.
611          if inp.name not in self._cached_variable_values:
612            old_value = self._sess.run(inp)
613            self._cached_variable_values[inp.name] = old_value
614
615        # N.B.: The order of the logical branches matters. For example,
616        # _client_feed_dict comes after _tensor_handles, so that tensor
617        # handles stored in cont() calls can override the original client
618        # feeds. Also for example, _override_tensors comes the first, so
619        # the manual overriding, if exists, can always take effect.
620        if use_overrides and can_feed and inp.name in self._override_tensors:
621          # Use client-supplied overriding tensor value.
622          feeds[inp] = self._override_tensors[inp.name]
623          self._last_feed_types[inp.name] = self.FEED_TYPE_OVERRIDE
624        elif (can_feed and inp not in feeds and
625              use_tensor_handles and inp.name in self._tensor_handles):
626          # Tensor handle found in cache.
627          feeds[inp] = self._tensor_handles[inp.name]
628          self._last_feed_types[inp.name] = self.FEED_TYPE_HANDLE
629        elif (can_feed and inp not in feeds and
630              use_dumped_intermediates and
631              inp.name in self._dumped_intermediate_tensors):
632          # Dumped intermediate Tensor found.
633          feeds[inp] = self._dumped_intermediate_tensors[inp.name].get_tensor()
634          self._last_feed_types[inp.name] = self.FEED_TYPE_DUMPED_INTERMEDIATE
635        elif inp.name in self._client_feed_dict:
636          # This input is available in the client feed_dict.
637          feeds[inp] = self._client_feed_dict[inp.name]
638          self._last_feed_types[inp.name] = self.FEED_TYPE_CLIENT
639        else:
640          # There is no feed available for this input. So keep tracing its
641          # input(s).
642          inp_node = self._get_node(inp)
643          if inp_node.name in done:
644            # Already visited.
645            continue
646
647          elem_stack.append(inp)
648          done.add(inp_node.name)
649
650    # =========================================================================
651
652    if self._last_updated:
653      self._dirty_variables.update(self._last_updated)
654
655    for variable in restored_variables:
656      self._dirty_variables.remove(variable)
657
658    (dump_path,
659     run_options) = self._prepare_cont_call_dump_path_and_run_options()
660    if isinstance(fetched, ops.Operation):
661      # The fetched is an Operation: Will not get tensor handle.
662      self._sess.run(fetched, feed_dict=feeds, options=run_options)
663      return_value = None
664    else:
665      # This is a Tensor: Will get tensor handle and cache it.
666      # Will also get the additional requested tensor handles (if any).
667      tensors_to_get_handles_for = [fetched]
668      handle_names = [target_name]
669
670      tensors_to_get_handles_for.extend([
671          self._sess.graph.as_graph_element(h)
672          for h in additional_handle_requests
673      ])
674      handle_names.extend(additional_handle_requests)
675
676      handles = self._sess.run(
677          [session_ops.get_session_handle(tensor) for tensor in
678           tensors_to_get_handles_for],
679          feed_dict=feeds,
680          options=run_options)
681      for handle_name, handle in zip(handle_names, handles):
682        self._tensor_handles[handle_name] = handle
683
684      return_value = self._tensor_handles[target_name].eval()
685
686    self._load_dumped_intermediate_tensors(dump_path, target_name)
687
688    if invalidate_from_updated_variables:
689      # Invalidate caches at the end.
690      for last_updated_variable in self._last_updated:
691        self._invalidate_transitively_outgoing_cache(last_updated_variable)
692
693    return return_value
694
695  def _prepare_cont_call_dump_path_and_run_options(self):
696    """Prepare the dump path and RunOptions for next cont() call.
697
698    Returns:
699      dump_path: (str) Directory path to which the intermediate tensor will be
700        dumped.
701      run_options: (config_pb2.RunOptions) The RunOptions containing the tensor
702        watch options for this graph.
703    """
704    run_options = config_pb2.RunOptions()
705    dump_path = self._cont_call_dump_path()
706    for element_name in self._closure_elements:
707      if ":" in element_name:
708        debug_utils.add_debug_tensor_watch(
709            run_options,
710            debug_graphs.get_node_name(element_name),
711            output_slot=debug_graphs.get_output_slot(element_name),
712            debug_urls=["file://" + dump_path])
713
714    return dump_path, run_options
715
716  def _cont_call_dump_path(self):
717    return os.path.join(self._dump_session_root,
718                        "cont_%d" % int(time.time() * 1e6))
719
720  def _load_dumped_intermediate_tensors(self, dump_path, target_name):
721    dump_dir = debug_data.DebugDumpDir(dump_path, validate=False)
722    for dump in dump_dir.dumped_tensor_data:
723      if (dump.tensor_name not in self._ref_tensor_names and
724          dump.tensor_name not in self._tensor_handles and
725          dump.tensor_name not in self._override_tensors and
726          dump.tensor_name != target_name):
727        self._dumped_intermediate_tensors[dump.tensor_name] = dump
728
729  def _get_node_name(self, graph_element_name):
730    return graph_element_name.split(":")[0]
731
732  def _invalidate_transitively_outgoing_cache(self, source_element):
733    """Invalidate the cached tensor handles by tracing output.
734
735    This method is used to invalidate caches such as cached TensorHandles
736    and intermediate tensor values when Variable mutation happens or when
737    client overrides tensor values.
738
739    Uses non-recursive implementation to avoid stack overflow on deep networks.
740
741    Args:
742      source_element: The source graph element (e.g., a Variable output slot)
743        to trace the output from.
744    """
745
746    if not self._tensor_handles and not self._dumped_intermediate_tensors:
747      return
748
749    # First, use cached invalidation paths to eliminate some cached tensor
750    # handles and intermediate tensors.
751    to_delete_handles = []
752    for handle_name in self._tensor_handles:
753      if (handle_name in self._cached_invalidation_path and
754          source_element in self._cached_invalidation_path[handle_name]):
755        to_delete_handles.append(handle_name)
756    for handle_name in to_delete_handles:
757      del self._tensor_handles[handle_name]
758
759    to_delete_intermediates = []
760    for intm_tensor_name in self._dumped_intermediate_tensors:
761      if (intm_tensor_name in self._cached_invalidation_path and
762          source_element in self._cached_invalidation_path[intm_tensor_name]):
763        to_delete_intermediates.append(intm_tensor_name)
764    for intermediate in to_delete_intermediates:
765      del self._dumped_intermediate_tensors[intermediate]
766
767    if not self._tensor_handles and not self._dumped_intermediate_tensors:
768      return
769
770    stack = [source_element]
771    done = set()
772
773    while stack:
774      curr_element = stack.pop()
775      done.add(curr_element)
776
777      if (curr_element in self._tensor_handles or
778          curr_element in self._dumped_intermediate_tensors):
779        # Cache the invalidation path for potential future use.
780        if curr_element not in self._cached_invalidation_path:
781          self._cached_invalidation_path[curr_element] = set([source_element])
782        else:
783          self._cached_invalidation_path[curr_element].add(source_element)
784
785        if curr_element in self._tensor_handles:
786          del self._tensor_handles[curr_element]
787        else:
788          del self._dumped_intermediate_tensors[curr_element]
789
790      targets = self._output_targets.get(curr_element, [])
791      for target in targets:
792        if target in done:
793          continue
794        else:
795          stack.append(target)
796
797  def finalize(self):
798    """Run the final fetch(es).
799
800    Restore the dirty variables; ignore the client-supplied overriding tensor
801    values.
802
803    Returns:
804      The same return value as self.cont() as called on the final fetch.
805    """
806
807    self.restore_variable_values()
808    return self._sess.run(self._fetches, feed_dict=self._client_feed_dict)
809
810  def restore_variable_values(self):
811    """Restore variables to the initial values.
812
813    "Initial value" refers to the value when this NodeStepper instance was
814    first constructed.
815    """
816
817    for var_name in self._dirty_variables:
818      self._sess.run(self._variable_initializers[var_name],
819                     feed_dict={
820                         self._variable_initial_values[var_name]:
821                             self._cached_variable_values[var_name]
822                     })
823
824  def handle_names(self):
825    """Return names of the TensorHandles that the debugger is holding.
826
827    Returns:
828      (list of str) Name of the tensors for which TensorHandle is available.
829    """
830
831    return [name for name in self._tensor_handles]
832
833  def handle_node_names(self):
834    """Get list of names of the nodes for which handles are available.
835
836    Returns:
837      (set of str) List of names of the nodes.
838    """
839
840    return set([self._get_node_name(name) for name in self._tensor_handles])
841
842  def intermediate_tensor_names(self):
843    """Get list of the names of the Tensors for which dumps are available.
844
845    Returns:
846      (list of str) List of the names of the Tensors for which intermediate
847        dumps are available.
848    """
849
850    return self._dumped_intermediate_tensors.keys()
851
852  def last_updated(self):
853    """Get the names of the variables updated in the last cont() call.
854
855    Returns:
856      A set of the variable names updated in the previous cont() call.
857      If no cont() call has occurred before, returns None.
858    """
859
860    return self._last_updated
861
862  def dirty_variables(self):
863    """Get the set of variables that are currently "dirty".
864
865    "dirty" means:
866      previous cont() calls have updated the value of the Variable,
867      and the Variable's old value (the value before any cont() calls
868      happened) was not restored.
869
870    Returns:
871      (set) A set of dirty variables.
872    """
873
874    return self._dirty_variables
875
876  def is_placeholder(self, graph_element_name):
877    """Check whether a graph element is a Placeholder, by name.
878
879    Args:
880      graph_element_name: (str) Name of the tensor or op to be tested.
881
882    Returns:
883      (bool) Whether the graph element of the specified name is a Placeholder
884        op or the output Tensor of a Placeholder op.
885
886    Raises:
887      ValueError: If graph_element_name is not in the transitive closure of the
888        stepper instance.
889    """
890
891    node_name = self._get_node_name(graph_element_name)
892    if node_name not in self.sorted_nodes():
893      raise ValueError(
894          "%s is not in the transitive closure of this NodeStepper "
895          "instance" % graph_element_name)
896
897    graph_element = self._sess.graph.as_graph_element(graph_element_name)
898    if not isinstance(graph_element, ops.Operation):
899      graph_element = graph_element.op
900    return graph_element.type == "Placeholder"
901
902  def placeholders(self):
903    """Get the list of Placeholder Tensors in the transitive closure.
904
905    Returns:
906      (list of str) A list of Placeholder Tensors or ops in the transitive
907        closure.
908    """
909
910    placeholders = []
911    for item in self.sorted_nodes():
912      if self.is_placeholder(item):
913        placeholders.append(item)
914
915    return placeholders
916
917  def get_tensor_value(self, tensor_name):
918    """Get the value of a tensor that the stepper has access to.
919
920    Args:
921      tensor_name: (str) Name of the tensor.
922
923    Returns:
924      Value of the tensor, from overriding values or cached tensor handles.
925
926    Raises:
927      ValueError: If the value is not available as an overriding value
928        or through a TensorHandle.
929    """
930
931    if self.is_placeholder(tensor_name):
932      if ":" not in tensor_name:
933        tensor_name += ":0"
934      return self._client_feed_dict[tensor_name]
935    elif tensor_name in self._override_tensors:
936      return self._override_tensors[tensor_name]
937    elif tensor_name in self._tensor_handles:
938      return self._tensor_handles[tensor_name].eval()
939    elif tensor_name in self._dumped_intermediate_tensors:
940      return self._dumped_intermediate_tensors[tensor_name].get_tensor()
941    else:
942      raise ValueError(
943          "This stepper instance does not have access to the value of "
944          "tensor \"%s\"" % tensor_name)
945
946  def override_names(self):
947    """Return names of the TensorHandles that the debugger is holding.
948
949    Returns:
950      (list of str) Name of the tensor for which overriding tensor values are
951        available.
952    """
953    return [name for name in self._override_tensors]
954
955  def _get_node(self, element):
956    """Get the node of a graph element.
957
958    Args:
959      element: A graph element (Op, Tensor or Node)
960
961    Returns:
962      The node associated with element in the graph.
963    """
964
965    node_name, _ = debug_graphs.parse_node_or_tensor_name(element.name)
966    return self._sess.graph.as_graph_element(node_name)
967