1import unittest
2from test import test_support
3
4import errno
5import itertools
6import socket
7import select
8import time
9import traceback
10import Queue
11import sys
12import os
13import array
14import contextlib
15import signal
16import math
17import weakref
18try:
19    import _socket
20except ImportError:
21    _socket = None
22
23
24def try_address(host, port=0, family=socket.AF_INET):
25    """Try to bind a socket on the given host:port and return True
26    if that has been possible."""
27    try:
28        sock = socket.socket(family, socket.SOCK_STREAM)
29        sock.bind((host, port))
30    except (socket.error, socket.gaierror):
31        return False
32    else:
33        sock.close()
34        return True
35
36HOST = test_support.HOST
37MSG = b'Michael Gilfix was here\n'
38SUPPORTS_IPV6 = test_support.IPV6_ENABLED
39
40try:
41    import thread
42    import threading
43except ImportError:
44    thread = None
45    threading = None
46
47HOST = test_support.HOST
48MSG = 'Michael Gilfix was here\n'
49
50class SocketTCPTest(unittest.TestCase):
51
52    def setUp(self):
53        self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
54        self.port = test_support.bind_port(self.serv)
55        self.serv.listen(1)
56
57    def tearDown(self):
58        self.serv.close()
59        self.serv = None
60
61class SocketUDPTest(unittest.TestCase):
62
63    def setUp(self):
64        self.serv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
65        self.port = test_support.bind_port(self.serv)
66
67    def tearDown(self):
68        self.serv.close()
69        self.serv = None
70
71class ThreadableTest:
72    """Threadable Test class
73
74    The ThreadableTest class makes it easy to create a threaded
75    client/server pair from an existing unit test. To create a
76    new threaded class from an existing unit test, use multiple
77    inheritance:
78
79        class NewClass (OldClass, ThreadableTest):
80            pass
81
82    This class defines two new fixture functions with obvious
83    purposes for overriding:
84
85        clientSetUp ()
86        clientTearDown ()
87
88    Any new test functions within the class must then define
89    tests in pairs, where the test name is preceded with a
90    '_' to indicate the client portion of the test. Ex:
91
92        def testFoo(self):
93            # Server portion
94
95        def _testFoo(self):
96            # Client portion
97
98    Any exceptions raised by the clients during their tests
99    are caught and transferred to the main thread to alert
100    the testing framework.
101
102    Note, the server setup function cannot call any blocking
103    functions that rely on the client thread during setup,
104    unless serverExplicitReady() is called just before
105    the blocking call (such as in setting up a client/server
106    connection and performing the accept() in setUp().
107    """
108
109    def __init__(self):
110        # Swap the true setup function
111        self.__setUp = self.setUp
112        self.__tearDown = self.tearDown
113        self.setUp = self._setUp
114        self.tearDown = self._tearDown
115
116    def serverExplicitReady(self):
117        """This method allows the server to explicitly indicate that
118        it wants the client thread to proceed. This is useful if the
119        server is about to execute a blocking routine that is
120        dependent upon the client thread during its setup routine."""
121        self.server_ready.set()
122
123    def _setUp(self):
124        self.server_ready = threading.Event()
125        self.client_ready = threading.Event()
126        self.done = threading.Event()
127        self.queue = Queue.Queue(1)
128
129        # Do some munging to start the client test.
130        methodname = self.id()
131        i = methodname.rfind('.')
132        methodname = methodname[i+1:]
133        test_method = getattr(self, '_' + methodname)
134        self.client_thread = thread.start_new_thread(
135            self.clientRun, (test_method,))
136
137        self.__setUp()
138        if not self.server_ready.is_set():
139            self.server_ready.set()
140        self.client_ready.wait()
141
142    def _tearDown(self):
143        self.__tearDown()
144        self.done.wait()
145
146        if not self.queue.empty():
147            msg = self.queue.get()
148            self.fail(msg)
149
150    def clientRun(self, test_func):
151        self.server_ready.wait()
152        self.clientSetUp()
153        self.client_ready.set()
154        if not callable(test_func):
155            raise TypeError("test_func must be a callable function.")
156        try:
157            test_func()
158        except Exception, strerror:
159            self.queue.put(strerror)
160        self.clientTearDown()
161
162    def clientSetUp(self):
163        raise NotImplementedError("clientSetUp must be implemented.")
164
165    def clientTearDown(self):
166        self.done.set()
167        thread.exit()
168
169class ThreadedTCPSocketTest(SocketTCPTest, ThreadableTest):
170
171    def __init__(self, methodName='runTest'):
172        SocketTCPTest.__init__(self, methodName=methodName)
173        ThreadableTest.__init__(self)
174
175    def clientSetUp(self):
176        self.cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
177
178    def clientTearDown(self):
179        self.cli.close()
180        self.cli = None
181        ThreadableTest.clientTearDown(self)
182
183class ThreadedUDPSocketTest(SocketUDPTest, ThreadableTest):
184
185    def __init__(self, methodName='runTest'):
186        SocketUDPTest.__init__(self, methodName=methodName)
187        ThreadableTest.__init__(self)
188
189    def clientSetUp(self):
190        self.cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
191
192    def clientTearDown(self):
193        self.cli.close()
194        self.cli = None
195        ThreadableTest.clientTearDown(self)
196
197class SocketConnectedTest(ThreadedTCPSocketTest):
198
199    def __init__(self, methodName='runTest'):
200        ThreadedTCPSocketTest.__init__(self, methodName=methodName)
201
202    def setUp(self):
203        ThreadedTCPSocketTest.setUp(self)
204        # Indicate explicitly we're ready for the client thread to
205        # proceed and then perform the blocking call to accept
206        self.serverExplicitReady()
207        conn, addr = self.serv.accept()
208        self.cli_conn = conn
209
210    def tearDown(self):
211        self.cli_conn.close()
212        self.cli_conn = None
213        ThreadedTCPSocketTest.tearDown(self)
214
215    def clientSetUp(self):
216        ThreadedTCPSocketTest.clientSetUp(self)
217        self.cli.connect((HOST, self.port))
218        self.serv_conn = self.cli
219
220    def clientTearDown(self):
221        self.serv_conn.close()
222        self.serv_conn = None
223        ThreadedTCPSocketTest.clientTearDown(self)
224
225class SocketPairTest(unittest.TestCase, ThreadableTest):
226
227    def __init__(self, methodName='runTest'):
228        unittest.TestCase.__init__(self, methodName=methodName)
229        ThreadableTest.__init__(self)
230
231    def setUp(self):
232        self.serv, self.cli = socket.socketpair()
233
234    def tearDown(self):
235        self.serv.close()
236        self.serv = None
237
238    def clientSetUp(self):
239        pass
240
241    def clientTearDown(self):
242        self.cli.close()
243        self.cli = None
244        ThreadableTest.clientTearDown(self)
245
246
247#######################################################################
248## Begin Tests
249
250class GeneralModuleTests(unittest.TestCase):
251
252    @unittest.skipUnless(_socket is not None, 'need _socket module')
253    def test_csocket_repr(self):
254        s = _socket.socket(_socket.AF_INET, _socket.SOCK_STREAM)
255        try:
256            expected = ('<socket object, fd=%s, family=%s, type=%s, protocol=%s>'
257                        % (s.fileno(), s.family, s.type, s.proto))
258            self.assertEqual(repr(s), expected)
259        finally:
260            s.close()
261        expected = ('<socket object, fd=-1, family=%s, type=%s, protocol=%s>'
262                    % (s.family, s.type, s.proto))
263        self.assertEqual(repr(s), expected)
264
265    def test_weakref(self):
266        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
267        p = weakref.proxy(s)
268        self.assertEqual(p.fileno(), s.fileno())
269        s.close()
270        s = None
271        try:
272            p.fileno()
273        except ReferenceError:
274            pass
275        else:
276            self.fail('Socket proxy still exists')
277
278    def test_weakref__sock(self):
279        s = socket.socket()._sock
280        w = weakref.ref(s)
281        self.assertIs(w(), s)
282        del s
283        test_support.gc_collect()
284        self.assertIsNone(w())
285
286    def testSocketError(self):
287        # Testing socket module exceptions
288        def raise_error(*args, **kwargs):
289            raise socket.error
290        def raise_herror(*args, **kwargs):
291            raise socket.herror
292        def raise_gaierror(*args, **kwargs):
293            raise socket.gaierror
294        self.assertRaises(socket.error, raise_error,
295                              "Error raising socket exception.")
296        self.assertRaises(socket.error, raise_herror,
297                              "Error raising socket exception.")
298        self.assertRaises(socket.error, raise_gaierror,
299                              "Error raising socket exception.")
300
301    def testSendtoErrors(self):
302        # Testing that sendto doesn't mask failures. See #10169.
303        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
304        self.addCleanup(s.close)
305        s.bind(('', 0))
306        sockname = s.getsockname()
307        # 2 args
308        with self.assertRaises(UnicodeEncodeError):
309            s.sendto(u'\u2620', sockname)
310        with self.assertRaises(TypeError) as cm:
311            s.sendto(5j, sockname)
312        self.assertIn('not complex', str(cm.exception))
313        with self.assertRaises(TypeError) as cm:
314            s.sendto('foo', None)
315        self.assertIn('not NoneType', str(cm.exception))
316        # 3 args
317        with self.assertRaises(UnicodeEncodeError):
318            s.sendto(u'\u2620', 0, sockname)
319        with self.assertRaises(TypeError) as cm:
320            s.sendto(5j, 0, sockname)
321        self.assertIn('not complex', str(cm.exception))
322        with self.assertRaises(TypeError) as cm:
323            s.sendto('foo', 0, None)
324        self.assertIn('not NoneType', str(cm.exception))
325        with self.assertRaises(TypeError) as cm:
326            s.sendto('foo', 'bar', sockname)
327        self.assertIn('an integer is required', str(cm.exception))
328        with self.assertRaises(TypeError) as cm:
329            s.sendto('foo', None, None)
330        self.assertIn('an integer is required', str(cm.exception))
331        # wrong number of args
332        with self.assertRaises(TypeError) as cm:
333            s.sendto('foo')
334        self.assertIn('(1 given)', str(cm.exception))
335        with self.assertRaises(TypeError) as cm:
336            s.sendto('foo', 0, sockname, 4)
337        self.assertIn('(4 given)', str(cm.exception))
338
339
340    def testCrucialConstants(self):
341        # Testing for mission critical constants
342        socket.AF_INET
343        socket.SOCK_STREAM
344        socket.SOCK_DGRAM
345        socket.SOCK_RAW
346        socket.SOCK_RDM
347        socket.SOCK_SEQPACKET
348        socket.SOL_SOCKET
349        socket.SO_REUSEADDR
350
351    def testHostnameRes(self):
352        # Testing hostname resolution mechanisms
353        hostname = socket.gethostname()
354        try:
355            ip = socket.gethostbyname(hostname)
356        except socket.error:
357            # Probably name lookup wasn't set up right; skip this test
358            self.skipTest('name lookup failure')
359        self.assertTrue(ip.find('.') >= 0, "Error resolving host to ip.")
360        try:
361            hname, aliases, ipaddrs = socket.gethostbyaddr(ip)
362        except socket.error:
363            # Probably a similar problem as above; skip this test
364            self.skipTest('address lookup failure')
365        all_host_names = [hostname, hname] + aliases
366        fqhn = socket.getfqdn(ip)
367        if not fqhn in all_host_names:
368            self.fail("Error testing host resolution mechanisms. (fqdn: %s, all: %s)" % (fqhn, repr(all_host_names)))
369
370    @unittest.skipUnless(hasattr(sys, 'getrefcount'),
371                         'test needs sys.getrefcount()')
372    def testRefCountGetNameInfo(self):
373        # Testing reference count for getnameinfo
374        try:
375            # On some versions, this loses a reference
376            orig = sys.getrefcount(__name__)
377            socket.getnameinfo(__name__,0)
378        except TypeError:
379            self.assertEqual(sys.getrefcount(__name__), orig,
380                             "socket.getnameinfo loses a reference")
381
382    def testInterpreterCrash(self):
383        # Making sure getnameinfo doesn't crash the interpreter
384        try:
385            # On some versions, this crashes the interpreter.
386            socket.getnameinfo(('x', 0, 0, 0), 0)
387        except socket.error:
388            pass
389
390    def testNtoH(self):
391        # This just checks that htons etc. are their own inverse,
392        # when looking at the lower 16 or 32 bits.
393        sizes = {socket.htonl: 32, socket.ntohl: 32,
394                 socket.htons: 16, socket.ntohs: 16}
395        for func, size in sizes.items():
396            mask = (1L<<size) - 1
397            for i in (0, 1, 0xffff, ~0xffff, 2, 0x01234567, 0x76543210):
398                self.assertEqual(i & mask, func(func(i&mask)) & mask)
399
400            swapped = func(mask)
401            self.assertEqual(swapped & mask, mask)
402            self.assertRaises(OverflowError, func, 1L<<34)
403
404    def testNtoHErrors(self):
405        good_values = [ 1, 2, 3, 1L, 2L, 3L ]
406        bad_values = [ -1, -2, -3, -1L, -2L, -3L ]
407        for k in good_values:
408            socket.ntohl(k)
409            socket.ntohs(k)
410            socket.htonl(k)
411            socket.htons(k)
412        for k in bad_values:
413            self.assertRaises(OverflowError, socket.ntohl, k)
414            self.assertRaises(OverflowError, socket.ntohs, k)
415            self.assertRaises(OverflowError, socket.htonl, k)
416            self.assertRaises(OverflowError, socket.htons, k)
417
418    def testGetServBy(self):
419        eq = self.assertEqual
420        # Find one service that exists, then check all the related interfaces.
421        # I've ordered this by protocols that have both a tcp and udp
422        # protocol, at least for modern Linuxes.
423        if (sys.platform.startswith('linux') or
424            sys.platform.startswith('freebsd') or
425            sys.platform.startswith('netbsd') or
426            sys.platform == 'darwin'):
427            # avoid the 'echo' service on this platform, as there is an
428            # assumption breaking non-standard port/protocol entry
429            services = ('daytime', 'qotd', 'domain')
430        else:
431            services = ('echo', 'daytime', 'domain')
432        for service in services:
433            try:
434                port = socket.getservbyname(service, 'tcp')
435                break
436            except socket.error:
437                pass
438        else:
439            raise socket.error
440        # Try same call with optional protocol omitted
441        port2 = socket.getservbyname(service)
442        eq(port, port2)
443        # Try udp, but don't barf if it doesn't exist
444        try:
445            udpport = socket.getservbyname(service, 'udp')
446        except socket.error:
447            udpport = None
448        else:
449            eq(udpport, port)
450        # Now make sure the lookup by port returns the same service name
451        eq(socket.getservbyport(port2), service)
452        eq(socket.getservbyport(port, 'tcp'), service)
453        if udpport is not None:
454            eq(socket.getservbyport(udpport, 'udp'), service)
455        # Make sure getservbyport does not accept out of range ports.
456        self.assertRaises(OverflowError, socket.getservbyport, -1)
457        self.assertRaises(OverflowError, socket.getservbyport, 65536)
458
459    def testDefaultTimeout(self):
460        # Testing default timeout
461        # The default timeout should initially be None
462        self.assertEqual(socket.getdefaulttimeout(), None)
463        s = socket.socket()
464        self.assertEqual(s.gettimeout(), None)
465        s.close()
466
467        # Set the default timeout to 10, and see if it propagates
468        socket.setdefaulttimeout(10)
469        self.assertEqual(socket.getdefaulttimeout(), 10)
470        s = socket.socket()
471        self.assertEqual(s.gettimeout(), 10)
472        s.close()
473
474        # Reset the default timeout to None, and see if it propagates
475        socket.setdefaulttimeout(None)
476        self.assertEqual(socket.getdefaulttimeout(), None)
477        s = socket.socket()
478        self.assertEqual(s.gettimeout(), None)
479        s.close()
480
481        # Check that setting it to an invalid value raises ValueError
482        self.assertRaises(ValueError, socket.setdefaulttimeout, -1)
483
484        # Check that setting it to an invalid type raises TypeError
485        self.assertRaises(TypeError, socket.setdefaulttimeout, "spam")
486
487    @unittest.skipUnless(hasattr(socket, 'inet_aton'),
488                         'test needs socket.inet_aton()')
489    def testIPv4_inet_aton_fourbytes(self):
490        # Test that issue1008086 and issue767150 are fixed.
491        # It must return 4 bytes.
492        self.assertEqual('\x00'*4, socket.inet_aton('0.0.0.0'))
493        self.assertEqual('\xff'*4, socket.inet_aton('255.255.255.255'))
494
495    @unittest.skipUnless(hasattr(socket, 'inet_pton'),
496                         'test needs socket.inet_pton()')
497    def testIPv4toString(self):
498        from socket import inet_aton as f, inet_pton, AF_INET
499        g = lambda a: inet_pton(AF_INET, a)
500
501        self.assertEqual('\x00\x00\x00\x00', f('0.0.0.0'))
502        self.assertEqual('\xff\x00\xff\x00', f('255.0.255.0'))
503        self.assertEqual('\xaa\xaa\xaa\xaa', f('170.170.170.170'))
504        self.assertEqual('\x01\x02\x03\x04', f('1.2.3.4'))
505        self.assertEqual('\xff\xff\xff\xff', f('255.255.255.255'))
506
507        self.assertEqual('\x00\x00\x00\x00', g('0.0.0.0'))
508        self.assertEqual('\xff\x00\xff\x00', g('255.0.255.0'))
509        self.assertEqual('\xaa\xaa\xaa\xaa', g('170.170.170.170'))
510        self.assertEqual('\xff\xff\xff\xff', g('255.255.255.255'))
511
512    @unittest.skipUnless(hasattr(socket, 'inet_pton'),
513                         'test needs socket.inet_pton()')
514    def testIPv6toString(self):
515        try:
516            from socket import inet_pton, AF_INET6, has_ipv6
517            if not has_ipv6:
518                self.skipTest('IPv6 not available')
519        except ImportError:
520            self.skipTest('could not import needed symbols from socket')
521        f = lambda a: inet_pton(AF_INET6, a)
522
523        self.assertEqual('\x00' * 16, f('::'))
524        self.assertEqual('\x00' * 16, f('0::0'))
525        self.assertEqual('\x00\x01' + '\x00' * 14, f('1::'))
526        self.assertEqual(
527            '\x45\xef\x76\xcb\x00\x1a\x56\xef\xaf\xeb\x0b\xac\x19\x24\xae\xae',
528            f('45ef:76cb:1a:56ef:afeb:bac:1924:aeae')
529        )
530
531    @unittest.skipUnless(hasattr(socket, 'inet_ntop'),
532                         'test needs socket.inet_ntop()')
533    def testStringToIPv4(self):
534        from socket import inet_ntoa as f, inet_ntop, AF_INET
535        g = lambda a: inet_ntop(AF_INET, a)
536
537        self.assertEqual('1.0.1.0', f('\x01\x00\x01\x00'))
538        self.assertEqual('170.85.170.85', f('\xaa\x55\xaa\x55'))
539        self.assertEqual('255.255.255.255', f('\xff\xff\xff\xff'))
540        self.assertEqual('1.2.3.4', f('\x01\x02\x03\x04'))
541
542        self.assertEqual('1.0.1.0', g('\x01\x00\x01\x00'))
543        self.assertEqual('170.85.170.85', g('\xaa\x55\xaa\x55'))
544        self.assertEqual('255.255.255.255', g('\xff\xff\xff\xff'))
545
546    @unittest.skipUnless(hasattr(socket, 'inet_ntop'),
547                         'test needs socket.inet_ntop()')
548    def testStringToIPv6(self):
549        try:
550            from socket import inet_ntop, AF_INET6, has_ipv6
551            if not has_ipv6:
552                self.skipTest('IPv6 not available')
553        except ImportError:
554            self.skipTest('could not import needed symbols from socket')
555        f = lambda a: inet_ntop(AF_INET6, a)
556
557        self.assertEqual('::', f('\x00' * 16))
558        self.assertEqual('::1', f('\x00' * 15 + '\x01'))
559        self.assertEqual(
560            'aef:b01:506:1001:ffff:9997:55:170',
561            f('\x0a\xef\x0b\x01\x05\x06\x10\x01\xff\xff\x99\x97\x00\x55\x01\x70')
562        )
563
564    # XXX The following don't test module-level functionality...
565
566    def _get_unused_port(self, bind_address='0.0.0.0'):
567        """Use a temporary socket to elicit an unused ephemeral port.
568
569        Args:
570            bind_address: Hostname or IP address to search for a port on.
571
572        Returns: A most likely to be unused port.
573        """
574        tempsock = socket.socket()
575        tempsock.bind((bind_address, 0))
576        host, port = tempsock.getsockname()
577        tempsock.close()
578        return port
579
580    def testSockName(self):
581        # Testing getsockname()
582        port = self._get_unused_port()
583        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
584        self.addCleanup(sock.close)
585        sock.bind(("0.0.0.0", port))
586        name = sock.getsockname()
587        # XXX(nnorwitz): http://tinyurl.com/os5jz seems to indicate
588        # it reasonable to get the host's addr in addition to 0.0.0.0.
589        # At least for eCos.  This is required for the S/390 to pass.
590        try:
591            my_ip_addr = socket.gethostbyname(socket.gethostname())
592        except socket.error:
593            # Probably name lookup wasn't set up right; skip this test
594            self.skipTest('name lookup failure')
595        self.assertIn(name[0], ("0.0.0.0", my_ip_addr), '%s invalid' % name[0])
596        self.assertEqual(name[1], port)
597
598    def testGetSockOpt(self):
599        # Testing getsockopt()
600        # We know a socket should start without reuse==0
601        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
602        self.addCleanup(sock.close)
603        reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
604        self.assertFalse(reuse != 0, "initial mode is reuse")
605
606    def testSetSockOpt(self):
607        # Testing setsockopt()
608        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
609        self.addCleanup(sock.close)
610        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
611        reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
612        self.assertFalse(reuse == 0, "failed to set reuse mode")
613
614    def testSendAfterClose(self):
615        # testing send() after close() with timeout
616        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
617        sock.settimeout(1)
618        sock.close()
619        self.assertRaises(socket.error, sock.send, "spam")
620
621    def testNewAttributes(self):
622        # testing .family, .type and .protocol
623        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
624        self.assertEqual(sock.family, socket.AF_INET)
625        self.assertEqual(sock.type, socket.SOCK_STREAM)
626        self.assertEqual(sock.proto, 0)
627        sock.close()
628
629    def test_getsockaddrarg(self):
630        sock = socket.socket()
631        self.addCleanup(sock.close)
632        port = test_support.find_unused_port()
633        big_port = port + 65536
634        neg_port = port - 65536
635        self.assertRaises(OverflowError, sock.bind, (HOST, big_port))
636        self.assertRaises(OverflowError, sock.bind, (HOST, neg_port))
637        # Since find_unused_port() is inherently subject to race conditions, we
638        # call it a couple times if necessary.
639        for i in itertools.count():
640            port = test_support.find_unused_port()
641            try:
642                sock.bind((HOST, port))
643            except OSError as e:
644                if e.errno != errno.EADDRINUSE or i == 5:
645                    raise
646            else:
647                break
648
649    @unittest.skipUnless(os.name == "nt", "Windows specific")
650    def test_sock_ioctl(self):
651        self.assertTrue(hasattr(socket.socket, 'ioctl'))
652        self.assertTrue(hasattr(socket, 'SIO_RCVALL'))
653        self.assertTrue(hasattr(socket, 'RCVALL_ON'))
654        self.assertTrue(hasattr(socket, 'RCVALL_OFF'))
655        self.assertTrue(hasattr(socket, 'SIO_KEEPALIVE_VALS'))
656        s = socket.socket()
657        self.addCleanup(s.close)
658        self.assertRaises(ValueError, s.ioctl, -1, None)
659        s.ioctl(socket.SIO_KEEPALIVE_VALS, (1, 100, 100))
660
661    def testGetaddrinfo(self):
662        try:
663            socket.getaddrinfo('localhost', 80)
664        except socket.gaierror as err:
665            if err.errno == socket.EAI_SERVICE:
666                # see http://bugs.python.org/issue1282647
667                self.skipTest("buggy libc version")
668            raise
669        # len of every sequence is supposed to be == 5
670        for info in socket.getaddrinfo(HOST, None):
671            self.assertEqual(len(info), 5)
672        # host can be a domain name, a string representation of an
673        # IPv4/v6 address or None
674        socket.getaddrinfo('localhost', 80)
675        socket.getaddrinfo('127.0.0.1', 80)
676        socket.getaddrinfo(None, 80)
677        if SUPPORTS_IPV6:
678            socket.getaddrinfo('::1', 80)
679        # port can be a string service name such as "http", a numeric
680        # port number (int or long), or None
681        socket.getaddrinfo(HOST, "http")
682        socket.getaddrinfo(HOST, 80)
683        socket.getaddrinfo(HOST, 80L)
684        socket.getaddrinfo(HOST, None)
685        # test family and socktype filters
686        infos = socket.getaddrinfo(HOST, None, socket.AF_INET)
687        for family, _, _, _, _ in infos:
688            self.assertEqual(family, socket.AF_INET)
689        infos = socket.getaddrinfo(HOST, None, 0, socket.SOCK_STREAM)
690        for _, socktype, _, _, _ in infos:
691            self.assertEqual(socktype, socket.SOCK_STREAM)
692        # test proto and flags arguments
693        socket.getaddrinfo(HOST, None, 0, 0, socket.SOL_TCP)
694        socket.getaddrinfo(HOST, None, 0, 0, 0, socket.AI_PASSIVE)
695        # a server willing to support both IPv4 and IPv6 will
696        # usually do this
697        socket.getaddrinfo(None, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, 0,
698                           socket.AI_PASSIVE)
699
700        # Issue 17269: test workaround for OS X platform bug segfault
701        if hasattr(socket, 'AI_NUMERICSERV'):
702            try:
703                # The arguments here are undefined and the call may succeed
704                # or fail.  All we care here is that it doesn't segfault.
705                socket.getaddrinfo("localhost", None, 0, 0, 0,
706                                   socket.AI_NUMERICSERV)
707            except socket.gaierror:
708                pass
709
710    def check_sendall_interrupted(self, with_timeout):
711        # socketpair() is not strictly required, but it makes things easier.
712        if not hasattr(signal, 'alarm') or not hasattr(socket, 'socketpair'):
713            self.skipTest("signal.alarm and socket.socketpair required for this test")
714        # Our signal handlers clobber the C errno by calling a math function
715        # with an invalid domain value.
716        def ok_handler(*args):
717            self.assertRaises(ValueError, math.acosh, 0)
718        def raising_handler(*args):
719            self.assertRaises(ValueError, math.acosh, 0)
720            1 // 0
721        c, s = socket.socketpair()
722        old_alarm = signal.signal(signal.SIGALRM, raising_handler)
723        try:
724            if with_timeout:
725                # Just above the one second minimum for signal.alarm
726                c.settimeout(1.5)
727            with self.assertRaises(ZeroDivisionError):
728                signal.alarm(1)
729                c.sendall(b"x" * test_support.SOCK_MAX_SIZE)
730            if with_timeout:
731                signal.signal(signal.SIGALRM, ok_handler)
732                signal.alarm(1)
733                self.assertRaises(socket.timeout, c.sendall,
734                                  b"x" * test_support.SOCK_MAX_SIZE)
735        finally:
736            signal.alarm(0)
737            signal.signal(signal.SIGALRM, old_alarm)
738            c.close()
739            s.close()
740
741    def test_sendall_interrupted(self):
742        self.check_sendall_interrupted(False)
743
744    def test_sendall_interrupted_with_timeout(self):
745        self.check_sendall_interrupted(True)
746
747    def test_listen_backlog(self):
748        for backlog in 0, -1:
749            srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
750            srv.bind((HOST, 0))
751            srv.listen(backlog)
752            srv.close()
753
754    @test_support.cpython_only
755    def test_listen_backlog_overflow(self):
756        # Issue 15989
757        import _testcapi
758        srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
759        srv.bind((HOST, 0))
760        self.assertRaises(OverflowError, srv.listen, _testcapi.INT_MAX + 1)
761        srv.close()
762
763    @unittest.skipUnless(SUPPORTS_IPV6, 'IPv6 required for this test.')
764    def test_flowinfo(self):
765        self.assertRaises(OverflowError, socket.getnameinfo,
766                          ('::1',0, 0xffffffff), 0)
767        s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
768        try:
769            self.assertRaises(OverflowError, s.bind, ('::1', 0, -10))
770        finally:
771            s.close()
772
773
774@unittest.skipUnless(thread, 'Threading required for this test.')
775class BasicTCPTest(SocketConnectedTest):
776
777    def __init__(self, methodName='runTest'):
778        SocketConnectedTest.__init__(self, methodName=methodName)
779
780    def testRecv(self):
781        # Testing large receive over TCP
782        msg = self.cli_conn.recv(1024)
783        self.assertEqual(msg, MSG)
784
785    def _testRecv(self):
786        self.serv_conn.send(MSG)
787
788    def testOverFlowRecv(self):
789        # Testing receive in chunks over TCP
790        seg1 = self.cli_conn.recv(len(MSG) - 3)
791        seg2 = self.cli_conn.recv(1024)
792        msg = seg1 + seg2
793        self.assertEqual(msg, MSG)
794
795    def _testOverFlowRecv(self):
796        self.serv_conn.send(MSG)
797
798    def testRecvFrom(self):
799        # Testing large recvfrom() over TCP
800        msg, addr = self.cli_conn.recvfrom(1024)
801        self.assertEqual(msg, MSG)
802
803    def _testRecvFrom(self):
804        self.serv_conn.send(MSG)
805
806    def testOverFlowRecvFrom(self):
807        # Testing recvfrom() in chunks over TCP
808        seg1, addr = self.cli_conn.recvfrom(len(MSG)-3)
809        seg2, addr = self.cli_conn.recvfrom(1024)
810        msg = seg1 + seg2
811        self.assertEqual(msg, MSG)
812
813    def _testOverFlowRecvFrom(self):
814        self.serv_conn.send(MSG)
815
816    def testSendAll(self):
817        # Testing sendall() with a 2048 byte string over TCP
818        msg = ''
819        while 1:
820            read = self.cli_conn.recv(1024)
821            if not read:
822                break
823            msg += read
824        self.assertEqual(msg, 'f' * 2048)
825
826    def _testSendAll(self):
827        big_chunk = 'f' * 2048
828        self.serv_conn.sendall(big_chunk)
829
830    @unittest.skipUnless(hasattr(socket, 'fromfd'),
831                         'socket.fromfd not available')
832    def testFromFd(self):
833        # Testing fromfd()
834        fd = self.cli_conn.fileno()
835        sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM)
836        self.addCleanup(sock.close)
837        msg = sock.recv(1024)
838        self.assertEqual(msg, MSG)
839
840    def _testFromFd(self):
841        self.serv_conn.send(MSG)
842
843    def testDup(self):
844        # Testing dup()
845        sock = self.cli_conn.dup()
846        self.addCleanup(sock.close)
847        msg = sock.recv(1024)
848        self.assertEqual(msg, MSG)
849
850    def _testDup(self):
851        self.serv_conn.send(MSG)
852
853    def testShutdown(self):
854        # Testing shutdown()
855        msg = self.cli_conn.recv(1024)
856        self.assertEqual(msg, MSG)
857        # wait for _testShutdown to finish: on OS X, when the server
858        # closes the connection the client also becomes disconnected,
859        # and the client's shutdown call will fail. (Issue #4397.)
860        self.done.wait()
861
862    def _testShutdown(self):
863        self.serv_conn.send(MSG)
864        self.serv_conn.shutdown(2)
865
866    testShutdown_overflow = test_support.cpython_only(testShutdown)
867
868    @test_support.cpython_only
869    def _testShutdown_overflow(self):
870        import _testcapi
871        self.serv_conn.send(MSG)
872        # Issue 15989
873        self.assertRaises(OverflowError, self.serv_conn.shutdown,
874                          _testcapi.INT_MAX + 1)
875        self.assertRaises(OverflowError, self.serv_conn.shutdown,
876                          2 + (_testcapi.UINT_MAX + 1))
877        self.serv_conn.shutdown(2)
878
879@unittest.skipUnless(thread, 'Threading required for this test.')
880class BasicUDPTest(ThreadedUDPSocketTest):
881
882    def __init__(self, methodName='runTest'):
883        ThreadedUDPSocketTest.__init__(self, methodName=methodName)
884
885    def testSendtoAndRecv(self):
886        # Testing sendto() and Recv() over UDP
887        msg = self.serv.recv(len(MSG))
888        self.assertEqual(msg, MSG)
889
890    def _testSendtoAndRecv(self):
891        self.cli.sendto(MSG, 0, (HOST, self.port))
892
893    def testRecvFrom(self):
894        # Testing recvfrom() over UDP
895        msg, addr = self.serv.recvfrom(len(MSG))
896        self.assertEqual(msg, MSG)
897
898    def _testRecvFrom(self):
899        self.cli.sendto(MSG, 0, (HOST, self.port))
900
901    def testRecvFromNegative(self):
902        # Negative lengths passed to recvfrom should give ValueError.
903        self.assertRaises(ValueError, self.serv.recvfrom, -1)
904
905    def _testRecvFromNegative(self):
906        self.cli.sendto(MSG, 0, (HOST, self.port))
907
908@unittest.skipUnless(thread, 'Threading required for this test.')
909class TCPCloserTest(ThreadedTCPSocketTest):
910
911    def testClose(self):
912        conn, addr = self.serv.accept()
913        conn.close()
914
915        sd = self.cli
916        read, write, err = select.select([sd], [], [], 1.0)
917        self.assertEqual(read, [sd])
918        self.assertEqual(sd.recv(1), '')
919
920    def _testClose(self):
921        self.cli.connect((HOST, self.port))
922        time.sleep(1.0)
923
924@unittest.skipUnless(hasattr(socket, 'socketpair'),
925                     'test needs socket.socketpair()')
926@unittest.skipUnless(thread, 'Threading required for this test.')
927class BasicSocketPairTest(SocketPairTest):
928
929    def __init__(self, methodName='runTest'):
930        SocketPairTest.__init__(self, methodName=methodName)
931
932    def testRecv(self):
933        msg = self.serv.recv(1024)
934        self.assertEqual(msg, MSG)
935
936    def _testRecv(self):
937        self.cli.send(MSG)
938
939    def testSend(self):
940        self.serv.send(MSG)
941
942    def _testSend(self):
943        msg = self.cli.recv(1024)
944        self.assertEqual(msg, MSG)
945
946@unittest.skipUnless(thread, 'Threading required for this test.')
947class NonBlockingTCPTests(ThreadedTCPSocketTest):
948
949    def __init__(self, methodName='runTest'):
950        ThreadedTCPSocketTest.__init__(self, methodName=methodName)
951
952    def testSetBlocking(self):
953        # Testing whether set blocking works
954        self.serv.setblocking(True)
955        self.assertIsNone(self.serv.gettimeout())
956        self.serv.setblocking(False)
957        self.assertEqual(self.serv.gettimeout(), 0.0)
958        start = time.time()
959        try:
960            self.serv.accept()
961        except socket.error:
962            pass
963        end = time.time()
964        self.assertTrue((end - start) < 1.0, "Error setting non-blocking mode.")
965
966    def _testSetBlocking(self):
967        pass
968
969    @test_support.cpython_only
970    def testSetBlocking_overflow(self):
971        # Issue 15989
972        import _testcapi
973        if _testcapi.UINT_MAX >= _testcapi.ULONG_MAX:
974            self.skipTest('needs UINT_MAX < ULONG_MAX')
975        self.serv.setblocking(False)
976        self.assertEqual(self.serv.gettimeout(), 0.0)
977        self.serv.setblocking(_testcapi.UINT_MAX + 1)
978        self.assertIsNone(self.serv.gettimeout())
979
980    _testSetBlocking_overflow = test_support.cpython_only(_testSetBlocking)
981
982    def testAccept(self):
983        # Testing non-blocking accept
984        self.serv.setblocking(0)
985        try:
986            conn, addr = self.serv.accept()
987        except socket.error:
988            pass
989        else:
990            self.fail("Error trying to do non-blocking accept.")
991        read, write, err = select.select([self.serv], [], [])
992        if self.serv in read:
993            conn, addr = self.serv.accept()
994            conn.close()
995        else:
996            self.fail("Error trying to do accept after select.")
997
998    def _testAccept(self):
999        time.sleep(0.1)
1000        self.cli.connect((HOST, self.port))
1001
1002    def testConnect(self):
1003        # Testing non-blocking connect
1004        conn, addr = self.serv.accept()
1005        conn.close()
1006
1007    def _testConnect(self):
1008        self.cli.settimeout(10)
1009        self.cli.connect((HOST, self.port))
1010
1011    def testRecv(self):
1012        # Testing non-blocking recv
1013        conn, addr = self.serv.accept()
1014        conn.setblocking(0)
1015        try:
1016            msg = conn.recv(len(MSG))
1017        except socket.error:
1018            pass
1019        else:
1020            self.fail("Error trying to do non-blocking recv.")
1021        read, write, err = select.select([conn], [], [])
1022        if conn in read:
1023            msg = conn.recv(len(MSG))
1024            conn.close()
1025            self.assertEqual(msg, MSG)
1026        else:
1027            self.fail("Error during select call to non-blocking socket.")
1028
1029    def _testRecv(self):
1030        self.cli.connect((HOST, self.port))
1031        time.sleep(0.1)
1032        self.cli.send(MSG)
1033
1034@unittest.skipUnless(thread, 'Threading required for this test.')
1035class FileObjectClassTestCase(SocketConnectedTest):
1036
1037    bufsize = -1 # Use default buffer size
1038
1039    def __init__(self, methodName='runTest'):
1040        SocketConnectedTest.__init__(self, methodName=methodName)
1041
1042    def setUp(self):
1043        SocketConnectedTest.setUp(self)
1044        self.serv_file = self.cli_conn.makefile('rb', self.bufsize)
1045
1046    def tearDown(self):
1047        self.serv_file.close()
1048        self.assertTrue(self.serv_file.closed)
1049        SocketConnectedTest.tearDown(self)
1050        self.serv_file = None
1051
1052    def clientSetUp(self):
1053        SocketConnectedTest.clientSetUp(self)
1054        self.cli_file = self.serv_conn.makefile('wb')
1055
1056    def clientTearDown(self):
1057        self.cli_file.close()
1058        self.assertTrue(self.cli_file.closed)
1059        self.cli_file = None
1060        SocketConnectedTest.clientTearDown(self)
1061
1062    def testSmallRead(self):
1063        # Performing small file read test
1064        first_seg = self.serv_file.read(len(MSG)-3)
1065        second_seg = self.serv_file.read(3)
1066        msg = first_seg + second_seg
1067        self.assertEqual(msg, MSG)
1068
1069    def _testSmallRead(self):
1070        self.cli_file.write(MSG)
1071        self.cli_file.flush()
1072
1073    def testFullRead(self):
1074        # read until EOF
1075        msg = self.serv_file.read()
1076        self.assertEqual(msg, MSG)
1077
1078    def _testFullRead(self):
1079        self.cli_file.write(MSG)
1080        self.cli_file.close()
1081
1082    def testUnbufferedRead(self):
1083        # Performing unbuffered file read test
1084        buf = ''
1085        while 1:
1086            char = self.serv_file.read(1)
1087            if not char:
1088                break
1089            buf += char
1090        self.assertEqual(buf, MSG)
1091
1092    def _testUnbufferedRead(self):
1093        self.cli_file.write(MSG)
1094        self.cli_file.flush()
1095
1096    def testReadline(self):
1097        # Performing file readline test
1098        line = self.serv_file.readline()
1099        self.assertEqual(line, MSG)
1100
1101    def _testReadline(self):
1102        self.cli_file.write(MSG)
1103        self.cli_file.flush()
1104
1105    def testReadlineAfterRead(self):
1106        a_baloo_is = self.serv_file.read(len("A baloo is"))
1107        self.assertEqual("A baloo is", a_baloo_is)
1108        _a_bear = self.serv_file.read(len(" a bear"))
1109        self.assertEqual(" a bear", _a_bear)
1110        line = self.serv_file.readline()
1111        self.assertEqual("\n", line)
1112        line = self.serv_file.readline()
1113        self.assertEqual("A BALOO IS A BEAR.\n", line)
1114        line = self.serv_file.readline()
1115        self.assertEqual(MSG, line)
1116
1117    def _testReadlineAfterRead(self):
1118        self.cli_file.write("A baloo is a bear\n")
1119        self.cli_file.write("A BALOO IS A BEAR.\n")
1120        self.cli_file.write(MSG)
1121        self.cli_file.flush()
1122
1123    def testReadlineAfterReadNoNewline(self):
1124        end_of_ = self.serv_file.read(len("End Of "))
1125        self.assertEqual("End Of ", end_of_)
1126        line = self.serv_file.readline()
1127        self.assertEqual("Line", line)
1128
1129    def _testReadlineAfterReadNoNewline(self):
1130        self.cli_file.write("End Of Line")
1131
1132    def testClosedAttr(self):
1133        self.assertTrue(not self.serv_file.closed)
1134
1135    def _testClosedAttr(self):
1136        self.assertTrue(not self.cli_file.closed)
1137
1138
1139class FileObjectInterruptedTestCase(unittest.TestCase):
1140    """Test that the file object correctly handles EINTR internally."""
1141
1142    class MockSocket(object):
1143        def __init__(self, recv_funcs=()):
1144            # A generator that returns callables that we'll call for each
1145            # call to recv().
1146            self._recv_step = iter(recv_funcs)
1147
1148        def recv(self, size):
1149            return self._recv_step.next()()
1150
1151    @staticmethod
1152    def _raise_eintr():
1153        raise socket.error(errno.EINTR)
1154
1155    def _test_readline(self, size=-1, **kwargs):
1156        mock_sock = self.MockSocket(recv_funcs=[
1157                lambda : "This is the first line\nAnd the sec",
1158                self._raise_eintr,
1159                lambda : "ond line is here\n",
1160                lambda : "",
1161            ])
1162        fo = socket._fileobject(mock_sock, **kwargs)
1163        self.assertEqual(fo.readline(size), "This is the first line\n")
1164        self.assertEqual(fo.readline(size), "And the second line is here\n")
1165
1166    def _test_read(self, size=-1, **kwargs):
1167        mock_sock = self.MockSocket(recv_funcs=[
1168                lambda : "This is the first line\nAnd the sec",
1169                self._raise_eintr,
1170                lambda : "ond line is here\n",
1171                lambda : "",
1172            ])
1173        fo = socket._fileobject(mock_sock, **kwargs)
1174        self.assertEqual(fo.read(size), "This is the first line\n"
1175                          "And the second line is here\n")
1176
1177    def test_default(self):
1178        self._test_readline()
1179        self._test_readline(size=100)
1180        self._test_read()
1181        self._test_read(size=100)
1182
1183    def test_with_1k_buffer(self):
1184        self._test_readline(bufsize=1024)
1185        self._test_readline(size=100, bufsize=1024)
1186        self._test_read(bufsize=1024)
1187        self._test_read(size=100, bufsize=1024)
1188
1189    def _test_readline_no_buffer(self, size=-1):
1190        mock_sock = self.MockSocket(recv_funcs=[
1191                lambda : "aa",
1192                lambda : "\n",
1193                lambda : "BB",
1194                self._raise_eintr,
1195                lambda : "bb",
1196                lambda : "",
1197            ])
1198        fo = socket._fileobject(mock_sock, bufsize=0)
1199        self.assertEqual(fo.readline(size), "aa\n")
1200        self.assertEqual(fo.readline(size), "BBbb")
1201
1202    def test_no_buffer(self):
1203        self._test_readline_no_buffer()
1204        self._test_readline_no_buffer(size=4)
1205        self._test_read(bufsize=0)
1206        self._test_read(size=100, bufsize=0)
1207
1208
1209class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase):
1210
1211    """Repeat the tests from FileObjectClassTestCase with bufsize==0.
1212
1213    In this case (and in this case only), it should be possible to
1214    create a file object, read a line from it, create another file
1215    object, read another line from it, without loss of data in the
1216    first file object's buffer.  Note that httplib relies on this
1217    when reading multiple requests from the same socket."""
1218
1219    bufsize = 0 # Use unbuffered mode
1220
1221    def testUnbufferedReadline(self):
1222        # Read a line, create a new file object, read another line with it
1223        line = self.serv_file.readline() # first line
1224        self.assertEqual(line, "A. " + MSG) # first line
1225        self.serv_file = self.cli_conn.makefile('rb', 0)
1226        line = self.serv_file.readline() # second line
1227        self.assertEqual(line, "B. " + MSG) # second line
1228
1229    def _testUnbufferedReadline(self):
1230        self.cli_file.write("A. " + MSG)
1231        self.cli_file.write("B. " + MSG)
1232        self.cli_file.flush()
1233
1234class LineBufferedFileObjectClassTestCase(FileObjectClassTestCase):
1235
1236    bufsize = 1 # Default-buffered for reading; line-buffered for writing
1237
1238    class SocketMemo(object):
1239        """A wrapper to keep track of sent data, needed to examine write behaviour"""
1240        def __init__(self, sock):
1241            self._sock = sock
1242            self.sent = []
1243
1244        def send(self, data, flags=0):
1245            n = self._sock.send(data, flags)
1246            self.sent.append(data[:n])
1247            return n
1248
1249        def sendall(self, data, flags=0):
1250            self._sock.sendall(data, flags)
1251            self.sent.append(data)
1252
1253        def __getattr__(self, attr):
1254            return getattr(self._sock, attr)
1255
1256        def getsent(self):
1257            return [e.tobytes() if isinstance(e, memoryview) else e for e in self.sent]
1258
1259    def setUp(self):
1260        FileObjectClassTestCase.setUp(self)
1261        self.serv_file._sock = self.SocketMemo(self.serv_file._sock)
1262
1263    def testLinebufferedWrite(self):
1264        # Write two lines, in small chunks
1265        msg = MSG.strip()
1266        print >> self.serv_file, msg,
1267        print >> self.serv_file, msg
1268
1269        # second line:
1270        print >> self.serv_file, msg,
1271        print >> self.serv_file, msg,
1272        print >> self.serv_file, msg
1273
1274        # third line
1275        print >> self.serv_file, ''
1276
1277        self.serv_file.flush()
1278
1279        msg1 = "%s %s\n"%(msg, msg)
1280        msg2 =  "%s %s %s\n"%(msg, msg, msg)
1281        msg3 =  "\n"
1282        self.assertEqual(self.serv_file._sock.getsent(), [msg1, msg2, msg3])
1283
1284    def _testLinebufferedWrite(self):
1285        msg = MSG.strip()
1286        msg1 = "%s %s\n"%(msg, msg)
1287        msg2 =  "%s %s %s\n"%(msg, msg, msg)
1288        msg3 =  "\n"
1289        l1 = self.cli_file.readline()
1290        self.assertEqual(l1, msg1)
1291        l2 = self.cli_file.readline()
1292        self.assertEqual(l2, msg2)
1293        l3 = self.cli_file.readline()
1294        self.assertEqual(l3, msg3)
1295
1296
1297class SmallBufferedFileObjectClassTestCase(FileObjectClassTestCase):
1298
1299    bufsize = 2 # Exercise the buffering code
1300
1301
1302class NetworkConnectionTest(object):
1303    """Prove network connection."""
1304    def clientSetUp(self):
1305        # We're inherited below by BasicTCPTest2, which also inherits
1306        # BasicTCPTest, which defines self.port referenced below.
1307        self.cli = socket.create_connection((HOST, self.port))
1308        self.serv_conn = self.cli
1309
1310class BasicTCPTest2(NetworkConnectionTest, BasicTCPTest):
1311    """Tests that NetworkConnection does not break existing TCP functionality.
1312    """
1313
1314class NetworkConnectionNoServer(unittest.TestCase):
1315    class MockSocket(socket.socket):
1316        def connect(self, *args):
1317            raise socket.timeout('timed out')
1318
1319    @contextlib.contextmanager
1320    def mocked_socket_module(self):
1321        """Return a socket which times out on connect"""
1322        old_socket = socket.socket
1323        socket.socket = self.MockSocket
1324        try:
1325            yield
1326        finally:
1327            socket.socket = old_socket
1328
1329    def test_connect(self):
1330        port = test_support.find_unused_port()
1331        cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1332        self.addCleanup(cli.close)
1333        with self.assertRaises(socket.error) as cm:
1334            cli.connect((HOST, port))
1335        self.assertEqual(cm.exception.errno, errno.ECONNREFUSED)
1336
1337    def test_create_connection(self):
1338        # Issue #9792: errors raised by create_connection() should have
1339        # a proper errno attribute.
1340        port = test_support.find_unused_port()
1341        with self.assertRaises(socket.error) as cm:
1342            socket.create_connection((HOST, port))
1343
1344        # Issue #16257: create_connection() calls getaddrinfo() against
1345        # 'localhost'.  This may result in an IPV6 addr being returned
1346        # as well as an IPV4 one:
1347        #   >>> socket.getaddrinfo('localhost', port, 0, SOCK_STREAM)
1348        #   >>> [(2,  2, 0, '', ('127.0.0.1', 41230)),
1349        #        (26, 2, 0, '', ('::1', 41230, 0, 0))]
1350        #
1351        # create_connection() enumerates through all the addresses returned
1352        # and if it doesn't successfully bind to any of them, it propagates
1353        # the last exception it encountered.
1354        #
1355        # On Solaris, ENETUNREACH is returned in this circumstance instead
1356        # of ECONNREFUSED.  So, if that errno exists, add it to our list of
1357        # expected errnos.
1358        expected_errnos = [ errno.ECONNREFUSED, ]
1359        if hasattr(errno, 'ENETUNREACH'):
1360            expected_errnos.append(errno.ENETUNREACH)
1361        if hasattr(errno, 'EADDRNOTAVAIL'):
1362            # bpo-31910: socket.create_connection() fails randomly
1363            # with EADDRNOTAVAIL on Travis CI
1364            expected_errnos.append(errno.EADDRNOTAVAIL)
1365
1366        self.assertIn(cm.exception.errno, expected_errnos)
1367
1368    def test_create_connection_timeout(self):
1369        # Issue #9792: create_connection() should not recast timeout errors
1370        # as generic socket errors.
1371        with self.mocked_socket_module():
1372            with self.assertRaises(socket.timeout):
1373                socket.create_connection((HOST, 1234))
1374
1375
1376@unittest.skipUnless(thread, 'Threading required for this test.')
1377class NetworkConnectionAttributesTest(SocketTCPTest, ThreadableTest):
1378
1379    def __init__(self, methodName='runTest'):
1380        SocketTCPTest.__init__(self, methodName=methodName)
1381        ThreadableTest.__init__(self)
1382
1383    def clientSetUp(self):
1384        self.source_port = test_support.find_unused_port()
1385
1386    def clientTearDown(self):
1387        self.cli.close()
1388        self.cli = None
1389        ThreadableTest.clientTearDown(self)
1390
1391    def _justAccept(self):
1392        conn, addr = self.serv.accept()
1393        conn.close()
1394
1395    testFamily = _justAccept
1396    def _testFamily(self):
1397        self.cli = socket.create_connection((HOST, self.port), timeout=30)
1398        self.addCleanup(self.cli.close)
1399        self.assertEqual(self.cli.family, 2)
1400
1401    testSourceAddress = _justAccept
1402    def _testSourceAddress(self):
1403        self.cli = socket.create_connection((HOST, self.port), timeout=30,
1404                source_address=('', self.source_port))
1405        self.addCleanup(self.cli.close)
1406        self.assertEqual(self.cli.getsockname()[1], self.source_port)
1407        # The port number being used is sufficient to show that the bind()
1408        # call happened.
1409
1410    testTimeoutDefault = _justAccept
1411    def _testTimeoutDefault(self):
1412        # passing no explicit timeout uses socket's global default
1413        self.assertTrue(socket.getdefaulttimeout() is None)
1414        socket.setdefaulttimeout(42)
1415        try:
1416            self.cli = socket.create_connection((HOST, self.port))
1417            self.addCleanup(self.cli.close)
1418        finally:
1419            socket.setdefaulttimeout(None)
1420        self.assertEqual(self.cli.gettimeout(), 42)
1421
1422    testTimeoutNone = _justAccept
1423    def _testTimeoutNone(self):
1424        # None timeout means the same as sock.settimeout(None)
1425        self.assertTrue(socket.getdefaulttimeout() is None)
1426        socket.setdefaulttimeout(30)
1427        try:
1428            self.cli = socket.create_connection((HOST, self.port), timeout=None)
1429            self.addCleanup(self.cli.close)
1430        finally:
1431            socket.setdefaulttimeout(None)
1432        self.assertEqual(self.cli.gettimeout(), None)
1433
1434    testTimeoutValueNamed = _justAccept
1435    def _testTimeoutValueNamed(self):
1436        self.cli = socket.create_connection((HOST, self.port), timeout=30)
1437        self.assertEqual(self.cli.gettimeout(), 30)
1438
1439    testTimeoutValueNonamed = _justAccept
1440    def _testTimeoutValueNonamed(self):
1441        self.cli = socket.create_connection((HOST, self.port), 30)
1442        self.addCleanup(self.cli.close)
1443        self.assertEqual(self.cli.gettimeout(), 30)
1444
1445@unittest.skipUnless(thread, 'Threading required for this test.')
1446class NetworkConnectionBehaviourTest(SocketTCPTest, ThreadableTest):
1447
1448    def __init__(self, methodName='runTest'):
1449        SocketTCPTest.__init__(self, methodName=methodName)
1450        ThreadableTest.__init__(self)
1451
1452    def clientSetUp(self):
1453        pass
1454
1455    def clientTearDown(self):
1456        self.cli.close()
1457        self.cli = None
1458        ThreadableTest.clientTearDown(self)
1459
1460    def testInsideTimeout(self):
1461        conn, addr = self.serv.accept()
1462        self.addCleanup(conn.close)
1463        time.sleep(3)
1464        conn.send("done!")
1465    testOutsideTimeout = testInsideTimeout
1466
1467    def _testInsideTimeout(self):
1468        self.cli = sock = socket.create_connection((HOST, self.port))
1469        data = sock.recv(5)
1470        self.assertEqual(data, "done!")
1471
1472    def _testOutsideTimeout(self):
1473        self.cli = sock = socket.create_connection((HOST, self.port), timeout=1)
1474        self.assertRaises(socket.timeout, lambda: sock.recv(5))
1475
1476
1477class Urllib2FileobjectTest(unittest.TestCase):
1478
1479    # urllib2.HTTPHandler has "borrowed" socket._fileobject, and requires that
1480    # it close the socket if the close c'tor argument is true
1481
1482    def testClose(self):
1483        class MockSocket:
1484            closed = False
1485            def flush(self): pass
1486            def close(self): self.closed = True
1487
1488        # must not close unless we request it: the original use of _fileobject
1489        # by module socket requires that the underlying socket not be closed until
1490        # the _socketobject that created the _fileobject is closed
1491        s = MockSocket()
1492        f = socket._fileobject(s)
1493        f.close()
1494        self.assertTrue(not s.closed)
1495
1496        s = MockSocket()
1497        f = socket._fileobject(s, close=True)
1498        f.close()
1499        self.assertTrue(s.closed)
1500
1501class TCPTimeoutTest(SocketTCPTest):
1502
1503    def testTCPTimeout(self):
1504        def raise_timeout(*args, **kwargs):
1505            self.serv.settimeout(1.0)
1506            self.serv.accept()
1507        self.assertRaises(socket.timeout, raise_timeout,
1508                              "Error generating a timeout exception (TCP)")
1509
1510    def testTimeoutZero(self):
1511        ok = False
1512        try:
1513            self.serv.settimeout(0.0)
1514            foo = self.serv.accept()
1515        except socket.timeout:
1516            self.fail("caught timeout instead of error (TCP)")
1517        except socket.error:
1518            ok = True
1519        except:
1520            self.fail("caught unexpected exception (TCP)")
1521        if not ok:
1522            self.fail("accept() returned success when we did not expect it")
1523
1524    @unittest.skipUnless(hasattr(signal, 'alarm'),
1525                         'test needs signal.alarm()')
1526    def testInterruptedTimeout(self):
1527        # XXX I don't know how to do this test on MSWindows or any other
1528        # plaform that doesn't support signal.alarm() or os.kill(), though
1529        # the bug should have existed on all platforms.
1530        self.serv.settimeout(5.0)   # must be longer than alarm
1531        class Alarm(Exception):
1532            pass
1533        def alarm_handler(signal, frame):
1534            raise Alarm
1535        old_alarm = signal.signal(signal.SIGALRM, alarm_handler)
1536        try:
1537            try:
1538                signal.alarm(2)    # POSIX allows alarm to be up to 1 second early
1539                foo = self.serv.accept()
1540            except socket.timeout:
1541                self.fail("caught timeout instead of Alarm")
1542            except Alarm:
1543                pass
1544            except:
1545                self.fail("caught other exception instead of Alarm:"
1546                          " %s(%s):\n%s" %
1547                          (sys.exc_info()[:2] + (traceback.format_exc(),)))
1548            else:
1549                self.fail("nothing caught")
1550            finally:
1551                signal.alarm(0)         # shut off alarm
1552        except Alarm:
1553            self.fail("got Alarm in wrong place")
1554        finally:
1555            # no alarm can be pending.  Safe to restore old handler.
1556            signal.signal(signal.SIGALRM, old_alarm)
1557
1558class UDPTimeoutTest(SocketUDPTest):
1559
1560    def testUDPTimeout(self):
1561        def raise_timeout(*args, **kwargs):
1562            self.serv.settimeout(1.0)
1563            self.serv.recv(1024)
1564        self.assertRaises(socket.timeout, raise_timeout,
1565                              "Error generating a timeout exception (UDP)")
1566
1567    def testTimeoutZero(self):
1568        ok = False
1569        try:
1570            self.serv.settimeout(0.0)
1571            foo = self.serv.recv(1024)
1572        except socket.timeout:
1573            self.fail("caught timeout instead of error (UDP)")
1574        except socket.error:
1575            ok = True
1576        except:
1577            self.fail("caught unexpected exception (UDP)")
1578        if not ok:
1579            self.fail("recv() returned success when we did not expect it")
1580
1581class TestExceptions(unittest.TestCase):
1582
1583    def testExceptionTree(self):
1584        self.assertTrue(issubclass(socket.error, Exception))
1585        self.assertTrue(issubclass(socket.herror, socket.error))
1586        self.assertTrue(issubclass(socket.gaierror, socket.error))
1587        self.assertTrue(issubclass(socket.timeout, socket.error))
1588
1589@unittest.skipUnless(sys.platform == 'linux', 'Linux specific test')
1590class TestLinuxAbstractNamespace(unittest.TestCase):
1591
1592    UNIX_PATH_MAX = 108
1593
1594    def testLinuxAbstractNamespace(self):
1595        address = "\x00python-test-hello\x00\xff"
1596        s1 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1597        s1.bind(address)
1598        s1.listen(1)
1599        s2 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1600        s2.connect(s1.getsockname())
1601        s1.accept()
1602        self.assertEqual(s1.getsockname(), address)
1603        self.assertEqual(s2.getpeername(), address)
1604
1605    def testMaxName(self):
1606        address = "\x00" + "h" * (self.UNIX_PATH_MAX - 1)
1607        s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1608        s.bind(address)
1609        self.assertEqual(s.getsockname(), address)
1610
1611    def testNameOverflow(self):
1612        address = "\x00" + "h" * self.UNIX_PATH_MAX
1613        s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1614        self.assertRaises(socket.error, s.bind, address)
1615
1616
1617@unittest.skipUnless(thread, 'Threading required for this test.')
1618class BufferIOTest(SocketConnectedTest):
1619    """
1620    Test the buffer versions of socket.recv() and socket.send().
1621    """
1622    def __init__(self, methodName='runTest'):
1623        SocketConnectedTest.__init__(self, methodName=methodName)
1624
1625    def testRecvIntoArray(self):
1626        buf = array.array('c', ' '*1024)
1627        nbytes = self.cli_conn.recv_into(buf)
1628        self.assertEqual(nbytes, len(MSG))
1629        msg = buf.tostring()[:len(MSG)]
1630        self.assertEqual(msg, MSG)
1631
1632    def _testRecvIntoArray(self):
1633        with test_support.check_py3k_warnings():
1634            buf = buffer(MSG)
1635        self.serv_conn.send(buf)
1636
1637    def testRecvIntoBytearray(self):
1638        buf = bytearray(1024)
1639        nbytes = self.cli_conn.recv_into(buf)
1640        self.assertEqual(nbytes, len(MSG))
1641        msg = buf[:len(MSG)]
1642        self.assertEqual(msg, MSG)
1643
1644    _testRecvIntoBytearray = _testRecvIntoArray
1645
1646    def testRecvIntoMemoryview(self):
1647        buf = bytearray(1024)
1648        nbytes = self.cli_conn.recv_into(memoryview(buf))
1649        self.assertEqual(nbytes, len(MSG))
1650        msg = buf[:len(MSG)]
1651        self.assertEqual(msg, MSG)
1652
1653    _testRecvIntoMemoryview = _testRecvIntoArray
1654
1655    def testRecvFromIntoArray(self):
1656        buf = array.array('c', ' '*1024)
1657        nbytes, addr = self.cli_conn.recvfrom_into(buf)
1658        self.assertEqual(nbytes, len(MSG))
1659        msg = buf.tostring()[:len(MSG)]
1660        self.assertEqual(msg, MSG)
1661
1662    def _testRecvFromIntoArray(self):
1663        with test_support.check_py3k_warnings():
1664            buf = buffer(MSG)
1665        self.serv_conn.send(buf)
1666
1667    def testRecvFromIntoBytearray(self):
1668        buf = bytearray(1024)
1669        nbytes, addr = self.cli_conn.recvfrom_into(buf)
1670        self.assertEqual(nbytes, len(MSG))
1671        msg = buf[:len(MSG)]
1672        self.assertEqual(msg, MSG)
1673
1674    _testRecvFromIntoBytearray = _testRecvFromIntoArray
1675
1676    def testRecvFromIntoMemoryview(self):
1677        buf = bytearray(1024)
1678        nbytes, addr = self.cli_conn.recvfrom_into(memoryview(buf))
1679        self.assertEqual(nbytes, len(MSG))
1680        msg = buf[:len(MSG)]
1681        self.assertEqual(msg, MSG)
1682
1683    _testRecvFromIntoMemoryview = _testRecvFromIntoArray
1684
1685    def testRecvFromIntoSmallBuffer(self):
1686        # See issue #20246.
1687        buf = bytearray(8)
1688        self.assertRaises(ValueError, self.cli_conn.recvfrom_into, buf, 1024)
1689
1690    def _testRecvFromIntoSmallBuffer(self):
1691        with test_support.check_py3k_warnings():
1692            buf = buffer(MSG)
1693        self.serv_conn.send(buf)
1694
1695    def testRecvFromIntoEmptyBuffer(self):
1696        buf = bytearray()
1697        self.cli_conn.recvfrom_into(buf)
1698        self.cli_conn.recvfrom_into(buf, 0)
1699
1700    _testRecvFromIntoEmptyBuffer = _testRecvFromIntoArray
1701
1702
1703TIPC_STYPE = 2000
1704TIPC_LOWER = 200
1705TIPC_UPPER = 210
1706
1707def isTipcAvailable():
1708    """Check if the TIPC module is loaded
1709
1710    The TIPC module is not loaded automatically on Ubuntu and probably
1711    other Linux distros.
1712    """
1713    if not hasattr(socket, "AF_TIPC"):
1714        return False
1715    try:
1716        f = open("/proc/modules")
1717    except IOError as e:
1718        # It's ok if the file does not exist, is a directory or if we
1719        # have not the permission to read it. In any other case it's a
1720        # real error, so raise it again.
1721        if e.errno in (errno.ENOENT, errno.EISDIR, errno.EACCES):
1722            return False
1723        else:
1724            raise
1725    with f:
1726        for line in f:
1727            if line.startswith("tipc "):
1728                return True
1729    return False
1730
1731@unittest.skipUnless(isTipcAvailable(),
1732                     "TIPC module is not loaded, please 'sudo modprobe tipc'")
1733class TIPCTest(unittest.TestCase):
1734    def testRDM(self):
1735        srv = socket.socket(socket.AF_TIPC, socket.SOCK_RDM)
1736        cli = socket.socket(socket.AF_TIPC, socket.SOCK_RDM)
1737
1738        srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1739        srvaddr = (socket.TIPC_ADDR_NAMESEQ, TIPC_STYPE,
1740                TIPC_LOWER, TIPC_UPPER)
1741        srv.bind(srvaddr)
1742
1743        sendaddr = (socket.TIPC_ADDR_NAME, TIPC_STYPE,
1744                TIPC_LOWER + (TIPC_UPPER - TIPC_LOWER) / 2, 0)
1745        cli.sendto(MSG, sendaddr)
1746
1747        msg, recvaddr = srv.recvfrom(1024)
1748
1749        self.assertEqual(cli.getsockname(), recvaddr)
1750        self.assertEqual(msg, MSG)
1751
1752
1753@unittest.skipUnless(isTipcAvailable(),
1754                     "TIPC module is not loaded, please 'sudo modprobe tipc'")
1755class TIPCThreadableTest(unittest.TestCase, ThreadableTest):
1756    def __init__(self, methodName = 'runTest'):
1757        unittest.TestCase.__init__(self, methodName = methodName)
1758        ThreadableTest.__init__(self)
1759
1760    def setUp(self):
1761        self.srv = socket.socket(socket.AF_TIPC, socket.SOCK_STREAM)
1762        self.srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1763        srvaddr = (socket.TIPC_ADDR_NAMESEQ, TIPC_STYPE,
1764                TIPC_LOWER, TIPC_UPPER)
1765        self.srv.bind(srvaddr)
1766        self.srv.listen(5)
1767        self.serverExplicitReady()
1768        self.conn, self.connaddr = self.srv.accept()
1769
1770    def clientSetUp(self):
1771        # There is a hittable race between serverExplicitReady() and the
1772        # accept() call; sleep a little while to avoid it, otherwise
1773        # we could get an exception
1774        time.sleep(0.1)
1775        self.cli = socket.socket(socket.AF_TIPC, socket.SOCK_STREAM)
1776        addr = (socket.TIPC_ADDR_NAME, TIPC_STYPE,
1777                TIPC_LOWER + (TIPC_UPPER - TIPC_LOWER) / 2, 0)
1778        self.cli.connect(addr)
1779        self.cliaddr = self.cli.getsockname()
1780
1781    def testStream(self):
1782        msg = self.conn.recv(1024)
1783        self.assertEqual(msg, MSG)
1784        self.assertEqual(self.cliaddr, self.connaddr)
1785
1786    def _testStream(self):
1787        self.cli.send(MSG)
1788        self.cli.close()
1789
1790
1791def test_main():
1792    tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest,
1793             TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest,
1794             UDPTimeoutTest ]
1795
1796    tests.extend([
1797        NonBlockingTCPTests,
1798        FileObjectClassTestCase,
1799        FileObjectInterruptedTestCase,
1800        UnbufferedFileObjectClassTestCase,
1801        LineBufferedFileObjectClassTestCase,
1802        SmallBufferedFileObjectClassTestCase,
1803        Urllib2FileobjectTest,
1804        NetworkConnectionNoServer,
1805        NetworkConnectionAttributesTest,
1806        NetworkConnectionBehaviourTest,
1807    ])
1808    tests.append(BasicSocketPairTest)
1809    tests.append(TestLinuxAbstractNamespace)
1810    tests.extend([TIPCTest, TIPCThreadableTest])
1811
1812    thread_info = test_support.threading_setup()
1813    test_support.run_unittest(*tests)
1814    test_support.threading_cleanup(*thread_info)
1815
1816if __name__ == "__main__":
1817    test_main()
1818