1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""AST conversion templates. 16 17Adapted from Tangent. 18""" 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import ast 25import textwrap 26 27import gast 28 29from tensorflow.python.autograph.pyct import anno 30from tensorflow.python.autograph.pyct import ast_util 31from tensorflow.python.autograph.pyct import parser 32from tensorflow.python.autograph.pyct import qual_names 33 34 35class ContextAdjuster(gast.NodeTransformer): 36 """Adjusts the ctx field of nodes to ensure consistency. 37 38 This transformer can change the ctx fields of a variable, tuple and other 39 AST elements that allow one, based on whether the element is being read or 40 written. 41 """ 42 43 def __init__(self, override_value): 44 self._ctx_override = override_value 45 46 def visit(self, node): 47 original_override = self._ctx_override 48 node = super(ContextAdjuster, self).visit(node) 49 if hasattr(node, 'ctx'): 50 assert node.ctx is not None, 'node {} has ctx unset'.format(node) 51 self._ctx_override = original_override 52 return node 53 54 def _apply_override(self, node): 55 if self._ctx_override is not None: 56 node.ctx = self._ctx_override() 57 58 def visit_Attribute(self, node): 59 self._apply_override(node) 60 self._ctx_override = gast.Load 61 node = self.generic_visit(node) 62 return node 63 64 def visit_Tuple(self, node): 65 self._apply_override(node) 66 return self.generic_visit(node) 67 68 def visit_List(self, node): 69 self._apply_override(node) 70 return self.generic_visit(node) 71 72 def visit_Name(self, node): 73 self._apply_override(node) 74 return self.generic_visit(node) 75 76 def visit_Call(self, node): 77 self._apply_override(node) 78 # We may be able to override these to Load(), but for now it's simpler 79 # to just assert that they're set. 80 self._ctx_override = None 81 return self.generic_visit(node) 82 83 def visit_Dict(self, node): 84 # We may be able to override these to Load(), but for now it's simpler 85 # to just assert that they're set. 86 self._ctx_override = None 87 return self.generic_visit(node) 88 89 def visit_Subscript(self, node): 90 node.value = self.visit(node.value) 91 self._ctx_override = None 92 return self.generic_visit(node) 93 94 def visit_comprehension(self, node): 95 # We may be able to override some of these, but for now it's simpler 96 # to just assert that they're set. 97 self._ctx_override = None 98 return self.generic_visit(node) 99 100 def visit_Lambda(self, node): 101 # We may be able to override some of these, but for now it's simpler 102 # to just assert that they're set. 103 self._ctx_override = None 104 return self.generic_visit(node) 105 106 107class ReplaceTransformer(gast.NodeTransformer): 108 """Replace AST nodes.""" 109 110 def __init__(self, replacements): 111 """Create a new ReplaceTransformer. 112 113 Args: 114 replacements: A mapping from placeholder names to (lists of) AST nodes 115 that these placeholders will be replaced by. 116 """ 117 self.replacements = replacements 118 self.in_replacements = False 119 self.preserved_annos = { 120 anno.Basic.ORIGIN, 121 anno.Basic.SKIP_PROCESSING, 122 anno.Static.ORIG_DEFINITIONS, 123 'extra_test', 124 } 125 126 def _prepare_replacement(self, replaced, key): 127 """Prepares a replacement AST that's safe to swap in for a node. 128 129 Args: 130 replaced: ast.AST, the node being replaced 131 key: Hashable, the key of the replacement AST 132 Returns: 133 ast.AST, the replacement AST 134 """ 135 repl = self.replacements[key] 136 137 new_nodes = ast_util.copy_clean(repl, preserve_annos=self.preserved_annos) 138 if isinstance(new_nodes, gast.AST): 139 new_nodes = [new_nodes] 140 141 return new_nodes 142 143 def visit_Expr(self, node): 144 # When replacing a placeholder with an entire statement, the replacement 145 # must stand on its own and not be wrapped in an Expr. 146 new_value = self.visit(node.value) 147 if new_value is node.value: 148 return node 149 return new_value 150 151 def visit_keyword(self, node): 152 if node.arg not in self.replacements: 153 return self.generic_visit(node) 154 155 repl = self._prepare_replacement(node, node.arg) 156 if isinstance(repl, gast.keyword): 157 return repl 158 elif (repl and isinstance(repl, (list, tuple)) and 159 all(isinstance(r, gast.keyword) for r in repl)): 160 return repl 161 # TODO(mdan): We may allow replacing with a string as well. 162 # For example, if one wanted to replace foo with bar in foo=baz, then 163 # we could allow changing just node arg, so that we end up with bar=baz. 164 raise ValueError( 165 'a keyword argument may only be replaced by another keyword or a ' 166 'non-empty list of keywords. Found: {} for keyword {}'.format( 167 repl, node.arg)) 168 169 def visit_FunctionDef(self, node): 170 node = self.generic_visit(node) 171 if node.name not in self.replacements: 172 return node 173 174 repl = self.replacements[node.name] 175 if not isinstance(repl, (gast.Name, ast.Name)): 176 raise ValueError( 177 'a function name can only be replaced by a Name node. Found: %s' % 178 repl) 179 node.name = repl.id 180 return node 181 182 def visit_Attribute(self, node): 183 node = self.generic_visit(node) 184 if node.attr not in self.replacements: 185 return node 186 187 repl = self.replacements[node.attr] 188 if not isinstance(repl, gast.Name): 189 raise ValueError( 190 'An attribute can only be replaced by a Name node. Found: %s' % repl) 191 node.attr = repl.id 192 return node 193 194 def visit_Name(self, node): 195 if node.id not in self.replacements: 196 return node 197 198 new_nodes = self._prepare_replacement(node, node.id) 199 200 if not new_nodes: 201 return new_nodes 202 203 # Preserve the target context. 204 adjuster = ContextAdjuster(type(node.ctx)) 205 for n in new_nodes: 206 if hasattr(n, 'ctx'): 207 adjuster.visit(n) 208 209 if len(new_nodes) == 1: 210 new_nodes, = new_nodes 211 212 return new_nodes 213 214 215def _convert_to_ast(n): 216 """Converts from a known data type to AST.""" 217 if isinstance(n, str): 218 # Note: the node will receive the ctx value from the template, see 219 # ReplaceTransformer.visit_Name. 220 return gast.Name(id=n, ctx=None, annotation=None) 221 if isinstance(n, qual_names.QN): 222 return n.ast() 223 if isinstance(n, list): 224 return [_convert_to_ast(e) for e in n] 225 if isinstance(n, tuple): 226 return tuple(_convert_to_ast(e) for e in n) 227 return n 228 229 230def replace(template, **replacements): 231 """Replaces placeholders in a Python template. 232 233 AST Name and Tuple nodes always receive the context that inferred from 234 the template. However, when replacing more complex nodes (that can potentially 235 contain Name children), then the caller is responsible for setting the 236 appropriate context. 237 238 Args: 239 template: A string representing Python code. Any symbol name can be used 240 that appears in the template code can be used as placeholder. 241 **replacements: A mapping from placeholder names to (lists of) AST nodes 242 that these placeholders will be replaced by. String values are also 243 supported as a shorthand for AST Name nodes with the respective ID. 244 245 Returns: 246 An AST node or list of AST nodes with the replacements made. If the 247 template was a function, a list will be returned. If the template was a 248 node, the same node will be returned. If the template was a string, an 249 AST node will be returned (a `Module` node in the case of a multi-line 250 string, an `Expr` node otherwise). 251 252 Raises: 253 ValueError: if the arguments are incorrect. 254 """ 255 if not isinstance(template, str): 256 raise ValueError('Expected string template, got %s' % type(template)) 257 tree = parser.parse_str(textwrap.dedent(template)) 258 for k in replacements: 259 replacements[k] = _convert_to_ast(replacements[k]) 260 results = ReplaceTransformer(replacements).visit(tree).body 261 if isinstance(results, list): 262 return [qual_names.resolve(r) for r in results] 263 return qual_names.resolve(results) 264 265 266def replace_as_expression(template, **replacements): 267 """Variant of replace that generates expressions, instead of code blocks.""" 268 replacement = replace(template, **replacements) 269 if len(replacement) != 1: 270 raise ValueError( 271 'single expression expected; for more general templates use replace') 272 node = replacement[0] 273 node = qual_names.resolve(node) 274 275 if isinstance(node, gast.Expr): 276 return node.value 277 elif isinstance(node, gast.Name): 278 return node 279 280 raise ValueError( 281 'the template is expected to generate an expression or a name node;' 282 ' instead found %s' % node) 283