1import asyncio 2import unittest 3 4from test.test_asyncio import functional as func_tests 5 6 7class ReceiveStuffProto(asyncio.BufferedProtocol): 8 def __init__(self, cb, con_lost_fut): 9 self.cb = cb 10 self.con_lost_fut = con_lost_fut 11 12 def get_buffer(self, sizehint): 13 self.buffer = bytearray(100) 14 return self.buffer 15 16 def buffer_updated(self, nbytes): 17 self.cb(self.buffer[:nbytes]) 18 19 def connection_lost(self, exc): 20 if exc is None: 21 self.con_lost_fut.set_result(None) 22 else: 23 self.con_lost_fut.set_exception(exc) 24 25 26class BaseTestBufferedProtocol(func_tests.FunctionalTestCaseMixin): 27 28 def new_loop(self): 29 raise NotImplementedError 30 31 def test_buffered_proto_create_connection(self): 32 33 NOISE = b'12345678+' * 1024 34 35 async def client(addr): 36 data = b'' 37 38 def on_buf(buf): 39 nonlocal data 40 data += buf 41 if data == NOISE: 42 tr.write(b'1') 43 44 conn_lost_fut = self.loop.create_future() 45 46 tr, pr = await self.loop.create_connection( 47 lambda: ReceiveStuffProto(on_buf, conn_lost_fut), *addr) 48 49 await conn_lost_fut 50 51 async def on_server_client(reader, writer): 52 writer.write(NOISE) 53 await reader.readexactly(1) 54 writer.close() 55 await writer.wait_closed() 56 57 srv = self.loop.run_until_complete( 58 asyncio.start_server( 59 on_server_client, '127.0.0.1', 0)) 60 61 addr = srv.sockets[0].getsockname() 62 self.loop.run_until_complete( 63 asyncio.wait_for(client(addr), 5, loop=self.loop)) 64 65 srv.close() 66 self.loop.run_until_complete(srv.wait_closed()) 67 68 69class BufferedProtocolSelectorTests(BaseTestBufferedProtocol, 70 unittest.TestCase): 71 72 def new_loop(self): 73 return asyncio.SelectorEventLoop() 74 75 76@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only') 77class BufferedProtocolProactorTests(BaseTestBufferedProtocol, 78 unittest.TestCase): 79 80 def new_loop(self): 81 return asyncio.ProactorEventLoop() 82 83 84if __name__ == '__main__': 85 unittest.main() 86