1""" 2Unit tests for refactor.py. 3""" 4 5from __future__ import with_statement 6 7import sys 8import os 9import codecs 10import operator 11import StringIO 12import tempfile 13import shutil 14import unittest 15import warnings 16 17from lib2to3 import refactor, pygram, fixer_base 18from lib2to3.pgen2 import token 19 20from . import support 21 22 23TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "data") 24FIXER_DIR = os.path.join(TEST_DATA_DIR, "fixers") 25 26sys.path.append(FIXER_DIR) 27try: 28 _DEFAULT_FIXERS = refactor.get_fixers_from_package("myfixes") 29finally: 30 sys.path.pop() 31 32_2TO3_FIXERS = refactor.get_fixers_from_package("lib2to3.fixes") 33 34class TestRefactoringTool(unittest.TestCase): 35 36 def setUp(self): 37 sys.path.append(FIXER_DIR) 38 39 def tearDown(self): 40 sys.path.pop() 41 42 def check_instances(self, instances, classes): 43 for inst, cls in zip(instances, classes): 44 if not isinstance(inst, cls): 45 self.fail("%s are not instances of %s" % instances, classes) 46 47 def rt(self, options=None, fixers=_DEFAULT_FIXERS, explicit=None): 48 return refactor.RefactoringTool(fixers, options, explicit) 49 50 def test_print_function_option(self): 51 rt = self.rt({"print_function" : True}) 52 self.assertTrue(rt.grammar is pygram.python_grammar_no_print_statement) 53 self.assertTrue(rt.driver.grammar is 54 pygram.python_grammar_no_print_statement) 55 56 def test_fixer_loading_helpers(self): 57 contents = ["explicit", "first", "last", "parrot", "preorder"] 58 non_prefixed = refactor.get_all_fix_names("myfixes") 59 prefixed = refactor.get_all_fix_names("myfixes", False) 60 full_names = refactor.get_fixers_from_package("myfixes") 61 self.assertEqual(prefixed, ["fix_" + name for name in contents]) 62 self.assertEqual(non_prefixed, contents) 63 self.assertEqual(full_names, 64 ["myfixes.fix_" + name for name in contents]) 65 66 def test_detect_future_features(self): 67 run = refactor._detect_future_features 68 fs = frozenset 69 empty = fs() 70 self.assertEqual(run(""), empty) 71 self.assertEqual(run("from __future__ import print_function"), 72 fs(("print_function",))) 73 self.assertEqual(run("from __future__ import generators"), 74 fs(("generators",))) 75 self.assertEqual(run("from __future__ import generators, feature"), 76 fs(("generators", "feature"))) 77 inp = "from __future__ import generators, print_function" 78 self.assertEqual(run(inp), fs(("generators", "print_function"))) 79 inp ="from __future__ import print_function, generators" 80 self.assertEqual(run(inp), fs(("print_function", "generators"))) 81 inp = "from __future__ import (print_function,)" 82 self.assertEqual(run(inp), fs(("print_function",))) 83 inp = "from __future__ import (generators, print_function)" 84 self.assertEqual(run(inp), fs(("generators", "print_function"))) 85 inp = "from __future__ import (generators, nested_scopes)" 86 self.assertEqual(run(inp), fs(("generators", "nested_scopes"))) 87 inp = """from __future__ import generators 88from __future__ import print_function""" 89 self.assertEqual(run(inp), fs(("generators", "print_function"))) 90 invalid = ("from", 91 "from 4", 92 "from x", 93 "from x 5", 94 "from x im", 95 "from x import", 96 "from x import 4", 97 ) 98 for inp in invalid: 99 self.assertEqual(run(inp), empty) 100 inp = "'docstring'\nfrom __future__ import print_function" 101 self.assertEqual(run(inp), fs(("print_function",))) 102 inp = "'docstring'\n'somng'\nfrom __future__ import print_function" 103 self.assertEqual(run(inp), empty) 104 inp = "# comment\nfrom __future__ import print_function" 105 self.assertEqual(run(inp), fs(("print_function",))) 106 inp = "# comment\n'doc'\nfrom __future__ import print_function" 107 self.assertEqual(run(inp), fs(("print_function",))) 108 inp = "class x: pass\nfrom __future__ import print_function" 109 self.assertEqual(run(inp), empty) 110 111 def test_get_headnode_dict(self): 112 class NoneFix(fixer_base.BaseFix): 113 pass 114 115 class FileInputFix(fixer_base.BaseFix): 116 PATTERN = "file_input< any * >" 117 118 class SimpleFix(fixer_base.BaseFix): 119 PATTERN = "'name'" 120 121 no_head = NoneFix({}, []) 122 with_head = FileInputFix({}, []) 123 simple = SimpleFix({}, []) 124 d = refactor._get_headnode_dict([no_head, with_head, simple]) 125 top_fixes = d.pop(pygram.python_symbols.file_input) 126 self.assertEqual(top_fixes, [with_head, no_head]) 127 name_fixes = d.pop(token.NAME) 128 self.assertEqual(name_fixes, [simple, no_head]) 129 for fixes in d.itervalues(): 130 self.assertEqual(fixes, [no_head]) 131 132 def test_fixer_loading(self): 133 from myfixes.fix_first import FixFirst 134 from myfixes.fix_last import FixLast 135 from myfixes.fix_parrot import FixParrot 136 from myfixes.fix_preorder import FixPreorder 137 138 rt = self.rt() 139 pre, post = rt.get_fixers() 140 141 self.check_instances(pre, [FixPreorder]) 142 self.check_instances(post, [FixFirst, FixParrot, FixLast]) 143 144 def test_naughty_fixers(self): 145 self.assertRaises(ImportError, self.rt, fixers=["not_here"]) 146 self.assertRaises(refactor.FixerError, self.rt, fixers=["no_fixer_cls"]) 147 self.assertRaises(refactor.FixerError, self.rt, fixers=["bad_order"]) 148 149 def test_refactor_string(self): 150 rt = self.rt() 151 input = "def parrot(): pass\n\n" 152 tree = rt.refactor_string(input, "<test>") 153 self.assertNotEqual(str(tree), input) 154 155 input = "def f(): pass\n\n" 156 tree = rt.refactor_string(input, "<test>") 157 self.assertEqual(str(tree), input) 158 159 def test_refactor_stdin(self): 160 161 class MyRT(refactor.RefactoringTool): 162 163 def print_output(self, old_text, new_text, filename, equal): 164 results.extend([old_text, new_text, filename, equal]) 165 166 results = [] 167 rt = MyRT(_DEFAULT_FIXERS) 168 save = sys.stdin 169 sys.stdin = StringIO.StringIO("def parrot(): pass\n\n") 170 try: 171 rt.refactor_stdin() 172 finally: 173 sys.stdin = save 174 expected = ["def parrot(): pass\n\n", 175 "def cheese(): pass\n\n", 176 "<stdin>", False] 177 self.assertEqual(results, expected) 178 179 def check_file_refactoring(self, test_file, fixers=_2TO3_FIXERS): 180 def read_file(): 181 with open(test_file, "rb") as fp: 182 return fp.read() 183 old_contents = read_file() 184 rt = self.rt(fixers=fixers) 185 186 rt.refactor_file(test_file) 187 self.assertEqual(old_contents, read_file()) 188 189 try: 190 rt.refactor_file(test_file, True) 191 new_contents = read_file() 192 self.assertNotEqual(old_contents, new_contents) 193 finally: 194 with open(test_file, "wb") as fp: 195 fp.write(old_contents) 196 return new_contents 197 198 def test_refactor_file(self): 199 test_file = os.path.join(FIXER_DIR, "parrot_example.py") 200 self.check_file_refactoring(test_file, _DEFAULT_FIXERS) 201 202 def test_refactor_dir(self): 203 def check(structure, expected): 204 def mock_refactor_file(self, f, *args): 205 got.append(f) 206 save_func = refactor.RefactoringTool.refactor_file 207 refactor.RefactoringTool.refactor_file = mock_refactor_file 208 rt = self.rt() 209 got = [] 210 dir = tempfile.mkdtemp(prefix="2to3-test_refactor") 211 try: 212 os.mkdir(os.path.join(dir, "a_dir")) 213 for fn in structure: 214 open(os.path.join(dir, fn), "wb").close() 215 rt.refactor_dir(dir) 216 finally: 217 refactor.RefactoringTool.refactor_file = save_func 218 shutil.rmtree(dir) 219 self.assertEqual(got, 220 [os.path.join(dir, path) for path in expected]) 221 check([], []) 222 tree = ["nothing", 223 "hi.py", 224 ".dumb", 225 ".after.py", 226 "notpy.npy", 227 "sappy"] 228 expected = ["hi.py"] 229 check(tree, expected) 230 tree = ["hi.py", 231 os.path.join("a_dir", "stuff.py")] 232 check(tree, tree) 233 234 def test_file_encoding(self): 235 fn = os.path.join(TEST_DATA_DIR, "different_encoding.py") 236 self.check_file_refactoring(fn) 237 238 def test_bom(self): 239 fn = os.path.join(TEST_DATA_DIR, "bom.py") 240 data = self.check_file_refactoring(fn) 241 self.assertTrue(data.startswith(codecs.BOM_UTF8)) 242 243 def test_crlf_newlines(self): 244 old_sep = os.linesep 245 os.linesep = "\r\n" 246 try: 247 fn = os.path.join(TEST_DATA_DIR, "crlf.py") 248 fixes = refactor.get_fixers_from_package("lib2to3.fixes") 249 self.check_file_refactoring(fn, fixes) 250 finally: 251 os.linesep = old_sep 252 253 def test_refactor_docstring(self): 254 rt = self.rt() 255 256 doc = """ 257>>> example() 25842 259""" 260 out = rt.refactor_docstring(doc, "<test>") 261 self.assertEqual(out, doc) 262 263 doc = """ 264>>> def parrot(): 265... return 43 266""" 267 out = rt.refactor_docstring(doc, "<test>") 268 self.assertNotEqual(out, doc) 269 270 def test_explicit(self): 271 from myfixes.fix_explicit import FixExplicit 272 273 rt = self.rt(fixers=["myfixes.fix_explicit"]) 274 self.assertEqual(len(rt.post_order), 0) 275 276 rt = self.rt(explicit=["myfixes.fix_explicit"]) 277 for fix in rt.post_order: 278 if isinstance(fix, FixExplicit): 279 break 280 else: 281 self.fail("explicit fixer not loaded") 282