1"""Utilities shared by tests."""
2
3import asyncio
4import collections
5import contextlib
6import io
7import logging
8import os
9import re
10import selectors
11import socket
12import socketserver
13import sys
14import tempfile
15import threading
16import time
17import unittest
18import weakref
19
20from unittest import mock
21
22from http.server import HTTPServer
23from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
24
25try:
26    import ssl
27except ImportError:  # pragma: no cover
28    ssl = None
29
30from asyncio import base_events
31from asyncio import events
32from asyncio import format_helpers
33from asyncio import futures
34from asyncio import tasks
35from asyncio.log import logger
36from test import support
37
38
39def data_file(filename):
40    if hasattr(support, 'TEST_HOME_DIR'):
41        fullname = os.path.join(support.TEST_HOME_DIR, filename)
42        if os.path.isfile(fullname):
43            return fullname
44    fullname = os.path.join(os.path.dirname(__file__), '..', filename)
45    if os.path.isfile(fullname):
46        return fullname
47    raise FileNotFoundError(filename)
48
49
50ONLYCERT = data_file('ssl_cert.pem')
51ONLYKEY = data_file('ssl_key.pem')
52SIGNED_CERTFILE = data_file('keycert3.pem')
53SIGNING_CA = data_file('pycacert.pem')
54PEERCERT = {
55    'OCSP': ('http://testca.pythontest.net/testca/ocsp/',),
56    'caIssuers': ('http://testca.pythontest.net/testca/pycacert.cer',),
57    'crlDistributionPoints': ('http://testca.pythontest.net/testca/revocation.crl',),
58    'issuer': ((('countryName', 'XY'),),
59            (('organizationName', 'Python Software Foundation CA'),),
60            (('commonName', 'our-ca-server'),)),
61    'notAfter': 'Jul  7 14:23:16 2028 GMT',
62    'notBefore': 'Aug 29 14:23:16 2018 GMT',
63    'serialNumber': 'CB2D80995A69525C',
64    'subject': ((('countryName', 'XY'),),
65             (('localityName', 'Castle Anthrax'),),
66             (('organizationName', 'Python Software Foundation'),),
67             (('commonName', 'localhost'),)),
68    'subjectAltName': (('DNS', 'localhost'),),
69    'version': 3
70}
71
72
73def simple_server_sslcontext():
74    server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
75    server_context.load_cert_chain(ONLYCERT, ONLYKEY)
76    server_context.check_hostname = False
77    server_context.verify_mode = ssl.CERT_NONE
78    return server_context
79
80
81def simple_client_sslcontext(*, disable_verify=True):
82    client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
83    client_context.check_hostname = False
84    if disable_verify:
85        client_context.verify_mode = ssl.CERT_NONE
86    return client_context
87
88
89def dummy_ssl_context():
90    if ssl is None:
91        return None
92    else:
93        return ssl.SSLContext(ssl.PROTOCOL_TLS)
94
95
96def run_briefly(loop):
97    async def once():
98        pass
99    gen = once()
100    t = loop.create_task(gen)
101    # Don't log a warning if the task is not done after run_until_complete().
102    # It occurs if the loop is stopped or if a task raises a BaseException.
103    t._log_destroy_pending = False
104    try:
105        loop.run_until_complete(t)
106    finally:
107        gen.close()
108
109
110def run_until(loop, pred, timeout=support.SHORT_TIMEOUT):
111    deadline = time.monotonic() + timeout
112    while not pred():
113        if timeout is not None:
114            timeout = deadline - time.monotonic()
115            if timeout <= 0:
116                raise futures.TimeoutError()
117        loop.run_until_complete(tasks.sleep(0.001))
118
119
120def run_once(loop):
121    """Legacy API to run once through the event loop.
122
123    This is the recommended pattern for test code.  It will poll the
124    selector once and run all callbacks scheduled in response to I/O
125    events.
126    """
127    loop.call_soon(loop.stop)
128    loop.run_forever()
129
130
131class SilentWSGIRequestHandler(WSGIRequestHandler):
132
133    def get_stderr(self):
134        return io.StringIO()
135
136    def log_message(self, format, *args):
137        pass
138
139
140class SilentWSGIServer(WSGIServer):
141
142    request_timeout = support.LOOPBACK_TIMEOUT
143
144    def get_request(self):
145        request, client_addr = super().get_request()
146        request.settimeout(self.request_timeout)
147        return request, client_addr
148
149    def handle_error(self, request, client_address):
150        pass
151
152
153class SSLWSGIServerMixin:
154
155    def finish_request(self, request, client_address):
156        # The relative location of our test directory (which
157        # contains the ssl key and certificate files) differs
158        # between the stdlib and stand-alone asyncio.
159        # Prefer our own if we can find it.
160        context = ssl.SSLContext()
161        context.load_cert_chain(ONLYCERT, ONLYKEY)
162
163        ssock = context.wrap_socket(request, server_side=True)
164        try:
165            self.RequestHandlerClass(ssock, client_address, self)
166            ssock.close()
167        except OSError:
168            # maybe socket has been closed by peer
169            pass
170
171
172class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
173    pass
174
175
176def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
177
178    def loop(environ):
179        size = int(environ['CONTENT_LENGTH'])
180        while size:
181            data = environ['wsgi.input'].read(min(size, 0x10000))
182            yield data
183            size -= len(data)
184
185    def app(environ, start_response):
186        status = '200 OK'
187        headers = [('Content-type', 'text/plain')]
188        start_response(status, headers)
189        if environ['PATH_INFO'] == '/loop':
190            return loop(environ)
191        else:
192            return [b'Test message']
193
194    # Run the test WSGI server in a separate thread in order not to
195    # interfere with event handling in the main thread
196    server_class = server_ssl_cls if use_ssl else server_cls
197    httpd = server_class(address, SilentWSGIRequestHandler)
198    httpd.set_app(app)
199    httpd.address = httpd.server_address
200    server_thread = threading.Thread(
201        target=lambda: httpd.serve_forever(poll_interval=0.05))
202    server_thread.start()
203    try:
204        yield httpd
205    finally:
206        httpd.shutdown()
207        httpd.server_close()
208        server_thread.join()
209
210
211if hasattr(socket, 'AF_UNIX'):
212
213    class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
214
215        def server_bind(self):
216            socketserver.UnixStreamServer.server_bind(self)
217            self.server_name = '127.0.0.1'
218            self.server_port = 80
219
220
221    class UnixWSGIServer(UnixHTTPServer, WSGIServer):
222
223        request_timeout = support.LOOPBACK_TIMEOUT
224
225        def server_bind(self):
226            UnixHTTPServer.server_bind(self)
227            self.setup_environ()
228
229        def get_request(self):
230            request, client_addr = super().get_request()
231            request.settimeout(self.request_timeout)
232            # Code in the stdlib expects that get_request
233            # will return a socket and a tuple (host, port).
234            # However, this isn't true for UNIX sockets,
235            # as the second return value will be a path;
236            # hence we return some fake data sufficient
237            # to get the tests going
238            return request, ('127.0.0.1', '')
239
240
241    class SilentUnixWSGIServer(UnixWSGIServer):
242
243        def handle_error(self, request, client_address):
244            pass
245
246
247    class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
248        pass
249
250
251    def gen_unix_socket_path():
252        with tempfile.NamedTemporaryFile() as file:
253            return file.name
254
255
256    @contextlib.contextmanager
257    def unix_socket_path():
258        path = gen_unix_socket_path()
259        try:
260            yield path
261        finally:
262            try:
263                os.unlink(path)
264            except OSError:
265                pass
266
267
268    @contextlib.contextmanager
269    def run_test_unix_server(*, use_ssl=False):
270        with unix_socket_path() as path:
271            yield from _run_test_server(address=path, use_ssl=use_ssl,
272                                        server_cls=SilentUnixWSGIServer,
273                                        server_ssl_cls=UnixSSLWSGIServer)
274
275
276@contextlib.contextmanager
277def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
278    yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
279                                server_cls=SilentWSGIServer,
280                                server_ssl_cls=SSLWSGIServer)
281
282
283def make_test_protocol(base):
284    dct = {}
285    for name in dir(base):
286        if name.startswith('__') and name.endswith('__'):
287            # skip magic names
288            continue
289        dct[name] = MockCallback(return_value=None)
290    return type('TestProtocol', (base,) + base.__bases__, dct)()
291
292
293class TestSelector(selectors.BaseSelector):
294
295    def __init__(self):
296        self.keys = {}
297
298    def register(self, fileobj, events, data=None):
299        key = selectors.SelectorKey(fileobj, 0, events, data)
300        self.keys[fileobj] = key
301        return key
302
303    def unregister(self, fileobj):
304        return self.keys.pop(fileobj)
305
306    def select(self, timeout):
307        return []
308
309    def get_map(self):
310        return self.keys
311
312
313class TestLoop(base_events.BaseEventLoop):
314    """Loop for unittests.
315
316    It manages self time directly.
317    If something scheduled to be executed later then
318    on next loop iteration after all ready handlers done
319    generator passed to __init__ is calling.
320
321    Generator should be like this:
322
323        def gen():
324            ...
325            when = yield ...
326            ... = yield time_advance
327
328    Value returned by yield is absolute time of next scheduled handler.
329    Value passed to yield is time advance to move loop's time forward.
330    """
331
332    def __init__(self, gen=None):
333        super().__init__()
334
335        if gen is None:
336            def gen():
337                yield
338            self._check_on_close = False
339        else:
340            self._check_on_close = True
341
342        self._gen = gen()
343        next(self._gen)
344        self._time = 0
345        self._clock_resolution = 1e-9
346        self._timers = []
347        self._selector = TestSelector()
348
349        self.readers = {}
350        self.writers = {}
351        self.reset_counters()
352
353        self._transports = weakref.WeakValueDictionary()
354
355    def time(self):
356        return self._time
357
358    def advance_time(self, advance):
359        """Move test time forward."""
360        if advance:
361            self._time += advance
362
363    def close(self):
364        super().close()
365        if self._check_on_close:
366            try:
367                self._gen.send(0)
368            except StopIteration:
369                pass
370            else:  # pragma: no cover
371                raise AssertionError("Time generator is not finished")
372
373    def _add_reader(self, fd, callback, *args):
374        self.readers[fd] = events.Handle(callback, args, self, None)
375
376    def _remove_reader(self, fd):
377        self.remove_reader_count[fd] += 1
378        if fd in self.readers:
379            del self.readers[fd]
380            return True
381        else:
382            return False
383
384    def assert_reader(self, fd, callback, *args):
385        if fd not in self.readers:
386            raise AssertionError(f'fd {fd} is not registered')
387        handle = self.readers[fd]
388        if handle._callback != callback:
389            raise AssertionError(
390                f'unexpected callback: {handle._callback} != {callback}')
391        if handle._args != args:
392            raise AssertionError(
393                f'unexpected callback args: {handle._args} != {args}')
394
395    def assert_no_reader(self, fd):
396        if fd in self.readers:
397            raise AssertionError(f'fd {fd} is registered')
398
399    def _add_writer(self, fd, callback, *args):
400        self.writers[fd] = events.Handle(callback, args, self, None)
401
402    def _remove_writer(self, fd):
403        self.remove_writer_count[fd] += 1
404        if fd in self.writers:
405            del self.writers[fd]
406            return True
407        else:
408            return False
409
410    def assert_writer(self, fd, callback, *args):
411        assert fd in self.writers, 'fd {} is not registered'.format(fd)
412        handle = self.writers[fd]
413        assert handle._callback == callback, '{!r} != {!r}'.format(
414            handle._callback, callback)
415        assert handle._args == args, '{!r} != {!r}'.format(
416            handle._args, args)
417
418    def _ensure_fd_no_transport(self, fd):
419        if not isinstance(fd, int):
420            try:
421                fd = int(fd.fileno())
422            except (AttributeError, TypeError, ValueError):
423                # This code matches selectors._fileobj_to_fd function.
424                raise ValueError("Invalid file object: "
425                                 "{!r}".format(fd)) from None
426        try:
427            transport = self._transports[fd]
428        except KeyError:
429            pass
430        else:
431            raise RuntimeError(
432                'File descriptor {!r} is used by transport {!r}'.format(
433                    fd, transport))
434
435    def add_reader(self, fd, callback, *args):
436        """Add a reader callback."""
437        self._ensure_fd_no_transport(fd)
438        return self._add_reader(fd, callback, *args)
439
440    def remove_reader(self, fd):
441        """Remove a reader callback."""
442        self._ensure_fd_no_transport(fd)
443        return self._remove_reader(fd)
444
445    def add_writer(self, fd, callback, *args):
446        """Add a writer callback.."""
447        self._ensure_fd_no_transport(fd)
448        return self._add_writer(fd, callback, *args)
449
450    def remove_writer(self, fd):
451        """Remove a writer callback."""
452        self._ensure_fd_no_transport(fd)
453        return self._remove_writer(fd)
454
455    def reset_counters(self):
456        self.remove_reader_count = collections.defaultdict(int)
457        self.remove_writer_count = collections.defaultdict(int)
458
459    def _run_once(self):
460        super()._run_once()
461        for when in self._timers:
462            advance = self._gen.send(when)
463            self.advance_time(advance)
464        self._timers = []
465
466    def call_at(self, when, callback, *args, context=None):
467        self._timers.append(when)
468        return super().call_at(when, callback, *args, context=context)
469
470    def _process_events(self, event_list):
471        return
472
473    def _write_to_self(self):
474        pass
475
476
477def MockCallback(**kwargs):
478    return mock.Mock(spec=['__call__'], **kwargs)
479
480
481class MockPattern(str):
482    """A regex based str with a fuzzy __eq__.
483
484    Use this helper with 'mock.assert_called_with', or anywhere
485    where a regex comparison between strings is needed.
486
487    For instance:
488       mock_call.assert_called_with(MockPattern('spam.*ham'))
489    """
490    def __eq__(self, other):
491        return bool(re.search(str(self), other, re.S))
492
493
494class MockInstanceOf:
495    def __init__(self, type):
496        self._type = type
497
498    def __eq__(self, other):
499        return isinstance(other, self._type)
500
501
502def get_function_source(func):
503    source = format_helpers._get_function_source(func)
504    if source is None:
505        raise ValueError("unable to get the source of %r" % (func,))
506    return source
507
508
509class TestCase(unittest.TestCase):
510    @staticmethod
511    def close_loop(loop):
512        if loop._default_executor is not None:
513            if not loop.is_closed():
514                loop.run_until_complete(loop.shutdown_default_executor())
515            else:
516                loop._default_executor.shutdown(wait=True)
517        loop.close()
518        policy = support.maybe_get_event_loop_policy()
519        if policy is not None:
520            try:
521                watcher = policy.get_child_watcher()
522            except NotImplementedError:
523                # watcher is not implemented by EventLoopPolicy, e.g. Windows
524                pass
525            else:
526                if isinstance(watcher, asyncio.ThreadedChildWatcher):
527                    threads = list(watcher._threads.values())
528                    for thread in threads:
529                        thread.join()
530
531    def set_event_loop(self, loop, *, cleanup=True):
532        assert loop is not None
533        # ensure that the event loop is passed explicitly in asyncio
534        events.set_event_loop(None)
535        if cleanup:
536            self.addCleanup(self.close_loop, loop)
537
538    def new_test_loop(self, gen=None):
539        loop = TestLoop(gen)
540        self.set_event_loop(loop)
541        return loop
542
543    def unpatch_get_running_loop(self):
544        events._get_running_loop = self._get_running_loop
545
546    def setUp(self):
547        self._get_running_loop = events._get_running_loop
548        events._get_running_loop = lambda: None
549        self._thread_cleanup = support.threading_setup()
550
551    def tearDown(self):
552        self.unpatch_get_running_loop()
553
554        events.set_event_loop(None)
555
556        # Detect CPython bug #23353: ensure that yield/yield-from is not used
557        # in an except block of a generator
558        self.assertEqual(sys.exc_info(), (None, None, None))
559
560        self.doCleanups()
561        support.threading_cleanup(*self._thread_cleanup)
562        support.reap_children()
563
564
565@contextlib.contextmanager
566def disable_logger():
567    """Context manager to disable asyncio logger.
568
569    For example, it can be used to ignore warnings in debug mode.
570    """
571    old_level = logger.level
572    try:
573        logger.setLevel(logging.CRITICAL+1)
574        yield
575    finally:
576        logger.setLevel(old_level)
577
578
579def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
580                            family=socket.AF_INET):
581    """Create a mock of a non-blocking socket."""
582    sock = mock.MagicMock(socket.socket)
583    sock.proto = proto
584    sock.type = type
585    sock.family = family
586    sock.gettimeout.return_value = 0.0
587    return sock
588