1 /* 2 * Copyright 2016 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.testing.integration; 18 19 import static java.util.Arrays.asList; 20 import static java.util.Collections.shuffle; 21 import static java.util.Collections.singletonList; 22 import static java.util.concurrent.Executors.newFixedThreadPool; 23 import static java.util.concurrent.TimeUnit.SECONDS; 24 25 import com.google.common.annotations.VisibleForTesting; 26 import com.google.common.base.Joiner; 27 import com.google.common.base.Objects; 28 import com.google.common.base.Preconditions; 29 import com.google.common.base.Splitter; 30 import com.google.common.collect.Iterators; 31 import com.google.common.util.concurrent.Futures; 32 import com.google.common.util.concurrent.ListenableFuture; 33 import com.google.common.util.concurrent.ListeningExecutorService; 34 import com.google.common.util.concurrent.MoreExecutors; 35 import io.grpc.ManagedChannel; 36 import io.grpc.Server; 37 import io.grpc.ServerBuilder; 38 import io.grpc.Status; 39 import io.grpc.StatusException; 40 import io.grpc.internal.testing.TestUtils; 41 import io.grpc.netty.GrpcSslContexts; 42 import io.grpc.netty.NegotiationType; 43 import io.grpc.netty.NettyChannelBuilder; 44 import io.grpc.stub.StreamObserver; 45 import io.netty.handler.ssl.SslContext; 46 import java.io.IOException; 47 import java.net.InetAddress; 48 import java.net.InetSocketAddress; 49 import java.net.URI; 50 import java.net.URISyntaxException; 51 import java.net.UnknownHostException; 52 import java.util.ArrayList; 53 import java.util.Collections; 54 import java.util.Iterator; 55 import java.util.List; 56 import java.util.Map; 57 import java.util.concurrent.ConcurrentHashMap; 58 import java.util.logging.Level; 59 import java.util.logging.Logger; 60 61 /** 62 * A stress test client following the 63 * <a href="https://github.com/grpc/grpc/blob/master/tools/run_tests/stress_test/STRESS_CLIENT_SPEC.md"> 64 * specifications</a> of the gRPC stress testing framework. 65 */ 66 public class StressTestClient { 67 68 private static final Logger log = Logger.getLogger(StressTestClient.class.getName()); 69 70 /** 71 * The main application allowing this client to be launched from the command line. 72 */ main(String... args)73 public static void main(String... args) throws Exception { 74 final StressTestClient client = new StressTestClient(); 75 client.parseArgs(args); 76 77 // Attempt an orderly shutdown, if the JVM is shutdown via a signal. 78 Runtime.getRuntime().addShutdownHook(new Thread() { 79 @Override 80 public void run() { 81 client.shutdown(); 82 } 83 }); 84 85 try { 86 client.startMetricsService(); 87 client.runStressTest(); 88 client.blockUntilStressTestComplete(); 89 } catch (Exception e) { 90 log.log(Level.WARNING, "The stress test client encountered an error!", e); 91 } finally { 92 client.shutdown(); 93 } 94 } 95 96 private static final int WORKER_GRACE_PERIOD_SECS = 30; 97 98 private List<InetSocketAddress> addresses = 99 singletonList(new InetSocketAddress("localhost", 8080)); 100 private List<TestCaseWeightPair> testCaseWeightPairs = new ArrayList<>(); 101 102 private String serverHostOverride; 103 private boolean useTls = false; 104 private boolean useTestCa = false; 105 private int durationSecs = -1; 106 private int channelsPerServer = 1; 107 private int stubsPerChannel = 1; 108 private int metricsPort = 8081; 109 110 private Server metricsServer; 111 private final Map<String, Metrics.GaugeResponse> gauges = 112 new ConcurrentHashMap<String, Metrics.GaugeResponse>(); 113 114 private volatile boolean shutdown; 115 116 /** 117 * List of futures that {@link #blockUntilStressTestComplete()} waits for. 118 */ 119 private final List<ListenableFuture<?>> workerFutures = 120 new ArrayList<ListenableFuture<?>>(); 121 private final List<ManagedChannel> channels = new ArrayList<>(); 122 private ListeningExecutorService threadpool; 123 124 @VisibleForTesting parseArgs(String[] args)125 void parseArgs(String[] args) { 126 boolean usage = false; 127 String serverAddresses = ""; 128 for (String arg : args) { 129 if (!arg.startsWith("--")) { 130 System.err.println("All arguments must start with '--': " + arg); 131 usage = true; 132 break; 133 } 134 String[] parts = arg.substring(2).split("=", 2); 135 String key = parts[0]; 136 if ("help".equals(key)) { 137 usage = true; 138 break; 139 } 140 if (parts.length != 2) { 141 System.err.println("All arguments must be of the form --arg=value"); 142 usage = true; 143 break; 144 } 145 String value = parts[1]; 146 if ("server_addresses".equals(key)) { 147 // May need to apply server host overrides to the addresses, so delay processing 148 serverAddresses = value; 149 } else if ("server_host_override".equals(key)) { 150 serverHostOverride = value; 151 } else if ("use_tls".equals(key)) { 152 useTls = Boolean.parseBoolean(value); 153 } else if ("use_test_ca".equals(key)) { 154 useTestCa = Boolean.parseBoolean(value); 155 } else if ("test_cases".equals(key)) { 156 testCaseWeightPairs = parseTestCases(value); 157 } else if ("test_duration_secs".equals(key)) { 158 durationSecs = Integer.valueOf(value); 159 } else if ("num_channels_per_server".equals(key)) { 160 channelsPerServer = Integer.valueOf(value); 161 } else if ("num_stubs_per_channel".equals(key)) { 162 stubsPerChannel = Integer.valueOf(value); 163 } else if ("metrics_port".equals(key)) { 164 metricsPort = Integer.valueOf(value); 165 } else { 166 System.err.println("Unknown argument: " + key); 167 usage = true; 168 break; 169 } 170 } 171 172 if (!usage && !serverAddresses.isEmpty()) { 173 addresses = parseServerAddresses(serverAddresses); 174 usage = addresses.isEmpty(); 175 } 176 177 if (usage) { 178 StressTestClient c = new StressTestClient(); 179 System.err.println( 180 "Usage: [ARGS...]" 181 + "\n" 182 + "\n --server_host_override=HOST Claimed identification expected of server." 183 + "\n Defaults to server host" 184 + "\n --server_addresses=<name_1>:<port_1>,<name_2>:<port_2>...<name_N>:<port_N>" 185 + "\n Default: " + serverAddressesToString(c.addresses) 186 + "\n --test_cases=<testcase_1:w_1>,<testcase_2:w_2>...<testcase_n:w_n>" 187 + "\n List of <testcase,weight> tuples. Weight is the relative frequency at which" 188 + " testcase is run." 189 + "\n Valid Testcases:" 190 + validTestCasesHelpText() 191 + "\n --use_tls=true|false Whether to use TLS. Default: " + c.useTls 192 + "\n --use_test_ca=true|false Whether to trust our fake CA. Requires" 193 + " --use_tls=true" 194 + "\n to have effect. Default: " + c.useTestCa 195 + "\n --test_duration_secs=SECONDS '-1' for no limit. Default: " + c.durationSecs 196 + "\n --num_channels_per_server=INT Number of connections to each server address." 197 + " Default: " + c.channelsPerServer 198 + "\n --num_stubs_per_channel=INT Default: " + c.stubsPerChannel 199 + "\n --metrics_port=PORT Listening port of the metrics server." 200 + " Default: " + c.metricsPort 201 ); 202 System.exit(1); 203 } 204 } 205 206 @VisibleForTesting startMetricsService()207 void startMetricsService() throws IOException { 208 Preconditions.checkState(!shutdown, "client was shutdown."); 209 210 metricsServer = ServerBuilder.forPort(metricsPort) 211 .addService(new MetricsServiceImpl()) 212 .build() 213 .start(); 214 } 215 216 @VisibleForTesting runStressTest()217 void runStressTest() throws Exception { 218 Preconditions.checkState(!shutdown, "client was shutdown."); 219 if (testCaseWeightPairs.isEmpty()) { 220 return; 221 } 222 223 int numChannels = addresses.size() * channelsPerServer; 224 int numThreads = numChannels * stubsPerChannel; 225 threadpool = MoreExecutors.listeningDecorator(newFixedThreadPool(numThreads)); 226 int serverIdx = -1; 227 for (InetSocketAddress address : addresses) { 228 serverIdx++; 229 for (int i = 0; i < channelsPerServer; i++) { 230 ManagedChannel channel = createChannel(address); 231 channels.add(channel); 232 for (int j = 0; j < stubsPerChannel; j++) { 233 String gaugeName = 234 String.format("/stress_test/server_%d/channel_%d/stub_%d/qps", serverIdx, i, j); 235 Worker worker = 236 new Worker(channel, testCaseWeightPairs, durationSecs, gaugeName); 237 238 workerFutures.add(threadpool.submit(worker)); 239 } 240 } 241 } 242 } 243 244 @VisibleForTesting blockUntilStressTestComplete()245 void blockUntilStressTestComplete() throws Exception { 246 Preconditions.checkState(!shutdown, "client was shutdown."); 247 248 ListenableFuture<?> f = Futures.allAsList(workerFutures); 249 if (durationSecs == -1) { 250 // '-1' indicates that the stress test runs until terminated by the user. 251 f.get(); 252 } else { 253 f.get(durationSecs + WORKER_GRACE_PERIOD_SECS, SECONDS); 254 } 255 } 256 257 @VisibleForTesting shutdown()258 void shutdown() { 259 if (shutdown) { 260 return; 261 } 262 shutdown = true; 263 264 for (ManagedChannel ch : channels) { 265 try { 266 ch.shutdownNow(); 267 ch.awaitTermination(1, SECONDS); 268 } catch (Throwable t) { 269 log.log(Level.WARNING, "Error shutting down channel!", t); 270 } 271 } 272 273 try { 274 metricsServer.shutdownNow(); 275 } catch (Throwable t) { 276 log.log(Level.WARNING, "Error shutting down metrics service!", t); 277 } 278 279 try { 280 if (threadpool != null) { 281 threadpool.shutdownNow(); 282 } 283 } catch (Throwable t) { 284 log.log(Level.WARNING, "Error shutting down threadpool.", t); 285 } 286 } 287 288 @VisibleForTesting getMetricServerPort()289 int getMetricServerPort() { 290 return metricsServer.getPort(); 291 } 292 parseServerAddresses(String addressesStr)293 private List<InetSocketAddress> parseServerAddresses(String addressesStr) { 294 List<InetSocketAddress> addresses = new ArrayList<>(); 295 296 for (List<String> namePort : parseCommaSeparatedTuples(addressesStr)) { 297 InetAddress address; 298 String name = namePort.get(0); 299 int port = Integer.valueOf(namePort.get(1)); 300 try { 301 address = InetAddress.getByName(name); 302 if (serverHostOverride != null) { 303 // Force the hostname to match the cert the server uses. 304 address = InetAddress.getByAddress(serverHostOverride, address.getAddress()); 305 } 306 } catch (UnknownHostException ex) { 307 throw new RuntimeException(ex); 308 } 309 addresses.add(new InetSocketAddress(address, port)); 310 } 311 312 return addresses; 313 } 314 parseTestCases(String testCasesStr)315 private static List<TestCaseWeightPair> parseTestCases(String testCasesStr) { 316 List<TestCaseWeightPair> testCaseWeightPairs = new ArrayList<>(); 317 318 for (List<String> nameWeight : parseCommaSeparatedTuples(testCasesStr)) { 319 TestCases testCase = TestCases.fromString(nameWeight.get(0)); 320 int weight = Integer.valueOf(nameWeight.get(1)); 321 testCaseWeightPairs.add(new TestCaseWeightPair(testCase, weight)); 322 } 323 324 return testCaseWeightPairs; 325 } 326 parseCommaSeparatedTuples(String str)327 private static List<List<String>> parseCommaSeparatedTuples(String str) { 328 List<List<String>> tuples = new ArrayList<List<String>>(); 329 for (String tupleStr : Splitter.on(',').split(str)) { 330 int splitIdx = tupleStr.lastIndexOf(':'); 331 if (splitIdx == -1) { 332 throw new IllegalArgumentException("Illegal tuple format: '" + tupleStr + "'"); 333 } 334 String part0 = tupleStr.substring(0, splitIdx); 335 String part1 = tupleStr.substring(splitIdx + 1); 336 tuples.add(asList(part0, part1)); 337 } 338 return tuples; 339 } 340 createChannel(InetSocketAddress address)341 private ManagedChannel createChannel(InetSocketAddress address) { 342 SslContext sslContext = null; 343 if (useTestCa) { 344 try { 345 sslContext = GrpcSslContexts.forClient().trustManager( 346 TestUtils.loadCert("ca.pem")).build(); 347 } catch (Exception ex) { 348 throw new RuntimeException(ex); 349 } 350 } 351 return NettyChannelBuilder.forAddress(address) 352 .negotiationType(useTls ? NegotiationType.TLS : NegotiationType.PLAINTEXT) 353 .sslContext(sslContext) 354 .build(); 355 } 356 serverAddressesToString(List<InetSocketAddress> addresses)357 private static String serverAddressesToString(List<InetSocketAddress> addresses) { 358 List<String> tmp = new ArrayList<>(); 359 for (InetSocketAddress address : addresses) { 360 URI uri; 361 try { 362 uri = new URI(null, null, address.getHostName(), address.getPort(), null, null, null); 363 } catch (URISyntaxException e) { 364 throw new RuntimeException(e); 365 } 366 tmp.add(uri.getAuthority()); 367 } 368 return Joiner.on(',').join(tmp); 369 } 370 validTestCasesHelpText()371 private static String validTestCasesHelpText() { 372 StringBuilder builder = new StringBuilder(); 373 for (TestCases testCase : TestCases.values()) { 374 String strTestcase = testCase.name().toLowerCase(); 375 builder.append("\n ") 376 .append(strTestcase) 377 .append(": ") 378 .append(testCase.description()); 379 } 380 return builder.toString(); 381 } 382 383 /** 384 * A stress test worker. Every stub has its own stress test worker. 385 */ 386 private class Worker implements Runnable { 387 388 // Interval at which the QPS stats of metrics service are updated. 389 private static final long METRICS_COLLECTION_INTERVAL_SECS = 5; 390 391 private final ManagedChannel channel; 392 private final List<TestCaseWeightPair> testCaseWeightPairs; 393 private final Integer durationSec; 394 private final String gaugeName; 395 Worker(ManagedChannel channel, List<TestCaseWeightPair> testCaseWeightPairs, int durationSec, String gaugeName)396 Worker(ManagedChannel channel, List<TestCaseWeightPair> testCaseWeightPairs, 397 int durationSec, String gaugeName) { 398 Preconditions.checkArgument(durationSec >= -1, "durationSec must be gte -1."); 399 this.channel = Preconditions.checkNotNull(channel, "channel"); 400 this.testCaseWeightPairs = 401 Preconditions.checkNotNull(testCaseWeightPairs, "testCaseWeightPairs"); 402 this.durationSec = durationSec == -1 ? null : durationSec; 403 this.gaugeName = Preconditions.checkNotNull(gaugeName, "gaugeName"); 404 } 405 406 @Override run()407 public void run() { 408 // Simplify debugging if the worker crashes / never terminates. 409 Thread.currentThread().setName(gaugeName); 410 411 Tester tester = new Tester(); 412 tester.setUp(); 413 WeightedTestCaseSelector testCaseSelector = new WeightedTestCaseSelector(testCaseWeightPairs); 414 Long endTime = durationSec == null ? null : System.nanoTime() + SECONDS.toNanos(durationSecs); 415 long lastMetricsCollectionTime = initLastMetricsCollectionTime(); 416 // Number of interop testcases run since the last time metrics have been updated. 417 long testCasesSinceLastMetricsCollection = 0; 418 419 while (!Thread.currentThread().isInterrupted() && !shutdown 420 && (endTime == null || endTime - System.nanoTime() > 0)) { 421 try { 422 runTestCase(tester, testCaseSelector.nextTestCase()); 423 } catch (Exception e) { 424 throw new RuntimeException(e); 425 } 426 427 testCasesSinceLastMetricsCollection++; 428 429 double durationSecs = computeDurationSecs(lastMetricsCollectionTime); 430 if (durationSecs >= METRICS_COLLECTION_INTERVAL_SECS) { 431 long qps = (long) Math.ceil(testCasesSinceLastMetricsCollection / durationSecs); 432 433 Metrics.GaugeResponse gauge = Metrics.GaugeResponse 434 .newBuilder() 435 .setName(gaugeName) 436 .setLongValue(qps) 437 .build(); 438 439 gauges.put(gaugeName, gauge); 440 441 lastMetricsCollectionTime = System.nanoTime(); 442 testCasesSinceLastMetricsCollection = 0; 443 } 444 } 445 } 446 initLastMetricsCollectionTime()447 private long initLastMetricsCollectionTime() { 448 return System.nanoTime() - SECONDS.toNanos(METRICS_COLLECTION_INTERVAL_SECS); 449 } 450 computeDurationSecs(long lastMetricsCollectionTime)451 private double computeDurationSecs(long lastMetricsCollectionTime) { 452 return (System.nanoTime() - lastMetricsCollectionTime) / 1000000000.0; 453 } 454 runTestCase(Tester tester, TestCases testCase)455 private void runTestCase(Tester tester, TestCases testCase) throws Exception { 456 // TODO(buchgr): Implement tests requiring auth, once C++ supports it. 457 switch (testCase) { 458 case EMPTY_UNARY: 459 tester.emptyUnary(); 460 break; 461 462 case LARGE_UNARY: 463 tester.largeUnary(); 464 break; 465 466 case CLIENT_STREAMING: 467 tester.clientStreaming(); 468 break; 469 470 case SERVER_STREAMING: 471 tester.serverStreaming(); 472 break; 473 474 case PING_PONG: 475 tester.pingPong(); 476 break; 477 478 case EMPTY_STREAM: 479 tester.emptyStream(); 480 break; 481 482 case UNIMPLEMENTED_METHOD: { 483 tester.unimplementedMethod(); 484 break; 485 } 486 487 case UNIMPLEMENTED_SERVICE: { 488 tester.unimplementedService(); 489 break; 490 } 491 492 case CANCEL_AFTER_BEGIN: { 493 tester.cancelAfterBegin(); 494 break; 495 } 496 497 case CANCEL_AFTER_FIRST_RESPONSE: { 498 tester.cancelAfterFirstResponse(); 499 break; 500 } 501 502 case TIMEOUT_ON_SLEEPING_SERVER: { 503 tester.timeoutOnSleepingServer(); 504 break; 505 } 506 507 default: 508 throw new IllegalArgumentException("Unknown test case: " + testCase); 509 } 510 } 511 512 class Tester extends AbstractInteropTest { 513 @Override createChannel()514 protected ManagedChannel createChannel() { 515 return Worker.this.channel; 516 } 517 518 @Override operationTimeoutMillis()519 protected int operationTimeoutMillis() { 520 // Don't enforce a timeout when using the interop tests for the stress test client. 521 // Fixes https://github.com/grpc/grpc-java/issues/1812 522 return Integer.MAX_VALUE; 523 } 524 525 @Override metricsExpected()526 protected boolean metricsExpected() { 527 // TODO(zhangkun83): we may want to enable the real google Instrumentation implementation in 528 // stress tests. 529 return false; 530 } 531 } 532 533 class WeightedTestCaseSelector { 534 /** 535 * Randomly shuffled and cyclic sequence that contains each testcase proportionally 536 * to its weight. 537 */ 538 final Iterator<TestCases> testCases; 539 WeightedTestCaseSelector(List<TestCaseWeightPair> testCaseWeightPairs)540 WeightedTestCaseSelector(List<TestCaseWeightPair> testCaseWeightPairs) { 541 Preconditions.checkNotNull(testCaseWeightPairs, "testCaseWeightPairs"); 542 Preconditions.checkArgument(testCaseWeightPairs.size() > 0); 543 544 List<TestCases> testCases = new ArrayList<>(); 545 for (TestCaseWeightPair testCaseWeightPair : testCaseWeightPairs) { 546 for (int i = 0; i < testCaseWeightPair.weight; i++) { 547 testCases.add(testCaseWeightPair.testCase); 548 } 549 } 550 551 shuffle(testCases); 552 553 this.testCases = Iterators.cycle(testCases); 554 } 555 nextTestCase()556 TestCases nextTestCase() { 557 return testCases.next(); 558 } 559 } 560 } 561 562 /** 563 * Service that exports the QPS metrics of the stress test. 564 */ 565 private class MetricsServiceImpl extends MetricsServiceGrpc.MetricsServiceImplBase { 566 567 @Override getAllGauges(Metrics.EmptyMessage request, StreamObserver<Metrics.GaugeResponse> responseObserver)568 public void getAllGauges(Metrics.EmptyMessage request, 569 StreamObserver<Metrics.GaugeResponse> responseObserver) { 570 for (Metrics.GaugeResponse gauge : gauges.values()) { 571 responseObserver.onNext(gauge); 572 } 573 responseObserver.onCompleted(); 574 } 575 576 @Override getGauge(Metrics.GaugeRequest request, StreamObserver<Metrics.GaugeResponse> responseObserver)577 public void getGauge(Metrics.GaugeRequest request, 578 StreamObserver<Metrics.GaugeResponse> responseObserver) { 579 String gaugeName = request.getName(); 580 Metrics.GaugeResponse gauge = gauges.get(gaugeName); 581 if (gauge != null) { 582 responseObserver.onNext(gauge); 583 responseObserver.onCompleted(); 584 } else { 585 responseObserver.onError(new StatusException(Status.NOT_FOUND)); 586 } 587 } 588 } 589 590 @VisibleForTesting 591 static class TestCaseWeightPair { 592 final TestCases testCase; 593 final int weight; 594 TestCaseWeightPair(TestCases testCase, int weight)595 TestCaseWeightPair(TestCases testCase, int weight) { 596 Preconditions.checkArgument(weight >= 0, "weight must be positive."); 597 this.testCase = Preconditions.checkNotNull(testCase, "testCase"); 598 this.weight = weight; 599 } 600 601 @Override equals(Object other)602 public boolean equals(Object other) { 603 if (!(other instanceof TestCaseWeightPair)) { 604 return false; 605 } 606 TestCaseWeightPair that = (TestCaseWeightPair) other; 607 return testCase.equals(that.testCase) && weight == that.weight; 608 } 609 610 @Override hashCode()611 public int hashCode() { 612 return Objects.hashCode(testCase, weight); 613 } 614 } 615 616 @VisibleForTesting addresses()617 List<InetSocketAddress> addresses() { 618 return Collections.unmodifiableList(addresses); 619 } 620 621 @VisibleForTesting serverHostOverride()622 String serverHostOverride() { 623 return serverHostOverride; 624 } 625 626 @VisibleForTesting useTls()627 boolean useTls() { 628 return useTls; 629 } 630 631 @VisibleForTesting useTestCa()632 boolean useTestCa() { 633 return useTestCa; 634 } 635 636 @VisibleForTesting testCaseWeightPairs()637 List<TestCaseWeightPair> testCaseWeightPairs() { 638 return testCaseWeightPairs; 639 } 640 641 @VisibleForTesting durationSecs()642 int durationSecs() { 643 return durationSecs; 644 } 645 646 @VisibleForTesting channelsPerServer()647 int channelsPerServer() { 648 return channelsPerServer; 649 } 650 651 @VisibleForTesting stubsPerChannel()652 int stubsPerChannel() { 653 return stubsPerChannel; 654 } 655 656 @VisibleForTesting metricsPort()657 int metricsPort() { 658 return metricsPort; 659 } 660 } 661