1import asyncio
2from contextlib import asynccontextmanager, AbstractAsyncContextManager, AsyncExitStack
3import functools
4from test import support
5import unittest
6
7from test.test_contextlib import TestBaseExitStack
8
9
10def _async_test(func):
11    """Decorator to turn an async function into a test case."""
12    @functools.wraps(func)
13    def wrapper(*args, **kwargs):
14        coro = func(*args, **kwargs)
15        loop = asyncio.new_event_loop()
16        asyncio.set_event_loop(loop)
17        try:
18            return loop.run_until_complete(coro)
19        finally:
20            loop.close()
21            asyncio.set_event_loop(None)
22    return wrapper
23
24
25class TestAbstractAsyncContextManager(unittest.TestCase):
26
27    @_async_test
28    async def test_enter(self):
29        class DefaultEnter(AbstractAsyncContextManager):
30            async def __aexit__(self, *args):
31                await super().__aexit__(*args)
32
33        manager = DefaultEnter()
34        self.assertIs(await manager.__aenter__(), manager)
35
36        async with manager as context:
37            self.assertIs(manager, context)
38
39    def test_exit_is_abstract(self):
40        class MissingAexit(AbstractAsyncContextManager):
41            pass
42
43        with self.assertRaises(TypeError):
44            MissingAexit()
45
46    def test_structural_subclassing(self):
47        class ManagerFromScratch:
48            async def __aenter__(self):
49                return self
50            async def __aexit__(self, exc_type, exc_value, traceback):
51                return None
52
53        self.assertTrue(issubclass(ManagerFromScratch, AbstractAsyncContextManager))
54
55        class DefaultEnter(AbstractAsyncContextManager):
56            async def __aexit__(self, *args):
57                await super().__aexit__(*args)
58
59        self.assertTrue(issubclass(DefaultEnter, AbstractAsyncContextManager))
60
61        class NoneAenter(ManagerFromScratch):
62            __aenter__ = None
63
64        self.assertFalse(issubclass(NoneAenter, AbstractAsyncContextManager))
65
66        class NoneAexit(ManagerFromScratch):
67            __aexit__ = None
68
69        self.assertFalse(issubclass(NoneAexit, AbstractAsyncContextManager))
70
71
72class AsyncContextManagerTestCase(unittest.TestCase):
73
74    @_async_test
75    async def test_contextmanager_plain(self):
76        state = []
77        @asynccontextmanager
78        async def woohoo():
79            state.append(1)
80            yield 42
81            state.append(999)
82        async with woohoo() as x:
83            self.assertEqual(state, [1])
84            self.assertEqual(x, 42)
85            state.append(x)
86        self.assertEqual(state, [1, 42, 999])
87
88    @_async_test
89    async def test_contextmanager_finally(self):
90        state = []
91        @asynccontextmanager
92        async def woohoo():
93            state.append(1)
94            try:
95                yield 42
96            finally:
97                state.append(999)
98        with self.assertRaises(ZeroDivisionError):
99            async with woohoo() as x:
100                self.assertEqual(state, [1])
101                self.assertEqual(x, 42)
102                state.append(x)
103                raise ZeroDivisionError()
104        self.assertEqual(state, [1, 42, 999])
105
106    @_async_test
107    async def test_contextmanager_no_reraise(self):
108        @asynccontextmanager
109        async def whee():
110            yield
111        ctx = whee()
112        await ctx.__aenter__()
113        # Calling __aexit__ should not result in an exception
114        self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None))
115
116    @_async_test
117    async def test_contextmanager_trap_yield_after_throw(self):
118        @asynccontextmanager
119        async def whoo():
120            try:
121                yield
122            except:
123                yield
124        ctx = whoo()
125        await ctx.__aenter__()
126        with self.assertRaises(RuntimeError):
127            await ctx.__aexit__(TypeError, TypeError('foo'), None)
128
129    @_async_test
130    async def test_contextmanager_trap_no_yield(self):
131        @asynccontextmanager
132        async def whoo():
133            if False:
134                yield
135        ctx = whoo()
136        with self.assertRaises(RuntimeError):
137            await ctx.__aenter__()
138
139    @_async_test
140    async def test_contextmanager_trap_second_yield(self):
141        @asynccontextmanager
142        async def whoo():
143            yield
144            yield
145        ctx = whoo()
146        await ctx.__aenter__()
147        with self.assertRaises(RuntimeError):
148            await ctx.__aexit__(None, None, None)
149
150    @_async_test
151    async def test_contextmanager_non_normalised(self):
152        @asynccontextmanager
153        async def whoo():
154            try:
155                yield
156            except RuntimeError:
157                raise SyntaxError
158
159        ctx = whoo()
160        await ctx.__aenter__()
161        with self.assertRaises(SyntaxError):
162            await ctx.__aexit__(RuntimeError, None, None)
163
164    @_async_test
165    async def test_contextmanager_except(self):
166        state = []
167        @asynccontextmanager
168        async def woohoo():
169            state.append(1)
170            try:
171                yield 42
172            except ZeroDivisionError as e:
173                state.append(e.args[0])
174                self.assertEqual(state, [1, 42, 999])
175        async with woohoo() as x:
176            self.assertEqual(state, [1])
177            self.assertEqual(x, 42)
178            state.append(x)
179            raise ZeroDivisionError(999)
180        self.assertEqual(state, [1, 42, 999])
181
182    @_async_test
183    async def test_contextmanager_except_stopiter(self):
184        @asynccontextmanager
185        async def woohoo():
186            yield
187
188        for stop_exc in (StopIteration('spam'), StopAsyncIteration('ham')):
189            with self.subTest(type=type(stop_exc)):
190                try:
191                    async with woohoo():
192                        raise stop_exc
193                except Exception as ex:
194                    self.assertIs(ex, stop_exc)
195                else:
196                    self.fail(f'{stop_exc} was suppressed')
197
198    @_async_test
199    async def test_contextmanager_wrap_runtimeerror(self):
200        @asynccontextmanager
201        async def woohoo():
202            try:
203                yield
204            except Exception as exc:
205                raise RuntimeError(f'caught {exc}') from exc
206
207        with self.assertRaises(RuntimeError):
208            async with woohoo():
209                1 / 0
210
211        # If the context manager wrapped StopAsyncIteration in a RuntimeError,
212        # we also unwrap it, because we can't tell whether the wrapping was
213        # done by the generator machinery or by the generator itself.
214        with self.assertRaises(StopAsyncIteration):
215            async with woohoo():
216                raise StopAsyncIteration
217
218    def _create_contextmanager_attribs(self):
219        def attribs(**kw):
220            def decorate(func):
221                for k,v in kw.items():
222                    setattr(func,k,v)
223                return func
224            return decorate
225        @asynccontextmanager
226        @attribs(foo='bar')
227        async def baz(spam):
228            """Whee!"""
229            yield
230        return baz
231
232    def test_contextmanager_attribs(self):
233        baz = self._create_contextmanager_attribs()
234        self.assertEqual(baz.__name__,'baz')
235        self.assertEqual(baz.foo, 'bar')
236
237    @support.requires_docstrings
238    def test_contextmanager_doc_attrib(self):
239        baz = self._create_contextmanager_attribs()
240        self.assertEqual(baz.__doc__, "Whee!")
241
242    @support.requires_docstrings
243    @_async_test
244    async def test_instance_docstring_given_cm_docstring(self):
245        baz = self._create_contextmanager_attribs()(None)
246        self.assertEqual(baz.__doc__, "Whee!")
247        async with baz:
248            pass  # suppress warning
249
250    @_async_test
251    async def test_keywords(self):
252        # Ensure no keyword arguments are inhibited
253        @asynccontextmanager
254        async def woohoo(self, func, args, kwds):
255            yield (self, func, args, kwds)
256        async with woohoo(self=11, func=22, args=33, kwds=44) as target:
257            self.assertEqual(target, (11, 22, 33, 44))
258
259
260class TestAsyncExitStack(TestBaseExitStack, unittest.TestCase):
261    class SyncAsyncExitStack(AsyncExitStack):
262        @staticmethod
263        def run_coroutine(coro):
264            loop = asyncio.get_event_loop()
265
266            f = asyncio.ensure_future(coro)
267            f.add_done_callback(lambda f: loop.stop())
268            loop.run_forever()
269
270            exc = f.exception()
271
272            if not exc:
273                return f.result()
274            else:
275                context = exc.__context__
276
277                try:
278                    raise exc
279                except:
280                    exc.__context__ = context
281                    raise exc
282
283        def close(self):
284            return self.run_coroutine(self.aclose())
285
286        def __enter__(self):
287            return self.run_coroutine(self.__aenter__())
288
289        def __exit__(self, *exc_details):
290            return self.run_coroutine(self.__aexit__(*exc_details))
291
292    exit_stack = SyncAsyncExitStack
293
294    def setUp(self):
295        self.loop = asyncio.new_event_loop()
296        asyncio.set_event_loop(self.loop)
297        self.addCleanup(self.loop.close)
298
299    @_async_test
300    async def test_async_callback(self):
301        expected = [
302            ((), {}),
303            ((1,), {}),
304            ((1,2), {}),
305            ((), dict(example=1)),
306            ((1,), dict(example=1)),
307            ((1,2), dict(example=1)),
308        ]
309        result = []
310        async def _exit(*args, **kwds):
311            """Test metadata propagation"""
312            result.append((args, kwds))
313
314        async with AsyncExitStack() as stack:
315            for args, kwds in reversed(expected):
316                if args and kwds:
317                    f = stack.push_async_callback(_exit, *args, **kwds)
318                elif args:
319                    f = stack.push_async_callback(_exit, *args)
320                elif kwds:
321                    f = stack.push_async_callback(_exit, **kwds)
322                else:
323                    f = stack.push_async_callback(_exit)
324                self.assertIs(f, _exit)
325            for wrapper in stack._exit_callbacks:
326                self.assertIs(wrapper[1].__wrapped__, _exit)
327                self.assertNotEqual(wrapper[1].__name__, _exit.__name__)
328                self.assertIsNone(wrapper[1].__doc__, _exit.__doc__)
329
330        self.assertEqual(result, expected)
331
332    @_async_test
333    async def test_async_push(self):
334        exc_raised = ZeroDivisionError
335        async def _expect_exc(exc_type, exc, exc_tb):
336            self.assertIs(exc_type, exc_raised)
337        async def _suppress_exc(*exc_details):
338            return True
339        async def _expect_ok(exc_type, exc, exc_tb):
340            self.assertIsNone(exc_type)
341            self.assertIsNone(exc)
342            self.assertIsNone(exc_tb)
343        class ExitCM(object):
344            def __init__(self, check_exc):
345                self.check_exc = check_exc
346            async def __aenter__(self):
347                self.fail("Should not be called!")
348            async def __aexit__(self, *exc_details):
349                await self.check_exc(*exc_details)
350
351        async with self.exit_stack() as stack:
352            stack.push_async_exit(_expect_ok)
353            self.assertIs(stack._exit_callbacks[-1][1], _expect_ok)
354            cm = ExitCM(_expect_ok)
355            stack.push_async_exit(cm)
356            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
357            stack.push_async_exit(_suppress_exc)
358            self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc)
359            cm = ExitCM(_expect_exc)
360            stack.push_async_exit(cm)
361            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
362            stack.push_async_exit(_expect_exc)
363            self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
364            stack.push_async_exit(_expect_exc)
365            self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
366            1/0
367
368    @_async_test
369    async def test_async_enter_context(self):
370        class TestCM(object):
371            async def __aenter__(self):
372                result.append(1)
373            async def __aexit__(self, *exc_details):
374                result.append(3)
375
376        result = []
377        cm = TestCM()
378
379        async with AsyncExitStack() as stack:
380            @stack.push_async_callback  # Registered first => cleaned up last
381            async def _exit():
382                result.append(4)
383            self.assertIsNotNone(_exit)
384            await stack.enter_async_context(cm)
385            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
386            result.append(2)
387
388        self.assertEqual(result, [1, 2, 3, 4])
389
390    @_async_test
391    async def test_async_exit_exception_chaining(self):
392        # Ensure exception chaining matches the reference behaviour
393        async def raise_exc(exc):
394            raise exc
395
396        saved_details = None
397        async def suppress_exc(*exc_details):
398            nonlocal saved_details
399            saved_details = exc_details
400            return True
401
402        try:
403            async with self.exit_stack() as stack:
404                stack.push_async_callback(raise_exc, IndexError)
405                stack.push_async_callback(raise_exc, KeyError)
406                stack.push_async_callback(raise_exc, AttributeError)
407                stack.push_async_exit(suppress_exc)
408                stack.push_async_callback(raise_exc, ValueError)
409                1 / 0
410        except IndexError as exc:
411            self.assertIsInstance(exc.__context__, KeyError)
412            self.assertIsInstance(exc.__context__.__context__, AttributeError)
413            # Inner exceptions were suppressed
414            self.assertIsNone(exc.__context__.__context__.__context__)
415        else:
416            self.fail("Expected IndexError, but no exception was raised")
417        # Check the inner exceptions
418        inner_exc = saved_details[1]
419        self.assertIsInstance(inner_exc, ValueError)
420        self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)
421
422
423if __name__ == '__main__':
424    unittest.main()
425