1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Container for origin source code information before AutoGraph compilation."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21import difflib
22import os
23import tokenize
24
25import gast
26import six
27
28from tensorflow.python.autograph.pyct import anno
29from tensorflow.python.autograph.pyct import ast_util
30from tensorflow.python.autograph.pyct import parser
31from tensorflow.python.autograph.pyct import pretty_printer
32from tensorflow.python.util import tf_inspect
33
34
35class LineLocation(
36    collections.namedtuple('LineLocation', ('filename', 'lineno'))):
37  """Similar to Location, but without column information.
38
39  Attributes:
40    filename: Text
41    lineno: int, 1-based
42  """
43  pass
44
45
46class Location(
47    collections.namedtuple('Location', ('filename', 'lineno', 'col_offset'))):
48  """Encodes code location information.
49
50  Attributes:
51    filename: Text
52    lineno: int, 1-based
53    col_offset: int
54    line_loc: LineLocation
55  """
56
57  @property
58  def line_loc(self):
59    return LineLocation(self.filename, self.lineno)
60
61
62class OriginInfo(
63    collections.namedtuple(
64        'OriginInfo',
65        ('loc', 'function_name', 'source_code_line', 'comment'))):
66  """Container for information about the source code before conversion.
67
68  Attributes:
69    loc: Location
70    function_name: Optional[Text]
71    source_code_line: Text
72    comment: Optional[Text]
73  """
74
75  def as_frame(self):
76    """Returns a 4-tuple consistent with the return of traceback.extract_tb."""
77    return (self.loc.filename, self.loc.lineno, self.function_name,
78            self.source_code_line)
79
80  def __repr__(self):
81    if self.loc.filename:
82      return '{}:{}:{}'.format(
83          os.path.split(self.loc.filename)[1], self.loc.lineno,
84          self.loc.col_offset)
85    return '<no file>:{}:{}'.format(self.loc.lineno, self.loc.col_offset)
86
87
88# TODO(mdan): This source map should be a class - easier to refer to.
89def create_source_map(nodes, code, filepath):
90  """Creates a source map between an annotated AST and the code it compiles to.
91
92  Note: this function assumes nodes nodes, code and filepath correspond to the
93  same code.
94
95  Args:
96    nodes: Iterable[ast.AST, ...], one or more AST modes.
97    code: Text, the source code in which nodes are found.
98    filepath: Text
99
100  Returns:
101    Dict[LineLocation, OriginInfo], mapping locations in code to locations
102    indicated by origin annotations in node.
103  """
104  reparsed_nodes = parser.parse(code, preamble_len=0, single_node=False)
105  for node in reparsed_nodes:
106    resolve(node, code, filepath, node.lineno, node.col_offset)
107
108  source_map = {}
109
110  try:
111    for before, after in ast_util.parallel_walk(nodes, reparsed_nodes):
112      # Note: generated code might not be mapped back to its origin.
113      # TODO(mdan): Generated code should always be mapped to something.
114      origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None)
115      final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None)
116      if origin_info is None or final_info is None:
117        continue
118
119      # Note: the keys are by line only, excluding the column offset.
120      line_loc = LineLocation(final_info.loc.filename, final_info.loc.lineno)
121
122      existing_origin = source_map.get(line_loc)
123      if existing_origin is not None:
124        # Overlaps may exist because of child nodes, but almost never to
125        # different line locations. Exception make decorated functions, where
126        # both lines are mapped to the same line in the AST.
127
128        # Line overlaps: keep bottom node.
129        if existing_origin.loc.line_loc == origin_info.loc.line_loc:
130          if existing_origin.loc.lineno >= origin_info.loc.lineno:
131            continue
132
133        # In case of column overlaps, keep the leftmost node.
134        if existing_origin.loc.col_offset <= origin_info.loc.col_offset:
135          continue
136
137      source_map[line_loc] = origin_info
138
139  except ValueError as err:
140    new_msg = 'Inconsistent ASTs detected. This is a bug. Cause: \n'
141    new_msg += str(err)
142    new_msg += 'Diff:\n'
143
144    for n, rn in zip(nodes, reparsed_nodes):
145      nodes_str = pretty_printer.fmt(n, color=False, noanno=True)
146      reparsed_nodes_str = pretty_printer.fmt(rn, color=False, noanno=True)
147      diff = difflib.context_diff(
148          nodes_str.split('\n'),
149          reparsed_nodes_str.split('\n'),
150          fromfile='Original nodes',
151          tofile='Reparsed nodes',
152          n=7)
153      diff = '\n'.join(diff)
154      new_msg += diff + '\n'
155    raise ValueError(new_msg)
156
157  return source_map
158
159
160class _Function(object):
161
162  def __init__(self, name):
163    self.name = name
164
165
166class OriginResolver(gast.NodeVisitor):
167  """Annotates an AST with additional source information like file name."""
168
169  def __init__(self, root_node, source_lines, comments_map,
170               context_lineno, context_col_offset,
171               filepath):
172    self._source_lines = source_lines
173    self._comments_map = comments_map
174
175    if (hasattr(root_node, 'decorator_list') and root_node.decorator_list and
176        hasattr(root_node.decorator_list[0], 'lineno')):
177      # Typical case: functions. The line number of the first decorator
178      # is more accurate than the line number of the function itself in
179      # 3.8+. In earier versions they coincide.
180      self._lineno_offset = context_lineno - root_node.decorator_list[0].lineno
181    else:
182      # Fall back to the line number of the root node.
183      self._lineno_offset = context_lineno - root_node.lineno
184
185    self._col_offset = context_col_offset - root_node.col_offset
186
187    self._filepath = filepath
188
189    self._function_stack = []
190
191  def _absolute_lineno(self, node):
192    return node.lineno + self._lineno_offset
193
194  def _absolute_col_offset(self, node):
195    return node.col_offset + self._col_offset
196
197  def _attach_origin_info(self, node):
198    if self._function_stack:
199      function_name = self._function_stack[-1].name
200    else:
201      function_name = None
202
203    source_code_line = self._source_lines[node.lineno - 1]
204    comment = self._comments_map.get(node.lineno)
205
206    loc = Location(self._filepath, self._absolute_lineno(node),
207                   self._absolute_col_offset(node))
208    origin = OriginInfo(loc, function_name, source_code_line, comment)
209    anno.setanno(node, 'lineno', node.lineno)
210    anno.setanno(node, anno.Basic.ORIGIN, origin)
211
212  def visit(self, node):
213    entered_function = False
214    if isinstance(node, gast.FunctionDef):
215      entered_function = True
216      self._function_stack.append(_Function(node.name))
217
218    if hasattr(node, 'lineno'):
219      self._attach_origin_info(node)
220    self.generic_visit(node)
221
222    if entered_function:
223      self._function_stack.pop()
224
225
226def resolve(node, source, context_filepath, context_lineno, context_col_offset):
227  """Adds origin information to an AST, based on the source it was loaded from.
228
229  This allows us to map the original source code line numbers to generated
230  source code.
231
232  Note: the AST may be a part of a larger context (e.g. a function is part of
233  a module that may contain other things). However, this function does not
234  assume the source argument contains the entire context, nor that it contains
235  only code corresponding to node itself. However, it assumes that node was
236  parsed from the given source code.
237  For this reason, two extra arguments are required, and they indicate the
238  location of the node in the original context.
239
240  Args:
241    node: gast.AST, the AST to annotate.
242    source: Text, the source code representing node.
243    context_filepath: Text
244    context_lineno: int
245    context_col_offset: int
246  """
247  # TODO(mdan): Pull this to a separate utility.
248  code_reader = six.StringIO(source)
249  comments_map = {}
250  try:
251    for token in tokenize.generate_tokens(code_reader.readline):
252      tok_type, tok_string, loc, _, _ = token
253      srow, _ = loc
254      if tok_type == tokenize.COMMENT:
255        comments_map[srow] = tok_string.strip()[1:].strip()
256  except tokenize.TokenError:
257    if isinstance(node, gast.Lambda):
258      # Source code resolution in older Python versions is brittle for
259      # lambda functions, and may contain garbage.
260      pass
261    else:
262      raise
263
264  source_lines = source.split('\n')
265  visitor = OriginResolver(node, source_lines, comments_map,
266                           context_lineno, context_col_offset,
267                           context_filepath)
268  visitor.visit(node)
269
270
271def resolve_entity(node, source, entity):
272  """Like resolve, but extracts the context information from an entity."""
273  lines, lineno = tf_inspect.getsourcelines(entity)
274  filepath = tf_inspect.getsourcefile(entity)
275
276  # Poor man's attempt at guessing the column offset: count the leading
277  # whitespace. This might not work well with tabs.
278  definition_line = lines[0]
279  col_offset = len(definition_line) - len(definition_line.lstrip())
280
281  resolve(node, source, filepath, lineno, col_offset)
282
283
284def copy_origin(from_node, to_node):
285  """Copies the origin info from a node to another, recursively."""
286  origin = anno.Basic.ORIGIN.of(from_node, default=None)
287  if origin is None:
288    return
289  if not isinstance(to_node, (list, tuple)):
290    to_node = (to_node,)
291  for node in to_node:
292    for n in gast.walk(node):
293      anno.setanno(n, anno.Basic.ORIGIN, origin)
294