1import asyncio
2
3import pytest
4
5from jinja2 import DictLoader
6from jinja2 import Environment
7from jinja2 import Template
8from jinja2.asyncsupport import auto_aiter
9from jinja2.exceptions import TemplateNotFound
10from jinja2.exceptions import TemplatesNotFound
11from jinja2.exceptions import UndefinedError
12
13
14def run(coro):
15    loop = asyncio.get_event_loop()
16    return loop.run_until_complete(coro)
17
18
19def test_basic_async():
20    t = Template(
21        "{% for item in [1, 2, 3] %}[{{ item }}]{% endfor %}", enable_async=True
22    )
23
24    async def func():
25        return await t.render_async()
26
27    rv = run(func())
28    assert rv == "[1][2][3]"
29
30
31def test_await_on_calls():
32    t = Template("{{ async_func() + normal_func() }}", enable_async=True)
33
34    async def async_func():
35        return 42
36
37    def normal_func():
38        return 23
39
40    async def func():
41        return await t.render_async(async_func=async_func, normal_func=normal_func)
42
43    rv = run(func())
44    assert rv == "65"
45
46
47def test_await_on_calls_normal_render():
48    t = Template("{{ async_func() + normal_func() }}", enable_async=True)
49
50    async def async_func():
51        return 42
52
53    def normal_func():
54        return 23
55
56    rv = t.render(async_func=async_func, normal_func=normal_func)
57
58    assert rv == "65"
59
60
61def test_await_and_macros():
62    t = Template(
63        "{% macro foo(x) %}[{{ x }}][{{ async_func() }}]{% endmacro %}{{ foo(42) }}",
64        enable_async=True,
65    )
66
67    async def async_func():
68        return 42
69
70    async def func():
71        return await t.render_async(async_func=async_func)
72
73    rv = run(func())
74    assert rv == "[42][42]"
75
76
77def test_async_blocks():
78    t = Template(
79        "{% block foo %}<Test>{% endblock %}{{ self.foo() }}",
80        enable_async=True,
81        autoescape=True,
82    )
83
84    async def func():
85        return await t.render_async()
86
87    rv = run(func())
88    assert rv == "<Test><Test>"
89
90
91def test_async_generate():
92    t = Template("{% for x in [1, 2, 3] %}{{ x }}{% endfor %}", enable_async=True)
93    rv = list(t.generate())
94    assert rv == ["1", "2", "3"]
95
96
97def test_async_iteration_in_templates():
98    t = Template("{% for x in rng %}{{ x }}{% endfor %}", enable_async=True)
99
100    async def async_iterator():
101        for item in [1, 2, 3]:
102            yield item
103
104    rv = list(t.generate(rng=async_iterator()))
105    assert rv == ["1", "2", "3"]
106
107
108def test_async_iteration_in_templates_extended():
109    t = Template(
110        "{% for x in rng %}{{ loop.index0 }}/{{ x }}{% endfor %}", enable_async=True
111    )
112    stream = t.generate(rng=auto_aiter(range(1, 4)))
113    assert next(stream) == "0"
114    assert "".join(stream) == "/11/22/3"
115
116
117@pytest.fixture
118def test_env_async():
119    env = Environment(
120        loader=DictLoader(
121            dict(
122                module="{% macro test() %}[{{ foo }}|{{ bar }}]{% endmacro %}",
123                header="[{{ foo }}|{{ 23 }}]",
124                o_printer="({{ o }})",
125            )
126        ),
127        enable_async=True,
128    )
129    env.globals["bar"] = 23
130    return env
131
132
133class TestAsyncImports:
134    def test_context_imports(self, test_env_async):
135        t = test_env_async.from_string('{% import "module" as m %}{{ m.test() }}')
136        assert t.render(foo=42) == "[|23]"
137        t = test_env_async.from_string(
138            '{% import "module" as m without context %}{{ m.test() }}'
139        )
140        assert t.render(foo=42) == "[|23]"
141        t = test_env_async.from_string(
142            '{% import "module" as m with context %}{{ m.test() }}'
143        )
144        assert t.render(foo=42) == "[42|23]"
145        t = test_env_async.from_string('{% from "module" import test %}{{ test() }}')
146        assert t.render(foo=42) == "[|23]"
147        t = test_env_async.from_string(
148            '{% from "module" import test without context %}{{ test() }}'
149        )
150        assert t.render(foo=42) == "[|23]"
151        t = test_env_async.from_string(
152            '{% from "module" import test with context %}{{ test() }}'
153        )
154        assert t.render(foo=42) == "[42|23]"
155
156    def test_trailing_comma(self, test_env_async):
157        test_env_async.from_string('{% from "foo" import bar, baz with context %}')
158        test_env_async.from_string('{% from "foo" import bar, baz, with context %}')
159        test_env_async.from_string('{% from "foo" import bar, with context %}')
160        test_env_async.from_string('{% from "foo" import bar, with, context %}')
161        test_env_async.from_string('{% from "foo" import bar, with with context %}')
162
163    def test_exports(self, test_env_async):
164        m = run(
165            test_env_async.from_string(
166                """
167            {% macro toplevel() %}...{% endmacro %}
168            {% macro __private() %}...{% endmacro %}
169            {% set variable = 42 %}
170            {% for item in [1] %}
171                {% macro notthere() %}{% endmacro %}
172            {% endfor %}
173        """
174            )._get_default_module_async()
175        )
176        assert run(m.toplevel()) == "..."
177        assert not hasattr(m, "__missing")
178        assert m.variable == 42
179        assert not hasattr(m, "notthere")
180
181
182class TestAsyncIncludes:
183    def test_context_include(self, test_env_async):
184        t = test_env_async.from_string('{% include "header" %}')
185        assert t.render(foo=42) == "[42|23]"
186        t = test_env_async.from_string('{% include "header" with context %}')
187        assert t.render(foo=42) == "[42|23]"
188        t = test_env_async.from_string('{% include "header" without context %}')
189        assert t.render(foo=42) == "[|23]"
190
191    def test_choice_includes(self, test_env_async):
192        t = test_env_async.from_string('{% include ["missing", "header"] %}')
193        assert t.render(foo=42) == "[42|23]"
194
195        t = test_env_async.from_string(
196            '{% include ["missing", "missing2"] ignore missing %}'
197        )
198        assert t.render(foo=42) == ""
199
200        t = test_env_async.from_string('{% include ["missing", "missing2"] %}')
201        pytest.raises(TemplateNotFound, t.render)
202        with pytest.raises(TemplatesNotFound) as e:
203            t.render()
204
205        assert e.value.templates == ["missing", "missing2"]
206        assert e.value.name == "missing2"
207
208        def test_includes(t, **ctx):
209            ctx["foo"] = 42
210            assert t.render(ctx) == "[42|23]"
211
212        t = test_env_async.from_string('{% include ["missing", "header"] %}')
213        test_includes(t)
214        t = test_env_async.from_string("{% include x %}")
215        test_includes(t, x=["missing", "header"])
216        t = test_env_async.from_string('{% include [x, "header"] %}')
217        test_includes(t, x="missing")
218        t = test_env_async.from_string("{% include x %}")
219        test_includes(t, x="header")
220        t = test_env_async.from_string("{% include x %}")
221        test_includes(t, x="header")
222        t = test_env_async.from_string("{% include [x] %}")
223        test_includes(t, x="header")
224
225    def test_include_ignoring_missing(self, test_env_async):
226        t = test_env_async.from_string('{% include "missing" %}')
227        pytest.raises(TemplateNotFound, t.render)
228        for extra in "", "with context", "without context":
229            t = test_env_async.from_string(
230                '{% include "missing" ignore missing ' + extra + " %}"
231            )
232            assert t.render() == ""
233
234    def test_context_include_with_overrides(self, test_env_async):
235        env = Environment(
236            loader=DictLoader(
237                dict(
238                    main="{% for item in [1, 2, 3] %}{% include 'item' %}{% endfor %}",
239                    item="{{ item }}",
240                )
241            )
242        )
243        assert env.get_template("main").render() == "123"
244
245    def test_unoptimized_scopes(self, test_env_async):
246        t = test_env_async.from_string(
247            """
248            {% macro outer(o) %}
249            {% macro inner() %}
250            {% include "o_printer" %}
251            {% endmacro %}
252            {{ inner() }}
253            {% endmacro %}
254            {{ outer("FOO") }}
255        """
256        )
257        assert t.render().strip() == "(FOO)"
258
259    def test_unoptimized_scopes_autoescape(self):
260        env = Environment(
261            loader=DictLoader(dict(o_printer="({{ o }})",)),
262            autoescape=True,
263            enable_async=True,
264        )
265        t = env.from_string(
266            """
267            {% macro outer(o) %}
268            {% macro inner() %}
269            {% include "o_printer" %}
270            {% endmacro %}
271            {{ inner() }}
272            {% endmacro %}
273            {{ outer("FOO") }}
274        """
275        )
276        assert t.render().strip() == "(FOO)"
277
278
279class TestAsyncForLoop:
280    def test_simple(self, test_env_async):
281        tmpl = test_env_async.from_string("{% for item in seq %}{{ item }}{% endfor %}")
282        assert tmpl.render(seq=list(range(10))) == "0123456789"
283
284    def test_else(self, test_env_async):
285        tmpl = test_env_async.from_string(
286            "{% for item in seq %}XXX{% else %}...{% endfor %}"
287        )
288        assert tmpl.render() == "..."
289
290    def test_empty_blocks(self, test_env_async):
291        tmpl = test_env_async.from_string(
292            "<{% for item in seq %}{% else %}{% endfor %}>"
293        )
294        assert tmpl.render() == "<>"
295
296    @pytest.mark.parametrize(
297        "transform", [lambda x: x, iter, reversed, lambda x: (i for i in x), auto_aiter]
298    )
299    def test_context_vars(self, test_env_async, transform):
300        t = test_env_async.from_string(
301            "{% for item in seq %}{{ loop.index }}|{{ loop.index0 }}"
302            "|{{ loop.revindex }}|{{ loop.revindex0 }}|{{ loop.first }}"
303            "|{{ loop.last }}|{{ loop.length }}\n{% endfor %}"
304        )
305        out = t.render(seq=transform([42, 24]))
306        assert out == "1|0|2|1|True|False|2\n2|1|1|0|False|True|2\n"
307
308    def test_cycling(self, test_env_async):
309        tmpl = test_env_async.from_string(
310            """{% for item in seq %}{{
311            loop.cycle('<1>', '<2>') }}{% endfor %}{%
312            for item in seq %}{{ loop.cycle(*through) }}{% endfor %}"""
313        )
314        output = tmpl.render(seq=list(range(4)), through=("<1>", "<2>"))
315        assert output == "<1><2>" * 4
316
317    def test_lookaround(self, test_env_async):
318        tmpl = test_env_async.from_string(
319            """{% for item in seq -%}
320            {{ loop.previtem|default('x') }}-{{ item }}-{{
321            loop.nextitem|default('x') }}|
322        {%- endfor %}"""
323        )
324        output = tmpl.render(seq=list(range(4)))
325        assert output == "x-0-1|0-1-2|1-2-3|2-3-x|"
326
327    def test_changed(self, test_env_async):
328        tmpl = test_env_async.from_string(
329            """{% for item in seq -%}
330            {{ loop.changed(item) }},
331        {%- endfor %}"""
332        )
333        output = tmpl.render(seq=[None, None, 1, 2, 2, 3, 4, 4, 4])
334        assert output == "True,False,True,True,False,True,True,False,False,"
335
336    def test_scope(self, test_env_async):
337        tmpl = test_env_async.from_string("{% for item in seq %}{% endfor %}{{ item }}")
338        output = tmpl.render(seq=list(range(10)))
339        assert not output
340
341    def test_varlen(self, test_env_async):
342        def inner():
343            yield from range(5)
344
345        tmpl = test_env_async.from_string(
346            "{% for item in iter %}{{ item }}{% endfor %}"
347        )
348        output = tmpl.render(iter=inner())
349        assert output == "01234"
350
351    def test_noniter(self, test_env_async):
352        tmpl = test_env_async.from_string("{% for item in none %}...{% endfor %}")
353        pytest.raises(TypeError, tmpl.render)
354
355    def test_recursive(self, test_env_async):
356        tmpl = test_env_async.from_string(
357            """{% for item in seq recursive -%}
358            [{{ item.a }}{% if item.b %}<{{ loop(item.b) }}>{% endif %}]
359        {%- endfor %}"""
360        )
361        assert (
362            tmpl.render(
363                seq=[
364                    dict(a=1, b=[dict(a=1), dict(a=2)]),
365                    dict(a=2, b=[dict(a=1), dict(a=2)]),
366                    dict(a=3, b=[dict(a="a")]),
367                ]
368            )
369            == "[1<[1][2]>][2<[1][2]>][3<[a]>]"
370        )
371
372    def test_recursive_lookaround(self, test_env_async):
373        tmpl = test_env_async.from_string(
374            """{% for item in seq recursive -%}
375            [{{ loop.previtem.a if loop.previtem is defined else 'x' }}.{{
376            item.a }}.{{ loop.nextitem.a if loop.nextitem is defined else 'x'
377            }}{% if item.b %}<{{ loop(item.b) }}>{% endif %}]
378        {%- endfor %}"""
379        )
380        assert (
381            tmpl.render(
382                seq=[
383                    dict(a=1, b=[dict(a=1), dict(a=2)]),
384                    dict(a=2, b=[dict(a=1), dict(a=2)]),
385                    dict(a=3, b=[dict(a="a")]),
386                ]
387            )
388            == "[x.1.2<[x.1.2][1.2.x]>][1.2.3<[x.1.2][1.2.x]>][2.3.x<[x.a.x]>]"
389        )
390
391    def test_recursive_depth0(self, test_env_async):
392        tmpl = test_env_async.from_string(
393            "{% for item in seq recursive %}[{{ loop.depth0 }}:{{ item.a }}"
394            "{% if item.b %}<{{ loop(item.b) }}>{% endif %}]{% endfor %}"
395        )
396        assert (
397            tmpl.render(
398                seq=[
399                    dict(a=1, b=[dict(a=1), dict(a=2)]),
400                    dict(a=2, b=[dict(a=1), dict(a=2)]),
401                    dict(a=3, b=[dict(a="a")]),
402                ]
403            )
404            == "[0:1<[1:1][1:2]>][0:2<[1:1][1:2]>][0:3<[1:a]>]"
405        )
406
407    def test_recursive_depth(self, test_env_async):
408        tmpl = test_env_async.from_string(
409            "{% for item in seq recursive %}[{{ loop.depth }}:{{ item.a }}"
410            "{% if item.b %}<{{ loop(item.b) }}>{% endif %}]{% endfor %}"
411        )
412        assert (
413            tmpl.render(
414                seq=[
415                    dict(a=1, b=[dict(a=1), dict(a=2)]),
416                    dict(a=2, b=[dict(a=1), dict(a=2)]),
417                    dict(a=3, b=[dict(a="a")]),
418                ]
419            )
420            == "[1:1<[2:1][2:2]>][1:2<[2:1][2:2]>][1:3<[2:a]>]"
421        )
422
423    def test_looploop(self, test_env_async):
424        tmpl = test_env_async.from_string(
425            """{% for row in table %}
426            {%- set rowloop = loop -%}
427            {% for cell in row -%}
428                [{{ rowloop.index }}|{{ loop.index }}]
429            {%- endfor %}
430        {%- endfor %}"""
431        )
432        assert tmpl.render(table=["ab", "cd"]) == "[1|1][1|2][2|1][2|2]"
433
434    def test_reversed_bug(self, test_env_async):
435        tmpl = test_env_async.from_string(
436            "{% for i in items %}{{ i }}"
437            "{% if not loop.last %}"
438            ",{% endif %}{% endfor %}"
439        )
440        assert tmpl.render(items=reversed([3, 2, 1])) == "1,2,3"
441
442    def test_loop_errors(self, test_env_async):
443        tmpl = test_env_async.from_string(
444            """{% for item in [1] if loop.index
445                                      == 0 %}...{% endfor %}"""
446        )
447        pytest.raises(UndefinedError, tmpl.render)
448        tmpl = test_env_async.from_string(
449            """{% for item in [] %}...{% else
450            %}{{ loop }}{% endfor %}"""
451        )
452        assert tmpl.render() == ""
453
454    def test_loop_filter(self, test_env_async):
455        tmpl = test_env_async.from_string(
456            "{% for item in range(10) if item is even %}[{{ item }}]{% endfor %}"
457        )
458        assert tmpl.render() == "[0][2][4][6][8]"
459        tmpl = test_env_async.from_string(
460            """
461            {%- for item in range(10) if item is even %}[{{
462                loop.index }}:{{ item }}]{% endfor %}"""
463        )
464        assert tmpl.render() == "[1:0][2:2][3:4][4:6][5:8]"
465
466    def test_scoped_special_var(self, test_env_async):
467        t = test_env_async.from_string(
468            "{% for s in seq %}[{{ loop.first }}{% for c in s %}"
469            "|{{ loop.first }}{% endfor %}]{% endfor %}"
470        )
471        assert t.render(seq=("ab", "cd")) == "[True|True|False][False|True|False]"
472
473    def test_scoped_loop_var(self, test_env_async):
474        t = test_env_async.from_string(
475            "{% for x in seq %}{{ loop.first }}"
476            "{% for y in seq %}{% endfor %}{% endfor %}"
477        )
478        assert t.render(seq="ab") == "TrueFalse"
479        t = test_env_async.from_string(
480            "{% for x in seq %}{% for y in seq %}"
481            "{{ loop.first }}{% endfor %}{% endfor %}"
482        )
483        assert t.render(seq="ab") == "TrueFalseTrueFalse"
484
485    def test_recursive_empty_loop_iter(self, test_env_async):
486        t = test_env_async.from_string(
487            """
488        {%- for item in foo recursive -%}{%- endfor -%}
489        """
490        )
491        assert t.render(dict(foo=[])) == ""
492
493    def test_call_in_loop(self, test_env_async):
494        t = test_env_async.from_string(
495            """
496        {%- macro do_something() -%}
497            [{{ caller() }}]
498        {%- endmacro %}
499
500        {%- for i in [1, 2, 3] %}
501            {%- call do_something() -%}
502                {{ i }}
503            {%- endcall %}
504        {%- endfor -%}
505        """
506        )
507        assert t.render() == "[1][2][3]"
508
509    def test_scoping_bug(self, test_env_async):
510        t = test_env_async.from_string(
511            """
512        {%- for item in foo %}...{{ item }}...{% endfor %}
513        {%- macro item(a) %}...{{ a }}...{% endmacro %}
514        {{- item(2) -}}
515        """
516        )
517        assert t.render(foo=(1,)) == "...1......2..."
518
519    def test_unpacking(self, test_env_async):
520        tmpl = test_env_async.from_string(
521            "{% for a, b, c in [[1, 2, 3]] %}{{ a }}|{{ b }}|{{ c }}{% endfor %}"
522        )
523        assert tmpl.render() == "1|2|3"
524
525    def test_recursive_loop_filter(self, test_env_async):
526        t = test_env_async.from_string(
527            """
528        <?xml version="1.0" encoding="UTF-8"?>
529        <urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9">
530          {%- for page in [site.root] if page.url != this recursive %}
531          <url><loc>{{ page.url }}</loc></url>
532          {{- loop(page.children) }}
533          {%- endfor %}
534        </urlset>
535        """
536        )
537        sm = t.render(
538            this="/foo",
539            site={"root": {"url": "/", "children": [{"url": "/foo"}, {"url": "/bar"}]}},
540        )
541        lines = [x.strip() for x in sm.splitlines() if x.strip()]
542        assert lines == [
543            '<?xml version="1.0" encoding="UTF-8"?>',
544            '<urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9">',
545            "<url><loc>/</loc></url>",
546            "<url><loc>/bar</loc></url>",
547            "</urlset>",
548        ]
549
550    def test_nonrecursive_loop_filter(self, test_env_async):
551        t = test_env_async.from_string(
552            """
553        <?xml version="1.0" encoding="UTF-8"?>
554        <urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9">
555          {%- for page in items if page.url != this %}
556          <url><loc>{{ page.url }}</loc></url>
557          {%- endfor %}
558        </urlset>
559        """
560        )
561        sm = t.render(
562            this="/foo", items=[{"url": "/"}, {"url": "/foo"}, {"url": "/bar"}]
563        )
564        lines = [x.strip() for x in sm.splitlines() if x.strip()]
565        assert lines == [
566            '<?xml version="1.0" encoding="UTF-8"?>',
567            '<urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9">',
568            "<url><loc>/</loc></url>",
569            "<url><loc>/bar</loc></url>",
570            "</urlset>",
571        ]
572
573    def test_bare_async(self, test_env_async):
574        t = test_env_async.from_string('{% extends "header" %}')
575        assert t.render(foo=42) == "[42|23]"
576
577    def test_awaitable_property_slicing(self, test_env_async):
578        t = test_env_async.from_string("{% for x in a.b[:1] %}{{ x }}{% endfor %}")
579        assert t.render(a=dict(b=[1, 2, 3])) == "1"
580
581
582def test_namespace_awaitable(test_env_async):
583    async def _test():
584        t = test_env_async.from_string(
585            '{% set ns = namespace(foo="Bar") %}{{ ns.foo }}'
586        )
587        actual = await t.render_async()
588        assert actual == "Bar"
589
590    run(_test())
591