1# Copyright 2015 Google Inc. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for yapf.pytree_utils."""
15
16import unittest
17
18from lib2to3 import pygram
19from lib2to3 import pytree
20from lib2to3.pgen2 import token
21
22from yapf.yapflib import pytree_utils
23
24# More direct access to the symbol->number mapping living within the grammar
25# module.
26_GRAMMAR_SYMBOL2NUMBER = pygram.python_grammar.symbol2number
27
28_FOO = 'foo'
29_FOO1 = 'foo1'
30_FOO2 = 'foo2'
31_FOO3 = 'foo3'
32_FOO4 = 'foo4'
33_FOO5 = 'foo5'
34
35
36class NodeNameTest(unittest.TestCase):
37
38  def testNodeNameForLeaf(self):
39    leaf = pytree.Leaf(token.LPAR, '(')
40    self.assertEqual('LPAR', pytree_utils.NodeName(leaf))
41
42  def testNodeNameForNode(self):
43    leaf = pytree.Leaf(token.LPAR, '(')
44    node = pytree.Node(pygram.python_grammar.symbol2number['suite'], [leaf])
45    self.assertEqual('suite', pytree_utils.NodeName(node))
46
47
48class ParseCodeToTreeTest(unittest.TestCase):
49
50  def testParseCodeToTree(self):
51    # Since ParseCodeToTree is a thin wrapper around underlying lib2to3
52    # functionality, only a sanity test here...
53    tree = pytree_utils.ParseCodeToTree('foo = 2\n')
54    self.assertEqual('file_input', pytree_utils.NodeName(tree))
55    self.assertEqual(2, len(tree.children))
56    self.assertEqual('simple_stmt', pytree_utils.NodeName(tree.children[0]))
57
58  def testPrintFunctionToTree(self):
59    tree = pytree_utils.ParseCodeToTree(
60        'print("hello world", file=sys.stderr)\n')
61    self.assertEqual('file_input', pytree_utils.NodeName(tree))
62    self.assertEqual(2, len(tree.children))
63    self.assertEqual('simple_stmt', pytree_utils.NodeName(tree.children[0]))
64
65  def testPrintStatementToTree(self):
66    tree = pytree_utils.ParseCodeToTree('print "hello world"\n')
67    self.assertEqual('file_input', pytree_utils.NodeName(tree))
68    self.assertEqual(2, len(tree.children))
69    self.assertEqual('simple_stmt', pytree_utils.NodeName(tree.children[0]))
70
71  def testClassNotLocal(self):
72    tree = pytree_utils.ParseCodeToTree('class nonlocal: pass\n')
73    self.assertEqual('file_input', pytree_utils.NodeName(tree))
74    self.assertEqual(2, len(tree.children))
75    self.assertEqual('classdef', pytree_utils.NodeName(tree.children[0]))
76
77
78class InsertNodesBeforeAfterTest(unittest.TestCase):
79
80  def _BuildSimpleTree(self):
81    # Builds a simple tree we can play with in the tests.
82    # The tree looks like this:
83    #
84    #   suite:
85    #     LPAR
86    #     LPAR
87    #     simple_stmt:
88    #       NAME('foo')
89    #
90    lpar1 = pytree.Leaf(token.LPAR, '(')
91    lpar2 = pytree.Leaf(token.LPAR, '(')
92    simple_stmt = pytree.Node(_GRAMMAR_SYMBOL2NUMBER['simple_stmt'],
93                              [pytree.Leaf(token.NAME, 'foo')])
94    return pytree.Node(_GRAMMAR_SYMBOL2NUMBER['suite'],
95                       [lpar1, lpar2, simple_stmt])
96
97  def _MakeNewNodeRPAR(self):
98    return pytree.Leaf(token.RPAR, ')')
99
100  def setUp(self):
101    self._simple_tree = self._BuildSimpleTree()
102
103  def testInsertNodesBefore(self):
104    # Insert before simple_stmt and make sure it went to the right place
105    pytree_utils.InsertNodesBefore([self._MakeNewNodeRPAR()],
106                                   self._simple_tree.children[2])
107    self.assertEqual(4, len(self._simple_tree.children))
108    self.assertEqual('RPAR',
109                     pytree_utils.NodeName(self._simple_tree.children[2]))
110    self.assertEqual('simple_stmt',
111                     pytree_utils.NodeName(self._simple_tree.children[3]))
112
113  def testInsertNodesBeforeFirstChild(self):
114    # Insert before the first child of its parent
115    simple_stmt = self._simple_tree.children[2]
116    foo_child = simple_stmt.children[0]
117    pytree_utils.InsertNodesBefore([self._MakeNewNodeRPAR()], foo_child)
118    self.assertEqual(3, len(self._simple_tree.children))
119    self.assertEqual(2, len(simple_stmt.children))
120    self.assertEqual('RPAR', pytree_utils.NodeName(simple_stmt.children[0]))
121    self.assertEqual('NAME', pytree_utils.NodeName(simple_stmt.children[1]))
122
123  def testInsertNodesAfter(self):
124    # Insert after and make sure it went to the right place
125    pytree_utils.InsertNodesAfter([self._MakeNewNodeRPAR()],
126                                  self._simple_tree.children[2])
127    self.assertEqual(4, len(self._simple_tree.children))
128    self.assertEqual('simple_stmt',
129                     pytree_utils.NodeName(self._simple_tree.children[2]))
130    self.assertEqual('RPAR',
131                     pytree_utils.NodeName(self._simple_tree.children[3]))
132
133  def testInsertNodesAfterLastChild(self):
134    # Insert after the last child of its parent
135    simple_stmt = self._simple_tree.children[2]
136    foo_child = simple_stmt.children[0]
137    pytree_utils.InsertNodesAfter([self._MakeNewNodeRPAR()], foo_child)
138    self.assertEqual(3, len(self._simple_tree.children))
139    self.assertEqual(2, len(simple_stmt.children))
140    self.assertEqual('NAME', pytree_utils.NodeName(simple_stmt.children[0]))
141    self.assertEqual('RPAR', pytree_utils.NodeName(simple_stmt.children[1]))
142
143  def testInsertNodesWhichHasParent(self):
144    # Try to insert an existing tree node into another place and fail.
145    with self.assertRaises(RuntimeError):
146      pytree_utils.InsertNodesAfter([self._simple_tree.children[1]],
147                                    self._simple_tree.children[0])
148
149
150class AnnotationsTest(unittest.TestCase):
151
152  def setUp(self):
153    self._leaf = pytree.Leaf(token.LPAR, '(')
154    self._node = pytree.Node(_GRAMMAR_SYMBOL2NUMBER['simple_stmt'],
155                             [pytree.Leaf(token.NAME, 'foo')])
156
157  def testGetWhenNone(self):
158    self.assertIsNone(pytree_utils.GetNodeAnnotation(self._leaf, _FOO))
159
160  def testSetWhenNone(self):
161    pytree_utils.SetNodeAnnotation(self._leaf, _FOO, 20)
162    self.assertEqual(pytree_utils.GetNodeAnnotation(self._leaf, _FOO), 20)
163
164  def testSetAgain(self):
165    pytree_utils.SetNodeAnnotation(self._leaf, _FOO, 20)
166    self.assertEqual(pytree_utils.GetNodeAnnotation(self._leaf, _FOO), 20)
167    pytree_utils.SetNodeAnnotation(self._leaf, _FOO, 30)
168    self.assertEqual(pytree_utils.GetNodeAnnotation(self._leaf, _FOO), 30)
169
170  def testMultiple(self):
171    pytree_utils.SetNodeAnnotation(self._leaf, _FOO, 20)
172    pytree_utils.SetNodeAnnotation(self._leaf, _FOO1, 1)
173    pytree_utils.SetNodeAnnotation(self._leaf, _FOO2, 2)
174    pytree_utils.SetNodeAnnotation(self._leaf, _FOO3, 3)
175    pytree_utils.SetNodeAnnotation(self._leaf, _FOO4, 4)
176    pytree_utils.SetNodeAnnotation(self._leaf, _FOO5, 5)
177
178    self.assertEqual(pytree_utils.GetNodeAnnotation(self._leaf, _FOO), 20)
179    self.assertEqual(pytree_utils.GetNodeAnnotation(self._leaf, _FOO1), 1)
180    self.assertEqual(pytree_utils.GetNodeAnnotation(self._leaf, _FOO2), 2)
181    self.assertEqual(pytree_utils.GetNodeAnnotation(self._leaf, _FOO3), 3)
182    self.assertEqual(pytree_utils.GetNodeAnnotation(self._leaf, _FOO4), 4)
183    self.assertEqual(pytree_utils.GetNodeAnnotation(self._leaf, _FOO5), 5)
184
185  def testSubtype(self):
186    pytree_utils.AppendNodeAnnotation(self._leaf,
187                                      pytree_utils.Annotation.SUBTYPE, _FOO)
188
189    self.assertSetEqual(
190        pytree_utils.GetNodeAnnotation(self._leaf,
191                                       pytree_utils.Annotation.SUBTYPE), {_FOO})
192
193    pytree_utils.RemoveSubtypeAnnotation(self._leaf, _FOO)
194
195    self.assertSetEqual(
196        pytree_utils.GetNodeAnnotation(self._leaf,
197                                       pytree_utils.Annotation.SUBTYPE), set())
198
199  def testSetOnNode(self):
200    pytree_utils.SetNodeAnnotation(self._node, _FOO, 20)
201    self.assertEqual(pytree_utils.GetNodeAnnotation(self._node, _FOO), 20)
202
203
204if __name__ == '__main__':
205  unittest.main()
206