1from jinja2.visitor import NodeVisitor
2from jinja2._compat import iteritems
3
4
5VAR_LOAD_PARAMETER = 'param'
6VAR_LOAD_RESOLVE = 'resolve'
7VAR_LOAD_ALIAS = 'alias'
8VAR_LOAD_UNDEFINED = 'undefined'
9
10
11def find_symbols(nodes, parent_symbols=None):
12    sym = Symbols(parent=parent_symbols)
13    visitor = FrameSymbolVisitor(sym)
14    for node in nodes:
15        visitor.visit(node)
16    return sym
17
18
19def symbols_for_node(node, parent_symbols=None):
20    sym = Symbols(parent=parent_symbols)
21    sym.analyze_node(node)
22    return sym
23
24
25class Symbols(object):
26
27    def __init__(self, parent=None, level=None):
28        if level is None:
29            if parent is None:
30                level = 0
31            else:
32                level = parent.level + 1
33        self.level = level
34        self.parent = parent
35        self.refs = {}
36        self.loads = {}
37        self.stores = set()
38
39    def analyze_node(self, node, **kwargs):
40        visitor = RootVisitor(self)
41        visitor.visit(node, **kwargs)
42
43    def _define_ref(self, name, load=None):
44        ident = 'l_%d_%s' % (self.level, name)
45        self.refs[name] = ident
46        if load is not None:
47            self.loads[ident] = load
48        return ident
49
50    def find_load(self, target):
51        if target in self.loads:
52            return self.loads[target]
53        if self.parent is not None:
54            return self.parent.find_load(target)
55
56    def find_ref(self, name):
57        if name in self.refs:
58            return self.refs[name]
59        if self.parent is not None:
60            return self.parent.find_ref(name)
61
62    def ref(self, name):
63        rv = self.find_ref(name)
64        if rv is None:
65            raise AssertionError('Tried to resolve a name to a reference that '
66                                 'was unknown to the frame (%r)' % name)
67        return rv
68
69    def copy(self):
70        rv = object.__new__(self.__class__)
71        rv.__dict__.update(self.__dict__)
72        rv.refs = self.refs.copy()
73        rv.loads = self.loads.copy()
74        rv.stores = self.stores.copy()
75        return rv
76
77    def store(self, name):
78        self.stores.add(name)
79
80        # If we have not see the name referenced yet, we need to figure
81        # out what to set it to.
82        if name not in self.refs:
83            # If there is a parent scope we check if the name has a
84            # reference there.  If it does it means we might have to alias
85            # to a variable there.
86            if self.parent is not None:
87                outer_ref = self.parent.find_ref(name)
88                if outer_ref is not None:
89                    self._define_ref(name, load=(VAR_LOAD_ALIAS, outer_ref))
90                    return
91
92            # Otherwise we can just set it to undefined.
93            self._define_ref(name, load=(VAR_LOAD_UNDEFINED, None))
94
95    def declare_parameter(self, name):
96        self.stores.add(name)
97        return self._define_ref(name, load=(VAR_LOAD_PARAMETER, None))
98
99    def load(self, name):
100        target = self.find_ref(name)
101        if target is None:
102            self._define_ref(name, load=(VAR_LOAD_RESOLVE, name))
103
104    def branch_update(self, branch_symbols):
105        stores = {}
106        for branch in branch_symbols:
107            for target in branch.stores:
108                if target in self.stores:
109                    continue
110                stores[target] = stores.get(target, 0) + 1
111
112        for sym in branch_symbols:
113            self.refs.update(sym.refs)
114            self.loads.update(sym.loads)
115            self.stores.update(sym.stores)
116
117        for name, branch_count in iteritems(stores):
118            if branch_count == len(branch_symbols):
119                continue
120            target = self.find_ref(name)
121            assert target is not None, 'should not happen'
122
123            if self.parent is not None:
124                outer_target = self.parent.find_ref(name)
125                if outer_target is not None:
126                    self.loads[target] = (VAR_LOAD_ALIAS, outer_target)
127                    continue
128            self.loads[target] = (VAR_LOAD_RESOLVE, name)
129
130    def dump_stores(self):
131        rv = {}
132        node = self
133        while node is not None:
134            for name in node.stores:
135                if name not in rv:
136                    rv[name] = self.find_ref(name)
137            node = node.parent
138        return rv
139
140    def dump_param_targets(self):
141        rv = set()
142        node = self
143        while node is not None:
144            for target, (instr, _) in iteritems(self.loads):
145                if instr == VAR_LOAD_PARAMETER:
146                    rv.add(target)
147            node = node.parent
148        return rv
149
150
151class RootVisitor(NodeVisitor):
152
153    def __init__(self, symbols):
154        self.sym_visitor = FrameSymbolVisitor(symbols)
155
156    def _simple_visit(self, node, **kwargs):
157        for child in node.iter_child_nodes():
158            self.sym_visitor.visit(child)
159
160    visit_Template = visit_Block = visit_Macro = visit_FilterBlock = \
161        visit_Scope = visit_If = visit_ScopedEvalContextModifier = \
162        _simple_visit
163
164    def visit_AssignBlock(self, node, **kwargs):
165        for child in node.body:
166            self.sym_visitor.visit(child)
167
168    def visit_CallBlock(self, node, **kwargs):
169        for child in node.iter_child_nodes(exclude=('call',)):
170            self.sym_visitor.visit(child)
171
172    def visit_OverlayScope(self, node, **kwargs):
173        for child in node.body:
174            self.sym_visitor.visit(child)
175
176    def visit_For(self, node, for_branch='body', **kwargs):
177        if for_branch == 'body':
178            self.sym_visitor.visit(node.target, store_as_param=True)
179            branch = node.body
180        elif for_branch == 'else':
181            branch = node.else_
182        elif for_branch == 'test':
183            self.sym_visitor.visit(node.target, store_as_param=True)
184            if node.test is not None:
185                self.sym_visitor.visit(node.test)
186            return
187        else:
188            raise RuntimeError('Unknown for branch')
189        for item in branch or ():
190            self.sym_visitor.visit(item)
191
192    def visit_With(self, node, **kwargs):
193        for target in node.targets:
194            self.sym_visitor.visit(target)
195        for child in node.body:
196            self.sym_visitor.visit(child)
197
198    def generic_visit(self, node, *args, **kwargs):
199        raise NotImplementedError('Cannot find symbols for %r' %
200                                  node.__class__.__name__)
201
202
203class FrameSymbolVisitor(NodeVisitor):
204    """A visitor for `Frame.inspect`."""
205
206    def __init__(self, symbols):
207        self.symbols = symbols
208
209    def visit_Name(self, node, store_as_param=False, **kwargs):
210        """All assignments to names go through this function."""
211        if store_as_param or node.ctx == 'param':
212            self.symbols.declare_parameter(node.name)
213        elif node.ctx == 'store':
214            self.symbols.store(node.name)
215        elif node.ctx == 'load':
216            self.symbols.load(node.name)
217
218    def visit_NSRef(self, node, **kwargs):
219        self.symbols.load(node.name)
220
221    def visit_If(self, node, **kwargs):
222        self.visit(node.test, **kwargs)
223
224        original_symbols = self.symbols
225
226        def inner_visit(nodes):
227            self.symbols = rv = original_symbols.copy()
228            for subnode in nodes:
229                self.visit(subnode, **kwargs)
230            self.symbols = original_symbols
231            return rv
232
233        body_symbols = inner_visit(node.body)
234        elif_symbols = inner_visit(node.elif_)
235        else_symbols = inner_visit(node.else_ or ())
236
237        self.symbols.branch_update([body_symbols, elif_symbols, else_symbols])
238
239    def visit_Macro(self, node, **kwargs):
240        self.symbols.store(node.name)
241
242    def visit_Import(self, node, **kwargs):
243        self.generic_visit(node, **kwargs)
244        self.symbols.store(node.target)
245
246    def visit_FromImport(self, node, **kwargs):
247        self.generic_visit(node, **kwargs)
248        for name in node.names:
249            if isinstance(name, tuple):
250                self.symbols.store(name[1])
251            else:
252                self.symbols.store(name)
253
254    def visit_Assign(self, node, **kwargs):
255        """Visit assignments in the correct order."""
256        self.visit(node.node, **kwargs)
257        self.visit(node.target, **kwargs)
258
259    def visit_For(self, node, **kwargs):
260        """Visiting stops at for blocks.  However the block sequence
261        is visited as part of the outer scope.
262        """
263        self.visit(node.iter, **kwargs)
264
265    def visit_CallBlock(self, node, **kwargs):
266        self.visit(node.call, **kwargs)
267
268    def visit_FilterBlock(self, node, **kwargs):
269        self.visit(node.filter, **kwargs)
270
271    def visit_With(self, node, **kwargs):
272        for target in node.values:
273            self.visit(target)
274
275    def visit_AssignBlock(self, node, **kwargs):
276        """Stop visiting at block assigns."""
277        self.visit(node.target, **kwargs)
278
279    def visit_Scope(self, node, **kwargs):
280        """Stop visiting at scopes."""
281
282    def visit_Block(self, node, **kwargs):
283        """Stop visiting at blocks."""
284
285    def visit_OverlayScope(self, node, **kwargs):
286        """Do not visit into overlay scopes."""
287