1"""
2websocket - WebSocket client library for Python
3
4Copyright (C) 2010 Hiroki Ohtani(liris)
5
6    This library is free software; you can redistribute it and/or
7    modify it under the terms of the GNU Lesser General Public
8    License as published by the Free Software Foundation; either
9    version 2.1 of the License, or (at your option) any later version.
10
11    This library is distributed in the hope that it will be useful,
12    but WITHOUT ANY WARRANTY; without even the implied warranty of
13    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14    Lesser General Public License for more details.
15
16    You should have received a copy of the GNU Lesser General Public
17    License along with this library; if not, write to the Free Software
18    Foundation, Inc., 51 Franklin Street, Fifth Floor,
19    Boston, MA  02110-1335  USA
20
21"""
22import hashlib
23import hmac
24import os
25
26import six
27
28from ._exceptions import *
29from ._http import *
30from ._logging import *
31from ._socket import *
32
33if six.PY3:
34    from base64 import encodebytes as base64encode
35else:
36    from base64 import encodestring as base64encode
37
38__all__ = ["handshake_response", "handshake"]
39
40if hasattr(hmac, "compare_digest"):
41    compare_digest = hmac.compare_digest
42else:
43    def compare_digest(s1, s2):
44        return s1 == s2
45
46# websocket supported version.
47VERSION = 13
48
49
50class handshake_response(object):
51
52    def __init__(self, status, headers, subprotocol):
53        self.status = status
54        self.headers = headers
55        self.subprotocol = subprotocol
56
57
58def handshake(sock, hostname, port, resource, **options):
59    headers, key = _get_handshake_headers(resource, hostname, port, options)
60
61    header_str = "\r\n".join(headers)
62    send(sock, header_str)
63    dump("request header", header_str)
64
65    status, resp = _get_resp_headers(sock)
66    success, subproto = _validate(resp, key, options.get("subprotocols"))
67    if not success:
68        raise WebSocketException("Invalid WebSocket Header")
69
70    return handshake_response(status, resp, subproto)
71
72
73def _get_handshake_headers(resource, host, port, options):
74    headers = [
75        "GET %s HTTP/1.1" % resource,
76        "Upgrade: websocket",
77        "Connection: Upgrade"
78    ]
79    if port == 80 or port == 443:
80        hostport = host
81    else:
82        hostport = "%s:%d" % (host, port)
83
84    if "host" in options and options["host"]:
85        headers.append("Host: %s" % options["host"])
86    else:
87        headers.append("Host: %s" % hostport)
88
89    if "origin" in options and options["origin"]:
90        headers.append("Origin: %s" % options["origin"])
91    else:
92        headers.append("Origin: http://%s" % hostport)
93
94    key = _create_sec_websocket_key()
95    headers.append("Sec-WebSocket-Key: %s" % key)
96    headers.append("Sec-WebSocket-Version: %s" % VERSION)
97
98    subprotocols = options.get("subprotocols")
99    if subprotocols:
100        headers.append("Sec-WebSocket-Protocol: %s" % ",".join(subprotocols))
101
102    if "header" in options:
103        header = options["header"]
104        if isinstance(header, dict):
105            header = map(": ".join, header.items())
106        headers.extend(header)
107
108    cookie = options.get("cookie", None)
109
110    if cookie:
111        headers.append("Cookie: %s" % cookie)
112
113    headers.append("")
114    headers.append("")
115
116    return headers, key
117
118
119def _get_resp_headers(sock, success_status=101):
120    status, resp_headers = read_headers(sock)
121    if status != success_status:
122        raise WebSocketBadStatusException("Handshake status %d", status)
123    return status, resp_headers
124
125_HEADERS_TO_CHECK = {
126    "upgrade": "websocket",
127    "connection": "upgrade",
128}
129
130
131def _validate(headers, key, subprotocols):
132    subproto = None
133    for k, v in _HEADERS_TO_CHECK.items():
134        r = headers.get(k, None)
135        if not r:
136            return False, None
137        r = r.lower()
138        if v != r:
139            return False, None
140
141    if subprotocols:
142        subproto = headers.get("sec-websocket-protocol", None).lower()
143        if not subproto or subproto not in [s.lower() for s in subprotocols]:
144            error("Invalid subprotocol: " + str(subprotocols))
145            return False, None
146
147    result = headers.get("sec-websocket-accept", None)
148    if not result:
149        return False, None
150    result = result.lower()
151
152    if isinstance(result, six.text_type):
153        result = result.encode('utf-8')
154
155    value = (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode('utf-8')
156    hashed = base64encode(hashlib.sha1(value).digest()).strip().lower()
157    success = compare_digest(hashed, result)
158
159    if success:
160        return True, subproto
161    else:
162        return False, None
163
164
165def _create_sec_websocket_key():
166    randomness = os.urandom(16)
167    return base64encode(randomness).decode('utf-8').strip()
168