1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Container for origin source code information before AutoGraph compilation.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections 21import difflib 22import os 23import tokenize 24 25import gast 26import six 27 28from tensorflow.python.autograph.pyct import anno 29from tensorflow.python.autograph.pyct import ast_util 30from tensorflow.python.autograph.pyct import parser 31from tensorflow.python.autograph.pyct import pretty_printer 32from tensorflow.python.util import tf_inspect 33 34 35class LineLocation( 36 collections.namedtuple('LineLocation', ('filename', 'lineno'))): 37 """Similar to Location, but without column information. 38 39 Attributes: 40 filename: Text 41 lineno: int, 1-based 42 """ 43 pass 44 45 46class Location( 47 collections.namedtuple('Location', ('filename', 'lineno', 'col_offset'))): 48 """Encodes code location information. 49 50 Attributes: 51 filename: Text 52 lineno: int, 1-based 53 col_offset: int 54 line_loc: LineLocation 55 """ 56 57 @property 58 def line_loc(self): 59 return LineLocation(self.filename, self.lineno) 60 61 62class OriginInfo( 63 collections.namedtuple( 64 'OriginInfo', 65 ('loc', 'function_name', 'source_code_line', 'comment'))): 66 """Container for information about the source code before conversion. 67 68 Attributes: 69 loc: Location 70 function_name: Optional[Text] 71 source_code_line: Text 72 comment: Optional[Text] 73 """ 74 75 def as_frame(self): 76 """Returns a 4-tuple consistent with the return of traceback.extract_tb.""" 77 return (self.loc.filename, self.loc.lineno, self.function_name, 78 self.source_code_line) 79 80 def __repr__(self): 81 if self.loc.filename: 82 return '{}:{}:{}'.format( 83 os.path.split(self.loc.filename)[1], self.loc.lineno, 84 self.loc.col_offset) 85 return '<no file>:{}:{}'.format(self.loc.lineno, self.loc.col_offset) 86 87 88# TODO(mdan): This source map should be a class - easier to refer to. 89def create_source_map(nodes, code, filepath): 90 """Creates a source map between an annotated AST and the code it compiles to. 91 92 Note: this function assumes nodes nodes, code and filepath correspond to the 93 same code. 94 95 Args: 96 nodes: Iterable[ast.AST, ...], one or more AST modes. 97 code: Text, the source code in which nodes are found. 98 filepath: Text 99 100 Returns: 101 Dict[LineLocation, OriginInfo], mapping locations in code to locations 102 indicated by origin annotations in node. 103 """ 104 reparsed_nodes = parser.parse(code, preamble_len=0, single_node=False) 105 for node in reparsed_nodes: 106 resolve(node, code, filepath, node.lineno, node.col_offset) 107 108 source_map = {} 109 110 try: 111 for before, after in ast_util.parallel_walk(nodes, reparsed_nodes): 112 # Note: generated code might not be mapped back to its origin. 113 # TODO(mdan): Generated code should always be mapped to something. 114 origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None) 115 final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None) 116 if origin_info is None or final_info is None: 117 continue 118 119 # Note: the keys are by line only, excluding the column offset. 120 line_loc = LineLocation(final_info.loc.filename, final_info.loc.lineno) 121 122 existing_origin = source_map.get(line_loc) 123 if existing_origin is not None: 124 # Overlaps may exist because of child nodes, but almost never to 125 # different line locations. Exception make decorated functions, where 126 # both lines are mapped to the same line in the AST. 127 128 # Line overlaps: keep bottom node. 129 if existing_origin.loc.line_loc == origin_info.loc.line_loc: 130 if existing_origin.loc.lineno >= origin_info.loc.lineno: 131 continue 132 133 # In case of column overlaps, keep the leftmost node. 134 if existing_origin.loc.col_offset <= origin_info.loc.col_offset: 135 continue 136 137 source_map[line_loc] = origin_info 138 139 except ValueError as err: 140 new_msg = 'Inconsistent ASTs detected. This is a bug. Cause: \n' 141 new_msg += str(err) 142 new_msg += 'Diff:\n' 143 144 for n, rn in zip(nodes, reparsed_nodes): 145 nodes_str = pretty_printer.fmt(n, color=False, noanno=True) 146 reparsed_nodes_str = pretty_printer.fmt(rn, color=False, noanno=True) 147 diff = difflib.context_diff( 148 nodes_str.split('\n'), 149 reparsed_nodes_str.split('\n'), 150 fromfile='Original nodes', 151 tofile='Reparsed nodes', 152 n=7) 153 diff = '\n'.join(diff) 154 new_msg += diff + '\n' 155 raise ValueError(new_msg) 156 157 return source_map 158 159 160class _Function(object): 161 162 def __init__(self, name): 163 self.name = name 164 165 166class OriginResolver(gast.NodeVisitor): 167 """Annotates an AST with additional source information like file name.""" 168 169 def __init__(self, root_node, source_lines, comments_map, 170 context_lineno, context_col_offset, 171 filepath): 172 self._source_lines = source_lines 173 self._comments_map = comments_map 174 175 if (hasattr(root_node, 'decorator_list') and root_node.decorator_list and 176 hasattr(root_node.decorator_list[0], 'lineno')): 177 # Typical case: functions. The line number of the first decorator 178 # is more accurate than the line number of the function itself in 179 # 3.8+. In earier versions they coincide. 180 self._lineno_offset = context_lineno - root_node.decorator_list[0].lineno 181 else: 182 # Fall back to the line number of the root node. 183 self._lineno_offset = context_lineno - root_node.lineno 184 185 self._col_offset = context_col_offset - root_node.col_offset 186 187 self._filepath = filepath 188 189 self._function_stack = [] 190 191 def _absolute_lineno(self, node): 192 return node.lineno + self._lineno_offset 193 194 def _absolute_col_offset(self, node): 195 return node.col_offset + self._col_offset 196 197 def _attach_origin_info(self, node): 198 if self._function_stack: 199 function_name = self._function_stack[-1].name 200 else: 201 function_name = None 202 203 source_code_line = self._source_lines[node.lineno - 1] 204 comment = self._comments_map.get(node.lineno) 205 206 loc = Location(self._filepath, self._absolute_lineno(node), 207 self._absolute_col_offset(node)) 208 origin = OriginInfo(loc, function_name, source_code_line, comment) 209 anno.setanno(node, 'lineno', node.lineno) 210 anno.setanno(node, anno.Basic.ORIGIN, origin) 211 212 def visit(self, node): 213 entered_function = False 214 if isinstance(node, gast.FunctionDef): 215 entered_function = True 216 self._function_stack.append(_Function(node.name)) 217 218 if hasattr(node, 'lineno'): 219 self._attach_origin_info(node) 220 self.generic_visit(node) 221 222 if entered_function: 223 self._function_stack.pop() 224 225 226def resolve(node, source, context_filepath, context_lineno, context_col_offset): 227 """Adds origin information to an AST, based on the source it was loaded from. 228 229 This allows us to map the original source code line numbers to generated 230 source code. 231 232 Note: the AST may be a part of a larger context (e.g. a function is part of 233 a module that may contain other things). However, this function does not 234 assume the source argument contains the entire context, nor that it contains 235 only code corresponding to node itself. However, it assumes that node was 236 parsed from the given source code. 237 For this reason, two extra arguments are required, and they indicate the 238 location of the node in the original context. 239 240 Args: 241 node: gast.AST, the AST to annotate. 242 source: Text, the source code representing node. 243 context_filepath: Text 244 context_lineno: int 245 context_col_offset: int 246 """ 247 # TODO(mdan): Pull this to a separate utility. 248 code_reader = six.StringIO(source) 249 comments_map = {} 250 try: 251 for token in tokenize.generate_tokens(code_reader.readline): 252 tok_type, tok_string, loc, _, _ = token 253 srow, _ = loc 254 if tok_type == tokenize.COMMENT: 255 comments_map[srow] = tok_string.strip()[1:].strip() 256 except tokenize.TokenError: 257 if isinstance(node, gast.Lambda): 258 # Source code resolution in older Python versions is brittle for 259 # lambda functions, and may contain garbage. 260 pass 261 else: 262 raise 263 264 source_lines = source.split('\n') 265 visitor = OriginResolver(node, source_lines, comments_map, 266 context_lineno, context_col_offset, 267 context_filepath) 268 visitor.visit(node) 269 270 271def resolve_entity(node, source, entity): 272 """Like resolve, but extracts the context information from an entity.""" 273 lines, lineno = tf_inspect.getsourcelines(entity) 274 filepath = tf_inspect.getsourcefile(entity) 275 276 # Poor man's attempt at guessing the column offset: count the leading 277 # whitespace. This might not work well with tabs. 278 definition_line = lines[0] 279 col_offset = len(definition_line) - len(definition_line.lstrip()) 280 281 resolve(node, source, filepath, lineno, col_offset) 282 283 284def copy_origin(from_node, to_node): 285 """Copies the origin info from a node to another, recursively.""" 286 origin = anno.Basic.ORIGIN.of(from_node, default=None) 287 if origin is None: 288 return 289 if not isinstance(to_node, (list, tuple)): 290 to_node = (to_node,) 291 for node in to_node: 292 for n in gast.walk(node): 293 anno.setanno(n, anno.Basic.ORIGIN, origin) 294