1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Function for interpolating formatted errors from the TensorFlow runtime.
16
17Exposes the function `interpolate` to interpolate messages with tags of the form
18{{type name}}.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import collections
26import itertools
27import os
28import re
29
30import six
31
32from tensorflow.core.protobuf import graph_debug_info_pb2
33
34_NAME_REGEX = r"[A-Za-z0-9_.][A-Za-z0-9_.\-/]*?"
35_TAG_REGEX = r"{{{{({name}) ({name})}}}}".format(name=_NAME_REGEX)
36_INTERPOLATION_REGEX = r"^(.*?)({tag})".format(tag=_TAG_REGEX)
37_INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX, re.DOTALL)
38
39_ParseTag = collections.namedtuple("_ParseTag", ["type", "name"])
40
41
42# Remove the last three path components from this module's file (i.e.
43# python/framework/error_interpolation.py) so that we have an absolute path
44# prefix to the root of the installation.
45_FRAMEWORK_COMMON_PREFIX = os.path.dirname(
46    os.path.dirname(os.path.dirname(__file__)))
47
48# Sub-directories under the common prefix that are considered part of the
49# framework.
50_FRAMEWORK_PATH_PREFIXES = [
51    os.path.join(_FRAMEWORK_COMMON_PREFIX, "python") + os.sep,
52    os.path.join(_FRAMEWORK_COMMON_PREFIX, "contrib") + os.sep,
53]
54
55# Patterns of filename patterns that should be considered internal to
56# the TensorFlow framework.
57_FRAMEWORK_FILENAME_PATTERNS = [
58    re.compile(r"<embedded"),
59]
60
61# Patterns of filename patterns that should be considered external to
62# TensorFlow regardless of framework prefix match.
63_EXTERNAL_FILENAME_PATTERNS = [
64    # Explicitly treat test frames as not part of the framework.
65    re.compile(r"_test\.py$"),
66]
67
68
69def parse_message(message):
70  """Parses the message.
71
72  Splits the message into separators and tags. Tags are named tuples
73  representing the string {{type name}} and they are separated by
74  separators. For example, in "123{{node Foo}}456{{node Bar}}789", there are
75  two tags and three separators. The separators are the numeric characters.
76
77  Args:
78    message: String to parse
79
80  Returns:
81    (list of separator strings, list of _ParseTags).
82
83    For example, if message is "123{{node Foo}}456" then this function
84    returns (["123", "456"], [_ParseTag("node", "Foo")])
85  """
86  seps = []
87  tags = []
88  pos = 0
89  while pos < len(message):
90    match = re.match(_INTERPOLATION_PATTERN, message[pos:])
91    if match:
92      seps.append(match.group(1))
93      tags.append(_ParseTag(match.group(3), match.group(4)))
94      pos += match.end()
95    else:
96      break
97  seps.append(message[pos:])
98  return seps, tags
99
100
101def _compute_device_summary_from_list(name, device_assignment_list, prefix=""):
102  """Return a summary of an op's device function stack.
103
104  Args:
105    name: The name of the op.
106    device_assignment_list: The op._device_assignments list.
107    prefix:  An optional string prefix used before each line of the multi-
108        line string returned by this function.
109
110  Returns:
111    A multi-line string similar to:
112        Device assignments active during op 'foo' creation:
113          with tf.device(/cpu:0): <test_1.py:27>
114          with tf.device(some_func<foo.py, 123>): <test_2.py:38>
115    The first line will have no padding to its left by default.  Subsequent
116    lines will have two spaces of left-padding.  Use the prefix argument
117    to increase indentation.
118  """
119  if not device_assignment_list:
120    message = "No device assignments were active during op '%s' creation."
121    message %= name
122    return prefix + message
123
124  str_list = []
125  str_list.append(
126      "%sDevice assignments active during op '%s' creation:" % (prefix, name))
127
128  for traceable_obj in device_assignment_list:
129    location_summary = "<{file}:{line}>".format(
130        file=traceable_obj.filename, line=traceable_obj.lineno)
131    subs = {
132        "prefix": prefix,
133        "indent": "  ",
134        "dev_name": traceable_obj.obj,
135        "loc": location_summary,
136    }
137    str_list.append(
138        "{prefix}{indent}with tf.device({dev_name}): {loc}".format(**subs))
139
140  return "\n".join(str_list)
141
142
143def _compute_device_assignment_summary_from_op(op, prefix=""):
144  # pylint: disable=protected-access
145  return _compute_device_summary_from_list(op.name, op._device_assignments,
146                                           prefix)
147  # pylint: enable=protected-access
148
149
150def _compute_colocation_summary_from_dict(name, colocation_dict, prefix=""):
151  """Return a summary of an op's colocation stack.
152
153  Args:
154    name: The op name.
155    colocation_dict: The op._colocation_dict.
156    prefix:  An optional string prefix used before each line of the multi-
157        line string returned by this function.
158
159  Returns:
160    A multi-line string similar to:
161        Node-device colocations active during op creation:
162          with tf.compat.v1.colocate_with(test_node_1): <test_1.py:27>
163          with tf.compat.v1.colocate_with(test_node_2): <test_2.py:38>
164    The first line will have no padding to its left by default.  Subsequent
165    lines will have two spaces of left-padding.  Use the prefix argument
166    to increase indentation.
167  """
168  if not colocation_dict:
169    message = "No node-device colocations were active during op '%s' creation."
170    message %= name
171    return prefix + message
172
173  str_list = []
174  str_list.append("%sNode-device colocations active during op '%s' creation:" %
175                  (prefix, name))
176
177  for coloc_name, location in colocation_dict.items():
178    location_summary = "<{file}:{line}>".format(
179        file=location.filename, line=location.lineno)
180    subs = {
181        "prefix": prefix,
182        "indent": "  ",
183        "name": coloc_name,
184        "loc": location_summary,
185    }
186    str_list.append(
187        "{prefix}{indent}with tf.colocate_with({name}): {loc}".format(**subs))
188
189  return "\n".join(str_list)
190
191
192def _compute_colocation_summary_from_op(op, prefix=""):
193  """Fetch colocation file, line, and nesting and return a summary string."""
194  # pylint: disable=protected-access
195  return _compute_colocation_summary_from_dict(op.name, op._colocation_dict,
196                                               prefix)
197  # pylint: enable=protected-access
198
199
200def _is_framework_filename(filename):
201  """Returns whether a filename should be considered a part of the framework.
202
203  A file is part of the framework if it does not match a pattern in
204  _EXTERNAL_FILENAME_PATTERNS and it either matches a pattern in
205  _FRAMEWORK_FILENAME_PATTERNS or starts with a _FRAMEWORK_PATH_PREFIXES prefix.
206
207  Args:
208    filename: A filename string.
209
210  Returns:
211    Whether the filename should be considered to be internal to the
212    TensorFlow framework for the purposes of reporting errors.
213  """
214  for pattern in _EXTERNAL_FILENAME_PATTERNS:
215    if pattern.search(filename):
216      return False
217  for pattern in _FRAMEWORK_FILENAME_PATTERNS:
218    if pattern.search(filename):
219      return True
220  for prefix in _FRAMEWORK_PATH_PREFIXES:
221    if filename.startswith(prefix):
222      return True
223  return False
224
225
226def _find_index_of_defining_frame(traceback):
227  """Return index in op.traceback with first 'useful' frame.
228
229  This method reads through the stack stored in op.traceback looking for the
230  innermost frame which (hopefully) belongs to the caller.  It accomplishes this
231  by rejecting frames deemed to be part of the TensorFlow framework (by
232  pattern matching the filename).
233
234  Args:
235    traceback: A list of traceback frames (as from Operation.traceback).
236
237  Returns:
238    Integer index into op.traceback where the first non-TF file was found
239    (innermost to outermost), or 0 (for the outermost stack frame) if all files
240    came from TensorFlow.
241  """
242  # Index 0 of traceback is the outermost frame.
243  size = len(traceback)
244  filenames = [frame.filename for frame in traceback]
245  # We process the filenames from the innermost frame to outermost.
246  for idx, filename in enumerate(reversed(filenames)):
247    is_framework = _is_framework_filename(filename)
248    if not is_framework:
249      # Consider this to be the defining frame.
250      return size - idx - 1
251  return 0
252
253
254def _get_defining_frame(traceback):
255  """Find and return stack frame where op was defined."""
256  frame_index = _find_index_of_defining_frame(traceback)
257  return traceback[frame_index]
258
259
260def _compute_useful_frames(traceback, num):
261  """Return a list of frames, which form a 'useful' stack.
262
263  Starting from the defining frame to the outermost one, this method computes
264  the contiguous portion of the 'useful' stack trace and returns the selected
265  frames.
266
267  Args:
268    traceback: A list of traceback frames (as from Operation.traceback).
269    num: total number of frames to return.
270
271  Returns:
272    A list of frames.
273  """
274  defining_frame_index = _find_index_of_defining_frame(traceback)
275  # The stack trace is collected from two lines before the defining frame in the
276  # model file to the outermost with `num` frames at most. These two extra lines
277  # are included from the TensorFlow library to give the context which node is
278  # defined.
279  innermost_excluded = min(defining_frame_index + 2 + 1, len(traceback))
280  outermost_included = max(innermost_excluded - num, 0)
281  return traceback[outermost_included:innermost_excluded]
282
283
284def create_graph_debug_info_def(func_named_operations):
285  """Construct and returns a `GraphDebugInfo` protocol buffer.
286
287  Args:
288    func_named_operations: An iterable of (func_name, op.Operation) tuples
289      where the Operation instances have a _traceback members. The func_name
290      should be the empty string for operations in the top-level Graph.
291
292  Returns:
293    GraphDebugInfo protocol buffer.
294
295  Raises:
296    TypeError: If the arguments are not of the correct proto buffer type.
297  """
298  # Creates an empty GraphDebugInfoDef proto.
299  graph_debug_info_def = graph_debug_info_pb2.GraphDebugInfo()
300
301  # Gets the file names and line numbers for the exported node names. Also
302  # collects the unique file names.
303  all_file_names = set()
304  node_to_trace = {}
305  for func_name, op in func_named_operations:
306    try:
307      op_traceback = op.traceback
308    except AttributeError:
309      # Some ops synthesized on as part of function or control flow definition
310      # do not have tracebacks.
311      continue
312
313    # Gets the stack trace of the operation and then the file location.
314    node_name = op.name + "@" + func_name
315    node_to_trace[node_name] = _compute_useful_frames(op_traceback, 10)
316    for frame in node_to_trace[node_name]:
317      all_file_names.add(frame.filename)
318
319  # Sets the `files` field in the GraphDebugInfo proto
320  graph_debug_info_def.files.extend(all_file_names)
321
322  # Builds a mapping between file names and index of the `files` field, so we
323  # only store the indexes for the nodes in the GraphDebugInfo.
324  file_to_index = dict(
325      [(y, x) for x, y in enumerate(graph_debug_info_def.files)])
326
327  # Creates the FileLineCol proto for each node and sets the value in the
328  # GraphDebugInfo proto. We only store the file name index for each node to
329  # save the storage space.
330  for node_name, frames in node_to_trace.items():
331    trace_def = graph_debug_info_def.traces[node_name]
332    for frame in reversed(frames):
333      trace_def.file_line_cols.add(
334          file_index=file_to_index[frame.filename],
335          line=frame.lineno)
336
337  return graph_debug_info_def
338
339
340def _compute_field_dict(op, strip_file_prefix=""):
341  """Return a dictionary mapping interpolation tokens to values.
342
343  Args:
344    op: op.Operation object having a _traceback member.
345    strip_file_prefix: The common path in the stacktrace. We remove the prefix
346    from the file names.
347
348  Returns:
349    A dictionary mapping string tokens to string values.  The keys are shown
350    below along with example values.
351    {
352      "file": "tool_utils.py",
353      "line": "124",
354      "defined_at": " (defined at tool_utils.py:124)",
355      "colocations":
356          '''Node-device colocations active during op creation:
357               with tf.compat.v1.colocate_with(test_node_1): <test_1.py:27>
358               with tf.compat.v1.colocate_with(test_node_2): <test_2.py:38>'''
359      "devices":
360          '''Device assignments active during op 'foo' creation:
361               with tf.device(/cpu:0): <test_1.py:27>
362               with tf.device(some_func<foo.py, 123>): <test_2.py:38>'''
363      "devs_and_colocs": A concatenation of colocations and devices, e.g.
364          '''Node-device colocations active during op creation:
365               with tf.compat.v1.colocate_with(test_node_1): <test_1.py:27>
366               with tf.compat.v1.colocate_with(test_node_2): <test_2.py:38>'''
367             Device assignments active during op 'foo' creation:
368               with tf.device(/cpu:0): <test_1.py:27>
369               with tf.device(some_func<foo.py, 123>): <test_2.py:38>'''
370    }
371  """
372  colocation_summary = _compute_colocation_summary_from_op(op)
373  device_summary = _compute_device_assignment_summary_from_op(op)
374  combined_summary = "\n".join([colocation_summary, device_summary])
375
376  # Optional traceback info.
377  try:
378    traceback = op.traceback
379  except AttributeError:
380    # Some ops synthesized on as part of function or control flow definition
381    # do not have tracebacks.
382    filename = "<unknown>"
383    lineno = 0
384    defined_at = " (defined at <unknown>)"
385  else:
386    frame = _get_defining_frame(traceback)
387    filename = frame.filename
388    if filename.startswith(strip_file_prefix):
389      filename = filename[len(strip_file_prefix):]
390    lineno = frame.lineno
391    defined_at = " (defined at %s:%d)" % (filename, lineno)
392
393  field_dict = {
394      "colocations": colocation_summary,
395      "devices": device_summary,
396      "devs_and_colocs": combined_summary,
397      "defined_at": defined_at,
398      "file": filename,
399      "line": lineno,
400  }
401  return field_dict
402
403
404def traceback_files_common_prefix(all_ops):
405  """Determines the common prefix from the paths of the stacktrace of 'all_ops'.
406
407  For example, if the paths are '/foo/bar/baz/' and '/foo/car', this would
408  return '/foo'.
409
410  Args:
411    all_ops: All the input nodes in the form of a list of lists of ops.
412
413  Returns:
414    The common prefix.
415  """
416  files = set()
417  for ops in all_ops:
418    if ops is None:
419      continue
420    for op in ops:
421      # TODO(slebedev): switch to .filename once 2.X support is dropped.
422      for filename, _, _, _ in op.traceback:
423        if "<embedded" not in filename:
424          files.add(filename)
425  return os.path.split(os.path.commonprefix(list(files)))[0]
426
427
428def _sources_for_node(node, graph):
429  """Gets the input op nodes for 'node'.
430
431  Args:
432    node: The node.
433    graph: The graph containing the node.
434
435  Returns:
436    The unique input nodes.
437  """
438  inputs = set()
439  for name in node.node_def.input:
440    if name.startswith("^"):
441      name = name[1:]
442    try:
443      tensor = graph.get_tensor_by_name(name)
444      op = tensor.op
445    except (KeyError, ValueError):
446      try:
447        op = graph.get_operation_by_name(name)
448      except KeyError:
449        continue
450    inputs.add(op)
451
452  return list(inputs)
453
454
455def _build_error_message(op, input_ops, common_prefix):
456  """Returns the formatted error message for the given op.
457
458  Args:
459    op: The node.
460    input_ops: The input nodes to the 'op' node
461    common_prefix: The prefix path common to the stacktrace of inputs.
462
463  Returns:
464    The formatted error message for the given op. The error message also
465    includes the information about the input sources for the given op.
466  """
467  field_dict = _compute_field_dict(op, common_prefix)
468  msg = "node %s%s " % (op.name, field_dict["defined_at"])
469  input_debug_info = []
470  # This stores the line numbers that we have already printed.
471  done = set()
472  done.add(field_dict["defined_at"])
473  for op_inp in input_ops:
474    field_dict_inp = _compute_field_dict(op_inp, common_prefix)
475    if field_dict_inp["defined_at"] not in done:
476      input_debug_info.append(
477          " %s%s" % (op_inp.name, field_dict_inp["defined_at"]))
478      done.add(field_dict_inp["defined_at"])
479  if input_debug_info:
480    end_msg = ("\nInput Source operations connected to node %s:\n") % (op.name)
481    end_msg += "\t\n".join(input_debug_info)
482  else:
483    end_msg = ""
484  return msg, end_msg
485
486
487def interpolate(error_message, graph):
488  """Interpolates an error message.
489
490  The error message can contain tags of the form `{{type name}}` which will be
491  replaced. For example: "{{node <name>}}" would get expanded to:
492  "node <name>(defined at <path>)".
493
494  Args:
495    error_message: A string to interpolate.
496    graph: ops.Graph object containing all nodes referenced in the error
497        message.
498
499  Returns:
500    The string with tags of the form {{type name}} interpolated.
501  """
502  seps, tags = parse_message(error_message)
503  subs = []
504  end_msg = collections.defaultdict(list)
505  tagged_ops = []
506
507  for t in tags:
508    try:
509      op = graph.get_operation_by_name(t.name)
510    except KeyError:
511      op = None
512    if op is None:
513      tagged_ops.append(None)
514    else:
515      tagged_ops.append([op] + _sources_for_node(op, graph))
516
517  common_prefix = traceback_files_common_prefix(tagged_ops)
518  for tag, ops in zip(tags, tagged_ops):
519    msg = "{{%s %s}}" % (tag.type, tag.name)
520    if ops is not None:
521      if tag.type == "node":
522        msg, source_msg = _build_error_message(ops[0], ops[1:], common_prefix)
523        if source_msg:
524          end_msg["source_nodes"].append(source_msg)
525      elif tag.type == "colocation_node":
526        field_dict = _compute_field_dict(ops[0], common_prefix)
527        msg = "node %s%s placed on device %s " % (
528            ops[0].name, field_dict["defined_at"], field_dict["devices"])
529        end_msg["colocations"].append(field_dict["devs_and_colocs"])
530    if tag.type == "function_node":
531      msg = ""
532    subs.append(msg)
533
534  if "source_nodes" in end_msg:
535    subs.append("\n\nErrors may have originated from an input operation.")
536    subs.append("\n".join(end_msg["source_nodes"]))
537    end_msg.pop("source_nodes", None)
538  for k, messages in end_msg.items():
539    subs.append("Additional information about %s:" % k)
540    subs.append("\n".join(messages))
541
542  return "".join(
543      itertools.chain(*six.moves.zip_longest(seps, subs, fillvalue="")))
544