1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Debugger wrapper session that sends debug data to file:// URLs.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import signal 21import sys 22import traceback 23 24import six 25 26# Google-internal import(s). 27from tensorflow.python.debug.lib import common 28from tensorflow.python.debug.wrappers import framework 29 30 31def publish_traceback(debug_server_urls, 32 graph, 33 feed_dict, 34 fetches, 35 old_graph_version): 36 """Publish traceback and source code if graph version is new. 37 38 `graph.version` is compared with `old_graph_version`. If the former is higher 39 (i.e., newer), the graph traceback and the associated source code is sent to 40 the debug server at the specified gRPC URLs. 41 42 Args: 43 debug_server_urls: A single gRPC debug server URL as a `str` or a `list` of 44 debug server URLs. 45 graph: A Python `tf.Graph` object. 46 feed_dict: Feed dictionary given to the `Session.run()` call. 47 fetches: Fetches from the `Session.run()` call. 48 old_graph_version: Old graph version to compare to. 49 50 Returns: 51 If `graph.version > old_graph_version`, the new graph version as an `int`. 52 Else, the `old_graph_version` is returned. 53 """ 54 # TODO(cais): Consider moving this back to the top, after grpc becomes a 55 # pip dependency of tensorflow or tf_debug. 56 # pylint:disable=g-import-not-at-top 57 from tensorflow.python.debug.lib import source_remote 58 # pylint:enable=g-import-not-at-top 59 if graph.version > old_graph_version: 60 run_key = common.get_run_key(feed_dict, fetches) 61 source_remote.send_graph_tracebacks( 62 debug_server_urls, run_key, traceback.extract_stack(), graph, 63 send_source=True) 64 return graph.version 65 else: 66 return old_graph_version 67 68 69class GrpcDebugWrapperSession(framework.NonInteractiveDebugWrapperSession): 70 """Debug Session wrapper that send debug data to gRPC stream(s).""" 71 72 def __init__(self, 73 sess, 74 grpc_debug_server_addresses, 75 watch_fn=None, 76 thread_name_filter=None, 77 log_usage=True): 78 """Constructor of DumpingDebugWrapperSession. 79 80 Args: 81 sess: The TensorFlow `Session` object being wrapped. 82 grpc_debug_server_addresses: (`str` or `list` of `str`) Single or a list 83 of the gRPC debug server addresses, in the format of 84 <host:port>, with or without the "grpc://" prefix. For example: 85 "localhost:7000", 86 ["localhost:7000", "192.168.0.2:8000"] 87 watch_fn: (`Callable`) A Callable that can be used to define per-run 88 debug ops and watched tensors. See the doc of 89 `NonInteractiveDebugWrapperSession.__init__()` for details. 90 thread_name_filter: Regular-expression white list for threads on which the 91 wrapper session will be active. See doc of `BaseDebugWrapperSession` for 92 more details. 93 log_usage: (`bool`) whether the usage of this class is to be logged. 94 95 Raises: 96 TypeError: If `grpc_debug_server_addresses` is not a `str` or a `list` 97 of `str`. 98 """ 99 100 if log_usage: 101 pass # No logging for open-source. 102 103 framework.NonInteractiveDebugWrapperSession.__init__( 104 self, sess, watch_fn=watch_fn, thread_name_filter=thread_name_filter) 105 106 if isinstance(grpc_debug_server_addresses, str): 107 self._grpc_debug_server_urls = [ 108 self._normalize_grpc_url(grpc_debug_server_addresses)] 109 elif isinstance(grpc_debug_server_addresses, list): 110 self._grpc_debug_server_urls = [] 111 for address in grpc_debug_server_addresses: 112 if not isinstance(address, str): 113 raise TypeError( 114 "Expected type str in list grpc_debug_server_addresses, " 115 "received type %s" % type(address)) 116 self._grpc_debug_server_urls.append(self._normalize_grpc_url(address)) 117 else: 118 raise TypeError( 119 "Expected type str or list in grpc_debug_server_addresses, " 120 "received type %s" % type(grpc_debug_server_addresses)) 121 122 def prepare_run_debug_urls(self, fetches, feed_dict): 123 """Implementation of abstract method in superclass. 124 125 See doc of `NonInteractiveDebugWrapperSession.prepare_run_debug_urls()` 126 for details. 127 128 Args: 129 fetches: Same as the `fetches` argument to `Session.run()` 130 feed_dict: Same as the `feed_dict` argument to `Session.run()` 131 132 Returns: 133 debug_urls: (`str` or `list` of `str`) file:// debug URLs to be used in 134 this `Session.run()` call. 135 """ 136 137 return self._grpc_debug_server_urls 138 139 def _normalize_grpc_url(self, address): 140 return (common.GRPC_URL_PREFIX + address 141 if not address.startswith(common.GRPC_URL_PREFIX) else address) 142 143 144def _signal_handler(unused_signal, unused_frame): 145 while True: 146 response = six.moves.input( 147 "\nSIGINT received. Quit program? (Y/n): ").strip() 148 if response in ("", "Y", "y"): 149 sys.exit(0) 150 elif response in ("N", "n"): 151 break 152 153 154def register_signal_handler(): 155 try: 156 signal.signal(signal.SIGINT, _signal_handler) 157 except ValueError: 158 # This can happen if we are not in the MainThread. 159 pass 160 161 162class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession): 163 """A tfdbg Session wrapper that can be used with TensorBoard Debugger Plugin. 164 165 This wrapper is the same as `GrpcDebugWrapperSession`, except that it uses a 166 predefined `watch_fn` that 167 1) uses `DebugIdentity` debug ops with the `gated_grpc` attribute set to 168 `True` to allow the interactive enabling and disabling of tensor 169 breakpoints. 170 2) watches all tensors in the graph. 171 This saves the need for the user to define a `watch_fn`. 172 """ 173 174 def __init__(self, 175 sess, 176 grpc_debug_server_addresses, 177 thread_name_filter=None, 178 send_traceback_and_source_code=True, 179 log_usage=True): 180 """Constructor of TensorBoardDebugWrapperSession. 181 182 Args: 183 sess: The `tf.Session` instance to be wrapped. 184 grpc_debug_server_addresses: gRPC address(es) of debug server(s), as a 185 `str` or a `list` of `str`s. E.g., "localhost:2333", 186 "grpc://localhost:2333", ["192.168.0.7:2333", "192.168.0.8:2333"]. 187 thread_name_filter: Optional filter for thread names. 188 send_traceback_and_source_code: Whether traceback of graph elements and 189 the source code are to be sent to the debug server(s). 190 log_usage: Whether the usage of this class is to be logged (if 191 applicable). 192 """ 193 def _gated_grpc_watch_fn(fetches, feeds): 194 del fetches, feeds # Unused. 195 return framework.WatchOptions( 196 debug_ops=["DebugIdentity(gated_grpc=true)"]) 197 198 super(TensorBoardDebugWrapperSession, self).__init__( 199 sess, 200 grpc_debug_server_addresses, 201 watch_fn=_gated_grpc_watch_fn, 202 thread_name_filter=thread_name_filter, 203 log_usage=log_usage) 204 205 self._send_traceback_and_source_code = send_traceback_and_source_code 206 # Keeps track of the latest version of Python graph object that has been 207 # sent to the debug servers. 208 self._sent_graph_version = -1 209 210 register_signal_handler() 211 212 def run(self, 213 fetches, 214 feed_dict=None, 215 options=None, 216 run_metadata=None, 217 callable_runner=None, 218 callable_runner_args=None, 219 callable_options=None): 220 if self._send_traceback_and_source_code: 221 self._sent_graph_version = publish_traceback( 222 self._grpc_debug_server_urls, self.graph, feed_dict, fetches, 223 self._sent_graph_version) 224 return super(TensorBoardDebugWrapperSession, self).run( 225 fetches, 226 feed_dict=feed_dict, 227 options=options, 228 run_metadata=run_metadata, 229 callable_runner=callable_runner, 230 callable_runner_args=callable_runner_args, 231 callable_options=callable_options) 232