1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Framework of debug wrapper sessions. 16 17A debug wrapper session is a wrapper around a TensorFlow Python Session. 18The wrapper preserves the Session interface, most importantly the run() method, 19while providing abilities to: 20a) Intercept a run() call to a wrapped session and insert debug tensor watches 21 according to externally-specified debug URLs. 22 23b) Release control to an external (i.e., non-Session) object before and after 24 the run() call, so that the external object can perform actions such as 25 launching a UI to let users inspect the intermediate tensors and partition 26 graphs from the run() call. 27 28c) (To be implemented) Intercept a run() call and give control to DebugStepper 29 to let it perform stepping / continuing-to actions on the graph. 30 31b) (To be implemented in a future CL) Enter an instruction loop to let an 32 external object (e.g., remote client) launch run() and cont() calls 33 remotely. 34 35*** The lifetime of a debug wrapper session: *** 36 371) The wrapper session is created by calling the constructor with a 38 wrapped (normal) session as the argument: 39 wrapper = FooDebugWrapperSession(sess) 40 wherein FooDebugWrapperSession is a concrete subclass implementing the 41 abstract BaseDebugWrapperSession class below. 42 432) Near the end of the constructor call, the on_session_init() callback is 44 invoked, with a OnSessionInitRequest object as the argument. The object 45 carries the wrapped (normal) session object. 46 473) The callback handles the request and returns a OnSessionInitResponse 48 object with an action field, directing the wrapper session what to do next. 49 50If the action field in the OnSessionInitResponse is PROCEED, the constuctor 51returns. Control is released back to the caller of the constructor, which can 52invoke run() method of wrapper session with the same syntax as a non-wrapped 53session, e.g.,: 54 wrapper.run(fetches, feed_dict=feeds, options=run_options) 55 56Below, A1 - A2 is the lifetime of a wrapper run() call if the action is 57PROCEED: 58 59A1) Right at the start of each run() call, the on_run_start() callback is 60 invoked, with an OnRunStartRequest object carrying information such as 61 the fetches, the feed dict, the run options and run metadata used in 62 this run call, along with a count of how many run calls has occurred 63 on this wrapper session. The callback then returns an OnRunStartResponse 64 object, of which the action field directs what the wrapper session 65 actually will do of the run() call. 66 67 If the action is DEBUG_RUN, a debugged (tensor-watched) run will ensue, 68 with the debug URLs supplied in the debug_urls field of the response. 69 These can be file:// or grpc:// URLs, for example. 70 71 If the action is NON_DEBUG_RUN, a non-debug (normal) run will ensue. 72 73 If the action is INVOKE_STEPPER, no run() call will be issued to the 74 wrapped session. But instead, a DebugStepper (i.e., "continuation 75 debugger") will be used to perform stepping / continue-to actions on 76 the graph. 77 78TODO(cais): The event loop for the DebugStepper will request additional 79 callbacks including on_cont_start() and on_cont_end(). Add those. 80 81A2) Right before the run() returns, the on_run_end() callback is invoked, 82 with an OnRunEndRequest object as the argument, which carries information 83 including the actual action performed in the warpper run() call and the 84 run_metadata from the run() call. 85 86However, if the action field in OnSessionInitResponse is 87REMOTE_INSTR_LOOP, the constructor will automatically invoke an instruction loop 88that gives the control to a remote caller. 89 90In the remote instruction loop, the following steps will happen: 91 92B1) Callback on_instr_start() is invoked. The callback will return an 93 OnInstrStartResponse object with an action field which can order one of 94 the following actions: 95 i) a run() call with fetches, feeds and debug_urls specified. 96 ii) a DebugStepper cont() call with target specified. 97 iii) value overrides in the cached tensors from the DebugStepper. 98 iv) exit the instruction loop. 99 100B2) The wrapper session carries out the action specified above. 101 102B3) If still in the instruction loop, the wrapper session invokes the 103 on_instr_end() callback. After the on_instr_end() callback returns, jump 104 back to B1. 105 106TODO(cais): Implemented the instruction loop in B1 - B3. 107 108""" 109 110from __future__ import absolute_import 111from __future__ import division 112from __future__ import print_function 113 114import abc 115import re 116import threading 117 118import six 119 120from tensorflow.core.protobuf import config_pb2 121from tensorflow.python.client import session 122from tensorflow.python.debug.lib import debug_utils 123from tensorflow.python.debug.lib import stepper 124from tensorflow.python.framework import errors 125from tensorflow.python.framework import ops 126from tensorflow.python.platform import tf_logging 127from tensorflow.python.training import monitored_session 128from tensorflow.python.util import nest 129 130 131# Helper function. 132def _check_type(obj, expected_types): 133 """Check if an object is of the expected type. 134 135 Args: 136 obj: The object being checked. 137 expected_types: (`type` or an iterable of `type`s) The expected `type`(s) 138 of obj. 139 140 Raises: 141 TypeError: If obj is not an instance of expected_type. 142 """ 143 if not isinstance(obj, expected_types): 144 raise TypeError("Expected type %s; got type %s" % 145 (expected_types, type(obj))) 146 147 148class OnSessionInitRequest(object): 149 """Request to an on-session-init callback. 150 151 This callback is invoked during the __init__ call to a debug-wrapper session. 152 """ 153 154 def __init__(self, sess): 155 """Constructor. 156 157 Args: 158 sess: A tensorflow Session object. 159 """ 160 161 _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession)) 162 self.session = sess 163 164 165class OnSessionInitAction(object): 166 """Enum-like values for possible action to take on session init.""" 167 168 # Proceed, without special actions, in the wrapper session initialization. 169 # What action the wrapper session performs next is determined by the caller 170 # of the wrapper session. E.g., it can call run(). 171 PROCEED = "proceed" 172 173 # Instead of letting the caller of the wrapper session determine what actions 174 # the wrapper session will perform next, enter a loop to receive instructions 175 # from a remote client. 176 # For example, TensorBoard visual debugger can use this action so that it can 177 # launch session.run() calls remotely. 178 REMOTE_INSTR_LOOP = "remote_instr_loop" 179 180 181class OnSessionInitResponse(object): 182 """Response from an on-session-init callback.""" 183 184 def __init__(self, action): 185 """Constructor. 186 187 Args: 188 action: (`OnSessionInitAction`) Debugger action to take on session init. 189 """ 190 _check_type(action, str) 191 self.action = action 192 193 194class OnRunStartRequest(object): 195 """Request to an on-run-start callback. 196 197 This callback is invoked during a run() call of the debug-wrapper 198 session, immediately after the run() call counter is incremented. 199 """ 200 201 def __init__(self, fetches, feed_dict, run_options, run_metadata, 202 run_call_count, is_callable_runner=False): 203 """Constructor of `OnRunStartRequest`. 204 205 Args: 206 fetches: Fetch targets of the run() call. 207 feed_dict: The feed dictionary to the run() call. 208 run_options: RunOptions input to the run() call. 209 run_metadata: RunMetadata input to the run() call. 210 The above four arguments are identical to the input arguments to the 211 run() method of a non-wrapped TensorFlow session. 212 run_call_count: 1-based count of how many run calls (including this one) 213 has been invoked. 214 is_callable_runner: (bool) whether a runner returned by 215 Session.make_callable is being run. 216 """ 217 self.fetches = fetches 218 self.feed_dict = feed_dict 219 self.run_options = run_options 220 self.run_metadata = run_metadata 221 self.run_call_count = run_call_count 222 self.is_callable_runner = is_callable_runner 223 224 225class OnRunStartAction(object): 226 """Enum-like values for possible action to take on start of a run() call.""" 227 228 # Run once with debug tensor-watching. 229 DEBUG_RUN = "debug_run" 230 231 # Run once with profiler. 232 PROFILE_RUN = "profile_run" 233 234 # Run without debug tensor-watching. 235 NON_DEBUG_RUN = "non_debug_run" 236 237 # Instead of running the fetches as a whole, as would normally happen, invoke 238 # the (to-be-implemented) debug stepper. 239 # TODO(cais): Remove "to-be-implemented". 240 INVOKE_STEPPER = "invoke_stepper" 241 242 243class OnRunStartResponse(object): 244 """Request from an on-run-start callback. 245 246 The caller of the callback can use this response object to specify what 247 action the debug-wrapper session actually takes on the run() call. 248 """ 249 250 def __init__(self, 251 action, 252 debug_urls, 253 debug_ops="DebugIdentity", 254 node_name_regex_whitelist=None, 255 op_type_regex_whitelist=None, 256 tensor_dtype_regex_whitelist=None, 257 tolerate_debug_op_creation_failures=False): 258 """Constructor of `OnRunStartResponse`. 259 260 Args: 261 action: (`OnRunStartAction`) the action actually taken by the wrapped 262 session for the run() call. 263 debug_urls: (`list` of `str`) debug_urls used in watching the tensors 264 during the run() call. 265 debug_ops: (`str` or `list` of `str`) Debug op(s) to be used by the 266 debugger. 267 node_name_regex_whitelist: Regular-expression whitelist for node 268 name. 269 op_type_regex_whitelist: Regular-expression whitelist for op type. 270 tensor_dtype_regex_whitelist: Regular-expression whitelist for tensor 271 dtype. 272 tolerate_debug_op_creation_failures: Whether debug op creation failures 273 are to be tolerated. 274 """ 275 276 _check_type(action, str) 277 self.action = action 278 279 _check_type(debug_urls, list) 280 self.debug_urls = debug_urls 281 282 self.debug_ops = debug_ops 283 284 self.node_name_regex_whitelist = node_name_regex_whitelist 285 self.op_type_regex_whitelist = op_type_regex_whitelist 286 self.tensor_dtype_regex_whitelist = tensor_dtype_regex_whitelist 287 self.tolerate_debug_op_creation_failures = ( 288 tolerate_debug_op_creation_failures) 289 290 291class OnRunEndRequest(object): 292 """Request to an on-run-end callback. 293 294 The callback is invoked immediately before the wrapped run() call ends. 295 """ 296 297 def __init__(self, 298 performed_action, 299 run_metadata=None, 300 client_graph_def=None, 301 tf_error=None): 302 """Constructor for `OnRunEndRequest`. 303 304 Args: 305 performed_action: (`OnRunStartAction`) Actually-performed action by the 306 debug-wrapper session. 307 run_metadata: run_metadata output from the run() call (if any). 308 client_graph_def: (GraphDef) GraphDef from the client side, i.e., from 309 the python front end of TensorFlow. Can be obtained with 310 session.graph.as_graph_def(). 311 tf_error: (errors.OpError subtypes) TensorFlow OpError that occurred 312 during the run (if any). 313 """ 314 315 _check_type(performed_action, str) 316 self.performed_action = performed_action 317 318 if run_metadata is not None: 319 _check_type(run_metadata, config_pb2.RunMetadata) 320 self.run_metadata = run_metadata 321 self.client_graph_def = client_graph_def 322 self.tf_error = tf_error 323 324 325class OnRunEndResponse(object): 326 """Response from an on-run-end callback.""" 327 328 def __init__(self): 329 330 # Currently only a placeholder. 331 pass 332 333 334@six.add_metaclass(abc.ABCMeta) 335class BaseDebugWrapperSession(session.SessionInterface): 336 """Base class of debug-wrapper session classes. 337 338 Concrete classes that inherit from this class need to implement the abstract 339 methods such as on_session_init, on_run_start and on_run_end. 340 """ 341 342 # TODO(cais): Add on_cont_start and on_cont_end callbacks once the stepper is 343 # is available. 344 345 def __init__(self, sess, thread_name_filter=None, 346 pass_through_operrors=False): 347 """Constructor of `BaseDebugWrapperSession`. 348 349 Args: 350 sess: An (unwrapped) TensorFlow session instance. It should be a subtype 351 of `BaseSession` or `tf.MonitoredSession`. 352 thread_name_filter: Regular-expression filter (whitelist) for name(s) of 353 thread(s) on which the wrapper session will be active. This regular 354 expression is used in a start-anchored fashion on the thread name, i.e., 355 by applying the `match` method of the compiled pattern. The default 356 `None` means that the wrapper session will be active on all threads. 357 E.g., r"MainThread$", r"QueueRunnerThread.*". 358 pass_through_operrors: If True, all captured OpErrors will be 359 propagated. By default this captures all OpErrors. 360 361 Raises: 362 ValueError: On invalid `OnSessionInitAction` value. 363 NotImplementedError: If a non-DirectSession sess object is received. 364 """ 365 366 _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession)) 367 368 # The session being wrapped. 369 self._sess = sess 370 self._thread_name_filter_pattern = (re.compile(thread_name_filter) 371 if thread_name_filter else None) 372 # TODO(cais/kstevens): Unittest this pass through feature. 373 self._pass_through_operrors = pass_through_operrors 374 375 # Keeps track of number of run calls that have been performed on this 376 # debug-wrapper session. The count can be used for purposes such as 377 # displaying the state of the Session in a UI and determining a run 378 # number-dependent debug URL. 379 self._run_call_count = 0 380 381 # Invoke on-session-init callback. 382 response = self.on_session_init(OnSessionInitRequest(self._sess)) 383 _check_type(response, OnSessionInitResponse) 384 385 if response.action == OnSessionInitAction.PROCEED: 386 pass 387 elif response.action == OnSessionInitAction.REMOTE_INSTR_LOOP: 388 # TODO(cais): Implement REMOTE_INSTR_LOOP 389 raise NotImplementedError( 390 "OnSessionInitAction REMOTE_INSTR_LOOP has not been " 391 "implemented.") 392 else: 393 raise ValueError( 394 "Invalid OnSessionInitAction value: %s" % response.action) 395 396 self._default_session_context_manager = None 397 398 # A cache for callables created from CallableOptions. 399 self._cached_callables_from_options = dict() 400 401 @property 402 def graph(self): 403 return self._sess.graph 404 405 @property 406 def graph_def(self): 407 return self._sess.graph_def 408 409 @property 410 def sess_str(self): 411 return self._sess.sess_str 412 413 @property 414 def session(self): 415 return self._sess 416 417 def run(self, 418 fetches, 419 feed_dict=None, 420 options=None, 421 run_metadata=None, 422 callable_runner=None, 423 callable_runner_args=None, 424 callable_options=None): 425 """Wrapper around Session.run() that inserts tensor watch options. 426 427 Args: 428 fetches: Same as the `fetches` arg to regular `Session.run()`. 429 feed_dict: Same as the `feed_dict` arg to regular `Session.run()`. 430 options: Same as the `options` arg to regular `Session.run()`. 431 run_metadata: Same as the `run_metadata` arg to regular `Session.run()`. 432 callable_runner: A `callable` returned by `Session.make_callable()`. 433 If not `None`, `fetches` and `feed_dict` must both be `None`. 434 Mutually exclusive with `callable_options`. 435 callable_runner_args: An optional list of arguments to `callable_runner` 436 or for `callable_options`. 437 callable_options: An instance of `config_pb2.CallableOptions`, to be 438 used with `Session._make_callable_from_options()`. Mutually exclusive 439 with `callable_runner`. 440 441 Returns: 442 Simply forwards the output of the wrapped `Session.run()` call. 443 444 Raises: 445 ValueError: On invalid `OnRunStartAction` value. Or if `callable_runner` 446 is not `None` and either or both of `fetches` and `feed_dict` is `None`. 447 """ 448 if callable_runner and callable_options: 449 raise ValueError( 450 "callable_runner and callable_options are mutually exclusive, but " 451 "are both specified in this call to BaseDebugWrapperSession.run().") 452 453 if callable_runner and (fetches or feed_dict): 454 raise ValueError( 455 "callable_runner and fetches/feed_dict are mutually exclusive, " 456 "but are used simultaneously.") 457 elif callable_options and (fetches or feed_dict): 458 raise ValueError( 459 "callable_options and fetches/feed_dict are mutually exclusive, " 460 "but are used simultaneously.") 461 462 self.increment_run_call_count() 463 empty_fetches = not nest.flatten(fetches) 464 if empty_fetches: 465 tf_logging.info( 466 "Due to empty fetches, tfdbg Session wrapper is letting a " 467 "Session.run pass through without any debugging actions.") 468 if self._is_disabled_thread() or empty_fetches: 469 if callable_runner: 470 return callable_runner(*callable_runner_args) 471 elif callable_options: 472 # pylint:disable=protected-access 473 return self._sess._make_callable_from_options( 474 callable_options)(*callable_runner_args) 475 # pylint:enable=protected-access 476 else: 477 return self._sess.run(fetches, 478 feed_dict=feed_dict, 479 options=options, 480 run_metadata=run_metadata) 481 482 # Invoke on-run-start callback and obtain response. 483 run_start_resp = self.on_run_start( 484 OnRunStartRequest(fetches, feed_dict, options, run_metadata, 485 self._run_call_count, 486 is_callable_runner=bool(callable_runner))) 487 _check_type(run_start_resp, OnRunStartResponse) 488 489 if run_start_resp.action == OnRunStartAction.DEBUG_RUN: 490 # Decorate RunOption to fill in debugger tensor watch specifications. 491 decorated_run_options = None 492 if callable_options: 493 callable_options_id = id(callable_options) 494 if callable_options_id not in self._cached_callables_from_options: 495 # Make a copy of callable_options to avoid mutating it. 496 new_callable_options = config_pb2.CallableOptions() 497 new_callable_options.CopyFrom(callable_options) 498 decorated_run_options = new_callable_options.run_options 499 else: 500 decorated_run_options = options or config_pb2.RunOptions() 501 502 run_metadata = run_metadata or config_pb2.RunMetadata() 503 504 if decorated_run_options: 505 self._decorate_run_options_for_debug( 506 decorated_run_options, 507 run_start_resp.debug_urls, 508 debug_ops=run_start_resp.debug_ops, 509 node_name_regex_whitelist=run_start_resp.node_name_regex_whitelist, 510 op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist, 511 tensor_dtype_regex_whitelist=( 512 run_start_resp.tensor_dtype_regex_whitelist), 513 tolerate_debug_op_creation_failures=( 514 run_start_resp.tolerate_debug_op_creation_failures)) 515 516 # Invoke the run() method of the wrapped Session. Catch any TensorFlow 517 # runtime errors. 518 tf_error = None 519 try: 520 if callable_runner: 521 retvals = callable_runner(*callable_runner_args, 522 options=decorated_run_options, 523 run_metadata=run_metadata) 524 elif callable_options: 525 # pylint:disable=protected-access 526 if callable_options_id in self._cached_callables_from_options: 527 callable_object = self._cached_callables_from_options[ 528 callable_options_id] 529 else: 530 callable_object = self._sess._make_callable_from_options( 531 new_callable_options) 532 self._cached_callables_from_options[ 533 callable_options_id] = callable_object 534 # pylint:enable=protected-access 535 retvals = callable_object( 536 *callable_runner_args, run_metadata=run_metadata) 537 else: 538 retvals = self._sess.run(fetches, 539 feed_dict=feed_dict, 540 options=decorated_run_options, 541 run_metadata=run_metadata) 542 except errors.OpError as op_error: 543 if self._pass_through_operrors: 544 raise op_error 545 tf_error = op_error 546 retvals = op_error 547 548 run_end_req = OnRunEndRequest( 549 run_start_resp.action, 550 run_metadata=run_metadata, 551 client_graph_def=self._sess.graph.as_graph_def(), 552 tf_error=tf_error) 553 554 elif run_start_resp.action == OnRunStartAction.PROFILE_RUN: 555 decorated_run_options = options or config_pb2.RunOptions() 556 run_metadata = run_metadata or config_pb2.RunMetadata() 557 self._decorate_run_options_for_profile(decorated_run_options) 558 if callable_runner: 559 retvals = callable_runner(*callable_runner_args, 560 options=decorated_run_options, 561 run_metadata=run_metadata) 562 else: 563 retvals = self._sess.run(fetches, 564 feed_dict=feed_dict, 565 options=decorated_run_options, 566 run_metadata=run_metadata) 567 run_end_req = OnRunEndRequest( 568 run_start_resp.action, 569 run_metadata=run_metadata, 570 client_graph_def=self._sess.graph.as_graph_def()) 571 elif (run_start_resp.action == OnRunStartAction.NON_DEBUG_RUN or 572 run_start_resp.action == OnRunStartAction.INVOKE_STEPPER): 573 if callable_runner: 574 raise NotImplementedError( 575 "Stepper mode is not implemented for callables created by " 576 "Session.make_callable().") 577 578 if run_start_resp.action == OnRunStartAction.INVOKE_STEPPER: 579 with stepper.NodeStepper( 580 self._sess, fetches, feed_dict) as node_stepper: 581 retvals = self.invoke_node_stepper( 582 node_stepper, restore_variable_values_on_exit=True) 583 584 # Invoke run() method of the wrapped session. 585 retvals = self._sess.run( 586 fetches, 587 feed_dict=feed_dict, 588 options=options, 589 run_metadata=run_metadata) 590 591 # Prepare arg for the on-run-end callback. 592 run_end_req = OnRunEndRequest(run_start_resp.action) 593 else: 594 raise ValueError( 595 "Invalid OnRunStartAction value: %s" % run_start_resp.action) 596 597 # Invoke on-run-end callback and obtain response. 598 run_end_resp = self.on_run_end(run_end_req) 599 _check_type(run_end_resp, OnRunEndResponse) 600 # Currently run_end_resp is only a placeholder. No action is taken on it. 601 602 return retvals 603 604 def _is_disabled_thread(self): 605 thread_name = threading.current_thread().name or "" 606 return (self._thread_name_filter_pattern and 607 not self._thread_name_filter_pattern.match(thread_name)) 608 609 def run_step_fn(self, step_fn): 610 return step_fn( 611 monitored_session.MonitoredSession.StepContext(self._sess, self.run)) 612 613 def partial_run_setup(self, fetches, feeds=None): 614 """Sets up the feeds and fetches for partial runs in the session.""" 615 raise NotImplementedError( 616 "partial_run_setup is not implemented for debug-wrapper sessions.") 617 618 def partial_run(self, handle, fetches, feed_dict=None): 619 raise NotImplementedError( 620 "partial_run is not implemented for debug-wrapper sessions.") 621 622 def list_devices(self, *args, **kwargs): 623 return self._sess.list_devices(*args, **kwargs) 624 625 def reset(self, *args, **kwargs): 626 return self._sess.reset(*args, **kwargs) 627 628 def make_callable(self, 629 fetches, 630 feed_list=None, 631 accept_options=False): 632 runner = self._sess.make_callable( 633 fetches, feed_list=feed_list, accept_options=True) 634 def wrapped_runner(*runner_args, **kwargs): 635 return self.run(None, 636 feed_dict=None, 637 options=kwargs.get("options", None), 638 run_metadata=kwargs.get("run_metadata", None), 639 callable_runner=runner, 640 callable_runner_args=runner_args) 641 return wrapped_runner 642 643 def _make_callable_from_options(self, callable_options): 644 def wrapped_runner(*feed_values, **kwargs): 645 return self.run(None, 646 run_metadata=kwargs.get("run_metadata", None), 647 callable_options=callable_options, 648 callable_runner_args=feed_values) 649 return wrapped_runner 650 651 @property 652 def run_call_count(self): 653 return self._run_call_count 654 655 def increment_run_call_count(self): 656 self._run_call_count += 1 657 658 def _is_disk_usage_reset_each_run(self): 659 """Indicates whether disk usage is reset after each Session.run. 660 661 Subclasses that clean up the disk usage after every run should 662 override this protected method. 663 664 Returns: 665 (`bool`) Whether the disk usage amount is reset to zero after 666 each Session.run. 667 """ 668 return False 669 670 def _decorate_run_options_for_debug( 671 self, 672 run_options, 673 debug_urls, 674 debug_ops="DebugIdentity", 675 node_name_regex_whitelist=None, 676 op_type_regex_whitelist=None, 677 tensor_dtype_regex_whitelist=None, 678 tolerate_debug_op_creation_failures=False): 679 """Modify a RunOptions object for debug tensor watching. 680 681 Specifies request for outputting partition graphs. Adds 682 debug_tensor_watch_opts with proper debug URLs. 683 684 Args: 685 run_options: (RunOptions) the modified RunOptions object. 686 debug_urls: (list of str) debug URLs to be entered in run_options. 687 debug_tensor_watch_opts. 688 debug_ops: (str or list of str) debug op(s) to be used by the debugger. 689 node_name_regex_whitelist: Regular-expression whitelist for node 690 name. 691 op_type_regex_whitelist: Regular-expression whitelist for op type. 692 tensor_dtype_regex_whitelist: Regular-expression whitelist for tensor 693 dtype. 694 tolerate_debug_op_creation_failures: Whether debug op creation failures 695 are to be tolerated. 696 """ 697 698 run_options.output_partition_graphs = True 699 debug_utils.watch_graph( 700 run_options, 701 self._sess.graph, 702 debug_urls=debug_urls, 703 debug_ops=debug_ops, 704 node_name_regex_whitelist=node_name_regex_whitelist, 705 op_type_regex_whitelist=op_type_regex_whitelist, 706 tensor_dtype_regex_whitelist=tensor_dtype_regex_whitelist, 707 tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures, 708 reset_disk_byte_usage=(self._run_call_count == 1 or 709 self._is_disk_usage_reset_each_run())) 710 711 def _decorate_run_options_for_profile(self, run_options): 712 """Modify a RunOptions object for profiling TensorFlow graph execution. 713 714 Args: 715 run_options: (RunOptions) the modified RunOptions object. 716 """ 717 718 run_options.trace_level = config_pb2.RunOptions.FULL_TRACE 719 720 @abc.abstractmethod 721 def on_session_init(self, request): 722 """Callback invoked during construction of the debug-wrapper session. 723 724 This is a blocking callback. 725 The invocation happens right before the constructor ends. 726 727 Args: 728 request: (`OnSessionInitRequest`) callback request carrying information 729 such as the session being wrapped. 730 731 Returns: 732 An instance of `OnSessionInitResponse`. 733 """ 734 735 @abc.abstractmethod 736 def on_run_start(self, request): 737 """Callback invoked on run() calls to the debug-wrapper session. 738 739 This is a blocking callback. 740 The invocation happens after the wrapper's run() call is entered, 741 after an increment of run call counter. 742 743 Args: 744 request: (`OnRunStartRequest`) callback request object carrying 745 information about the run call such as the fetches, feed dict, run 746 options, run metadata, and how many `run()` calls to this wrapper 747 session have occurred. 748 749 Returns: 750 An instance of `OnRunStartResponse`, carrying information to 751 1) direct the wrapper session to perform a specified action (e.g., run 752 with or without debug tensor watching, invoking the stepper.) 753 2) debug URLs used to watch the tensors. 754 """ 755 756 @abc.abstractmethod 757 def on_run_end(self, request): 758 """Callback invoked on run() calls to the debug-wrapper session. 759 760 This is a blocking callback. 761 The invocation happens right before the wrapper exits its run() call. 762 763 Args: 764 request: (`OnRunEndRequest`) callback request object carrying information 765 such as the actual action performed by the session wrapper for the 766 run() call. 767 768 Returns: 769 An instance of `OnRunStartResponse`. 770 """ 771 772 def as_default(self): 773 return ops.default_session(self) 774 775 def __enter__(self): 776 if self._default_session_context_manager is None: 777 self._default_session_context_manager = self.as_default() 778 return self._default_session_context_manager.__enter__() 779 780 def __exit__(self, exec_type, exec_value, exec_tb): 781 self._default_session_context_manager.__exit__( 782 exec_type, exec_value, exec_tb) 783 784 def __del__(self): 785 if hasattr(self._sess, "__del__"): 786 self._sess.__del__() 787 788 def close(self): 789 self._sess.close() 790 791 # TODO(cais): Add _node_name_regex_whitelist and 792 # _node_op_type_regex_whitelist. 793 794 def invoke_node_stepper(self, 795 node_stepper, 796 restore_variable_values_on_exit=True): 797 """Callback invoked when the client intends to step through graph nodes. 798 799 Args: 800 node_stepper: (stepper.NodeStepper) An instance of NodeStepper to be used 801 in this stepping session. 802 restore_variable_values_on_exit: (bool) Whether any variables whose values 803 have been altered during this node-stepper invocation should be restored 804 to their old values when this invocation ends. 805 806 Returns: 807 The same return values as the `Session.run()` call on the same fetches as 808 the NodeStepper. 809 """ 810 raise NotImplementedError( 811 self.__class__.__name__ + " does not support node-stepper mode.") 812 813 814 def should_stop(self): 815 if hasattr(self._sess, "should_stop"): 816 return self._sess.should_stop() 817 else: 818 raise ValueError( 819 "The wrapped session %r does not have a method called 'should_stop'. " 820 "Do you intend to wrap a tf.MonitoredSession instead?" % self._sess) 821 822 823class WatchOptions(object): 824 """Type for return values of watch_fn.""" 825 826 def __init__(self, 827 debug_ops=None, 828 node_name_regex_whitelist=None, 829 op_type_regex_whitelist=None, 830 tensor_dtype_regex_whitelist=None, 831 tolerate_debug_op_creation_failures=False): 832 """Constructor of WatchOptions: Debug watch options. 833 834 Used as return values of `watch_fn`s. 835 836 Args: 837 debug_ops: (`str` or `list of str`) Debug ops to be used. 838 node_name_regex_whitelist: Regular-expression whitelist for node_name, 839 e.g., `"(weight_[0-9]+|bias_.*)"` 840 op_type_regex_whitelist: Regular-expression whitelist for the op type of 841 nodes, e.g., `"(Variable|Add)"`. 842 If both `node_name_regex_whitelist` and `op_type_regex_whitelist` 843 are set, the two filtering operations will occur in a logical `AND` 844 relation. In other words, a node will be included if and only if it 845 hits both whitelists. 846 tensor_dtype_regex_whitelist: Regular-expression whitelist for Tensor 847 data type, e.g., `"^int.*"`. 848 This whitelist operates in logical `AND` relations to the two whitelists 849 above. 850 tolerate_debug_op_creation_failures: (`bool`) whether debug op creation 851 failures (e.g., due to dtype incompatibility) are to be tolerated by not 852 throwing exceptions. 853 """ 854 if debug_ops: 855 self.debug_ops = debug_ops 856 else: 857 self.debug_ops = ["DebugIdentity"] 858 self.node_name_regex_whitelist = node_name_regex_whitelist 859 self.op_type_regex_whitelist = op_type_regex_whitelist 860 self.tensor_dtype_regex_whitelist = tensor_dtype_regex_whitelist 861 self.tolerate_debug_op_creation_failures = ( 862 tolerate_debug_op_creation_failures) 863 864 def __repr__(self): 865 return ("WatchOptions(debug_ops=%r, node_name_regex_whitelist=%r, " 866 "op_type_regex_whitelist=%r, tensor_dtype_regex_whitelist=%r, " 867 "tolerate_debug_op_creation_failures=%r)" % ( 868 self.debug_ops, self.node_name_regex_whitelist, 869 self.op_type_regex_whitelist, self.tensor_dtype_regex_whitelist, 870 self.tolerate_debug_op_creation_failures)) 871 872 873class NonInteractiveDebugWrapperSession(BaseDebugWrapperSession): 874 """Base class for non-interactive (i.e., non-CLI) debug wrapper sessions.""" 875 876 def __init__(self, sess, watch_fn=None, thread_name_filter=None, 877 pass_through_operrors=False): 878 """Constructor of NonInteractiveDebugWrapperSession. 879 880 Args: 881 sess: The TensorFlow `Session` object being wrapped. 882 watch_fn: (`Callable`) A Callable that maps the fetches and feeds of a 883 debugged `Session.run()` call to `WatchOptions.` 884 * Args: 885 * `fetches`: the fetches to the `Session.run()` call. 886 * `feeds`: the feeds to the `Session.run()` call. 887 888 * Returns: 889 (`tf_debug.WatchOptions`) An object containing debug options including 890 the debug ops to use, the node names, op types and/or tensor data 891 types to watch, etc. See the documentation of `tf_debug.WatchOptions` 892 for more details. 893 thread_name_filter: Regular-expression white list for threads on which the 894 wrapper session will be active. See doc of `BaseDebugWrapperSession` for 895 more details. 896 pass_through_operrors: If true, all captured OpErrors will be 897 propagated. By default this captures all OpErrors. 898 Raises: 899 TypeError: If a non-None `watch_fn` is specified and it is not callable. 900 """ 901 902 BaseDebugWrapperSession.__init__( 903 self, sess, thread_name_filter=thread_name_filter, 904 pass_through_operrors=pass_through_operrors) 905 906 self._watch_fn = None 907 if watch_fn is not None: 908 if not callable(watch_fn): 909 raise TypeError("watch_fn is not callable") 910 self._watch_fn = watch_fn 911 912 def on_session_init(self, request): 913 """See doc of BaseDebugWrapperSession.on_run_start.""" 914 915 return OnSessionInitResponse(OnSessionInitAction.PROCEED) 916 917 @abc.abstractmethod 918 def prepare_run_debug_urls(self, fetches, feed_dict): 919 """Abstract method to be implemented by concrete subclasses. 920 921 This method prepares the run-specific debug URL(s). 922 923 Args: 924 fetches: Same as the `fetches` argument to `Session.run()` 925 feed_dict: Same as the `feed_dict` argument to `Session.run()` 926 927 Returns: 928 debug_urls: (`str` or `list` of `str`) Debug URLs to be used in 929 this `Session.run()` call. 930 """ 931 932 def on_run_start(self, request): 933 """See doc of BaseDebugWrapperSession.on_run_start.""" 934 935 debug_urls, watch_opts = self._prepare_run_watch_config( 936 request.fetches, request.feed_dict) 937 938 return OnRunStartResponse( 939 OnRunStartAction.DEBUG_RUN, 940 debug_urls, 941 debug_ops=watch_opts.debug_ops, 942 node_name_regex_whitelist=watch_opts.node_name_regex_whitelist, 943 op_type_regex_whitelist=watch_opts.op_type_regex_whitelist, 944 tensor_dtype_regex_whitelist=watch_opts.tensor_dtype_regex_whitelist, 945 tolerate_debug_op_creation_failures=( 946 watch_opts.tolerate_debug_op_creation_failures)) 947 948 def _prepare_run_watch_config(self, fetches, feed_dict): 949 """Get the debug_urls, and node/op whitelists for the current run() call. 950 951 Args: 952 fetches: Same as the `fetches` argument to `Session.run()`. 953 feed_dict: Same as the `feed_dict argument` to `Session.run()`. 954 955 Returns: 956 debug_urls: (str or list of str) Debug URLs for the current run() call. 957 Currently, the list consists of only one URL that is a file:// URL. 958 watch_options: (WatchOptions) The return value of a watch_fn, containing 959 options including debug_ops, and whitelists. 960 """ 961 962 debug_urls = self.prepare_run_debug_urls(fetches, feed_dict) 963 if self._watch_fn is None: 964 watch_options = WatchOptions() 965 else: 966 watch_options = self._watch_fn(fetches, feed_dict) 967 if isinstance(watch_options, tuple): 968 # For legacy return type (tuples). 969 watch_options = WatchOptions(*watch_options) 970 971 return debug_urls, watch_options 972 973 def on_run_end(self, request): 974 """See doc of BaseDebugWrapperSession.on_run_end.""" 975 976 return OnRunEndResponse() 977 978 def invoke_node_stepper(self, 979 node_stepper, 980 restore_variable_values_on_exit=True): 981 """See doc of BaseDebugWrapperSession.invoke_node_stepper.""" 982 983 raise NotImplementedError( 984 "NonInteractiveDebugWrapperSession does not support node-stepper mode.") 985