1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for ast_util module."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import ast
22import collections
23import textwrap
24
25import gast
26
27from tensorflow.python.autograph.pyct import anno
28from tensorflow.python.autograph.pyct import ast_util
29from tensorflow.python.autograph.pyct import compiler
30from tensorflow.python.autograph.pyct import parser
31from tensorflow.python.autograph.pyct import qual_names
32from tensorflow.python.platform import test
33
34
35class AstUtilTest(test.TestCase):
36
37  def setUp(self):
38    super(AstUtilTest, self).setUp()
39    self._invocation_counts = collections.defaultdict(lambda: 0)
40
41  def test_rename_symbols_basic(self):
42    node = parser.parse_str('a + b')
43    node = qual_names.resolve(node)
44
45    node = ast_util.rename_symbols(
46        node, {qual_names.QN('a'): qual_names.QN('renamed_a')})
47
48    self.assertIsInstance(node.body[0].value.left.id, str)
49    source = compiler.ast_to_source(node)
50    self.assertEqual(source.strip(), 'renamed_a + b')
51
52  def test_rename_symbols_attributes(self):
53    node = parser.parse_str('b.c = b.c.d')
54    node = qual_names.resolve(node)
55
56    node = ast_util.rename_symbols(
57        node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')})
58
59    source = compiler.ast_to_source(node)
60    self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
61
62  def test_rename_symbols_annotations(self):
63    node = parser.parse_str('a[i]')
64    node = qual_names.resolve(node)
65    anno.setanno(node, 'foo', 'bar')
66    orig_anno = anno.getanno(node, 'foo')
67
68    node = ast_util.rename_symbols(node,
69                                   {qual_names.QN('a'): qual_names.QN('b')})
70
71    self.assertIs(anno.getanno(node, 'foo'), orig_anno)
72
73  def test_copy_clean(self):
74    node = parser.parse_str(
75        textwrap.dedent("""
76      def f(a):
77        return a + 1
78    """))
79    setattr(node.body[0], '__foo', 'bar')
80    new_node = ast_util.copy_clean(node)
81    self.assertIsNot(new_node, node)
82    self.assertIsNot(new_node.body[0], node.body[0])
83    self.assertFalse(hasattr(new_node.body[0], '__foo'))
84
85  def test_copy_clean_preserves_annotations(self):
86    node = parser.parse_str(
87        textwrap.dedent("""
88      def f(a):
89        return a + 1
90    """))
91    anno.setanno(node.body[0], 'foo', 'bar')
92    anno.setanno(node.body[0], 'baz', 1)
93    new_node = ast_util.copy_clean(node, preserve_annos={'foo'})
94    self.assertEqual(anno.getanno(new_node.body[0], 'foo'), 'bar')
95    self.assertFalse(anno.hasanno(new_node.body[0], 'baz'))
96
97  def test_keywords_to_dict(self):
98    keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords
99    d = ast_util.keywords_to_dict(keywords)
100    # Make sure we generate a usable dict node by attaching it to a variable and
101    # compiling everything.
102    node = parser.parse_str('def f(b): pass').body[0]
103    node.body.append(ast.Return(d))
104    result, _ = compiler.ast_to_object(node)
105    self.assertDictEqual(result.f(3), {'a': 3, 'c': 1, 'd': 'e'})
106
107  def assertMatch(self, target_str, pattern_str):
108    node = parser.parse_expression(target_str)
109    pattern = parser.parse_expression(pattern_str)
110    self.assertTrue(ast_util.matches(node, pattern))
111
112  def assertNoMatch(self, target_str, pattern_str):
113    node = parser.parse_expression(target_str)
114    pattern = parser.parse_expression(pattern_str)
115    self.assertFalse(ast_util.matches(node, pattern))
116
117  def test_matches_symbols(self):
118    self.assertMatch('foo', '_')
119    self.assertNoMatch('foo()', '_')
120    self.assertMatch('foo + bar', 'foo + _')
121    self.assertNoMatch('bar + bar', 'foo + _')
122    self.assertNoMatch('foo - bar', 'foo + _')
123
124  def test_matches_function_args(self):
125    self.assertMatch('super(Foo, self).__init__(arg1, arg2)',
126                     'super(_).__init__(_)')
127    self.assertMatch('super().__init__()', 'super(_).__init__(_)')
128    self.assertNoMatch('super(Foo, self).bar(arg1, arg2)',
129                       'super(_).__init__(_)')
130    self.assertMatch('super(Foo, self).__init__()', 'super(Foo, _).__init__(_)')
131    self.assertNoMatch('super(Foo, self).__init__()',
132                       'super(Bar, _).__init__(_)')
133
134  def _mock_apply_fn(self, target, source):
135    target = compiler.ast_to_source(target)
136    source = compiler.ast_to_source(source)
137    self._invocation_counts[(target.strip(), source.strip())] += 1
138
139  def test_apply_to_single_assignments_dynamic_unpack(self):
140    node = parser.parse_str('a, b, c = d')
141    node = node.body[0]
142    ast_util.apply_to_single_assignments(node.targets, node.value,
143                                         self._mock_apply_fn)
144    self.assertDictEqual(self._invocation_counts, {
145        ('a', 'd[0]'): 1,
146        ('b', 'd[1]'): 1,
147        ('c', 'd[2]'): 1,
148    })
149
150  def test_apply_to_single_assignments_static_unpack(self):
151    node = parser.parse_str('a, b, c = d, e, f')
152    node = node.body[0]
153    ast_util.apply_to_single_assignments(node.targets, node.value,
154                                         self._mock_apply_fn)
155    self.assertDictEqual(self._invocation_counts, {
156        ('a', 'd'): 1,
157        ('b', 'e'): 1,
158        ('c', 'f'): 1,
159    })
160
161  def test_parallel_walk(self):
162    src = """
163      def f(a):
164        return a + 1
165    """
166    node = parser.parse_str(textwrap.dedent(src))
167    for child_a, child_b in ast_util.parallel_walk(node, node):
168      self.assertEqual(child_a, child_b)
169
170  def test_parallel_walk_string_leaves(self):
171    src = """
172      def f(a):
173        global g
174    """
175    node = parser.parse_str(textwrap.dedent(src))
176    for child_a, child_b in ast_util.parallel_walk(node, node):
177      self.assertEqual(child_a, child_b)
178
179  def test_parallel_walk_inconsistent_trees(self):
180    node_1 = parser.parse_str(
181        textwrap.dedent("""
182      def f(a):
183        return a + 1
184    """))
185    node_2 = parser.parse_str(
186        textwrap.dedent("""
187      def f(a):
188        return a + (a * 2)
189    """))
190    node_3 = parser.parse_str(
191        textwrap.dedent("""
192      def f(a):
193        return a + 2
194    """))
195    with self.assertRaises(ValueError):
196      for _ in ast_util.parallel_walk(node_1, node_2):
197        pass
198    # There is not particular reason to reject trees that differ only in the
199    # value of a constant.
200    # TODO(mdan): This should probably be allowed.
201    with self.assertRaises(ValueError):
202      for _ in ast_util.parallel_walk(node_1, node_3):
203        pass
204
205  def assertLambdaNodes(self, matching_nodes, expected_bodies):
206    self.assertEqual(len(matching_nodes), len(expected_bodies))
207    for node in matching_nodes:
208      self.assertIsInstance(node, gast.Lambda)
209      self.assertIn(compiler.ast_to_source(node.body).strip(), expected_bodies)
210
211  def test_find_matching_definitions_lambda(self):
212    node = parser.parse_str(
213        textwrap.dedent("""
214      f = lambda x: 1
215    """))
216    f = lambda x: x
217    nodes = ast_util.find_matching_definitions(node, f)
218    self.assertLambdaNodes(nodes, ('(1)',))
219
220  def test_find_matching_definitions_lambda_multiple_matches(self):
221    node = parser.parse_str(
222        textwrap.dedent("""
223      f = lambda x: 1, lambda x: 2
224    """))
225    f = lambda x: x
226    nodes = ast_util.find_matching_definitions(node, f)
227    self.assertLambdaNodes(nodes, ('(1)', '(2)'))
228
229  def test_find_matching_definitions_lambda_uses_arg_names(self):
230    node = parser.parse_str(
231        textwrap.dedent("""
232      f = lambda x: 1, lambda y: 2
233    """))
234    f = lambda x: x
235    nodes = ast_util.find_matching_definitions(node, f)
236    self.assertLambdaNodes(nodes, ('(1)',))
237
238    f = lambda y: y
239    nodes = ast_util.find_matching_definitions(node, f)
240    self.assertLambdaNodes(nodes, ('(2)',))
241
242
243if __name__ == '__main__':
244  test.main()
245