1"""Mock socket module used by the smtpd and smtplib tests.
2"""
3
4# imported for _GLOBAL_DEFAULT_TIMEOUT
5import socket as socket_module
6
7# Mock socket module
8_defaulttimeout = None
9_reply_data = None
10
11# This is used to queue up data to be read through socket.makefile, typically
12# *before* the socket object is even created. It is intended to handle a single
13# line which the socket will feed on recv() or makefile().
14def reply_with(line):
15    global _reply_data
16    _reply_data = line
17
18
19class MockFile:
20    """Mock file object returned by MockSocket.makefile().
21    """
22    def __init__(self, lines):
23        self.lines = lines
24    def readline(self, limit=-1):
25        result = self.lines.pop(0) + b'\r\n'
26        if limit >= 0:
27            # Re-insert the line, removing the \r\n we added.
28            self.lines.insert(0, result[limit:-2])
29            result = result[:limit]
30        return result
31    def close(self):
32        pass
33
34
35class MockSocket:
36    """Mock socket object used by smtpd and smtplib tests.
37    """
38    def __init__(self, family=None):
39        global _reply_data
40        self.family = family
41        self.output = []
42        self.lines = []
43        if _reply_data:
44            self.lines.append(_reply_data)
45            _reply_data = None
46        self.conn = None
47        self.timeout = None
48
49    def queue_recv(self, line):
50        self.lines.append(line)
51
52    def recv(self, bufsize, flags=None):
53        data = self.lines.pop(0) + b'\r\n'
54        return data
55
56    def fileno(self):
57        return 0
58
59    def settimeout(self, timeout):
60        if timeout is None:
61            self.timeout = _defaulttimeout
62        else:
63            self.timeout = timeout
64
65    def gettimeout(self):
66        return self.timeout
67
68    def setsockopt(self, level, optname, value):
69        pass
70
71    def getsockopt(self, level, optname, buflen=None):
72        return 0
73
74    def bind(self, address):
75        pass
76
77    def accept(self):
78        self.conn = MockSocket()
79        return self.conn, 'c'
80
81    def getsockname(self):
82        return ('0.0.0.0', 0)
83
84    def setblocking(self, flag):
85        pass
86
87    def listen(self, backlog):
88        pass
89
90    def makefile(self, mode='r', bufsize=-1):
91        handle = MockFile(self.lines)
92        return handle
93
94    def sendall(self, data, flags=None):
95        self.last = data
96        self.output.append(data)
97        return len(data)
98
99    def send(self, data, flags=None):
100        self.last = data
101        self.output.append(data)
102        return len(data)
103
104    def getpeername(self):
105        return ('peer-address', 'peer-port')
106
107    def close(self):
108        pass
109
110
111def socket(family=None, type=None, proto=None):
112    return MockSocket(family)
113
114def create_connection(address, timeout=socket_module._GLOBAL_DEFAULT_TIMEOUT,
115                      source_address=None):
116    try:
117        int_port = int(address[1])
118    except ValueError:
119        raise error
120    ms = MockSocket()
121    if timeout is socket_module._GLOBAL_DEFAULT_TIMEOUT:
122        timeout = getdefaulttimeout()
123    ms.settimeout(timeout)
124    return ms
125
126
127def setdefaulttimeout(timeout):
128    global _defaulttimeout
129    _defaulttimeout = timeout
130
131
132def getdefaulttimeout():
133    return _defaulttimeout
134
135
136def getfqdn():
137    return ""
138
139
140def gethostname():
141    pass
142
143
144def gethostbyname(name):
145    return ""
146
147def getaddrinfo(*args, **kw):
148    return socket_module.getaddrinfo(*args, **kw)
149
150gaierror = socket_module.gaierror
151error = socket_module.error
152
153
154# Constants
155AF_INET = socket_module.AF_INET
156AF_INET6 = socket_module.AF_INET6
157SOCK_STREAM = socket_module.SOCK_STREAM
158SOL_SOCKET = None
159SO_REUSEADDR = None
160