1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Function for interpolating formatted errors from the TensorFlow runtime.
16
17Exposes the function `interpolate` to interpolate messages with tags of the form
18{{type name}}.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import collections
26import itertools
27import os
28import re
29
30import six
31
32from tensorflow.python.util import tf_stack
33
34_NAME_REGEX = r"[A-Za-z0-9_.][A-Za-z0-9_.\-/]*?"
35_TAG_REGEX = r"{{{{({name}) ({name})}}}}".format(name=_NAME_REGEX)
36_INTERPOLATION_REGEX = r"^(.*?)({tag})".format(tag=_TAG_REGEX)
37_INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX, re.DOTALL)
38
39_ParseTag = collections.namedtuple("_ParseTag", ["type", "name"])
40
41_BAD_FILE_SUBSTRINGS = [
42    os.path.join("tensorflow", "python"),
43    os.path.join("tensorflow", "contrib"),
44    os.path.join("tensorflow_estimator", "python"),
45    os.path.join("tensorflow_estimator", "contrib"),
46    "<embedded",
47]
48
49
50def parse_message(message):
51  """Parses the message.
52
53  Splits the message into separators and tags. Tags are named tuples
54  representing the string {{type name}} and they are separated by
55  separators. For example, in "123{{node Foo}}456{{node Bar}}789", there are
56  two tags and three separators. The separators are the numeric characters.
57
58  Args:
59    message: String to parse
60
61  Returns:
62    (list of separator strings, list of _ParseTags).
63
64    For example, if message is "123{{node Foo}}456" then this function
65    returns (["123", "456"], [_ParseTag("node", "Foo")])
66  """
67  seps = []
68  tags = []
69  pos = 0
70  while pos < len(message):
71    match = re.match(_INTERPOLATION_PATTERN, message[pos:])
72    if match:
73      seps.append(match.group(1))
74      tags.append(_ParseTag(match.group(3), match.group(4)))
75      pos += match.end()
76    else:
77      break
78  seps.append(message[pos:])
79  return seps, tags
80
81
82def _compute_device_summary_from_list(name, device_assignment_list, prefix=""):
83  """Return a summary of an op's device function stack.
84
85  Args:
86    name: The name of the op.
87    device_assignment_list: The op._device_assignments list.
88    prefix:  An optional string prefix used before each line of the multi-
89        line string returned by this function.
90
91  Returns:
92    A multi-line string similar to:
93        Device assignments active during op 'foo' creation:
94          with tf.device(/cpu:0): <test_1.py:27>
95          with tf.device(some_func<foo.py, 123>): <test_2.py:38>
96    The first line will have no padding to its left by default.  Subsequent
97    lines will have two spaces of left-padding.  Use the prefix argument
98    to increase indentation.
99  """
100  if not device_assignment_list:
101    message = "No device assignments were active during op '%s' creation."
102    message %= name
103    return prefix + message
104
105  str_list = []
106  str_list.append(
107      "%sDevice assignments active during op '%s' creation:" % (prefix, name))
108
109  for traceable_obj in device_assignment_list:
110    location_summary = "<{file}:{line}>".format(
111        file=traceable_obj.filename, line=traceable_obj.lineno)
112    subs = {
113        "prefix": prefix,
114        "indent": "  ",
115        "dev_name": traceable_obj.obj,
116        "loc": location_summary,
117    }
118    str_list.append(
119        "{prefix}{indent}with tf.device({dev_name}): {loc}".format(**subs))
120
121  return "\n".join(str_list)
122
123
124def _compute_device_assignment_summary_from_op(op, prefix=""):
125  # pylint: disable=protected-access
126  return _compute_device_summary_from_list(op.name, op._device_assignments,
127                                           prefix)
128  # pylint: enable=protected-access
129
130
131def _compute_colocation_summary_from_dict(name, colocation_dict, prefix=""):
132  """Return a summary of an op's colocation stack.
133
134  Args:
135    name: The op name.
136    colocation_dict: The op._colocation_dict.
137    prefix:  An optional string prefix used before each line of the multi-
138        line string returned by this function.
139
140  Returns:
141    A multi-line string similar to:
142        Node-device colocations active during op creation:
143          with tf.colocate_with(test_node_1): <test_1.py:27>
144          with tf.colocate_with(test_node_2): <test_2.py:38>
145    The first line will have no padding to its left by default.  Subsequent
146    lines will have two spaces of left-padding.  Use the prefix argument
147    to increase indentation.
148  """
149  if not colocation_dict:
150    message = "No node-device colocations were active during op '%s' creation."
151    message %= name
152    return prefix + message
153
154  str_list = []
155  str_list.append("%sNode-device colocations active during op '%s' creation:" %
156                  (prefix, name))
157
158  for coloc_name, location in colocation_dict.items():
159    location_summary = "<{file}:{line}>".format(
160        file=location.filename, line=location.lineno)
161    subs = {
162        "prefix": prefix,
163        "indent": "  ",
164        "name": coloc_name,
165        "loc": location_summary,
166    }
167    str_list.append(
168        "{prefix}{indent}with tf.colocate_with({name}): {loc}".format(**subs))
169
170  return "\n".join(str_list)
171
172
173def _compute_colocation_summary_from_op(op, prefix=""):
174  """Fetch colocation file, line, and nesting and return a summary string."""
175  # pylint: disable=protected-access
176  return _compute_colocation_summary_from_dict(op.name, op._colocation_dict,
177                                               prefix)
178  # pylint: enable=protected-access
179
180
181def _find_index_of_defining_frame_for_op(op):
182  """Return index in op.traceback with first 'useful' frame.
183
184  This method reads through the stack stored in op.traceback looking for the
185  innermost frame which (hopefully) belongs to the caller.  It accomplishes this
186  by rejecting frames whose filename appears to come from TensorFlow (see
187  error_interpolation._BAD_FILE_SUBSTRINGS for the list of rejected substrings).
188
189  Args:
190    op: the Operation object for which we would like to find the defining
191        location.
192
193  Returns:
194    Integer index into op.traceback where the first non-TF file was found
195    (innermost to outermost), or 0 (for the outermost stack frame) if all files
196    came from TensorFlow.
197  """
198  # Index 0 of tf_traceback is the outermost frame.
199  tf_traceback = op.traceback
200  size = len(tf_traceback)
201  filenames = [frame[tf_stack.TB_FILENAME] for frame in tf_traceback]
202  # We process the filenames from the innermost frame to outermost.
203  for idx, filename in enumerate(reversed(filenames)):
204    contains_bad_substrings = [ss in filename for ss in _BAD_FILE_SUBSTRINGS]
205    if not any(contains_bad_substrings):
206      return size - idx - 1
207  return 0
208
209
210def _get_defining_frame_from_op(op):
211  """Find and return stack frame where op was defined."""
212  frame_index = _find_index_of_defining_frame_for_op(op)
213  return op.traceback[frame_index]
214
215def compute_useful_stack(op):
216  """Return a list of line name and lineno pairs, which form a 'useful' stack.
217
218  Starting from the defining frame to the outermost one, this method computes
219  the contiguous portion of the 'useful' stack trace and returns each line as
220  a line name and lineno pair.
221
222  Args:
223    op: op.Operation object having a _traceback member.
224
225  Returns:
226    A list of line name and lineno pairs. Below is an example of returned list:
227    [("tool_utils.py", "124", "func1", "a={}"), ("tool_utils.py", "21", "func2",
228    "for i in range(10):"), ....]
229  """
230  defining_frame_index = _find_index_of_defining_frame_for_op(op)
231  stack_trace = []
232  # The stack trace is collected from the defining (included) to the outermost.
233  # Include `frame_num` frames at most.
234  # Two lines from the TensorFlow library are included to show the node
235  # definition.
236  frame_num = 10
237  innermost_excluded = min(defining_frame_index + 2 + 1, len(op.traceback))
238  outermost_included = max(innermost_excluded - frame_num, 0)
239  for index in reversed(range(outermost_included, innermost_excluded)):
240    frame = op.traceback[index]
241    filename = frame[tf_stack.TB_FILENAME]
242    lineno = frame[tf_stack.TB_LINENO]
243    func = frame[tf_stack.TB_FUNCNAME]
244    code = frame[tf_stack.TB_CODEDICT]
245    stack_trace.append((filename, lineno, func, code))
246  return stack_trace
247
248
249def compute_field_dict(op, strip_file_prefix=""):
250  """Return a dictionary mapping interpolation tokens to values.
251
252  Args:
253    op: op.Operation object having a _traceback member.
254    strip_file_prefix: The common path in the stacktrace. We remove the prefix
255    from the file names.
256
257  Returns:
258    A dictionary mapping string tokens to string values.  The keys are shown
259    below along with example values.
260    {
261      "file": "tool_utils.py",
262      "line": "124",
263      "defined_at": " (defined at tool_utils.py:124)",
264      "colocations":
265          '''Node-device colocations active during op creation:
266               with tf.colocate_with(test_node_1): <test_1.py:27>
267               with tf.colocate_with(test_node_2): <test_2.py:38>'''
268      "devices":
269          '''Device assignments active during op 'foo' creation:
270               with tf.device(/cpu:0): <test_1.py:27>
271               with tf.device(some_func<foo.py, 123>): <test_2.py:38>'''
272      "devs_and_colocs": A concatenation of colocations and devices, e.g.
273          '''Node-device colocations active during op creation:
274               with tf.colocate_with(test_node_1): <test_1.py:27>
275               with tf.colocate_with(test_node_2): <test_2.py:38>'''
276             Device assignments active during op 'foo' creation:
277               with tf.device(/cpu:0): <test_1.py:27>
278               with tf.device(some_func<foo.py, 123>): <test_2.py:38>'''
279    }
280  """
281  frame = _get_defining_frame_from_op(op)
282  filename = frame[tf_stack.TB_FILENAME]
283  if filename.startswith(strip_file_prefix):
284    filename = filename[len(strip_file_prefix):]
285  lineno = frame[tf_stack.TB_LINENO]
286  defined_at = " (defined at %s:%d)" % (filename, lineno)
287  colocation_summary = _compute_colocation_summary_from_op(op)
288  device_summary = _compute_device_assignment_summary_from_op(op)
289  combined_summary = "\n".join([colocation_summary, device_summary])
290
291  field_dict = {
292      "file": filename,
293      "line": lineno,
294      "defined_at": defined_at,
295      "colocations": colocation_summary,
296      "devices": device_summary,
297      "devs_and_colocs": combined_summary,
298  }
299  return field_dict
300
301
302def traceback_files_common_prefix(all_ops):
303  """Determines the common prefix from the paths of the stacktrace of 'all_ops'.
304
305  For example, if the paths are '/foo/bar/baz/' and '/foo/car', this would
306  return '/foo'.
307
308  Args:
309    all_ops: All the input nodes in the form of a list of lists of ops.
310
311  Returns:
312    The common prefix.
313  """
314  files = set()
315  for ops in all_ops:
316    if ops is None:
317      continue
318    for op in ops:
319      for frame in op.traceback:
320        filename = frame[tf_stack.TB_FILENAME]
321        if "<embedded" not in filename:
322          files.add(filename)
323  return os.path.split(os.path.commonprefix(list(files)))[0]
324
325
326def _sources_for_node(node, graph):
327  """Gets the input op nodes for 'node'.
328
329  Args:
330    node: The node.
331    graph: The graph containing the node.
332
333  Returns:
334    The unique input nodes.
335  """
336  inputs = set()
337  for name in node.node_def.input:
338    if name.startswith("^"):
339      name = name[1:]
340    try:
341      tensor = graph.get_tensor_by_name(name)
342      op = tensor.op
343    except (KeyError, ValueError):
344      try:
345        op = graph.get_operation_by_name(name)
346      except KeyError:
347        continue
348    inputs.add(op)
349
350  return list(inputs)
351
352
353def _build_error_message(op, input_ops, common_prefix):
354  """Returns the formatted error message for the given op.
355
356  Args:
357    op: The node.
358    input_ops: The input nodes to the 'op' node
359    common_prefix: The prefix path common to the stacktrace of inputs.
360
361  Returns:
362    The formatted error message for the given op. The error message also
363    includes the information about the input sources for the given op.
364  """
365  field_dict = compute_field_dict(op, common_prefix)
366  msg = "node %s%s " % (op.name, field_dict["defined_at"])
367  input_debug_info = []
368  # This stores the line numbers that we have already printed.
369  done = set()
370  done.add(field_dict["defined_at"])
371  for op_inp in input_ops:
372    field_dict_inp = compute_field_dict(op_inp, common_prefix)
373    if field_dict_inp["defined_at"] not in done:
374      input_debug_info.append(
375          " %s%s" % (op_inp.name, field_dict_inp["defined_at"]))
376      done.add(field_dict_inp["defined_at"])
377  if input_debug_info:
378    end_msg = ("\nInput Source operations connected to node %s:\n") % (op.name)
379    end_msg += "\t\n".join(input_debug_info)
380  else:
381    end_msg = ""
382  return msg, end_msg
383
384
385def interpolate(error_message, graph):
386  """Interpolates an error message.
387
388  The error message can contain tags of the form `{{type name}}` which will be
389  replaced. For example: "{{node <name>}}" would get expanded to:
390  "node <name>(defined at <path>)".
391
392  Args:
393    error_message: A string to interpolate.
394    graph: ops.Graph object containing all nodes referenced in the error
395        message.
396
397  Returns:
398    The string with tags of the form {{type name}} interpolated.
399  """
400  seps, tags = parse_message(error_message)
401  subs = []
402  end_msg = collections.defaultdict(list)
403  tagged_ops = []
404
405  for t in tags:
406    try:
407      op = graph.get_operation_by_name(t.name)
408    except KeyError:
409      op = None
410    if op is None:
411      tagged_ops.append(None)
412    else:
413      tagged_ops.append([op] + _sources_for_node(op, graph))
414
415  common_prefix = traceback_files_common_prefix(tagged_ops)
416  for tag, ops in zip(tags, tagged_ops):
417    msg = "{{%s %s}}" % (tag.type, tag.name)
418    if ops is not None:
419      if tag.type == "node":
420        msg, source_msg = _build_error_message(ops[0], ops[1:], common_prefix)
421        if source_msg:
422          end_msg["source_nodes"].append(source_msg)
423      elif tag.type == "colocation_node":
424        field_dict = compute_field_dict(ops[0], common_prefix)
425        msg = "node %s%s placed on device %s " % (
426            ops[0].name, field_dict["defined_at"], field_dict["devices"])
427        end_msg["colocations"].append(field_dict["devs_and_colocs"])
428    if tag.type == "function_node":
429      msg = ""
430    subs.append(msg)
431
432  if "source_nodes" in end_msg:
433    subs.append("\n\nErrors may have originated from an input operation.")
434    subs.append("\n".join(end_msg["source_nodes"]))
435    end_msg.pop("source_nodes", None)
436  for k, messages in end_msg.items():
437    subs.append("Additional information about %s:" % k)
438    subs.append("\n".join(messages))
439
440  return "".join(
441      itertools.chain(*six.moves.zip_longest(seps, subs, fillvalue="")))
442