1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for ast_util module.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import ast 22import collections 23import textwrap 24 25import gast 26 27from tensorflow.python.autograph.pyct import anno 28from tensorflow.python.autograph.pyct import ast_util 29from tensorflow.python.autograph.pyct import compiler 30from tensorflow.python.autograph.pyct import parser 31from tensorflow.python.autograph.pyct import qual_names 32from tensorflow.python.platform import test 33 34 35class AstUtilTest(test.TestCase): 36 37 def setUp(self): 38 super(AstUtilTest, self).setUp() 39 self._invocation_counts = collections.defaultdict(lambda: 0) 40 41 def test_rename_symbols_basic(self): 42 node = parser.parse_str('a + b') 43 node = qual_names.resolve(node) 44 45 node = ast_util.rename_symbols( 46 node, {qual_names.QN('a'): qual_names.QN('renamed_a')}) 47 48 self.assertIsInstance(node.body[0].value.left.id, str) 49 source = compiler.ast_to_source(node) 50 self.assertEqual(source.strip(), 'renamed_a + b') 51 52 def test_rename_symbols_attributes(self): 53 node = parser.parse_str('b.c = b.c.d') 54 node = qual_names.resolve(node) 55 56 node = ast_util.rename_symbols( 57 node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')}) 58 59 source = compiler.ast_to_source(node) 60 self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d') 61 62 def test_rename_symbols_annotations(self): 63 node = parser.parse_str('a[i]') 64 node = qual_names.resolve(node) 65 anno.setanno(node, 'foo', 'bar') 66 orig_anno = anno.getanno(node, 'foo') 67 68 node = ast_util.rename_symbols(node, 69 {qual_names.QN('a'): qual_names.QN('b')}) 70 71 self.assertIs(anno.getanno(node, 'foo'), orig_anno) 72 73 def test_copy_clean(self): 74 node = parser.parse_str( 75 textwrap.dedent(""" 76 def f(a): 77 return a + 1 78 """)) 79 setattr(node.body[0], '__foo', 'bar') 80 new_node = ast_util.copy_clean(node) 81 self.assertIsNot(new_node, node) 82 self.assertIsNot(new_node.body[0], node.body[0]) 83 self.assertFalse(hasattr(new_node.body[0], '__foo')) 84 85 def test_copy_clean_preserves_annotations(self): 86 node = parser.parse_str( 87 textwrap.dedent(""" 88 def f(a): 89 return a + 1 90 """)) 91 anno.setanno(node.body[0], 'foo', 'bar') 92 anno.setanno(node.body[0], 'baz', 1) 93 new_node = ast_util.copy_clean(node, preserve_annos={'foo'}) 94 self.assertEqual(anno.getanno(new_node.body[0], 'foo'), 'bar') 95 self.assertFalse(anno.hasanno(new_node.body[0], 'baz')) 96 97 def test_keywords_to_dict(self): 98 keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords 99 d = ast_util.keywords_to_dict(keywords) 100 # Make sure we generate a usable dict node by attaching it to a variable and 101 # compiling everything. 102 node = parser.parse_str('def f(b): pass').body[0] 103 node.body.append(ast.Return(d)) 104 result, _ = compiler.ast_to_object(node) 105 self.assertDictEqual(result.f(3), {'a': 3, 'c': 1, 'd': 'e'}) 106 107 def assertMatch(self, target_str, pattern_str): 108 node = parser.parse_expression(target_str) 109 pattern = parser.parse_expression(pattern_str) 110 self.assertTrue(ast_util.matches(node, pattern)) 111 112 def assertNoMatch(self, target_str, pattern_str): 113 node = parser.parse_expression(target_str) 114 pattern = parser.parse_expression(pattern_str) 115 self.assertFalse(ast_util.matches(node, pattern)) 116 117 def test_matches_symbols(self): 118 self.assertMatch('foo', '_') 119 self.assertNoMatch('foo()', '_') 120 self.assertMatch('foo + bar', 'foo + _') 121 self.assertNoMatch('bar + bar', 'foo + _') 122 self.assertNoMatch('foo - bar', 'foo + _') 123 124 def test_matches_function_args(self): 125 self.assertMatch('super(Foo, self).__init__(arg1, arg2)', 126 'super(_).__init__(_)') 127 self.assertMatch('super().__init__()', 'super(_).__init__(_)') 128 self.assertNoMatch('super(Foo, self).bar(arg1, arg2)', 129 'super(_).__init__(_)') 130 self.assertMatch('super(Foo, self).__init__()', 'super(Foo, _).__init__(_)') 131 self.assertNoMatch('super(Foo, self).__init__()', 132 'super(Bar, _).__init__(_)') 133 134 def _mock_apply_fn(self, target, source): 135 target = compiler.ast_to_source(target) 136 source = compiler.ast_to_source(source) 137 self._invocation_counts[(target.strip(), source.strip())] += 1 138 139 def test_apply_to_single_assignments_dynamic_unpack(self): 140 node = parser.parse_str('a, b, c = d') 141 node = node.body[0] 142 ast_util.apply_to_single_assignments(node.targets, node.value, 143 self._mock_apply_fn) 144 self.assertDictEqual(self._invocation_counts, { 145 ('a', 'd[0]'): 1, 146 ('b', 'd[1]'): 1, 147 ('c', 'd[2]'): 1, 148 }) 149 150 def test_apply_to_single_assignments_static_unpack(self): 151 node = parser.parse_str('a, b, c = d, e, f') 152 node = node.body[0] 153 ast_util.apply_to_single_assignments(node.targets, node.value, 154 self._mock_apply_fn) 155 self.assertDictEqual(self._invocation_counts, { 156 ('a', 'd'): 1, 157 ('b', 'e'): 1, 158 ('c', 'f'): 1, 159 }) 160 161 def test_parallel_walk(self): 162 src = """ 163 def f(a): 164 return a + 1 165 """ 166 node = parser.parse_str(textwrap.dedent(src)) 167 for child_a, child_b in ast_util.parallel_walk(node, node): 168 self.assertEqual(child_a, child_b) 169 170 def test_parallel_walk_string_leaves(self): 171 src = """ 172 def f(a): 173 global g 174 """ 175 node = parser.parse_str(textwrap.dedent(src)) 176 for child_a, child_b in ast_util.parallel_walk(node, node): 177 self.assertEqual(child_a, child_b) 178 179 def test_parallel_walk_inconsistent_trees(self): 180 node_1 = parser.parse_str( 181 textwrap.dedent(""" 182 def f(a): 183 return a + 1 184 """)) 185 node_2 = parser.parse_str( 186 textwrap.dedent(""" 187 def f(a): 188 return a + (a * 2) 189 """)) 190 node_3 = parser.parse_str( 191 textwrap.dedent(""" 192 def f(a): 193 return a + 2 194 """)) 195 with self.assertRaises(ValueError): 196 for _ in ast_util.parallel_walk(node_1, node_2): 197 pass 198 # There is not particular reason to reject trees that differ only in the 199 # value of a constant. 200 # TODO(mdan): This should probably be allowed. 201 with self.assertRaises(ValueError): 202 for _ in ast_util.parallel_walk(node_1, node_3): 203 pass 204 205 def assertLambdaNodes(self, matching_nodes, expected_bodies): 206 self.assertEqual(len(matching_nodes), len(expected_bodies)) 207 for node in matching_nodes: 208 self.assertIsInstance(node, gast.Lambda) 209 self.assertIn(compiler.ast_to_source(node.body).strip(), expected_bodies) 210 211 def test_find_matching_definitions_lambda(self): 212 node = parser.parse_str( 213 textwrap.dedent(""" 214 f = lambda x: 1 215 """)) 216 f = lambda x: x 217 nodes = ast_util.find_matching_definitions(node, f) 218 self.assertLambdaNodes(nodes, ('(1)',)) 219 220 def test_find_matching_definitions_lambda_multiple_matches(self): 221 node = parser.parse_str( 222 textwrap.dedent(""" 223 f = lambda x: 1, lambda x: 2 224 """)) 225 f = lambda x: x 226 nodes = ast_util.find_matching_definitions(node, f) 227 self.assertLambdaNodes(nodes, ('(1)', '(2)')) 228 229 def test_find_matching_definitions_lambda_uses_arg_names(self): 230 node = parser.parse_str( 231 textwrap.dedent(""" 232 f = lambda x: 1, lambda y: 2 233 """)) 234 f = lambda x: x 235 nodes = ast_util.find_matching_definitions(node, f) 236 self.assertLambdaNodes(nodes, ('(1)',)) 237 238 f = lambda y: y 239 nodes = ast_util.find_matching_definitions(node, f) 240 self.assertLambdaNodes(nodes, ('(2)',)) 241 242 243if __name__ == '__main__': 244 test.main() 245