1""" 2websocket - WebSocket client library for Python 3 4Copyright (C) 2010 Hiroki Ohtani(liris) 5 6 This library is free software; you can redistribute it and/or 7 modify it under the terms of the GNU Lesser General Public 8 License as published by the Free Software Foundation; either 9 version 2.1 of the License, or (at your option) any later version. 10 11 This library is distributed in the hope that it will be useful, 12 but WITHOUT ANY WARRANTY; without even the implied warranty of 13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 14 Lesser General Public License for more details. 15 16 You should have received a copy of the GNU Lesser General Public 17 License along with this library; if not, write to the Free Software 18 Foundation, Inc., 51 Franklin Street, Fifth Floor, 19 Boston, MA 02110-1335 USA 20 21""" 22import array 23import os 24import struct 25 26import six 27 28from ._exceptions import * 29from ._utils import validate_utf8 30 31try: 32 # If wsaccel is available we use compiled routines to mask data. 33 from wsaccel.xormask import XorMaskerSimple 34 35 def _mask(_m, _d): 36 return XorMaskerSimple(_m).process(_d) 37 38except ImportError: 39 # wsaccel is not available, we rely on python implementations. 40 def _mask(_m, _d): 41 for i in range(len(_d)): 42 _d[i] ^= _m[i % 4] 43 44 if six.PY3: 45 return _d.tobytes() 46 else: 47 return _d.tostring() 48 49__all__ = [ 50 'ABNF', 'continuous_frame', 'frame_buffer', 51 'STATUS_NORMAL', 52 'STATUS_GOING_AWAY', 53 'STATUS_PROTOCOL_ERROR', 54 'STATUS_UNSUPPORTED_DATA_TYPE', 55 'STATUS_STATUS_NOT_AVAILABLE', 56 'STATUS_ABNORMAL_CLOSED', 57 'STATUS_INVALID_PAYLOAD', 58 'STATUS_POLICY_VIOLATION', 59 'STATUS_MESSAGE_TOO_BIG', 60 'STATUS_INVALID_EXTENSION', 61 'STATUS_UNEXPECTED_CONDITION', 62 'STATUS_BAD_GATEWAY', 63 'STATUS_TLS_HANDSHAKE_ERROR', 64] 65 66# closing frame status codes. 67STATUS_NORMAL = 1000 68STATUS_GOING_AWAY = 1001 69STATUS_PROTOCOL_ERROR = 1002 70STATUS_UNSUPPORTED_DATA_TYPE = 1003 71STATUS_STATUS_NOT_AVAILABLE = 1005 72STATUS_ABNORMAL_CLOSED = 1006 73STATUS_INVALID_PAYLOAD = 1007 74STATUS_POLICY_VIOLATION = 1008 75STATUS_MESSAGE_TOO_BIG = 1009 76STATUS_INVALID_EXTENSION = 1010 77STATUS_UNEXPECTED_CONDITION = 1011 78STATUS_BAD_GATEWAY = 1014 79STATUS_TLS_HANDSHAKE_ERROR = 1015 80 81VALID_CLOSE_STATUS = ( 82 STATUS_NORMAL, 83 STATUS_GOING_AWAY, 84 STATUS_PROTOCOL_ERROR, 85 STATUS_UNSUPPORTED_DATA_TYPE, 86 STATUS_INVALID_PAYLOAD, 87 STATUS_POLICY_VIOLATION, 88 STATUS_MESSAGE_TOO_BIG, 89 STATUS_INVALID_EXTENSION, 90 STATUS_UNEXPECTED_CONDITION, 91 STATUS_BAD_GATEWAY, 92) 93 94 95class ABNF(object): 96 """ 97 ABNF frame class. 98 see http://tools.ietf.org/html/rfc5234 99 and http://tools.ietf.org/html/rfc6455#section-5.2 100 """ 101 102 # operation code values. 103 OPCODE_CONT = 0x0 104 OPCODE_TEXT = 0x1 105 OPCODE_BINARY = 0x2 106 OPCODE_CLOSE = 0x8 107 OPCODE_PING = 0x9 108 OPCODE_PONG = 0xa 109 110 # available operation code value tuple 111 OPCODES = (OPCODE_CONT, OPCODE_TEXT, OPCODE_BINARY, OPCODE_CLOSE, 112 OPCODE_PING, OPCODE_PONG) 113 114 # opcode human readable string 115 OPCODE_MAP = { 116 OPCODE_CONT: "cont", 117 OPCODE_TEXT: "text", 118 OPCODE_BINARY: "binary", 119 OPCODE_CLOSE: "close", 120 OPCODE_PING: "ping", 121 OPCODE_PONG: "pong" 122 } 123 124 # data length threshold. 125 LENGTH_7 = 0x7e 126 LENGTH_16 = 1 << 16 127 LENGTH_63 = 1 << 63 128 129 def __init__(self, fin=0, rsv1=0, rsv2=0, rsv3=0, 130 opcode=OPCODE_TEXT, mask=1, data=""): 131 """ 132 Constructor for ABNF. 133 please check RFC for arguments. 134 """ 135 self.fin = fin 136 self.rsv1 = rsv1 137 self.rsv2 = rsv2 138 self.rsv3 = rsv3 139 self.opcode = opcode 140 self.mask = mask 141 if data is None: 142 data = "" 143 self.data = data 144 self.get_mask_key = os.urandom 145 146 def validate(self, skip_utf8_validation=False): 147 """ 148 validate the ABNF frame. 149 skip_utf8_validation: skip utf8 validation. 150 """ 151 if self.rsv1 or self.rsv2 or self.rsv3: 152 raise WebSocketProtocolException("rsv is not implemented, yet") 153 154 if self.opcode not in ABNF.OPCODES: 155 raise WebSocketProtocolException("Invalid opcode %r", self.opcode) 156 157 if self.opcode == ABNF.OPCODE_PING and not self.fin: 158 raise WebSocketProtocolException("Invalid ping frame.") 159 160 if self.opcode == ABNF.OPCODE_CLOSE: 161 l = len(self.data) 162 if not l: 163 return 164 if l == 1 or l >= 126: 165 raise WebSocketProtocolException("Invalid close frame.") 166 if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]): 167 raise WebSocketProtocolException("Invalid close frame.") 168 169 code = 256 * \ 170 six.byte2int(self.data[0:1]) + six.byte2int(self.data[1:2]) 171 if not self._is_valid_close_status(code): 172 raise WebSocketProtocolException("Invalid close opcode.") 173 174 @staticmethod 175 def _is_valid_close_status(code): 176 return code in VALID_CLOSE_STATUS or (3000 <= code < 5000) 177 178 def __str__(self): 179 return "fin=" + str(self.fin) \ 180 + " opcode=" + str(self.opcode) \ 181 + " data=" + str(self.data) 182 183 @staticmethod 184 def create_frame(data, opcode, fin=1): 185 """ 186 create frame to send text, binary and other data. 187 188 data: data to send. This is string value(byte array). 189 if opcode is OPCODE_TEXT and this value is unicode, 190 data value is converted into unicode string, automatically. 191 192 opcode: operation code. please see OPCODE_XXX. 193 194 fin: fin flag. if set to 0, create continue fragmentation. 195 """ 196 if opcode == ABNF.OPCODE_TEXT and isinstance(data, six.text_type): 197 data = data.encode("utf-8") 198 # mask must be set if send data from client 199 return ABNF(fin, 0, 0, 0, opcode, 1, data) 200 201 def format(self): 202 """ 203 format this object to string(byte array) to send data to server. 204 """ 205 if any(x not in (0, 1) for x in [self.fin, self.rsv1, self.rsv2, self.rsv3]): 206 raise ValueError("not 0 or 1") 207 if self.opcode not in ABNF.OPCODES: 208 raise ValueError("Invalid OPCODE") 209 length = len(self.data) 210 if length >= ABNF.LENGTH_63: 211 raise ValueError("data is too long") 212 213 frame_header = chr(self.fin << 7 214 | self.rsv1 << 6 | self.rsv2 << 5 | self.rsv3 << 4 215 | self.opcode) 216 if length < ABNF.LENGTH_7: 217 frame_header += chr(self.mask << 7 | length) 218 frame_header = six.b(frame_header) 219 elif length < ABNF.LENGTH_16: 220 frame_header += chr(self.mask << 7 | 0x7e) 221 frame_header = six.b(frame_header) 222 frame_header += struct.pack("!H", length) 223 else: 224 frame_header += chr(self.mask << 7 | 0x7f) 225 frame_header = six.b(frame_header) 226 frame_header += struct.pack("!Q", length) 227 228 if not self.mask: 229 return frame_header + self.data 230 else: 231 mask_key = self.get_mask_key(4) 232 return frame_header + self._get_masked(mask_key) 233 234 def _get_masked(self, mask_key): 235 s = ABNF.mask(mask_key, self.data) 236 237 if isinstance(mask_key, six.text_type): 238 mask_key = mask_key.encode('utf-8') 239 240 return mask_key + s 241 242 @staticmethod 243 def mask(mask_key, data): 244 """ 245 mask or unmask data. Just do xor for each byte 246 247 mask_key: 4 byte string(byte). 248 249 data: data to mask/unmask. 250 """ 251 if data is None: 252 data = "" 253 254 if isinstance(mask_key, six.text_type): 255 mask_key = six.b(mask_key) 256 257 if isinstance(data, six.text_type): 258 data = six.b(data) 259 260 _m = array.array("B", mask_key) 261 _d = array.array("B", data) 262 return _mask(_m, _d) 263 264 265class frame_buffer(object): 266 _HEADER_MASK_INDEX = 5 267 _HEADER_LENGTH_INDEX = 6 268 269 def __init__(self, recv_fn, skip_utf8_validation): 270 self.recv = recv_fn 271 self.skip_utf8_validation = skip_utf8_validation 272 # Buffers over the packets from the layer beneath until desired amount 273 # bytes of bytes are received. 274 self.recv_buffer = [] 275 self.clear() 276 277 def clear(self): 278 self.header = None 279 self.length = None 280 self.mask = None 281 282 def has_received_header(self): 283 return self.header is None 284 285 def recv_header(self): 286 header = self.recv_strict(2) 287 b1 = header[0] 288 289 if six.PY2: 290 b1 = ord(b1) 291 292 fin = b1 >> 7 & 1 293 rsv1 = b1 >> 6 & 1 294 rsv2 = b1 >> 5 & 1 295 rsv3 = b1 >> 4 & 1 296 opcode = b1 & 0xf 297 b2 = header[1] 298 299 if six.PY2: 300 b2 = ord(b2) 301 302 has_mask = b2 >> 7 & 1 303 length_bits = b2 & 0x7f 304 305 self.header = (fin, rsv1, rsv2, rsv3, opcode, has_mask, length_bits) 306 307 def has_mask(self): 308 if not self.header: 309 return False 310 return self.header[frame_buffer._HEADER_MASK_INDEX] 311 312 def has_received_length(self): 313 return self.length is None 314 315 def recv_length(self): 316 bits = self.header[frame_buffer._HEADER_LENGTH_INDEX] 317 length_bits = bits & 0x7f 318 if length_bits == 0x7e: 319 v = self.recv_strict(2) 320 self.length = struct.unpack("!H", v)[0] 321 elif length_bits == 0x7f: 322 v = self.recv_strict(8) 323 self.length = struct.unpack("!Q", v)[0] 324 else: 325 self.length = length_bits 326 327 def has_received_mask(self): 328 return self.mask is None 329 330 def recv_mask(self): 331 self.mask = self.recv_strict(4) if self.has_mask() else "" 332 333 def recv_frame(self): 334 # Header 335 if self.has_received_header(): 336 self.recv_header() 337 (fin, rsv1, rsv2, rsv3, opcode, has_mask, _) = self.header 338 339 # Frame length 340 if self.has_received_length(): 341 self.recv_length() 342 length = self.length 343 344 # Mask 345 if self.has_received_mask(): 346 self.recv_mask() 347 mask = self.mask 348 349 # Payload 350 payload = self.recv_strict(length) 351 if has_mask: 352 payload = ABNF.mask(mask, payload) 353 354 # Reset for next frame 355 self.clear() 356 357 frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload) 358 frame.validate(self.skip_utf8_validation) 359 360 return frame 361 362 def recv_strict(self, bufsize): 363 shortage = bufsize - sum(len(x) for x in self.recv_buffer) 364 while shortage > 0: 365 # Limit buffer size that we pass to socket.recv() to avoid 366 # fragmenting the heap -- the number of bytes recv() actually 367 # reads is limited by socket buffer and is relatively small, 368 # yet passing large numbers repeatedly causes lots of large 369 # buffers allocated and then shrunk, which results in 370 # fragmentation. 371 bytes_ = self.recv(min(16384, shortage)) 372 self.recv_buffer.append(bytes_) 373 shortage -= len(bytes_) 374 375 unified = six.b("").join(self.recv_buffer) 376 377 if shortage == 0: 378 self.recv_buffer = [] 379 return unified 380 else: 381 self.recv_buffer = [unified[bufsize:]] 382 return unified[:bufsize] 383 384 385class continuous_frame(object): 386 387 def __init__(self, fire_cont_frame, skip_utf8_validation): 388 self.fire_cont_frame = fire_cont_frame 389 self.skip_utf8_validation = skip_utf8_validation 390 self.cont_data = None 391 self.recving_frames = None 392 393 def validate(self, frame): 394 if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT: 395 raise WebSocketProtocolException("Illegal frame") 396 if self.recving_frames and \ 397 frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY): 398 raise WebSocketProtocolException("Illegal frame") 399 400 def add(self, frame): 401 if self.cont_data: 402 self.cont_data[1] += frame.data 403 else: 404 if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY): 405 self.recving_frames = frame.opcode 406 self.cont_data = [frame.opcode, frame.data] 407 408 if frame.fin: 409 self.recving_frames = None 410 411 def is_fire(self, frame): 412 return frame.fin or self.fire_cont_frame 413 414 def extract(self, frame): 415 data = self.cont_data 416 self.cont_data = None 417 frame.data = data[1] 418 if not self.fire_cont_frame and data[0] == ABNF.OPCODE_TEXT and not self.skip_utf8_validation and not validate_utf8(frame.data): 419 raise WebSocketPayloadException( 420 "cannot decode: " + repr(frame.data)) 421 422 return [data[0], frame] 423