1 /*
2  * Copyright 2016 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.testing.integration;
18 
19 import static com.google.common.base.Preconditions.checkArgument;
20 
21 import io.netty.util.concurrent.DefaultThreadFactory;
22 import java.io.DataInputStream;
23 import java.io.DataOutputStream;
24 import java.io.IOException;
25 import java.net.InetSocketAddress;
26 import java.net.ServerSocket;
27 import java.net.Socket;
28 import java.util.concurrent.BlockingQueue;
29 import java.util.concurrent.DelayQueue;
30 import java.util.concurrent.Delayed;
31 import java.util.concurrent.LinkedBlockingQueue;
32 import java.util.concurrent.ThreadPoolExecutor;
33 import java.util.concurrent.TimeUnit;
34 import java.util.logging.Logger;
35 
36 public final class TrafficControlProxy {
37 
38   private static final int DEFAULT_BAND_BPS = 1024 * 1024;
39   private static final int DEFAULT_DELAY_NANOS = 200 * 1000 * 1000;
40   private static final Logger logger = Logger.getLogger(TrafficControlProxy.class.getName());
41 
42   // TODO: make host and ports arguments
43   private final String localhost = "localhost";
44   private final int serverPort;
45   private final int queueLength;
46   private final int chunkSize;
47   private final int bandwidth;
48   private final long latency;
49   private volatile boolean shutDown;
50   private ServerSocket clientAcceptor;
51   private Socket serverSock;
52   private Socket clientSock;
53   private final ThreadPoolExecutor executor =
54       new ThreadPoolExecutor(5, 10, 1, TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>(),
55           new DefaultThreadFactory("proxy-pool", true));
56 
57   /**
58    * Returns a new TrafficControlProxy with default bandwidth and latency.
59    */
TrafficControlProxy(int serverPort)60   public TrafficControlProxy(int serverPort) {
61     this(serverPort, DEFAULT_BAND_BPS, DEFAULT_DELAY_NANOS, TimeUnit.NANOSECONDS);
62   }
63 
64   /**
65    * Returns a new TrafficControlProxy with bandwidth set to targetBPS, and latency set to
66    * targetLatency in latencyUnits.
67    */
TrafficControlProxy(int serverPort, int targetBps, int targetLatency, TimeUnit latencyUnits)68   public TrafficControlProxy(int serverPort, int targetBps, int targetLatency,
69       TimeUnit latencyUnits) {
70     checkArgument(targetBps > 0);
71     checkArgument(targetLatency > 0);
72     this.serverPort = serverPort;
73     bandwidth = targetBps;
74     // divide by 2 because latency is applied in both directions
75     latency = latencyUnits.toNanos(targetLatency) / 2;
76     queueLength = (int) Math.max(bandwidth * latency / TimeUnit.SECONDS.toNanos(1), 1);
77     chunkSize = Math.max(1, queueLength);
78   }
79 
80   /**
81    * Starts a new thread that waits for client and server and start reader/writer threads.
82    */
start()83   public void start() throws IOException {
84     // ClientAcceptor uses a ServerSocket server so that the client can connect to the proxy as it
85     // normally would a server. serverSock then connects the server using a regular Socket as a
86     // client normally would.
87     clientAcceptor = new ServerSocket();
88     clientAcceptor.bind(new InetSocketAddress(localhost, 0));
89     executor.execute(new Runnable() {
90       @Override
91       public void run() {
92         try {
93           clientSock = clientAcceptor.accept();
94           serverSock = new Socket();
95           serverSock.connect(new InetSocketAddress(localhost, serverPort));
96           startWorkers();
97         } catch (IOException e) {
98           throw new RuntimeException(e);
99         }
100       }
101     });
102     logger.info("Started new proxy on port " + clientAcceptor.getLocalPort()
103         + " with Queue Length " + queueLength);
104   }
105 
getPort()106   public int getPort() {
107     return clientAcceptor.getLocalPort();
108   }
109 
110   /** Interrupt all workers and close sockets. */
shutDown()111   public void shutDown() throws IOException {
112     // TODO: Handle case where a socket fails to close, therefore blocking the others from closing
113     logger.info("Proxy shutting down... ");
114     shutDown = true;
115     executor.shutdown();
116     clientAcceptor.close();
117     clientSock.close();
118     serverSock.close();
119     logger.info("Shutdown Complete");
120   }
121 
startWorkers()122   private void startWorkers() throws IOException {
123     DataInputStream clientIn = new DataInputStream(clientSock.getInputStream());
124     DataOutputStream clientOut = new DataOutputStream(serverSock.getOutputStream());
125     DataInputStream serverIn = new DataInputStream(serverSock.getInputStream());
126     DataOutputStream serverOut = new DataOutputStream(clientSock.getOutputStream());
127 
128     MessageQueue clientPipe = new MessageQueue(clientIn, clientOut);
129     MessageQueue serverPipe = new MessageQueue(serverIn, serverOut);
130 
131     executor.execute(new Reader(clientPipe));
132     executor.execute(new Writer(clientPipe));
133     executor.execute(new Reader(serverPipe));
134     executor.execute(new Writer(serverPipe));
135   }
136 
137   private final class Reader implements Runnable {
138 
139     private final MessageQueue queue;
140 
Reader(MessageQueue queue)141     Reader(MessageQueue queue) {
142       this.queue = queue;
143     }
144 
145     @Override
run()146     public void run() {
147       while (!shutDown) {
148         try {
149           queue.readIn();
150         } catch (IOException e) {
151           shutDown = true;
152         } catch (InterruptedException e) {
153           shutDown = true;
154         }
155       }
156     }
157 
158   }
159 
160   private final class Writer implements Runnable {
161 
162     private final MessageQueue queue;
163 
Writer(MessageQueue queue)164     Writer(MessageQueue queue) {
165       this.queue = queue;
166     }
167 
168     @Override
run()169     public void run() {
170       while (!shutDown) {
171         try {
172           queue.writeOut();
173         } catch (IOException e) {
174           shutDown = true;
175         } catch (InterruptedException e) {
176           shutDown = true;
177         }
178       }
179     }
180   }
181 
182   /**
183    * A Delay Queue that counts by number of bytes instead of the number of elements.
184    */
185   private class MessageQueue {
186     DataInputStream inStream;
187     DataOutputStream outStream;
188     int bytesQueued;
189     BlockingQueue<Message> queue = new DelayQueue<Message>();
190 
MessageQueue(DataInputStream inputStream, DataOutputStream outputStream)191     MessageQueue(DataInputStream inputStream, DataOutputStream outputStream) {
192       inStream = inputStream;
193       outStream = outputStream;
194     }
195 
196     /**
197      * Take a message off the queue and write it to an endpoint. Blocks until a message becomes
198      * available.
199      */
writeOut()200     void writeOut() throws InterruptedException, IOException {
201       Message next = queue.take();
202       outStream.write(next.message, 0, next.messageLength);
203       incrementBytes(-next.messageLength);
204     }
205 
206     /**
207      * Read bytes from an endpoint and add them as a message to the queue. Blocks if the queue is
208      * full.
209      */
readIn()210     void readIn() throws InterruptedException, IOException {
211       byte[] request = new byte[getNextChunk()];
212       int readableBytes = inStream.read(request);
213       long sendTime = System.nanoTime() + latency;
214       queue.put(new Message(sendTime, request, readableBytes));
215       incrementBytes(readableBytes);
216     }
217 
218     /**
219      * Block until space on the queue becomes available. Returns how many bytes can be read on to
220      * the queue
221      */
getNextChunk()222     synchronized int getNextChunk() throws InterruptedException {
223       while (bytesQueued == queueLength) {
224         wait();
225       }
226       return Math.max(0, Math.min(chunkSize, queueLength - bytesQueued));
227     }
228 
incrementBytes(int delta)229     synchronized void incrementBytes(int delta) {
230       bytesQueued += delta;
231       if (bytesQueued < queueLength) {
232         notifyAll();
233       }
234     }
235   }
236 
237   private static class Message implements Delayed {
238     long sendTime;
239     byte[] message;
240     int messageLength;
241 
Message(long sendTime, byte[] message, int messageLength)242     Message(long sendTime, byte[] message, int messageLength) {
243       this.sendTime = sendTime;
244       this.message = message;
245       this.messageLength = messageLength;
246     }
247 
248     @Override
compareTo(Delayed o)249     public int compareTo(Delayed o) {
250       return ((Long) sendTime).compareTo(((Message) o).sendTime);
251     }
252 
253     @Override
getDelay(TimeUnit unit)254     public long getDelay(TimeUnit unit) {
255       return unit.convert(sendTime - System.nanoTime(), TimeUnit.NANOSECONDS);
256     }
257   }
258 }
259