1import asyncio 2import inspect 3 4from .case import TestCase 5 6 7 8class IsolatedAsyncioTestCase(TestCase): 9 # Names intentionally have a long prefix 10 # to reduce a chance of clashing with user-defined attributes 11 # from inherited test case 12 # 13 # The class doesn't call loop.run_until_complete(self.setUp()) and family 14 # but uses a different approach: 15 # 1. create a long-running task that reads self.setUp() 16 # awaitable from queue along with a future 17 # 2. await the awaitable object passing in and set the result 18 # into the future object 19 # 3. Outer code puts the awaitable and the future object into a queue 20 # with waiting for the future 21 # The trick is necessary because every run_until_complete() call 22 # creates a new task with embedded ContextVar context. 23 # To share contextvars between setUp(), test and tearDown() we need to execute 24 # them inside the same task. 25 26 # Note: the test case modifies event loop policy if the policy was not instantiated 27 # yet. 28 # asyncio.get_event_loop_policy() creates a default policy on demand but never 29 # returns None 30 # I believe this is not an issue in user level tests but python itself for testing 31 # should reset a policy in every test module 32 # by calling asyncio.set_event_loop_policy(None) in tearDownModule() 33 34 def __init__(self, methodName='runTest'): 35 super().__init__(methodName) 36 self._asyncioTestLoop = None 37 self._asyncioCallsQueue = None 38 39 async def asyncSetUp(self): 40 pass 41 42 async def asyncTearDown(self): 43 pass 44 45 def addAsyncCleanup(self, func, /, *args, **kwargs): 46 # A trivial trampoline to addCleanup() 47 # the function exists because it has a different semantics 48 # and signature: 49 # addCleanup() accepts regular functions 50 # but addAsyncCleanup() accepts coroutines 51 # 52 # We intentionally don't add inspect.iscoroutinefunction() check 53 # for func argument because there is no way 54 # to check for async function reliably: 55 # 1. It can be "async def func()" iself 56 # 2. Class can implement "async def __call__()" method 57 # 3. Regular "def func()" that returns awaitable object 58 self.addCleanup(*(func, *args), **kwargs) 59 60 def _callSetUp(self): 61 self.setUp() 62 self._callAsync(self.asyncSetUp) 63 64 def _callTestMethod(self, method): 65 self._callMaybeAsync(method) 66 67 def _callTearDown(self): 68 self._callAsync(self.asyncTearDown) 69 self.tearDown() 70 71 def _callCleanup(self, function, *args, **kwargs): 72 self._callMaybeAsync(function, *args, **kwargs) 73 74 def _callAsync(self, func, /, *args, **kwargs): 75 assert self._asyncioTestLoop is not None 76 ret = func(*args, **kwargs) 77 assert inspect.isawaitable(ret) 78 fut = self._asyncioTestLoop.create_future() 79 self._asyncioCallsQueue.put_nowait((fut, ret)) 80 return self._asyncioTestLoop.run_until_complete(fut) 81 82 def _callMaybeAsync(self, func, /, *args, **kwargs): 83 assert self._asyncioTestLoop is not None 84 ret = func(*args, **kwargs) 85 if inspect.isawaitable(ret): 86 fut = self._asyncioTestLoop.create_future() 87 self._asyncioCallsQueue.put_nowait((fut, ret)) 88 return self._asyncioTestLoop.run_until_complete(fut) 89 else: 90 return ret 91 92 async def _asyncioLoopRunner(self, fut): 93 self._asyncioCallsQueue = queue = asyncio.Queue() 94 fut.set_result(None) 95 while True: 96 query = await queue.get() 97 queue.task_done() 98 if query is None: 99 return 100 fut, awaitable = query 101 try: 102 ret = await awaitable 103 if not fut.cancelled(): 104 fut.set_result(ret) 105 except asyncio.CancelledError: 106 raise 107 except Exception as ex: 108 if not fut.cancelled(): 109 fut.set_exception(ex) 110 111 def _setupAsyncioLoop(self): 112 assert self._asyncioTestLoop is None 113 loop = asyncio.new_event_loop() 114 asyncio.set_event_loop(loop) 115 loop.set_debug(True) 116 self._asyncioTestLoop = loop 117 fut = loop.create_future() 118 self._asyncioCallsTask = loop.create_task(self._asyncioLoopRunner(fut)) 119 loop.run_until_complete(fut) 120 121 def _tearDownAsyncioLoop(self): 122 assert self._asyncioTestLoop is not None 123 loop = self._asyncioTestLoop 124 self._asyncioTestLoop = None 125 self._asyncioCallsQueue.put_nowait(None) 126 loop.run_until_complete(self._asyncioCallsQueue.join()) 127 128 try: 129 # cancel all tasks 130 to_cancel = asyncio.all_tasks(loop) 131 if not to_cancel: 132 return 133 134 for task in to_cancel: 135 task.cancel() 136 137 loop.run_until_complete( 138 asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)) 139 140 for task in to_cancel: 141 if task.cancelled(): 142 continue 143 if task.exception() is not None: 144 loop.call_exception_handler({ 145 'message': 'unhandled exception during test shutdown', 146 'exception': task.exception(), 147 'task': task, 148 }) 149 # shutdown asyncgens 150 loop.run_until_complete(loop.shutdown_asyncgens()) 151 finally: 152 asyncio.set_event_loop(None) 153 loop.close() 154 155 def run(self, result=None): 156 self._setupAsyncioLoop() 157 try: 158 return super().run(result) 159 finally: 160 self._tearDownAsyncioLoop() 161