1import asyncio
2import inspect
3
4from .case import TestCase
5
6
7
8class IsolatedAsyncioTestCase(TestCase):
9    # Names intentionally have a long prefix
10    # to reduce a chance of clashing with user-defined attributes
11    # from inherited test case
12    #
13    # The class doesn't call loop.run_until_complete(self.setUp()) and family
14    # but uses a different approach:
15    # 1. create a long-running task that reads self.setUp()
16    #    awaitable from queue along with a future
17    # 2. await the awaitable object passing in and set the result
18    #    into the future object
19    # 3. Outer code puts the awaitable and the future object into a queue
20    #    with waiting for the future
21    # The trick is necessary because every run_until_complete() call
22    # creates a new task with embedded ContextVar context.
23    # To share contextvars between setUp(), test and tearDown() we need to execute
24    # them inside the same task.
25
26    # Note: the test case modifies event loop policy if the policy was not instantiated
27    # yet.
28    # asyncio.get_event_loop_policy() creates a default policy on demand but never
29    # returns None
30    # I believe this is not an issue in user level tests but python itself for testing
31    # should reset a policy in every test module
32    # by calling asyncio.set_event_loop_policy(None) in tearDownModule()
33
34    def __init__(self, methodName='runTest'):
35        super().__init__(methodName)
36        self._asyncioTestLoop = None
37        self._asyncioCallsQueue = None
38
39    async def asyncSetUp(self):
40        pass
41
42    async def asyncTearDown(self):
43        pass
44
45    def addAsyncCleanup(self, func, /, *args, **kwargs):
46        # A trivial trampoline to addCleanup()
47        # the function exists because it has a different semantics
48        # and signature:
49        # addCleanup() accepts regular functions
50        # but addAsyncCleanup() accepts coroutines
51        #
52        # We intentionally don't add inspect.iscoroutinefunction() check
53        # for func argument because there is no way
54        # to check for async function reliably:
55        # 1. It can be "async def func()" iself
56        # 2. Class can implement "async def __call__()" method
57        # 3. Regular "def func()" that returns awaitable object
58        self.addCleanup(*(func, *args), **kwargs)
59
60    def _callSetUp(self):
61        self.setUp()
62        self._callAsync(self.asyncSetUp)
63
64    def _callTestMethod(self, method):
65        self._callMaybeAsync(method)
66
67    def _callTearDown(self):
68        self._callAsync(self.asyncTearDown)
69        self.tearDown()
70
71    def _callCleanup(self, function, *args, **kwargs):
72        self._callMaybeAsync(function, *args, **kwargs)
73
74    def _callAsync(self, func, /, *args, **kwargs):
75        assert self._asyncioTestLoop is not None
76        ret = func(*args, **kwargs)
77        assert inspect.isawaitable(ret)
78        fut = self._asyncioTestLoop.create_future()
79        self._asyncioCallsQueue.put_nowait((fut, ret))
80        return self._asyncioTestLoop.run_until_complete(fut)
81
82    def _callMaybeAsync(self, func, /, *args, **kwargs):
83        assert self._asyncioTestLoop is not None
84        ret = func(*args, **kwargs)
85        if inspect.isawaitable(ret):
86            fut = self._asyncioTestLoop.create_future()
87            self._asyncioCallsQueue.put_nowait((fut, ret))
88            return self._asyncioTestLoop.run_until_complete(fut)
89        else:
90            return ret
91
92    async def _asyncioLoopRunner(self, fut):
93        self._asyncioCallsQueue = queue = asyncio.Queue()
94        fut.set_result(None)
95        while True:
96            query = await queue.get()
97            queue.task_done()
98            if query is None:
99                return
100            fut, awaitable = query
101            try:
102                ret = await awaitable
103                if not fut.cancelled():
104                    fut.set_result(ret)
105            except asyncio.CancelledError:
106                raise
107            except Exception as ex:
108                if not fut.cancelled():
109                    fut.set_exception(ex)
110
111    def _setupAsyncioLoop(self):
112        assert self._asyncioTestLoop is None
113        loop = asyncio.new_event_loop()
114        asyncio.set_event_loop(loop)
115        loop.set_debug(True)
116        self._asyncioTestLoop = loop
117        fut = loop.create_future()
118        self._asyncioCallsTask = loop.create_task(self._asyncioLoopRunner(fut))
119        loop.run_until_complete(fut)
120
121    def _tearDownAsyncioLoop(self):
122        assert self._asyncioTestLoop is not None
123        loop = self._asyncioTestLoop
124        self._asyncioTestLoop = None
125        self._asyncioCallsQueue.put_nowait(None)
126        loop.run_until_complete(self._asyncioCallsQueue.join())
127
128        try:
129            # cancel all tasks
130            to_cancel = asyncio.all_tasks(loop)
131            if not to_cancel:
132                return
133
134            for task in to_cancel:
135                task.cancel()
136
137            loop.run_until_complete(
138                asyncio.gather(*to_cancel, loop=loop, return_exceptions=True))
139
140            for task in to_cancel:
141                if task.cancelled():
142                    continue
143                if task.exception() is not None:
144                    loop.call_exception_handler({
145                        'message': 'unhandled exception during test shutdown',
146                        'exception': task.exception(),
147                        'task': task,
148                    })
149            # shutdown asyncgens
150            loop.run_until_complete(loop.shutdown_asyncgens())
151        finally:
152            asyncio.set_event_loop(None)
153            loop.close()
154
155    def run(self, result=None):
156        self._setupAsyncioLoop()
157        try:
158            return super().run(result)
159        finally:
160            self._tearDownAsyncioLoop()
161