1from jinja2 import nodes
2from jinja2.ext import Extension
3
4
5class FragmentCacheExtension(Extension):
6    # a set of names that trigger the extension.
7    tags = {"cache"}
8
9    def __init__(self, environment):
10        super().__init__(environment)
11
12        # add the defaults to the environment
13        environment.extend(fragment_cache_prefix="", fragment_cache=None)
14
15    def parse(self, parser):
16        # the first token is the token that started the tag.  In our case
17        # we only listen to ``'cache'`` so this will be a name token with
18        # `cache` as value.  We get the line number so that we can give
19        # that line number to the nodes we create by hand.
20        lineno = next(parser.stream).lineno
21
22        # now we parse a single expression that is used as cache key.
23        args = [parser.parse_expression()]
24
25        # if there is a comma, the user provided a timeout.  If not use
26        # None as second parameter.
27        if parser.stream.skip_if("comma"):
28            args.append(parser.parse_expression())
29        else:
30            args.append(nodes.Const(None))
31
32        # now we parse the body of the cache block up to `endcache` and
33        # drop the needle (which would always be `endcache` in that case)
34        body = parser.parse_statements(["name:endcache"], drop_needle=True)
35
36        # now return a `CallBlock` node that calls our _cache_support
37        # helper method on this extension.
38        return nodes.CallBlock(
39            self.call_method("_cache_support", args), [], [], body
40        ).set_lineno(lineno)
41
42    def _cache_support(self, name, timeout, caller):
43        """Helper callback."""
44        key = self.environment.fragment_cache_prefix + name
45
46        # try to load the block from the cache
47        # if there is no fragment in the cache, render it and store
48        # it in the cache.
49        rv = self.environment.fragment_cache.get(key)
50        if rv is not None:
51            return rv
52        rv = caller()
53        self.environment.fragment_cache.add(key, rv, timeout)
54        return rv
55