1 /*
2  * Copyright (C) 2014 Square, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.squareup.okhttp;
17 
18 import com.squareup.okhttp.internal.NamedRunnable;
19 import com.squareup.okhttp.internal.Util;
20 import java.io.IOException;
21 import java.net.InetAddress;
22 import java.net.InetSocketAddress;
23 import java.net.ProtocolException;
24 import java.net.Proxy;
25 import java.net.ServerSocket;
26 import java.net.Socket;
27 import java.net.SocketException;
28 import java.util.concurrent.ExecutorService;
29 import java.util.concurrent.Executors;
30 import java.util.concurrent.TimeUnit;
31 import java.util.concurrent.atomic.AtomicInteger;
32 import java.util.logging.Level;
33 import java.util.logging.Logger;
34 import okio.Buffer;
35 import okio.BufferedSink;
36 import okio.BufferedSource;
37 import okio.Okio;
38 
39 /**
40  * A limited implementation of SOCKS Protocol Version 5, intended to be similar to MockWebServer.
41  * See <a href="https://www.ietf.org/rfc/rfc1928.txt">RFC 1928</a>.
42  */
43 public final class SocksProxy {
44   public final String HOSTNAME_THAT_ONLY_THE_PROXY_KNOWS = "onlyProxyCanResolveMe.org";
45 
46   private static final int VERSION_5 = 5;
47   private static final int METHOD_NONE = 0xff;
48   private static final int METHOD_NO_AUTHENTICATION_REQUIRED = 0;
49   private static final int ADDRESS_TYPE_IPV4 = 1;
50   private static final int ADDRESS_TYPE_DOMAIN_NAME = 3;
51   private static final int COMMAND_CONNECT = 1;
52   private static final int REPLY_SUCCEEDED = 0;
53 
54   private static final Logger logger = Logger.getLogger(SocksProxy.class.getName());
55 
56   private final ExecutorService executor = Executors.newCachedThreadPool(
57       Util.threadFactory("SocksProxy", false));
58 
59   private ServerSocket serverSocket;
60   private AtomicInteger connectionCount = new AtomicInteger();
61 
play()62   public void play() throws IOException {
63     serverSocket = new ServerSocket(0);
64     executor.execute(new NamedRunnable("SocksProxy %s", serverSocket.getLocalPort()) {
65       @Override protected void execute() {
66         try {
67           while (true) {
68             Socket socket = serverSocket.accept();
69             connectionCount.incrementAndGet();
70             service(socket);
71           }
72         } catch (SocketException e) {
73           logger.info(name + " done accepting connections: " + e.getMessage());
74         } catch (IOException e) {
75           logger.log(Level.WARNING, name + " failed unexpectedly", e);
76         }
77       }
78     });
79   }
80 
proxy()81   public Proxy proxy() {
82     return new Proxy(Proxy.Type.SOCKS, InetSocketAddress.createUnresolved(
83         "localhost", serverSocket.getLocalPort()));
84   }
85 
connectionCount()86   public int connectionCount() {
87     return connectionCount.get();
88   }
89 
shutdown()90   public void shutdown() throws Exception {
91     serverSocket.close();
92     executor.shutdown();
93     if (!executor.awaitTermination(5, TimeUnit.SECONDS)) {
94       throw new IOException("Gave up waiting for executor to shut down");
95     }
96   }
97 
service(final Socket from)98   private void service(final Socket from) {
99     executor.execute(new NamedRunnable("SocksProxy %s", from.getRemoteSocketAddress()) {
100       @Override protected void execute() {
101         try {
102           BufferedSource fromSource = Okio.buffer(Okio.source(from));
103           BufferedSink fromSink = Okio.buffer(Okio.sink(from));
104           hello(fromSource, fromSink);
105           acceptCommand(from.getInetAddress(), fromSource, fromSink);
106         } catch (IOException e) {
107           logger.log(Level.WARNING, name + " failed", e);
108           Util.closeQuietly(from);
109         }
110       }
111     });
112   }
113 
hello(BufferedSource fromSource, BufferedSink fromSink)114   private void hello(BufferedSource fromSource, BufferedSink fromSink) throws IOException {
115     int version = fromSource.readByte() & 0xff;
116     int methodCount = fromSource.readByte() & 0xff;
117     int selectedMethod = METHOD_NONE;
118 
119     if (version != VERSION_5) {
120       throw new ProtocolException("unsupported version: " + version);
121     }
122 
123     for (int i = 0; i < methodCount; i++) {
124       int candidateMethod = fromSource.readByte() & 0xff;
125       if (candidateMethod == METHOD_NO_AUTHENTICATION_REQUIRED) {
126         selectedMethod = candidateMethod;
127       }
128     }
129 
130     switch (selectedMethod) {
131       case METHOD_NO_AUTHENTICATION_REQUIRED:
132         fromSink.writeByte(VERSION_5);
133         fromSink.writeByte(selectedMethod);
134         fromSink.emit();
135         break;
136 
137       default:
138         throw new ProtocolException("unsupported method: " + selectedMethod);
139     }
140   }
141 
acceptCommand(InetAddress fromAddress, BufferedSource fromSource, BufferedSink fromSink)142   private void acceptCommand(InetAddress fromAddress, BufferedSource fromSource,
143       BufferedSink fromSink) throws IOException {
144     // Read the command.
145     int version = fromSource.readByte() & 0xff;
146     if (version != VERSION_5) throw new ProtocolException("unexpected version: " + version);
147     int command = fromSource.readByte() & 0xff;
148     int reserved = fromSource.readByte() & 0xff;
149     if (reserved != 0) throw new ProtocolException("unexpected reserved: " + reserved);
150 
151     int addressType = fromSource.readByte() & 0xff;
152     InetAddress toAddress;
153     switch (addressType) {
154       case ADDRESS_TYPE_IPV4:
155         toAddress = InetAddress.getByAddress(fromSource.readByteArray(4L));
156         break;
157 
158       case ADDRESS_TYPE_DOMAIN_NAME:
159         int domainNameLength = fromSource.readByte() & 0xff;
160         String domainName = fromSource.readUtf8(domainNameLength);
161         // Resolve HOSTNAME_THAT_ONLY_THE_PROXY_KNOWS to localhost.
162         toAddress = domainName.equalsIgnoreCase(HOSTNAME_THAT_ONLY_THE_PROXY_KNOWS)
163             ? InetAddress.getByName("localhost")
164             : InetAddress.getByName(domainName);
165         break;
166 
167       default:
168         throw new ProtocolException("unsupported address type: " + addressType);
169     }
170 
171     int port = fromSource.readShort() & 0xffff;
172 
173     switch (command) {
174       case COMMAND_CONNECT:
175         Socket toSocket = new Socket(toAddress, port);
176         byte[] localAddress = toSocket.getLocalAddress().getAddress();
177         if (localAddress.length != 4) {
178           throw new ProtocolException("unexpected address: " + toSocket.getLocalAddress());
179         }
180 
181         // Write the reply.
182         fromSink.writeByte(VERSION_5);
183         fromSink.writeByte(REPLY_SUCCEEDED);
184         fromSink.writeByte(0);
185         fromSink.writeByte(ADDRESS_TYPE_IPV4);
186         fromSink.write(localAddress);
187         fromSink.writeShort(toSocket.getLocalPort());
188         fromSink.emit();
189 
190         logger.log(Level.INFO, "SocksProxy connected " + fromAddress + " to " + toAddress);
191 
192         // Copy sources to sinks in both directions.
193         BufferedSource toSource = Okio.buffer(Okio.source(toSocket));
194         BufferedSink toSink = Okio.buffer(Okio.sink(toSocket));
195         transfer(fromAddress, toAddress, fromSource, toSink);
196         transfer(fromAddress, toAddress, toSource, fromSink);
197         break;
198 
199       default:
200         throw new ProtocolException("unexpected command: " + command);
201     }
202   }
203 
transfer(final InetAddress fromAddress, final InetAddress toAddress, final BufferedSource source, final BufferedSink sink)204   private void transfer(final InetAddress fromAddress, final InetAddress toAddress,
205       final BufferedSource source, final BufferedSink sink) {
206     executor.execute(new NamedRunnable("SocksProxy %s to %s", fromAddress, toAddress) {
207       @Override protected void execute() {
208         Buffer buffer = new Buffer();
209         try {
210           while (true) {
211             long byteCount = source.read(buffer, 2048L);
212             if (byteCount == -1L) break;
213             sink.write(buffer, byteCount);
214             sink.emit();
215           }
216         } catch (SocketException e) {
217           logger.info(name + " done: " + e.getMessage());
218         } catch (IOException e) {
219           logger.log(Level.WARNING, name + " failed", e);
220         }
221 
222         try {
223           source.close();
224         } catch (IOException e) {
225           logger.log(Level.WARNING, name + " failed", e);
226         }
227 
228         try {
229           sink.close();
230         } catch (IOException e) {
231           logger.log(Level.WARNING, name + " failed", e);
232         }
233       }
234     });
235   }
236 }
237