1"""Tests for proactor_events.py"""
2
3import socket
4import unittest
5from unittest import mock
6
7import asyncio
8from asyncio.proactor_events import BaseProactorEventLoop
9from asyncio.proactor_events import _ProactorSocketTransport
10from asyncio.proactor_events import _ProactorWritePipeTransport
11from asyncio.proactor_events import _ProactorDuplexPipeTransport
12from asyncio import test_utils
13
14
15def close_transport(transport):
16    # Don't call transport.close() because the event loop and the IOCP proactor
17    # are mocked
18    if transport._sock is None:
19        return
20    transport._sock.close()
21    transport._sock = None
22
23
24class ProactorSocketTransportTests(test_utils.TestCase):
25
26    def setUp(self):
27        super().setUp()
28        self.loop = self.new_test_loop()
29        self.addCleanup(self.loop.close)
30        self.proactor = mock.Mock()
31        self.loop._proactor = self.proactor
32        self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
33        self.sock = mock.Mock(socket.socket)
34
35    def socket_transport(self, waiter=None):
36        transport = _ProactorSocketTransport(self.loop, self.sock,
37                                             self.protocol, waiter=waiter)
38        self.addCleanup(close_transport, transport)
39        return transport
40
41    def test_ctor(self):
42        fut = asyncio.Future(loop=self.loop)
43        tr = self.socket_transport(waiter=fut)
44        test_utils.run_briefly(self.loop)
45        self.assertIsNone(fut.result())
46        self.protocol.connection_made(tr)
47        self.proactor.recv.assert_called_with(self.sock, 4096)
48
49    def test_loop_reading(self):
50        tr = self.socket_transport()
51        tr._loop_reading()
52        self.loop._proactor.recv.assert_called_with(self.sock, 4096)
53        self.assertFalse(self.protocol.data_received.called)
54        self.assertFalse(self.protocol.eof_received.called)
55
56    def test_loop_reading_data(self):
57        res = asyncio.Future(loop=self.loop)
58        res.set_result(b'data')
59
60        tr = self.socket_transport()
61        tr._read_fut = res
62        tr._loop_reading(res)
63        self.loop._proactor.recv.assert_called_with(self.sock, 4096)
64        self.protocol.data_received.assert_called_with(b'data')
65
66    def test_loop_reading_no_data(self):
67        res = asyncio.Future(loop=self.loop)
68        res.set_result(b'')
69
70        tr = self.socket_transport()
71        self.assertRaises(AssertionError, tr._loop_reading, res)
72
73        tr.close = mock.Mock()
74        tr._read_fut = res
75        tr._loop_reading(res)
76        self.assertFalse(self.loop._proactor.recv.called)
77        self.assertTrue(self.protocol.eof_received.called)
78        self.assertTrue(tr.close.called)
79
80    def test_loop_reading_aborted(self):
81        err = self.loop._proactor.recv.side_effect = ConnectionAbortedError()
82
83        tr = self.socket_transport()
84        tr._fatal_error = mock.Mock()
85        tr._loop_reading()
86        tr._fatal_error.assert_called_with(
87                            err,
88                            'Fatal read error on pipe transport')
89
90    def test_loop_reading_aborted_closing(self):
91        self.loop._proactor.recv.side_effect = ConnectionAbortedError()
92
93        tr = self.socket_transport()
94        tr._closing = True
95        tr._fatal_error = mock.Mock()
96        tr._loop_reading()
97        self.assertFalse(tr._fatal_error.called)
98
99    def test_loop_reading_aborted_is_fatal(self):
100        self.loop._proactor.recv.side_effect = ConnectionAbortedError()
101        tr = self.socket_transport()
102        tr._closing = False
103        tr._fatal_error = mock.Mock()
104        tr._loop_reading()
105        self.assertTrue(tr._fatal_error.called)
106
107    def test_loop_reading_conn_reset_lost(self):
108        err = self.loop._proactor.recv.side_effect = ConnectionResetError()
109
110        tr = self.socket_transport()
111        tr._closing = False
112        tr._fatal_error = mock.Mock()
113        tr._force_close = mock.Mock()
114        tr._loop_reading()
115        self.assertFalse(tr._fatal_error.called)
116        tr._force_close.assert_called_with(err)
117
118    def test_loop_reading_exception(self):
119        err = self.loop._proactor.recv.side_effect = (OSError())
120
121        tr = self.socket_transport()
122        tr._fatal_error = mock.Mock()
123        tr._loop_reading()
124        tr._fatal_error.assert_called_with(
125                            err,
126                            'Fatal read error on pipe transport')
127
128    def test_write(self):
129        tr = self.socket_transport()
130        tr._loop_writing = mock.Mock()
131        tr.write(b'data')
132        self.assertEqual(tr._buffer, None)
133        tr._loop_writing.assert_called_with(data=b'data')
134
135    def test_write_no_data(self):
136        tr = self.socket_transport()
137        tr.write(b'')
138        self.assertFalse(tr._buffer)
139
140    def test_write_more(self):
141        tr = self.socket_transport()
142        tr._write_fut = mock.Mock()
143        tr._loop_writing = mock.Mock()
144        tr.write(b'data')
145        self.assertEqual(tr._buffer, b'data')
146        self.assertFalse(tr._loop_writing.called)
147
148    def test_loop_writing(self):
149        tr = self.socket_transport()
150        tr._buffer = bytearray(b'data')
151        tr._loop_writing()
152        self.loop._proactor.send.assert_called_with(self.sock, b'data')
153        self.loop._proactor.send.return_value.add_done_callback.\
154            assert_called_with(tr._loop_writing)
155
156    @mock.patch('asyncio.proactor_events.logger')
157    def test_loop_writing_err(self, m_log):
158        err = self.loop._proactor.send.side_effect = OSError()
159        tr = self.socket_transport()
160        tr._fatal_error = mock.Mock()
161        tr._buffer = [b'da', b'ta']
162        tr._loop_writing()
163        tr._fatal_error.assert_called_with(
164                            err,
165                            'Fatal write error on pipe transport')
166        tr._conn_lost = 1
167
168        tr.write(b'data')
169        tr.write(b'data')
170        tr.write(b'data')
171        tr.write(b'data')
172        tr.write(b'data')
173        self.assertEqual(tr._buffer, None)
174        m_log.warning.assert_called_with('socket.send() raised exception.')
175
176    def test_loop_writing_stop(self):
177        fut = asyncio.Future(loop=self.loop)
178        fut.set_result(b'data')
179
180        tr = self.socket_transport()
181        tr._write_fut = fut
182        tr._loop_writing(fut)
183        self.assertIsNone(tr._write_fut)
184
185    def test_loop_writing_closing(self):
186        fut = asyncio.Future(loop=self.loop)
187        fut.set_result(1)
188
189        tr = self.socket_transport()
190        tr._write_fut = fut
191        tr.close()
192        tr._loop_writing(fut)
193        self.assertIsNone(tr._write_fut)
194        test_utils.run_briefly(self.loop)
195        self.protocol.connection_lost.assert_called_with(None)
196
197    def test_abort(self):
198        tr = self.socket_transport()
199        tr._force_close = mock.Mock()
200        tr.abort()
201        tr._force_close.assert_called_with(None)
202
203    def test_close(self):
204        tr = self.socket_transport()
205        tr.close()
206        test_utils.run_briefly(self.loop)
207        self.protocol.connection_lost.assert_called_with(None)
208        self.assertTrue(tr.is_closing())
209        self.assertEqual(tr._conn_lost, 1)
210
211        self.protocol.connection_lost.reset_mock()
212        tr.close()
213        test_utils.run_briefly(self.loop)
214        self.assertFalse(self.protocol.connection_lost.called)
215
216    def test_close_write_fut(self):
217        tr = self.socket_transport()
218        tr._write_fut = mock.Mock()
219        tr.close()
220        test_utils.run_briefly(self.loop)
221        self.assertFalse(self.protocol.connection_lost.called)
222
223    def test_close_buffer(self):
224        tr = self.socket_transport()
225        tr._buffer = [b'data']
226        tr.close()
227        test_utils.run_briefly(self.loop)
228        self.assertFalse(self.protocol.connection_lost.called)
229
230    @mock.patch('asyncio.base_events.logger')
231    def test_fatal_error(self, m_logging):
232        tr = self.socket_transport()
233        tr._force_close = mock.Mock()
234        tr._fatal_error(None)
235        self.assertTrue(tr._force_close.called)
236        self.assertTrue(m_logging.error.called)
237
238    def test_force_close(self):
239        tr = self.socket_transport()
240        tr._buffer = [b'data']
241        read_fut = tr._read_fut = mock.Mock()
242        write_fut = tr._write_fut = mock.Mock()
243        tr._force_close(None)
244
245        read_fut.cancel.assert_called_with()
246        write_fut.cancel.assert_called_with()
247        test_utils.run_briefly(self.loop)
248        self.protocol.connection_lost.assert_called_with(None)
249        self.assertEqual(None, tr._buffer)
250        self.assertEqual(tr._conn_lost, 1)
251
252    def test_force_close_idempotent(self):
253        tr = self.socket_transport()
254        tr._closing = True
255        tr._force_close(None)
256        test_utils.run_briefly(self.loop)
257        self.assertFalse(self.protocol.connection_lost.called)
258
259    def test_fatal_error_2(self):
260        tr = self.socket_transport()
261        tr._buffer = [b'data']
262        tr._force_close(None)
263
264        test_utils.run_briefly(self.loop)
265        self.protocol.connection_lost.assert_called_with(None)
266        self.assertEqual(None, tr._buffer)
267
268    def test_call_connection_lost(self):
269        tr = self.socket_transport()
270        tr._call_connection_lost(None)
271        self.assertTrue(self.protocol.connection_lost.called)
272        self.assertTrue(self.sock.close.called)
273
274    def test_write_eof(self):
275        tr = self.socket_transport()
276        self.assertTrue(tr.can_write_eof())
277        tr.write_eof()
278        self.sock.shutdown.assert_called_with(socket.SHUT_WR)
279        tr.write_eof()
280        self.assertEqual(self.sock.shutdown.call_count, 1)
281        tr.close()
282
283    def test_write_eof_buffer(self):
284        tr = self.socket_transport()
285        f = asyncio.Future(loop=self.loop)
286        tr._loop._proactor.send.return_value = f
287        tr.write(b'data')
288        tr.write_eof()
289        self.assertTrue(tr._eof_written)
290        self.assertFalse(self.sock.shutdown.called)
291        tr._loop._proactor.send.assert_called_with(self.sock, b'data')
292        f.set_result(4)
293        self.loop._run_once()
294        self.sock.shutdown.assert_called_with(socket.SHUT_WR)
295        tr.close()
296
297    def test_write_eof_write_pipe(self):
298        tr = _ProactorWritePipeTransport(
299            self.loop, self.sock, self.protocol)
300        self.assertTrue(tr.can_write_eof())
301        tr.write_eof()
302        self.assertTrue(tr.is_closing())
303        self.loop._run_once()
304        self.assertTrue(self.sock.close.called)
305        tr.close()
306
307    def test_write_eof_buffer_write_pipe(self):
308        tr = _ProactorWritePipeTransport(self.loop, self.sock, self.protocol)
309        f = asyncio.Future(loop=self.loop)
310        tr._loop._proactor.send.return_value = f
311        tr.write(b'data')
312        tr.write_eof()
313        self.assertTrue(tr.is_closing())
314        self.assertFalse(self.sock.shutdown.called)
315        tr._loop._proactor.send.assert_called_with(self.sock, b'data')
316        f.set_result(4)
317        self.loop._run_once()
318        self.loop._run_once()
319        self.assertTrue(self.sock.close.called)
320        tr.close()
321
322    def test_write_eof_duplex_pipe(self):
323        tr = _ProactorDuplexPipeTransport(
324            self.loop, self.sock, self.protocol)
325        self.assertFalse(tr.can_write_eof())
326        with self.assertRaises(NotImplementedError):
327            tr.write_eof()
328        close_transport(tr)
329
330    def test_pause_resume_reading(self):
331        tr = self.socket_transport()
332        futures = []
333        for msg in [b'data1', b'data2', b'data3', b'data4', b'']:
334            f = asyncio.Future(loop=self.loop)
335            f.set_result(msg)
336            futures.append(f)
337        self.loop._proactor.recv.side_effect = futures
338        self.loop._run_once()
339        self.assertFalse(tr._paused)
340        self.loop._run_once()
341        self.protocol.data_received.assert_called_with(b'data1')
342        self.loop._run_once()
343        self.protocol.data_received.assert_called_with(b'data2')
344        tr.pause_reading()
345        self.assertTrue(tr._paused)
346        for i in range(10):
347            self.loop._run_once()
348        self.protocol.data_received.assert_called_with(b'data2')
349        tr.resume_reading()
350        self.assertFalse(tr._paused)
351        self.loop._run_once()
352        self.protocol.data_received.assert_called_with(b'data3')
353        self.loop._run_once()
354        self.protocol.data_received.assert_called_with(b'data4')
355        tr.close()
356
357
358    def pause_writing_transport(self, high):
359        tr = self.socket_transport()
360        tr.set_write_buffer_limits(high=high)
361
362        self.assertEqual(tr.get_write_buffer_size(), 0)
363        self.assertFalse(self.protocol.pause_writing.called)
364        self.assertFalse(self.protocol.resume_writing.called)
365        return tr
366
367    def test_pause_resume_writing(self):
368        tr = self.pause_writing_transport(high=4)
369
370        # write a large chunk, must pause writing
371        fut = asyncio.Future(loop=self.loop)
372        self.loop._proactor.send.return_value = fut
373        tr.write(b'large data')
374        self.loop._run_once()
375        self.assertTrue(self.protocol.pause_writing.called)
376
377        # flush the buffer
378        fut.set_result(None)
379        self.loop._run_once()
380        self.assertEqual(tr.get_write_buffer_size(), 0)
381        self.assertTrue(self.protocol.resume_writing.called)
382
383    def test_pause_writing_2write(self):
384        tr = self.pause_writing_transport(high=4)
385
386        # first short write, the buffer is not full (3 <= 4)
387        fut1 = asyncio.Future(loop=self.loop)
388        self.loop._proactor.send.return_value = fut1
389        tr.write(b'123')
390        self.loop._run_once()
391        self.assertEqual(tr.get_write_buffer_size(), 3)
392        self.assertFalse(self.protocol.pause_writing.called)
393
394        # fill the buffer, must pause writing (6 > 4)
395        tr.write(b'abc')
396        self.loop._run_once()
397        self.assertEqual(tr.get_write_buffer_size(), 6)
398        self.assertTrue(self.protocol.pause_writing.called)
399
400    def test_pause_writing_3write(self):
401        tr = self.pause_writing_transport(high=4)
402
403        # first short write, the buffer is not full (1 <= 4)
404        fut = asyncio.Future(loop=self.loop)
405        self.loop._proactor.send.return_value = fut
406        tr.write(b'1')
407        self.loop._run_once()
408        self.assertEqual(tr.get_write_buffer_size(), 1)
409        self.assertFalse(self.protocol.pause_writing.called)
410
411        # second short write, the buffer is not full (3 <= 4)
412        tr.write(b'23')
413        self.loop._run_once()
414        self.assertEqual(tr.get_write_buffer_size(), 3)
415        self.assertFalse(self.protocol.pause_writing.called)
416
417        # fill the buffer, must pause writing (6 > 4)
418        tr.write(b'abc')
419        self.loop._run_once()
420        self.assertEqual(tr.get_write_buffer_size(), 6)
421        self.assertTrue(self.protocol.pause_writing.called)
422
423    def test_dont_pause_writing(self):
424        tr = self.pause_writing_transport(high=4)
425
426        # write a large chunk which completes immedialty,
427        # it should not pause writing
428        fut = asyncio.Future(loop=self.loop)
429        fut.set_result(None)
430        self.loop._proactor.send.return_value = fut
431        tr.write(b'very large data')
432        self.loop._run_once()
433        self.assertEqual(tr.get_write_buffer_size(), 0)
434        self.assertFalse(self.protocol.pause_writing.called)
435
436
437class BaseProactorEventLoopTests(test_utils.TestCase):
438
439    def setUp(self):
440        super().setUp()
441
442        self.sock = test_utils.mock_nonblocking_socket()
443        self.proactor = mock.Mock()
444
445        self.ssock, self.csock = mock.Mock(), mock.Mock()
446
447        class EventLoop(BaseProactorEventLoop):
448            def _socketpair(s):
449                return (self.ssock, self.csock)
450
451        self.loop = EventLoop(self.proactor)
452        self.set_event_loop(self.loop)
453
454    @mock.patch.object(BaseProactorEventLoop, 'call_soon')
455    @mock.patch.object(BaseProactorEventLoop, '_socketpair')
456    def test_ctor(self, socketpair, call_soon):
457        ssock, csock = socketpair.return_value = (
458            mock.Mock(), mock.Mock())
459        loop = BaseProactorEventLoop(self.proactor)
460        self.assertIs(loop._ssock, ssock)
461        self.assertIs(loop._csock, csock)
462        self.assertEqual(loop._internal_fds, 1)
463        call_soon.assert_called_with(loop._loop_self_reading)
464        loop.close()
465
466    def test_close_self_pipe(self):
467        self.loop._close_self_pipe()
468        self.assertEqual(self.loop._internal_fds, 0)
469        self.assertTrue(self.ssock.close.called)
470        self.assertTrue(self.csock.close.called)
471        self.assertIsNone(self.loop._ssock)
472        self.assertIsNone(self.loop._csock)
473
474        # Don't call close(): _close_self_pipe() cannot be called twice
475        self.loop._closed = True
476
477    def test_close(self):
478        self.loop._close_self_pipe = mock.Mock()
479        self.loop.close()
480        self.assertTrue(self.loop._close_self_pipe.called)
481        self.assertTrue(self.proactor.close.called)
482        self.assertIsNone(self.loop._proactor)
483
484        self.loop._close_self_pipe.reset_mock()
485        self.loop.close()
486        self.assertFalse(self.loop._close_self_pipe.called)
487
488    def test_sock_recv(self):
489        self.loop.sock_recv(self.sock, 1024)
490        self.proactor.recv.assert_called_with(self.sock, 1024)
491
492    def test_sock_sendall(self):
493        self.loop.sock_sendall(self.sock, b'data')
494        self.proactor.send.assert_called_with(self.sock, b'data')
495
496    def test_sock_connect(self):
497        self.loop.sock_connect(self.sock, ('1.2.3.4', 123))
498        self.proactor.connect.assert_called_with(self.sock, ('1.2.3.4', 123))
499
500    def test_sock_accept(self):
501        self.loop.sock_accept(self.sock)
502        self.proactor.accept.assert_called_with(self.sock)
503
504    def test_socketpair(self):
505        class EventLoop(BaseProactorEventLoop):
506            # override the destructor to not log a ResourceWarning
507            def __del__(self):
508                pass
509        self.assertRaises(
510            NotImplementedError, EventLoop, self.proactor)
511
512    def test_make_socket_transport(self):
513        tr = self.loop._make_socket_transport(self.sock, asyncio.Protocol())
514        self.assertIsInstance(tr, _ProactorSocketTransport)
515        close_transport(tr)
516
517    def test_loop_self_reading(self):
518        self.loop._loop_self_reading()
519        self.proactor.recv.assert_called_with(self.ssock, 4096)
520        self.proactor.recv.return_value.add_done_callback.assert_called_with(
521            self.loop._loop_self_reading)
522
523    def test_loop_self_reading_fut(self):
524        fut = mock.Mock()
525        self.loop._loop_self_reading(fut)
526        self.assertTrue(fut.result.called)
527        self.proactor.recv.assert_called_with(self.ssock, 4096)
528        self.proactor.recv.return_value.add_done_callback.assert_called_with(
529            self.loop._loop_self_reading)
530
531    def test_loop_self_reading_exception(self):
532        self.loop.close = mock.Mock()
533        self.loop.call_exception_handler = mock.Mock()
534        self.proactor.recv.side_effect = OSError()
535        self.loop._loop_self_reading()
536        self.assertTrue(self.loop.call_exception_handler.called)
537
538    def test_write_to_self(self):
539        self.loop._write_to_self()
540        self.csock.send.assert_called_with(b'\0')
541
542    def test_process_events(self):
543        self.loop._process_events([])
544
545    @mock.patch('asyncio.base_events.logger')
546    def test_create_server(self, m_log):
547        pf = mock.Mock()
548        call_soon = self.loop.call_soon = mock.Mock()
549
550        self.loop._start_serving(pf, self.sock)
551        self.assertTrue(call_soon.called)
552
553        # callback
554        loop = call_soon.call_args[0][0]
555        loop()
556        self.proactor.accept.assert_called_with(self.sock)
557
558        # conn
559        fut = mock.Mock()
560        fut.result.return_value = (mock.Mock(), mock.Mock())
561
562        make_tr = self.loop._make_socket_transport = mock.Mock()
563        loop(fut)
564        self.assertTrue(fut.result.called)
565        self.assertTrue(make_tr.called)
566
567        # exception
568        fut.result.side_effect = OSError()
569        loop(fut)
570        self.assertTrue(self.sock.close.called)
571        self.assertTrue(m_log.error.called)
572
573    def test_create_server_cancel(self):
574        pf = mock.Mock()
575        call_soon = self.loop.call_soon = mock.Mock()
576
577        self.loop._start_serving(pf, self.sock)
578        loop = call_soon.call_args[0][0]
579
580        # cancelled
581        fut = asyncio.Future(loop=self.loop)
582        fut.cancel()
583        loop(fut)
584        self.assertTrue(self.sock.close.called)
585
586    def test_stop_serving(self):
587        sock = mock.Mock()
588        self.loop._stop_serving(sock)
589        self.assertTrue(sock.close.called)
590        self.proactor._stop_serving.assert_called_with(sock)
591
592
593if __name__ == '__main__':
594    unittest.main()
595