1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Debugger wrapper session that sends debug data to file:// URLs."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import signal
21import sys
22import traceback
23
24import six
25
26# Google-internal import(s).
27from tensorflow.python.debug.lib import common
28from tensorflow.python.debug.wrappers import framework
29
30
31def publish_traceback(debug_server_urls,
32                      graph,
33                      feed_dict,
34                      fetches,
35                      old_graph_version):
36  """Publish traceback and source code if graph version is new.
37
38  `graph.version` is compared with `old_graph_version`. If the former is higher
39  (i.e., newer), the graph traceback and the associated source code is sent to
40  the debug server at the specified gRPC URLs.
41
42  Args:
43    debug_server_urls: A single gRPC debug server URL as a `str` or a `list` of
44      debug server URLs.
45    graph: A Python `tf.Graph` object.
46    feed_dict: Feed dictionary given to the `Session.run()` call.
47    fetches: Fetches from the `Session.run()` call.
48    old_graph_version: Old graph version to compare to.
49
50  Returns:
51    If `graph.version > old_graph_version`, the new graph version as an `int`.
52    Else, the `old_graph_version` is returned.
53  """
54  # TODO(cais): Consider moving this back to the top, after grpc becomes a
55  # pip dependency of tensorflow or tf_debug.
56  # pylint:disable=g-import-not-at-top
57  from tensorflow.python.debug.lib import source_remote
58  # pylint:enable=g-import-not-at-top
59  if graph.version > old_graph_version:
60    run_key = common.get_run_key(feed_dict, fetches)
61    source_remote.send_graph_tracebacks(
62        debug_server_urls, run_key, traceback.extract_stack(), graph,
63        send_source=True)
64    return graph.version
65  else:
66    return old_graph_version
67
68
69class GrpcDebugWrapperSession(framework.NonInteractiveDebugWrapperSession):
70  """Debug Session wrapper that send debug data to gRPC stream(s)."""
71
72  def __init__(self,
73               sess,
74               grpc_debug_server_addresses,
75               watch_fn=None,
76               thread_name_filter=None,
77               log_usage=True):
78    """Constructor of DumpingDebugWrapperSession.
79
80    Args:
81      sess: The TensorFlow `Session` object being wrapped.
82      grpc_debug_server_addresses: (`str` or `list` of `str`) Single or a list
83        of the gRPC debug server addresses, in the format of
84        <host:port>, with or without the "grpc://" prefix. For example:
85          "localhost:7000",
86          ["localhost:7000", "192.168.0.2:8000"]
87      watch_fn: (`Callable`) A Callable that can be used to define per-run
88        debug ops and watched tensors. See the doc of
89        `NonInteractiveDebugWrapperSession.__init__()` for details.
90      thread_name_filter: Regular-expression white list for threads on which the
91        wrapper session will be active. See doc of `BaseDebugWrapperSession` for
92        more details.
93      log_usage: (`bool`) whether the usage of this class is to be logged.
94
95    Raises:
96       TypeError: If `grpc_debug_server_addresses` is not a `str` or a `list`
97         of `str`.
98    """
99
100    if log_usage:
101      pass  # No logging for open-source.
102
103    framework.NonInteractiveDebugWrapperSession.__init__(
104        self, sess, watch_fn=watch_fn, thread_name_filter=thread_name_filter)
105
106    if isinstance(grpc_debug_server_addresses, str):
107      self._grpc_debug_server_urls = [
108          self._normalize_grpc_url(grpc_debug_server_addresses)]
109    elif isinstance(grpc_debug_server_addresses, list):
110      self._grpc_debug_server_urls = []
111      for address in grpc_debug_server_addresses:
112        if not isinstance(address, str):
113          raise TypeError(
114              "Expected type str in list grpc_debug_server_addresses, "
115              "received type %s" % type(address))
116        self._grpc_debug_server_urls.append(self._normalize_grpc_url(address))
117    else:
118      raise TypeError(
119          "Expected type str or list in grpc_debug_server_addresses, "
120          "received type %s" % type(grpc_debug_server_addresses))
121
122  def prepare_run_debug_urls(self, fetches, feed_dict):
123    """Implementation of abstract method in superclass.
124
125    See doc of `NonInteractiveDebugWrapperSession.prepare_run_debug_urls()`
126    for details.
127
128    Args:
129      fetches: Same as the `fetches` argument to `Session.run()`
130      feed_dict: Same as the `feed_dict` argument to `Session.run()`
131
132    Returns:
133      debug_urls: (`str` or `list` of `str`) file:// debug URLs to be used in
134        this `Session.run()` call.
135    """
136
137    return self._grpc_debug_server_urls
138
139  def _normalize_grpc_url(self, address):
140    return (common.GRPC_URL_PREFIX + address
141            if not address.startswith(common.GRPC_URL_PREFIX) else address)
142
143
144def _signal_handler(unused_signal, unused_frame):
145  while True:
146    response = six.moves.input(
147        "\nSIGINT received. Quit program? (Y/n): ").strip()
148    if response in ("", "Y", "y"):
149      sys.exit(0)
150    elif response in ("N", "n"):
151      break
152
153
154def register_signal_handler():
155  try:
156    signal.signal(signal.SIGINT, _signal_handler)
157  except ValueError:
158    # This can happen if we are not in the MainThread.
159    pass
160
161
162class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession):
163  """A tfdbg Session wrapper that can be used with TensorBoard Debugger Plugin.
164
165  This wrapper is the same as `GrpcDebugWrapperSession`, except that it uses a
166    predefined `watch_fn` that
167    1) uses `DebugIdentity` debug ops with the `gated_grpc` attribute set to
168        `True` to allow the interactive enabling and disabling of tensor
169       breakpoints.
170    2) watches all tensors in the graph.
171  This saves the need for the user to define a `watch_fn`.
172  """
173
174  def __init__(self,
175               sess,
176               grpc_debug_server_addresses,
177               thread_name_filter=None,
178               send_traceback_and_source_code=True,
179               log_usage=True):
180    """Constructor of TensorBoardDebugWrapperSession.
181
182    Args:
183      sess: The `tf.Session` instance to be wrapped.
184      grpc_debug_server_addresses: gRPC address(es) of debug server(s), as a
185        `str` or a `list` of `str`s. E.g., "localhost:2333",
186        "grpc://localhost:2333", ["192.168.0.7:2333", "192.168.0.8:2333"].
187      thread_name_filter: Optional filter for thread names.
188      send_traceback_and_source_code: Whether traceback of graph elements and
189        the source code are to be sent to the debug server(s).
190      log_usage: Whether the usage of this class is to be logged (if
191        applicable).
192    """
193    def _gated_grpc_watch_fn(fetches, feeds):
194      del fetches, feeds  # Unused.
195      return framework.WatchOptions(
196          debug_ops=["DebugIdentity(gated_grpc=true)"])
197
198    super(TensorBoardDebugWrapperSession, self).__init__(
199        sess,
200        grpc_debug_server_addresses,
201        watch_fn=_gated_grpc_watch_fn,
202        thread_name_filter=thread_name_filter,
203        log_usage=log_usage)
204
205    self._send_traceback_and_source_code = send_traceback_and_source_code
206    # Keeps track of the latest version of Python graph object that has been
207    # sent to the debug servers.
208    self._sent_graph_version = -1
209
210    register_signal_handler()
211
212  def run(self,
213          fetches,
214          feed_dict=None,
215          options=None,
216          run_metadata=None,
217          callable_runner=None,
218          callable_runner_args=None,
219          callable_options=None):
220    if self._send_traceback_and_source_code:
221      self._sent_graph_version = publish_traceback(
222          self._grpc_debug_server_urls, self.graph, feed_dict, fetches,
223          self._sent_graph_version)
224    return super(TensorBoardDebugWrapperSession, self).run(
225        fetches,
226        feed_dict=feed_dict,
227        options=options,
228        run_metadata=run_metadata,
229        callable_runner=callable_runner,
230        callable_runner_args=callable_runner_args,
231        callable_options=callable_options)
232