1"""
2Convert use of sys.exitfunc to use the atexit module.
3"""
4
5# Author: Benjamin Peterson
6
7from lib2to3 import pytree, fixer_base
8from lib2to3.fixer_util import Name, Attr, Call, Comma, Newline, syms
9
10
11class FixExitfunc(fixer_base.BaseFix):
12    keep_line_order = True
13    BM_compatible = True
14
15    PATTERN = """
16              (
17                  sys_import=import_name<'import'
18                      ('sys'
19                      |
20                      dotted_as_names< (any ',')* 'sys' (',' any)* >
21                      )
22                  >
23              |
24                  expr_stmt<
25                      power< 'sys' trailer< '.' 'exitfunc' > >
26                  '=' func=any >
27              )
28              """
29
30    def __init__(self, *args):
31        super(FixExitfunc, self).__init__(*args)
32
33    def start_tree(self, tree, filename):
34        super(FixExitfunc, self).start_tree(tree, filename)
35        self.sys_import = None
36
37    def transform(self, node, results):
38        # First, find the sys import. We'll just hope it's global scope.
39        if "sys_import" in results:
40            if self.sys_import is None:
41                self.sys_import = results["sys_import"]
42            return
43
44        func = results["func"].clone()
45        func.prefix = ""
46        register = pytree.Node(syms.power,
47                               Attr(Name("atexit"), Name("register"))
48                               )
49        call = Call(register, [func], node.prefix)
50        node.replace(call)
51
52        if self.sys_import is None:
53            # That's interesting.
54            self.warning(node, "Can't find sys import; Please add an atexit "
55                             "import at the top of your file.")
56            return
57
58        # Now add an atexit import after the sys import.
59        names = self.sys_import.children[1]
60        if names.type == syms.dotted_as_names:
61            names.append_child(Comma())
62            names.append_child(Name("atexit", " "))
63        else:
64            containing_stmt = self.sys_import.parent
65            position = containing_stmt.children.index(self.sys_import)
66            stmt_container = containing_stmt.parent
67            new_import = pytree.Node(syms.import_name,
68                              [Name("import"), Name("atexit", " ")]
69                              )
70            new = pytree.Node(syms.simple_stmt, [new_import])
71            containing_stmt.insert_child(position + 1, Newline())
72            containing_stmt.insert_child(position + 2, new)
73