1 /* 2 * Copyright 2016 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.testing.integration; 18 19 import static com.google.common.base.Preconditions.checkArgument; 20 21 import io.netty.util.concurrent.DefaultThreadFactory; 22 import java.io.DataInputStream; 23 import java.io.DataOutputStream; 24 import java.io.IOException; 25 import java.net.InetSocketAddress; 26 import java.net.ServerSocket; 27 import java.net.Socket; 28 import java.util.concurrent.BlockingQueue; 29 import java.util.concurrent.DelayQueue; 30 import java.util.concurrent.Delayed; 31 import java.util.concurrent.LinkedBlockingQueue; 32 import java.util.concurrent.ThreadPoolExecutor; 33 import java.util.concurrent.TimeUnit; 34 import java.util.logging.Logger; 35 36 public final class TrafficControlProxy { 37 38 private static final int DEFAULT_BAND_BPS = 1024 * 1024; 39 private static final int DEFAULT_DELAY_NANOS = 200 * 1000 * 1000; 40 private static final Logger logger = Logger.getLogger(TrafficControlProxy.class.getName()); 41 42 // TODO: make host and ports arguments 43 private final String localhost = "localhost"; 44 private final int serverPort; 45 private final int queueLength; 46 private final int chunkSize; 47 private final int bandwidth; 48 private final long latency; 49 private volatile boolean shutDown; 50 private ServerSocket clientAcceptor; 51 private Socket serverSock; 52 private Socket clientSock; 53 private final ThreadPoolExecutor executor = 54 new ThreadPoolExecutor(5, 10, 1, TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>(), 55 new DefaultThreadFactory("proxy-pool", true)); 56 57 /** 58 * Returns a new TrafficControlProxy with default bandwidth and latency. 59 */ TrafficControlProxy(int serverPort)60 public TrafficControlProxy(int serverPort) { 61 this(serverPort, DEFAULT_BAND_BPS, DEFAULT_DELAY_NANOS, TimeUnit.NANOSECONDS); 62 } 63 64 /** 65 * Returns a new TrafficControlProxy with bandwidth set to targetBPS, and latency set to 66 * targetLatency in latencyUnits. 67 */ TrafficControlProxy(int serverPort, int targetBps, int targetLatency, TimeUnit latencyUnits)68 public TrafficControlProxy(int serverPort, int targetBps, int targetLatency, 69 TimeUnit latencyUnits) { 70 checkArgument(targetBps > 0); 71 checkArgument(targetLatency > 0); 72 this.serverPort = serverPort; 73 bandwidth = targetBps; 74 // divide by 2 because latency is applied in both directions 75 latency = latencyUnits.toNanos(targetLatency) / 2; 76 queueLength = (int) Math.max(bandwidth * latency / TimeUnit.SECONDS.toNanos(1), 1); 77 chunkSize = Math.max(1, queueLength); 78 } 79 80 /** 81 * Starts a new thread that waits for client and server and start reader/writer threads. 82 */ start()83 public void start() throws IOException { 84 // ClientAcceptor uses a ServerSocket server so that the client can connect to the proxy as it 85 // normally would a server. serverSock then connects the server using a regular Socket as a 86 // client normally would. 87 clientAcceptor = new ServerSocket(); 88 clientAcceptor.bind(new InetSocketAddress(localhost, 0)); 89 executor.execute(new Runnable() { 90 @Override 91 public void run() { 92 try { 93 clientSock = clientAcceptor.accept(); 94 serverSock = new Socket(); 95 serverSock.connect(new InetSocketAddress(localhost, serverPort)); 96 startWorkers(); 97 } catch (IOException e) { 98 throw new RuntimeException(e); 99 } 100 } 101 }); 102 logger.info("Started new proxy on port " + clientAcceptor.getLocalPort() 103 + " with Queue Length " + queueLength); 104 } 105 getPort()106 public int getPort() { 107 return clientAcceptor.getLocalPort(); 108 } 109 110 /** Interrupt all workers and close sockets. */ shutDown()111 public void shutDown() throws IOException { 112 // TODO: Handle case where a socket fails to close, therefore blocking the others from closing 113 logger.info("Proxy shutting down... "); 114 shutDown = true; 115 executor.shutdown(); 116 clientAcceptor.close(); 117 clientSock.close(); 118 serverSock.close(); 119 logger.info("Shutdown Complete"); 120 } 121 startWorkers()122 private void startWorkers() throws IOException { 123 DataInputStream clientIn = new DataInputStream(clientSock.getInputStream()); 124 DataOutputStream clientOut = new DataOutputStream(serverSock.getOutputStream()); 125 DataInputStream serverIn = new DataInputStream(serverSock.getInputStream()); 126 DataOutputStream serverOut = new DataOutputStream(clientSock.getOutputStream()); 127 128 MessageQueue clientPipe = new MessageQueue(clientIn, clientOut); 129 MessageQueue serverPipe = new MessageQueue(serverIn, serverOut); 130 131 executor.execute(new Reader(clientPipe)); 132 executor.execute(new Writer(clientPipe)); 133 executor.execute(new Reader(serverPipe)); 134 executor.execute(new Writer(serverPipe)); 135 } 136 137 private final class Reader implements Runnable { 138 139 private final MessageQueue queue; 140 Reader(MessageQueue queue)141 Reader(MessageQueue queue) { 142 this.queue = queue; 143 } 144 145 @Override run()146 public void run() { 147 while (!shutDown) { 148 try { 149 queue.readIn(); 150 } catch (IOException e) { 151 shutDown = true; 152 } catch (InterruptedException e) { 153 shutDown = true; 154 } 155 } 156 } 157 158 } 159 160 private final class Writer implements Runnable { 161 162 private final MessageQueue queue; 163 Writer(MessageQueue queue)164 Writer(MessageQueue queue) { 165 this.queue = queue; 166 } 167 168 @Override run()169 public void run() { 170 while (!shutDown) { 171 try { 172 queue.writeOut(); 173 } catch (IOException e) { 174 shutDown = true; 175 } catch (InterruptedException e) { 176 shutDown = true; 177 } 178 } 179 } 180 } 181 182 /** 183 * A Delay Queue that counts by number of bytes instead of the number of elements. 184 */ 185 private class MessageQueue { 186 DataInputStream inStream; 187 DataOutputStream outStream; 188 int bytesQueued; 189 BlockingQueue<Message> queue = new DelayQueue<Message>(); 190 MessageQueue(DataInputStream inputStream, DataOutputStream outputStream)191 MessageQueue(DataInputStream inputStream, DataOutputStream outputStream) { 192 inStream = inputStream; 193 outStream = outputStream; 194 } 195 196 /** 197 * Take a message off the queue and write it to an endpoint. Blocks until a message becomes 198 * available. 199 */ writeOut()200 void writeOut() throws InterruptedException, IOException { 201 Message next = queue.take(); 202 outStream.write(next.message, 0, next.messageLength); 203 incrementBytes(-next.messageLength); 204 } 205 206 /** 207 * Read bytes from an endpoint and add them as a message to the queue. Blocks if the queue is 208 * full. 209 */ readIn()210 void readIn() throws InterruptedException, IOException { 211 byte[] request = new byte[getNextChunk()]; 212 int readableBytes = inStream.read(request); 213 long sendTime = System.nanoTime() + latency; 214 queue.put(new Message(sendTime, request, readableBytes)); 215 incrementBytes(readableBytes); 216 } 217 218 /** 219 * Block until space on the queue becomes available. Returns how many bytes can be read on to 220 * the queue 221 */ getNextChunk()222 synchronized int getNextChunk() throws InterruptedException { 223 while (bytesQueued == queueLength) { 224 wait(); 225 } 226 return Math.max(0, Math.min(chunkSize, queueLength - bytesQueued)); 227 } 228 incrementBytes(int delta)229 synchronized void incrementBytes(int delta) { 230 bytesQueued += delta; 231 if (bytesQueued < queueLength) { 232 notifyAll(); 233 } 234 } 235 } 236 237 private static class Message implements Delayed { 238 long sendTime; 239 byte[] message; 240 int messageLength; 241 Message(long sendTime, byte[] message, int messageLength)242 Message(long sendTime, byte[] message, int messageLength) { 243 this.sendTime = sendTime; 244 this.message = message; 245 this.messageLength = messageLength; 246 } 247 248 @Override compareTo(Delayed o)249 public int compareTo(Delayed o) { 250 return ((Long) sendTime).compareTo(((Message) o).sendTime); 251 } 252 253 @Override getDelay(TimeUnit unit)254 public long getDelay(TimeUnit unit) { 255 return unit.convert(sendTime - System.nanoTime(), TimeUnit.NANOSECONDS); 256 } 257 } 258 } 259