1"""
2websocket - WebSocket client library for Python
3
4Copyright (C) 2010 Hiroki Ohtani(liris)
5
6    This library is free software; you can redistribute it and/or
7    modify it under the terms of the GNU Lesser General Public
8    License as published by the Free Software Foundation; either
9    version 2.1 of the License, or (at your option) any later version.
10
11    This library is distributed in the hope that it will be useful,
12    but WITHOUT ANY WARRANTY; without even the implied warranty of
13    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14    Lesser General Public License for more details.
15
16    You should have received a copy of the GNU Lesser General Public
17    License along with this library; if not, write to the Free Software
18    Foundation, Inc., 51 Franklin Street, Fifth Floor,
19    Boston, MA  02110-1335  USA
20
21"""
22import array
23import os
24import struct
25
26import six
27
28from ._exceptions import *
29from ._utils import validate_utf8
30
31try:
32    # If wsaccel is available we use compiled routines to mask data.
33    from wsaccel.xormask import XorMaskerSimple
34
35    def _mask(_m, _d):
36        return XorMaskerSimple(_m).process(_d)
37
38except ImportError:
39    # wsaccel is not available, we rely on python implementations.
40    def _mask(_m, _d):
41        for i in range(len(_d)):
42            _d[i] ^= _m[i % 4]
43
44        if six.PY3:
45            return _d.tobytes()
46        else:
47            return _d.tostring()
48
49__all__ = [
50    'ABNF', 'continuous_frame', 'frame_buffer',
51    'STATUS_NORMAL',
52    'STATUS_GOING_AWAY',
53    'STATUS_PROTOCOL_ERROR',
54    'STATUS_UNSUPPORTED_DATA_TYPE',
55    'STATUS_STATUS_NOT_AVAILABLE',
56    'STATUS_ABNORMAL_CLOSED',
57    'STATUS_INVALID_PAYLOAD',
58    'STATUS_POLICY_VIOLATION',
59    'STATUS_MESSAGE_TOO_BIG',
60    'STATUS_INVALID_EXTENSION',
61    'STATUS_UNEXPECTED_CONDITION',
62    'STATUS_BAD_GATEWAY',
63    'STATUS_TLS_HANDSHAKE_ERROR',
64]
65
66# closing frame status codes.
67STATUS_NORMAL = 1000
68STATUS_GOING_AWAY = 1001
69STATUS_PROTOCOL_ERROR = 1002
70STATUS_UNSUPPORTED_DATA_TYPE = 1003
71STATUS_STATUS_NOT_AVAILABLE = 1005
72STATUS_ABNORMAL_CLOSED = 1006
73STATUS_INVALID_PAYLOAD = 1007
74STATUS_POLICY_VIOLATION = 1008
75STATUS_MESSAGE_TOO_BIG = 1009
76STATUS_INVALID_EXTENSION = 1010
77STATUS_UNEXPECTED_CONDITION = 1011
78STATUS_BAD_GATEWAY = 1014
79STATUS_TLS_HANDSHAKE_ERROR = 1015
80
81VALID_CLOSE_STATUS = (
82    STATUS_NORMAL,
83    STATUS_GOING_AWAY,
84    STATUS_PROTOCOL_ERROR,
85    STATUS_UNSUPPORTED_DATA_TYPE,
86    STATUS_INVALID_PAYLOAD,
87    STATUS_POLICY_VIOLATION,
88    STATUS_MESSAGE_TOO_BIG,
89    STATUS_INVALID_EXTENSION,
90    STATUS_UNEXPECTED_CONDITION,
91    STATUS_BAD_GATEWAY,
92)
93
94
95class ABNF(object):
96    """
97    ABNF frame class.
98    see http://tools.ietf.org/html/rfc5234
99    and http://tools.ietf.org/html/rfc6455#section-5.2
100    """
101
102    # operation code values.
103    OPCODE_CONT = 0x0
104    OPCODE_TEXT = 0x1
105    OPCODE_BINARY = 0x2
106    OPCODE_CLOSE = 0x8
107    OPCODE_PING = 0x9
108    OPCODE_PONG = 0xa
109
110    # available operation code value tuple
111    OPCODES = (OPCODE_CONT, OPCODE_TEXT, OPCODE_BINARY, OPCODE_CLOSE,
112               OPCODE_PING, OPCODE_PONG)
113
114    # opcode human readable string
115    OPCODE_MAP = {
116        OPCODE_CONT: "cont",
117        OPCODE_TEXT: "text",
118        OPCODE_BINARY: "binary",
119        OPCODE_CLOSE: "close",
120        OPCODE_PING: "ping",
121        OPCODE_PONG: "pong"
122    }
123
124    # data length threshold.
125    LENGTH_7 = 0x7e
126    LENGTH_16 = 1 << 16
127    LENGTH_63 = 1 << 63
128
129    def __init__(self, fin=0, rsv1=0, rsv2=0, rsv3=0,
130                 opcode=OPCODE_TEXT, mask=1, data=""):
131        """
132        Constructor for ABNF.
133        please check RFC for arguments.
134        """
135        self.fin = fin
136        self.rsv1 = rsv1
137        self.rsv2 = rsv2
138        self.rsv3 = rsv3
139        self.opcode = opcode
140        self.mask = mask
141        if data is None:
142            data = ""
143        self.data = data
144        self.get_mask_key = os.urandom
145
146    def validate(self, skip_utf8_validation=False):
147        """
148        validate the ABNF frame.
149        skip_utf8_validation: skip utf8 validation.
150        """
151        if self.rsv1 or self.rsv2 or self.rsv3:
152            raise WebSocketProtocolException("rsv is not implemented, yet")
153
154        if self.opcode not in ABNF.OPCODES:
155            raise WebSocketProtocolException("Invalid opcode %r", self.opcode)
156
157        if self.opcode == ABNF.OPCODE_PING and not self.fin:
158            raise WebSocketProtocolException("Invalid ping frame.")
159
160        if self.opcode == ABNF.OPCODE_CLOSE:
161            l = len(self.data)
162            if not l:
163                return
164            if l == 1 or l >= 126:
165                raise WebSocketProtocolException("Invalid close frame.")
166            if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]):
167                raise WebSocketProtocolException("Invalid close frame.")
168
169            code = 256 * \
170                six.byte2int(self.data[0:1]) + six.byte2int(self.data[1:2])
171            if not self._is_valid_close_status(code):
172                raise WebSocketProtocolException("Invalid close opcode.")
173
174    @staticmethod
175    def _is_valid_close_status(code):
176        return code in VALID_CLOSE_STATUS or (3000 <= code < 5000)
177
178    def __str__(self):
179        return "fin=" + str(self.fin) \
180            + " opcode=" + str(self.opcode) \
181            + " data=" + str(self.data)
182
183    @staticmethod
184    def create_frame(data, opcode, fin=1):
185        """
186        create frame to send text, binary and other data.
187
188        data: data to send. This is string value(byte array).
189            if opcode is OPCODE_TEXT and this value is unicode,
190            data value is converted into unicode string, automatically.
191
192        opcode: operation code. please see OPCODE_XXX.
193
194        fin: fin flag. if set to 0, create continue fragmentation.
195        """
196        if opcode == ABNF.OPCODE_TEXT and isinstance(data, six.text_type):
197            data = data.encode("utf-8")
198        # mask must be set if send data from client
199        return ABNF(fin, 0, 0, 0, opcode, 1, data)
200
201    def format(self):
202        """
203        format this object to string(byte array) to send data to server.
204        """
205        if any(x not in (0, 1) for x in [self.fin, self.rsv1, self.rsv2, self.rsv3]):
206            raise ValueError("not 0 or 1")
207        if self.opcode not in ABNF.OPCODES:
208            raise ValueError("Invalid OPCODE")
209        length = len(self.data)
210        if length >= ABNF.LENGTH_63:
211            raise ValueError("data is too long")
212
213        frame_header = chr(self.fin << 7
214                           | self.rsv1 << 6 | self.rsv2 << 5 | self.rsv3 << 4
215                           | self.opcode)
216        if length < ABNF.LENGTH_7:
217            frame_header += chr(self.mask << 7 | length)
218            frame_header = six.b(frame_header)
219        elif length < ABNF.LENGTH_16:
220            frame_header += chr(self.mask << 7 | 0x7e)
221            frame_header = six.b(frame_header)
222            frame_header += struct.pack("!H", length)
223        else:
224            frame_header += chr(self.mask << 7 | 0x7f)
225            frame_header = six.b(frame_header)
226            frame_header += struct.pack("!Q", length)
227
228        if not self.mask:
229            return frame_header + self.data
230        else:
231            mask_key = self.get_mask_key(4)
232            return frame_header + self._get_masked(mask_key)
233
234    def _get_masked(self, mask_key):
235        s = ABNF.mask(mask_key, self.data)
236
237        if isinstance(mask_key, six.text_type):
238            mask_key = mask_key.encode('utf-8')
239
240        return mask_key + s
241
242    @staticmethod
243    def mask(mask_key, data):
244        """
245        mask or unmask data. Just do xor for each byte
246
247        mask_key: 4 byte string(byte).
248
249        data: data to mask/unmask.
250        """
251        if data is None:
252            data = ""
253
254        if isinstance(mask_key, six.text_type):
255            mask_key = six.b(mask_key)
256
257        if isinstance(data, six.text_type):
258            data = six.b(data)
259
260        _m = array.array("B", mask_key)
261        _d = array.array("B", data)
262        return _mask(_m, _d)
263
264
265class frame_buffer(object):
266    _HEADER_MASK_INDEX = 5
267    _HEADER_LENGTH_INDEX = 6
268
269    def __init__(self, recv_fn, skip_utf8_validation):
270        self.recv = recv_fn
271        self.skip_utf8_validation = skip_utf8_validation
272        # Buffers over the packets from the layer beneath until desired amount
273        # bytes of bytes are received.
274        self.recv_buffer = []
275        self.clear()
276
277    def clear(self):
278        self.header = None
279        self.length = None
280        self.mask = None
281
282    def has_received_header(self):
283        return self.header is None
284
285    def recv_header(self):
286        header = self.recv_strict(2)
287        b1 = header[0]
288
289        if six.PY2:
290            b1 = ord(b1)
291
292        fin = b1 >> 7 & 1
293        rsv1 = b1 >> 6 & 1
294        rsv2 = b1 >> 5 & 1
295        rsv3 = b1 >> 4 & 1
296        opcode = b1 & 0xf
297        b2 = header[1]
298
299        if six.PY2:
300            b2 = ord(b2)
301
302        has_mask = b2 >> 7 & 1
303        length_bits = b2 & 0x7f
304
305        self.header = (fin, rsv1, rsv2, rsv3, opcode, has_mask, length_bits)
306
307    def has_mask(self):
308        if not self.header:
309            return False
310        return self.header[frame_buffer._HEADER_MASK_INDEX]
311
312    def has_received_length(self):
313        return self.length is None
314
315    def recv_length(self):
316        bits = self.header[frame_buffer._HEADER_LENGTH_INDEX]
317        length_bits = bits & 0x7f
318        if length_bits == 0x7e:
319            v = self.recv_strict(2)
320            self.length = struct.unpack("!H", v)[0]
321        elif length_bits == 0x7f:
322            v = self.recv_strict(8)
323            self.length = struct.unpack("!Q", v)[0]
324        else:
325            self.length = length_bits
326
327    def has_received_mask(self):
328        return self.mask is None
329
330    def recv_mask(self):
331        self.mask = self.recv_strict(4) if self.has_mask() else ""
332
333    def recv_frame(self):
334        # Header
335        if self.has_received_header():
336            self.recv_header()
337        (fin, rsv1, rsv2, rsv3, opcode, has_mask, _) = self.header
338
339        # Frame length
340        if self.has_received_length():
341            self.recv_length()
342        length = self.length
343
344        # Mask
345        if self.has_received_mask():
346            self.recv_mask()
347        mask = self.mask
348
349        # Payload
350        payload = self.recv_strict(length)
351        if has_mask:
352            payload = ABNF.mask(mask, payload)
353
354        # Reset for next frame
355        self.clear()
356
357        frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload)
358        frame.validate(self.skip_utf8_validation)
359
360        return frame
361
362    def recv_strict(self, bufsize):
363        shortage = bufsize - sum(len(x) for x in self.recv_buffer)
364        while shortage > 0:
365            # Limit buffer size that we pass to socket.recv() to avoid
366            # fragmenting the heap -- the number of bytes recv() actually
367            # reads is limited by socket buffer and is relatively small,
368            # yet passing large numbers repeatedly causes lots of large
369            # buffers allocated and then shrunk, which results in
370            # fragmentation.
371            bytes_ = self.recv(min(16384, shortage))
372            self.recv_buffer.append(bytes_)
373            shortage -= len(bytes_)
374
375        unified = six.b("").join(self.recv_buffer)
376
377        if shortage == 0:
378            self.recv_buffer = []
379            return unified
380        else:
381            self.recv_buffer = [unified[bufsize:]]
382            return unified[:bufsize]
383
384
385class continuous_frame(object):
386
387    def __init__(self, fire_cont_frame, skip_utf8_validation):
388        self.fire_cont_frame = fire_cont_frame
389        self.skip_utf8_validation = skip_utf8_validation
390        self.cont_data = None
391        self.recving_frames = None
392
393    def validate(self, frame):
394        if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT:
395            raise WebSocketProtocolException("Illegal frame")
396        if self.recving_frames and \
397                frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
398            raise WebSocketProtocolException("Illegal frame")
399
400    def add(self, frame):
401        if self.cont_data:
402            self.cont_data[1] += frame.data
403        else:
404            if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
405                self.recving_frames = frame.opcode
406            self.cont_data = [frame.opcode, frame.data]
407
408        if frame.fin:
409            self.recving_frames = None
410
411    def is_fire(self, frame):
412        return frame.fin or self.fire_cont_frame
413
414    def extract(self, frame):
415        data = self.cont_data
416        self.cont_data = None
417        frame.data = data[1]
418        if not self.fire_cont_frame and data[0] == ABNF.OPCODE_TEXT and not self.skip_utf8_validation and not validate_utf8(frame.data):
419            raise WebSocketPayloadException(
420                "cannot decode: " + repr(frame.data))
421
422        return [data[0], frame]
423