1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""AST conversion templates.
16
17Adapted from Tangent.
18"""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import ast
25import textwrap
26
27import gast
28
29from tensorflow.python.autograph.pyct import anno
30from tensorflow.python.autograph.pyct import ast_util
31from tensorflow.python.autograph.pyct import parser
32from tensorflow.python.autograph.pyct import qual_names
33
34
35class ContextAdjuster(gast.NodeTransformer):
36  """Adjusts the ctx field of nodes to ensure consistency.
37
38  This transformer can change the ctx fields of a variable, tuple and other
39  AST elements that allow one, based on whether the element is being read or
40  written.
41  """
42
43  def __init__(self, override_value):
44    self._ctx_override = override_value
45
46  def visit(self, node):
47    original_override = self._ctx_override
48    node = super(ContextAdjuster, self).visit(node)
49    if hasattr(node, 'ctx'):
50      assert node.ctx is not None, 'node {} has ctx unset'.format(node)
51    self._ctx_override = original_override
52    return node
53
54  def _apply_override(self, node):
55    if self._ctx_override is not None:
56      node.ctx = self._ctx_override()
57
58  def visit_Attribute(self, node):
59    self._apply_override(node)
60    self._ctx_override = gast.Load
61    node = self.generic_visit(node)
62    return node
63
64  def visit_Tuple(self, node):
65    self._apply_override(node)
66    return self.generic_visit(node)
67
68  def visit_List(self, node):
69    self._apply_override(node)
70    return self.generic_visit(node)
71
72  def visit_Name(self, node):
73    self._apply_override(node)
74    return self.generic_visit(node)
75
76  def visit_Call(self, node):
77    self._apply_override(node)
78    # We may be able to override these to Load(), but for now it's simpler
79    # to just assert that they're set.
80    self._ctx_override = None
81    return self.generic_visit(node)
82
83  def visit_Dict(self, node):
84    # We may be able to override these to Load(), but for now it's simpler
85    # to just assert that they're set.
86    self._ctx_override = None
87    return self.generic_visit(node)
88
89  def visit_Subscript(self, node):
90    node.value = self.visit(node.value)
91    self._ctx_override = None
92    return self.generic_visit(node)
93
94  def visit_comprehension(self, node):
95    # We may be able to override some of these, but for now it's simpler
96    # to just assert that they're set.
97    self._ctx_override = None
98    return self.generic_visit(node)
99
100  def visit_Lambda(self, node):
101    # We may be able to override some of these, but for now it's simpler
102    # to just assert that they're set.
103    self._ctx_override = None
104    return self.generic_visit(node)
105
106
107class ReplaceTransformer(gast.NodeTransformer):
108  """Replace AST nodes."""
109
110  def __init__(self, replacements):
111    """Create a new ReplaceTransformer.
112
113    Args:
114      replacements: A mapping from placeholder names to (lists of) AST nodes
115          that these placeholders will be replaced by.
116    """
117    self.replacements = replacements
118    self.in_replacements = False
119    self.preserved_annos = {
120        anno.Basic.ORIGIN,
121        anno.Basic.SKIP_PROCESSING,
122        anno.Static.ORIG_DEFINITIONS,
123        'extra_test',
124    }
125
126  def _prepare_replacement(self, replaced, key):
127    """Prepares a replacement AST that's safe to swap in for a node.
128
129    Args:
130      replaced: ast.AST, the node being replaced
131      key: Hashable, the key of the replacement AST
132    Returns:
133      ast.AST, the replacement AST
134    """
135    repl = self.replacements[key]
136
137    new_nodes = ast_util.copy_clean(repl, preserve_annos=self.preserved_annos)
138    if isinstance(new_nodes, gast.AST):
139      new_nodes = [new_nodes]
140
141    return new_nodes
142
143  def visit_Expr(self, node):
144    # When replacing a placeholder with an entire statement, the replacement
145    # must stand on its own and not be wrapped in an Expr.
146    new_value = self.visit(node.value)
147    if new_value is node.value:
148      return node
149    return new_value
150
151  def visit_keyword(self, node):
152    if node.arg not in self.replacements:
153      return self.generic_visit(node)
154
155    repl = self._prepare_replacement(node, node.arg)
156    if isinstance(repl, gast.keyword):
157      return repl
158    elif (repl and isinstance(repl, (list, tuple)) and
159          all(isinstance(r, gast.keyword) for r in repl)):
160      return repl
161    # TODO(mdan): We may allow replacing with a string as well.
162    # For example, if one wanted to replace foo with bar in foo=baz, then
163    # we could allow changing just node arg, so that we end up with bar=baz.
164    raise ValueError(
165        'a keyword argument may only be replaced by another keyword or a '
166        'non-empty list of keywords. Found: {} for keyword {}'.format(
167            repl, node.arg))
168
169  def visit_FunctionDef(self, node):
170    node = self.generic_visit(node)
171    if node.name not in self.replacements:
172      return node
173
174    repl = self.replacements[node.name]
175    if not isinstance(repl, (gast.Name, ast.Name)):
176      raise ValueError(
177          'a function name can only be replaced by a Name node. Found: %s' %
178          repl)
179    node.name = repl.id
180    return node
181
182  def visit_Attribute(self, node):
183    node = self.generic_visit(node)
184    if node.attr not in self.replacements:
185      return node
186
187    repl = self.replacements[node.attr]
188    if not isinstance(repl, gast.Name):
189      raise ValueError(
190          'An attribute can only be replaced by a Name node. Found: %s' % repl)
191    node.attr = repl.id
192    return node
193
194  def visit_Name(self, node):
195    if node.id not in self.replacements:
196      return node
197
198    new_nodes = self._prepare_replacement(node, node.id)
199
200    if not new_nodes:
201      return new_nodes
202
203    # Preserve the target context.
204    adjuster = ContextAdjuster(type(node.ctx))
205    for n in new_nodes:
206      if hasattr(n, 'ctx'):
207        adjuster.visit(n)
208
209    if len(new_nodes) == 1:
210      new_nodes, = new_nodes
211
212    return new_nodes
213
214
215def _convert_to_ast(n):
216  """Converts from a known data type to AST."""
217  if isinstance(n, str):
218    # Note: the node will receive the ctx value from the template, see
219    # ReplaceTransformer.visit_Name.
220    return gast.Name(id=n, ctx=None, annotation=None)
221  if isinstance(n, qual_names.QN):
222    return n.ast()
223  if isinstance(n, list):
224    return [_convert_to_ast(e) for e in n]
225  if isinstance(n, tuple):
226    return tuple(_convert_to_ast(e) for e in n)
227  return n
228
229
230def replace(template, **replacements):
231  """Replaces placeholders in a Python template.
232
233  AST Name and Tuple nodes always receive the context that inferred from
234  the template. However, when replacing more complex nodes (that can potentially
235  contain Name children), then the caller is responsible for setting the
236  appropriate context.
237
238  Args:
239    template: A string representing Python code. Any symbol name can be used
240        that appears in the template code can be used as placeholder.
241    **replacements: A mapping from placeholder names to (lists of) AST nodes
242        that these placeholders will be replaced by. String values are also
243        supported as a shorthand for AST Name nodes with the respective ID.
244
245  Returns:
246    An AST node or list of AST nodes with the replacements made. If the
247    template was a function, a list will be returned. If the template was a
248    node, the same node will be returned. If the template was a string, an
249    AST node will be returned (a `Module` node in the case of a multi-line
250    string, an `Expr` node otherwise).
251
252  Raises:
253    ValueError: if the arguments are incorrect.
254  """
255  if not isinstance(template, str):
256    raise ValueError('Expected string template, got %s' % type(template))
257  tree = parser.parse_str(textwrap.dedent(template))
258  for k in replacements:
259    replacements[k] = _convert_to_ast(replacements[k])
260  results = ReplaceTransformer(replacements).visit(tree).body
261  if isinstance(results, list):
262    return [qual_names.resolve(r) for r in results]
263  return qual_names.resolve(results)
264
265
266def replace_as_expression(template, **replacements):
267  """Variant of replace that generates expressions, instead of code blocks."""
268  replacement = replace(template, **replacements)
269  if len(replacement) != 1:
270    raise ValueError(
271        'single expression expected; for more general templates use replace')
272  node = replacement[0]
273  node = qual_names.resolve(node)
274
275  if isinstance(node, gast.Expr):
276    return node.value
277  elif isinstance(node, gast.Name):
278    return node
279
280  raise ValueError(
281      'the template is expected to generate an expression or a name node;'
282      ' instead found %s' % node)
283