1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TensorFlow Debugger (tfdbg) Stepper Module.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import copy 21import os 22import shutil 23import tempfile 24import time 25 26import six 27 28from tensorflow.core.protobuf import config_pb2 29from tensorflow.python.debug.lib import debug_data 30from tensorflow.python.debug.lib import debug_graphs 31from tensorflow.python.debug.lib import debug_utils 32from tensorflow.python.framework import ops 33from tensorflow.python.ops import session_ops 34 35 36# TODO(cais): Use nest.flatten once it handles nest Dicts correctly. 37def _flatten_fetches(fetches): 38 """Flatten list, tuple of fetches, or a single fetch into a list of fetches. 39 40 Args: 41 fetches: The fetches to flatten: Can be a single Tensor, Op, or a 42 potentially nested list, tuple or dict of such individual fetches. 43 44 Returns: 45 The fetches flattened to a list. 46 """ 47 48 flattened = [] 49 if isinstance(fetches, (list, tuple)): 50 for fetch in fetches: 51 flattened.extend(_flatten_fetches(fetch)) 52 elif isinstance(fetches, dict): 53 for key in fetches: 54 flattened.extend(_flatten_fetches(fetches[key])) 55 else: 56 flattened.append(fetches) 57 58 return flattened 59 60 61class NodeStepper(object): 62 """TensorFlow Debugger (tfdbg) stepper. 63 64 The stepper provides ability to perform "continue to" actions on a graph, 65 given fetch and feeds. The stepper calculates the transitive closure of the 66 fetch. cont() (continue to) calls can only be performed on members of the 67 transitive closure. 68 69 On a cont() call, the stepper performs depth-first tracing of the input 70 tree of the target. When it reaches an input where one of the following is 71 available, it will supply the available value to the feed_dict of the cont() 72 call: 73 (1) Overriding (injected) values from the client. 74 (2) TensorHandles from previous cont() calls. 75 (3) Dumped intermediate Tensors from previous cont() calls. 76 (4) Feeds supplied during the construction of the stepper instance. 77 78 During the cont() call, intermediate Tensors are dumped to temporary 79 directories. The dumped Tensor values will be used in subsequent cont() calls 80 when they are required as data dependencies. 81 82 The temporary directories are automatically clean when the NodeStepper 83 instance exits as a context manager. 84 85 Once the tracing is complete, it will issue a run() call on the 86 underlying session, using the aforementioned feed_dict prepared by the input 87 tracing, to achieve the "continue-to" action. The above process takes into 88 account whether the transitive closure of an input contains Variables that 89 are updated during previous cont() calls on this stepper instance. If such 90 updates exist, we say the transitive closure is "dirty" and the stepper 91 can restore the "clean" state of the Variable and avoid using the 92 TensorHandle. 93 94 Example of basic usage: 95 a = tf.Variable(1.0, name="a") 96 b = tf.Variable(2.0, anme="b") 97 c = tf.add(a, b, name="c") 98 d = tf.multiply(a, c, name="d") 99 100 sess = tf.Session() 101 sess.run(tf.initialize_all_varialbes()) 102 stepper = NodeStepper(sess, d) 103 104 stepper.cont(c) # Caches the handle to Tensor c:0. 105 stepper.cont(d) # Uses handle to Tensor c:0, avoiding recomputing c. 106 """ 107 108 # Possible types of feed used during cont() calls. 109 FEED_TYPE_CLIENT = "client" 110 FEED_TYPE_HANDLE = "handle" 111 FEED_TYPE_OVERRIDE = "override" 112 FEED_TYPE_DUMPED_INTERMEDIATE = "dumped_intermediate" 113 114 def __init__(self, sess, fetches, feed_dict=None): 115 """Constructor for Debugger. 116 117 Args: 118 sess: (Session) the TensorFlow Session to step in. 119 fetches: Same as the fetches input argument to `Session.run()`. 120 feed_dict: Same as the feed_dict input argument to `Session.run()`. 121 """ 122 123 self._sess = sess 124 125 self._fetches = fetches 126 flattened_fetches = _flatten_fetches(fetches) 127 128 self._fetch_names, self._fetch_list = self._get_fetch_and_name_lists( 129 flattened_fetches) 130 131 # A map from Variable name to initializer op. 132 self._variable_initializers = {} 133 134 # A map from Variable name to initial value, used when overriding or 135 # restoring Variable values. 136 self._variable_initial_values = {} 137 138 # Initialize the map for output recipients (targets). 139 self._output_targets = {} 140 141 # Sorted transitive closure of the fetched node. 142 # We also collect the list of the names of the reference-type Tensors, 143 # because we later need to avoid using intermediate dumps for such Tensors. 144 (self._sorted_nodes, 145 self._closure_elements, 146 self._ref_tensor_names) = self._dfs_visit(self._sess.graph, 147 self._fetch_list) 148 149 self._transitive_closure_set = set(self._sorted_nodes) 150 151 # A map from Variable name to the old values (before any cont() calls). 152 self._cached_variable_values = {} 153 154 # A cache map from tensor name to what variables may invalidate the tensor 155 self._cached_invalidation_path = {} 156 157 # Keep track of which variables are in a dirty state. 158 self._dirty_variables = set() 159 160 # Variables updated in the last cont() call. 161 self._last_updated = None 162 163 # Cached tensor handles: a dict with keys as tensor names and values as 164 # tensor handles. 165 self._tensor_handles = {} 166 167 # Cached intermediate tensor values: a dict mapping tensor names to 168 # DebugTensorDatum. 169 self._dumped_intermediate_tensors = {} 170 self._dump_session_root = tempfile.mkdtemp(prefix="tfdbg_stepper_") 171 172 # Feed dict from the client. 173 self._client_feed_dict = {} 174 if feed_dict: 175 for key in feed_dict: 176 if isinstance(key, ops.Tensor): 177 self._client_feed_dict[key.name] = feed_dict[key] 178 else: 179 self._client_feed_dict[key] = feed_dict[key] 180 181 # Overriding tensor values. 182 self._override_tensors = {} 183 184 # What the feed types were used by the last cont() call. 185 self._last_feed_types = {} 186 187 def __enter__(self): 188 return self 189 190 def __exit__(self, exc_type, exc_value, exc_traceback): 191 if os.path.isdir(self._dump_session_root): 192 shutil.rmtree(self._dump_session_root) 193 194 def _get_fetch_and_name_lists(self, flattened_fetches): 195 """Get the lists of fetches and their names. 196 197 Args: 198 flattened_fetches: A list of fetches or their names. Can mix fetches and 199 names. 200 201 Returns: 202 (list of str): A list of the names of the fetches. 203 (list): A list of the fetches. 204 """ 205 206 fetch_names = [] 207 fetch_list = [] 208 for fetch in flattened_fetches: 209 if isinstance(fetch, six.string_types): 210 fetch_names.append(fetch) 211 fetch_list.append(self._sess.graph.as_graph_element(fetch)) 212 else: 213 fetch_names.append(fetch.name) 214 fetch_list.append(fetch) 215 216 return fetch_names, fetch_list 217 218 def _dfs_visit(self, graph, elem_list): 219 """Trace back the input of a graph element, using depth-first search. 220 221 Uses non-recursive implementation to prevent stack overflow for deep 222 graphs. 223 224 Also performs the following action(s): 225 1) When encountering a Variable, obtain its initializer op, to 226 facilitate possible subsequent restoration / overriding of variable 227 value. 228 229 Args: 230 graph: A TF graph instance. 231 elem_list: list of graph elements: a Tensor or an Operation. 232 233 Returns: 234 (list of str) A topologically-sorted list of all nodes (not tensors) 235 in the transitive closure of elem_list. Obviously, the topological sort 236 is not unique in general. The return value here is just an arbitrary 237 one of potentially many possible topological sorts. 238 (list of str) A list of all graph elements (nodes and/or tensors) in the 239 transitive closure. 240 """ 241 242 # These set should hold only strings, i.e, names of the nodes. 243 done = set() # Keep track of visited graph elements. 244 245 # A list of str: Names of the topologically-sorted graph elements. 246 node_inputs = dict() # New: Input map of nodes in the transitive closure. 247 248 elem_stack = copy.copy(elem_list) 249 250 # Graph elements in the transitive closure, including the nodes and tensors. 251 closure_elements = [elem.name for elem in elem_list] 252 253 ref_tensor_names = set() 254 for element in elem_list: 255 if isinstance(element, ops.Tensor) and element.dtype._is_ref_dtype: # pylint: disable=protected-access 256 ref_tensor_names.add(element.name) 257 258 while elem_stack: 259 curr_elem = elem_stack.pop() 260 curr_node = self._get_node(curr_elem) 261 262 done.add(curr_node.name) 263 264 non_control_inputs = [inp for inp in curr_node.inputs] 265 control_inputs = [inp for inp in curr_node.control_inputs] 266 all_inputs = set(non_control_inputs + control_inputs) 267 268 if curr_node.name not in node_inputs: 269 all_input_nodes = set() 270 for inp in all_inputs: 271 all_input_nodes.add(self._get_node(inp).name) 272 node_inputs[curr_node.name] = all_input_nodes 273 274 # Iterate through the (non-control) inputs. 275 for inp in all_inputs: 276 # Set up the non-control output map. 277 # if is_non_control_input: 278 if inp.name not in self._output_targets: 279 self._output_targets[inp.name] = set([curr_elem.name]) 280 else: 281 self._output_targets[inp.name].add(curr_elem.name) 282 283 if (isinstance(inp, ops.Tensor) and 284 inp.op.type in ["Variable", "VariableV2"] and 285 inp.name not in self._variable_initializers): 286 # Obtain the initializer op of the variable, in case the Variable's 287 # value needs to be restored later. 288 initializer = graph.as_graph_element(inp.op.name + "/Assign") 289 self._variable_initializers[inp.name] = initializer 290 self._variable_initial_values[inp.name] = initializer.inputs[1] 291 292 inp_node = self._get_node(inp) 293 if inp_node.name in done: 294 # Already visited. 295 continue 296 297 elem_stack.append(inp) 298 closure_elements.append(inp.name) 299 if isinstance(inp, ops.Tensor) and inp.dtype._is_ref_dtype: # pylint: disable=protected-access 300 ref_tensor_names.add(inp.name) 301 302 # Now that we have traversed the transitive closure and obtained the 303 # node-input map, we can topologically sort them. 304 sorted_nodes = [] 305 stack = [] 306 for node in node_inputs: 307 if not node_inputs[node]: 308 stack.append(node) 309 for node in stack: 310 del node_inputs[node] 311 312 while stack: 313 curr_node = stack.pop() 314 sorted_nodes.append(curr_node) 315 316 # Iterate through the node-input map and remove the child. 317 pushes = [] 318 for node in node_inputs: 319 if curr_node in node_inputs[node]: 320 node_inputs[node].remove(curr_node) 321 if not node_inputs[node]: 322 pushes.append(node) 323 324 # Delete new pushes from node-input map. 325 for node in pushes: 326 del node_inputs[node] 327 328 stack.extend(pushes) 329 330 return sorted_nodes, closure_elements, ref_tensor_names 331 332 def sorted_nodes(self): 333 """Get a topologically-sorted list of node names of the stepper. 334 335 These are the names of the nodes (i.e., not Tensors) in the transitive 336 closure of the stepper, in a topologically-sorted order. 337 338 Returns: 339 (list of str): Sorted transitive inputs to the fetch of the stepper 340 instance. The fetch itself is included in the list. 341 """ 342 343 return self._sorted_nodes 344 345 def closure_elements(self): 346 """Get a name list of the graph elements of the stepper. 347 348 Returns: 349 (list of str): names of the graph elements (i.e., nodes and tensors) in 350 the transitive closure of the stepper, in a random order. 351 """ 352 353 return self._closure_elements 354 355 def output_slots_in_closure(self, node_name): 356 """Get the output tensors in the transitive closure from node. 357 358 Args: 359 node_name: (str) Name of the node in question. 360 361 Returns: 362 (list of int) Output slots of the output tensors of the node that are in 363 the transitive closure of the stepper. 364 """ 365 366 node = self._sess.graph.as_graph_element(node_name) 367 368 tensor_slots = [] 369 for i, _ in enumerate(node.outputs): 370 tensor_name = node_name + ":%d" % i 371 if tensor_name in self._closure_elements: 372 tensor_slots.append(i) 373 374 return tensor_slots 375 376 def is_feedable(self, name): 377 """Determine if a graph element if feedable. 378 379 Args: 380 name: (str) name of the graph element (Tensor or Operation) 381 382 Returns: 383 (bool) whether the graph element is feedable. 384 """ 385 386 if not isinstance(name, six.string_types): 387 raise TypeError("Expected type str; got type %s" % type(name)) 388 389 elem = self._sess.graph.as_graph_element(name) 390 return self._sess.graph.is_feedable(elem) 391 392 def override_tensor(self, tensor_name, overriding_val): 393 """Override the value of a tensor. 394 395 Args: 396 tensor_name: (str) Name of the tensor to override. 397 overriding_val: (numpy.ndarray) Overriding tensor value. 398 399 Raises: 400 ValueError: If tensor_name does not correspond to a tensor in the input 401 tree to the fetched graph element of this stepper instance. 402 """ 403 404 if not isinstance(tensor_name, six.string_types): 405 raise TypeError("Expected type str; got type %s" % type(tensor_name)) 406 407 node_name = self._get_node_name(tensor_name) 408 if node_name not in self._transitive_closure_set: 409 raise ValueError( 410 "Cannot override tensor \"%s\" because it does not exist in the " 411 "input tree to the fetch \"%s\"" % 412 (tensor_name, repr(self._fetch_names))) 413 414 self._override_tensors[tensor_name] = overriding_val 415 416 # Invalidate cache by tracing outputs. 417 self._invalidate_transitively_outgoing_cache(tensor_name) 418 419 def remove_override(self, tensor_name): 420 """Remove the overriding value on a tensor. 421 422 Args: 423 tensor_name: (str) name of the tensor to remove the overriding value 424 from. 425 426 Raises: 427 ValueError: If no overriding value exists for tensor_name. 428 """ 429 430 if tensor_name not in self._override_tensors: 431 raise ValueError("No overriding value exists for tensor \"%s\"." % 432 tensor_name) 433 434 del self._override_tensors[tensor_name] 435 436 # Invalidate cache by tracing outputs. 437 self._invalidate_transitively_outgoing_cache(tensor_name) 438 439 def last_feed_types(self): 440 """Obtain information about the feed in the last cont() call. 441 442 Returns: 443 (dict) A dict mapping tensor names to feed types. 444 """ 445 446 return self._last_feed_types 447 448 def cont(self, 449 target, 450 use_tensor_handles=True, 451 use_dumped_intermediates=True, 452 use_overrides=True, 453 invalidate_from_updated_variables=False, 454 restore_variable_values=False): 455 """Continue till the completion of the specified target tensor. 456 457 Args: 458 target: A single fetched Tensor or Op, or a name (str) representing the 459 Tensor or Op. In the case of a name str, the graph will be searched 460 to find the corresponding Tensor or Op. 461 # TODO(cais): Support multiple fetches as in Session.run() interface. 462 use_tensor_handles: (bool) Whether this cont() run will use cached tensor 463 handles to avoid recomputation. Default: True. 464 use_dumped_intermediates: (bool) Whether this cont() call will use dumped 465 intermediate tensors to avoid recomputation. 466 use_overrides: (bool) Whether the overriding tensor values supplied by 467 the client are to be used in this cont() call. Default: True. 468 invalidate_from_updated_variables: (bool) Whether to invalidate the 469 tensor handles and intermediate tensor handles affected by the 470 Variable updates that happen in this cont() call. 471 restore_variable_values: (bool) Whether the old values of the variables 472 (before any cont() calls in this object) are to be restored. 473 474 Returns: 475 Value from Session.run() of the target. 476 477 Raises: 478 ValueError: If the target is specified as a string and the string does 479 not correspond to any tensors in the Session graph. 480 Or if the target of this cont() is not in the input list of the Stepper 481 object's target. 482 Or if target is a Placeholder. 483 """ 484 485 self._last_feed_types = {} 486 487 if isinstance(target, six.string_types): 488 # Fetch target is a string. Assume it is the name of the Tensor or Op and 489 # will attempt to find it in the Session's graph. 490 target_name = target 491 else: 492 target_name = target.name 493 494 graph_element = self._sess.graph.as_graph_element(target_name) 495 # Any additional tensor handles to obtain in this cont() action. 496 additional_handle_requests = [] 497 498 if (isinstance(graph_element, ops.Tensor) and 499 graph_element.op.type == "Placeholder"): 500 self._last_feed_types[graph_element.name] = self.FEED_TYPE_CLIENT 501 return self._client_feed_dict[graph_element.name] 502 elif (isinstance(graph_element, ops.Operation) and 503 graph_element.type == "Placeholder"): 504 tensor_name = graph_element.name + ":0" 505 self._last_feed_types[tensor_name] = self.FEED_TYPE_CLIENT 506 return self._client_feed_dict[tensor_name] 507 508 if isinstance(graph_element, ops.Operation) and graph_element.outputs: 509 # Check if this op has any output tensors that also fall into this 510 # stepper's transitive closure. 511 node_outputs = [ 512 output.name for output in graph_element.outputs 513 if output.name in self._closure_elements 514 ] 515 if node_outputs: 516 # The target is an op with at least one output within the transitive 517 # closure. The cont() action will amount to using the 0-th 518 # output Tensor as the target, as well as obtaining handles to it 519 # and to the rest of the outputs tensors in the transitive closure 520 # (if any). 521 target_name = node_outputs[0] 522 additional_handle_requests = node_outputs[1:] 523 524 # Verify that the target is in the transitive closure of the stepper's 525 # fetch. 526 target_node_name = self._get_node_name(target_name) 527 if target_node_name not in self._transitive_closure_set: 528 raise ValueError( 529 "Target \"%s\" is not in the transitive closure for the fetch of the " 530 "stepper: \"%s\"." % (target_name, repr(self._fetch_names))) 531 532 # Check if a cached tensor handle can be used on the fetch directly. 533 if use_tensor_handles and target_name in self._tensor_handles: 534 self._last_feed_types[target_name] = self.FEED_TYPE_HANDLE 535 return self._tensor_handles[target_name].eval() 536 537 # Check if a dumped intermediate tensor can be used on the fetch directly. 538 if (use_dumped_intermediates and 539 target_name in self._dumped_intermediate_tensors): 540 self._last_feed_types[target_name] = self.FEED_TYPE_DUMPED_INTERMEDIATE 541 return self._dumped_intermediate_tensors[target_name].get_tensor() 542 543 # Check if an overriding tensor value can be used directly. 544 if use_overrides and target_name in self._override_tensors: 545 # Override is available. Return the value right away. 546 self._last_feed_types[target_name] = self.FEED_TYPE_OVERRIDE 547 return self._override_tensors[target_name] 548 549 # Keep track of which variables are restored in this cont() call. 550 restored_variables = set() 551 552 # Keep track of which variables are "touched" (i.e., possibly updated) in 553 # this cont() call. 554 self._last_updated = set() 555 556 # ========================================================================= 557 # Use a non-recursive method to trace the inputs from the node and set up 558 # the feeds. 559 feeds = {} # The feeds to be used in the Session.run() call. 560 fetched = self._sess.graph.as_graph_element(target_name) 561 elem_stack = [fetched] 562 done = set() 563 564 while elem_stack: 565 curr_elem = elem_stack.pop() 566 curr_node = self._get_node(curr_elem) 567 568 done.add(curr_node.name) 569 570 non_control_inputs = [inp for inp in curr_node.inputs] 571 control_inputs = [inp for inp in curr_node.control_inputs] 572 all_inputs = set(non_control_inputs + control_inputs) 573 574 # Iterate through the (non-control) inputs. 575 for inp in all_inputs: 576 # Determine whether the input is feedable. Reference-type tensors, 577 # e.g., Variables, should not be fed, because they can change. 578 if isinstance(inp, ops.Tensor): 579 is_inp_ref = inp.dtype._is_ref_dtype # pylint: disable=protected-access 580 can_feed = self._sess.graph.is_feedable(inp) and not is_inp_ref 581 else: 582 is_inp_ref = False 583 can_feed = False 584 585 if (restore_variable_values and inp.name in self._dirty_variables and 586 inp.name not in restored_variables and 587 inp.name not in self._last_updated): 588 # Do not restore Variables touched or restored previously in this 589 # cont() call. 590 initializer_op = self._variable_initializers[inp.name] 591 initial_value_tensor = self._variable_initial_values[inp.name] 592 self._sess.run(initializer_op, 593 feed_dict={ 594 initial_value_tensor: 595 self._cached_variable_values[inp.name] 596 }) 597 598 # Mark the variable as restored. 599 restored_variables.add(inp.name) 600 601 # Determine if this is a reference-type input from a variable, and 602 # the recipient node is not Identity. In that case, the Variable 603 # needs to be marked as dirty and its current value recorded, due to 604 # the fact that the receiving op may mutate the value of the Variable. 605 if (is_inp_ref and inp.op.type in ["Variable", "VariableV2"] and 606 curr_node.type != "Identity"): 607 # Mark the variable as dirty. 608 self._last_updated.add(inp.name) 609 610 # Obtain the old value of the variable and cache it. 611 if inp.name not in self._cached_variable_values: 612 old_value = self._sess.run(inp) 613 self._cached_variable_values[inp.name] = old_value 614 615 # N.B.: The order of the logical branches matters. For example, 616 # _client_feed_dict comes after _tensor_handles, so that tensor 617 # handles stored in cont() calls can override the original client 618 # feeds. Also for example, _override_tensors comes the first, so 619 # the manual overriding, if exists, can always take effect. 620 if use_overrides and can_feed and inp.name in self._override_tensors: 621 # Use client-supplied overriding tensor value. 622 feeds[inp] = self._override_tensors[inp.name] 623 self._last_feed_types[inp.name] = self.FEED_TYPE_OVERRIDE 624 elif (can_feed and inp not in feeds and 625 use_tensor_handles and inp.name in self._tensor_handles): 626 # Tensor handle found in cache. 627 feeds[inp] = self._tensor_handles[inp.name] 628 self._last_feed_types[inp.name] = self.FEED_TYPE_HANDLE 629 elif (can_feed and inp not in feeds and 630 use_dumped_intermediates and 631 inp.name in self._dumped_intermediate_tensors): 632 # Dumped intermediate Tensor found. 633 feeds[inp] = self._dumped_intermediate_tensors[inp.name].get_tensor() 634 self._last_feed_types[inp.name] = self.FEED_TYPE_DUMPED_INTERMEDIATE 635 elif inp.name in self._client_feed_dict: 636 # This input is available in the client feed_dict. 637 feeds[inp] = self._client_feed_dict[inp.name] 638 self._last_feed_types[inp.name] = self.FEED_TYPE_CLIENT 639 else: 640 # There is no feed available for this input. So keep tracing its 641 # input(s). 642 inp_node = self._get_node(inp) 643 if inp_node.name in done: 644 # Already visited. 645 continue 646 647 elem_stack.append(inp) 648 done.add(inp_node.name) 649 650 # ========================================================================= 651 652 if self._last_updated: 653 self._dirty_variables.update(self._last_updated) 654 655 for variable in restored_variables: 656 self._dirty_variables.remove(variable) 657 658 (dump_path, 659 run_options) = self._prepare_cont_call_dump_path_and_run_options() 660 if isinstance(fetched, ops.Operation): 661 # The fetched is an Operation: Will not get tensor handle. 662 self._sess.run(fetched, feed_dict=feeds, options=run_options) 663 return_value = None 664 else: 665 # This is a Tensor: Will get tensor handle and cache it. 666 # Will also get the additional requested tensor handles (if any). 667 tensors_to_get_handles_for = [fetched] 668 handle_names = [target_name] 669 670 tensors_to_get_handles_for.extend([ 671 self._sess.graph.as_graph_element(h) 672 for h in additional_handle_requests 673 ]) 674 handle_names.extend(additional_handle_requests) 675 676 handles = self._sess.run( 677 [session_ops.get_session_handle(tensor) for tensor in 678 tensors_to_get_handles_for], 679 feed_dict=feeds, 680 options=run_options) 681 for handle_name, handle in zip(handle_names, handles): 682 self._tensor_handles[handle_name] = handle 683 684 return_value = self._tensor_handles[target_name].eval() 685 686 self._load_dumped_intermediate_tensors(dump_path, target_name) 687 688 if invalidate_from_updated_variables: 689 # Invalidate caches at the end. 690 for last_updated_variable in self._last_updated: 691 self._invalidate_transitively_outgoing_cache(last_updated_variable) 692 693 return return_value 694 695 def _prepare_cont_call_dump_path_and_run_options(self): 696 """Prepare the dump path and RunOptions for next cont() call. 697 698 Returns: 699 dump_path: (str) Directory path to which the intermediate tensor will be 700 dumped. 701 run_options: (config_pb2.RunOptions) The RunOptions containing the tensor 702 watch options for this graph. 703 """ 704 run_options = config_pb2.RunOptions() 705 dump_path = self._cont_call_dump_path() 706 for element_name in self._closure_elements: 707 if ":" in element_name: 708 debug_utils.add_debug_tensor_watch( 709 run_options, 710 debug_graphs.get_node_name(element_name), 711 output_slot=debug_graphs.get_output_slot(element_name), 712 debug_urls=["file://" + dump_path]) 713 714 return dump_path, run_options 715 716 def _cont_call_dump_path(self): 717 return os.path.join(self._dump_session_root, 718 "cont_%d" % int(time.time() * 1e6)) 719 720 def _load_dumped_intermediate_tensors(self, dump_path, target_name): 721 dump_dir = debug_data.DebugDumpDir(dump_path, validate=False) 722 for dump in dump_dir.dumped_tensor_data: 723 if (dump.tensor_name not in self._ref_tensor_names and 724 dump.tensor_name not in self._tensor_handles and 725 dump.tensor_name not in self._override_tensors and 726 dump.tensor_name != target_name): 727 self._dumped_intermediate_tensors[dump.tensor_name] = dump 728 729 def _get_node_name(self, graph_element_name): 730 return graph_element_name.split(":")[0] 731 732 def _invalidate_transitively_outgoing_cache(self, source_element): 733 """Invalidate the cached tensor handles by tracing output. 734 735 This method is used to invalidate caches such as cached TensorHandles 736 and intermediate tensor values when Variable mutation happens or when 737 client overrides tensor values. 738 739 Uses non-recursive implementation to avoid stack overflow on deep networks. 740 741 Args: 742 source_element: The source graph element (e.g., a Variable output slot) 743 to trace the output from. 744 """ 745 746 if not self._tensor_handles and not self._dumped_intermediate_tensors: 747 return 748 749 # First, use cached invalidation paths to eliminate some cached tensor 750 # handles and intermediate tensors. 751 to_delete_handles = [] 752 for handle_name in self._tensor_handles: 753 if (handle_name in self._cached_invalidation_path and 754 source_element in self._cached_invalidation_path[handle_name]): 755 to_delete_handles.append(handle_name) 756 for handle_name in to_delete_handles: 757 del self._tensor_handles[handle_name] 758 759 to_delete_intermediates = [] 760 for intm_tensor_name in self._dumped_intermediate_tensors: 761 if (intm_tensor_name in self._cached_invalidation_path and 762 source_element in self._cached_invalidation_path[intm_tensor_name]): 763 to_delete_intermediates.append(intm_tensor_name) 764 for intermediate in to_delete_intermediates: 765 del self._dumped_intermediate_tensors[intermediate] 766 767 if not self._tensor_handles and not self._dumped_intermediate_tensors: 768 return 769 770 stack = [source_element] 771 done = set() 772 773 while stack: 774 curr_element = stack.pop() 775 done.add(curr_element) 776 777 if (curr_element in self._tensor_handles or 778 curr_element in self._dumped_intermediate_tensors): 779 # Cache the invalidation path for potential future use. 780 if curr_element not in self._cached_invalidation_path: 781 self._cached_invalidation_path[curr_element] = set([source_element]) 782 else: 783 self._cached_invalidation_path[curr_element].add(source_element) 784 785 if curr_element in self._tensor_handles: 786 del self._tensor_handles[curr_element] 787 else: 788 del self._dumped_intermediate_tensors[curr_element] 789 790 targets = self._output_targets.get(curr_element, []) 791 for target in targets: 792 if target in done: 793 continue 794 else: 795 stack.append(target) 796 797 def finalize(self): 798 """Run the final fetch(es). 799 800 Restore the dirty variables; ignore the client-supplied overriding tensor 801 values. 802 803 Returns: 804 The same return value as self.cont() as called on the final fetch. 805 """ 806 807 self.restore_variable_values() 808 return self._sess.run(self._fetches, feed_dict=self._client_feed_dict) 809 810 def restore_variable_values(self): 811 """Restore variables to the initial values. 812 813 "Initial value" refers to the value when this NodeStepper instance was 814 first constructed. 815 """ 816 817 for var_name in self._dirty_variables: 818 self._sess.run(self._variable_initializers[var_name], 819 feed_dict={ 820 self._variable_initial_values[var_name]: 821 self._cached_variable_values[var_name] 822 }) 823 824 def handle_names(self): 825 """Return names of the TensorHandles that the debugger is holding. 826 827 Returns: 828 (list of str) Name of the tensors for which TensorHandle is available. 829 """ 830 831 return [name for name in self._tensor_handles] 832 833 def handle_node_names(self): 834 """Get list of names of the nodes for which handles are available. 835 836 Returns: 837 (set of str) List of names of the nodes. 838 """ 839 840 return set([self._get_node_name(name) for name in self._tensor_handles]) 841 842 def intermediate_tensor_names(self): 843 """Get list of the names of the Tensors for which dumps are available. 844 845 Returns: 846 (list of str) List of the names of the Tensors for which intermediate 847 dumps are available. 848 """ 849 850 return self._dumped_intermediate_tensors.keys() 851 852 def last_updated(self): 853 """Get the names of the variables updated in the last cont() call. 854 855 Returns: 856 A set of the variable names updated in the previous cont() call. 857 If no cont() call has occurred before, returns None. 858 """ 859 860 return self._last_updated 861 862 def dirty_variables(self): 863 """Get the set of variables that are currently "dirty". 864 865 "dirty" means: 866 previous cont() calls have updated the value of the Variable, 867 and the Variable's old value (the value before any cont() calls 868 happened) was not restored. 869 870 Returns: 871 (set) A set of dirty variables. 872 """ 873 874 return self._dirty_variables 875 876 def is_placeholder(self, graph_element_name): 877 """Check whether a graph element is a Placeholder, by name. 878 879 Args: 880 graph_element_name: (str) Name of the tensor or op to be tested. 881 882 Returns: 883 (bool) Whether the graph element of the specified name is a Placeholder 884 op or the output Tensor of a Placeholder op. 885 886 Raises: 887 ValueError: If graph_element_name is not in the transitive closure of the 888 stepper instance. 889 """ 890 891 node_name = self._get_node_name(graph_element_name) 892 if node_name not in self.sorted_nodes(): 893 raise ValueError( 894 "%s is not in the transitive closure of this NodeStepper " 895 "instance" % graph_element_name) 896 897 graph_element = self._sess.graph.as_graph_element(graph_element_name) 898 if not isinstance(graph_element, ops.Operation): 899 graph_element = graph_element.op 900 return graph_element.type == "Placeholder" 901 902 def placeholders(self): 903 """Get the list of Placeholder Tensors in the transitive closure. 904 905 Returns: 906 (list of str) A list of Placeholder Tensors or ops in the transitive 907 closure. 908 """ 909 910 placeholders = [] 911 for item in self.sorted_nodes(): 912 if self.is_placeholder(item): 913 placeholders.append(item) 914 915 return placeholders 916 917 def get_tensor_value(self, tensor_name): 918 """Get the value of a tensor that the stepper has access to. 919 920 Args: 921 tensor_name: (str) Name of the tensor. 922 923 Returns: 924 Value of the tensor, from overriding values or cached tensor handles. 925 926 Raises: 927 ValueError: If the value is not available as an overriding value 928 or through a TensorHandle. 929 """ 930 931 if self.is_placeholder(tensor_name): 932 if ":" not in tensor_name: 933 tensor_name += ":0" 934 return self._client_feed_dict[tensor_name] 935 elif tensor_name in self._override_tensors: 936 return self._override_tensors[tensor_name] 937 elif tensor_name in self._tensor_handles: 938 return self._tensor_handles[tensor_name].eval() 939 elif tensor_name in self._dumped_intermediate_tensors: 940 return self._dumped_intermediate_tensors[tensor_name].get_tensor() 941 else: 942 raise ValueError( 943 "This stepper instance does not have access to the value of " 944 "tensor \"%s\"" % tensor_name) 945 946 def override_names(self): 947 """Return names of the TensorHandles that the debugger is holding. 948 949 Returns: 950 (list of str) Name of the tensor for which overriding tensor values are 951 available. 952 """ 953 return [name for name in self._override_tensors] 954 955 def _get_node(self, element): 956 """Get the node of a graph element. 957 958 Args: 959 element: A graph element (Op, Tensor or Node) 960 961 Returns: 962 The node associated with element in the graph. 963 """ 964 965 node_name, _ = debug_graphs.parse_node_or_tensor_name(element.name) 966 return self._sess.graph.as_graph_element(node_name) 967