1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Communicating tracebacks and source code with debug server.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import socket 22 23import grpc 24 25from tensorflow.core.debug import debug_service_pb2 26from tensorflow.core.protobuf import debug_pb2 27from tensorflow.python.debug.lib import common 28from tensorflow.python.debug.lib import debug_service_pb2_grpc 29from tensorflow.python.debug.lib import source_utils 30from tensorflow.python.platform import gfile 31from tensorflow.python.platform import tf_logging 32from tensorflow.python.profiler import tfprof_logger 33 34 35def _load_debugged_source_file(file_path, source_file_proto): 36 file_stat = gfile.Stat(file_path) 37 source_file_proto.host = socket.gethostname() 38 source_file_proto.file_path = file_path 39 source_file_proto.last_modified = file_stat.mtime_nsec 40 source_file_proto.bytes = file_stat.length 41 try: 42 with gfile.Open(file_path, "r") as f: 43 source_file_proto.lines.extend(f.read().splitlines()) 44 except IOError: 45 pass 46 47 48def _string_to_id(string, string_to_id): 49 if string not in string_to_id: 50 string_to_id[string] = len(string_to_id) 51 return string_to_id[string] 52 53 54def _format_origin_stack(origin_stack, call_traceback_proto): 55 """Format a traceback stack for a `CallTraceback` proto. 56 57 Args: 58 origin_stack: The stack list as returned by `traceback.extract_stack()`. 59 call_traceback_proto: A `CallTraceback` proto whose fields are to be 60 populated. 61 """ 62 string_to_id = dict() 63 string_to_id[None] = 0 64 for frame in origin_stack: 65 file_path, lineno, func_name, line_text = frame 66 call_traceback_proto.origin_stack.traces.add( 67 file_id=_string_to_id(file_path, string_to_id), 68 lineno=lineno, 69 function_id=_string_to_id(func_name, string_to_id), 70 line_id=_string_to_id(line_text, string_to_id)) 71 72 id_to_string = call_traceback_proto.origin_id_to_string 73 for key, value in string_to_id.items(): 74 id_to_string[value] = key if key is not None else "" 75 76 77def _source_file_paths_outside_tensorflow_py_library(code_defs, id_to_string): 78 """Extract source file paths outside TensorFlow Python library. 79 80 Args: 81 code_defs: An iterable of `CodeDef` protos, i.e., an iterable of stack 82 traces. 83 id_to_string: A proto map from integer ids to strings. 84 85 Returns: 86 An iterable of source file paths outside the TensorFlow Python library. 87 """ 88 file_ids = set() 89 for code_def in code_defs: 90 for trace in code_def.traces: 91 file_ids.add(trace.file_id) 92 non_tf_files = (id_to_string[file_id] for file_id in file_ids) 93 non_tf_files = ( 94 f for f in non_tf_files 95 if not source_utils.guess_is_tensorflow_py_library(f) and gfile.Exists(f)) 96 return non_tf_files 97 98 99def grpc_message_length_bytes(): 100 """Maximum gRPC message length in bytes.""" 101 return 4 * 1024 * 1024 102 103 104def _send_call_tracebacks(destinations, 105 origin_stack, 106 is_eager_execution=False, 107 call_key=None, 108 graph=None, 109 send_source=True): 110 """Send the tracebacks of a TensorFlow execution call. 111 112 To gRPC debug server(s). This applies to graph execution (`tf.Session.run()`) 113 calls and eager execution calls. 114 115 If `send_source`, also sends the underlying source files outside the 116 TensorFlow library. 117 118 Args: 119 destinations: gRPC destination addresses, a `str` or a `list` of `str`s, 120 e.g., "localhost:4242". If a `list`, gRPC requests containing the same 121 `CallTraceback` proto payload will be sent to all the destinations. 122 origin_stack: The traceback stack for the origin of the execution call. For 123 graph execution, this is the traceback of the `tf.Session.run()` 124 invocation. For eager execution, this is the traceback of the Python 125 line that executes the eager opertion. 126 is_eager_execution: (`bool`) whether an eager execution call (i.e., not a 127 `tf.Session.run` or derived methods) is being sent. 128 call_key: The key of the execution call, as a string. For graph execution, 129 this is a string describing the feeds, fetches (and targets) names of the 130 `tf.Session.run` call. For eager execution, this is ignored. 131 graph: A Python `tf.Graph` object (i.e., *not* a `tf.GraphDef`), which 132 contains op tracebacks, if applicable. 133 send_source: Whether the source files involved in the op tracebacks but 134 outside the TensorFlow library are to be sent. 135 """ 136 if not isinstance(destinations, list): 137 destinations = [destinations] 138 # Strip grpc:// prefix, if any is present. 139 destinations = [ 140 dest[len(common.GRPC_URL_PREFIX):] 141 if dest.startswith(common.GRPC_URL_PREFIX) else dest 142 for dest in destinations] 143 144 call_type = (debug_service_pb2.CallTraceback.EAGER_EXECUTION 145 if is_eager_execution 146 else debug_service_pb2.CallTraceback.GRAPH_EXECUTION) 147 graph_traceback = tfprof_logger.merge_default_with_oplog( 148 graph, add_trainable_var=False) if graph else None 149 call_traceback = debug_service_pb2.CallTraceback( 150 call_type=call_type, call_key=call_key, graph_traceback=graph_traceback, 151 graph_version=graph.version if graph else None) 152 153 _format_origin_stack(origin_stack, call_traceback) 154 155 if send_source: 156 source_file_paths = set() 157 source_file_paths.update(_source_file_paths_outside_tensorflow_py_library( 158 (log_entry.code_def for log_entry 159 in call_traceback.graph_traceback.log_entries), 160 call_traceback.graph_traceback.id_to_string)) 161 source_file_paths.update(_source_file_paths_outside_tensorflow_py_library( 162 [call_traceback.origin_stack], call_traceback.origin_id_to_string)) 163 164 debugged_source_files = [] 165 for file_path in source_file_paths: 166 source_files = debug_pb2.DebuggedSourceFiles() 167 _load_debugged_source_file( 168 file_path, source_files.source_files.add()) 169 debugged_source_files.append(source_files) 170 171 for destination in destinations: 172 channel = grpc.insecure_channel(destination) 173 stub = debug_service_pb2_grpc.EventListenerStub(channel) 174 stub.SendTracebacks(call_traceback) 175 if send_source: 176 for path, source_files in zip( 177 source_file_paths, debugged_source_files): 178 if source_files.ByteSize() < grpc_message_length_bytes(): 179 stub.SendSourceFiles(source_files) 180 else: 181 tf_logging.warn( 182 "The content of the source file at %s is not sent to " 183 "gRPC debug server %s, because the message size exceeds " 184 "gRPC message length limit (%d bytes)." % ( 185 path, destination, grpc_message_length_bytes())) 186 187 188def send_graph_tracebacks(destinations, 189 run_key, 190 origin_stack, 191 graph, 192 send_source=True): 193 """Send the tracebacks of a graph execution call to debug server(s). 194 195 Args: 196 destinations: gRPC destination addresses, a `str` or a `list` of `str`s, 197 e.g., "localhost:4242". If a `list`, gRPC requests containing the same 198 `CallTraceback` proto payload will be sent to all the destinations. 199 run_key: A string describing the feeds, fetches (and targets) names of the 200 `tf.Session.run` call. 201 origin_stack: The traceback of the `tf.Session.run()` invocation. 202 graph: A Python `tf.Graph` object (i.e., *not* a `tf.GraphDef`), which 203 contains op tracebacks. 204 send_source: Whether the source files involved in the op tracebacks but 205 outside the TensorFlow library are to be sent. 206 """ 207 _send_call_tracebacks( 208 destinations, origin_stack, is_eager_execution=False, call_key=run_key, 209 graph=graph, send_source=send_source) 210 211 212def send_eager_tracebacks(destinations, 213 origin_stack, 214 send_source=True): 215 """Send the tracebacks of an eager execution call to debug server(s). 216 217 Args: 218 destinations: gRPC destination addresses, a `str` or a `list` of `str`s, 219 e.g., "localhost:4242". If a `list`, gRPC requests containing the same 220 origin_stack: The traceback of the eager operation invocation. 221 send_source: Whether the source files involved in the op tracebacks but 222 outside the TensorFlow library are to be sent. 223 """ 224 _send_call_tracebacks( 225 destinations, origin_stack, is_eager_execution=True, 226 send_source=send_source) 227