1 /*
2  * Copyright 2016 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.testing.integration;
18 
19 import static java.util.concurrent.Executors.newFixedThreadPool;
20 
21 import com.google.common.util.concurrent.Futures;
22 import com.google.common.util.concurrent.ListenableFuture;
23 import com.google.common.util.concurrent.ListeningExecutorService;
24 import com.google.common.util.concurrent.MoreExecutors;
25 import com.google.protobuf.ByteString;
26 import io.grpc.ManagedChannel;
27 import io.grpc.Status;
28 import io.grpc.StatusRuntimeException;
29 import io.grpc.netty.NegotiationType;
30 import io.grpc.netty.NettyChannelBuilder;
31 import io.grpc.stub.StreamObserver;
32 import io.grpc.testing.integration.Messages.Payload;
33 import io.grpc.testing.integration.Messages.PayloadType;
34 import io.grpc.testing.integration.Messages.SimpleRequest;
35 import io.grpc.testing.integration.Messages.SimpleResponse;
36 import java.net.InetAddress;
37 import java.net.InetSocketAddress;
38 import java.net.UnknownHostException;
39 import java.util.ArrayList;
40 import java.util.List;
41 import java.util.concurrent.CountDownLatch;
42 import java.util.concurrent.TimeUnit;
43 import java.util.logging.Level;
44 import java.util.logging.Logger;
45 
46 /**
47  * Client application for the {@link TestServiceGrpc.TestServiceImplBase} that runs through a series
48  * of HTTP/2 interop tests. The tests are designed to simulate incorrect behavior on the part of the
49  * server. Some of the test cases require server-side checks and do not have assertions within the
50  * client code.
51  */
52 public final class Http2Client {
53   private static final Logger logger = Logger.getLogger(Http2Client.class.getName());
54 
55   /**
56    * The main application allowing this client to be launched from the command line.
57    */
main(String[] args)58   public static void main(String[] args) throws Exception {
59     final Http2Client client = new Http2Client();
60     client.parseArgs(args);
61     client.setUp();
62 
63     Runtime.getRuntime().addShutdownHook(new Thread() {
64       @Override
65       public void run() {
66         try {
67           client.shutdown();
68         } catch (Exception e) {
69           logger.log(Level.SEVERE, e.getMessage(), e);
70         }
71       }
72     });
73 
74     try {
75       client.run();
76     } finally {
77       client.shutdown();
78     }
79   }
80 
81   private String serverHost = "localhost";
82   private int serverPort = 8080;
83   private String testCase = Http2TestCases.RST_AFTER_DATA.name();
84 
85   private Tester tester = new Tester();
86   private ListeningExecutorService threadpool;
87 
88   protected ManagedChannel channel;
89   protected TestServiceGrpc.TestServiceBlockingStub blockingStub;
90   protected TestServiceGrpc.TestServiceStub asyncStub;
91 
parseArgs(String[] args)92   private void parseArgs(String[] args) {
93     boolean usage = false;
94     for (String arg : args) {
95       if (!arg.startsWith("--")) {
96         System.err.println("All arguments must start with '--': " + arg);
97         usage = true;
98         break;
99       }
100       String[] parts = arg.substring(2).split("=", 2);
101       String key = parts[0];
102       if ("help".equals(key)) {
103         usage = true;
104         break;
105       }
106       if (parts.length != 2) {
107         System.err.println("All arguments must be of the form --arg=value");
108         usage = true;
109         break;
110       }
111       String value = parts[1];
112       if ("server_host".equals(key)) {
113         serverHost = value;
114       } else if ("server_port".equals(key)) {
115         serverPort = Integer.parseInt(value);
116       } else if ("test_case".equals(key)) {
117         testCase = value;
118       } else {
119         System.err.println("Unknown argument: " + key);
120         usage = true;
121         break;
122       }
123     }
124     if (usage) {
125       Http2Client c = new Http2Client();
126       System.out.println(
127           "Usage: [ARGS...]"
128               + "\n"
129               + "\n  --server_host=HOST          Server to connect to. Default " + c.serverHost
130               + "\n  --server_port=PORT          Port to connect to. Default " + c.serverPort
131               + "\n  --test_case=TESTCASE        Test case to run. Default " + c.testCase
132               + "\n    Valid options:"
133               + validTestCasesHelpText()
134       );
135       System.exit(1);
136     }
137   }
138 
setUp()139   private void setUp() {
140     channel = createChannel();
141     blockingStub = TestServiceGrpc.newBlockingStub(channel);
142     asyncStub = TestServiceGrpc.newStub(channel);
143   }
144 
shutdown()145   private void shutdown() {
146     try {
147       if (channel != null) {
148         channel.shutdownNow();
149         channel.awaitTermination(1, TimeUnit.SECONDS);
150       }
151     } catch (Exception ex) {
152       throw new RuntimeException(ex);
153     }
154 
155     try {
156       if (threadpool != null) {
157         threadpool.shutdownNow();
158       }
159     } catch (Exception ex) {
160       throw new RuntimeException(ex);
161     }
162   }
163 
run()164   private void run() {
165     logger.info("Running test " + testCase);
166     try {
167       runTest(Http2TestCases.fromString(testCase));
168     } catch (RuntimeException ex) {
169       throw ex;
170     } catch (Exception ex) {
171       throw new RuntimeException(ex);
172     }
173     logger.info("Test completed.");
174   }
175 
runTest(Http2TestCases testCase)176   private void runTest(Http2TestCases testCase) throws Exception {
177     switch (testCase) {
178       case RST_AFTER_HEADER:
179         tester.rstAfterHeader();
180         break;
181       case RST_AFTER_DATA:
182         tester.rstAfterData();
183         break;
184       case RST_DURING_DATA:
185         tester.rstDuringData();
186         break;
187       case GOAWAY:
188         tester.goAway();
189         break;
190       case PING:
191         tester.ping();
192         break;
193       case MAX_STREAMS:
194         tester.maxStreams();
195         break;
196       default:
197         throw new IllegalArgumentException("Unknown test case: " + testCase);
198     }
199   }
200 
201   private class Tester {
202     private final int timeoutSeconds = 180;
203 
204     private final int responseSize = 314159;
205     private final int payloadSize = 271828;
206     private final SimpleRequest simpleRequest = SimpleRequest.newBuilder()
207         .setResponseSize(responseSize)
208         .setResponseType(PayloadType.COMPRESSABLE)
209         .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(new byte[payloadSize])))
210         .build();
211     final SimpleResponse goldenResponse = SimpleResponse.newBuilder()
212         .setPayload(Payload.newBuilder()
213             .setType(PayloadType.COMPRESSABLE)
214             .setBody(ByteString.copyFrom(new byte[responseSize])))
215         .build();
216 
rstAfterHeader()217     private void rstAfterHeader() throws Exception {
218       try {
219         blockingStub.unaryCall(simpleRequest);
220         throw new AssertionError("Expected call to fail");
221       } catch (StatusRuntimeException ex) {
222         assertRstStreamReceived(ex.getStatus());
223       }
224     }
225 
rstAfterData()226     private void rstAfterData() throws Exception {
227       // Use async stub to verify data is received.
228       RstStreamObserver responseObserver = new RstStreamObserver();
229       asyncStub.unaryCall(simpleRequest, responseObserver);
230       if (!responseObserver.awaitCompletion(timeoutSeconds, TimeUnit.SECONDS)) {
231         throw new AssertionError("Operation timed out");
232       }
233       if (responseObserver.getError() == null) {
234         throw new AssertionError("Expected call to fail");
235       }
236       assertRstStreamReceived(Status.fromThrowable(responseObserver.getError()));
237       if (responseObserver.getResponses().size() != 1) {
238         throw new AssertionError("Expected one response");
239       }
240     }
241 
rstDuringData()242     private void rstDuringData() throws Exception {
243       // Use async stub to verify no data is received.
244       RstStreamObserver responseObserver = new RstStreamObserver();
245       asyncStub.unaryCall(simpleRequest, responseObserver);
246       if (!responseObserver.awaitCompletion(timeoutSeconds, TimeUnit.SECONDS)) {
247         throw new AssertionError("Operation timed out");
248       }
249       if (responseObserver.getError() == null) {
250         throw new AssertionError("Expected call to fail");
251       }
252       assertRstStreamReceived(Status.fromThrowable(responseObserver.getError()));
253       if (responseObserver.getResponses().size() != 0) {
254         throw new AssertionError("Expected zero responses");
255       }
256     }
257 
goAway()258     private void goAway() throws Exception {
259       assertResponseEquals(blockingStub.unaryCall(simpleRequest), goldenResponse);
260       TimeUnit.SECONDS.sleep(1);
261       assertResponseEquals(blockingStub.unaryCall(simpleRequest), goldenResponse);
262     }
263 
ping()264     private void ping() throws Exception {
265       assertResponseEquals(blockingStub.unaryCall(simpleRequest), goldenResponse);
266     }
267 
maxStreams()268     private void maxStreams() throws Exception {
269       final int numThreads = 10;
270 
271       // Preliminary call to ensure MAX_STREAMS setting is received by the client.
272       assertResponseEquals(blockingStub.unaryCall(simpleRequest), goldenResponse);
273 
274       threadpool = MoreExecutors.listeningDecorator(newFixedThreadPool(numThreads));
275       List<ListenableFuture<?>> workerFutures = new ArrayList<ListenableFuture<?>>();
276       for (int i = 0; i < numThreads; i++) {
277         workerFutures.add(threadpool.submit(new MaxStreamsWorker(i, simpleRequest)));
278       }
279       ListenableFuture<?> f = Futures.allAsList(workerFutures);
280       f.get(timeoutSeconds, TimeUnit.SECONDS);
281     }
282 
283     private class RstStreamObserver implements StreamObserver<SimpleResponse> {
284       private final CountDownLatch latch = new CountDownLatch(1);
285       private final List<SimpleResponse> responses = new ArrayList<>();
286       private Throwable error;
287 
288       @Override
onNext(SimpleResponse value)289       public void onNext(SimpleResponse value) {
290         responses.add(value);
291       }
292 
293       @Override
onError(Throwable t)294       public void onError(Throwable t) {
295         error = t;
296         latch.countDown();
297       }
298 
299       @Override
onCompleted()300       public void onCompleted() {
301         latch.countDown();
302       }
303 
getResponses()304       public List<SimpleResponse> getResponses() {
305         return responses;
306       }
307 
getError()308       public Throwable getError() {
309         return error;
310       }
311 
awaitCompletion(long timeout, TimeUnit unit)312       public boolean awaitCompletion(long timeout, TimeUnit unit) throws Exception {
313         return latch.await(timeout, unit);
314       }
315     }
316 
317     private class MaxStreamsWorker implements Runnable {
318       int threadNum;
319       SimpleRequest request;
320 
MaxStreamsWorker(int threadNum, SimpleRequest request)321       MaxStreamsWorker(int threadNum, SimpleRequest request) {
322         this.threadNum = threadNum;
323         this.request = request;
324       }
325 
326       @Override
run()327       public void run() {
328         Thread.currentThread().setName("thread:" + threadNum);
329         try {
330           TestServiceGrpc.TestServiceBlockingStub blockingStub =
331               TestServiceGrpc.newBlockingStub(channel);
332           assertResponseEquals(blockingStub.unaryCall(simpleRequest), goldenResponse);
333         } catch (Exception e) {
334           throw new RuntimeException(e);
335         }
336       }
337     }
338 
assertRstStreamReceived(Status status)339     private void assertRstStreamReceived(Status status) {
340       if (!status.getCode().equals(Status.Code.UNAVAILABLE)) {
341         throw new AssertionError("Wrong status code. Expected: " + Status.Code.UNAVAILABLE
342             + " Received: " + status.getCode());
343       }
344       String http2ErrorPrefix = "HTTP/2 error code: NO_ERROR";
345       if (status.getDescription() == null
346           || !status.getDescription().startsWith(http2ErrorPrefix)) {
347         throw new AssertionError("Wrong HTTP/2 error code. Expected: " + http2ErrorPrefix
348             + " Received: " + status.getDescription());
349       }
350     }
351 
assertResponseEquals(SimpleResponse response, SimpleResponse goldenResponse)352     private void assertResponseEquals(SimpleResponse response, SimpleResponse goldenResponse) {
353       if (!response.equals(goldenResponse)) {
354         throw new AssertionError("Incorrect response received");
355       }
356     }
357   }
358 
createChannel()359   private ManagedChannel createChannel() {
360     InetAddress address;
361     try {
362       address = InetAddress.getByName(serverHost);
363     } catch (UnknownHostException ex) {
364       throw new RuntimeException(ex);
365     }
366     return NettyChannelBuilder.forAddress(new InetSocketAddress(address, serverPort))
367         .negotiationType(NegotiationType.PLAINTEXT)
368         .build();
369   }
370 
validTestCasesHelpText()371   private static String validTestCasesHelpText() {
372     StringBuilder builder = new StringBuilder();
373     for (Http2TestCases testCase : Http2TestCases.values()) {
374       String strTestcase = testCase.name().toLowerCase();
375       builder.append("\n      ")
376           .append(strTestcase)
377           .append(": ")
378           .append(testCase.description());
379     }
380     return builder.toString();
381   }
382 }
383 
384