1 /* 2 * Copyright 2016 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.testing.integration; 18 19 import com.google.common.base.Preconditions; 20 import com.google.common.util.concurrent.MoreExecutors; 21 import io.grpc.ManagedChannel; 22 import io.grpc.Server; 23 import io.grpc.internal.testing.TestUtils; 24 import io.grpc.netty.GrpcSslContexts; 25 import io.grpc.netty.NegotiationType; 26 import io.grpc.netty.NettyChannelBuilder; 27 import io.grpc.netty.NettyServerBuilder; 28 import io.grpc.stub.StreamObserver; 29 import io.grpc.testing.integration.Messages.PayloadType; 30 import io.grpc.testing.integration.Messages.ResponseParameters; 31 import io.grpc.testing.integration.Messages.StreamingOutputCallRequest; 32 import io.grpc.testing.integration.Messages.StreamingOutputCallResponse; 33 import io.netty.handler.ssl.ClientAuth; 34 import io.netty.handler.ssl.SslContext; 35 import java.io.File; 36 import java.io.IOException; 37 import java.security.cert.CertificateException; 38 import java.security.cert.X509Certificate; 39 import java.util.concurrent.CountDownLatch; 40 import java.util.concurrent.CyclicBarrier; 41 import java.util.concurrent.ExecutorService; 42 import java.util.concurrent.Executors; 43 import java.util.concurrent.ScheduledExecutorService; 44 import java.util.concurrent.TimeUnit; 45 import org.junit.After; 46 import org.junit.Before; 47 import org.junit.Rule; 48 import org.junit.Test; 49 import org.junit.rules.Timeout; 50 import org.junit.runner.RunWith; 51 import org.junit.runners.JUnit4; 52 53 54 /** 55 * Tests that gRPC clients and servers can handle concurrent RPCs. 56 * 57 * <p>These tests use TLS to make them more realistic, and because we'd like to test the thread 58 * safety of the TLS-related code paths as well. 59 */ 60 // TODO: Consider augmenting this class to perform non-streaming, client streaming, and 61 // bidirectional streaming requests also. 62 @RunWith(JUnit4.class) 63 public class ConcurrencyTest { 64 65 @Rule public final Timeout globalTimeout = Timeout.seconds(10); 66 67 /** 68 * A response observer that signals a {@code CountDownLatch} when the proper number of responses 69 * arrives and the server signals that the RPC is complete. 70 */ 71 private static class SignalingResponseObserver 72 implements StreamObserver<StreamingOutputCallResponse> { SignalingResponseObserver(CountDownLatch responsesDoneSignal)73 public SignalingResponseObserver(CountDownLatch responsesDoneSignal) { 74 this.responsesDoneSignal = responsesDoneSignal; 75 } 76 77 @Override onCompleted()78 public void onCompleted() { 79 Preconditions.checkState(numResponsesReceived == NUM_RESPONSES_PER_REQUEST); 80 responsesDoneSignal.countDown(); 81 } 82 83 @Override onError(Throwable error)84 public void onError(Throwable error) { 85 // This should never happen. If it does happen, ensure that the error is visible. 86 error.printStackTrace(); 87 } 88 89 @Override onNext(StreamingOutputCallResponse response)90 public void onNext(StreamingOutputCallResponse response) { 91 numResponsesReceived++; 92 } 93 94 private final CountDownLatch responsesDoneSignal; 95 private int numResponsesReceived = 0; 96 } 97 98 /** 99 * A client worker task that waits until all client workers are ready, then sends a request for a 100 * server-streaming RPC and arranges for a {@code CountDownLatch} to be signaled when the RPC is 101 * complete. 102 */ 103 private class ClientWorker implements Runnable { ClientWorker(CyclicBarrier startBarrier, CountDownLatch responsesDoneSignal)104 public ClientWorker(CyclicBarrier startBarrier, CountDownLatch responsesDoneSignal) { 105 this.startBarrier = startBarrier; 106 this.responsesDoneSignal = responsesDoneSignal; 107 } 108 109 @Override run()110 public void run() { 111 try { 112 // Prepare the request. 113 StreamingOutputCallRequest.Builder requestBuilder = StreamingOutputCallRequest.newBuilder() 114 .setResponseType(PayloadType.RANDOM); 115 for (int i = 0; i < NUM_RESPONSES_PER_REQUEST; i++) { 116 requestBuilder.addResponseParameters(ResponseParameters.newBuilder() 117 .setSize(1000) 118 .setIntervalUs(0)); // No delay between responses, for maximum concurrency. 119 } 120 StreamingOutputCallRequest request = requestBuilder.build(); 121 122 // Wait until all client worker threads are poised & ready, then send the request. This way 123 // all clients send their requests at approximately the same time. 124 startBarrier.await(); 125 clientStub.streamingOutputCall(request, new SignalingResponseObserver(responsesDoneSignal)); 126 } catch (Exception e) { 127 throw e instanceof RuntimeException ? (RuntimeException) e : new RuntimeException(e); 128 } 129 } 130 131 private final CyclicBarrier startBarrier; 132 private final CountDownLatch responsesDoneSignal; 133 } 134 135 private static final int NUM_SERVER_THREADS = 10; 136 private static final int NUM_CONCURRENT_REQUESTS = 100; 137 private static final int NUM_RESPONSES_PER_REQUEST = 100; 138 139 private Server server; 140 private ManagedChannel clientChannel; 141 private TestServiceGrpc.TestServiceStub clientStub; 142 private ScheduledExecutorService serverExecutor; 143 private ExecutorService clientExecutor; 144 145 @Before setUp()146 public void setUp() throws Exception { 147 serverExecutor = Executors.newScheduledThreadPool(NUM_SERVER_THREADS); 148 clientExecutor = Executors.newFixedThreadPool(NUM_CONCURRENT_REQUESTS); 149 150 server = newServer(); 151 152 // Create the client. Keep a reference to its channel so we can shut it down during tearDown(). 153 clientChannel = newClientChannel(); 154 clientStub = TestServiceGrpc.newStub(clientChannel); 155 } 156 157 @After tearDown()158 public void tearDown() { 159 if (server != null) { 160 server.shutdown(); 161 } 162 if (clientChannel != null) { 163 clientChannel.shutdown(); 164 } 165 166 MoreExecutors.shutdownAndAwaitTermination(serverExecutor, 5, TimeUnit.SECONDS); 167 MoreExecutors.shutdownAndAwaitTermination(clientExecutor, 5, TimeUnit.SECONDS); 168 } 169 170 /** 171 * Tests that gRPC can handle concurrent server-streaming RPCs. 172 */ 173 @Test serverStreamingTest()174 public void serverStreamingTest() throws Exception { 175 CyclicBarrier startBarrier = new CyclicBarrier(NUM_CONCURRENT_REQUESTS); 176 CountDownLatch responsesDoneSignal = new CountDownLatch(NUM_CONCURRENT_REQUESTS); 177 178 for (int i = 0; i < NUM_CONCURRENT_REQUESTS; i++) { 179 clientExecutor.execute(new ClientWorker(startBarrier, responsesDoneSignal)); 180 } 181 182 // Wait until the clients all receive their complete RPC response streams. 183 responsesDoneSignal.await(); 184 } 185 186 /** 187 * Creates and starts a new {@link TestServiceImpl} server. 188 */ newServer()189 private Server newServer() throws CertificateException, IOException { 190 File serverCertChainFile = TestUtils.loadCert("server1.pem"); 191 File serverPrivateKeyFile = TestUtils.loadCert("server1.key"); 192 X509Certificate[] serverTrustedCaCerts = { 193 TestUtils.loadX509Cert("ca.pem") 194 }; 195 196 SslContext sslContext = 197 GrpcSslContexts.forServer(serverCertChainFile, serverPrivateKeyFile) 198 .trustManager(serverTrustedCaCerts) 199 .clientAuth(ClientAuth.REQUIRE) 200 .build(); 201 202 return NettyServerBuilder.forPort(0) 203 .sslContext(sslContext) 204 .addService(new TestServiceImpl(serverExecutor)) 205 .build() 206 .start(); 207 } 208 newClientChannel()209 private ManagedChannel newClientChannel() throws CertificateException, IOException { 210 File clientCertChainFile = TestUtils.loadCert("client.pem"); 211 File clientPrivateKeyFile = TestUtils.loadCert("client.key"); 212 X509Certificate[] clientTrustedCaCerts = { 213 TestUtils.loadX509Cert("ca.pem") 214 }; 215 216 SslContext sslContext = 217 GrpcSslContexts.forClient() 218 .keyManager(clientCertChainFile, clientPrivateKeyFile) 219 .trustManager(clientTrustedCaCerts) 220 .build(); 221 222 return NettyChannelBuilder.forAddress("localhost", server.getPort()) 223 .overrideAuthority(TestUtils.TEST_SERVER_HOST) 224 .negotiationType(NegotiationType.TLS) 225 .sslContext(sslContext) 226 .build(); 227 } 228 } 229