1"""
2websocket - WebSocket client library for Python
3
4Copyright (C) 2010 Hiroki Ohtani(liris)
5
6    This library is free software; you can redistribute it and/or
7    modify it under the terms of the GNU Lesser General Public
8    License as published by the Free Software Foundation; either
9    version 2.1 of the License, or (at your option) any later version.
10
11    This library is distributed in the hope that it will be useful,
12    but WITHOUT ANY WARRANTY; without even the implied warranty of
13    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14    Lesser General Public License for more details.
15
16    You should have received a copy of the GNU Lesser General Public
17    License along with this library; if not, write to the Free Software
18    Foundation, Inc., 51 Franklin Street, Fifth Floor,
19    Boston, MA  02110-1335  USA
20
21"""
22import errno
23import os
24import socket
25import sys
26
27import six
28
29from ._exceptions import *
30from ._logging import *
31from ._socket import*
32from ._ssl_compat import *
33from ._url import *
34
35if six.PY3:
36    from base64 import encodebytes as base64encode
37else:
38    from base64 import encodestring as base64encode
39
40__all__ = ["proxy_info", "connect", "read_headers"]
41
42
43class proxy_info(object):
44
45    def __init__(self, **options):
46        self.host = options.get("http_proxy_host", None)
47        if self.host:
48            self.port = options.get("http_proxy_port", 0)
49            self.auth = options.get("http_proxy_auth", None)
50            self.no_proxy = options.get("http_no_proxy", None)
51        else:
52            self.port = 0
53            self.auth = None
54            self.no_proxy = None
55
56
57def connect(url, options, proxy, socket):
58    hostname, port, resource, is_secure = parse_url(url)
59
60    if socket:
61        return socket, (hostname, port, resource)
62
63    addrinfo_list, need_tunnel, auth = _get_addrinfo_list(
64        hostname, port, is_secure, proxy)
65    if not addrinfo_list:
66        raise WebSocketException(
67            "Host not found.: " + hostname + ":" + str(port))
68
69    sock = None
70    try:
71        sock = _open_socket(addrinfo_list, options.sockopt, options.timeout)
72        if need_tunnel:
73            sock = _tunnel(sock, hostname, port, auth)
74
75        if is_secure:
76            if HAVE_SSL:
77                sock = _ssl_socket(sock, options.sslopt, hostname)
78            else:
79                raise WebSocketException("SSL not available.")
80
81        return sock, (hostname, port, resource)
82    except:
83        if sock:
84            sock.close()
85        raise
86
87
88def _get_addrinfo_list(hostname, port, is_secure, proxy):
89    phost, pport, pauth = get_proxy_info(
90        hostname, is_secure, proxy.host, proxy.port, proxy.auth, proxy.no_proxy)
91    if not phost:
92        addrinfo_list = socket.getaddrinfo(
93            hostname, port, 0, 0, socket.SOL_TCP)
94        return addrinfo_list, False, None
95    else:
96        pport = pport and pport or 80
97        addrinfo_list = socket.getaddrinfo(phost, pport, 0, 0, socket.SOL_TCP)
98        return addrinfo_list, True, pauth
99
100
101def _open_socket(addrinfo_list, sockopt, timeout):
102    err = None
103    for addrinfo in addrinfo_list:
104        family = addrinfo[0]
105        sock = socket.socket(family)
106        sock.settimeout(timeout)
107        for opts in DEFAULT_SOCKET_OPTION:
108            sock.setsockopt(*opts)
109        for opts in sockopt:
110            sock.setsockopt(*opts)
111
112        address = addrinfo[4]
113        try:
114            sock.connect(address)
115        except socket.error as error:
116            error.remote_ip = str(address[0])
117            if error.errno in (errno.ECONNREFUSED, ):
118                err = error
119                continue
120            else:
121                raise
122        else:
123            break
124    else:
125        raise err
126
127    return sock
128
129
130def _can_use_sni():
131    return six.PY2 and sys.version_info >= (2, 7, 9) or sys.version_info >= (3, 2)
132
133
134def _wrap_sni_socket(sock, sslopt, hostname, check_hostname):
135    context = ssl.SSLContext(sslopt.get('ssl_version', ssl.PROTOCOL_SSLv23))
136
137    if sslopt.get('cert_reqs', ssl.CERT_NONE) != ssl.CERT_NONE:
138        context.load_verify_locations(cafile=sslopt.get('ca_certs', None))
139    if sslopt.get('certfile', None):
140        context.load_cert_chain(
141            sslopt['certfile'],
142            sslopt.get('keyfile', None),
143            sslopt.get('password', None),
144        )
145    # see
146    # https://github.com/liris/websocket-client/commit/b96a2e8fa765753e82eea531adb19716b52ca3ca#commitcomment-10803153
147    context.verify_mode = sslopt['cert_reqs']
148    if HAVE_CONTEXT_CHECK_HOSTNAME:
149        context.check_hostname = check_hostname
150    if 'ciphers' in sslopt:
151        context.set_ciphers(sslopt['ciphers'])
152    if 'cert_chain' in sslopt:
153        certfile, keyfile, password = sslopt['cert_chain']
154        context.load_cert_chain(certfile, keyfile, password)
155
156    return context.wrap_socket(
157        sock,
158        do_handshake_on_connect=sslopt.get('do_handshake_on_connect', True),
159        suppress_ragged_eofs=sslopt.get('suppress_ragged_eofs', True),
160        server_hostname=hostname,
161    )
162
163
164def _ssl_socket(sock, user_sslopt, hostname):
165    sslopt = dict(cert_reqs=ssl.CERT_REQUIRED)
166    sslopt.update(user_sslopt)
167
168    if os.environ.get('WEBSOCKET_CLIENT_CA_BUNDLE'):
169        certPath = os.environ.get('WEBSOCKET_CLIENT_CA_BUNDLE')
170    else:
171        certPath = os.path.join(
172            os.path.dirname(__file__), "cacert.pem")
173    if os.path.isfile(certPath) and user_sslopt.get('ca_certs', None) is None:
174        sslopt['ca_certs'] = certPath
175    check_hostname = sslopt["cert_reqs"] != ssl.CERT_NONE and sslopt.pop(
176        'check_hostname', True)
177
178    if _can_use_sni():
179        sock = _wrap_sni_socket(sock, sslopt, hostname, check_hostname)
180    else:
181        sslopt.pop('check_hostname', True)
182        sock = ssl.wrap_socket(sock, **sslopt)
183
184    if not HAVE_CONTEXT_CHECK_HOSTNAME and check_hostname:
185        match_hostname(sock.getpeercert(), hostname)
186
187    return sock
188
189
190def _tunnel(sock, host, port, auth):
191    debug("Connecting proxy...")
192    connect_header = "CONNECT %s:%d HTTP/1.0\r\n" % (host, port)
193    # TODO: support digest auth.
194    if auth and auth[0]:
195        auth_str = auth[0]
196        if auth[1]:
197            auth_str += ":" + auth[1]
198        encoded_str = base64encode(auth_str.encode()).strip().decode()
199        connect_header += "Proxy-Authorization: Basic %s\r\n" % encoded_str
200    connect_header += "\r\n"
201    dump("request header", connect_header)
202
203    send(sock, connect_header)
204
205    try:
206        status, resp_headers = read_headers(sock)
207    except Exception as e:
208        raise WebSocketProxyException(str(e))
209
210    if status != 200:
211        raise WebSocketProxyException(
212            "failed CONNECT via proxy status: %r" % status)
213
214    return sock
215
216
217def read_headers(sock):
218    status = None
219    headers = {}
220    trace("--- response header ---")
221
222    while True:
223        line = recv_line(sock)
224        line = line.decode('utf-8').strip()
225        if not line:
226            break
227        trace(line)
228        if not status:
229
230            status_info = line.split(" ", 2)
231            status = int(status_info[1])
232        else:
233            kv = line.split(":", 1)
234            if len(kv) == 2:
235                key, value = kv
236                headers[key.lower()] = value.strip()
237            else:
238                raise WebSocketException("Invalid header")
239
240    trace("-----------------------")
241
242    return status, headers
243