1"""Utilities shared by tests."""
2
3import collections
4import contextlib
5import io
6import logging
7import os
8import re
9import selectors
10import socket
11import socketserver
12import sys
13import tempfile
14import threading
15import time
16import unittest
17import weakref
18
19from unittest import mock
20
21from http.server import HTTPServer
22from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
23
24try:
25    import ssl
26except ImportError:  # pragma: no cover
27    ssl = None
28
29from asyncio import base_events
30from asyncio import events
31from asyncio import format_helpers
32from asyncio import futures
33from asyncio import tasks
34from asyncio.log import logger
35from test import support
36
37
38def data_file(filename):
39    if hasattr(support, 'TEST_HOME_DIR'):
40        fullname = os.path.join(support.TEST_HOME_DIR, filename)
41        if os.path.isfile(fullname):
42            return fullname
43    fullname = os.path.join(os.path.dirname(__file__), '..', filename)
44    if os.path.isfile(fullname):
45        return fullname
46    raise FileNotFoundError(filename)
47
48
49ONLYCERT = data_file('ssl_cert.pem')
50ONLYKEY = data_file('ssl_key.pem')
51SIGNED_CERTFILE = data_file('keycert3.pem')
52SIGNING_CA = data_file('pycacert.pem')
53PEERCERT = {
54    'OCSP': ('http://testca.pythontest.net/testca/ocsp/',),
55    'caIssuers': ('http://testca.pythontest.net/testca/pycacert.cer',),
56    'crlDistributionPoints': ('http://testca.pythontest.net/testca/revocation.crl',),
57    'issuer': ((('countryName', 'XY'),),
58            (('organizationName', 'Python Software Foundation CA'),),
59            (('commonName', 'our-ca-server'),)),
60    'notAfter': 'Jul  7 14:23:16 2028 GMT',
61    'notBefore': 'Aug 29 14:23:16 2018 GMT',
62    'serialNumber': 'CB2D80995A69525C',
63    'subject': ((('countryName', 'XY'),),
64             (('localityName', 'Castle Anthrax'),),
65             (('organizationName', 'Python Software Foundation'),),
66             (('commonName', 'localhost'),)),
67    'subjectAltName': (('DNS', 'localhost'),),
68    'version': 3
69}
70
71
72def simple_server_sslcontext():
73    server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
74    server_context.load_cert_chain(ONLYCERT, ONLYKEY)
75    server_context.check_hostname = False
76    server_context.verify_mode = ssl.CERT_NONE
77    return server_context
78
79
80def simple_client_sslcontext(*, disable_verify=True):
81    client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
82    client_context.check_hostname = False
83    if disable_verify:
84        client_context.verify_mode = ssl.CERT_NONE
85    return client_context
86
87
88def dummy_ssl_context():
89    if ssl is None:
90        return None
91    else:
92        return ssl.SSLContext(ssl.PROTOCOL_TLS)
93
94
95def run_briefly(loop):
96    async def once():
97        pass
98    gen = once()
99    t = loop.create_task(gen)
100    # Don't log a warning if the task is not done after run_until_complete().
101    # It occurs if the loop is stopped or if a task raises a BaseException.
102    t._log_destroy_pending = False
103    try:
104        loop.run_until_complete(t)
105    finally:
106        gen.close()
107
108
109def run_until(loop, pred, timeout=30):
110    deadline = time.monotonic() + timeout
111    while not pred():
112        if timeout is not None:
113            timeout = deadline - time.monotonic()
114            if timeout <= 0:
115                raise futures.TimeoutError()
116        loop.run_until_complete(tasks.sleep(0.001, loop=loop))
117
118
119def run_once(loop):
120    """Legacy API to run once through the event loop.
121
122    This is the recommended pattern for test code.  It will poll the
123    selector once and run all callbacks scheduled in response to I/O
124    events.
125    """
126    loop.call_soon(loop.stop)
127    loop.run_forever()
128
129
130class SilentWSGIRequestHandler(WSGIRequestHandler):
131
132    def get_stderr(self):
133        return io.StringIO()
134
135    def log_message(self, format, *args):
136        pass
137
138
139class SilentWSGIServer(WSGIServer):
140
141    request_timeout = 2
142
143    def get_request(self):
144        request, client_addr = super().get_request()
145        request.settimeout(self.request_timeout)
146        return request, client_addr
147
148    def handle_error(self, request, client_address):
149        pass
150
151
152class SSLWSGIServerMixin:
153
154    def finish_request(self, request, client_address):
155        # The relative location of our test directory (which
156        # contains the ssl key and certificate files) differs
157        # between the stdlib and stand-alone asyncio.
158        # Prefer our own if we can find it.
159        context = ssl.SSLContext()
160        context.load_cert_chain(ONLYCERT, ONLYKEY)
161
162        ssock = context.wrap_socket(request, server_side=True)
163        try:
164            self.RequestHandlerClass(ssock, client_address, self)
165            ssock.close()
166        except OSError:
167            # maybe socket has been closed by peer
168            pass
169
170
171class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
172    pass
173
174
175def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
176
177    def app(environ, start_response):
178        status = '200 OK'
179        headers = [('Content-type', 'text/plain')]
180        start_response(status, headers)
181        return [b'Test message']
182
183    # Run the test WSGI server in a separate thread in order not to
184    # interfere with event handling in the main thread
185    server_class = server_ssl_cls if use_ssl else server_cls
186    httpd = server_class(address, SilentWSGIRequestHandler)
187    httpd.set_app(app)
188    httpd.address = httpd.server_address
189    server_thread = threading.Thread(
190        target=lambda: httpd.serve_forever(poll_interval=0.05))
191    server_thread.start()
192    try:
193        yield httpd
194    finally:
195        httpd.shutdown()
196        httpd.server_close()
197        server_thread.join()
198
199
200if hasattr(socket, 'AF_UNIX'):
201
202    class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
203
204        def server_bind(self):
205            socketserver.UnixStreamServer.server_bind(self)
206            self.server_name = '127.0.0.1'
207            self.server_port = 80
208
209
210    class UnixWSGIServer(UnixHTTPServer, WSGIServer):
211
212        request_timeout = 2
213
214        def server_bind(self):
215            UnixHTTPServer.server_bind(self)
216            self.setup_environ()
217
218        def get_request(self):
219            request, client_addr = super().get_request()
220            request.settimeout(self.request_timeout)
221            # Code in the stdlib expects that get_request
222            # will return a socket and a tuple (host, port).
223            # However, this isn't true for UNIX sockets,
224            # as the second return value will be a path;
225            # hence we return some fake data sufficient
226            # to get the tests going
227            return request, ('127.0.0.1', '')
228
229
230    class SilentUnixWSGIServer(UnixWSGIServer):
231
232        def handle_error(self, request, client_address):
233            pass
234
235
236    class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
237        pass
238
239
240    def gen_unix_socket_path():
241        with tempfile.NamedTemporaryFile() as file:
242            return file.name
243
244
245    @contextlib.contextmanager
246    def unix_socket_path():
247        path = gen_unix_socket_path()
248        try:
249            yield path
250        finally:
251            try:
252                os.unlink(path)
253            except OSError:
254                pass
255
256
257    @contextlib.contextmanager
258    def run_test_unix_server(*, use_ssl=False):
259        with unix_socket_path() as path:
260            yield from _run_test_server(address=path, use_ssl=use_ssl,
261                                        server_cls=SilentUnixWSGIServer,
262                                        server_ssl_cls=UnixSSLWSGIServer)
263
264
265@contextlib.contextmanager
266def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
267    yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
268                                server_cls=SilentWSGIServer,
269                                server_ssl_cls=SSLWSGIServer)
270
271
272def make_test_protocol(base):
273    dct = {}
274    for name in dir(base):
275        if name.startswith('__') and name.endswith('__'):
276            # skip magic names
277            continue
278        dct[name] = MockCallback(return_value=None)
279    return type('TestProtocol', (base,) + base.__bases__, dct)()
280
281
282class TestSelector(selectors.BaseSelector):
283
284    def __init__(self):
285        self.keys = {}
286
287    def register(self, fileobj, events, data=None):
288        key = selectors.SelectorKey(fileobj, 0, events, data)
289        self.keys[fileobj] = key
290        return key
291
292    def unregister(self, fileobj):
293        return self.keys.pop(fileobj)
294
295    def select(self, timeout):
296        return []
297
298    def get_map(self):
299        return self.keys
300
301
302class TestLoop(base_events.BaseEventLoop):
303    """Loop for unittests.
304
305    It manages self time directly.
306    If something scheduled to be executed later then
307    on next loop iteration after all ready handlers done
308    generator passed to __init__ is calling.
309
310    Generator should be like this:
311
312        def gen():
313            ...
314            when = yield ...
315            ... = yield time_advance
316
317    Value returned by yield is absolute time of next scheduled handler.
318    Value passed to yield is time advance to move loop's time forward.
319    """
320
321    def __init__(self, gen=None):
322        super().__init__()
323
324        if gen is None:
325            def gen():
326                yield
327            self._check_on_close = False
328        else:
329            self._check_on_close = True
330
331        self._gen = gen()
332        next(self._gen)
333        self._time = 0
334        self._clock_resolution = 1e-9
335        self._timers = []
336        self._selector = TestSelector()
337
338        self.readers = {}
339        self.writers = {}
340        self.reset_counters()
341
342        self._transports = weakref.WeakValueDictionary()
343
344    def time(self):
345        return self._time
346
347    def advance_time(self, advance):
348        """Move test time forward."""
349        if advance:
350            self._time += advance
351
352    def close(self):
353        super().close()
354        if self._check_on_close:
355            try:
356                self._gen.send(0)
357            except StopIteration:
358                pass
359            else:  # pragma: no cover
360                raise AssertionError("Time generator is not finished")
361
362    def _add_reader(self, fd, callback, *args):
363        self.readers[fd] = events.Handle(callback, args, self, None)
364
365    def _remove_reader(self, fd):
366        self.remove_reader_count[fd] += 1
367        if fd in self.readers:
368            del self.readers[fd]
369            return True
370        else:
371            return False
372
373    def assert_reader(self, fd, callback, *args):
374        if fd not in self.readers:
375            raise AssertionError(f'fd {fd} is not registered')
376        handle = self.readers[fd]
377        if handle._callback != callback:
378            raise AssertionError(
379                f'unexpected callback: {handle._callback} != {callback}')
380        if handle._args != args:
381            raise AssertionError(
382                f'unexpected callback args: {handle._args} != {args}')
383
384    def assert_no_reader(self, fd):
385        if fd in self.readers:
386            raise AssertionError(f'fd {fd} is registered')
387
388    def _add_writer(self, fd, callback, *args):
389        self.writers[fd] = events.Handle(callback, args, self, None)
390
391    def _remove_writer(self, fd):
392        self.remove_writer_count[fd] += 1
393        if fd in self.writers:
394            del self.writers[fd]
395            return True
396        else:
397            return False
398
399    def assert_writer(self, fd, callback, *args):
400        assert fd in self.writers, 'fd {} is not registered'.format(fd)
401        handle = self.writers[fd]
402        assert handle._callback == callback, '{!r} != {!r}'.format(
403            handle._callback, callback)
404        assert handle._args == args, '{!r} != {!r}'.format(
405            handle._args, args)
406
407    def _ensure_fd_no_transport(self, fd):
408        if not isinstance(fd, int):
409            try:
410                fd = int(fd.fileno())
411            except (AttributeError, TypeError, ValueError):
412                # This code matches selectors._fileobj_to_fd function.
413                raise ValueError("Invalid file object: "
414                                 "{!r}".format(fd)) from None
415        try:
416            transport = self._transports[fd]
417        except KeyError:
418            pass
419        else:
420            raise RuntimeError(
421                'File descriptor {!r} is used by transport {!r}'.format(
422                    fd, transport))
423
424    def add_reader(self, fd, callback, *args):
425        """Add a reader callback."""
426        self._ensure_fd_no_transport(fd)
427        return self._add_reader(fd, callback, *args)
428
429    def remove_reader(self, fd):
430        """Remove a reader callback."""
431        self._ensure_fd_no_transport(fd)
432        return self._remove_reader(fd)
433
434    def add_writer(self, fd, callback, *args):
435        """Add a writer callback.."""
436        self._ensure_fd_no_transport(fd)
437        return self._add_writer(fd, callback, *args)
438
439    def remove_writer(self, fd):
440        """Remove a writer callback."""
441        self._ensure_fd_no_transport(fd)
442        return self._remove_writer(fd)
443
444    def reset_counters(self):
445        self.remove_reader_count = collections.defaultdict(int)
446        self.remove_writer_count = collections.defaultdict(int)
447
448    def _run_once(self):
449        super()._run_once()
450        for when in self._timers:
451            advance = self._gen.send(when)
452            self.advance_time(advance)
453        self._timers = []
454
455    def call_at(self, when, callback, *args, context=None):
456        self._timers.append(when)
457        return super().call_at(when, callback, *args, context=context)
458
459    def _process_events(self, event_list):
460        return
461
462    def _write_to_self(self):
463        pass
464
465
466def MockCallback(**kwargs):
467    return mock.Mock(spec=['__call__'], **kwargs)
468
469
470class MockPattern(str):
471    """A regex based str with a fuzzy __eq__.
472
473    Use this helper with 'mock.assert_called_with', or anywhere
474    where a regex comparison between strings is needed.
475
476    For instance:
477       mock_call.assert_called_with(MockPattern('spam.*ham'))
478    """
479    def __eq__(self, other):
480        return bool(re.search(str(self), other, re.S))
481
482
483class MockInstanceOf:
484    def __init__(self, type):
485        self._type = type
486
487    def __eq__(self, other):
488        return isinstance(other, self._type)
489
490
491def get_function_source(func):
492    source = format_helpers._get_function_source(func)
493    if source is None:
494        raise ValueError("unable to get the source of %r" % (func,))
495    return source
496
497
498class TestCase(unittest.TestCase):
499    @staticmethod
500    def close_loop(loop):
501        executor = loop._default_executor
502        if executor is not None:
503            executor.shutdown(wait=True)
504        loop.close()
505
506    def set_event_loop(self, loop, *, cleanup=True):
507        assert loop is not None
508        # ensure that the event loop is passed explicitly in asyncio
509        events.set_event_loop(None)
510        if cleanup:
511            self.addCleanup(self.close_loop, loop)
512
513    def new_test_loop(self, gen=None):
514        loop = TestLoop(gen)
515        self.set_event_loop(loop)
516        return loop
517
518    def unpatch_get_running_loop(self):
519        events._get_running_loop = self._get_running_loop
520
521    def setUp(self):
522        self._get_running_loop = events._get_running_loop
523        events._get_running_loop = lambda: None
524        self._thread_cleanup = support.threading_setup()
525
526    def tearDown(self):
527        self.unpatch_get_running_loop()
528
529        events.set_event_loop(None)
530
531        # Detect CPython bug #23353: ensure that yield/yield-from is not used
532        # in an except block of a generator
533        self.assertEqual(sys.exc_info(), (None, None, None))
534
535        self.doCleanups()
536        support.threading_cleanup(*self._thread_cleanup)
537        support.reap_children()
538
539
540@contextlib.contextmanager
541def disable_logger():
542    """Context manager to disable asyncio logger.
543
544    For example, it can be used to ignore warnings in debug mode.
545    """
546    old_level = logger.level
547    try:
548        logger.setLevel(logging.CRITICAL+1)
549        yield
550    finally:
551        logger.setLevel(old_level)
552
553
554def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
555                            family=socket.AF_INET):
556    """Create a mock of a non-blocking socket."""
557    sock = mock.MagicMock(socket.socket)
558    sock.proto = proto
559    sock.type = type
560    sock.family = family
561    sock.gettimeout.return_value = 0.0
562    return sock
563