1import asyncio 2from contextlib import asynccontextmanager, AbstractAsyncContextManager, AsyncExitStack 3import functools 4from test import support 5import unittest 6 7from test.test_contextlib import TestBaseExitStack 8 9 10def _async_test(func): 11 """Decorator to turn an async function into a test case.""" 12 @functools.wraps(func) 13 def wrapper(*args, **kwargs): 14 coro = func(*args, **kwargs) 15 loop = asyncio.new_event_loop() 16 asyncio.set_event_loop(loop) 17 try: 18 return loop.run_until_complete(coro) 19 finally: 20 loop.close() 21 asyncio.set_event_loop(None) 22 return wrapper 23 24 25class TestAbstractAsyncContextManager(unittest.TestCase): 26 27 @_async_test 28 async def test_enter(self): 29 class DefaultEnter(AbstractAsyncContextManager): 30 async def __aexit__(self, *args): 31 await super().__aexit__(*args) 32 33 manager = DefaultEnter() 34 self.assertIs(await manager.__aenter__(), manager) 35 36 async with manager as context: 37 self.assertIs(manager, context) 38 39 def test_exit_is_abstract(self): 40 class MissingAexit(AbstractAsyncContextManager): 41 pass 42 43 with self.assertRaises(TypeError): 44 MissingAexit() 45 46 def test_structural_subclassing(self): 47 class ManagerFromScratch: 48 async def __aenter__(self): 49 return self 50 async def __aexit__(self, exc_type, exc_value, traceback): 51 return None 52 53 self.assertTrue(issubclass(ManagerFromScratch, AbstractAsyncContextManager)) 54 55 class DefaultEnter(AbstractAsyncContextManager): 56 async def __aexit__(self, *args): 57 await super().__aexit__(*args) 58 59 self.assertTrue(issubclass(DefaultEnter, AbstractAsyncContextManager)) 60 61 class NoneAenter(ManagerFromScratch): 62 __aenter__ = None 63 64 self.assertFalse(issubclass(NoneAenter, AbstractAsyncContextManager)) 65 66 class NoneAexit(ManagerFromScratch): 67 __aexit__ = None 68 69 self.assertFalse(issubclass(NoneAexit, AbstractAsyncContextManager)) 70 71 72class AsyncContextManagerTestCase(unittest.TestCase): 73 74 @_async_test 75 async def test_contextmanager_plain(self): 76 state = [] 77 @asynccontextmanager 78 async def woohoo(): 79 state.append(1) 80 yield 42 81 state.append(999) 82 async with woohoo() as x: 83 self.assertEqual(state, [1]) 84 self.assertEqual(x, 42) 85 state.append(x) 86 self.assertEqual(state, [1, 42, 999]) 87 88 @_async_test 89 async def test_contextmanager_finally(self): 90 state = [] 91 @asynccontextmanager 92 async def woohoo(): 93 state.append(1) 94 try: 95 yield 42 96 finally: 97 state.append(999) 98 with self.assertRaises(ZeroDivisionError): 99 async with woohoo() as x: 100 self.assertEqual(state, [1]) 101 self.assertEqual(x, 42) 102 state.append(x) 103 raise ZeroDivisionError() 104 self.assertEqual(state, [1, 42, 999]) 105 106 @_async_test 107 async def test_contextmanager_no_reraise(self): 108 @asynccontextmanager 109 async def whee(): 110 yield 111 ctx = whee() 112 await ctx.__aenter__() 113 # Calling __aexit__ should not result in an exception 114 self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None)) 115 116 @_async_test 117 async def test_contextmanager_trap_yield_after_throw(self): 118 @asynccontextmanager 119 async def whoo(): 120 try: 121 yield 122 except: 123 yield 124 ctx = whoo() 125 await ctx.__aenter__() 126 with self.assertRaises(RuntimeError): 127 await ctx.__aexit__(TypeError, TypeError('foo'), None) 128 129 @_async_test 130 async def test_contextmanager_trap_no_yield(self): 131 @asynccontextmanager 132 async def whoo(): 133 if False: 134 yield 135 ctx = whoo() 136 with self.assertRaises(RuntimeError): 137 await ctx.__aenter__() 138 139 @_async_test 140 async def test_contextmanager_trap_second_yield(self): 141 @asynccontextmanager 142 async def whoo(): 143 yield 144 yield 145 ctx = whoo() 146 await ctx.__aenter__() 147 with self.assertRaises(RuntimeError): 148 await ctx.__aexit__(None, None, None) 149 150 @_async_test 151 async def test_contextmanager_non_normalised(self): 152 @asynccontextmanager 153 async def whoo(): 154 try: 155 yield 156 except RuntimeError: 157 raise SyntaxError 158 159 ctx = whoo() 160 await ctx.__aenter__() 161 with self.assertRaises(SyntaxError): 162 await ctx.__aexit__(RuntimeError, None, None) 163 164 @_async_test 165 async def test_contextmanager_except(self): 166 state = [] 167 @asynccontextmanager 168 async def woohoo(): 169 state.append(1) 170 try: 171 yield 42 172 except ZeroDivisionError as e: 173 state.append(e.args[0]) 174 self.assertEqual(state, [1, 42, 999]) 175 async with woohoo() as x: 176 self.assertEqual(state, [1]) 177 self.assertEqual(x, 42) 178 state.append(x) 179 raise ZeroDivisionError(999) 180 self.assertEqual(state, [1, 42, 999]) 181 182 @_async_test 183 async def test_contextmanager_except_stopiter(self): 184 @asynccontextmanager 185 async def woohoo(): 186 yield 187 188 for stop_exc in (StopIteration('spam'), StopAsyncIteration('ham')): 189 with self.subTest(type=type(stop_exc)): 190 try: 191 async with woohoo(): 192 raise stop_exc 193 except Exception as ex: 194 self.assertIs(ex, stop_exc) 195 else: 196 self.fail(f'{stop_exc} was suppressed') 197 198 @_async_test 199 async def test_contextmanager_wrap_runtimeerror(self): 200 @asynccontextmanager 201 async def woohoo(): 202 try: 203 yield 204 except Exception as exc: 205 raise RuntimeError(f'caught {exc}') from exc 206 207 with self.assertRaises(RuntimeError): 208 async with woohoo(): 209 1 / 0 210 211 # If the context manager wrapped StopAsyncIteration in a RuntimeError, 212 # we also unwrap it, because we can't tell whether the wrapping was 213 # done by the generator machinery or by the generator itself. 214 with self.assertRaises(StopAsyncIteration): 215 async with woohoo(): 216 raise StopAsyncIteration 217 218 def _create_contextmanager_attribs(self): 219 def attribs(**kw): 220 def decorate(func): 221 for k,v in kw.items(): 222 setattr(func,k,v) 223 return func 224 return decorate 225 @asynccontextmanager 226 @attribs(foo='bar') 227 async def baz(spam): 228 """Whee!""" 229 yield 230 return baz 231 232 def test_contextmanager_attribs(self): 233 baz = self._create_contextmanager_attribs() 234 self.assertEqual(baz.__name__,'baz') 235 self.assertEqual(baz.foo, 'bar') 236 237 @support.requires_docstrings 238 def test_contextmanager_doc_attrib(self): 239 baz = self._create_contextmanager_attribs() 240 self.assertEqual(baz.__doc__, "Whee!") 241 242 @support.requires_docstrings 243 @_async_test 244 async def test_instance_docstring_given_cm_docstring(self): 245 baz = self._create_contextmanager_attribs()(None) 246 self.assertEqual(baz.__doc__, "Whee!") 247 async with baz: 248 pass # suppress warning 249 250 @_async_test 251 async def test_keywords(self): 252 # Ensure no keyword arguments are inhibited 253 @asynccontextmanager 254 async def woohoo(self, func, args, kwds): 255 yield (self, func, args, kwds) 256 async with woohoo(self=11, func=22, args=33, kwds=44) as target: 257 self.assertEqual(target, (11, 22, 33, 44)) 258 259 260class TestAsyncExitStack(TestBaseExitStack, unittest.TestCase): 261 class SyncAsyncExitStack(AsyncExitStack): 262 @staticmethod 263 def run_coroutine(coro): 264 loop = asyncio.get_event_loop() 265 266 f = asyncio.ensure_future(coro) 267 f.add_done_callback(lambda f: loop.stop()) 268 loop.run_forever() 269 270 exc = f.exception() 271 272 if not exc: 273 return f.result() 274 else: 275 context = exc.__context__ 276 277 try: 278 raise exc 279 except: 280 exc.__context__ = context 281 raise exc 282 283 def close(self): 284 return self.run_coroutine(self.aclose()) 285 286 def __enter__(self): 287 return self.run_coroutine(self.__aenter__()) 288 289 def __exit__(self, *exc_details): 290 return self.run_coroutine(self.__aexit__(*exc_details)) 291 292 exit_stack = SyncAsyncExitStack 293 294 def setUp(self): 295 self.loop = asyncio.new_event_loop() 296 asyncio.set_event_loop(self.loop) 297 self.addCleanup(self.loop.close) 298 299 @_async_test 300 async def test_async_callback(self): 301 expected = [ 302 ((), {}), 303 ((1,), {}), 304 ((1,2), {}), 305 ((), dict(example=1)), 306 ((1,), dict(example=1)), 307 ((1,2), dict(example=1)), 308 ] 309 result = [] 310 async def _exit(*args, **kwds): 311 """Test metadata propagation""" 312 result.append((args, kwds)) 313 314 async with AsyncExitStack() as stack: 315 for args, kwds in reversed(expected): 316 if args and kwds: 317 f = stack.push_async_callback(_exit, *args, **kwds) 318 elif args: 319 f = stack.push_async_callback(_exit, *args) 320 elif kwds: 321 f = stack.push_async_callback(_exit, **kwds) 322 else: 323 f = stack.push_async_callback(_exit) 324 self.assertIs(f, _exit) 325 for wrapper in stack._exit_callbacks: 326 self.assertIs(wrapper[1].__wrapped__, _exit) 327 self.assertNotEqual(wrapper[1].__name__, _exit.__name__) 328 self.assertIsNone(wrapper[1].__doc__, _exit.__doc__) 329 330 self.assertEqual(result, expected) 331 332 @_async_test 333 async def test_async_push(self): 334 exc_raised = ZeroDivisionError 335 async def _expect_exc(exc_type, exc, exc_tb): 336 self.assertIs(exc_type, exc_raised) 337 async def _suppress_exc(*exc_details): 338 return True 339 async def _expect_ok(exc_type, exc, exc_tb): 340 self.assertIsNone(exc_type) 341 self.assertIsNone(exc) 342 self.assertIsNone(exc_tb) 343 class ExitCM(object): 344 def __init__(self, check_exc): 345 self.check_exc = check_exc 346 async def __aenter__(self): 347 self.fail("Should not be called!") 348 async def __aexit__(self, *exc_details): 349 await self.check_exc(*exc_details) 350 351 async with self.exit_stack() as stack: 352 stack.push_async_exit(_expect_ok) 353 self.assertIs(stack._exit_callbacks[-1][1], _expect_ok) 354 cm = ExitCM(_expect_ok) 355 stack.push_async_exit(cm) 356 self.assertIs(stack._exit_callbacks[-1][1].__self__, cm) 357 stack.push_async_exit(_suppress_exc) 358 self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc) 359 cm = ExitCM(_expect_exc) 360 stack.push_async_exit(cm) 361 self.assertIs(stack._exit_callbacks[-1][1].__self__, cm) 362 stack.push_async_exit(_expect_exc) 363 self.assertIs(stack._exit_callbacks[-1][1], _expect_exc) 364 stack.push_async_exit(_expect_exc) 365 self.assertIs(stack._exit_callbacks[-1][1], _expect_exc) 366 1/0 367 368 @_async_test 369 async def test_async_enter_context(self): 370 class TestCM(object): 371 async def __aenter__(self): 372 result.append(1) 373 async def __aexit__(self, *exc_details): 374 result.append(3) 375 376 result = [] 377 cm = TestCM() 378 379 async with AsyncExitStack() as stack: 380 @stack.push_async_callback # Registered first => cleaned up last 381 async def _exit(): 382 result.append(4) 383 self.assertIsNotNone(_exit) 384 await stack.enter_async_context(cm) 385 self.assertIs(stack._exit_callbacks[-1][1].__self__, cm) 386 result.append(2) 387 388 self.assertEqual(result, [1, 2, 3, 4]) 389 390 @_async_test 391 async def test_async_exit_exception_chaining(self): 392 # Ensure exception chaining matches the reference behaviour 393 async def raise_exc(exc): 394 raise exc 395 396 saved_details = None 397 async def suppress_exc(*exc_details): 398 nonlocal saved_details 399 saved_details = exc_details 400 return True 401 402 try: 403 async with self.exit_stack() as stack: 404 stack.push_async_callback(raise_exc, IndexError) 405 stack.push_async_callback(raise_exc, KeyError) 406 stack.push_async_callback(raise_exc, AttributeError) 407 stack.push_async_exit(suppress_exc) 408 stack.push_async_callback(raise_exc, ValueError) 409 1 / 0 410 except IndexError as exc: 411 self.assertIsInstance(exc.__context__, KeyError) 412 self.assertIsInstance(exc.__context__.__context__, AttributeError) 413 # Inner exceptions were suppressed 414 self.assertIsNone(exc.__context__.__context__.__context__) 415 else: 416 self.fail("Expected IndexError, but no exception was raised") 417 # Check the inner exceptions 418 inner_exc = saved_details[1] 419 self.assertIsInstance(inner_exc, ValueError) 420 self.assertIsInstance(inner_exc.__context__, ZeroDivisionError) 421 422 423if __name__ == '__main__': 424 unittest.main() 425