1"""Tests for unix_events.py."""
2
3import collections
4import contextlib
5import errno
6import io
7import os
8import pathlib
9import signal
10import socket
11import stat
12import sys
13import tempfile
14import threading
15import unittest
16from unittest import mock
17from test import support
18
19if sys.platform == 'win32':
20    raise unittest.SkipTest('UNIX only')
21
22
23import asyncio
24from asyncio import log
25from asyncio import base_events
26from asyncio import events
27from asyncio import unix_events
28from test.test_asyncio import utils as test_utils
29
30
31MOCK_ANY = mock.ANY
32
33
34def close_pipe_transport(transport):
35    # Don't call transport.close() because the event loop and the selector
36    # are mocked
37    if transport._pipe is None:
38        return
39    transport._pipe.close()
40    transport._pipe = None
41
42
43@unittest.skipUnless(signal, 'Signals are not supported')
44class SelectorEventLoopSignalTests(test_utils.TestCase):
45
46    def setUp(self):
47        super().setUp()
48        self.loop = asyncio.SelectorEventLoop()
49        self.set_event_loop(self.loop)
50
51    def test_check_signal(self):
52        self.assertRaises(
53            TypeError, self.loop._check_signal, '1')
54        self.assertRaises(
55            ValueError, self.loop._check_signal, signal.NSIG + 1)
56
57    def test_handle_signal_no_handler(self):
58        self.loop._handle_signal(signal.NSIG + 1)
59
60    def test_handle_signal_cancelled_handler(self):
61        h = asyncio.Handle(mock.Mock(), (),
62                           loop=mock.Mock())
63        h.cancel()
64        self.loop._signal_handlers[signal.NSIG + 1] = h
65        self.loop.remove_signal_handler = mock.Mock()
66        self.loop._handle_signal(signal.NSIG + 1)
67        self.loop.remove_signal_handler.assert_called_with(signal.NSIG + 1)
68
69    @mock.patch('asyncio.unix_events.signal')
70    def test_add_signal_handler_setup_error(self, m_signal):
71        m_signal.NSIG = signal.NSIG
72        m_signal.set_wakeup_fd.side_effect = ValueError
73
74        self.assertRaises(
75            RuntimeError,
76            self.loop.add_signal_handler,
77            signal.SIGINT, lambda: True)
78
79    @mock.patch('asyncio.unix_events.signal')
80    def test_add_signal_handler_coroutine_error(self, m_signal):
81        m_signal.NSIG = signal.NSIG
82
83        async def simple_coroutine():
84            pass
85
86        # callback must not be a coroutine function
87        coro_func = simple_coroutine
88        coro_obj = coro_func()
89        self.addCleanup(coro_obj.close)
90        for func in (coro_func, coro_obj):
91            self.assertRaisesRegex(
92                TypeError, 'coroutines cannot be used with add_signal_handler',
93                self.loop.add_signal_handler,
94                signal.SIGINT, func)
95
96    @mock.patch('asyncio.unix_events.signal')
97    def test_add_signal_handler(self, m_signal):
98        m_signal.NSIG = signal.NSIG
99
100        cb = lambda: True
101        self.loop.add_signal_handler(signal.SIGHUP, cb)
102        h = self.loop._signal_handlers.get(signal.SIGHUP)
103        self.assertIsInstance(h, asyncio.Handle)
104        self.assertEqual(h._callback, cb)
105
106    @mock.patch('asyncio.unix_events.signal')
107    def test_add_signal_handler_install_error(self, m_signal):
108        m_signal.NSIG = signal.NSIG
109
110        def set_wakeup_fd(fd):
111            if fd == -1:
112                raise ValueError()
113        m_signal.set_wakeup_fd = set_wakeup_fd
114
115        class Err(OSError):
116            errno = errno.EFAULT
117        m_signal.signal.side_effect = Err
118
119        self.assertRaises(
120            Err,
121            self.loop.add_signal_handler,
122            signal.SIGINT, lambda: True)
123
124    @mock.patch('asyncio.unix_events.signal')
125    @mock.patch('asyncio.base_events.logger')
126    def test_add_signal_handler_install_error2(self, m_logging, m_signal):
127        m_signal.NSIG = signal.NSIG
128
129        class Err(OSError):
130            errno = errno.EINVAL
131        m_signal.signal.side_effect = Err
132
133        self.loop._signal_handlers[signal.SIGHUP] = lambda: True
134        self.assertRaises(
135            RuntimeError,
136            self.loop.add_signal_handler,
137            signal.SIGINT, lambda: True)
138        self.assertFalse(m_logging.info.called)
139        self.assertEqual(1, m_signal.set_wakeup_fd.call_count)
140
141    @mock.patch('asyncio.unix_events.signal')
142    @mock.patch('asyncio.base_events.logger')
143    def test_add_signal_handler_install_error3(self, m_logging, m_signal):
144        class Err(OSError):
145            errno = errno.EINVAL
146        m_signal.signal.side_effect = Err
147        m_signal.NSIG = signal.NSIG
148
149        self.assertRaises(
150            RuntimeError,
151            self.loop.add_signal_handler,
152            signal.SIGINT, lambda: True)
153        self.assertFalse(m_logging.info.called)
154        self.assertEqual(2, m_signal.set_wakeup_fd.call_count)
155
156    @mock.patch('asyncio.unix_events.signal')
157    def test_remove_signal_handler(self, m_signal):
158        m_signal.NSIG = signal.NSIG
159
160        self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
161
162        self.assertTrue(
163            self.loop.remove_signal_handler(signal.SIGHUP))
164        self.assertTrue(m_signal.set_wakeup_fd.called)
165        self.assertTrue(m_signal.signal.called)
166        self.assertEqual(
167            (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0])
168
169    @mock.patch('asyncio.unix_events.signal')
170    def test_remove_signal_handler_2(self, m_signal):
171        m_signal.NSIG = signal.NSIG
172        m_signal.SIGINT = signal.SIGINT
173
174        self.loop.add_signal_handler(signal.SIGINT, lambda: True)
175        self.loop._signal_handlers[signal.SIGHUP] = object()
176        m_signal.set_wakeup_fd.reset_mock()
177
178        self.assertTrue(
179            self.loop.remove_signal_handler(signal.SIGINT))
180        self.assertFalse(m_signal.set_wakeup_fd.called)
181        self.assertTrue(m_signal.signal.called)
182        self.assertEqual(
183            (signal.SIGINT, m_signal.default_int_handler),
184            m_signal.signal.call_args[0])
185
186    @mock.patch('asyncio.unix_events.signal')
187    @mock.patch('asyncio.base_events.logger')
188    def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal):
189        m_signal.NSIG = signal.NSIG
190        self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
191
192        m_signal.set_wakeup_fd.side_effect = ValueError
193
194        self.loop.remove_signal_handler(signal.SIGHUP)
195        self.assertTrue(m_logging.info)
196
197    @mock.patch('asyncio.unix_events.signal')
198    def test_remove_signal_handler_error(self, m_signal):
199        m_signal.NSIG = signal.NSIG
200        self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
201
202        m_signal.signal.side_effect = OSError
203
204        self.assertRaises(
205            OSError, self.loop.remove_signal_handler, signal.SIGHUP)
206
207    @mock.patch('asyncio.unix_events.signal')
208    def test_remove_signal_handler_error2(self, m_signal):
209        m_signal.NSIG = signal.NSIG
210        self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
211
212        class Err(OSError):
213            errno = errno.EINVAL
214        m_signal.signal.side_effect = Err
215
216        self.assertRaises(
217            RuntimeError, self.loop.remove_signal_handler, signal.SIGHUP)
218
219    @mock.patch('asyncio.unix_events.signal')
220    def test_close(self, m_signal):
221        m_signal.NSIG = signal.NSIG
222
223        self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
224        self.loop.add_signal_handler(signal.SIGCHLD, lambda: True)
225
226        self.assertEqual(len(self.loop._signal_handlers), 2)
227
228        m_signal.set_wakeup_fd.reset_mock()
229
230        self.loop.close()
231
232        self.assertEqual(len(self.loop._signal_handlers), 0)
233        m_signal.set_wakeup_fd.assert_called_once_with(-1)
234
235    @mock.patch('asyncio.unix_events.sys')
236    @mock.patch('asyncio.unix_events.signal')
237    def test_close_on_finalizing(self, m_signal, m_sys):
238        m_signal.NSIG = signal.NSIG
239        self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
240
241        self.assertEqual(len(self.loop._signal_handlers), 1)
242        m_sys.is_finalizing.return_value = True
243        m_signal.signal.reset_mock()
244
245        with self.assertWarnsRegex(ResourceWarning,
246                                   "skipping signal handlers removal"):
247            self.loop.close()
248
249        self.assertEqual(len(self.loop._signal_handlers), 0)
250        self.assertFalse(m_signal.signal.called)
251
252
253@unittest.skipUnless(hasattr(socket, 'AF_UNIX'),
254                     'UNIX Sockets are not supported')
255class SelectorEventLoopUnixSocketTests(test_utils.TestCase):
256
257    def setUp(self):
258        super().setUp()
259        self.loop = asyncio.SelectorEventLoop()
260        self.set_event_loop(self.loop)
261
262    @support.skip_unless_bind_unix_socket
263    def test_create_unix_server_existing_path_sock(self):
264        with test_utils.unix_socket_path() as path:
265            sock = socket.socket(socket.AF_UNIX)
266            sock.bind(path)
267            sock.listen(1)
268            sock.close()
269
270            coro = self.loop.create_unix_server(lambda: None, path)
271            srv = self.loop.run_until_complete(coro)
272            srv.close()
273            self.loop.run_until_complete(srv.wait_closed())
274
275    @support.skip_unless_bind_unix_socket
276    def test_create_unix_server_pathlib(self):
277        with test_utils.unix_socket_path() as path:
278            path = pathlib.Path(path)
279            srv_coro = self.loop.create_unix_server(lambda: None, path)
280            srv = self.loop.run_until_complete(srv_coro)
281            srv.close()
282            self.loop.run_until_complete(srv.wait_closed())
283
284    def test_create_unix_connection_pathlib(self):
285        with test_utils.unix_socket_path() as path:
286            path = pathlib.Path(path)
287            coro = self.loop.create_unix_connection(lambda: None, path)
288            with self.assertRaises(FileNotFoundError):
289                # If pathlib.Path wasn't supported, the exception would be
290                # different.
291                self.loop.run_until_complete(coro)
292
293    def test_create_unix_server_existing_path_nonsock(self):
294        with tempfile.NamedTemporaryFile() as file:
295            coro = self.loop.create_unix_server(lambda: None, file.name)
296            with self.assertRaisesRegex(OSError,
297                                        'Address.*is already in use'):
298                self.loop.run_until_complete(coro)
299
300    def test_create_unix_server_ssl_bool(self):
301        coro = self.loop.create_unix_server(lambda: None, path='spam',
302                                            ssl=True)
303        with self.assertRaisesRegex(TypeError,
304                                    'ssl argument must be an SSLContext'):
305            self.loop.run_until_complete(coro)
306
307    def test_create_unix_server_nopath_nosock(self):
308        coro = self.loop.create_unix_server(lambda: None, path=None)
309        with self.assertRaisesRegex(ValueError,
310                                    'path was not specified, and no sock'):
311            self.loop.run_until_complete(coro)
312
313    def test_create_unix_server_path_inetsock(self):
314        sock = socket.socket()
315        with sock:
316            coro = self.loop.create_unix_server(lambda: None, path=None,
317                                                sock=sock)
318            with self.assertRaisesRegex(ValueError,
319                                        'A UNIX Domain Stream.*was expected'):
320                self.loop.run_until_complete(coro)
321
322    def test_create_unix_server_path_dgram(self):
323        sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
324        with sock:
325            coro = self.loop.create_unix_server(lambda: None, path=None,
326                                                sock=sock)
327            with self.assertRaisesRegex(ValueError,
328                                        'A UNIX Domain Stream.*was expected'):
329                self.loop.run_until_complete(coro)
330
331    @unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'),
332                         'no socket.SOCK_NONBLOCK (linux only)')
333    @support.skip_unless_bind_unix_socket
334    def test_create_unix_server_path_stream_bittype(self):
335        sock = socket.socket(
336            socket.AF_UNIX, socket.SOCK_STREAM | socket.SOCK_NONBLOCK)
337        with tempfile.NamedTemporaryFile() as file:
338            fn = file.name
339        try:
340            with sock:
341                sock.bind(fn)
342                coro = self.loop.create_unix_server(lambda: None, path=None,
343                                                    sock=sock)
344                srv = self.loop.run_until_complete(coro)
345                srv.close()
346                self.loop.run_until_complete(srv.wait_closed())
347        finally:
348            os.unlink(fn)
349
350    def test_create_unix_server_ssl_timeout_with_plain_sock(self):
351        coro = self.loop.create_unix_server(lambda: None, path='spam',
352                                            ssl_handshake_timeout=1)
353        with self.assertRaisesRegex(
354                ValueError,
355                'ssl_handshake_timeout is only meaningful with ssl'):
356            self.loop.run_until_complete(coro)
357
358    def test_create_unix_connection_path_inetsock(self):
359        sock = socket.socket()
360        with sock:
361            coro = self.loop.create_unix_connection(lambda: None,
362                                                    sock=sock)
363            with self.assertRaisesRegex(ValueError,
364                                        'A UNIX Domain Stream.*was expected'):
365                self.loop.run_until_complete(coro)
366
367    @mock.patch('asyncio.unix_events.socket')
368    def test_create_unix_server_bind_error(self, m_socket):
369        # Ensure that the socket is closed on any bind error
370        sock = mock.Mock()
371        m_socket.socket.return_value = sock
372
373        sock.bind.side_effect = OSError
374        coro = self.loop.create_unix_server(lambda: None, path="/test")
375        with self.assertRaises(OSError):
376            self.loop.run_until_complete(coro)
377        self.assertTrue(sock.close.called)
378
379        sock.bind.side_effect = MemoryError
380        coro = self.loop.create_unix_server(lambda: None, path="/test")
381        with self.assertRaises(MemoryError):
382            self.loop.run_until_complete(coro)
383        self.assertTrue(sock.close.called)
384
385    def test_create_unix_connection_path_sock(self):
386        coro = self.loop.create_unix_connection(
387            lambda: None, os.devnull, sock=object())
388        with self.assertRaisesRegex(ValueError, 'path and sock can not be'):
389            self.loop.run_until_complete(coro)
390
391    def test_create_unix_connection_nopath_nosock(self):
392        coro = self.loop.create_unix_connection(
393            lambda: None, None)
394        with self.assertRaisesRegex(ValueError,
395                                    'no path and sock were specified'):
396            self.loop.run_until_complete(coro)
397
398    def test_create_unix_connection_nossl_serverhost(self):
399        coro = self.loop.create_unix_connection(
400            lambda: None, os.devnull, server_hostname='spam')
401        with self.assertRaisesRegex(ValueError,
402                                    'server_hostname is only meaningful'):
403            self.loop.run_until_complete(coro)
404
405    def test_create_unix_connection_ssl_noserverhost(self):
406        coro = self.loop.create_unix_connection(
407            lambda: None, os.devnull, ssl=True)
408
409        with self.assertRaisesRegex(
410            ValueError, 'you have to pass server_hostname when using ssl'):
411
412            self.loop.run_until_complete(coro)
413
414    def test_create_unix_connection_ssl_timeout_with_plain_sock(self):
415        coro = self.loop.create_unix_connection(lambda: None, path='spam',
416                                            ssl_handshake_timeout=1)
417        with self.assertRaisesRegex(
418                ValueError,
419                'ssl_handshake_timeout is only meaningful with ssl'):
420            self.loop.run_until_complete(coro)
421
422
423@unittest.skipUnless(hasattr(os, 'sendfile'),
424                     'sendfile is not supported')
425class SelectorEventLoopUnixSockSendfileTests(test_utils.TestCase):
426    DATA = b"12345abcde" * 16 * 1024  # 160 KiB
427
428    class MyProto(asyncio.Protocol):
429
430        def __init__(self, loop):
431            self.started = False
432            self.closed = False
433            self.data = bytearray()
434            self.fut = loop.create_future()
435            self.transport = None
436            self._ready = loop.create_future()
437
438        def connection_made(self, transport):
439            self.started = True
440            self.transport = transport
441            self._ready.set_result(None)
442
443        def data_received(self, data):
444            self.data.extend(data)
445
446        def connection_lost(self, exc):
447            self.closed = True
448            self.fut.set_result(None)
449
450        async def wait_closed(self):
451            await self.fut
452
453    @classmethod
454    def setUpClass(cls):
455        with open(support.TESTFN, 'wb') as fp:
456            fp.write(cls.DATA)
457        super().setUpClass()
458
459    @classmethod
460    def tearDownClass(cls):
461        support.unlink(support.TESTFN)
462        super().tearDownClass()
463
464    def setUp(self):
465        self.loop = asyncio.new_event_loop()
466        self.set_event_loop(self.loop)
467        self.file = open(support.TESTFN, 'rb')
468        self.addCleanup(self.file.close)
469        super().setUp()
470
471    def make_socket(self, cleanup=True):
472        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
473        sock.setblocking(False)
474        sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
475        sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024)
476        if cleanup:
477            self.addCleanup(sock.close)
478        return sock
479
480    def run_loop(self, coro):
481        return self.loop.run_until_complete(coro)
482
483    def prepare(self):
484        sock = self.make_socket()
485        proto = self.MyProto(self.loop)
486        port = support.find_unused_port()
487        srv_sock = self.make_socket(cleanup=False)
488        srv_sock.bind((support.HOST, port))
489        server = self.run_loop(self.loop.create_server(
490            lambda: proto, sock=srv_sock))
491        self.run_loop(self.loop.sock_connect(sock, (support.HOST, port)))
492        self.run_loop(proto._ready)
493
494        def cleanup():
495            proto.transport.close()
496            self.run_loop(proto.wait_closed())
497
498            server.close()
499            self.run_loop(server.wait_closed())
500
501        self.addCleanup(cleanup)
502
503        return sock, proto
504
505    def test_sock_sendfile_not_available(self):
506        sock, proto = self.prepare()
507        with mock.patch('asyncio.unix_events.os', spec=[]):
508            with self.assertRaisesRegex(events.SendfileNotAvailableError,
509                                        "os[.]sendfile[(][)] is not available"):
510                self.run_loop(self.loop._sock_sendfile_native(sock, self.file,
511                                                              0, None))
512        self.assertEqual(self.file.tell(), 0)
513
514    def test_sock_sendfile_not_a_file(self):
515        sock, proto = self.prepare()
516        f = object()
517        with self.assertRaisesRegex(events.SendfileNotAvailableError,
518                                    "not a regular file"):
519            self.run_loop(self.loop._sock_sendfile_native(sock, f,
520                                                          0, None))
521        self.assertEqual(self.file.tell(), 0)
522
523    def test_sock_sendfile_iobuffer(self):
524        sock, proto = self.prepare()
525        f = io.BytesIO()
526        with self.assertRaisesRegex(events.SendfileNotAvailableError,
527                                    "not a regular file"):
528            self.run_loop(self.loop._sock_sendfile_native(sock, f,
529                                                          0, None))
530        self.assertEqual(self.file.tell(), 0)
531
532    def test_sock_sendfile_not_regular_file(self):
533        sock, proto = self.prepare()
534        f = mock.Mock()
535        f.fileno.return_value = -1
536        with self.assertRaisesRegex(events.SendfileNotAvailableError,
537                                    "not a regular file"):
538            self.run_loop(self.loop._sock_sendfile_native(sock, f,
539                                                          0, None))
540        self.assertEqual(self.file.tell(), 0)
541
542    def test_sock_sendfile_cancel1(self):
543        sock, proto = self.prepare()
544
545        fut = self.loop.create_future()
546        fileno = self.file.fileno()
547        self.loop._sock_sendfile_native_impl(fut, None, sock, fileno,
548                                             0, None, len(self.DATA), 0)
549        fut.cancel()
550        with contextlib.suppress(asyncio.CancelledError):
551            self.run_loop(fut)
552        with self.assertRaises(KeyError):
553            self.loop._selector.get_key(sock)
554
555    def test_sock_sendfile_cancel2(self):
556        sock, proto = self.prepare()
557
558        fut = self.loop.create_future()
559        fileno = self.file.fileno()
560        self.loop._sock_sendfile_native_impl(fut, None, sock, fileno,
561                                             0, None, len(self.DATA), 0)
562        fut.cancel()
563        self.loop._sock_sendfile_native_impl(fut, sock.fileno(), sock, fileno,
564                                             0, None, len(self.DATA), 0)
565        with self.assertRaises(KeyError):
566            self.loop._selector.get_key(sock)
567
568    def test_sock_sendfile_blocking_error(self):
569        sock, proto = self.prepare()
570
571        fileno = self.file.fileno()
572        fut = mock.Mock()
573        fut.cancelled.return_value = False
574        with mock.patch('os.sendfile', side_effect=BlockingIOError()):
575            self.loop._sock_sendfile_native_impl(fut, None, sock, fileno,
576                                                 0, None, len(self.DATA), 0)
577        key = self.loop._selector.get_key(sock)
578        self.assertIsNotNone(key)
579        fut.add_done_callback.assert_called_once_with(mock.ANY)
580
581    def test_sock_sendfile_os_error_first_call(self):
582        sock, proto = self.prepare()
583
584        fileno = self.file.fileno()
585        fut = self.loop.create_future()
586        with mock.patch('os.sendfile', side_effect=OSError()):
587            self.loop._sock_sendfile_native_impl(fut, None, sock, fileno,
588                                                 0, None, len(self.DATA), 0)
589        with self.assertRaises(KeyError):
590            self.loop._selector.get_key(sock)
591        exc = fut.exception()
592        self.assertIsInstance(exc, events.SendfileNotAvailableError)
593        self.assertEqual(0, self.file.tell())
594
595    def test_sock_sendfile_os_error_next_call(self):
596        sock, proto = self.prepare()
597
598        fileno = self.file.fileno()
599        fut = self.loop.create_future()
600        err = OSError()
601        with mock.patch('os.sendfile', side_effect=err):
602            self.loop._sock_sendfile_native_impl(fut, sock.fileno(),
603                                                 sock, fileno,
604                                                 1000, None, len(self.DATA),
605                                                 1000)
606        with self.assertRaises(KeyError):
607            self.loop._selector.get_key(sock)
608        exc = fut.exception()
609        self.assertIs(exc, err)
610        self.assertEqual(1000, self.file.tell())
611
612    def test_sock_sendfile_exception(self):
613        sock, proto = self.prepare()
614
615        fileno = self.file.fileno()
616        fut = self.loop.create_future()
617        err = events.SendfileNotAvailableError()
618        with mock.patch('os.sendfile', side_effect=err):
619            self.loop._sock_sendfile_native_impl(fut, sock.fileno(),
620                                                 sock, fileno,
621                                                 1000, None, len(self.DATA),
622                                                 1000)
623        with self.assertRaises(KeyError):
624            self.loop._selector.get_key(sock)
625        exc = fut.exception()
626        self.assertIs(exc, err)
627        self.assertEqual(1000, self.file.tell())
628
629
630class UnixReadPipeTransportTests(test_utils.TestCase):
631
632    def setUp(self):
633        super().setUp()
634        self.loop = self.new_test_loop()
635        self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
636        self.pipe = mock.Mock(spec_set=io.RawIOBase)
637        self.pipe.fileno.return_value = 5
638
639        blocking_patcher = mock.patch('os.set_blocking')
640        blocking_patcher.start()
641        self.addCleanup(blocking_patcher.stop)
642
643        fstat_patcher = mock.patch('os.fstat')
644        m_fstat = fstat_patcher.start()
645        st = mock.Mock()
646        st.st_mode = stat.S_IFIFO
647        m_fstat.return_value = st
648        self.addCleanup(fstat_patcher.stop)
649
650    def read_pipe_transport(self, waiter=None):
651        transport = unix_events._UnixReadPipeTransport(self.loop, self.pipe,
652                                                       self.protocol,
653                                                       waiter=waiter)
654        self.addCleanup(close_pipe_transport, transport)
655        return transport
656
657    def test_ctor(self):
658        waiter = asyncio.Future(loop=self.loop)
659        tr = self.read_pipe_transport(waiter=waiter)
660        self.loop.run_until_complete(waiter)
661
662        self.protocol.connection_made.assert_called_with(tr)
663        self.loop.assert_reader(5, tr._read_ready)
664        self.assertIsNone(waiter.result())
665
666    @mock.patch('os.read')
667    def test__read_ready(self, m_read):
668        tr = self.read_pipe_transport()
669        m_read.return_value = b'data'
670        tr._read_ready()
671
672        m_read.assert_called_with(5, tr.max_size)
673        self.protocol.data_received.assert_called_with(b'data')
674
675    @mock.patch('os.read')
676    def test__read_ready_eof(self, m_read):
677        tr = self.read_pipe_transport()
678        m_read.return_value = b''
679        tr._read_ready()
680
681        m_read.assert_called_with(5, tr.max_size)
682        self.assertFalse(self.loop.readers)
683        test_utils.run_briefly(self.loop)
684        self.protocol.eof_received.assert_called_with()
685        self.protocol.connection_lost.assert_called_with(None)
686
687    @mock.patch('os.read')
688    def test__read_ready_blocked(self, m_read):
689        tr = self.read_pipe_transport()
690        m_read.side_effect = BlockingIOError
691        tr._read_ready()
692
693        m_read.assert_called_with(5, tr.max_size)
694        test_utils.run_briefly(self.loop)
695        self.assertFalse(self.protocol.data_received.called)
696
697    @mock.patch('asyncio.log.logger.error')
698    @mock.patch('os.read')
699    def test__read_ready_error(self, m_read, m_logexc):
700        tr = self.read_pipe_transport()
701        err = OSError()
702        m_read.side_effect = err
703        tr._close = mock.Mock()
704        tr._read_ready()
705
706        m_read.assert_called_with(5, tr.max_size)
707        tr._close.assert_called_with(err)
708        m_logexc.assert_called_with(
709            test_utils.MockPattern(
710                'Fatal read error on pipe transport'
711                '\nprotocol:.*\ntransport:.*'),
712            exc_info=(OSError, MOCK_ANY, MOCK_ANY))
713
714    @mock.patch('os.read')
715    def test_pause_reading(self, m_read):
716        tr = self.read_pipe_transport()
717        m = mock.Mock()
718        self.loop.add_reader(5, m)
719        tr.pause_reading()
720        self.assertFalse(self.loop.readers)
721
722    @mock.patch('os.read')
723    def test_resume_reading(self, m_read):
724        tr = self.read_pipe_transport()
725        tr.resume_reading()
726        self.loop.assert_reader(5, tr._read_ready)
727
728    @mock.patch('os.read')
729    def test_close(self, m_read):
730        tr = self.read_pipe_transport()
731        tr._close = mock.Mock()
732        tr.close()
733        tr._close.assert_called_with(None)
734
735    @mock.patch('os.read')
736    def test_close_already_closing(self, m_read):
737        tr = self.read_pipe_transport()
738        tr._closing = True
739        tr._close = mock.Mock()
740        tr.close()
741        self.assertFalse(tr._close.called)
742
743    @mock.patch('os.read')
744    def test__close(self, m_read):
745        tr = self.read_pipe_transport()
746        err = object()
747        tr._close(err)
748        self.assertTrue(tr.is_closing())
749        self.assertFalse(self.loop.readers)
750        test_utils.run_briefly(self.loop)
751        self.protocol.connection_lost.assert_called_with(err)
752
753    def test__call_connection_lost(self):
754        tr = self.read_pipe_transport()
755        self.assertIsNotNone(tr._protocol)
756        self.assertIsNotNone(tr._loop)
757
758        err = None
759        tr._call_connection_lost(err)
760        self.protocol.connection_lost.assert_called_with(err)
761        self.pipe.close.assert_called_with()
762
763        self.assertIsNone(tr._protocol)
764        self.assertIsNone(tr._loop)
765
766    def test__call_connection_lost_with_err(self):
767        tr = self.read_pipe_transport()
768        self.assertIsNotNone(tr._protocol)
769        self.assertIsNotNone(tr._loop)
770
771        err = OSError()
772        tr._call_connection_lost(err)
773        self.protocol.connection_lost.assert_called_with(err)
774        self.pipe.close.assert_called_with()
775
776        self.assertIsNone(tr._protocol)
777        self.assertIsNone(tr._loop)
778
779
780class UnixWritePipeTransportTests(test_utils.TestCase):
781
782    def setUp(self):
783        super().setUp()
784        self.loop = self.new_test_loop()
785        self.protocol = test_utils.make_test_protocol(asyncio.BaseProtocol)
786        self.pipe = mock.Mock(spec_set=io.RawIOBase)
787        self.pipe.fileno.return_value = 5
788
789        blocking_patcher = mock.patch('os.set_blocking')
790        blocking_patcher.start()
791        self.addCleanup(blocking_patcher.stop)
792
793        fstat_patcher = mock.patch('os.fstat')
794        m_fstat = fstat_patcher.start()
795        st = mock.Mock()
796        st.st_mode = stat.S_IFSOCK
797        m_fstat.return_value = st
798        self.addCleanup(fstat_patcher.stop)
799
800    def write_pipe_transport(self, waiter=None):
801        transport = unix_events._UnixWritePipeTransport(self.loop, self.pipe,
802                                                        self.protocol,
803                                                        waiter=waiter)
804        self.addCleanup(close_pipe_transport, transport)
805        return transport
806
807    def test_ctor(self):
808        waiter = asyncio.Future(loop=self.loop)
809        tr = self.write_pipe_transport(waiter=waiter)
810        self.loop.run_until_complete(waiter)
811
812        self.protocol.connection_made.assert_called_with(tr)
813        self.loop.assert_reader(5, tr._read_ready)
814        self.assertEqual(None, waiter.result())
815
816    def test_can_write_eof(self):
817        tr = self.write_pipe_transport()
818        self.assertTrue(tr.can_write_eof())
819
820    @mock.patch('os.write')
821    def test_write(self, m_write):
822        tr = self.write_pipe_transport()
823        m_write.return_value = 4
824        tr.write(b'data')
825        m_write.assert_called_with(5, b'data')
826        self.assertFalse(self.loop.writers)
827        self.assertEqual(bytearray(), tr._buffer)
828
829    @mock.patch('os.write')
830    def test_write_no_data(self, m_write):
831        tr = self.write_pipe_transport()
832        tr.write(b'')
833        self.assertFalse(m_write.called)
834        self.assertFalse(self.loop.writers)
835        self.assertEqual(bytearray(b''), tr._buffer)
836
837    @mock.patch('os.write')
838    def test_write_partial(self, m_write):
839        tr = self.write_pipe_transport()
840        m_write.return_value = 2
841        tr.write(b'data')
842        self.loop.assert_writer(5, tr._write_ready)
843        self.assertEqual(bytearray(b'ta'), tr._buffer)
844
845    @mock.patch('os.write')
846    def test_write_buffer(self, m_write):
847        tr = self.write_pipe_transport()
848        self.loop.add_writer(5, tr._write_ready)
849        tr._buffer = bytearray(b'previous')
850        tr.write(b'data')
851        self.assertFalse(m_write.called)
852        self.loop.assert_writer(5, tr._write_ready)
853        self.assertEqual(bytearray(b'previousdata'), tr._buffer)
854
855    @mock.patch('os.write')
856    def test_write_again(self, m_write):
857        tr = self.write_pipe_transport()
858        m_write.side_effect = BlockingIOError()
859        tr.write(b'data')
860        m_write.assert_called_with(5, bytearray(b'data'))
861        self.loop.assert_writer(5, tr._write_ready)
862        self.assertEqual(bytearray(b'data'), tr._buffer)
863
864    @mock.patch('asyncio.unix_events.logger')
865    @mock.patch('os.write')
866    def test_write_err(self, m_write, m_log):
867        tr = self.write_pipe_transport()
868        err = OSError()
869        m_write.side_effect = err
870        tr._fatal_error = mock.Mock()
871        tr.write(b'data')
872        m_write.assert_called_with(5, b'data')
873        self.assertFalse(self.loop.writers)
874        self.assertEqual(bytearray(), tr._buffer)
875        tr._fatal_error.assert_called_with(
876                            err,
877                            'Fatal write error on pipe transport')
878        self.assertEqual(1, tr._conn_lost)
879
880        tr.write(b'data')
881        self.assertEqual(2, tr._conn_lost)
882        tr.write(b'data')
883        tr.write(b'data')
884        tr.write(b'data')
885        tr.write(b'data')
886        # This is a bit overspecified. :-(
887        m_log.warning.assert_called_with(
888            'pipe closed by peer or os.write(pipe, data) raised exception.')
889        tr.close()
890
891    @mock.patch('os.write')
892    def test_write_close(self, m_write):
893        tr = self.write_pipe_transport()
894        tr._read_ready()  # pipe was closed by peer
895
896        tr.write(b'data')
897        self.assertEqual(tr._conn_lost, 1)
898        tr.write(b'data')
899        self.assertEqual(tr._conn_lost, 2)
900
901    def test__read_ready(self):
902        tr = self.write_pipe_transport()
903        tr._read_ready()
904        self.assertFalse(self.loop.readers)
905        self.assertFalse(self.loop.writers)
906        self.assertTrue(tr.is_closing())
907        test_utils.run_briefly(self.loop)
908        self.protocol.connection_lost.assert_called_with(None)
909
910    @mock.patch('os.write')
911    def test__write_ready(self, m_write):
912        tr = self.write_pipe_transport()
913        self.loop.add_writer(5, tr._write_ready)
914        tr._buffer = bytearray(b'data')
915        m_write.return_value = 4
916        tr._write_ready()
917        self.assertFalse(self.loop.writers)
918        self.assertEqual(bytearray(), tr._buffer)
919
920    @mock.patch('os.write')
921    def test__write_ready_partial(self, m_write):
922        tr = self.write_pipe_transport()
923        self.loop.add_writer(5, tr._write_ready)
924        tr._buffer = bytearray(b'data')
925        m_write.return_value = 3
926        tr._write_ready()
927        self.loop.assert_writer(5, tr._write_ready)
928        self.assertEqual(bytearray(b'a'), tr._buffer)
929
930    @mock.patch('os.write')
931    def test__write_ready_again(self, m_write):
932        tr = self.write_pipe_transport()
933        self.loop.add_writer(5, tr._write_ready)
934        tr._buffer = bytearray(b'data')
935        m_write.side_effect = BlockingIOError()
936        tr._write_ready()
937        m_write.assert_called_with(5, bytearray(b'data'))
938        self.loop.assert_writer(5, tr._write_ready)
939        self.assertEqual(bytearray(b'data'), tr._buffer)
940
941    @mock.patch('os.write')
942    def test__write_ready_empty(self, m_write):
943        tr = self.write_pipe_transport()
944        self.loop.add_writer(5, tr._write_ready)
945        tr._buffer = bytearray(b'data')
946        m_write.return_value = 0
947        tr._write_ready()
948        m_write.assert_called_with(5, bytearray(b'data'))
949        self.loop.assert_writer(5, tr._write_ready)
950        self.assertEqual(bytearray(b'data'), tr._buffer)
951
952    @mock.patch('asyncio.log.logger.error')
953    @mock.patch('os.write')
954    def test__write_ready_err(self, m_write, m_logexc):
955        tr = self.write_pipe_transport()
956        self.loop.add_writer(5, tr._write_ready)
957        tr._buffer = bytearray(b'data')
958        m_write.side_effect = err = OSError()
959        tr._write_ready()
960        self.assertFalse(self.loop.writers)
961        self.assertFalse(self.loop.readers)
962        self.assertEqual(bytearray(), tr._buffer)
963        self.assertTrue(tr.is_closing())
964        m_logexc.assert_called_with(
965            test_utils.MockPattern(
966                'Fatal write error on pipe transport'
967                '\nprotocol:.*\ntransport:.*'),
968            exc_info=(OSError, MOCK_ANY, MOCK_ANY))
969        self.assertEqual(1, tr._conn_lost)
970        test_utils.run_briefly(self.loop)
971        self.protocol.connection_lost.assert_called_with(err)
972
973    @mock.patch('os.write')
974    def test__write_ready_closing(self, m_write):
975        tr = self.write_pipe_transport()
976        self.loop.add_writer(5, tr._write_ready)
977        tr._closing = True
978        tr._buffer = bytearray(b'data')
979        m_write.return_value = 4
980        tr._write_ready()
981        self.assertFalse(self.loop.writers)
982        self.assertFalse(self.loop.readers)
983        self.assertEqual(bytearray(), tr._buffer)
984        self.protocol.connection_lost.assert_called_with(None)
985        self.pipe.close.assert_called_with()
986
987    @mock.patch('os.write')
988    def test_abort(self, m_write):
989        tr = self.write_pipe_transport()
990        self.loop.add_writer(5, tr._write_ready)
991        self.loop.add_reader(5, tr._read_ready)
992        tr._buffer = [b'da', b'ta']
993        tr.abort()
994        self.assertFalse(m_write.called)
995        self.assertFalse(self.loop.readers)
996        self.assertFalse(self.loop.writers)
997        self.assertEqual([], tr._buffer)
998        self.assertTrue(tr.is_closing())
999        test_utils.run_briefly(self.loop)
1000        self.protocol.connection_lost.assert_called_with(None)
1001
1002    def test__call_connection_lost(self):
1003        tr = self.write_pipe_transport()
1004        self.assertIsNotNone(tr._protocol)
1005        self.assertIsNotNone(tr._loop)
1006
1007        err = None
1008        tr._call_connection_lost(err)
1009        self.protocol.connection_lost.assert_called_with(err)
1010        self.pipe.close.assert_called_with()
1011
1012        self.assertIsNone(tr._protocol)
1013        self.assertIsNone(tr._loop)
1014
1015    def test__call_connection_lost_with_err(self):
1016        tr = self.write_pipe_transport()
1017        self.assertIsNotNone(tr._protocol)
1018        self.assertIsNotNone(tr._loop)
1019
1020        err = OSError()
1021        tr._call_connection_lost(err)
1022        self.protocol.connection_lost.assert_called_with(err)
1023        self.pipe.close.assert_called_with()
1024
1025        self.assertIsNone(tr._protocol)
1026        self.assertIsNone(tr._loop)
1027
1028    def test_close(self):
1029        tr = self.write_pipe_transport()
1030        tr.write_eof = mock.Mock()
1031        tr.close()
1032        tr.write_eof.assert_called_with()
1033
1034        # closing the transport twice must not fail
1035        tr.close()
1036
1037    def test_close_closing(self):
1038        tr = self.write_pipe_transport()
1039        tr.write_eof = mock.Mock()
1040        tr._closing = True
1041        tr.close()
1042        self.assertFalse(tr.write_eof.called)
1043
1044    def test_write_eof(self):
1045        tr = self.write_pipe_transport()
1046        tr.write_eof()
1047        self.assertTrue(tr.is_closing())
1048        self.assertFalse(self.loop.readers)
1049        test_utils.run_briefly(self.loop)
1050        self.protocol.connection_lost.assert_called_with(None)
1051
1052    def test_write_eof_pending(self):
1053        tr = self.write_pipe_transport()
1054        tr._buffer = [b'data']
1055        tr.write_eof()
1056        self.assertTrue(tr.is_closing())
1057        self.assertFalse(self.protocol.connection_lost.called)
1058
1059
1060class AbstractChildWatcherTests(unittest.TestCase):
1061
1062    def test_not_implemented(self):
1063        f = mock.Mock()
1064        watcher = asyncio.AbstractChildWatcher()
1065        self.assertRaises(
1066            NotImplementedError, watcher.add_child_handler, f, f)
1067        self.assertRaises(
1068            NotImplementedError, watcher.remove_child_handler, f)
1069        self.assertRaises(
1070            NotImplementedError, watcher.attach_loop, f)
1071        self.assertRaises(
1072            NotImplementedError, watcher.close)
1073        self.assertRaises(
1074            NotImplementedError, watcher.__enter__)
1075        self.assertRaises(
1076            NotImplementedError, watcher.__exit__, f, f, f)
1077
1078
1079class BaseChildWatcherTests(unittest.TestCase):
1080
1081    def test_not_implemented(self):
1082        f = mock.Mock()
1083        watcher = unix_events.BaseChildWatcher()
1084        self.assertRaises(
1085            NotImplementedError, watcher._do_waitpid, f)
1086
1087
1088WaitPidMocks = collections.namedtuple("WaitPidMocks",
1089                                      ("waitpid",
1090                                       "WIFEXITED",
1091                                       "WIFSIGNALED",
1092                                       "WEXITSTATUS",
1093                                       "WTERMSIG",
1094                                       ))
1095
1096
1097class ChildWatcherTestsMixin:
1098
1099    ignore_warnings = mock.patch.object(log.logger, "warning")
1100
1101    def setUp(self):
1102        super().setUp()
1103        self.loop = self.new_test_loop()
1104        self.running = False
1105        self.zombies = {}
1106
1107        with mock.patch.object(
1108                self.loop, "add_signal_handler") as self.m_add_signal_handler:
1109            self.watcher = self.create_watcher()
1110            self.watcher.attach_loop(self.loop)
1111
1112    def waitpid(self, pid, flags):
1113        if isinstance(self.watcher, asyncio.SafeChildWatcher) or pid != -1:
1114            self.assertGreater(pid, 0)
1115        try:
1116            if pid < 0:
1117                return self.zombies.popitem()
1118            else:
1119                return pid, self.zombies.pop(pid)
1120        except KeyError:
1121            pass
1122        if self.running:
1123            return 0, 0
1124        else:
1125            raise ChildProcessError()
1126
1127    def add_zombie(self, pid, returncode):
1128        self.zombies[pid] = returncode + 32768
1129
1130    def WIFEXITED(self, status):
1131        return status >= 32768
1132
1133    def WIFSIGNALED(self, status):
1134        return 32700 < status < 32768
1135
1136    def WEXITSTATUS(self, status):
1137        self.assertTrue(self.WIFEXITED(status))
1138        return status - 32768
1139
1140    def WTERMSIG(self, status):
1141        self.assertTrue(self.WIFSIGNALED(status))
1142        return 32768 - status
1143
1144    def test_create_watcher(self):
1145        self.m_add_signal_handler.assert_called_once_with(
1146            signal.SIGCHLD, self.watcher._sig_chld)
1147
1148    def waitpid_mocks(func):
1149        def wrapped_func(self):
1150            def patch(target, wrapper):
1151                return mock.patch(target, wraps=wrapper,
1152                                  new_callable=mock.Mock)
1153
1154            with patch('os.WTERMSIG', self.WTERMSIG) as m_WTERMSIG, \
1155                 patch('os.WEXITSTATUS', self.WEXITSTATUS) as m_WEXITSTATUS, \
1156                 patch('os.WIFSIGNALED', self.WIFSIGNALED) as m_WIFSIGNALED, \
1157                 patch('os.WIFEXITED', self.WIFEXITED) as m_WIFEXITED, \
1158                 patch('os.waitpid', self.waitpid) as m_waitpid:
1159                func(self, WaitPidMocks(m_waitpid,
1160                                        m_WIFEXITED, m_WIFSIGNALED,
1161                                        m_WEXITSTATUS, m_WTERMSIG,
1162                                        ))
1163        return wrapped_func
1164
1165    @waitpid_mocks
1166    def test_sigchld(self, m):
1167        # register a child
1168        callback = mock.Mock()
1169
1170        with self.watcher:
1171            self.running = True
1172            self.watcher.add_child_handler(42, callback, 9, 10, 14)
1173
1174        self.assertFalse(callback.called)
1175        self.assertFalse(m.WIFEXITED.called)
1176        self.assertFalse(m.WIFSIGNALED.called)
1177        self.assertFalse(m.WEXITSTATUS.called)
1178        self.assertFalse(m.WTERMSIG.called)
1179
1180        # child is running
1181        self.watcher._sig_chld()
1182
1183        self.assertFalse(callback.called)
1184        self.assertFalse(m.WIFEXITED.called)
1185        self.assertFalse(m.WIFSIGNALED.called)
1186        self.assertFalse(m.WEXITSTATUS.called)
1187        self.assertFalse(m.WTERMSIG.called)
1188
1189        # child terminates (returncode 12)
1190        self.running = False
1191        self.add_zombie(42, 12)
1192        self.watcher._sig_chld()
1193
1194        self.assertTrue(m.WIFEXITED.called)
1195        self.assertTrue(m.WEXITSTATUS.called)
1196        self.assertFalse(m.WTERMSIG.called)
1197        callback.assert_called_once_with(42, 12, 9, 10, 14)
1198
1199        m.WIFSIGNALED.reset_mock()
1200        m.WIFEXITED.reset_mock()
1201        m.WEXITSTATUS.reset_mock()
1202        callback.reset_mock()
1203
1204        # ensure that the child is effectively reaped
1205        self.add_zombie(42, 13)
1206        with self.ignore_warnings:
1207            self.watcher._sig_chld()
1208
1209        self.assertFalse(callback.called)
1210        self.assertFalse(m.WTERMSIG.called)
1211
1212        m.WIFSIGNALED.reset_mock()
1213        m.WIFEXITED.reset_mock()
1214        m.WEXITSTATUS.reset_mock()
1215
1216        # sigchld called again
1217        self.zombies.clear()
1218        self.watcher._sig_chld()
1219
1220        self.assertFalse(callback.called)
1221        self.assertFalse(m.WIFEXITED.called)
1222        self.assertFalse(m.WIFSIGNALED.called)
1223        self.assertFalse(m.WEXITSTATUS.called)
1224        self.assertFalse(m.WTERMSIG.called)
1225
1226    @waitpid_mocks
1227    def test_sigchld_two_children(self, m):
1228        callback1 = mock.Mock()
1229        callback2 = mock.Mock()
1230
1231        # register child 1
1232        with self.watcher:
1233            self.running = True
1234            self.watcher.add_child_handler(43, callback1, 7, 8)
1235
1236        self.assertFalse(callback1.called)
1237        self.assertFalse(callback2.called)
1238        self.assertFalse(m.WIFEXITED.called)
1239        self.assertFalse(m.WIFSIGNALED.called)
1240        self.assertFalse(m.WEXITSTATUS.called)
1241        self.assertFalse(m.WTERMSIG.called)
1242
1243        # register child 2
1244        with self.watcher:
1245            self.watcher.add_child_handler(44, callback2, 147, 18)
1246
1247        self.assertFalse(callback1.called)
1248        self.assertFalse(callback2.called)
1249        self.assertFalse(m.WIFEXITED.called)
1250        self.assertFalse(m.WIFSIGNALED.called)
1251        self.assertFalse(m.WEXITSTATUS.called)
1252        self.assertFalse(m.WTERMSIG.called)
1253
1254        # children are running
1255        self.watcher._sig_chld()
1256
1257        self.assertFalse(callback1.called)
1258        self.assertFalse(callback2.called)
1259        self.assertFalse(m.WIFEXITED.called)
1260        self.assertFalse(m.WIFSIGNALED.called)
1261        self.assertFalse(m.WEXITSTATUS.called)
1262        self.assertFalse(m.WTERMSIG.called)
1263
1264        # child 1 terminates (signal 3)
1265        self.add_zombie(43, -3)
1266        self.watcher._sig_chld()
1267
1268        callback1.assert_called_once_with(43, -3, 7, 8)
1269        self.assertFalse(callback2.called)
1270        self.assertTrue(m.WIFSIGNALED.called)
1271        self.assertFalse(m.WEXITSTATUS.called)
1272        self.assertTrue(m.WTERMSIG.called)
1273
1274        m.WIFSIGNALED.reset_mock()
1275        m.WIFEXITED.reset_mock()
1276        m.WTERMSIG.reset_mock()
1277        callback1.reset_mock()
1278
1279        # child 2 still running
1280        self.watcher._sig_chld()
1281
1282        self.assertFalse(callback1.called)
1283        self.assertFalse(callback2.called)
1284        self.assertFalse(m.WIFEXITED.called)
1285        self.assertFalse(m.WIFSIGNALED.called)
1286        self.assertFalse(m.WEXITSTATUS.called)
1287        self.assertFalse(m.WTERMSIG.called)
1288
1289        # child 2 terminates (code 108)
1290        self.add_zombie(44, 108)
1291        self.running = False
1292        self.watcher._sig_chld()
1293
1294        callback2.assert_called_once_with(44, 108, 147, 18)
1295        self.assertFalse(callback1.called)
1296        self.assertTrue(m.WIFEXITED.called)
1297        self.assertTrue(m.WEXITSTATUS.called)
1298        self.assertFalse(m.WTERMSIG.called)
1299
1300        m.WIFSIGNALED.reset_mock()
1301        m.WIFEXITED.reset_mock()
1302        m.WEXITSTATUS.reset_mock()
1303        callback2.reset_mock()
1304
1305        # ensure that the children are effectively reaped
1306        self.add_zombie(43, 14)
1307        self.add_zombie(44, 15)
1308        with self.ignore_warnings:
1309            self.watcher._sig_chld()
1310
1311        self.assertFalse(callback1.called)
1312        self.assertFalse(callback2.called)
1313        self.assertFalse(m.WTERMSIG.called)
1314
1315        m.WIFSIGNALED.reset_mock()
1316        m.WIFEXITED.reset_mock()
1317        m.WEXITSTATUS.reset_mock()
1318
1319        # sigchld called again
1320        self.zombies.clear()
1321        self.watcher._sig_chld()
1322
1323        self.assertFalse(callback1.called)
1324        self.assertFalse(callback2.called)
1325        self.assertFalse(m.WIFEXITED.called)
1326        self.assertFalse(m.WIFSIGNALED.called)
1327        self.assertFalse(m.WEXITSTATUS.called)
1328        self.assertFalse(m.WTERMSIG.called)
1329
1330    @waitpid_mocks
1331    def test_sigchld_two_children_terminating_together(self, m):
1332        callback1 = mock.Mock()
1333        callback2 = mock.Mock()
1334
1335        # register child 1
1336        with self.watcher:
1337            self.running = True
1338            self.watcher.add_child_handler(45, callback1, 17, 8)
1339
1340        self.assertFalse(callback1.called)
1341        self.assertFalse(callback2.called)
1342        self.assertFalse(m.WIFEXITED.called)
1343        self.assertFalse(m.WIFSIGNALED.called)
1344        self.assertFalse(m.WEXITSTATUS.called)
1345        self.assertFalse(m.WTERMSIG.called)
1346
1347        # register child 2
1348        with self.watcher:
1349            self.watcher.add_child_handler(46, callback2, 1147, 18)
1350
1351        self.assertFalse(callback1.called)
1352        self.assertFalse(callback2.called)
1353        self.assertFalse(m.WIFEXITED.called)
1354        self.assertFalse(m.WIFSIGNALED.called)
1355        self.assertFalse(m.WEXITSTATUS.called)
1356        self.assertFalse(m.WTERMSIG.called)
1357
1358        # children are running
1359        self.watcher._sig_chld()
1360
1361        self.assertFalse(callback1.called)
1362        self.assertFalse(callback2.called)
1363        self.assertFalse(m.WIFEXITED.called)
1364        self.assertFalse(m.WIFSIGNALED.called)
1365        self.assertFalse(m.WEXITSTATUS.called)
1366        self.assertFalse(m.WTERMSIG.called)
1367
1368        # child 1 terminates (code 78)
1369        # child 2 terminates (signal 5)
1370        self.add_zombie(45, 78)
1371        self.add_zombie(46, -5)
1372        self.running = False
1373        self.watcher._sig_chld()
1374
1375        callback1.assert_called_once_with(45, 78, 17, 8)
1376        callback2.assert_called_once_with(46, -5, 1147, 18)
1377        self.assertTrue(m.WIFSIGNALED.called)
1378        self.assertTrue(m.WIFEXITED.called)
1379        self.assertTrue(m.WEXITSTATUS.called)
1380        self.assertTrue(m.WTERMSIG.called)
1381
1382        m.WIFSIGNALED.reset_mock()
1383        m.WIFEXITED.reset_mock()
1384        m.WTERMSIG.reset_mock()
1385        m.WEXITSTATUS.reset_mock()
1386        callback1.reset_mock()
1387        callback2.reset_mock()
1388
1389        # ensure that the children are effectively reaped
1390        self.add_zombie(45, 14)
1391        self.add_zombie(46, 15)
1392        with self.ignore_warnings:
1393            self.watcher._sig_chld()
1394
1395        self.assertFalse(callback1.called)
1396        self.assertFalse(callback2.called)
1397        self.assertFalse(m.WTERMSIG.called)
1398
1399    @waitpid_mocks
1400    def test_sigchld_race_condition(self, m):
1401        # register a child
1402        callback = mock.Mock()
1403
1404        with self.watcher:
1405            # child terminates before being registered
1406            self.add_zombie(50, 4)
1407            self.watcher._sig_chld()
1408
1409            self.watcher.add_child_handler(50, callback, 1, 12)
1410
1411        callback.assert_called_once_with(50, 4, 1, 12)
1412        callback.reset_mock()
1413
1414        # ensure that the child is effectively reaped
1415        self.add_zombie(50, -1)
1416        with self.ignore_warnings:
1417            self.watcher._sig_chld()
1418
1419        self.assertFalse(callback.called)
1420
1421    @waitpid_mocks
1422    def test_sigchld_replace_handler(self, m):
1423        callback1 = mock.Mock()
1424        callback2 = mock.Mock()
1425
1426        # register a child
1427        with self.watcher:
1428            self.running = True
1429            self.watcher.add_child_handler(51, callback1, 19)
1430
1431        self.assertFalse(callback1.called)
1432        self.assertFalse(callback2.called)
1433        self.assertFalse(m.WIFEXITED.called)
1434        self.assertFalse(m.WIFSIGNALED.called)
1435        self.assertFalse(m.WEXITSTATUS.called)
1436        self.assertFalse(m.WTERMSIG.called)
1437
1438        # register the same child again
1439        with self.watcher:
1440            self.watcher.add_child_handler(51, callback2, 21)
1441
1442        self.assertFalse(callback1.called)
1443        self.assertFalse(callback2.called)
1444        self.assertFalse(m.WIFEXITED.called)
1445        self.assertFalse(m.WIFSIGNALED.called)
1446        self.assertFalse(m.WEXITSTATUS.called)
1447        self.assertFalse(m.WTERMSIG.called)
1448
1449        # child terminates (signal 8)
1450        self.running = False
1451        self.add_zombie(51, -8)
1452        self.watcher._sig_chld()
1453
1454        callback2.assert_called_once_with(51, -8, 21)
1455        self.assertFalse(callback1.called)
1456        self.assertTrue(m.WIFSIGNALED.called)
1457        self.assertFalse(m.WEXITSTATUS.called)
1458        self.assertTrue(m.WTERMSIG.called)
1459
1460        m.WIFSIGNALED.reset_mock()
1461        m.WIFEXITED.reset_mock()
1462        m.WTERMSIG.reset_mock()
1463        callback2.reset_mock()
1464
1465        # ensure that the child is effectively reaped
1466        self.add_zombie(51, 13)
1467        with self.ignore_warnings:
1468            self.watcher._sig_chld()
1469
1470        self.assertFalse(callback1.called)
1471        self.assertFalse(callback2.called)
1472        self.assertFalse(m.WTERMSIG.called)
1473
1474    @waitpid_mocks
1475    def test_sigchld_remove_handler(self, m):
1476        callback = mock.Mock()
1477
1478        # register a child
1479        with self.watcher:
1480            self.running = True
1481            self.watcher.add_child_handler(52, callback, 1984)
1482
1483        self.assertFalse(callback.called)
1484        self.assertFalse(m.WIFEXITED.called)
1485        self.assertFalse(m.WIFSIGNALED.called)
1486        self.assertFalse(m.WEXITSTATUS.called)
1487        self.assertFalse(m.WTERMSIG.called)
1488
1489        # unregister the child
1490        self.watcher.remove_child_handler(52)
1491
1492        self.assertFalse(callback.called)
1493        self.assertFalse(m.WIFEXITED.called)
1494        self.assertFalse(m.WIFSIGNALED.called)
1495        self.assertFalse(m.WEXITSTATUS.called)
1496        self.assertFalse(m.WTERMSIG.called)
1497
1498        # child terminates (code 99)
1499        self.running = False
1500        self.add_zombie(52, 99)
1501        with self.ignore_warnings:
1502            self.watcher._sig_chld()
1503
1504        self.assertFalse(callback.called)
1505
1506    @waitpid_mocks
1507    def test_sigchld_unknown_status(self, m):
1508        callback = mock.Mock()
1509
1510        # register a child
1511        with self.watcher:
1512            self.running = True
1513            self.watcher.add_child_handler(53, callback, -19)
1514
1515        self.assertFalse(callback.called)
1516        self.assertFalse(m.WIFEXITED.called)
1517        self.assertFalse(m.WIFSIGNALED.called)
1518        self.assertFalse(m.WEXITSTATUS.called)
1519        self.assertFalse(m.WTERMSIG.called)
1520
1521        # terminate with unknown status
1522        self.zombies[53] = 1178
1523        self.running = False
1524        self.watcher._sig_chld()
1525
1526        callback.assert_called_once_with(53, 1178, -19)
1527        self.assertTrue(m.WIFEXITED.called)
1528        self.assertTrue(m.WIFSIGNALED.called)
1529        self.assertFalse(m.WEXITSTATUS.called)
1530        self.assertFalse(m.WTERMSIG.called)
1531
1532        callback.reset_mock()
1533        m.WIFEXITED.reset_mock()
1534        m.WIFSIGNALED.reset_mock()
1535
1536        # ensure that the child is effectively reaped
1537        self.add_zombie(53, 101)
1538        with self.ignore_warnings:
1539            self.watcher._sig_chld()
1540
1541        self.assertFalse(callback.called)
1542
1543    @waitpid_mocks
1544    def test_remove_child_handler(self, m):
1545        callback1 = mock.Mock()
1546        callback2 = mock.Mock()
1547        callback3 = mock.Mock()
1548
1549        # register children
1550        with self.watcher:
1551            self.running = True
1552            self.watcher.add_child_handler(54, callback1, 1)
1553            self.watcher.add_child_handler(55, callback2, 2)
1554            self.watcher.add_child_handler(56, callback3, 3)
1555
1556        # remove child handler 1
1557        self.assertTrue(self.watcher.remove_child_handler(54))
1558
1559        # remove child handler 2 multiple times
1560        self.assertTrue(self.watcher.remove_child_handler(55))
1561        self.assertFalse(self.watcher.remove_child_handler(55))
1562        self.assertFalse(self.watcher.remove_child_handler(55))
1563
1564        # all children terminate
1565        self.add_zombie(54, 0)
1566        self.add_zombie(55, 1)
1567        self.add_zombie(56, 2)
1568        self.running = False
1569        with self.ignore_warnings:
1570            self.watcher._sig_chld()
1571
1572        self.assertFalse(callback1.called)
1573        self.assertFalse(callback2.called)
1574        callback3.assert_called_once_with(56, 2, 3)
1575
1576    @waitpid_mocks
1577    def test_sigchld_unhandled_exception(self, m):
1578        callback = mock.Mock()
1579
1580        # register a child
1581        with self.watcher:
1582            self.running = True
1583            self.watcher.add_child_handler(57, callback)
1584
1585        # raise an exception
1586        m.waitpid.side_effect = ValueError
1587
1588        with mock.patch.object(log.logger,
1589                               'error') as m_error:
1590
1591            self.assertEqual(self.watcher._sig_chld(), None)
1592            self.assertTrue(m_error.called)
1593
1594    @waitpid_mocks
1595    def test_sigchld_child_reaped_elsewhere(self, m):
1596        # register a child
1597        callback = mock.Mock()
1598
1599        with self.watcher:
1600            self.running = True
1601            self.watcher.add_child_handler(58, callback)
1602
1603        self.assertFalse(callback.called)
1604        self.assertFalse(m.WIFEXITED.called)
1605        self.assertFalse(m.WIFSIGNALED.called)
1606        self.assertFalse(m.WEXITSTATUS.called)
1607        self.assertFalse(m.WTERMSIG.called)
1608
1609        # child terminates
1610        self.running = False
1611        self.add_zombie(58, 4)
1612
1613        # waitpid is called elsewhere
1614        os.waitpid(58, os.WNOHANG)
1615
1616        m.waitpid.reset_mock()
1617
1618        # sigchld
1619        with self.ignore_warnings:
1620            self.watcher._sig_chld()
1621
1622        if isinstance(self.watcher, asyncio.FastChildWatcher):
1623            # here the FastChildWatche enters a deadlock
1624            # (there is no way to prevent it)
1625            self.assertFalse(callback.called)
1626        else:
1627            callback.assert_called_once_with(58, 255)
1628
1629    @waitpid_mocks
1630    def test_sigchld_unknown_pid_during_registration(self, m):
1631        # register two children
1632        callback1 = mock.Mock()
1633        callback2 = mock.Mock()
1634
1635        with self.ignore_warnings, self.watcher:
1636            self.running = True
1637            # child 1 terminates
1638            self.add_zombie(591, 7)
1639            # an unknown child terminates
1640            self.add_zombie(593, 17)
1641
1642            self.watcher._sig_chld()
1643
1644            self.watcher.add_child_handler(591, callback1)
1645            self.watcher.add_child_handler(592, callback2)
1646
1647        callback1.assert_called_once_with(591, 7)
1648        self.assertFalse(callback2.called)
1649
1650    @waitpid_mocks
1651    def test_set_loop(self, m):
1652        # register a child
1653        callback = mock.Mock()
1654
1655        with self.watcher:
1656            self.running = True
1657            self.watcher.add_child_handler(60, callback)
1658
1659        # attach a new loop
1660        old_loop = self.loop
1661        self.loop = self.new_test_loop()
1662        patch = mock.patch.object
1663
1664        with patch(old_loop, "remove_signal_handler") as m_old_remove, \
1665             patch(self.loop, "add_signal_handler") as m_new_add:
1666
1667            self.watcher.attach_loop(self.loop)
1668
1669            m_old_remove.assert_called_once_with(
1670                signal.SIGCHLD)
1671            m_new_add.assert_called_once_with(
1672                signal.SIGCHLD, self.watcher._sig_chld)
1673
1674        # child terminates
1675        self.running = False
1676        self.add_zombie(60, 9)
1677        self.watcher._sig_chld()
1678
1679        callback.assert_called_once_with(60, 9)
1680
1681    @waitpid_mocks
1682    def test_set_loop_race_condition(self, m):
1683        # register 3 children
1684        callback1 = mock.Mock()
1685        callback2 = mock.Mock()
1686        callback3 = mock.Mock()
1687
1688        with self.watcher:
1689            self.running = True
1690            self.watcher.add_child_handler(61, callback1)
1691            self.watcher.add_child_handler(62, callback2)
1692            self.watcher.add_child_handler(622, callback3)
1693
1694        # detach the loop
1695        old_loop = self.loop
1696        self.loop = None
1697
1698        with mock.patch.object(
1699                old_loop, "remove_signal_handler") as m_remove_signal_handler:
1700
1701            with self.assertWarnsRegex(
1702                    RuntimeWarning, 'A loop is being detached'):
1703                self.watcher.attach_loop(None)
1704
1705            m_remove_signal_handler.assert_called_once_with(
1706                signal.SIGCHLD)
1707
1708        # child 1 & 2 terminate
1709        self.add_zombie(61, 11)
1710        self.add_zombie(62, -5)
1711
1712        # SIGCHLD was not caught
1713        self.assertFalse(callback1.called)
1714        self.assertFalse(callback2.called)
1715        self.assertFalse(callback3.called)
1716
1717        # attach a new loop
1718        self.loop = self.new_test_loop()
1719
1720        with mock.patch.object(
1721                self.loop, "add_signal_handler") as m_add_signal_handler:
1722
1723            self.watcher.attach_loop(self.loop)
1724
1725            m_add_signal_handler.assert_called_once_with(
1726                signal.SIGCHLD, self.watcher._sig_chld)
1727            callback1.assert_called_once_with(61, 11)  # race condition!
1728            callback2.assert_called_once_with(62, -5)  # race condition!
1729            self.assertFalse(callback3.called)
1730
1731        callback1.reset_mock()
1732        callback2.reset_mock()
1733
1734        # child 3 terminates
1735        self.running = False
1736        self.add_zombie(622, 19)
1737        self.watcher._sig_chld()
1738
1739        self.assertFalse(callback1.called)
1740        self.assertFalse(callback2.called)
1741        callback3.assert_called_once_with(622, 19)
1742
1743    @waitpid_mocks
1744    def test_close(self, m):
1745        # register two children
1746        callback1 = mock.Mock()
1747
1748        with self.watcher:
1749            self.running = True
1750            # child 1 terminates
1751            self.add_zombie(63, 9)
1752            # other child terminates
1753            self.add_zombie(65, 18)
1754            self.watcher._sig_chld()
1755
1756            self.watcher.add_child_handler(63, callback1)
1757            self.watcher.add_child_handler(64, callback1)
1758
1759            self.assertEqual(len(self.watcher._callbacks), 1)
1760            if isinstance(self.watcher, asyncio.FastChildWatcher):
1761                self.assertEqual(len(self.watcher._zombies), 1)
1762
1763            with mock.patch.object(
1764                    self.loop,
1765                    "remove_signal_handler") as m_remove_signal_handler:
1766
1767                self.watcher.close()
1768
1769                m_remove_signal_handler.assert_called_once_with(
1770                    signal.SIGCHLD)
1771                self.assertFalse(self.watcher._callbacks)
1772                if isinstance(self.watcher, asyncio.FastChildWatcher):
1773                    self.assertFalse(self.watcher._zombies)
1774
1775    @waitpid_mocks
1776    def test_add_child_handler_with_no_loop_attached(self, m):
1777        callback = mock.Mock()
1778        with self.create_watcher() as watcher:
1779            with self.assertRaisesRegex(
1780                    RuntimeError,
1781                    'the child watcher does not have a loop attached'):
1782                watcher.add_child_handler(100, callback)
1783
1784
1785class SafeChildWatcherTests (ChildWatcherTestsMixin, test_utils.TestCase):
1786    def create_watcher(self):
1787        return asyncio.SafeChildWatcher()
1788
1789
1790class FastChildWatcherTests (ChildWatcherTestsMixin, test_utils.TestCase):
1791    def create_watcher(self):
1792        return asyncio.FastChildWatcher()
1793
1794
1795class PolicyTests(unittest.TestCase):
1796
1797    def create_policy(self):
1798        return asyncio.DefaultEventLoopPolicy()
1799
1800    def test_get_child_watcher(self):
1801        policy = self.create_policy()
1802        self.assertIsNone(policy._watcher)
1803
1804        watcher = policy.get_child_watcher()
1805        self.assertIsInstance(watcher, asyncio.SafeChildWatcher)
1806
1807        self.assertIs(policy._watcher, watcher)
1808
1809        self.assertIs(watcher, policy.get_child_watcher())
1810        self.assertIsNone(watcher._loop)
1811
1812    def test_get_child_watcher_after_set(self):
1813        policy = self.create_policy()
1814        watcher = asyncio.FastChildWatcher()
1815
1816        policy.set_child_watcher(watcher)
1817        self.assertIs(policy._watcher, watcher)
1818        self.assertIs(watcher, policy.get_child_watcher())
1819
1820    def test_get_child_watcher_with_mainloop_existing(self):
1821        policy = self.create_policy()
1822        loop = policy.get_event_loop()
1823
1824        self.assertIsNone(policy._watcher)
1825        watcher = policy.get_child_watcher()
1826
1827        self.assertIsInstance(watcher, asyncio.SafeChildWatcher)
1828        self.assertIs(watcher._loop, loop)
1829
1830        loop.close()
1831
1832    def test_get_child_watcher_thread(self):
1833
1834        def f():
1835            policy.set_event_loop(policy.new_event_loop())
1836
1837            self.assertIsInstance(policy.get_event_loop(),
1838                                  asyncio.AbstractEventLoop)
1839            watcher = policy.get_child_watcher()
1840
1841            self.assertIsInstance(watcher, asyncio.SafeChildWatcher)
1842            self.assertIsNone(watcher._loop)
1843
1844            policy.get_event_loop().close()
1845
1846        policy = self.create_policy()
1847
1848        th = threading.Thread(target=f)
1849        th.start()
1850        th.join()
1851
1852    def test_child_watcher_replace_mainloop_existing(self):
1853        policy = self.create_policy()
1854        loop = policy.get_event_loop()
1855
1856        watcher = policy.get_child_watcher()
1857
1858        self.assertIs(watcher._loop, loop)
1859
1860        new_loop = policy.new_event_loop()
1861        policy.set_event_loop(new_loop)
1862
1863        self.assertIs(watcher._loop, new_loop)
1864
1865        policy.set_event_loop(None)
1866
1867        self.assertIs(watcher._loop, None)
1868
1869        loop.close()
1870        new_loop.close()
1871
1872
1873class TestFunctional(unittest.TestCase):
1874
1875    def setUp(self):
1876        self.loop = asyncio.new_event_loop()
1877        asyncio.set_event_loop(self.loop)
1878
1879    def tearDown(self):
1880        self.loop.close()
1881        asyncio.set_event_loop(None)
1882
1883    def test_add_reader_invalid_argument(self):
1884        def assert_raises():
1885            return self.assertRaisesRegex(ValueError, r'Invalid file object')
1886
1887        cb = lambda: None
1888
1889        with assert_raises():
1890            self.loop.add_reader(object(), cb)
1891        with assert_raises():
1892            self.loop.add_writer(object(), cb)
1893
1894        with assert_raises():
1895            self.loop.remove_reader(object())
1896        with assert_raises():
1897            self.loop.remove_writer(object())
1898
1899    def test_add_reader_or_writer_transport_fd(self):
1900        def assert_raises():
1901            return self.assertRaisesRegex(
1902                RuntimeError,
1903                r'File descriptor .* is used by transport')
1904
1905        async def runner():
1906            tr, pr = await self.loop.create_connection(
1907                lambda: asyncio.Protocol(), sock=rsock)
1908
1909            try:
1910                cb = lambda: None
1911
1912                with assert_raises():
1913                    self.loop.add_reader(rsock, cb)
1914                with assert_raises():
1915                    self.loop.add_reader(rsock.fileno(), cb)
1916
1917                with assert_raises():
1918                    self.loop.remove_reader(rsock)
1919                with assert_raises():
1920                    self.loop.remove_reader(rsock.fileno())
1921
1922                with assert_raises():
1923                    self.loop.add_writer(rsock, cb)
1924                with assert_raises():
1925                    self.loop.add_writer(rsock.fileno(), cb)
1926
1927                with assert_raises():
1928                    self.loop.remove_writer(rsock)
1929                with assert_raises():
1930                    self.loop.remove_writer(rsock.fileno())
1931
1932            finally:
1933                tr.close()
1934
1935        rsock, wsock = socket.socketpair()
1936        try:
1937            self.loop.run_until_complete(runner())
1938        finally:
1939            rsock.close()
1940            wsock.close()
1941
1942
1943if __name__ == '__main__':
1944    unittest.main()
1945