1import asyncio 2import asyncio.events 3import contextlib 4import os 5import pprint 6import select 7import socket 8import tempfile 9import threading 10 11 12class FunctionalTestCaseMixin: 13 14 def new_loop(self): 15 return asyncio.new_event_loop() 16 17 def run_loop_briefly(self, *, delay=0.01): 18 self.loop.run_until_complete(asyncio.sleep(delay, loop=self.loop)) 19 20 def loop_exception_handler(self, loop, context): 21 self.__unhandled_exceptions.append(context) 22 self.loop.default_exception_handler(context) 23 24 def setUp(self): 25 self.loop = self.new_loop() 26 asyncio.set_event_loop(None) 27 28 self.loop.set_exception_handler(self.loop_exception_handler) 29 self.__unhandled_exceptions = [] 30 31 # Disable `_get_running_loop`. 32 self._old_get_running_loop = asyncio.events._get_running_loop 33 asyncio.events._get_running_loop = lambda: None 34 35 def tearDown(self): 36 try: 37 self.loop.close() 38 39 if self.__unhandled_exceptions: 40 print('Unexpected calls to loop.call_exception_handler():') 41 pprint.pprint(self.__unhandled_exceptions) 42 self.fail('unexpected calls to loop.call_exception_handler()') 43 44 finally: 45 asyncio.events._get_running_loop = self._old_get_running_loop 46 asyncio.set_event_loop(None) 47 self.loop = None 48 49 def tcp_server(self, server_prog, *, 50 family=socket.AF_INET, 51 addr=None, 52 timeout=5, 53 backlog=1, 54 max_clients=10): 55 56 if addr is None: 57 if hasattr(socket, 'AF_UNIX') and family == socket.AF_UNIX: 58 with tempfile.NamedTemporaryFile() as tmp: 59 addr = tmp.name 60 else: 61 addr = ('127.0.0.1', 0) 62 63 sock = socket.socket(family, socket.SOCK_STREAM) 64 65 if timeout is None: 66 raise RuntimeError('timeout is required') 67 if timeout <= 0: 68 raise RuntimeError('only blocking sockets are supported') 69 sock.settimeout(timeout) 70 71 try: 72 sock.bind(addr) 73 sock.listen(backlog) 74 except OSError as ex: 75 sock.close() 76 raise ex 77 78 return TestThreadedServer( 79 self, sock, server_prog, timeout, max_clients) 80 81 def tcp_client(self, client_prog, 82 family=socket.AF_INET, 83 timeout=10): 84 85 sock = socket.socket(family, socket.SOCK_STREAM) 86 87 if timeout is None: 88 raise RuntimeError('timeout is required') 89 if timeout <= 0: 90 raise RuntimeError('only blocking sockets are supported') 91 sock.settimeout(timeout) 92 93 return TestThreadedClient( 94 self, sock, client_prog, timeout) 95 96 def unix_server(self, *args, **kwargs): 97 if not hasattr(socket, 'AF_UNIX'): 98 raise NotImplementedError 99 return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs) 100 101 def unix_client(self, *args, **kwargs): 102 if not hasattr(socket, 'AF_UNIX'): 103 raise NotImplementedError 104 return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs) 105 106 @contextlib.contextmanager 107 def unix_sock_name(self): 108 with tempfile.TemporaryDirectory() as td: 109 fn = os.path.join(td, 'sock') 110 try: 111 yield fn 112 finally: 113 try: 114 os.unlink(fn) 115 except OSError: 116 pass 117 118 def _abort_socket_test(self, ex): 119 try: 120 self.loop.stop() 121 finally: 122 self.fail(ex) 123 124 125############################################################################## 126# Socket Testing Utilities 127############################################################################## 128 129 130class TestSocketWrapper: 131 132 def __init__(self, sock): 133 self.__sock = sock 134 135 def recv_all(self, n): 136 buf = b'' 137 while len(buf) < n: 138 data = self.recv(n - len(buf)) 139 if data == b'': 140 raise ConnectionAbortedError 141 buf += data 142 return buf 143 144 def start_tls(self, ssl_context, *, 145 server_side=False, 146 server_hostname=None): 147 148 ssl_sock = ssl_context.wrap_socket( 149 self.__sock, server_side=server_side, 150 server_hostname=server_hostname, 151 do_handshake_on_connect=False) 152 153 try: 154 ssl_sock.do_handshake() 155 except: 156 ssl_sock.close() 157 raise 158 finally: 159 self.__sock.close() 160 161 self.__sock = ssl_sock 162 163 def __getattr__(self, name): 164 return getattr(self.__sock, name) 165 166 def __repr__(self): 167 return '<{} {!r}>'.format(type(self).__name__, self.__sock) 168 169 170class SocketThread(threading.Thread): 171 172 def stop(self): 173 self._active = False 174 self.join() 175 176 def __enter__(self): 177 self.start() 178 return self 179 180 def __exit__(self, *exc): 181 self.stop() 182 183 184class TestThreadedClient(SocketThread): 185 186 def __init__(self, test, sock, prog, timeout): 187 threading.Thread.__init__(self, None, None, 'test-client') 188 self.daemon = True 189 190 self._timeout = timeout 191 self._sock = sock 192 self._active = True 193 self._prog = prog 194 self._test = test 195 196 def run(self): 197 try: 198 self._prog(TestSocketWrapper(self._sock)) 199 except Exception as ex: 200 self._test._abort_socket_test(ex) 201 202 203class TestThreadedServer(SocketThread): 204 205 def __init__(self, test, sock, prog, timeout, max_clients): 206 threading.Thread.__init__(self, None, None, 'test-server') 207 self.daemon = True 208 209 self._clients = 0 210 self._finished_clients = 0 211 self._max_clients = max_clients 212 self._timeout = timeout 213 self._sock = sock 214 self._active = True 215 216 self._prog = prog 217 218 self._s1, self._s2 = socket.socketpair() 219 self._s1.setblocking(False) 220 221 self._test = test 222 223 def stop(self): 224 try: 225 if self._s2 and self._s2.fileno() != -1: 226 try: 227 self._s2.send(b'stop') 228 except OSError: 229 pass 230 finally: 231 super().stop() 232 233 def run(self): 234 try: 235 with self._sock: 236 self._sock.setblocking(0) 237 self._run() 238 finally: 239 self._s1.close() 240 self._s2.close() 241 242 def _run(self): 243 while self._active: 244 if self._clients >= self._max_clients: 245 return 246 247 r, w, x = select.select( 248 [self._sock, self._s1], [], [], self._timeout) 249 250 if self._s1 in r: 251 return 252 253 if self._sock in r: 254 try: 255 conn, addr = self._sock.accept() 256 except BlockingIOError: 257 continue 258 except socket.timeout: 259 if not self._active: 260 return 261 else: 262 raise 263 else: 264 self._clients += 1 265 conn.settimeout(self._timeout) 266 try: 267 with conn: 268 self._handle_client(conn) 269 except Exception as ex: 270 self._active = False 271 try: 272 raise 273 finally: 274 self._test._abort_socket_test(ex) 275 276 def _handle_client(self, sock): 277 self._prog(TestSocketWrapper(sock)) 278 279 @property 280 def addr(self): 281 return self._sock.getsockname() 282