1 /*
2  * Copyright (c) 2002, 2012, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 package test.java.net.Socks;
24 
25 import java.net.*;
26 import java.io.*;
27 import java.util.HashMap;
28 
29 public class SocksServer extends Thread {
30     // Some useful SOCKS constant
31 
32     static final int PROTO_VERS4        = 4;
33     static final int PROTO_VERS         = 5;
34     static final int DEFAULT_PORT       = 1080;
35 
36     static final int NO_AUTH            = 0;
37     static final int GSSAPI             = 1;
38     static final int USER_PASSW         = 2;
39     static final int NO_METHODS         = -1;
40 
41     static final int CONNECT            = 1;
42     static final int BIND               = 2;
43     static final int UDP_ASSOC          = 3;
44 
45     static final int IPV4               = 1;
46     static final int DOMAIN_NAME        = 3;
47     static final int IPV6               = 4;
48 
49     static final int REQUEST_OK         = 0;
50     static final int GENERAL_FAILURE    = 1;
51     static final int NOT_ALLOWED        = 2;
52     static final int NET_UNREACHABLE    = 3;
53     static final int HOST_UNREACHABLE   = 4;
54     static final int CONN_REFUSED       = 5;
55     static final int TTL_EXPIRED        = 6;
56     static final int CMD_NOT_SUPPORTED  = 7;
57     static final int ADDR_TYPE_NOT_SUP  = 8;
58 
59     private int port;
60     private ServerSocket server;
61     private boolean useV4 = false;
62     private HashMap<String,String> users = new HashMap<>();
63     private volatile boolean done = false;
64     // Inner class to handle protocol with client
65     // This is the bulk of the work (protocol handler)
66     class ClientHandler extends Thread {
67         private InputStream in;
68         private OutputStream out;
69         private Socket client;
70         private Socket dest;
71 
72         // Simple tunneling class, moving bits from one stream to another
73 
74         class Tunnel extends Thread {
75             private InputStream tin;
76             private OutputStream tout;
77 
Tunnel(InputStream in, OutputStream out)78             Tunnel(InputStream in, OutputStream out) {
79                 tin = in;
80                 tout = out;
81             }
82 
run()83             public void run() {
84                 int b;
85                 while (true) {
86                     try {
87                         b = tin.read();
88                         if (b == -1) {
89                             tin.close();
90                             tout.close();
91                             return;
92                         }
93                         tout.write(b);
94                         tout.flush();
95                     } catch (IOException e) {
96                         // actually exit from the thread
97                         return;
98                     }
99                 }
100             }
101         }
102 
ClientHandler(Socket s)103         ClientHandler(Socket s) throws IOException {
104             client = s;
105             in = new BufferedInputStream(client.getInputStream());
106             out = new BufferedOutputStream(client.getOutputStream());
107         }
108 
readBuf(InputStream is, byte[] buf)109         private void readBuf(InputStream is, byte[] buf) throws IOException {
110             int l = buf.length;
111             int count = 0;
112             int i;
113             do {
114                 i = is.read(buf, count, l - count);
115                 if (i == -1)
116                     throw new IOException("unexpected EOF");
117                 count += i;
118             } while (count < l);
119         }
120 
121 
userPassAuth()122         private boolean userPassAuth() throws IOException {
123             int ver = in.read();
124             int ulen = in.read();
125             if (ulen <= 0)
126                 throw new SocketException("SOCKS protocol error");
127             byte[] buf = new byte[ulen];
128             readBuf(in, buf);
129             String uname = new String(buf);
130             String password = null;
131             ulen = in.read();
132             if (ulen < 0)
133                 throw new SocketException("SOCKS protocol error");
134             if (ulen > 0) {
135                 buf = new byte[ulen];
136                 readBuf(in, buf);
137                 password = new String(buf);
138             }
139             // Check username/password validity here
140             System.err.println("User: '" + uname);
141             System.err.println("PSWD: '" + password);
142             if (users.containsKey(uname)) {
143                 String p1 = users.get(uname);
144                 System.err.println("p1 = " + p1);
145                 if (p1.equals(password)) {
146                     out.write(PROTO_VERS);
147                     out.write(REQUEST_OK);
148                     out.flush();
149                     return true;
150                 }
151             }
152             out.write(PROTO_VERS);
153             out.write(NOT_ALLOWED);
154             out.flush();
155             return false;
156         }
157 
purge()158         private void purge() throws IOException {
159             boolean done = false;
160             int i = 0;
161             client.setSoTimeout(1000);
162             while(!done && i != -1) {
163                 try {
164                     i = in.read();
165                 } catch(IOException e) {
166                     done = true;
167                 }
168             }
169         }
170 
171 
172         // Handle the SOCKS version 4 protocl
173 
getRequestV4()174         private void getRequestV4() throws IOException {
175             int ver = in.read();
176             int cmd = in.read();
177             if (ver == -1 || cmd == -1) {
178                 // EOF
179                 in.close();
180                 out.close();
181                 return;
182             }
183 
184             if (ver != 0 && ver != 4) {
185                 out.write(PROTO_VERS4);
186                 out.write(91); // Bad Request
187                 out.write(0);
188                 out.write(0);
189                 out.write(0);
190                 out.write(0);
191                 out.write(0);
192                 out.write(0);
193                 out.write(0);
194                 out.flush();
195                 purge();
196                 out.close();
197                 in.close();
198                 return;
199             }
200 
201             if (cmd == CONNECT) {
202                 int port = ((in.read() & 0xff) << 8);
203                 port += (in.read() & 0xff);
204                 byte[] buf = new byte[4];
205                 readBuf(in, buf);
206                 InetAddress addr = InetAddress.getByAddress(buf);
207                 // We don't use the username...
208                 int c;
209                 do {
210                     c = (in.read() & 0xff);
211                 } while (c!=0);
212                 boolean ok = true;
213                 try {
214                     dest = new Socket(addr, port);
215                 } catch (IOException e) {
216                     ok = false;
217                 }
218                 if (!ok) {
219                     out.write(PROTO_VERS4);
220                     out.write(91);
221                     out.write(0);
222                     out.write(0);
223                     out.write(buf);
224                     out.flush();
225                     purge();
226                     out.close();
227                     in.close();
228                     return;
229                 }
230                 out.write(PROTO_VERS4);
231                 out.write(90); // Success
232                 out.write((port >> 8) & 0xff);
233                 out.write(port & 0xff);
234                 out.write(buf);
235                 out.flush();
236                 InputStream in2 = new BufferedInputStream(dest.getInputStream());
237                 OutputStream out2 = new BufferedOutputStream(dest.getOutputStream());
238 
239                 Tunnel tunnel = new Tunnel(in2, out);
240                 tunnel.start();
241 
242                 int b = 0;
243                 do {
244                     try {
245                         b = in.read();
246                         if (b == -1) {
247                             in.close();
248                             out2.close();
249                             return;
250                         }
251                         out2.write(b);
252                         out2.flush();
253                     } catch (IOException ex) {
254                     }
255                 } while (!client.isClosed());
256             }
257         }
258 
259 
260         // Negociate the authentication scheme with the client
negociate()261         private void negociate() throws IOException {
262             int ver = in.read();
263             int n = in.read();
264             byte[] buf = null;
265             if (n > 0) {
266                 buf = new byte[n];
267                 readBuf(in, buf);
268             }
269             int scheme = NO_AUTH;
270             for (int i = 0; i < n; i++)
271                 if (buf[i] == USER_PASSW)
272                     scheme = USER_PASSW;
273             out.write(PROTO_VERS);
274             out.write(scheme);
275             out.flush();
276             if (scheme == USER_PASSW)
277                 userPassAuth();
278         }
279 
280         // Send error message then close the streams
sendError(int code)281         private void sendError(int code) {
282             try {
283                 out.write(PROTO_VERS);
284                 out.write(code);
285                 out.write(0);
286                 out.write(IPV4);
287                 for (int i=0; i<6; i++)
288                     out.write(0);
289                 out.flush();
290                 out.close();
291             } catch (IOException ex) {
292             }
293         }
294 
295         // Actually connect the proxy to the destination then initiate tunneling
296 
doConnect(InetSocketAddress addr)297         private void doConnect(InetSocketAddress addr) throws IOException {
298             dest = new Socket();
299             try {
300                 dest.connect(addr, 10000);
301             } catch (SocketTimeoutException ex) {
302                 sendError(HOST_UNREACHABLE);
303                 return;
304             } catch (ConnectException cex) {
305                 sendError(CONN_REFUSED);
306                 return;
307             }
308             // Success
309             InetAddress iadd = addr.getAddress();
310             if (iadd instanceof Inet4Address) {
311                 out.write(PROTO_VERS);
312                 out.write(REQUEST_OK);
313                 out.write(0);
314                 out.write(IPV4);
315                 out.write(iadd.getAddress());
316             } else if (iadd instanceof Inet6Address) {
317                 out.write(PROTO_VERS);
318                 out.write(REQUEST_OK);
319                 out.write(0);
320                 out.write(IPV6);
321                 out.write(iadd.getAddress());
322             } else {
323                 sendError(GENERAL_FAILURE);
324                 return;
325             }
326             out.write((addr.getPort() >> 8) & 0xff);
327             out.write((addr.getPort() >> 0) & 0xff);
328             out.flush();
329 
330             InputStream in2 = new BufferedInputStream(dest.getInputStream());
331             OutputStream out2 = new BufferedOutputStream(dest.getOutputStream());
332 
333             Tunnel tunnel = new Tunnel(in2, out);
334             tunnel.start();
335 
336             int b = 0;
337             do {
338                 // Note that the socket might be closed from another thread (the tunnel)
339                 try {
340                     b = in.read();
341                     if (b == -1) {
342                         in.close();
343                         out2.close();
344                         return;
345                     }
346                     out2.write(b);
347                     out2.flush();
348                 } catch(IOException ioe) {
349                 }
350             } while (!client.isClosed());
351         }
352 
doBind(InetSocketAddress addr)353         private void doBind(InetSocketAddress addr) throws IOException {
354             ServerSocket svr = new ServerSocket();
355             svr.bind(null);
356             InetSocketAddress bad = (InetSocketAddress) svr.getLocalSocketAddress();
357             out.write(PROTO_VERS);
358             out.write(REQUEST_OK);
359             out.write(0);
360             out.write(IPV4);
361             out.write(bad.getAddress().getAddress());
362             out.write((bad.getPort() >> 8) & 0xff);
363             out.write((bad.getPort() & 0xff));
364             out.flush();
365             dest = svr.accept();
366             bad = (InetSocketAddress) dest.getRemoteSocketAddress();
367             out.write(PROTO_VERS);
368             out.write(REQUEST_OK);
369             out.write(0);
370             out.write(IPV4);
371             out.write(bad.getAddress().getAddress());
372             out.write((bad.getPort() >> 8) & 0xff);
373             out.write((bad.getPort() & 0xff));
374             out.flush();
375             InputStream in2 = dest.getInputStream();
376             OutputStream out2 = dest.getOutputStream();
377 
378             Tunnel tunnel = new Tunnel(in2, out);
379             tunnel.start();
380 
381             int b = 0;
382             do {
383                 // Note that the socket might be close from another thread (the tunnel)
384                 try {
385                     b = in.read();
386                     if (b == -1) {
387                         in.close();
388                         out2.close();
389                         return;
390                     }
391                     out2.write(b);
392                     out2.flush();
393                 } catch(IOException ioe) {
394                 }
395             } while (!client.isClosed());
396 
397         }
398 
399         // Handle the SOCKS v5 requests
400 
getRequest()401         private void getRequest() throws IOException {
402             int ver = in.read();
403             int cmd = in.read();
404             if (ver == -1 || cmd == -1) {
405                 in.close();
406                 out.close();
407                 return;
408             }
409             int rsv = in.read();
410             int atyp = in.read();
411             String addr = null;
412             int port = 0;
413 
414             switch(atyp) {
415             case IPV4:
416                 {
417                 byte[] buf = new byte[4];
418                 readBuf(in, buf);
419                 addr = InetAddress.getByAddress(buf).getHostAddress();
420                 }
421                 break;
422             case DOMAIN_NAME:
423                 {
424                 int i = in.read();
425                 byte[] buf = new byte[i];
426                 readBuf(in, buf);
427                 addr = new String(buf);
428                 }
429                 break;
430             case IPV6:
431                 {
432                 byte[] buf = new byte[16];
433                 readBuf(in, buf);
434                 addr = InetAddress.getByAddress(buf).getHostAddress();
435                 }
436                 break;
437             }
438 
439             port = ((in.read()&0xff) << 8);
440             port += (in.read()&0xff);
441 
442             InetSocketAddress socAddr = new InetSocketAddress(addr, port);
443             switch(cmd) {
444             case CONNECT:
445                 doConnect(socAddr);
446                 break;
447             case BIND:
448                 doBind(socAddr);
449                 break;
450             case UDP_ASSOC:
451                 // doUDP(socAddr);
452                 break;
453             }
454         }
455 
run()456         public void run() {
457             String line = null;
458             try {
459                 if (useV4) {
460                     getRequestV4();
461                 } else {
462                     negociate();
463                     getRequest();
464                 }
465             } catch (IOException ex) {
466                 try {
467                     sendError(GENERAL_FAILURE);
468                 } catch (Exception e) {
469                 }
470             } finally {
471                 try {
472                     client.close();
473                 } catch (IOException e2) {
474                 }
475             }
476         }
477 
478     }
479 
SocksServer(int port, boolean v4)480     public SocksServer(int port, boolean v4) throws IOException {
481         this(port);
482         this.useV4 = v4;
483     }
484 
SocksServer(int port)485     public SocksServer(int port) throws IOException {
486         this.port = port;
487         server = new ServerSocket();
488         if (port == 0) {
489             server.bind(null);
490             this.port = server.getLocalPort();
491         } else {
492             server.bind(new InetSocketAddress(port));
493         }
494     }
495 
SocksServer()496     public SocksServer() throws IOException {
497         this (DEFAULT_PORT);
498     }
499 
addUser(String user, String passwd)500     public void addUser(String user, String passwd) {
501         users.put(user, passwd);
502     }
503 
getPort()504     public int getPort() {
505         return port;
506     }
507 
terminate()508     public void terminate() {
509         done = true;
510         try { server.close(); } catch (IOException unused) {}
511     }
512 
run()513     public void run() {
514         ClientHandler cl = null;
515         while (!done) {
516             try {
517                 Socket s = server.accept();
518                 cl = new ClientHandler(s);
519                 cl.start();
520             } catch (IOException ex) {
521                 if (cl != null)
522                     cl.interrupt();
523             }
524         }
525     }
526 }
527