1import asyncio
2import asyncio.events
3import contextlib
4import os
5import pprint
6import select
7import socket
8import tempfile
9import threading
10
11
12class FunctionalTestCaseMixin:
13
14    def new_loop(self):
15        return asyncio.new_event_loop()
16
17    def run_loop_briefly(self, *, delay=0.01):
18        self.loop.run_until_complete(asyncio.sleep(delay, loop=self.loop))
19
20    def loop_exception_handler(self, loop, context):
21        self.__unhandled_exceptions.append(context)
22        self.loop.default_exception_handler(context)
23
24    def setUp(self):
25        self.loop = self.new_loop()
26        asyncio.set_event_loop(None)
27
28        self.loop.set_exception_handler(self.loop_exception_handler)
29        self.__unhandled_exceptions = []
30
31        # Disable `_get_running_loop`.
32        self._old_get_running_loop = asyncio.events._get_running_loop
33        asyncio.events._get_running_loop = lambda: None
34
35    def tearDown(self):
36        try:
37            self.loop.close()
38
39            if self.__unhandled_exceptions:
40                print('Unexpected calls to loop.call_exception_handler():')
41                pprint.pprint(self.__unhandled_exceptions)
42                self.fail('unexpected calls to loop.call_exception_handler()')
43
44        finally:
45            asyncio.events._get_running_loop = self._old_get_running_loop
46            asyncio.set_event_loop(None)
47            self.loop = None
48
49    def tcp_server(self, server_prog, *,
50                   family=socket.AF_INET,
51                   addr=None,
52                   timeout=5,
53                   backlog=1,
54                   max_clients=10):
55
56        if addr is None:
57            if hasattr(socket, 'AF_UNIX') and family == socket.AF_UNIX:
58                with tempfile.NamedTemporaryFile() as tmp:
59                    addr = tmp.name
60            else:
61                addr = ('127.0.0.1', 0)
62
63        sock = socket.socket(family, socket.SOCK_STREAM)
64
65        if timeout is None:
66            raise RuntimeError('timeout is required')
67        if timeout <= 0:
68            raise RuntimeError('only blocking sockets are supported')
69        sock.settimeout(timeout)
70
71        try:
72            sock.bind(addr)
73            sock.listen(backlog)
74        except OSError as ex:
75            sock.close()
76            raise ex
77
78        return TestThreadedServer(
79            self, sock, server_prog, timeout, max_clients)
80
81    def tcp_client(self, client_prog,
82                   family=socket.AF_INET,
83                   timeout=10):
84
85        sock = socket.socket(family, socket.SOCK_STREAM)
86
87        if timeout is None:
88            raise RuntimeError('timeout is required')
89        if timeout <= 0:
90            raise RuntimeError('only blocking sockets are supported')
91        sock.settimeout(timeout)
92
93        return TestThreadedClient(
94            self, sock, client_prog, timeout)
95
96    def unix_server(self, *args, **kwargs):
97        if not hasattr(socket, 'AF_UNIX'):
98            raise NotImplementedError
99        return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)
100
101    def unix_client(self, *args, **kwargs):
102        if not hasattr(socket, 'AF_UNIX'):
103            raise NotImplementedError
104        return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)
105
106    @contextlib.contextmanager
107    def unix_sock_name(self):
108        with tempfile.TemporaryDirectory() as td:
109            fn = os.path.join(td, 'sock')
110            try:
111                yield fn
112            finally:
113                try:
114                    os.unlink(fn)
115                except OSError:
116                    pass
117
118    def _abort_socket_test(self, ex):
119        try:
120            self.loop.stop()
121        finally:
122            self.fail(ex)
123
124
125##############################################################################
126# Socket Testing Utilities
127##############################################################################
128
129
130class TestSocketWrapper:
131
132    def __init__(self, sock):
133        self.__sock = sock
134
135    def recv_all(self, n):
136        buf = b''
137        while len(buf) < n:
138            data = self.recv(n - len(buf))
139            if data == b'':
140                raise ConnectionAbortedError
141            buf += data
142        return buf
143
144    def start_tls(self, ssl_context, *,
145                  server_side=False,
146                  server_hostname=None):
147
148        ssl_sock = ssl_context.wrap_socket(
149            self.__sock, server_side=server_side,
150            server_hostname=server_hostname,
151            do_handshake_on_connect=False)
152
153        try:
154            ssl_sock.do_handshake()
155        except:
156            ssl_sock.close()
157            raise
158        finally:
159            self.__sock.close()
160
161        self.__sock = ssl_sock
162
163    def __getattr__(self, name):
164        return getattr(self.__sock, name)
165
166    def __repr__(self):
167        return '<{} {!r}>'.format(type(self).__name__, self.__sock)
168
169
170class SocketThread(threading.Thread):
171
172    def stop(self):
173        self._active = False
174        self.join()
175
176    def __enter__(self):
177        self.start()
178        return self
179
180    def __exit__(self, *exc):
181        self.stop()
182
183
184class TestThreadedClient(SocketThread):
185
186    def __init__(self, test, sock, prog, timeout):
187        threading.Thread.__init__(self, None, None, 'test-client')
188        self.daemon = True
189
190        self._timeout = timeout
191        self._sock = sock
192        self._active = True
193        self._prog = prog
194        self._test = test
195
196    def run(self):
197        try:
198            self._prog(TestSocketWrapper(self._sock))
199        except Exception as ex:
200            self._test._abort_socket_test(ex)
201
202
203class TestThreadedServer(SocketThread):
204
205    def __init__(self, test, sock, prog, timeout, max_clients):
206        threading.Thread.__init__(self, None, None, 'test-server')
207        self.daemon = True
208
209        self._clients = 0
210        self._finished_clients = 0
211        self._max_clients = max_clients
212        self._timeout = timeout
213        self._sock = sock
214        self._active = True
215
216        self._prog = prog
217
218        self._s1, self._s2 = socket.socketpair()
219        self._s1.setblocking(False)
220
221        self._test = test
222
223    def stop(self):
224        try:
225            if self._s2 and self._s2.fileno() != -1:
226                try:
227                    self._s2.send(b'stop')
228                except OSError:
229                    pass
230        finally:
231            super().stop()
232
233    def run(self):
234        try:
235            with self._sock:
236                self._sock.setblocking(0)
237                self._run()
238        finally:
239            self._s1.close()
240            self._s2.close()
241
242    def _run(self):
243        while self._active:
244            if self._clients >= self._max_clients:
245                return
246
247            r, w, x = select.select(
248                [self._sock, self._s1], [], [], self._timeout)
249
250            if self._s1 in r:
251                return
252
253            if self._sock in r:
254                try:
255                    conn, addr = self._sock.accept()
256                except BlockingIOError:
257                    continue
258                except socket.timeout:
259                    if not self._active:
260                        return
261                    else:
262                        raise
263                else:
264                    self._clients += 1
265                    conn.settimeout(self._timeout)
266                    try:
267                        with conn:
268                            self._handle_client(conn)
269                    except Exception as ex:
270                        self._active = False
271                        try:
272                            raise
273                        finally:
274                            self._test._abort_socket_test(ex)
275
276    def _handle_client(self, sock):
277        self._prog(TestSocketWrapper(sock))
278
279    @property
280    def addr(self):
281        return self._sock.getsockname()
282