1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""AST manipulation utilities."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import ast
22
23import gast
24
25from tensorflow.python.autograph.pyct import anno
26from tensorflow.python.autograph.pyct import parser
27from tensorflow.python.util import tf_inspect
28
29
30class CleanCopier(object):
31  """NodeTransformer-like visitor that copies an AST."""
32
33  def __init__(self, preserve_annos):
34    super(CleanCopier, self).__init__()
35    self.preserve_annos = preserve_annos
36
37  def copy(self, node):
38    """Returns a deep copy of node (excluding some fields, see copy_clean)."""
39
40    if isinstance(node, list):
41      return [self.copy(n) for n in node]
42    elif isinstance(node, tuple):
43      return tuple(self.copy(n) for n in node)
44    elif not isinstance(node, (gast.AST, ast.AST)):
45      # Assuming everything that's not an AST, list or tuple is a value type
46      # and may simply be assigned.
47      return node
48
49    assert isinstance(node, (gast.AST, ast.AST))
50
51    new_fields = {}
52    for f in node._fields:
53      if not f.startswith('__') and hasattr(node, f):
54        new_fields[f] = self.copy(getattr(node, f))
55    new_node = type(node)(**new_fields)
56
57    if self.preserve_annos:
58      for k in self.preserve_annos:
59        anno.copyanno(node, new_node, k)
60    return new_node
61
62
63def copy_clean(node, preserve_annos=None):
64  """Creates a deep copy of an AST.
65
66  The copy will not include fields that are prefixed by '__', with the
67  exception of user-specified annotations.
68
69  Args:
70    node: ast.AST
71    preserve_annos: Optional[Set[Hashable]], annotation keys to include in the
72        copy
73  Returns:
74    ast.AST
75  """
76  return CleanCopier(preserve_annos).copy(node)
77
78
79class SymbolRenamer(gast.NodeTransformer):
80  """Transformer that can rename symbols to a simple names."""
81
82  def __init__(self, name_map):
83    self.name_map = name_map
84
85  def _process(self, node):
86    qn = anno.getanno(node, anno.Basic.QN)
87    if qn in self.name_map:
88      new_node = gast.Name(str(self.name_map[qn]), node.ctx, None)
89      # All annotations get carried over.
90      for k in anno.keys(node):
91        anno.copyanno(node, new_node, k)
92      return new_node
93    return self.generic_visit(node)
94
95  def visit_Name(self, node):
96    return self._process(node)
97
98  def visit_Attribute(self, node):
99    if anno.hasanno(node, anno.Basic.QN):
100      return self._process(node)
101    # Attributes of dynamic objects will not have a QN.
102    return self.generic_visit(node)
103
104
105def rename_symbols(node, name_map):
106  """Renames symbols in an AST. Requires qual_names annotations."""
107  renamer = SymbolRenamer(name_map)
108  if isinstance(node, list):
109    return [renamer.visit(n) for n in node]
110  elif isinstance(node, tuple):
111    return tuple(renamer.visit(n) for n in node)
112  return renamer.visit(node)
113
114
115def keywords_to_dict(keywords):
116  """Converts a list of ast.keyword objects to a dict."""
117  keys = []
118  values = []
119  for kw in keywords:
120    keys.append(gast.Str(kw.arg))
121    values.append(kw.value)
122  return gast.Dict(keys=keys, values=values)
123
124
125class PatternMatcher(gast.NodeVisitor):
126  """Matches a node against a pattern represented by a node."""
127
128  def __init__(self, pattern):
129    self.pattern = pattern
130    self.pattern_stack = []
131    self.matches = True
132
133  def compare_and_visit(self, node, pattern):
134    self.pattern_stack.append(self.pattern)
135    self.pattern = pattern
136    self.generic_visit(node)
137    self.pattern = self.pattern_stack.pop()
138
139  def no_match(self):
140    self.matches = False
141    return False
142
143  def is_wildcard(self, p):
144    if isinstance(p, (list, tuple)) and len(p) == 1:
145      p, = p
146    if isinstance(p, gast.Name) and p.id == '_':
147      return True
148    if p == '_':
149      return True
150    return False
151
152  def generic_visit(self, node):
153    if not self.matches:
154      return
155
156    pattern = self.pattern
157    for f in node._fields:
158      if f.startswith('__'):
159        continue
160
161      if not hasattr(node, f):
162        if hasattr(pattern, f) and getattr(pattern, f):
163          return self.no_match()
164        else:
165          continue
166      if not hasattr(pattern, f):
167        return self.no_match()
168
169      v = getattr(node, f)
170      p = getattr(pattern, f)
171
172      if self.is_wildcard(p):
173        continue
174      if isinstance(v, (list, tuple)):
175        if not isinstance(p, (list, tuple)) or len(v) != len(p):
176          return self.no_match()
177        for v_item, p_item in zip(v, p):
178          self.compare_and_visit(v_item, p_item)
179      elif isinstance(v, (gast.AST, ast.AST)):
180        if not isinstance(v, type(p)) and not isinstance(p, type(v)):
181          return self.no_match()
182        self.compare_and_visit(v, p)
183      else:
184        # Assume everything else is a value type.
185        if v != p:
186          return self.no_match()
187
188
189def matches(node, pattern):
190  """Basic pattern matcher for AST.
191
192  The pattern may contain wildcards represented by the symbol '_'. A node
193  matches a pattern if for every node in the tree, either there is a node of
194  the same type in pattern, or a Name node with id='_'.
195
196  Args:
197    node: ast.AST
198    pattern: ast.AST
199  Returns:
200    bool
201  """
202  if isinstance(pattern, str):
203    pattern, = parser.parse_str(pattern).body
204
205  matcher = PatternMatcher(pattern)
206  matcher.visit(node)
207  return matcher.matches
208
209
210# TODO(mdan): Once we have error tracing, we may be able to just go to SSA.
211def apply_to_single_assignments(targets, values, apply_fn):
212  """Applies a function to each individual assignment.
213
214  This function can process a possibly-unpacked (e.g. a, b = c, d) assignment.
215  It tries to break down the unpacking if possible. In effect, it has the same
216  effect as passing the assigned values in SSA form to apply_fn.
217
218  Examples:
219
220  The following will result in apply_fn(a, c), apply_fn(b, d):
221
222      a, b = c, d
223
224  The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]):
225
226      a, b = c
227
228  The following will result in apply_fn(a, (b, c)):
229
230      a = b, c
231
232  It uses the visitor pattern to allow subclasses to process single
233  assignments individually.
234
235  Args:
236    targets: Union[List[ast.AST, ...], Tuple[ast.AST, ...], ast.AST, should be
237        used with the targets field of an ast.Assign node
238    values: ast.AST
239    apply_fn: Callable[[ast.AST, ast.AST], None], called with the
240        respective nodes of each single assignment
241  """
242  if not isinstance(targets, (list, tuple)):
243    targets = (targets,)
244  for target in targets:
245    if isinstance(target, (gast.Tuple, gast.List)):
246      for i in range(len(target.elts)):
247        target_el = target.elts[i]
248        if isinstance(values, (gast.Tuple, gast.List)):
249          value_el = values.elts[i]
250        else:
251          idx = parser.parse_expression(str(i))
252          value_el = gast.Subscript(values, gast.Index(idx), ctx=gast.Load())
253        apply_to_single_assignments(target_el, value_el, apply_fn)
254    else:
255      apply_fn(target, values)
256
257
258def parallel_walk(node, other):
259  """Walks two ASTs in parallel.
260
261  The two trees must have identical structure.
262
263  Args:
264    node: Union[ast.AST, Iterable[ast.AST]]
265    other: Union[ast.AST, Iterable[ast.AST]]
266  Yields:
267    Tuple[ast.AST, ast.AST]
268  Raises:
269    ValueError: if the two trees don't have identical structure.
270  """
271  if isinstance(node, (list, tuple)):
272    node_stack = list(node)
273  else:
274    node_stack = [node]
275
276  if isinstance(other, (list, tuple)):
277    other_stack = list(other)
278  else:
279    other_stack = [other]
280
281  while node_stack and other_stack:
282    assert len(node_stack) == len(other_stack)
283    n = node_stack.pop()
284    o = other_stack.pop()
285
286    if (not isinstance(n, (ast.AST, gast.AST, str)) or
287        not isinstance(o, (ast.AST, gast.AST, str)) or
288        n.__class__.__name__ != o.__class__.__name__):
289      raise ValueError('inconsistent nodes: {} ({}) and {} ({})'.format(
290          n, n.__class__.__name__, o, o.__class__.__name__))
291
292    yield n, o
293
294    if isinstance(n, str):
295      assert isinstance(o, str), 'The check above should have ensured this'
296      continue
297
298    for f in n._fields:
299      n_child = getattr(n, f, None)
300      o_child = getattr(o, f, None)
301      if f.startswith('__') or n_child is None or o_child is None:
302        continue
303
304      if isinstance(n_child, (list, tuple)):
305        if (not isinstance(o_child, (list, tuple)) or
306            len(n_child) != len(o_child)):
307          raise ValueError(
308              'inconsistent values for field {}: {} and {}'.format(
309                  f, n_child, o_child))
310        node_stack.extend(n_child)
311        other_stack.extend(o_child)
312
313      elif isinstance(n_child, (gast.AST, ast.AST)):
314        node_stack.append(n_child)
315        other_stack.append(o_child)
316
317      elif n_child != o_child:
318        raise ValueError(
319            'inconsistent values for field {}: {} and {}'.format(
320                f, n_child, o_child))
321
322
323class LambdaDefinitionMatcher(gast.NodeVisitor):
324  """Finds lambda nodes that match a given lambda's signature."""
325
326  def __init__(self, fn):
327    self.fn = fn
328    self.matching_nodes = []
329
330  def _arg_name(self, node):
331    if node is None:
332      return None
333    if isinstance(node, gast.Name):
334      return node.id
335    assert isinstance(node, str)
336    return node
337
338  def _argspec_matches(self, node):
339    arg_spec = tf_inspect.getfullargspec(self.fn)
340
341    node_args = tuple(self._arg_name(arg) for arg in node.args.args)
342    if node_args != tuple(arg_spec.args):
343      return False
344
345    if arg_spec.varargs != self._arg_name(node.args.vararg):
346      return False
347
348    if arg_spec.varkw != self._arg_name(node.args.kwarg):
349      return False
350
351    node_kwonlyargs = tuple(self._arg_name(arg) for arg in node.args.kwonlyargs)
352    if node_kwonlyargs != tuple(arg_spec.kwonlyargs):
353      return False
354
355    return True
356
357  def visit_Lambda(self, node):
358    self.generic_visit(node)
359
360    if self.fn.__name__ != '<lambda>':
361      return
362    if not self._argspec_matches(node):
363      return
364
365    self.matching_nodes.append(node)
366
367
368def find_matching_definitions(node, f):
369  matcher = LambdaDefinitionMatcher(f)
370  matcher.visit(node)
371  return tuple(matcher.matching_nodes)
372