/* * Copyright 2016 The gRPC Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.grpc.testing.integration; import static java.util.Arrays.asList; import static java.util.Collections.shuffle; import static java.util.Collections.singletonList; import static java.util.concurrent.Executors.newFixedThreadPool; import static java.util.concurrent.TimeUnit.SECONDS; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Joiner; import com.google.common.base.Objects; import com.google.common.base.Preconditions; import com.google.common.base.Splitter; import com.google.common.collect.Iterators; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.MoreExecutors; import io.grpc.ManagedChannel; import io.grpc.Server; import io.grpc.ServerBuilder; import io.grpc.Status; import io.grpc.StatusException; import io.grpc.internal.testing.TestUtils; import io.grpc.netty.GrpcSslContexts; import io.grpc.netty.NegotiationType; import io.grpc.netty.NettyChannelBuilder; import io.grpc.stub.StreamObserver; import io.netty.handler.ssl.SslContext; import java.io.IOException; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.URI; import java.net.URISyntaxException; import java.net.UnknownHostException; import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.logging.Level; import java.util.logging.Logger; /** * A stress test client following the * * specifications of the gRPC stress testing framework. */ public class StressTestClient { private static final Logger log = Logger.getLogger(StressTestClient.class.getName()); /** * The main application allowing this client to be launched from the command line. */ public static void main(String... args) throws Exception { final StressTestClient client = new StressTestClient(); client.parseArgs(args); // Attempt an orderly shutdown, if the JVM is shutdown via a signal. Runtime.getRuntime().addShutdownHook(new Thread() { @Override public void run() { client.shutdown(); } }); try { client.startMetricsService(); client.runStressTest(); client.blockUntilStressTestComplete(); } catch (Exception e) { log.log(Level.WARNING, "The stress test client encountered an error!", e); } finally { client.shutdown(); } } private static final int WORKER_GRACE_PERIOD_SECS = 30; private List addresses = singletonList(new InetSocketAddress("localhost", 8080)); private List testCaseWeightPairs = new ArrayList<>(); private String serverHostOverride; private boolean useTls = false; private boolean useTestCa = false; private int durationSecs = -1; private int channelsPerServer = 1; private int stubsPerChannel = 1; private int metricsPort = 8081; private Server metricsServer; private final Map gauges = new ConcurrentHashMap(); private volatile boolean shutdown; /** * List of futures that {@link #blockUntilStressTestComplete()} waits for. */ private final List> workerFutures = new ArrayList>(); private final List channels = new ArrayList<>(); private ListeningExecutorService threadpool; @VisibleForTesting void parseArgs(String[] args) { boolean usage = false; String serverAddresses = ""; for (String arg : args) { if (!arg.startsWith("--")) { System.err.println("All arguments must start with '--': " + arg); usage = true; break; } String[] parts = arg.substring(2).split("=", 2); String key = parts[0]; if ("help".equals(key)) { usage = true; break; } if (parts.length != 2) { System.err.println("All arguments must be of the form --arg=value"); usage = true; break; } String value = parts[1]; if ("server_addresses".equals(key)) { // May need to apply server host overrides to the addresses, so delay processing serverAddresses = value; } else if ("server_host_override".equals(key)) { serverHostOverride = value; } else if ("use_tls".equals(key)) { useTls = Boolean.parseBoolean(value); } else if ("use_test_ca".equals(key)) { useTestCa = Boolean.parseBoolean(value); } else if ("test_cases".equals(key)) { testCaseWeightPairs = parseTestCases(value); } else if ("test_duration_secs".equals(key)) { durationSecs = Integer.valueOf(value); } else if ("num_channels_per_server".equals(key)) { channelsPerServer = Integer.valueOf(value); } else if ("num_stubs_per_channel".equals(key)) { stubsPerChannel = Integer.valueOf(value); } else if ("metrics_port".equals(key)) { metricsPort = Integer.valueOf(value); } else { System.err.println("Unknown argument: " + key); usage = true; break; } } if (!usage && !serverAddresses.isEmpty()) { addresses = parseServerAddresses(serverAddresses); usage = addresses.isEmpty(); } if (usage) { StressTestClient c = new StressTestClient(); System.err.println( "Usage: [ARGS...]" + "\n" + "\n --server_host_override=HOST Claimed identification expected of server." + "\n Defaults to server host" + "\n --server_addresses=:,:...:" + "\n Default: " + serverAddressesToString(c.addresses) + "\n --test_cases=,..." + "\n List of tuples. Weight is the relative frequency at which" + " testcase is run." + "\n Valid Testcases:" + validTestCasesHelpText() + "\n --use_tls=true|false Whether to use TLS. Default: " + c.useTls + "\n --use_test_ca=true|false Whether to trust our fake CA. Requires" + " --use_tls=true" + "\n to have effect. Default: " + c.useTestCa + "\n --test_duration_secs=SECONDS '-1' for no limit. Default: " + c.durationSecs + "\n --num_channels_per_server=INT Number of connections to each server address." + " Default: " + c.channelsPerServer + "\n --num_stubs_per_channel=INT Default: " + c.stubsPerChannel + "\n --metrics_port=PORT Listening port of the metrics server." + " Default: " + c.metricsPort ); System.exit(1); } } @VisibleForTesting void startMetricsService() throws IOException { Preconditions.checkState(!shutdown, "client was shutdown."); metricsServer = ServerBuilder.forPort(metricsPort) .addService(new MetricsServiceImpl()) .build() .start(); } @VisibleForTesting void runStressTest() throws Exception { Preconditions.checkState(!shutdown, "client was shutdown."); if (testCaseWeightPairs.isEmpty()) { return; } int numChannels = addresses.size() * channelsPerServer; int numThreads = numChannels * stubsPerChannel; threadpool = MoreExecutors.listeningDecorator(newFixedThreadPool(numThreads)); int serverIdx = -1; for (InetSocketAddress address : addresses) { serverIdx++; for (int i = 0; i < channelsPerServer; i++) { ManagedChannel channel = createChannel(address); channels.add(channel); for (int j = 0; j < stubsPerChannel; j++) { String gaugeName = String.format("/stress_test/server_%d/channel_%d/stub_%d/qps", serverIdx, i, j); Worker worker = new Worker(channel, testCaseWeightPairs, durationSecs, gaugeName); workerFutures.add(threadpool.submit(worker)); } } } } @VisibleForTesting void blockUntilStressTestComplete() throws Exception { Preconditions.checkState(!shutdown, "client was shutdown."); ListenableFuture f = Futures.allAsList(workerFutures); if (durationSecs == -1) { // '-1' indicates that the stress test runs until terminated by the user. f.get(); } else { f.get(durationSecs + WORKER_GRACE_PERIOD_SECS, SECONDS); } } @VisibleForTesting void shutdown() { if (shutdown) { return; } shutdown = true; for (ManagedChannel ch : channels) { try { ch.shutdownNow(); ch.awaitTermination(1, SECONDS); } catch (Throwable t) { log.log(Level.WARNING, "Error shutting down channel!", t); } } try { metricsServer.shutdownNow(); } catch (Throwable t) { log.log(Level.WARNING, "Error shutting down metrics service!", t); } try { if (threadpool != null) { threadpool.shutdownNow(); } } catch (Throwable t) { log.log(Level.WARNING, "Error shutting down threadpool.", t); } } @VisibleForTesting int getMetricServerPort() { return metricsServer.getPort(); } private List parseServerAddresses(String addressesStr) { List addresses = new ArrayList<>(); for (List namePort : parseCommaSeparatedTuples(addressesStr)) { InetAddress address; String name = namePort.get(0); int port = Integer.valueOf(namePort.get(1)); try { address = InetAddress.getByName(name); if (serverHostOverride != null) { // Force the hostname to match the cert the server uses. address = InetAddress.getByAddress(serverHostOverride, address.getAddress()); } } catch (UnknownHostException ex) { throw new RuntimeException(ex); } addresses.add(new InetSocketAddress(address, port)); } return addresses; } private static List parseTestCases(String testCasesStr) { List testCaseWeightPairs = new ArrayList<>(); for (List nameWeight : parseCommaSeparatedTuples(testCasesStr)) { TestCases testCase = TestCases.fromString(nameWeight.get(0)); int weight = Integer.valueOf(nameWeight.get(1)); testCaseWeightPairs.add(new TestCaseWeightPair(testCase, weight)); } return testCaseWeightPairs; } private static List> parseCommaSeparatedTuples(String str) { List> tuples = new ArrayList>(); for (String tupleStr : Splitter.on(',').split(str)) { int splitIdx = tupleStr.lastIndexOf(':'); if (splitIdx == -1) { throw new IllegalArgumentException("Illegal tuple format: '" + tupleStr + "'"); } String part0 = tupleStr.substring(0, splitIdx); String part1 = tupleStr.substring(splitIdx + 1); tuples.add(asList(part0, part1)); } return tuples; } private ManagedChannel createChannel(InetSocketAddress address) { SslContext sslContext = null; if (useTestCa) { try { sslContext = GrpcSslContexts.forClient().trustManager( TestUtils.loadCert("ca.pem")).build(); } catch (Exception ex) { throw new RuntimeException(ex); } } return NettyChannelBuilder.forAddress(address) .negotiationType(useTls ? NegotiationType.TLS : NegotiationType.PLAINTEXT) .sslContext(sslContext) .build(); } private static String serverAddressesToString(List addresses) { List tmp = new ArrayList<>(); for (InetSocketAddress address : addresses) { URI uri; try { uri = new URI(null, null, address.getHostName(), address.getPort(), null, null, null); } catch (URISyntaxException e) { throw new RuntimeException(e); } tmp.add(uri.getAuthority()); } return Joiner.on(',').join(tmp); } private static String validTestCasesHelpText() { StringBuilder builder = new StringBuilder(); for (TestCases testCase : TestCases.values()) { String strTestcase = testCase.name().toLowerCase(); builder.append("\n ") .append(strTestcase) .append(": ") .append(testCase.description()); } return builder.toString(); } /** * A stress test worker. Every stub has its own stress test worker. */ private class Worker implements Runnable { // Interval at which the QPS stats of metrics service are updated. private static final long METRICS_COLLECTION_INTERVAL_SECS = 5; private final ManagedChannel channel; private final List testCaseWeightPairs; private final Integer durationSec; private final String gaugeName; Worker(ManagedChannel channel, List testCaseWeightPairs, int durationSec, String gaugeName) { Preconditions.checkArgument(durationSec >= -1, "durationSec must be gte -1."); this.channel = Preconditions.checkNotNull(channel, "channel"); this.testCaseWeightPairs = Preconditions.checkNotNull(testCaseWeightPairs, "testCaseWeightPairs"); this.durationSec = durationSec == -1 ? null : durationSec; this.gaugeName = Preconditions.checkNotNull(gaugeName, "gaugeName"); } @Override public void run() { // Simplify debugging if the worker crashes / never terminates. Thread.currentThread().setName(gaugeName); Tester tester = new Tester(); tester.setUp(); WeightedTestCaseSelector testCaseSelector = new WeightedTestCaseSelector(testCaseWeightPairs); Long endTime = durationSec == null ? null : System.nanoTime() + SECONDS.toNanos(durationSecs); long lastMetricsCollectionTime = initLastMetricsCollectionTime(); // Number of interop testcases run since the last time metrics have been updated. long testCasesSinceLastMetricsCollection = 0; while (!Thread.currentThread().isInterrupted() && !shutdown && (endTime == null || endTime - System.nanoTime() > 0)) { try { runTestCase(tester, testCaseSelector.nextTestCase()); } catch (Exception e) { throw new RuntimeException(e); } testCasesSinceLastMetricsCollection++; double durationSecs = computeDurationSecs(lastMetricsCollectionTime); if (durationSecs >= METRICS_COLLECTION_INTERVAL_SECS) { long qps = (long) Math.ceil(testCasesSinceLastMetricsCollection / durationSecs); Metrics.GaugeResponse gauge = Metrics.GaugeResponse .newBuilder() .setName(gaugeName) .setLongValue(qps) .build(); gauges.put(gaugeName, gauge); lastMetricsCollectionTime = System.nanoTime(); testCasesSinceLastMetricsCollection = 0; } } } private long initLastMetricsCollectionTime() { return System.nanoTime() - SECONDS.toNanos(METRICS_COLLECTION_INTERVAL_SECS); } private double computeDurationSecs(long lastMetricsCollectionTime) { return (System.nanoTime() - lastMetricsCollectionTime) / 1000000000.0; } private void runTestCase(Tester tester, TestCases testCase) throws Exception { // TODO(buchgr): Implement tests requiring auth, once C++ supports it. switch (testCase) { case EMPTY_UNARY: tester.emptyUnary(); break; case LARGE_UNARY: tester.largeUnary(); break; case CLIENT_STREAMING: tester.clientStreaming(); break; case SERVER_STREAMING: tester.serverStreaming(); break; case PING_PONG: tester.pingPong(); break; case EMPTY_STREAM: tester.emptyStream(); break; case UNIMPLEMENTED_METHOD: { tester.unimplementedMethod(); break; } case UNIMPLEMENTED_SERVICE: { tester.unimplementedService(); break; } case CANCEL_AFTER_BEGIN: { tester.cancelAfterBegin(); break; } case CANCEL_AFTER_FIRST_RESPONSE: { tester.cancelAfterFirstResponse(); break; } case TIMEOUT_ON_SLEEPING_SERVER: { tester.timeoutOnSleepingServer(); break; } default: throw new IllegalArgumentException("Unknown test case: " + testCase); } } class Tester extends AbstractInteropTest { @Override protected ManagedChannel createChannel() { return Worker.this.channel; } @Override protected int operationTimeoutMillis() { // Don't enforce a timeout when using the interop tests for the stress test client. // Fixes https://github.com/grpc/grpc-java/issues/1812 return Integer.MAX_VALUE; } @Override protected boolean metricsExpected() { // TODO(zhangkun83): we may want to enable the real google Instrumentation implementation in // stress tests. return false; } } class WeightedTestCaseSelector { /** * Randomly shuffled and cyclic sequence that contains each testcase proportionally * to its weight. */ final Iterator testCases; WeightedTestCaseSelector(List testCaseWeightPairs) { Preconditions.checkNotNull(testCaseWeightPairs, "testCaseWeightPairs"); Preconditions.checkArgument(testCaseWeightPairs.size() > 0); List testCases = new ArrayList<>(); for (TestCaseWeightPair testCaseWeightPair : testCaseWeightPairs) { for (int i = 0; i < testCaseWeightPair.weight; i++) { testCases.add(testCaseWeightPair.testCase); } } shuffle(testCases); this.testCases = Iterators.cycle(testCases); } TestCases nextTestCase() { return testCases.next(); } } } /** * Service that exports the QPS metrics of the stress test. */ private class MetricsServiceImpl extends MetricsServiceGrpc.MetricsServiceImplBase { @Override public void getAllGauges(Metrics.EmptyMessage request, StreamObserver responseObserver) { for (Metrics.GaugeResponse gauge : gauges.values()) { responseObserver.onNext(gauge); } responseObserver.onCompleted(); } @Override public void getGauge(Metrics.GaugeRequest request, StreamObserver responseObserver) { String gaugeName = request.getName(); Metrics.GaugeResponse gauge = gauges.get(gaugeName); if (gauge != null) { responseObserver.onNext(gauge); responseObserver.onCompleted(); } else { responseObserver.onError(new StatusException(Status.NOT_FOUND)); } } } @VisibleForTesting static class TestCaseWeightPair { final TestCases testCase; final int weight; TestCaseWeightPair(TestCases testCase, int weight) { Preconditions.checkArgument(weight >= 0, "weight must be positive."); this.testCase = Preconditions.checkNotNull(testCase, "testCase"); this.weight = weight; } @Override public boolean equals(Object other) { if (!(other instanceof TestCaseWeightPair)) { return false; } TestCaseWeightPair that = (TestCaseWeightPair) other; return testCase.equals(that.testCase) && weight == that.weight; } @Override public int hashCode() { return Objects.hashCode(testCase, weight); } } @VisibleForTesting List addresses() { return Collections.unmodifiableList(addresses); } @VisibleForTesting String serverHostOverride() { return serverHostOverride; } @VisibleForTesting boolean useTls() { return useTls; } @VisibleForTesting boolean useTestCa() { return useTestCa; } @VisibleForTesting List testCaseWeightPairs() { return testCaseWeightPairs; } @VisibleForTesting int durationSecs() { return durationSecs; } @VisibleForTesting int channelsPerServer() { return channelsPerServer; } @VisibleForTesting int stubsPerChannel() { return stubsPerChannel; } @VisibleForTesting int metricsPort() { return metricsPort; } }