1 /*
2  * Copyright 2016 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.testing.integration;
18 
19 import static java.util.Arrays.asList;
20 import static java.util.Collections.shuffle;
21 import static java.util.Collections.singletonList;
22 import static java.util.concurrent.Executors.newFixedThreadPool;
23 import static java.util.concurrent.TimeUnit.SECONDS;
24 
25 import com.google.common.annotations.VisibleForTesting;
26 import com.google.common.base.Joiner;
27 import com.google.common.base.Objects;
28 import com.google.common.base.Preconditions;
29 import com.google.common.base.Splitter;
30 import com.google.common.collect.Iterators;
31 import com.google.common.util.concurrent.Futures;
32 import com.google.common.util.concurrent.ListenableFuture;
33 import com.google.common.util.concurrent.ListeningExecutorService;
34 import com.google.common.util.concurrent.MoreExecutors;
35 import io.grpc.ManagedChannel;
36 import io.grpc.Server;
37 import io.grpc.ServerBuilder;
38 import io.grpc.Status;
39 import io.grpc.StatusException;
40 import io.grpc.internal.testing.TestUtils;
41 import io.grpc.netty.GrpcSslContexts;
42 import io.grpc.netty.NegotiationType;
43 import io.grpc.netty.NettyChannelBuilder;
44 import io.grpc.stub.StreamObserver;
45 import io.netty.handler.ssl.SslContext;
46 import java.io.IOException;
47 import java.net.InetAddress;
48 import java.net.InetSocketAddress;
49 import java.net.URI;
50 import java.net.URISyntaxException;
51 import java.net.UnknownHostException;
52 import java.util.ArrayList;
53 import java.util.Collections;
54 import java.util.Iterator;
55 import java.util.List;
56 import java.util.Map;
57 import java.util.concurrent.ConcurrentHashMap;
58 import java.util.logging.Level;
59 import java.util.logging.Logger;
60 
61 /**
62  * A stress test client following the
63  * <a href="https://github.com/grpc/grpc/blob/master/tools/run_tests/stress_test/STRESS_CLIENT_SPEC.md">
64  * specifications</a> of the gRPC stress testing framework.
65  */
66 public class StressTestClient {
67 
68   private static final Logger log = Logger.getLogger(StressTestClient.class.getName());
69 
70   /**
71    * The main application allowing this client to be launched from the command line.
72    */
main(String... args)73   public static void main(String... args) throws Exception {
74     final StressTestClient client = new StressTestClient();
75     client.parseArgs(args);
76 
77     // Attempt an orderly shutdown, if the JVM is shutdown via a signal.
78     Runtime.getRuntime().addShutdownHook(new Thread() {
79       @Override
80       public void run() {
81         client.shutdown();
82       }
83     });
84 
85     try {
86       client.startMetricsService();
87       client.runStressTest();
88       client.blockUntilStressTestComplete();
89     } catch (Exception e) {
90       log.log(Level.WARNING, "The stress test client encountered an error!", e);
91     } finally {
92       client.shutdown();
93     }
94   }
95 
96   private static final int WORKER_GRACE_PERIOD_SECS = 30;
97 
98   private List<InetSocketAddress> addresses =
99       singletonList(new InetSocketAddress("localhost", 8080));
100   private List<TestCaseWeightPair> testCaseWeightPairs = new ArrayList<>();
101 
102   private String serverHostOverride;
103   private boolean useTls = false;
104   private boolean useTestCa = false;
105   private int durationSecs = -1;
106   private int channelsPerServer = 1;
107   private int stubsPerChannel = 1;
108   private int metricsPort = 8081;
109 
110   private Server metricsServer;
111   private final Map<String, Metrics.GaugeResponse> gauges =
112       new ConcurrentHashMap<String, Metrics.GaugeResponse>();
113 
114   private volatile boolean shutdown;
115 
116   /**
117    * List of futures that {@link #blockUntilStressTestComplete()} waits for.
118    */
119   private final List<ListenableFuture<?>> workerFutures =
120       new ArrayList<ListenableFuture<?>>();
121   private final List<ManagedChannel> channels = new ArrayList<>();
122   private ListeningExecutorService threadpool;
123 
124   @VisibleForTesting
parseArgs(String[] args)125   void parseArgs(String[] args) {
126     boolean usage = false;
127     String serverAddresses = "";
128     for (String arg : args) {
129       if (!arg.startsWith("--")) {
130         System.err.println("All arguments must start with '--': " + arg);
131         usage = true;
132         break;
133       }
134       String[] parts = arg.substring(2).split("=", 2);
135       String key = parts[0];
136       if ("help".equals(key)) {
137         usage = true;
138         break;
139       }
140       if (parts.length != 2) {
141         System.err.println("All arguments must be of the form --arg=value");
142         usage = true;
143         break;
144       }
145       String value = parts[1];
146       if ("server_addresses".equals(key)) {
147         // May need to apply server host overrides to the addresses, so delay processing
148         serverAddresses = value;
149       } else if ("server_host_override".equals(key)) {
150         serverHostOverride = value;
151       } else if ("use_tls".equals(key)) {
152         useTls = Boolean.parseBoolean(value);
153       } else if ("use_test_ca".equals(key)) {
154         useTestCa = Boolean.parseBoolean(value);
155       } else if ("test_cases".equals(key)) {
156         testCaseWeightPairs = parseTestCases(value);
157       } else if ("test_duration_secs".equals(key)) {
158         durationSecs = Integer.valueOf(value);
159       } else if ("num_channels_per_server".equals(key)) {
160         channelsPerServer = Integer.valueOf(value);
161       } else if ("num_stubs_per_channel".equals(key)) {
162         stubsPerChannel = Integer.valueOf(value);
163       } else if ("metrics_port".equals(key)) {
164         metricsPort = Integer.valueOf(value);
165       } else {
166         System.err.println("Unknown argument: " + key);
167         usage = true;
168         break;
169       }
170     }
171 
172     if (!usage && !serverAddresses.isEmpty()) {
173       addresses = parseServerAddresses(serverAddresses);
174       usage = addresses.isEmpty();
175     }
176 
177     if (usage) {
178       StressTestClient c = new StressTestClient();
179       System.err.println(
180           "Usage: [ARGS...]"
181               + "\n"
182               + "\n  --server_host_override=HOST    Claimed identification expected of server."
183               + "\n                                 Defaults to server host"
184               + "\n  --server_addresses=<name_1>:<port_1>,<name_2>:<port_2>...<name_N>:<port_N>"
185               + "\n    Default: " + serverAddressesToString(c.addresses)
186               + "\n  --test_cases=<testcase_1:w_1>,<testcase_2:w_2>...<testcase_n:w_n>"
187               + "\n    List of <testcase,weight> tuples. Weight is the relative frequency at which"
188               + " testcase is run."
189               + "\n    Valid Testcases:"
190               + validTestCasesHelpText()
191               + "\n  --use_tls=true|false           Whether to use TLS. Default: " + c.useTls
192               + "\n  --use_test_ca=true|false       Whether to trust our fake CA. Requires"
193               + " --use_tls=true"
194               + "\n                                 to have effect. Default: " + c.useTestCa
195               + "\n  --test_duration_secs=SECONDS   '-1' for no limit. Default: " + c.durationSecs
196               + "\n  --num_channels_per_server=INT  Number of connections to each server address."
197               + " Default: " + c.channelsPerServer
198               + "\n  --num_stubs_per_channel=INT    Default: " + c.stubsPerChannel
199               + "\n  --metrics_port=PORT            Listening port of the metrics server."
200               + " Default: " + c.metricsPort
201       );
202       System.exit(1);
203     }
204   }
205 
206   @VisibleForTesting
startMetricsService()207   void startMetricsService() throws IOException {
208     Preconditions.checkState(!shutdown, "client was shutdown.");
209 
210     metricsServer = ServerBuilder.forPort(metricsPort)
211         .addService(new MetricsServiceImpl())
212         .build()
213         .start();
214   }
215 
216   @VisibleForTesting
runStressTest()217   void runStressTest() throws Exception {
218     Preconditions.checkState(!shutdown, "client was shutdown.");
219     if (testCaseWeightPairs.isEmpty()) {
220       return;
221     }
222 
223     int numChannels = addresses.size() * channelsPerServer;
224     int numThreads = numChannels * stubsPerChannel;
225     threadpool = MoreExecutors.listeningDecorator(newFixedThreadPool(numThreads));
226     int serverIdx = -1;
227     for (InetSocketAddress address : addresses) {
228       serverIdx++;
229       for (int i = 0; i < channelsPerServer; i++) {
230         ManagedChannel channel = createChannel(address);
231         channels.add(channel);
232         for (int j = 0; j < stubsPerChannel; j++) {
233           String gaugeName =
234               String.format("/stress_test/server_%d/channel_%d/stub_%d/qps", serverIdx, i, j);
235           Worker worker =
236               new Worker(channel, testCaseWeightPairs, durationSecs, gaugeName);
237 
238           workerFutures.add(threadpool.submit(worker));
239         }
240       }
241     }
242   }
243 
244   @VisibleForTesting
blockUntilStressTestComplete()245   void blockUntilStressTestComplete() throws Exception {
246     Preconditions.checkState(!shutdown, "client was shutdown.");
247 
248     ListenableFuture<?> f = Futures.allAsList(workerFutures);
249     if (durationSecs == -1) {
250       // '-1' indicates that the stress test runs until terminated by the user.
251       f.get();
252     } else {
253       f.get(durationSecs + WORKER_GRACE_PERIOD_SECS, SECONDS);
254     }
255   }
256 
257   @VisibleForTesting
shutdown()258   void shutdown() {
259     if (shutdown) {
260       return;
261     }
262     shutdown = true;
263 
264     for (ManagedChannel ch : channels) {
265       try {
266         ch.shutdownNow();
267         ch.awaitTermination(1, SECONDS);
268       } catch (Throwable t) {
269         log.log(Level.WARNING, "Error shutting down channel!", t);
270       }
271     }
272 
273     try {
274       metricsServer.shutdownNow();
275     } catch (Throwable t) {
276       log.log(Level.WARNING, "Error shutting down metrics service!", t);
277     }
278 
279     try {
280       if (threadpool != null) {
281         threadpool.shutdownNow();
282       }
283     } catch (Throwable t) {
284       log.log(Level.WARNING, "Error shutting down threadpool.", t);
285     }
286   }
287 
288   @VisibleForTesting
getMetricServerPort()289   int getMetricServerPort() {
290     return metricsServer.getPort();
291   }
292 
parseServerAddresses(String addressesStr)293   private List<InetSocketAddress> parseServerAddresses(String addressesStr) {
294     List<InetSocketAddress> addresses = new ArrayList<>();
295 
296     for (List<String> namePort : parseCommaSeparatedTuples(addressesStr)) {
297       InetAddress address;
298       String name = namePort.get(0);
299       int port = Integer.valueOf(namePort.get(1));
300       try {
301         address = InetAddress.getByName(name);
302         if (serverHostOverride != null) {
303           // Force the hostname to match the cert the server uses.
304           address = InetAddress.getByAddress(serverHostOverride, address.getAddress());
305         }
306       } catch (UnknownHostException ex) {
307         throw new RuntimeException(ex);
308       }
309       addresses.add(new InetSocketAddress(address, port));
310     }
311 
312     return addresses;
313   }
314 
parseTestCases(String testCasesStr)315   private static List<TestCaseWeightPair> parseTestCases(String testCasesStr) {
316     List<TestCaseWeightPair> testCaseWeightPairs = new ArrayList<>();
317 
318     for (List<String> nameWeight : parseCommaSeparatedTuples(testCasesStr)) {
319       TestCases testCase = TestCases.fromString(nameWeight.get(0));
320       int weight = Integer.valueOf(nameWeight.get(1));
321       testCaseWeightPairs.add(new TestCaseWeightPair(testCase, weight));
322     }
323 
324     return testCaseWeightPairs;
325   }
326 
parseCommaSeparatedTuples(String str)327   private static List<List<String>> parseCommaSeparatedTuples(String str) {
328     List<List<String>> tuples = new ArrayList<List<String>>();
329     for (String tupleStr : Splitter.on(',').split(str)) {
330       int splitIdx = tupleStr.lastIndexOf(':');
331       if (splitIdx == -1) {
332         throw new IllegalArgumentException("Illegal tuple format: '" + tupleStr + "'");
333       }
334       String part0 = tupleStr.substring(0, splitIdx);
335       String part1 = tupleStr.substring(splitIdx + 1);
336       tuples.add(asList(part0, part1));
337     }
338     return tuples;
339   }
340 
createChannel(InetSocketAddress address)341   private ManagedChannel createChannel(InetSocketAddress address) {
342     SslContext sslContext = null;
343     if (useTestCa) {
344       try {
345         sslContext = GrpcSslContexts.forClient().trustManager(
346             TestUtils.loadCert("ca.pem")).build();
347       } catch (Exception ex) {
348         throw new RuntimeException(ex);
349       }
350     }
351     return NettyChannelBuilder.forAddress(address)
352         .negotiationType(useTls ? NegotiationType.TLS : NegotiationType.PLAINTEXT)
353         .sslContext(sslContext)
354         .build();
355   }
356 
serverAddressesToString(List<InetSocketAddress> addresses)357   private static String serverAddressesToString(List<InetSocketAddress> addresses) {
358     List<String> tmp = new ArrayList<>();
359     for (InetSocketAddress address : addresses) {
360       URI uri;
361       try {
362         uri = new URI(null, null, address.getHostName(), address.getPort(), null, null, null);
363       } catch (URISyntaxException e) {
364         throw new RuntimeException(e);
365       }
366       tmp.add(uri.getAuthority());
367     }
368     return Joiner.on(',').join(tmp);
369   }
370 
validTestCasesHelpText()371   private static String validTestCasesHelpText() {
372     StringBuilder builder = new StringBuilder();
373     for (TestCases testCase : TestCases.values()) {
374       String strTestcase = testCase.name().toLowerCase();
375       builder.append("\n      ")
376           .append(strTestcase)
377           .append(": ")
378           .append(testCase.description());
379     }
380     return builder.toString();
381   }
382 
383   /**
384    * A stress test worker. Every stub has its own stress test worker.
385    */
386   private class Worker implements Runnable {
387 
388     // Interval at which the QPS stats of metrics service are updated.
389     private static final long METRICS_COLLECTION_INTERVAL_SECS = 5;
390 
391     private final ManagedChannel channel;
392     private final List<TestCaseWeightPair> testCaseWeightPairs;
393     private final Integer durationSec;
394     private final String gaugeName;
395 
Worker(ManagedChannel channel, List<TestCaseWeightPair> testCaseWeightPairs, int durationSec, String gaugeName)396     Worker(ManagedChannel channel, List<TestCaseWeightPair> testCaseWeightPairs,
397         int durationSec, String gaugeName) {
398       Preconditions.checkArgument(durationSec >= -1, "durationSec must be gte -1.");
399       this.channel = Preconditions.checkNotNull(channel, "channel");
400       this.testCaseWeightPairs =
401           Preconditions.checkNotNull(testCaseWeightPairs, "testCaseWeightPairs");
402       this.durationSec = durationSec == -1 ? null : durationSec;
403       this.gaugeName = Preconditions.checkNotNull(gaugeName, "gaugeName");
404     }
405 
406     @Override
run()407     public void run() {
408       // Simplify debugging if the worker crashes / never terminates.
409       Thread.currentThread().setName(gaugeName);
410 
411       Tester tester = new Tester();
412       tester.setUp();
413       WeightedTestCaseSelector testCaseSelector = new WeightedTestCaseSelector(testCaseWeightPairs);
414       Long endTime = durationSec == null ? null : System.nanoTime() + SECONDS.toNanos(durationSecs);
415       long lastMetricsCollectionTime = initLastMetricsCollectionTime();
416       // Number of interop testcases run since the last time metrics have been updated.
417       long testCasesSinceLastMetricsCollection = 0;
418 
419       while (!Thread.currentThread().isInterrupted() && !shutdown
420           && (endTime == null || endTime - System.nanoTime() > 0)) {
421         try {
422           runTestCase(tester, testCaseSelector.nextTestCase());
423         } catch (Exception e) {
424           throw new RuntimeException(e);
425         }
426 
427         testCasesSinceLastMetricsCollection++;
428 
429         double durationSecs = computeDurationSecs(lastMetricsCollectionTime);
430         if (durationSecs >= METRICS_COLLECTION_INTERVAL_SECS) {
431           long qps = (long) Math.ceil(testCasesSinceLastMetricsCollection / durationSecs);
432 
433           Metrics.GaugeResponse gauge = Metrics.GaugeResponse
434               .newBuilder()
435               .setName(gaugeName)
436               .setLongValue(qps)
437               .build();
438 
439           gauges.put(gaugeName, gauge);
440 
441           lastMetricsCollectionTime = System.nanoTime();
442           testCasesSinceLastMetricsCollection = 0;
443         }
444       }
445     }
446 
initLastMetricsCollectionTime()447     private long initLastMetricsCollectionTime() {
448       return System.nanoTime() - SECONDS.toNanos(METRICS_COLLECTION_INTERVAL_SECS);
449     }
450 
computeDurationSecs(long lastMetricsCollectionTime)451     private double computeDurationSecs(long lastMetricsCollectionTime) {
452       return (System.nanoTime() - lastMetricsCollectionTime) / 1000000000.0;
453     }
454 
runTestCase(Tester tester, TestCases testCase)455     private void runTestCase(Tester tester, TestCases testCase) throws Exception {
456       // TODO(buchgr): Implement tests requiring auth, once C++ supports it.
457       switch (testCase) {
458         case EMPTY_UNARY:
459           tester.emptyUnary();
460           break;
461 
462         case LARGE_UNARY:
463           tester.largeUnary();
464           break;
465 
466         case CLIENT_STREAMING:
467           tester.clientStreaming();
468           break;
469 
470         case SERVER_STREAMING:
471           tester.serverStreaming();
472           break;
473 
474         case PING_PONG:
475           tester.pingPong();
476           break;
477 
478         case EMPTY_STREAM:
479           tester.emptyStream();
480           break;
481 
482         case UNIMPLEMENTED_METHOD: {
483           tester.unimplementedMethod();
484           break;
485         }
486 
487         case UNIMPLEMENTED_SERVICE: {
488           tester.unimplementedService();
489           break;
490         }
491 
492         case CANCEL_AFTER_BEGIN: {
493           tester.cancelAfterBegin();
494           break;
495         }
496 
497         case CANCEL_AFTER_FIRST_RESPONSE: {
498           tester.cancelAfterFirstResponse();
499           break;
500         }
501 
502         case TIMEOUT_ON_SLEEPING_SERVER: {
503           tester.timeoutOnSleepingServer();
504           break;
505         }
506 
507         default:
508           throw new IllegalArgumentException("Unknown test case: " + testCase);
509       }
510     }
511 
512     class Tester extends AbstractInteropTest {
513       @Override
createChannel()514       protected ManagedChannel createChannel() {
515         return Worker.this.channel;
516       }
517 
518       @Override
operationTimeoutMillis()519       protected int operationTimeoutMillis() {
520         // Don't enforce a timeout when using the interop tests for the stress test client.
521         // Fixes https://github.com/grpc/grpc-java/issues/1812
522         return Integer.MAX_VALUE;
523       }
524 
525       @Override
metricsExpected()526       protected boolean metricsExpected() {
527         // TODO(zhangkun83): we may want to enable the real google Instrumentation implementation in
528         // stress tests.
529         return false;
530       }
531     }
532 
533     class WeightedTestCaseSelector {
534       /**
535        * Randomly shuffled and cyclic sequence that contains each testcase proportionally
536        * to its weight.
537        */
538       final Iterator<TestCases> testCases;
539 
WeightedTestCaseSelector(List<TestCaseWeightPair> testCaseWeightPairs)540       WeightedTestCaseSelector(List<TestCaseWeightPair> testCaseWeightPairs) {
541         Preconditions.checkNotNull(testCaseWeightPairs, "testCaseWeightPairs");
542         Preconditions.checkArgument(testCaseWeightPairs.size() > 0);
543 
544         List<TestCases> testCases = new ArrayList<>();
545         for (TestCaseWeightPair testCaseWeightPair : testCaseWeightPairs) {
546           for (int i = 0; i < testCaseWeightPair.weight; i++) {
547             testCases.add(testCaseWeightPair.testCase);
548           }
549         }
550 
551         shuffle(testCases);
552 
553         this.testCases = Iterators.cycle(testCases);
554       }
555 
nextTestCase()556       TestCases nextTestCase() {
557         return testCases.next();
558       }
559     }
560   }
561 
562   /**
563    * Service that exports the QPS metrics of the stress test.
564    */
565   private class MetricsServiceImpl extends MetricsServiceGrpc.MetricsServiceImplBase {
566 
567     @Override
getAllGauges(Metrics.EmptyMessage request, StreamObserver<Metrics.GaugeResponse> responseObserver)568     public void getAllGauges(Metrics.EmptyMessage request,
569         StreamObserver<Metrics.GaugeResponse> responseObserver) {
570       for (Metrics.GaugeResponse gauge : gauges.values()) {
571         responseObserver.onNext(gauge);
572       }
573       responseObserver.onCompleted();
574     }
575 
576     @Override
getGauge(Metrics.GaugeRequest request, StreamObserver<Metrics.GaugeResponse> responseObserver)577     public void getGauge(Metrics.GaugeRequest request,
578         StreamObserver<Metrics.GaugeResponse> responseObserver) {
579       String gaugeName = request.getName();
580       Metrics.GaugeResponse gauge = gauges.get(gaugeName);
581       if (gauge != null) {
582         responseObserver.onNext(gauge);
583         responseObserver.onCompleted();
584       } else {
585         responseObserver.onError(new StatusException(Status.NOT_FOUND));
586       }
587     }
588   }
589 
590   @VisibleForTesting
591   static class TestCaseWeightPair {
592     final TestCases testCase;
593     final int weight;
594 
TestCaseWeightPair(TestCases testCase, int weight)595     TestCaseWeightPair(TestCases testCase, int weight) {
596       Preconditions.checkArgument(weight >= 0, "weight must be positive.");
597       this.testCase = Preconditions.checkNotNull(testCase, "testCase");
598       this.weight = weight;
599     }
600 
601     @Override
equals(Object other)602     public boolean equals(Object other) {
603       if (!(other instanceof TestCaseWeightPair)) {
604         return false;
605       }
606       TestCaseWeightPair that = (TestCaseWeightPair) other;
607       return testCase.equals(that.testCase) && weight == that.weight;
608     }
609 
610     @Override
hashCode()611     public int hashCode() {
612       return Objects.hashCode(testCase, weight);
613     }
614   }
615 
616   @VisibleForTesting
addresses()617   List<InetSocketAddress> addresses() {
618     return Collections.unmodifiableList(addresses);
619   }
620 
621   @VisibleForTesting
serverHostOverride()622   String serverHostOverride() {
623     return serverHostOverride;
624   }
625 
626   @VisibleForTesting
useTls()627   boolean useTls() {
628     return useTls;
629   }
630 
631   @VisibleForTesting
useTestCa()632   boolean useTestCa() {
633     return useTestCa;
634   }
635 
636   @VisibleForTesting
testCaseWeightPairs()637   List<TestCaseWeightPair> testCaseWeightPairs() {
638     return testCaseWeightPairs;
639   }
640 
641   @VisibleForTesting
durationSecs()642   int durationSecs() {
643     return durationSecs;
644   }
645 
646   @VisibleForTesting
channelsPerServer()647   int channelsPerServer() {
648     return channelsPerServer;
649   }
650 
651   @VisibleForTesting
stubsPerChannel()652   int stubsPerChannel() {
653     return stubsPerChannel;
654   }
655 
656   @VisibleForTesting
metricsPort()657   int metricsPort() {
658     return metricsPort;
659   }
660 }
661