1 package fi.iki.elonen;
2 
3 import java.io.EOFException;
4 import java.io.IOException;
5 import java.io.InputStream;
6 import java.io.OutputStream;
7 import java.nio.ByteBuffer;
8 import java.nio.CharBuffer;
9 import java.nio.charset.CharacterCodingException;
10 import java.nio.charset.Charset;
11 import java.nio.charset.CharsetDecoder;
12 import java.nio.charset.CharsetEncoder;
13 import java.util.Arrays;
14 import java.util.List;
15 
16 public class WebSocketFrame {
17     private OpCode opCode;
18     private boolean fin;
19     private byte[] maskingKey;
20 
21     private byte[] payload;
22 
23     private transient int _payloadLength;
24     private transient String _payloadString;
25 
WebSocketFrame(OpCode opCode, boolean fin)26     private WebSocketFrame(OpCode opCode, boolean fin) {
27         setOpCode(opCode);
28         setFin(fin);
29     }
30 
WebSocketFrame(OpCode opCode, boolean fin, byte[] payload, byte[] maskingKey)31     public WebSocketFrame(OpCode opCode, boolean fin, byte[] payload, byte[] maskingKey) {
32         this(opCode, fin);
33         setMaskingKey(maskingKey);
34         setBinaryPayload(payload);
35     }
36 
WebSocketFrame(OpCode opCode, boolean fin, byte[] payload)37     public WebSocketFrame(OpCode opCode, boolean fin, byte[] payload) {
38         this(opCode, fin, payload, null);
39     }
40 
WebSocketFrame(OpCode opCode, boolean fin, String payload, byte[] maskingKey)41     public WebSocketFrame(OpCode opCode, boolean fin, String payload, byte[] maskingKey) throws CharacterCodingException {
42         this(opCode, fin);
43         setMaskingKey(maskingKey);
44         setTextPayload(payload);
45     }
46 
WebSocketFrame(OpCode opCode, boolean fin, String payload)47     public WebSocketFrame(OpCode opCode, boolean fin, String payload) throws CharacterCodingException {
48         this(opCode, fin, payload, null);
49     }
50 
WebSocketFrame(WebSocketFrame clone)51     public WebSocketFrame(WebSocketFrame clone) {
52         setOpCode(clone.getOpCode());
53         setFin(clone.isFin());
54         setBinaryPayload(clone.getBinaryPayload());
55         setMaskingKey(clone.getMaskingKey());
56     }
57 
WebSocketFrame(OpCode opCode, List<WebSocketFrame> fragments)58     public WebSocketFrame(OpCode opCode, List<WebSocketFrame> fragments) throws WebSocketException {
59         setOpCode(opCode);
60         setFin(true);
61 
62         long _payloadLength = 0;
63         for (WebSocketFrame inter : fragments) {
64             _payloadLength += inter.getBinaryPayload().length;
65         }
66         if (_payloadLength < 0 || _payloadLength > Integer.MAX_VALUE) {
67             throw new WebSocketException(CloseCode.MessageTooBig, "Max frame length has been exceeded.");
68         }
69         this._payloadLength = (int) _payloadLength;
70         byte[] payload = new byte[this._payloadLength];
71         int offset = 0;
72         for (WebSocketFrame inter : fragments) {
73             System.arraycopy(inter.getBinaryPayload(), 0, payload, offset, inter.getBinaryPayload().length);
74             offset += inter.getBinaryPayload().length;
75         }
76         setBinaryPayload(payload);
77     }
78 
79     // --------------------------------GETTERS---------------------------------
80 
getOpCode()81     public OpCode getOpCode() {
82         return opCode;
83     }
84 
setOpCode(OpCode opcode)85     public void setOpCode(OpCode opcode) {
86         this.opCode = opcode;
87     }
88 
isFin()89     public boolean isFin() {
90         return fin;
91     }
92 
setFin(boolean fin)93     public void setFin(boolean fin) {
94         this.fin = fin;
95     }
96 
isMasked()97     public boolean isMasked() {
98         return maskingKey != null && maskingKey.length == 4;
99     }
100 
getMaskingKey()101     public byte[] getMaskingKey() {
102         return maskingKey;
103     }
104 
setMaskingKey(byte[] maskingKey)105     public void setMaskingKey(byte[] maskingKey) {
106         if (maskingKey != null && maskingKey.length != 4) {
107             throw new IllegalArgumentException("MaskingKey " + Arrays.toString(maskingKey) + " hasn't length 4");
108         }
109         this.maskingKey = maskingKey;
110     }
111 
setUnmasked()112     public void setUnmasked() {
113         setMaskingKey(null);
114     }
115 
getBinaryPayload()116     public byte[] getBinaryPayload() {
117         return payload;
118     }
119 
setBinaryPayload(byte[] payload)120     public void setBinaryPayload(byte[] payload) {
121         this.payload = payload;
122         this._payloadLength = payload.length;
123         this._payloadString = null;
124     }
125 
getTextPayload()126     public String getTextPayload() {
127         if (_payloadString == null) {
128             try {
129                 _payloadString = binary2Text(getBinaryPayload());
130             } catch (CharacterCodingException e) {
131                 throw new RuntimeException("Undetected CharacterCodingException", e);
132             }
133         }
134         return _payloadString;
135     }
136 
setTextPayload(String payload)137     public void setTextPayload(String payload) throws CharacterCodingException {
138         this.payload = text2Binary(payload);
139         this._payloadLength = payload.length();
140         this._payloadString = payload;
141     }
142 
143     // --------------------------------SERIALIZATION---------------------------
144 
read(InputStream in)145     public static WebSocketFrame read(InputStream in) throws IOException {
146         byte head = (byte) checkedRead(in.read());
147         boolean fin = ((head & 0x80) != 0);
148         OpCode opCode = OpCode.find((byte) (head & 0x0F));
149         if ((head & 0x70) != 0) {
150             throw new WebSocketException(CloseCode.ProtocolError, "The reserved bits (" + Integer.toBinaryString(head & 0x70) + ") must be 0.");
151         }
152         if (opCode == null) {
153             throw new WebSocketException(CloseCode.ProtocolError, "Received frame with reserved/unknown opcode " + (head & 0x0F) + ".");
154         } else if (opCode.isControlFrame() && !fin) {
155             throw new WebSocketException(CloseCode.ProtocolError, "Fragmented control frame.");
156         }
157 
158         WebSocketFrame frame = new WebSocketFrame(opCode, fin);
159         frame.readPayloadInfo(in);
160         frame.readPayload(in);
161         if (frame.getOpCode() == OpCode.Close) {
162             return new CloseFrame(frame);
163         } else {
164             return frame;
165         }
166     }
167 
checkedRead(int read)168     private static int checkedRead(int read) throws IOException {
169         if (read < 0) {
170             throw new EOFException();
171         }
172         //System.out.println(Integer.toBinaryString(read) + "/" + read + "/" + Integer.toHexString(read));
173         return read;
174     }
175 
176 
readPayloadInfo(InputStream in)177     private void readPayloadInfo(InputStream in) throws IOException {
178         byte b = (byte) checkedRead(in.read());
179         boolean masked = ((b & 0x80) != 0);
180 
181         _payloadLength = (byte) (0x7F & b);
182         if (_payloadLength == 126) {
183             // checkedRead must return int for this to work
184             _payloadLength = (checkedRead(in.read()) << 8 | checkedRead(in.read())) & 0xFFFF;
185             if (_payloadLength < 126) {
186                 throw new WebSocketException(CloseCode.ProtocolError, "Invalid data frame 2byte length. (not using minimal length encoding)");
187             }
188         } else if (_payloadLength == 127) {
189             long _payloadLength = ((long) checkedRead(in.read())) << 56 |
190                     ((long) checkedRead(in.read())) << 48 |
191                     ((long) checkedRead(in.read())) << 40 |
192                     ((long) checkedRead(in.read())) << 32 |
193                     checkedRead(in.read()) << 24 | checkedRead(in.read()) << 16 | checkedRead(in.read()) << 8 | checkedRead(in.read());
194             if (_payloadLength < 65536) {
195                 throw new WebSocketException(CloseCode.ProtocolError, "Invalid data frame 4byte length. (not using minimal length encoding)");
196             }
197             if (_payloadLength < 0 || _payloadLength > Integer.MAX_VALUE) {
198                 throw new WebSocketException(CloseCode.MessageTooBig, "Max frame length has been exceeded.");
199             }
200             this._payloadLength = (int) _payloadLength;
201         }
202 
203         if (opCode.isControlFrame()) {
204             if (_payloadLength > 125) {
205                 throw new WebSocketException(CloseCode.ProtocolError, "Control frame with payload length > 125 bytes.");
206             }
207             if (opCode == OpCode.Close && _payloadLength == 1) {
208                 throw new WebSocketException(CloseCode.ProtocolError, "Received close frame with payload len 1.");
209             }
210         }
211 
212         if (masked) {
213             maskingKey = new byte[4];
214             int read = 0;
215             while (read < maskingKey.length) {
216                 read += checkedRead(in.read(maskingKey, read, maskingKey.length - read));
217             }
218         }
219     }
220 
readPayload(InputStream in)221     private void readPayload(InputStream in) throws IOException {
222         payload = new byte[_payloadLength];
223         int read = 0;
224         while (read < _payloadLength) {
225             read += checkedRead(in.read(payload, read, _payloadLength - read));
226         }
227 
228         if (isMasked()) {
229             for (int i = 0; i < payload.length; i++) {
230                 payload[i] ^= maskingKey[i % 4];
231             }
232         }
233 
234         //Test for Unicode errors
235         if (getOpCode() == OpCode.Text) {
236             _payloadString = binary2Text(getBinaryPayload());
237         }
238     }
239 
write(OutputStream out)240     public void write(OutputStream out) throws IOException {
241         byte header = 0;
242         if (fin) {
243             header |= 0x80;
244         }
245         header |= opCode.getValue() & 0x0F;
246         out.write(header);
247 
248         _payloadLength = getBinaryPayload().length;
249         if (_payloadLength <= 125) {
250             out.write(isMasked() ? 0x80 | (byte) _payloadLength : (byte) _payloadLength);
251         } else if (_payloadLength <= 0xFFFF) {
252             out.write(isMasked() ? 0xFE : 126);
253             out.write(_payloadLength >>> 8);
254             out.write(_payloadLength);
255         } else {
256             out.write(isMasked() ? 0xFF : 127);
257             out.write(_payloadLength >>> 56 & 0); //integer only contains 31 bit
258             out.write(_payloadLength >>> 48 & 0);
259             out.write(_payloadLength >>> 40 & 0);
260             out.write(_payloadLength >>> 32 & 0);
261             out.write(_payloadLength >>> 24);
262             out.write(_payloadLength >>> 16);
263             out.write(_payloadLength >>> 8);
264             out.write(_payloadLength);
265         }
266 
267 
268         if (isMasked()) {
269             out.write(maskingKey);
270             for (int i = 0; i < _payloadLength; i++) {
271                 out.write(getBinaryPayload()[i] ^ maskingKey[i % 4]);
272             }
273         } else {
274             out.write(getBinaryPayload());
275         }
276         out.flush();
277     }
278 
279     // --------------------------------ENCODING--------------------------------
280 
281     public static final Charset TEXT_CHARSET = Charset.forName("UTF-8");
282     public static final CharsetDecoder TEXT_DECODER = TEXT_CHARSET.newDecoder();
283     public static final CharsetEncoder TEXT_ENCODER = TEXT_CHARSET.newEncoder();
284 
285 
binary2Text(byte[] payload)286     public static String binary2Text(byte[] payload) throws CharacterCodingException {
287         return TEXT_DECODER.decode(ByteBuffer.wrap(payload)).toString();
288     }
289 
binary2Text(byte[] payload, int offset, int length)290     public static String binary2Text(byte[] payload, int offset, int length) throws CharacterCodingException {
291         return TEXT_DECODER.decode(ByteBuffer.wrap(payload, offset, length)).toString();
292     }
293 
text2Binary(String payload)294     public static byte[] text2Binary(String payload) throws CharacterCodingException {
295         return TEXT_ENCODER.encode(CharBuffer.wrap(payload)).array();
296     }
297 
298     @Override
toString()299     public String toString() {
300         final StringBuilder sb = new StringBuilder("WS[");
301         sb.append(getOpCode());
302         sb.append(", ").append(isFin() ? "fin" : "inter");
303         sb.append(", ").append(isMasked() ? "masked" : "unmasked");
304         sb.append(", ").append(payloadToString());
305         sb.append(']');
306         return sb.toString();
307     }
308 
payloadToString()309     protected String payloadToString() {
310         if (payload == null) return "null";
311         else {
312             final StringBuilder sb = new StringBuilder();
313             sb.append('[').append(payload.length).append("b] ");
314             if (getOpCode() == OpCode.Text) {
315                 String text = getTextPayload();
316                 if (text.length() > 100)
317                     sb.append(text.substring(0, 100)).append("...");
318                 else
319                     sb.append(text);
320             } else {
321                 sb.append("0x");
322                 for (int i = 0; i < Math.min(payload.length, 50); ++i)
323                     sb.append(Integer.toHexString((int) payload[i] & 0xFF));
324                 if (payload.length > 50)
325                     sb.append("...");
326             }
327             return sb.toString();
328         }
329     }
330 
331     // --------------------------------CONSTANTS-------------------------------
332 
333     public static enum OpCode {
334         Continuation(0), Text(1), Binary(2), Close(8), Ping(9), Pong(10);
335 
336         private final byte code;
337 
OpCode(int code)338         private OpCode(int code) {
339             this.code = (byte) code;
340         }
341 
getValue()342         public byte getValue() {
343             return code;
344         }
345 
isControlFrame()346         public boolean isControlFrame() {
347             return this == Close || this == Ping || this == Pong;
348         }
349 
find(byte value)350         public static OpCode find(byte value) {
351             for (OpCode opcode : values()) {
352                 if (opcode.getValue() == value) {
353                     return opcode;
354                 }
355             }
356             return null;
357         }
358     }
359 
360     public static enum CloseCode {
361         NormalClosure(1000), GoingAway(1001), ProtocolError(1002), UnsupportedData(1003), NoStatusRcvd(1005),
362         AbnormalClosure(1006), InvalidFramePayloadData(1007), PolicyViolation(1008), MessageTooBig(1009),
363         MandatoryExt(1010), InternalServerError(1011), TLSHandshake(1015);
364 
365         private final int code;
366 
CloseCode(int code)367         private CloseCode(int code) {
368             this.code = code;
369         }
370 
getValue()371         public int getValue() {
372             return code;
373         }
374 
find(int value)375         public static CloseCode find(int value) {
376             for (CloseCode code : values()) {
377                 if (code.getValue() == value) {
378                     return code;
379                 }
380             }
381             return null;
382         }
383     }
384 
385     // ------------------------------------------------------------------------
386 
387     public static class CloseFrame extends WebSocketFrame {
388         private CloseCode _closeCode;
389         private String _closeReason;
390 
CloseFrame(WebSocketFrame wrap)391         private CloseFrame(WebSocketFrame wrap) throws CharacterCodingException {
392             super(wrap);
393             assert wrap.getOpCode() == OpCode.Close;
394             if (wrap.getBinaryPayload().length >= 2) {
395                 _closeCode = CloseCode.find((wrap.getBinaryPayload()[0] & 0xFF) << 8 |
396                         (wrap.getBinaryPayload()[1] & 0xFF));
397                 _closeReason = binary2Text(getBinaryPayload(), 2, getBinaryPayload().length - 2);
398             }
399         }
400 
CloseFrame(CloseCode code, String closeReason)401         public CloseFrame(CloseCode code, String closeReason) throws CharacterCodingException {
402             super(OpCode.Close, true, generatePayload(code, closeReason));
403         }
404 
generatePayload(CloseCode code, String closeReason)405         private static byte[] generatePayload(CloseCode code, String closeReason) throws CharacterCodingException {
406             if (code != null) {
407                 byte[] reasonBytes = text2Binary(closeReason);
408                 byte[] payload = new byte[reasonBytes.length + 2];
409                 payload[0] = (byte) ((code.getValue() >> 8) & 0xFF);
410                 payload[1] = (byte) ((code.getValue()) & 0xFF);
411                 System.arraycopy(reasonBytes, 0, payload, 2, reasonBytes.length);
412                 return payload;
413             } else {
414                 return new byte[0];
415             }
416         }
417 
payloadToString()418         protected String payloadToString() {
419             return (_closeCode != null ? _closeCode : "UnknownCloseCode[" + _closeCode + "]") + (_closeReason != null && !_closeReason.isEmpty() ? ": " + _closeReason : "");
420         }
421 
getCloseCode()422         public CloseCode getCloseCode() {
423             return _closeCode;
424         }
425 
getCloseReason()426         public String getCloseReason() {
427             return _closeReason;
428         }
429     }
430 }
431