1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for ast_edits which is used in tf upgraders.
16
17All of the tests assume that we want to change from an API containing
18
19    import foo as f
20
21    def f(a, b, kw1, kw2): ...
22    def g(a, b, kw1, c, kw1_alias): ...
23    def g2(a, b, kw1, c, d, kw1_alias): ...
24    def h(a, kw1, kw2, kw1_alias, kw2_alias): ...
25
26and the changes to the API consist of renaming, reordering, and/or removing
27arguments. Thus, we want to be able to generate changes to produce each of the
28following new APIs:
29
30    import bar as f
31
32    def f(a, b, kw1, kw3): ...
33    def f(a, b, kw2, kw1): ...
34    def f(a, b, kw3, kw1): ...
35    def g(a, b, kw1, c): ...
36    def g(a, b, c, kw1): ...
37    def g2(a, b, kw1, c, d): ...
38    def g2(a, b, c, d, kw1): ...
39    def h(a, kw1, kw2): ...
40
41"""
42
43from __future__ import absolute_import
44from __future__ import division
45from __future__ import print_function
46
47import ast
48import os
49
50import six
51
52from tensorflow.python.framework import test_util
53from tensorflow.python.platform import test as test_lib
54from tensorflow.tools.compatibility import ast_edits
55
56
57class ModuleDeprecationSpec(ast_edits.NoUpdateSpec):
58  """A specification which deprecates 'a.b'."""
59
60  def __init__(self):
61    ast_edits.NoUpdateSpec.__init__(self)
62    self.module_deprecations.update({"a.b": (ast_edits.ERROR, "a.b is evil.")})
63
64
65class RenameKeywordSpec(ast_edits.NoUpdateSpec):
66  """A specification where kw2 gets renamed to kw3.
67
68  The new API is
69
70    def f(a, b, kw1, kw3): ...
71
72  """
73
74  def __init__(self):
75    ast_edits.NoUpdateSpec.__init__(self)
76    self.update_renames()
77
78  def update_renames(self):
79    self.function_keyword_renames["f"] = {"kw2": "kw3"}
80
81
82class ReorderKeywordSpec(ast_edits.NoUpdateSpec):
83  """A specification where kw2 gets moved in front of kw1.
84
85  The new API is
86
87    def f(a, b, kw2, kw1): ...
88
89  """
90
91  def __init__(self):
92    ast_edits.NoUpdateSpec.__init__(self)
93    self.update_reorders()
94
95  def update_reorders(self):
96    # Note that these should be in the old order.
97    self.function_reorders["f"] = ["a", "b", "kw1", "kw2"]
98
99
100class ReorderAndRenameKeywordSpec(ReorderKeywordSpec, RenameKeywordSpec):
101  """A specification where kw2 gets moved in front of kw1 and is changed to kw3.
102
103  The new API is
104
105    def f(a, b, kw3, kw1): ...
106
107  """
108
109  def __init__(self):
110    ReorderKeywordSpec.__init__(self)
111    RenameKeywordSpec.__init__(self)
112    self.update_renames()
113    self.update_reorders()
114
115
116class RemoveDeprecatedAliasKeyword(ast_edits.NoUpdateSpec):
117  """A specification where kw1_alias is removed in g.
118
119  The new API is
120
121    def g(a, b, kw1, c): ...
122    def g2(a, b, kw1, c, d): ...
123
124  """
125
126  def __init__(self):
127    ast_edits.NoUpdateSpec.__init__(self)
128    self.function_keyword_renames["g"] = {"kw1_alias": "kw1"}
129    self.function_keyword_renames["g2"] = {"kw1_alias": "kw1"}
130
131
132class RemoveDeprecatedAliasAndReorderRest(RemoveDeprecatedAliasKeyword):
133  """A specification where kw1_alias is removed in g.
134
135  The new API is
136
137    def g(a, b, c, kw1): ...
138    def g2(a, b, c, d, kw1): ...
139
140  """
141
142  def __init__(self):
143    RemoveDeprecatedAliasKeyword.__init__(self)
144    # Note that these should be in the old order.
145    self.function_reorders["g"] = ["a", "b", "kw1", "c"]
146    self.function_reorders["g2"] = ["a", "b", "kw1", "c", "d"]
147
148
149class RemoveMultipleKeywordArguments(ast_edits.NoUpdateSpec):
150  """A specification where both keyword aliases are removed from h.
151
152  The new API is
153
154    def h(a, kw1, kw2): ...
155
156  """
157
158  def __init__(self):
159    ast_edits.NoUpdateSpec.__init__(self)
160    self.function_keyword_renames["h"] = {
161        "kw1_alias": "kw1",
162        "kw2_alias": "kw2",
163    }
164
165
166class RenameImports(ast_edits.NoUpdateSpec):
167  """Specification for renaming imports."""
168
169  def __init__(self):
170    ast_edits.NoUpdateSpec.__init__(self)
171    self.import_renames = {
172        "foo": ast_edits.ImportRename(
173            "bar",
174            excluded_prefixes=["foo.baz"])
175    }
176
177
178class TestAstEdits(test_util.TensorFlowTestCase):
179
180  def _upgrade(self, spec, old_file_text):
181    in_file = six.StringIO(old_file_text)
182    out_file = six.StringIO()
183    upgrader = ast_edits.ASTCodeUpgrader(spec)
184    count, report, errors = (
185        upgrader.process_opened_file("test.py", in_file,
186                                     "test_out.py", out_file))
187    return (count, report, errors), out_file.getvalue()
188
189  def testModuleDeprecation(self):
190    text = "a.b.c(a.b.x)"
191    (_, _, errors), new_text = self._upgrade(ModuleDeprecationSpec(), text)
192    self.assertEqual(text, new_text)
193    self.assertIn("Using member a.b.c", errors[0])
194    self.assertIn("1:0", errors[0])
195    self.assertIn("Using member a.b.c", errors[0])
196    self.assertIn("1:6", errors[1])
197
198  def testNoTransformIfNothingIsSupplied(self):
199    text = "f(a, b, kw1=c, kw2=d)\n"
200    _, new_text = self._upgrade(ast_edits.NoUpdateSpec(), text)
201    self.assertEqual(new_text, text)
202
203    text = "f(a, b, c, d)\n"
204    _, new_text = self._upgrade(ast_edits.NoUpdateSpec(), text)
205    self.assertEqual(new_text, text)
206
207  def testKeywordRename(self):
208    """Test that we get the expected result if renaming kw2 to kw3."""
209    text = "f(a, b, kw1=c, kw2=d)\n"
210    expected = "f(a, b, kw1=c, kw3=d)\n"
211    (_, report, _), new_text = self._upgrade(RenameKeywordSpec(), text)
212    self.assertEqual(new_text, expected)
213    self.assertNotIn("Manual check required", report)
214
215    # No keywords specified, no reordering, so we should get input as output
216    text = "f(a, b, c, d)\n"
217    (_, report, _), new_text = self._upgrade(RenameKeywordSpec(), text)
218    self.assertEqual(new_text, text)
219    self.assertNotIn("Manual check required", report)
220
221    # Positional *args passed in that we cannot inspect, should warn
222    text = "f(a, *args)\n"
223    (_, report, _), _ = self._upgrade(RenameKeywordSpec(), text)
224    self.assertNotIn("Manual check required", report)
225
226    # **kwargs passed in that we cannot inspect, should warn
227    text = "f(a, b, kw1=c, **kwargs)\n"
228    (_, report, _), _ = self._upgrade(RenameKeywordSpec(), text)
229    self.assertIn("Manual check required", report)
230
231  def testKeywordReorderWithParens(self):
232    """Test that we get the expected result if there are parens around args."""
233    text = "f((a), ( ( b ) ))\n"
234    acceptable_outputs = [
235        # No change is a valid output
236        text,
237        # Also cases where all arguments are fully specified are allowed
238        "f(a=(a), b=( ( b ) ))\n",
239        # Making the parens canonical is ok
240        "f(a=(a), b=((b)))\n",
241    ]
242    _, new_text = self._upgrade(ReorderKeywordSpec(), text)
243    self.assertIn(new_text, acceptable_outputs)
244
245  def testKeywordReorder(self):
246    """Test that we get the expected result if kw2 is now before kw1."""
247    text = "f(a, b, kw1=c, kw2=d)\n"
248    acceptable_outputs = [
249        # No change is a valid output
250        text,
251        # Just reordering the kw.. args is also ok
252        "f(a, b, kw2=d, kw1=c)\n",
253        # Also cases where all arguments are fully specified are allowed
254        "f(a=a, b=b, kw1=c, kw2=d)\n",
255        "f(a=a, b=b, kw2=d, kw1=c)\n",
256    ]
257    (_, report, _), new_text = self._upgrade(ReorderKeywordSpec(), text)
258    self.assertIn(new_text, acceptable_outputs)
259    self.assertNotIn("Manual check required", report)
260
261    # Keywords are reordered, so we should reorder arguments too
262    text = "f(a, b, c, d)\n"
263    acceptable_outputs = [
264        "f(a, b, d, c)\n",
265        "f(a=a, b=b, kw1=c, kw2=d)\n",
266        "f(a=a, b=b, kw2=d, kw1=c)\n",
267    ]
268    (_, report, _), new_text = self._upgrade(ReorderKeywordSpec(), text)
269    self.assertIn(new_text, acceptable_outputs)
270    self.assertNotIn("Manual check required", report)
271
272    # Positional *args passed in that we cannot inspect, should warn
273    text = "f(a, b, *args)\n"
274    (_, report, _), _ = self._upgrade(ReorderKeywordSpec(), text)
275    self.assertIn("Manual check required", report)
276
277    # **kwargs passed in that we cannot inspect, should warn
278    text = "f(a, b, kw1=c, **kwargs)\n"
279    (_, report, _), _ = self._upgrade(ReorderKeywordSpec(), text)
280    self.assertNotIn("Manual check required", report)
281
282  def testKeywordReorderAndRename(self):
283    """Test that we get the expected result if kw2 is renamed and moved."""
284    text = "f(a, b, kw1=c, kw2=d)\n"
285    acceptable_outputs = [
286        "f(a, b, kw3=d, kw1=c)\n",
287        "f(a=a, b=b, kw1=c, kw3=d)\n",
288        "f(a=a, b=b, kw3=d, kw1=c)\n",
289    ]
290    (_, report, _), new_text = self._upgrade(
291        ReorderAndRenameKeywordSpec(), text)
292    self.assertIn(new_text, acceptable_outputs)
293    self.assertNotIn("Manual check required", report)
294
295    # Keywords are reordered, so we should reorder arguments too
296    text = "f(a, b, c, d)\n"
297    acceptable_outputs = [
298        "f(a, b, d, c)\n",
299        "f(a=a, b=b, kw1=c, kw3=d)\n",
300        "f(a=a, b=b, kw3=d, kw1=c)\n",
301    ]
302    (_, report, _), new_text = self._upgrade(
303        ReorderAndRenameKeywordSpec(), text)
304    self.assertIn(new_text, acceptable_outputs)
305    self.assertNotIn("Manual check required", report)
306
307    # Positional *args passed in that we cannot inspect, should warn
308    text = "f(a, *args, kw1=c)\n"
309    (_, report, _), _ = self._upgrade(ReorderAndRenameKeywordSpec(), text)
310    self.assertIn("Manual check required", report)
311
312    # **kwargs passed in that we cannot inspect, should warn
313    text = "f(a, b, kw1=c, **kwargs)\n"
314    (_, report, _), _ = self._upgrade(ReorderAndRenameKeywordSpec(), text)
315    self.assertIn("Manual check required", report)
316
317  def testRemoveDeprecatedKeywordAlias(self):
318    """Test that we get the expected result if a keyword alias is removed."""
319    text = "g(a, b, kw1=x, c=c)\n"
320    acceptable_outputs = [
321        # Not using deprecated alias, so original is ok
322        text,
323        "g(a=a, b=b, kw1=x, c=c)\n",
324    ]
325    _, new_text = self._upgrade(RemoveDeprecatedAliasKeyword(), text)
326    self.assertIn(new_text, acceptable_outputs)
327
328    # No keyword used, should be no change
329    text = "g(a, b, x, c)\n"
330    _, new_text = self._upgrade(RemoveDeprecatedAliasKeyword(), text)
331    self.assertEqual(new_text, text)
332
333    # If we used the alias, it should get renamed
334    text = "g(a, b, kw1_alias=x, c=c)\n"
335    acceptable_outputs = [
336        "g(a, b, kw1=x, c=c)\n",
337        "g(a, b, c=c, kw1=x)\n",
338        "g(a=a, b=b, kw1=x, c=c)\n",
339        "g(a=a, b=b, c=c, kw1=x)\n",
340    ]
341    _, new_text = self._upgrade(RemoveDeprecatedAliasKeyword(), text)
342    self.assertIn(new_text, acceptable_outputs)
343
344    # It should get renamed even if it's last
345    text = "g(a, b, c=c, kw1_alias=x)\n"
346    acceptable_outputs = [
347        "g(a, b, kw1=x, c=c)\n",
348        "g(a, b, c=c, kw1=x)\n",
349        "g(a=a, b=b, kw1=x, c=c)\n",
350        "g(a=a, b=b, c=c, kw1=x)\n",
351    ]
352    _, new_text = self._upgrade(RemoveDeprecatedAliasKeyword(), text)
353    self.assertIn(new_text, acceptable_outputs)
354
355  def testRemoveDeprecatedKeywordAndReorder(self):
356    """Test for when a keyword alias is removed and args are reordered."""
357    text = "g(a, b, kw1=x, c=c)\n"
358    acceptable_outputs = [
359        "g(a, b, c=c, kw1=x)\n",
360        "g(a=a, b=b, kw1=x, c=c)\n",
361    ]
362    _, new_text = self._upgrade(RemoveDeprecatedAliasAndReorderRest(), text)
363    self.assertIn(new_text, acceptable_outputs)
364
365    # Keywords are reordered, so we should reorder arguments too
366    text = "g(a, b, x, c)\n"
367    # Don't accept an output which doesn't reorder c and d
368    acceptable_outputs = [
369        "g(a, b, c, x)\n",
370        "g(a=a, b=b, kw1=x, c=c)\n",
371    ]
372    _, new_text = self._upgrade(RemoveDeprecatedAliasAndReorderRest(), text)
373    self.assertIn(new_text, acceptable_outputs)
374
375    # If we used the alias, it should get renamed
376    text = "g(a, b, kw1_alias=x, c=c)\n"
377    acceptable_outputs = [
378        "g(a, b, kw1=x, c=c)\n",
379        "g(a, b, c=c, kw1=x)\n",
380        "g(a=a, b=b, kw1=x, c=c)\n",
381        "g(a=a, b=b, c=c, kw1=x)\n",
382    ]
383    _, new_text = self._upgrade(RemoveDeprecatedAliasKeyword(), text)
384    self.assertIn(new_text, acceptable_outputs)
385
386    # It should get renamed and reordered even if it's last
387    text = "g(a, b, c=c, kw1_alias=x)\n"
388    acceptable_outputs = [
389        "g(a, b, kw1=x, c=c)\n",
390        "g(a, b, c=c, kw1=x)\n",
391        "g(a=a, b=b, kw1=x, c=c)\n",
392        "g(a=a, b=b, c=c, kw1=x)\n",
393    ]
394    _, new_text = self._upgrade(RemoveDeprecatedAliasKeyword(), text)
395    self.assertIn(new_text, acceptable_outputs)
396
397  def testRemoveDeprecatedKeywordAndReorder2(self):
398    """Same as testRemoveDeprecatedKeywordAndReorder but on g2 (more args)."""
399    text = "g2(a, b, kw1=x, c=c, d=d)\n"
400    acceptable_outputs = [
401        "g2(a, b, c=c, d=d, kw1=x)\n",
402        "g2(a=a, b=b, kw1=x, c=c, d=d)\n",
403    ]
404    _, new_text = self._upgrade(RemoveDeprecatedAliasAndReorderRest(), text)
405    self.assertIn(new_text, acceptable_outputs)
406
407    # Keywords are reordered, so we should reorder arguments too
408    text = "g2(a, b, x, c, d)\n"
409    # Don't accept an output which doesn't reorder c and d
410    acceptable_outputs = [
411        "g2(a, b, c, d, x)\n",
412        "g2(a=a, b=b, kw1=x, c=c, d=d)\n",
413    ]
414    _, new_text = self._upgrade(RemoveDeprecatedAliasAndReorderRest(), text)
415    self.assertIn(new_text, acceptable_outputs)
416
417    # If we used the alias, it should get renamed
418    text = "g2(a, b, kw1_alias=x, c=c, d=d)\n"
419    acceptable_outputs = [
420        "g2(a, b, kw1=x, c=c, d=d)\n",
421        "g2(a, b, c=c, d=d, kw1=x)\n",
422        "g2(a=a, b=b, kw1=x, c=c, d=d)\n",
423        "g2(a=a, b=b, c=c, d=d, kw1=x)\n",
424    ]
425    _, new_text = self._upgrade(RemoveDeprecatedAliasKeyword(), text)
426    self.assertIn(new_text, acceptable_outputs)
427
428    # It should get renamed and reordered even if it's not in order
429    text = "g2(a, b, d=d, c=c, kw1_alias=x)\n"
430    acceptable_outputs = [
431        "g2(a, b, kw1=x, c=c, d=d)\n",
432        "g2(a, b, c=c, d=d, kw1=x)\n",
433        "g2(a, b, d=d, c=c, kw1=x)\n",
434        "g2(a=a, b=b, kw1=x, c=c, d=d)\n",
435        "g2(a=a, b=b, c=c, d=d, kw1=x)\n",
436        "g2(a=a, b=b, d=d, c=c, kw1=x)\n",
437    ]
438    _, new_text = self._upgrade(RemoveDeprecatedAliasKeyword(), text)
439    self.assertIn(new_text, acceptable_outputs)
440
441  def testRemoveMultipleKeywords(self):
442    """Remove multiple keywords at once."""
443    # Not using deprecated keywords -> no rename
444    text = "h(a, kw1=x, kw2=y)\n"
445    _, new_text = self._upgrade(RemoveMultipleKeywordArguments(), text)
446    self.assertEqual(new_text, text)
447
448    # Using positional arguments (in proper order) -> no change
449    text = "h(a, x, y)\n"
450    _, new_text = self._upgrade(RemoveMultipleKeywordArguments(), text)
451    self.assertEqual(new_text, text)
452
453    # Use only the old names, in order
454    text = "h(a, kw1_alias=x, kw2_alias=y)\n"
455    acceptable_outputs = [
456        "h(a, x, y)\n",
457        "h(a, kw1=x, kw2=y)\n",
458        "h(a=a, kw1=x, kw2=y)\n",
459        "h(a, kw2=y, kw1=x)\n",
460        "h(a=a, kw2=y, kw1=x)\n",
461    ]
462    _, new_text = self._upgrade(RemoveMultipleKeywordArguments(), text)
463    self.assertIn(new_text, acceptable_outputs)
464
465    # Use only the old names, in reverse order, should give one of same outputs
466    text = "h(a, kw2_alias=y, kw1_alias=x)\n"
467    _, new_text = self._upgrade(RemoveMultipleKeywordArguments(), text)
468    self.assertIn(new_text, acceptable_outputs)
469
470    # Mix old and new names
471    text = "h(a, kw1=x, kw2_alias=y)\n"
472    _, new_text = self._upgrade(RemoveMultipleKeywordArguments(), text)
473    self.assertIn(new_text, acceptable_outputs)
474
475  def testUnrestrictedFunctionWarnings(self):
476    class FooWarningSpec(ast_edits.NoUpdateSpec):
477      """Usages of function attribute foo() prints out a warning."""
478
479      def __init__(self):
480        ast_edits.NoUpdateSpec.__init__(self)
481        self.function_warnings = {"*.foo": (ast_edits.WARNING, "not good")}
482
483    texts = ["object.foo()", "get_object().foo()",
484             "get_object().foo()", "object.foo().bar()"]
485    for text in texts:
486      (_, report, _), _ = self._upgrade(FooWarningSpec(), text)
487      self.assertIn("not good", report)
488
489    # Note that foo() won't result in a warning, because in this case foo is
490    # not an attribute, but a name.
491    false_alarms = ["foo", "foo()", "foo.bar()", "obj.run_foo()", "obj.foo"]
492    for text in false_alarms:
493      (_, report, _), _ = self._upgrade(FooWarningSpec(), text)
494      self.assertNotIn("not good", report)
495
496  def testFullNameNode(self):
497    t = ast_edits.full_name_node("a.b.c")
498    self.assertEqual(
499        ast.dump(t),
500        "Attribute(value=Attribute(value=Name(id='a', ctx=Load()), attr='b', "
501        "ctx=Load()), attr='c', ctx=Load())")
502
503  def testImport(self):
504    # foo should be renamed to bar.
505    text = "import foo as f"
506    expected_text = "import bar as f"
507    _, new_text = self._upgrade(RenameImports(), text)
508    self.assertEqual(expected_text, new_text)
509
510    text = "import foo"
511    expected_text = "import bar as foo"
512    _, new_text = self._upgrade(RenameImports(), text)
513    self.assertEqual(expected_text, new_text)
514
515    text = "import foo.test"
516    expected_text = "import bar.test"
517    _, new_text = self._upgrade(RenameImports(), text)
518    self.assertEqual(expected_text, new_text)
519
520    text = "import foo.test as t"
521    expected_text = "import bar.test as t"
522    _, new_text = self._upgrade(RenameImports(), text)
523    self.assertEqual(expected_text, new_text)
524
525    text = "import foo as f, a as b"
526    expected_text = "import bar as f, a as b"
527    _, new_text = self._upgrade(RenameImports(), text)
528    self.assertEqual(expected_text, new_text)
529
530  def testFromImport(self):
531    # foo should be renamed to bar.
532    text = "from foo import a"
533    expected_text = "from bar import a"
534    _, new_text = self._upgrade(RenameImports(), text)
535    self.assertEqual(expected_text, new_text)
536
537    text = "from foo.a import b"
538    expected_text = "from bar.a import b"
539    _, new_text = self._upgrade(RenameImports(), text)
540    self.assertEqual(expected_text, new_text)
541
542    text = "from foo import *"
543    expected_text = "from bar import *"
544    _, new_text = self._upgrade(RenameImports(), text)
545    self.assertEqual(expected_text, new_text)
546
547    text = "from foo import a, b"
548    expected_text = "from bar import a, b"
549    _, new_text = self._upgrade(RenameImports(), text)
550    self.assertEqual(expected_text, new_text)
551
552  def testImport_NoChangeNeeded(self):
553    text = "import bar as b"
554    _, new_text = self._upgrade(RenameImports(), text)
555    self.assertEqual(text, new_text)
556
557  def testFromImport_NoChangeNeeded(self):
558    text = "from bar import a as b"
559    _, new_text = self._upgrade(RenameImports(), text)
560    self.assertEqual(text, new_text)
561
562  def testExcludedImport(self):
563    # foo.baz module is excluded from changes.
564    text = "import foo.baz"
565    _, new_text = self._upgrade(RenameImports(), text)
566    self.assertEqual(text, new_text)
567
568    text = "import foo.baz as a"
569    _, new_text = self._upgrade(RenameImports(), text)
570    self.assertEqual(text, new_text)
571
572    text = "from foo import baz as a"
573    _, new_text = self._upgrade(RenameImports(), text)
574    self.assertEqual(text, new_text)
575
576    text = "from foo.baz import a"
577    _, new_text = self._upgrade(RenameImports(), text)
578    self.assertEqual(text, new_text)
579
580  def testMultipleImports(self):
581    text = "import foo.bar as a, foo.baz as b, foo.baz.c, foo.d"
582    expected_text = "import bar.bar as a, foo.baz as b, foo.baz.c, bar.d"
583    _, new_text = self._upgrade(RenameImports(), text)
584    self.assertEqual(expected_text, new_text)
585
586    text = "from foo import baz, a, c"
587    expected_text = """from foo import baz
588from bar import a, c"""
589    _, new_text = self._upgrade(RenameImports(), text)
590    self.assertEqual(expected_text, new_text)
591
592  def testImportInsideFunction(self):
593    text = """
594def t():
595  from c import d
596  from foo import baz, a
597  from e import y
598"""
599    expected_text = """
600def t():
601  from c import d
602  from foo import baz
603  from bar import a
604  from e import y
605"""
606    _, new_text = self._upgrade(RenameImports(), text)
607    self.assertEqual(expected_text, new_text)
608
609  def testUpgradeInplaceWithSymlink(self):
610    if os.name == "nt":
611      self.skipTest("os.symlink doesn't work uniformly on Windows.")
612
613    upgrade_dir = os.path.join(self.get_temp_dir(), "foo")
614    os.mkdir(upgrade_dir)
615    file_a = os.path.join(upgrade_dir, "a.py")
616    file_b = os.path.join(upgrade_dir, "b.py")
617
618    with open(file_a, "a") as f:
619      f.write("import foo as f")
620    os.symlink(file_a, file_b)
621
622    upgrader = ast_edits.ASTCodeUpgrader(RenameImports())
623    upgrader.process_tree_inplace(upgrade_dir)
624
625    self.assertTrue(os.path.islink(file_b))
626    self.assertEqual(file_a, os.readlink(file_b))
627    with open(file_a, "r") as f:
628      self.assertEqual("import bar as f", f.read())
629
630  def testUpgradeInPlaceWithSymlinkInDifferentDir(self):
631    if os.name == "nt":
632      self.skipTest("os.symlink doesn't work uniformly on Windows.")
633
634    upgrade_dir = os.path.join(self.get_temp_dir(), "foo")
635    other_dir = os.path.join(self.get_temp_dir(), "bar")
636    os.mkdir(upgrade_dir)
637    os.mkdir(other_dir)
638    file_c = os.path.join(other_dir, "c.py")
639    file_d = os.path.join(upgrade_dir, "d.py")
640
641    with open(file_c, "a") as f:
642      f.write("import foo as f")
643    os.symlink(file_c, file_d)
644
645    upgrader = ast_edits.ASTCodeUpgrader(RenameImports())
646    upgrader.process_tree_inplace(upgrade_dir)
647
648    self.assertTrue(os.path.islink(file_d))
649    self.assertEqual(file_c, os.readlink(file_d))
650    # File pointed to by symlink is in a different directory.
651    # Therefore, it should not be upgraded.
652    with open(file_c, "r") as f:
653      self.assertEqual("import foo as f", f.read())
654
655  def testUpgradeCopyWithSymlink(self):
656    if os.name == "nt":
657      self.skipTest("os.symlink doesn't work uniformly on Windows.")
658
659    upgrade_dir = os.path.join(self.get_temp_dir(), "foo")
660    output_dir = os.path.join(self.get_temp_dir(), "bar")
661    os.mkdir(upgrade_dir)
662    file_a = os.path.join(upgrade_dir, "a.py")
663    file_b = os.path.join(upgrade_dir, "b.py")
664
665    with open(file_a, "a") as f:
666      f.write("import foo as f")
667    os.symlink(file_a, file_b)
668
669    upgrader = ast_edits.ASTCodeUpgrader(RenameImports())
670    upgrader.process_tree(upgrade_dir, output_dir, copy_other_files=True)
671
672    new_file_a = os.path.join(output_dir, "a.py")
673    new_file_b = os.path.join(output_dir, "b.py")
674    self.assertTrue(os.path.islink(new_file_b))
675    self.assertEqual(new_file_a, os.readlink(new_file_b))
676    with open(new_file_a, "r") as f:
677      self.assertEqual("import bar as f", f.read())
678
679  def testUpgradeCopyWithSymlinkInDifferentDir(self):
680    if os.name == "nt":
681      self.skipTest("os.symlink doesn't work uniformly on Windows.")
682
683    upgrade_dir = os.path.join(self.get_temp_dir(), "foo")
684    other_dir = os.path.join(self.get_temp_dir(), "bar")
685    output_dir = os.path.join(self.get_temp_dir(), "baz")
686    os.mkdir(upgrade_dir)
687    os.mkdir(other_dir)
688    file_a = os.path.join(other_dir, "a.py")
689    file_b = os.path.join(upgrade_dir, "b.py")
690
691    with open(file_a, "a") as f:
692      f.write("import foo as f")
693    os.symlink(file_a, file_b)
694
695    upgrader = ast_edits.ASTCodeUpgrader(RenameImports())
696    upgrader.process_tree(upgrade_dir, output_dir, copy_other_files=True)
697
698    new_file_b = os.path.join(output_dir, "b.py")
699    self.assertTrue(os.path.islink(new_file_b))
700    self.assertEqual(file_a, os.readlink(new_file_b))
701    with open(file_a, "r") as f:
702      self.assertEqual("import foo as f", f.read())
703
704
705if __name__ == "__main__":
706  test_lib.main()
707