1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Communicating tracebacks and source code with debug server."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import socket
22
23import grpc
24
25from tensorflow.core.debug import debug_service_pb2
26from tensorflow.core.protobuf import debug_pb2
27from tensorflow.python.debug.lib import common
28from tensorflow.python.debug.lib import debug_service_pb2_grpc
29from tensorflow.python.debug.lib import source_utils
30from tensorflow.python.platform import gfile
31from tensorflow.python.platform import tf_logging
32from tensorflow.python.profiler import tfprof_logger
33
34
35def _load_debugged_source_file(file_path, source_file_proto):
36  file_stat = gfile.Stat(file_path)
37  source_file_proto.host = socket.gethostname()
38  source_file_proto.file_path = file_path
39  source_file_proto.last_modified = file_stat.mtime_nsec
40  source_file_proto.bytes = file_stat.length
41  try:
42    with gfile.Open(file_path, "r") as f:
43      source_file_proto.lines.extend(f.read().splitlines())
44  except IOError:
45    pass
46
47
48def _string_to_id(string, string_to_id):
49  if string not in string_to_id:
50    string_to_id[string] = len(string_to_id)
51  return string_to_id[string]
52
53
54def _format_origin_stack(origin_stack, call_traceback_proto):
55  """Format a traceback stack for a `CallTraceback` proto.
56
57  Args:
58    origin_stack: The stack list as returned by `traceback.extract_stack()`.
59    call_traceback_proto: A `CallTraceback` proto whose fields are to be
60      populated.
61  """
62  string_to_id = dict()
63  string_to_id[None] = 0
64  for frame in origin_stack:
65    file_path, lineno, func_name, line_text = frame
66    call_traceback_proto.origin_stack.traces.add(
67        file_id=_string_to_id(file_path, string_to_id),
68        lineno=lineno,
69        function_id=_string_to_id(func_name, string_to_id),
70        line_id=_string_to_id(line_text, string_to_id))
71
72  id_to_string = call_traceback_proto.origin_id_to_string
73  for key, value in string_to_id.items():
74    id_to_string[value] = key if key is not None else ""
75
76
77def _source_file_paths_outside_tensorflow_py_library(code_defs, id_to_string):
78  """Extract source file paths outside TensorFlow Python library.
79
80  Args:
81    code_defs: An iterable of `CodeDef` protos, i.e., an iterable of stack
82      traces.
83    id_to_string: A proto map from integer ids to strings.
84
85  Returns:
86    An iterable of source file paths outside the TensorFlow Python library.
87  """
88  file_ids = set()
89  for code_def in code_defs:
90    for trace in code_def.traces:
91      file_ids.add(trace.file_id)
92  non_tf_files = (id_to_string[file_id] for file_id in file_ids)
93  non_tf_files = (
94      f for f in non_tf_files
95      if not source_utils.guess_is_tensorflow_py_library(f) and gfile.Exists(f))
96  return non_tf_files
97
98
99def grpc_message_length_bytes():
100  """Maximum gRPC message length in bytes."""
101  return 4 * 1024 * 1024
102
103
104def _send_call_tracebacks(destinations,
105                          origin_stack,
106                          is_eager_execution=False,
107                          call_key=None,
108                          graph=None,
109                          send_source=True):
110  """Send the tracebacks of a TensorFlow execution call.
111
112  To gRPC debug server(s). This applies to graph execution (`tf.Session.run()`)
113  calls and eager execution calls.
114
115  If `send_source`, also sends the underlying source files outside the
116  TensorFlow library.
117
118  Args:
119    destinations: gRPC destination addresses, a `str` or a `list` of `str`s,
120      e.g., "localhost:4242". If a `list`, gRPC requests containing the same
121      `CallTraceback` proto payload will be sent to all the destinations.
122    origin_stack: The traceback stack for the origin of the execution call. For
123      graph execution, this is the traceback of the `tf.Session.run()`
124      invocation. For eager execution, this is the traceback of the Python
125      line that executes the eager opertion.
126    is_eager_execution: (`bool`) whether an eager execution call (i.e., not a
127      `tf.Session.run` or derived methods) is being sent.
128    call_key: The key of the execution call, as a string. For graph execution,
129      this is a string describing the feeds, fetches (and targets) names of the
130      `tf.Session.run` call. For eager execution, this is ignored.
131    graph: A Python `tf.Graph` object (i.e., *not* a `tf.GraphDef`), which
132      contains op tracebacks, if applicable.
133    send_source: Whether the source files involved in the op tracebacks but
134      outside the TensorFlow library are to be sent.
135  """
136  if not isinstance(destinations, list):
137    destinations = [destinations]
138  # Strip grpc:// prefix, if any is present.
139  destinations = [
140      dest[len(common.GRPC_URL_PREFIX):]
141      if dest.startswith(common.GRPC_URL_PREFIX) else dest
142      for dest in destinations]
143
144  call_type = (debug_service_pb2.CallTraceback.EAGER_EXECUTION
145               if is_eager_execution
146               else debug_service_pb2.CallTraceback.GRAPH_EXECUTION)
147  graph_traceback = tfprof_logger.merge_default_with_oplog(
148      graph, add_trainable_var=False) if graph else None
149  call_traceback = debug_service_pb2.CallTraceback(
150      call_type=call_type, call_key=call_key, graph_traceback=graph_traceback,
151      graph_version=graph.version if graph else None)
152
153  _format_origin_stack(origin_stack, call_traceback)
154
155  if send_source:
156    source_file_paths = set()
157    source_file_paths.update(_source_file_paths_outside_tensorflow_py_library(
158        (log_entry.code_def for log_entry
159         in call_traceback.graph_traceback.log_entries),
160        call_traceback.graph_traceback.id_to_string))
161    source_file_paths.update(_source_file_paths_outside_tensorflow_py_library(
162        [call_traceback.origin_stack], call_traceback.origin_id_to_string))
163
164    debugged_source_files = []
165    for file_path in source_file_paths:
166      source_files = debug_pb2.DebuggedSourceFiles()
167      _load_debugged_source_file(
168          file_path, source_files.source_files.add())
169      debugged_source_files.append(source_files)
170
171  for destination in destinations:
172    channel = grpc.insecure_channel(destination)
173    stub = debug_service_pb2_grpc.EventListenerStub(channel)
174    stub.SendTracebacks(call_traceback)
175    if send_source:
176      for path, source_files in zip(
177          source_file_paths, debugged_source_files):
178        if source_files.ByteSize() < grpc_message_length_bytes():
179          stub.SendSourceFiles(source_files)
180        else:
181          tf_logging.warn(
182              "The content of the source file at %s is not sent to "
183              "gRPC debug server %s, because the message size exceeds "
184              "gRPC message length limit (%d bytes)." % (
185                  path, destination, grpc_message_length_bytes()))
186
187
188def send_graph_tracebacks(destinations,
189                          run_key,
190                          origin_stack,
191                          graph,
192                          send_source=True):
193  """Send the tracebacks of a graph execution call to debug server(s).
194
195  Args:
196    destinations: gRPC destination addresses, a `str` or a `list` of `str`s,
197      e.g., "localhost:4242". If a `list`, gRPC requests containing the same
198      `CallTraceback` proto payload will be sent to all the destinations.
199    run_key: A string describing the feeds, fetches (and targets) names of the
200      `tf.Session.run` call.
201    origin_stack: The traceback of the `tf.Session.run()` invocation.
202    graph: A Python `tf.Graph` object (i.e., *not* a `tf.GraphDef`), which
203      contains op tracebacks.
204    send_source: Whether the source files involved in the op tracebacks but
205      outside the TensorFlow library are to be sent.
206  """
207  _send_call_tracebacks(
208      destinations, origin_stack, is_eager_execution=False, call_key=run_key,
209      graph=graph, send_source=send_source)
210
211
212def send_eager_tracebacks(destinations,
213                          origin_stack,
214                          send_source=True):
215  """Send the tracebacks of an eager execution call to debug server(s).
216
217  Args:
218    destinations: gRPC destination addresses, a `str` or a `list` of `str`s,
219      e.g., "localhost:4242". If a `list`, gRPC requests containing the same
220    origin_stack: The traceback of the eager operation invocation.
221    send_source: Whether the source files involved in the op tracebacks but
222      outside the TensorFlow library are to be sent.
223  """
224  _send_call_tracebacks(
225      destinations, origin_stack, is_eager_execution=True,
226      send_source=send_source)
227