1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Framework of debug wrapper sessions.
16
17A debug wrapper session is a wrapper around a TensorFlow Python Session.
18The wrapper preserves the Session interface, most importantly the run() method,
19while providing abilities to:
20a) Intercept a run() call to a wrapped session and insert debug tensor watches
21   according to externally-specified debug URLs.
22
23b) Release control to an external (i.e., non-Session) object before and after
24   the run() call, so that the external object can perform actions such as
25   launching a UI to let users inspect the intermediate tensors and partition
26   graphs from the run() call.
27
28c) (To be implemented) Intercept a run() call and give control to DebugStepper
29   to let it perform stepping / continuing-to actions on the graph.
30
31b) (To be implemented in a future CL) Enter an instruction loop to let an
32   external object (e.g., remote client) launch run() and cont() calls
33   remotely.
34
35*** The lifetime of a debug wrapper session: ***
36
371) The wrapper session is created by calling the constructor with a
38   wrapped (normal) session as the argument:
39     wrapper = FooDebugWrapperSession(sess)
40   wherein FooDebugWrapperSession is a concrete subclass implementing the
41   abstract BaseDebugWrapperSession class below.
42
432) Near the end of the constructor call, the on_session_init() callback is
44   invoked, with a OnSessionInitRequest object as the argument. The object
45   carries the wrapped (normal) session object.
46
473) The callback handles the request and returns a OnSessionInitResponse
48   object with an action field, directing the wrapper session what to do next.
49
50If the action field in the OnSessionInitResponse is PROCEED, the constuctor
51returns. Control is released back to the caller of the constructor, which can
52invoke run() method of wrapper session with the same syntax as a non-wrapped
53session, e.g.,:
54  wrapper.run(fetches, feed_dict=feeds, options=run_options)
55
56Below, A1 - A2 is the lifetime of a wrapper run() call if the action is
57PROCEED:
58
59A1) Right at the start of each run() call, the on_run_start() callback is
60    invoked, with an OnRunStartRequest object carrying information such as
61    the fetches, the feed dict, the run options and run metadata used in
62    this run call, along with a count of how many run calls has occurred
63    on this wrapper session. The callback then returns an OnRunStartResponse
64    object, of which the action field directs what the wrapper session
65    actually will do of the run() call.
66
67    If the action is DEBUG_RUN, a debugged (tensor-watched) run will ensue,
68    with the debug URLs supplied in the debug_urls field of the response.
69    These can be file:// or grpc:// URLs, for example.
70
71    If the action is NON_DEBUG_RUN, a non-debug (normal) run will ensue.
72
73    If the action is INVOKE_STEPPER, no run() call will be issued to the
74    wrapped session. But instead, a DebugStepper (i.e., "continuation
75    debugger") will be used to perform stepping / continue-to actions on
76    the graph.
77
78TODO(cais): The event loop for the DebugStepper will request additional
79   callbacks including on_cont_start() and on_cont_end(). Add those.
80
81A2) Right before the run() returns, the on_run_end() callback is invoked,
82    with an OnRunEndRequest object as the argument, which carries information
83    including the actual action performed in the warpper run() call and the
84    run_metadata from the run() call.
85
86However, if the action field in OnSessionInitResponse is
87REMOTE_INSTR_LOOP, the constructor will automatically invoke an instruction loop
88that gives the control to a remote caller.
89
90In the remote instruction loop, the following steps will happen:
91
92B1) Callback on_instr_start() is invoked. The callback will return an
93    OnInstrStartResponse object with an action field which can order one of
94    the following actions:
95        i) a run() call with fetches, feeds and debug_urls specified.
96       ii) a DebugStepper cont() call with target specified.
97      iii) value overrides in the cached tensors from the DebugStepper.
98       iv) exit the instruction loop.
99
100B2) The wrapper session carries out the action specified above.
101
102B3) If still in the instruction loop, the wrapper session invokes the
103    on_instr_end() callback. After the on_instr_end() callback returns, jump
104    back to B1.
105
106TODO(cais): Implemented the instruction loop in B1 - B3.
107
108"""
109
110from __future__ import absolute_import
111from __future__ import division
112from __future__ import print_function
113
114import abc
115import re
116import threading
117
118import six
119
120from tensorflow.core.protobuf import config_pb2
121from tensorflow.python.client import session
122from tensorflow.python.debug.lib import debug_utils
123from tensorflow.python.debug.lib import stepper
124from tensorflow.python.framework import errors
125from tensorflow.python.framework import ops
126from tensorflow.python.platform import tf_logging
127from tensorflow.python.training import monitored_session
128from tensorflow.python.util import nest
129
130
131# Helper function.
132def _check_type(obj, expected_types):
133  """Check if an object is of the expected type.
134
135  Args:
136    obj: The object being checked.
137    expected_types: (`type` or an iterable of `type`s) The expected `type`(s)
138      of obj.
139
140  Raises:
141      TypeError: If obj is not an instance of expected_type.
142  """
143  if not isinstance(obj, expected_types):
144    raise TypeError("Expected type %s; got type %s" %
145                    (expected_types, type(obj)))
146
147
148class OnSessionInitRequest(object):
149  """Request to an on-session-init callback.
150
151  This callback is invoked during the __init__ call to a debug-wrapper session.
152  """
153
154  def __init__(self, sess):
155    """Constructor.
156
157    Args:
158      sess: A tensorflow Session object.
159    """
160
161    _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession))
162    self.session = sess
163
164
165class OnSessionInitAction(object):
166  """Enum-like values for possible action to take on session init."""
167
168  # Proceed, without special actions, in the wrapper session initialization.
169  # What action the wrapper session performs next is determined by the caller
170  # of the wrapper session. E.g., it can call run().
171  PROCEED = "proceed"
172
173  # Instead of letting the caller of the wrapper session determine what actions
174  # the wrapper session will perform next, enter a loop to receive instructions
175  # from a remote client.
176  # For example, TensorBoard visual debugger can use this action so that it can
177  # launch session.run() calls remotely.
178  REMOTE_INSTR_LOOP = "remote_instr_loop"
179
180
181class OnSessionInitResponse(object):
182  """Response from an on-session-init callback."""
183
184  def __init__(self, action):
185    """Constructor.
186
187    Args:
188      action: (`OnSessionInitAction`) Debugger action to take on session init.
189    """
190    _check_type(action, str)
191    self.action = action
192
193
194class OnRunStartRequest(object):
195  """Request to an on-run-start callback.
196
197  This callback is invoked during a run() call of the debug-wrapper
198  session, immediately after the run() call counter is incremented.
199  """
200
201  def __init__(self, fetches, feed_dict, run_options, run_metadata,
202               run_call_count, is_callable_runner=False):
203    """Constructor of `OnRunStartRequest`.
204
205    Args:
206      fetches: Fetch targets of the run() call.
207      feed_dict: The feed dictionary to the run() call.
208      run_options: RunOptions input to the run() call.
209      run_metadata: RunMetadata input to the run() call.
210        The above four arguments are identical to the input arguments to the
211        run() method of a non-wrapped TensorFlow session.
212      run_call_count: 1-based count of how many run calls (including this one)
213        has been invoked.
214      is_callable_runner: (bool) whether a runner returned by
215        Session.make_callable is being run.
216    """
217    self.fetches = fetches
218    self.feed_dict = feed_dict
219    self.run_options = run_options
220    self.run_metadata = run_metadata
221    self.run_call_count = run_call_count
222    self.is_callable_runner = is_callable_runner
223
224
225class OnRunStartAction(object):
226  """Enum-like values for possible action to take on start of a run() call."""
227
228  # Run once with debug tensor-watching.
229  DEBUG_RUN = "debug_run"
230
231  # Run once with profiler.
232  PROFILE_RUN = "profile_run"
233
234  # Run without debug tensor-watching.
235  NON_DEBUG_RUN = "non_debug_run"
236
237  # Instead of running the fetches as a whole, as would normally happen, invoke
238  # the (to-be-implemented) debug stepper.
239  # TODO(cais): Remove "to-be-implemented".
240  INVOKE_STEPPER = "invoke_stepper"
241
242
243class OnRunStartResponse(object):
244  """Request from an on-run-start callback.
245
246  The caller of the callback can use this response object to specify what
247  action the debug-wrapper session actually takes on the run() call.
248  """
249
250  def __init__(self,
251               action,
252               debug_urls,
253               debug_ops="DebugIdentity",
254               node_name_regex_whitelist=None,
255               op_type_regex_whitelist=None,
256               tensor_dtype_regex_whitelist=None,
257               tolerate_debug_op_creation_failures=False):
258    """Constructor of `OnRunStartResponse`.
259
260    Args:
261      action: (`OnRunStartAction`) the action actually taken by the wrapped
262        session for the run() call.
263      debug_urls: (`list` of `str`) debug_urls used in watching the tensors
264        during the run() call.
265      debug_ops: (`str` or `list` of `str`) Debug op(s) to be used by the
266        debugger.
267      node_name_regex_whitelist: Regular-expression whitelist for node
268        name.
269      op_type_regex_whitelist: Regular-expression whitelist for op type.
270      tensor_dtype_regex_whitelist: Regular-expression whitelist for tensor
271        dtype.
272      tolerate_debug_op_creation_failures: Whether debug op creation failures
273        are to be tolerated.
274    """
275
276    _check_type(action, str)
277    self.action = action
278
279    _check_type(debug_urls, list)
280    self.debug_urls = debug_urls
281
282    self.debug_ops = debug_ops
283
284    self.node_name_regex_whitelist = node_name_regex_whitelist
285    self.op_type_regex_whitelist = op_type_regex_whitelist
286    self.tensor_dtype_regex_whitelist = tensor_dtype_regex_whitelist
287    self.tolerate_debug_op_creation_failures = (
288        tolerate_debug_op_creation_failures)
289
290
291class OnRunEndRequest(object):
292  """Request to an on-run-end callback.
293
294  The callback is invoked immediately before the wrapped run() call ends.
295  """
296
297  def __init__(self,
298               performed_action,
299               run_metadata=None,
300               client_graph_def=None,
301               tf_error=None):
302    """Constructor for `OnRunEndRequest`.
303
304    Args:
305      performed_action: (`OnRunStartAction`) Actually-performed action by the
306        debug-wrapper session.
307      run_metadata: run_metadata output from the run() call (if any).
308      client_graph_def: (GraphDef) GraphDef from the client side, i.e., from
309        the python front end of TensorFlow. Can be obtained with
310        session.graph.as_graph_def().
311      tf_error: (errors.OpError subtypes) TensorFlow OpError that occurred
312        during the run (if any).
313    """
314
315    _check_type(performed_action, str)
316    self.performed_action = performed_action
317
318    if run_metadata is not None:
319      _check_type(run_metadata, config_pb2.RunMetadata)
320    self.run_metadata = run_metadata
321    self.client_graph_def = client_graph_def
322    self.tf_error = tf_error
323
324
325class OnRunEndResponse(object):
326  """Response from an on-run-end callback."""
327
328  def __init__(self):
329
330    # Currently only a placeholder.
331    pass
332
333
334@six.add_metaclass(abc.ABCMeta)
335class BaseDebugWrapperSession(session.SessionInterface):
336  """Base class of debug-wrapper session classes.
337
338  Concrete classes that inherit from this class need to implement the abstract
339  methods such as on_session_init, on_run_start and on_run_end.
340  """
341
342  # TODO(cais): Add on_cont_start and on_cont_end callbacks once the stepper is
343  # is available.
344
345  def __init__(self, sess, thread_name_filter=None,
346               pass_through_operrors=False):
347    """Constructor of `BaseDebugWrapperSession`.
348
349    Args:
350      sess: An (unwrapped) TensorFlow session instance. It should be a subtype
351        of `BaseSession` or `tf.MonitoredSession`.
352      thread_name_filter: Regular-expression filter (whitelist) for name(s) of
353        thread(s) on which the wrapper session will be active. This regular
354        expression is used in a start-anchored fashion on the thread name, i.e.,
355        by applying the `match` method of the compiled pattern. The default
356        `None` means that the wrapper session will be active on all threads.
357        E.g., r"MainThread$", r"QueueRunnerThread.*".
358      pass_through_operrors: If True, all captured OpErrors will be
359        propagated.  By default this captures all OpErrors.
360
361    Raises:
362      ValueError: On invalid `OnSessionInitAction` value.
363      NotImplementedError: If a non-DirectSession sess object is received.
364    """
365
366    _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession))
367
368    # The session being wrapped.
369    self._sess = sess
370    self._thread_name_filter_pattern = (re.compile(thread_name_filter)
371                                        if thread_name_filter else None)
372    # TODO(cais/kstevens): Unittest this pass through feature.
373    self._pass_through_operrors = pass_through_operrors
374
375    # Keeps track of number of run calls that have been performed on this
376    # debug-wrapper session. The count can be used for purposes such as
377    # displaying the state of the Session in a UI and determining a run
378    # number-dependent debug URL.
379    self._run_call_count = 0
380
381    # Invoke on-session-init callback.
382    response = self.on_session_init(OnSessionInitRequest(self._sess))
383    _check_type(response, OnSessionInitResponse)
384
385    if response.action == OnSessionInitAction.PROCEED:
386      pass
387    elif response.action == OnSessionInitAction.REMOTE_INSTR_LOOP:
388      # TODO(cais): Implement REMOTE_INSTR_LOOP
389      raise NotImplementedError(
390          "OnSessionInitAction REMOTE_INSTR_LOOP has not been "
391          "implemented.")
392    else:
393      raise ValueError(
394          "Invalid OnSessionInitAction value: %s" % response.action)
395
396    self._default_session_context_manager = None
397
398    # A cache for callables created from CallableOptions.
399    self._cached_callables_from_options = dict()
400
401  @property
402  def graph(self):
403    return self._sess.graph
404
405  @property
406  def graph_def(self):
407    return self._sess.graph_def
408
409  @property
410  def sess_str(self):
411    return self._sess.sess_str
412
413  @property
414  def session(self):
415    return self._sess
416
417  def run(self,
418          fetches,
419          feed_dict=None,
420          options=None,
421          run_metadata=None,
422          callable_runner=None,
423          callable_runner_args=None,
424          callable_options=None):
425    """Wrapper around Session.run() that inserts tensor watch options.
426
427    Args:
428      fetches: Same as the `fetches` arg to regular `Session.run()`.
429      feed_dict: Same as the `feed_dict` arg to regular `Session.run()`.
430      options: Same as the `options` arg to regular `Session.run()`.
431      run_metadata: Same as the `run_metadata` arg to regular `Session.run()`.
432      callable_runner: A `callable` returned by `Session.make_callable()`.
433        If not `None`, `fetches` and `feed_dict` must both be `None`.
434        Mutually exclusive with `callable_options`.
435      callable_runner_args: An optional list of arguments to `callable_runner`
436        or for `callable_options`.
437      callable_options: An instance of `config_pb2.CallableOptions`, to be
438        used with `Session._make_callable_from_options()`. Mutually exclusive
439        with `callable_runner`.
440
441    Returns:
442      Simply forwards the output of the wrapped `Session.run()` call.
443
444    Raises:
445      ValueError: On invalid `OnRunStartAction` value. Or if `callable_runner`
446        is not `None` and either or both of `fetches` and `feed_dict` is `None`.
447    """
448    if callable_runner and callable_options:
449      raise ValueError(
450          "callable_runner and callable_options are mutually exclusive, but "
451          "are both specified in this call to BaseDebugWrapperSession.run().")
452
453    if callable_runner and (fetches or feed_dict):
454      raise ValueError(
455          "callable_runner and fetches/feed_dict are mutually exclusive, "
456          "but are used simultaneously.")
457    elif callable_options and (fetches or feed_dict):
458      raise ValueError(
459          "callable_options and fetches/feed_dict are mutually exclusive, "
460          "but are used simultaneously.")
461
462    self.increment_run_call_count()
463    empty_fetches = not nest.flatten(fetches)
464    if empty_fetches:
465      tf_logging.info(
466          "Due to empty fetches, tfdbg Session wrapper is letting a "
467          "Session.run pass through without any debugging actions.")
468    if self._is_disabled_thread() or empty_fetches:
469      if callable_runner:
470        return callable_runner(*callable_runner_args)
471      elif callable_options:
472        # pylint:disable=protected-access
473        return self._sess._make_callable_from_options(
474            callable_options)(*callable_runner_args)
475        # pylint:enable=protected-access
476      else:
477        return self._sess.run(fetches,
478                              feed_dict=feed_dict,
479                              options=options,
480                              run_metadata=run_metadata)
481
482    # Invoke on-run-start callback and obtain response.
483    run_start_resp = self.on_run_start(
484        OnRunStartRequest(fetches, feed_dict, options, run_metadata,
485                          self._run_call_count,
486                          is_callable_runner=bool(callable_runner)))
487    _check_type(run_start_resp, OnRunStartResponse)
488
489    if run_start_resp.action == OnRunStartAction.DEBUG_RUN:
490      # Decorate RunOption to fill in debugger tensor watch specifications.
491      decorated_run_options = None
492      if callable_options:
493        callable_options_id = id(callable_options)
494        if callable_options_id not in self._cached_callables_from_options:
495          # Make a copy of callable_options to avoid mutating it.
496          new_callable_options = config_pb2.CallableOptions()
497          new_callable_options.CopyFrom(callable_options)
498          decorated_run_options = new_callable_options.run_options
499      else:
500        decorated_run_options = options or config_pb2.RunOptions()
501
502      run_metadata = run_metadata or config_pb2.RunMetadata()
503
504      if decorated_run_options:
505        self._decorate_run_options_for_debug(
506            decorated_run_options,
507            run_start_resp.debug_urls,
508            debug_ops=run_start_resp.debug_ops,
509            node_name_regex_whitelist=run_start_resp.node_name_regex_whitelist,
510            op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist,
511            tensor_dtype_regex_whitelist=(
512                run_start_resp.tensor_dtype_regex_whitelist),
513            tolerate_debug_op_creation_failures=(
514                run_start_resp.tolerate_debug_op_creation_failures))
515
516      # Invoke the run() method of the wrapped Session. Catch any TensorFlow
517      # runtime errors.
518      tf_error = None
519      try:
520        if callable_runner:
521          retvals = callable_runner(*callable_runner_args,
522                                    options=decorated_run_options,
523                                    run_metadata=run_metadata)
524        elif callable_options:
525          # pylint:disable=protected-access
526          if callable_options_id in self._cached_callables_from_options:
527            callable_object = self._cached_callables_from_options[
528                callable_options_id]
529          else:
530            callable_object = self._sess._make_callable_from_options(
531                new_callable_options)
532            self._cached_callables_from_options[
533                callable_options_id] = callable_object
534          # pylint:enable=protected-access
535          retvals = callable_object(
536              *callable_runner_args, run_metadata=run_metadata)
537        else:
538          retvals = self._sess.run(fetches,
539                                   feed_dict=feed_dict,
540                                   options=decorated_run_options,
541                                   run_metadata=run_metadata)
542      except errors.OpError as op_error:
543        if self._pass_through_operrors:
544          raise op_error
545        tf_error = op_error
546        retvals = op_error
547
548      run_end_req = OnRunEndRequest(
549          run_start_resp.action,
550          run_metadata=run_metadata,
551          client_graph_def=self._sess.graph.as_graph_def(),
552          tf_error=tf_error)
553
554    elif run_start_resp.action == OnRunStartAction.PROFILE_RUN:
555      decorated_run_options = options or config_pb2.RunOptions()
556      run_metadata = run_metadata or config_pb2.RunMetadata()
557      self._decorate_run_options_for_profile(decorated_run_options)
558      if callable_runner:
559        retvals = callable_runner(*callable_runner_args,
560                                  options=decorated_run_options,
561                                  run_metadata=run_metadata)
562      else:
563        retvals = self._sess.run(fetches,
564                                 feed_dict=feed_dict,
565                                 options=decorated_run_options,
566                                 run_metadata=run_metadata)
567      run_end_req = OnRunEndRequest(
568          run_start_resp.action,
569          run_metadata=run_metadata,
570          client_graph_def=self._sess.graph.as_graph_def())
571    elif (run_start_resp.action == OnRunStartAction.NON_DEBUG_RUN or
572          run_start_resp.action == OnRunStartAction.INVOKE_STEPPER):
573      if callable_runner:
574        raise NotImplementedError(
575            "Stepper mode is not implemented for callables created by "
576            "Session.make_callable().")
577
578      if run_start_resp.action == OnRunStartAction.INVOKE_STEPPER:
579        with stepper.NodeStepper(
580            self._sess, fetches, feed_dict) as node_stepper:
581          retvals = self.invoke_node_stepper(
582              node_stepper, restore_variable_values_on_exit=True)
583
584      # Invoke run() method of the wrapped session.
585      retvals = self._sess.run(
586          fetches,
587          feed_dict=feed_dict,
588          options=options,
589          run_metadata=run_metadata)
590
591      # Prepare arg for the on-run-end callback.
592      run_end_req = OnRunEndRequest(run_start_resp.action)
593    else:
594      raise ValueError(
595          "Invalid OnRunStartAction value: %s" % run_start_resp.action)
596
597    # Invoke on-run-end callback and obtain response.
598    run_end_resp = self.on_run_end(run_end_req)
599    _check_type(run_end_resp, OnRunEndResponse)
600    # Currently run_end_resp is only a placeholder. No action is taken on it.
601
602    return retvals
603
604  def _is_disabled_thread(self):
605    thread_name = threading.current_thread().name or ""
606    return (self._thread_name_filter_pattern and
607            not self._thread_name_filter_pattern.match(thread_name))
608
609  def run_step_fn(self, step_fn):
610    return step_fn(
611        monitored_session.MonitoredSession.StepContext(self._sess, self.run))
612
613  def partial_run_setup(self, fetches, feeds=None):
614    """Sets up the feeds and fetches for partial runs in the session."""
615    raise NotImplementedError(
616        "partial_run_setup is not implemented for debug-wrapper sessions.")
617
618  def partial_run(self, handle, fetches, feed_dict=None):
619    raise NotImplementedError(
620        "partial_run is not implemented for debug-wrapper sessions.")
621
622  def list_devices(self, *args, **kwargs):
623    return self._sess.list_devices(*args, **kwargs)
624
625  def reset(self, *args, **kwargs):
626    return self._sess.reset(*args, **kwargs)
627
628  def make_callable(self,
629                    fetches,
630                    feed_list=None,
631                    accept_options=False):
632    runner = self._sess.make_callable(
633        fetches, feed_list=feed_list, accept_options=True)
634    def wrapped_runner(*runner_args, **kwargs):
635      return self.run(None,
636                      feed_dict=None,
637                      options=kwargs.get("options", None),
638                      run_metadata=kwargs.get("run_metadata", None),
639                      callable_runner=runner,
640                      callable_runner_args=runner_args)
641    return wrapped_runner
642
643  def _make_callable_from_options(self, callable_options):
644    def wrapped_runner(*feed_values, **kwargs):
645      return self.run(None,
646                      run_metadata=kwargs.get("run_metadata", None),
647                      callable_options=callable_options,
648                      callable_runner_args=feed_values)
649    return wrapped_runner
650
651  @property
652  def run_call_count(self):
653    return self._run_call_count
654
655  def increment_run_call_count(self):
656    self._run_call_count += 1
657
658  def _is_disk_usage_reset_each_run(self):
659    """Indicates whether disk usage is reset after each Session.run.
660
661    Subclasses that clean up the disk usage after every run should
662    override this protected method.
663
664    Returns:
665      (`bool`) Whether the disk usage amount is reset to zero after
666        each Session.run.
667    """
668    return False
669
670  def _decorate_run_options_for_debug(
671      self,
672      run_options,
673      debug_urls,
674      debug_ops="DebugIdentity",
675      node_name_regex_whitelist=None,
676      op_type_regex_whitelist=None,
677      tensor_dtype_regex_whitelist=None,
678      tolerate_debug_op_creation_failures=False):
679    """Modify a RunOptions object for debug tensor watching.
680
681    Specifies request for outputting partition graphs. Adds
682    debug_tensor_watch_opts with proper debug URLs.
683
684    Args:
685      run_options: (RunOptions) the modified RunOptions object.
686      debug_urls: (list of str) debug URLs to be entered in run_options.
687        debug_tensor_watch_opts.
688      debug_ops: (str or list of str) debug op(s) to be used by the debugger.
689      node_name_regex_whitelist: Regular-expression whitelist for node
690        name.
691      op_type_regex_whitelist: Regular-expression whitelist for op type.
692      tensor_dtype_regex_whitelist: Regular-expression whitelist for tensor
693        dtype.
694      tolerate_debug_op_creation_failures: Whether debug op creation failures
695        are to be tolerated.
696    """
697
698    run_options.output_partition_graphs = True
699    debug_utils.watch_graph(
700        run_options,
701        self._sess.graph,
702        debug_urls=debug_urls,
703        debug_ops=debug_ops,
704        node_name_regex_whitelist=node_name_regex_whitelist,
705        op_type_regex_whitelist=op_type_regex_whitelist,
706        tensor_dtype_regex_whitelist=tensor_dtype_regex_whitelist,
707        tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures,
708        reset_disk_byte_usage=(self._run_call_count == 1 or
709                               self._is_disk_usage_reset_each_run()))
710
711  def _decorate_run_options_for_profile(self, run_options):
712    """Modify a RunOptions object for profiling TensorFlow graph execution.
713
714    Args:
715      run_options: (RunOptions) the modified RunOptions object.
716    """
717
718    run_options.trace_level = config_pb2.RunOptions.FULL_TRACE
719
720  @abc.abstractmethod
721  def on_session_init(self, request):
722    """Callback invoked during construction of the debug-wrapper session.
723
724    This is a blocking callback.
725    The invocation happens right before the constructor ends.
726
727    Args:
728      request: (`OnSessionInitRequest`) callback request carrying information
729        such as the session being wrapped.
730
731    Returns:
732      An instance of `OnSessionInitResponse`.
733    """
734
735  @abc.abstractmethod
736  def on_run_start(self, request):
737    """Callback invoked on run() calls to the debug-wrapper session.
738
739    This is a blocking callback.
740    The invocation happens after the wrapper's run() call is entered,
741    after an increment of run call counter.
742
743    Args:
744      request: (`OnRunStartRequest`) callback request object carrying
745        information about the run call such as the fetches, feed dict, run
746        options, run metadata, and how many `run()` calls to this wrapper
747        session have occurred.
748
749    Returns:
750      An instance of `OnRunStartResponse`, carrying information to
751        1) direct the wrapper session to perform a specified action (e.g., run
752          with or without debug tensor watching, invoking the stepper.)
753        2) debug URLs used to watch the tensors.
754    """
755
756  @abc.abstractmethod
757  def on_run_end(self, request):
758    """Callback invoked on run() calls to the debug-wrapper session.
759
760    This is a blocking callback.
761    The invocation happens right before the wrapper exits its run() call.
762
763    Args:
764      request: (`OnRunEndRequest`) callback request object carrying information
765        such as the actual action performed by the session wrapper for the
766        run() call.
767
768    Returns:
769      An instance of `OnRunStartResponse`.
770    """
771
772  def as_default(self):
773    return ops.default_session(self)
774
775  def __enter__(self):
776    if self._default_session_context_manager is None:
777      self._default_session_context_manager = self.as_default()
778    return self._default_session_context_manager.__enter__()
779
780  def __exit__(self, exec_type, exec_value, exec_tb):
781    self._default_session_context_manager.__exit__(
782        exec_type, exec_value, exec_tb)
783
784  def __del__(self):
785    if hasattr(self._sess, "__del__"):
786      self._sess.__del__()
787
788  def close(self):
789    self._sess.close()
790
791  # TODO(cais): Add _node_name_regex_whitelist and
792  #   _node_op_type_regex_whitelist.
793
794  def invoke_node_stepper(self,
795                          node_stepper,
796                          restore_variable_values_on_exit=True):
797    """Callback invoked when the client intends to step through graph nodes.
798
799    Args:
800      node_stepper: (stepper.NodeStepper) An instance of NodeStepper to be used
801        in this stepping session.
802      restore_variable_values_on_exit: (bool) Whether any variables whose values
803        have been altered during this node-stepper invocation should be restored
804        to their old values when this invocation ends.
805
806    Returns:
807      The same return values as the `Session.run()` call on the same fetches as
808        the NodeStepper.
809    """
810    raise NotImplementedError(
811        self.__class__.__name__ + " does not support node-stepper mode.")
812
813
814  def should_stop(self):
815    if hasattr(self._sess, "should_stop"):
816      return self._sess.should_stop()
817    else:
818      raise ValueError(
819          "The wrapped session %r does not have a method called 'should_stop'. "
820          "Do you intend to wrap a tf.MonitoredSession instead?" % self._sess)
821
822
823class WatchOptions(object):
824  """Type for return values of watch_fn."""
825
826  def __init__(self,
827               debug_ops=None,
828               node_name_regex_whitelist=None,
829               op_type_regex_whitelist=None,
830               tensor_dtype_regex_whitelist=None,
831               tolerate_debug_op_creation_failures=False):
832    """Constructor of WatchOptions: Debug watch options.
833
834    Used as return values of `watch_fn`s.
835
836    Args:
837      debug_ops: (`str` or `list of str`) Debug ops to be used.
838      node_name_regex_whitelist: Regular-expression whitelist for node_name,
839        e.g., `"(weight_[0-9]+|bias_.*)"`
840      op_type_regex_whitelist: Regular-expression whitelist for the op type of
841        nodes, e.g., `"(Variable|Add)"`.
842        If both `node_name_regex_whitelist` and `op_type_regex_whitelist`
843        are set, the two filtering operations will occur in a logical `AND`
844        relation. In other words, a node will be included if and only if it
845        hits both whitelists.
846      tensor_dtype_regex_whitelist: Regular-expression whitelist for Tensor
847        data type, e.g., `"^int.*"`.
848        This whitelist operates in logical `AND` relations to the two whitelists
849        above.
850      tolerate_debug_op_creation_failures: (`bool`) whether debug op creation
851        failures (e.g., due to dtype incompatibility) are to be tolerated by not
852        throwing exceptions.
853    """
854    if debug_ops:
855      self.debug_ops = debug_ops
856    else:
857      self.debug_ops = ["DebugIdentity"]
858    self.node_name_regex_whitelist = node_name_regex_whitelist
859    self.op_type_regex_whitelist = op_type_regex_whitelist
860    self.tensor_dtype_regex_whitelist = tensor_dtype_regex_whitelist
861    self.tolerate_debug_op_creation_failures = (
862        tolerate_debug_op_creation_failures)
863
864  def __repr__(self):
865    return ("WatchOptions(debug_ops=%r, node_name_regex_whitelist=%r, "
866            "op_type_regex_whitelist=%r, tensor_dtype_regex_whitelist=%r, "
867            "tolerate_debug_op_creation_failures=%r)" % (
868                self.debug_ops, self.node_name_regex_whitelist,
869                self.op_type_regex_whitelist, self.tensor_dtype_regex_whitelist,
870                self.tolerate_debug_op_creation_failures))
871
872
873class NonInteractiveDebugWrapperSession(BaseDebugWrapperSession):
874  """Base class for non-interactive (i.e., non-CLI) debug wrapper sessions."""
875
876  def __init__(self, sess, watch_fn=None, thread_name_filter=None,
877               pass_through_operrors=False):
878    """Constructor of NonInteractiveDebugWrapperSession.
879
880    Args:
881      sess: The TensorFlow `Session` object being wrapped.
882      watch_fn: (`Callable`) A Callable that maps the fetches and feeds of a
883        debugged `Session.run()` call to `WatchOptions.`
884        * Args:
885          * `fetches`: the fetches to the `Session.run()` call.
886          * `feeds`: the feeds to the `Session.run()` call.
887
888        * Returns:
889         (`tf_debug.WatchOptions`) An object containing debug options including
890           the debug ops to use, the node names, op types and/or tensor data
891           types to watch, etc. See the documentation of `tf_debug.WatchOptions`
892           for more details.
893      thread_name_filter: Regular-expression white list for threads on which the
894        wrapper session will be active. See doc of `BaseDebugWrapperSession` for
895        more details.
896      pass_through_operrors: If true, all captured OpErrors will be
897        propagated.  By default this captures all OpErrors.
898    Raises:
899       TypeError: If a non-None `watch_fn` is specified and it is not callable.
900    """
901
902    BaseDebugWrapperSession.__init__(
903        self, sess, thread_name_filter=thread_name_filter,
904        pass_through_operrors=pass_through_operrors)
905
906    self._watch_fn = None
907    if watch_fn is not None:
908      if not callable(watch_fn):
909        raise TypeError("watch_fn is not callable")
910      self._watch_fn = watch_fn
911
912  def on_session_init(self, request):
913    """See doc of BaseDebugWrapperSession.on_run_start."""
914
915    return OnSessionInitResponse(OnSessionInitAction.PROCEED)
916
917  @abc.abstractmethod
918  def prepare_run_debug_urls(self, fetches, feed_dict):
919    """Abstract method to be implemented by concrete subclasses.
920
921    This method prepares the run-specific debug URL(s).
922
923    Args:
924      fetches: Same as the `fetches` argument to `Session.run()`
925      feed_dict: Same as the `feed_dict` argument to `Session.run()`
926
927    Returns:
928      debug_urls: (`str` or `list` of `str`) Debug URLs to be used in
929        this `Session.run()` call.
930    """
931
932  def on_run_start(self, request):
933    """See doc of BaseDebugWrapperSession.on_run_start."""
934
935    debug_urls, watch_opts = self._prepare_run_watch_config(
936        request.fetches, request.feed_dict)
937
938    return OnRunStartResponse(
939        OnRunStartAction.DEBUG_RUN,
940        debug_urls,
941        debug_ops=watch_opts.debug_ops,
942        node_name_regex_whitelist=watch_opts.node_name_regex_whitelist,
943        op_type_regex_whitelist=watch_opts.op_type_regex_whitelist,
944        tensor_dtype_regex_whitelist=watch_opts.tensor_dtype_regex_whitelist,
945        tolerate_debug_op_creation_failures=(
946            watch_opts.tolerate_debug_op_creation_failures))
947
948  def _prepare_run_watch_config(self, fetches, feed_dict):
949    """Get the debug_urls, and node/op whitelists for the current run() call.
950
951    Args:
952      fetches: Same as the `fetches` argument to `Session.run()`.
953      feed_dict: Same as the `feed_dict argument` to `Session.run()`.
954
955    Returns:
956      debug_urls: (str or list of str) Debug URLs for the current run() call.
957        Currently, the list consists of only one URL that is a file:// URL.
958      watch_options: (WatchOptions) The return value of a watch_fn, containing
959        options including debug_ops, and whitelists.
960    """
961
962    debug_urls = self.prepare_run_debug_urls(fetches, feed_dict)
963    if self._watch_fn is None:
964      watch_options = WatchOptions()
965    else:
966      watch_options = self._watch_fn(fetches, feed_dict)
967      if isinstance(watch_options, tuple):
968        # For legacy return type (tuples).
969        watch_options = WatchOptions(*watch_options)
970
971    return debug_urls, watch_options
972
973  def on_run_end(self, request):
974    """See doc of BaseDebugWrapperSession.on_run_end."""
975
976    return OnRunEndResponse()
977
978  def invoke_node_stepper(self,
979                          node_stepper,
980                          restore_variable_values_on_exit=True):
981    """See doc of BaseDebugWrapperSession.invoke_node_stepper."""
982
983    raise NotImplementedError(
984        "NonInteractiveDebugWrapperSession does not support node-stepper mode.")
985