1"""Fixer for reload().
2
3reload(s) -> importlib.reload(s)"""
4
5# Local imports
6from .. import fixer_base
7from ..fixer_util import ImportAndCall, touch_import
8
9
10class FixReload(fixer_base.BaseFix):
11    BM_compatible = True
12    order = "pre"
13
14    PATTERN = """
15    power< 'reload'
16           trailer< lpar='('
17                    ( not(arglist | argument<any '=' any>) obj=any
18                      | obj=arglist<(not argument<any '=' any>) any ','> )
19                    rpar=')' >
20           after=any*
21    >
22    """
23
24    def transform(self, node, results):
25        if results:
26            # I feel like we should be able to express this logic in the
27            # PATTERN above but I don't know how to do it so...
28            obj = results['obj']
29            if obj:
30                if (obj.type == self.syms.argument and
31                    obj.children[0].value in {'**', '*'}):
32                    return  # Make no change.
33        names = ('importlib', 'reload')
34        new = ImportAndCall(node, results, names)
35        touch_import(None, 'importlib', node)
36        return new
37