/* * Copyright (C) 2014 Square, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.squareup.okhttp; import com.squareup.okhttp.internal.NamedRunnable; import com.squareup.okhttp.internal.Util; import java.io.IOException; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.ProtocolException; import java.net.Proxy; import java.net.ServerSocket; import java.net.Socket; import java.net.SocketException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.logging.Level; import java.util.logging.Logger; import okio.Buffer; import okio.BufferedSink; import okio.BufferedSource; import okio.Okio; /** * A limited implementation of SOCKS Protocol Version 5, intended to be similar to MockWebServer. * See RFC 1928. */ public final class SocksProxy { public final String HOSTNAME_THAT_ONLY_THE_PROXY_KNOWS = "onlyProxyCanResolveMe.org"; private static final int VERSION_5 = 5; private static final int METHOD_NONE = 0xff; private static final int METHOD_NO_AUTHENTICATION_REQUIRED = 0; private static final int ADDRESS_TYPE_IPV4 = 1; private static final int ADDRESS_TYPE_DOMAIN_NAME = 3; private static final int COMMAND_CONNECT = 1; private static final int REPLY_SUCCEEDED = 0; private static final Logger logger = Logger.getLogger(SocksProxy.class.getName()); private final ExecutorService executor = Executors.newCachedThreadPool( Util.threadFactory("SocksProxy", false)); private ServerSocket serverSocket; private AtomicInteger connectionCount = new AtomicInteger(); public void play() throws IOException { serverSocket = new ServerSocket(0); executor.execute(new NamedRunnable("SocksProxy %s", serverSocket.getLocalPort()) { @Override protected void execute() { try { while (true) { Socket socket = serverSocket.accept(); connectionCount.incrementAndGet(); service(socket); } } catch (SocketException e) { logger.info(name + " done accepting connections: " + e.getMessage()); } catch (IOException e) { logger.log(Level.WARNING, name + " failed unexpectedly", e); } } }); } public Proxy proxy() { return new Proxy(Proxy.Type.SOCKS, InetSocketAddress.createUnresolved( "localhost", serverSocket.getLocalPort())); } public int connectionCount() { return connectionCount.get(); } public void shutdown() throws Exception { serverSocket.close(); executor.shutdown(); if (!executor.awaitTermination(5, TimeUnit.SECONDS)) { throw new IOException("Gave up waiting for executor to shut down"); } } private void service(final Socket from) { executor.execute(new NamedRunnable("SocksProxy %s", from.getRemoteSocketAddress()) { @Override protected void execute() { try { BufferedSource fromSource = Okio.buffer(Okio.source(from)); BufferedSink fromSink = Okio.buffer(Okio.sink(from)); hello(fromSource, fromSink); acceptCommand(from.getInetAddress(), fromSource, fromSink); } catch (IOException e) { logger.log(Level.WARNING, name + " failed", e); Util.closeQuietly(from); } } }); } private void hello(BufferedSource fromSource, BufferedSink fromSink) throws IOException { int version = fromSource.readByte() & 0xff; int methodCount = fromSource.readByte() & 0xff; int selectedMethod = METHOD_NONE; if (version != VERSION_5) { throw new ProtocolException("unsupported version: " + version); } for (int i = 0; i < methodCount; i++) { int candidateMethod = fromSource.readByte() & 0xff; if (candidateMethod == METHOD_NO_AUTHENTICATION_REQUIRED) { selectedMethod = candidateMethod; } } switch (selectedMethod) { case METHOD_NO_AUTHENTICATION_REQUIRED: fromSink.writeByte(VERSION_5); fromSink.writeByte(selectedMethod); fromSink.emit(); break; default: throw new ProtocolException("unsupported method: " + selectedMethod); } } private void acceptCommand(InetAddress fromAddress, BufferedSource fromSource, BufferedSink fromSink) throws IOException { // Read the command. int version = fromSource.readByte() & 0xff; if (version != VERSION_5) throw new ProtocolException("unexpected version: " + version); int command = fromSource.readByte() & 0xff; int reserved = fromSource.readByte() & 0xff; if (reserved != 0) throw new ProtocolException("unexpected reserved: " + reserved); int addressType = fromSource.readByte() & 0xff; InetAddress toAddress; switch (addressType) { case ADDRESS_TYPE_IPV4: toAddress = InetAddress.getByAddress(fromSource.readByteArray(4L)); break; case ADDRESS_TYPE_DOMAIN_NAME: int domainNameLength = fromSource.readByte() & 0xff; String domainName = fromSource.readUtf8(domainNameLength); // Resolve HOSTNAME_THAT_ONLY_THE_PROXY_KNOWS to localhost. toAddress = domainName.equalsIgnoreCase(HOSTNAME_THAT_ONLY_THE_PROXY_KNOWS) ? InetAddress.getByName("localhost") : InetAddress.getByName(domainName); break; default: throw new ProtocolException("unsupported address type: " + addressType); } int port = fromSource.readShort() & 0xffff; switch (command) { case COMMAND_CONNECT: Socket toSocket = new Socket(toAddress, port); byte[] localAddress = toSocket.getLocalAddress().getAddress(); if (localAddress.length != 4) { throw new ProtocolException("unexpected address: " + toSocket.getLocalAddress()); } // Write the reply. fromSink.writeByte(VERSION_5); fromSink.writeByte(REPLY_SUCCEEDED); fromSink.writeByte(0); fromSink.writeByte(ADDRESS_TYPE_IPV4); fromSink.write(localAddress); fromSink.writeShort(toSocket.getLocalPort()); fromSink.emit(); logger.log(Level.INFO, "SocksProxy connected " + fromAddress + " to " + toAddress); // Copy sources to sinks in both directions. BufferedSource toSource = Okio.buffer(Okio.source(toSocket)); BufferedSink toSink = Okio.buffer(Okio.sink(toSocket)); transfer(fromAddress, toAddress, fromSource, toSink); transfer(fromAddress, toAddress, toSource, fromSink); break; default: throw new ProtocolException("unexpected command: " + command); } } private void transfer(final InetAddress fromAddress, final InetAddress toAddress, final BufferedSource source, final BufferedSink sink) { executor.execute(new NamedRunnable("SocksProxy %s to %s", fromAddress, toAddress) { @Override protected void execute() { Buffer buffer = new Buffer(); try { while (true) { long byteCount = source.read(buffer, 2048L); if (byteCount == -1L) break; sink.write(buffer, byteCount); sink.emit(); } } catch (SocketException e) { logger.info(name + " done: " + e.getMessage()); } catch (IOException e) { logger.log(Level.WARNING, name + " failed", e); } try { source.close(); } catch (IOException e) { logger.log(Level.WARNING, name + " failed", e); } try { sink.close(); } catch (IOException e) { logger.log(Level.WARNING, name + " failed", e); } } }); } }