1 /*
2  * Copyright 2016 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.testing.integration;
18 
19 import com.google.common.base.Preconditions;
20 import com.google.common.util.concurrent.MoreExecutors;
21 import io.grpc.ManagedChannel;
22 import io.grpc.Server;
23 import io.grpc.internal.testing.TestUtils;
24 import io.grpc.netty.GrpcSslContexts;
25 import io.grpc.netty.NegotiationType;
26 import io.grpc.netty.NettyChannelBuilder;
27 import io.grpc.netty.NettyServerBuilder;
28 import io.grpc.stub.StreamObserver;
29 import io.grpc.testing.integration.Messages.PayloadType;
30 import io.grpc.testing.integration.Messages.ResponseParameters;
31 import io.grpc.testing.integration.Messages.StreamingOutputCallRequest;
32 import io.grpc.testing.integration.Messages.StreamingOutputCallResponse;
33 import io.netty.handler.ssl.ClientAuth;
34 import io.netty.handler.ssl.SslContext;
35 import java.io.File;
36 import java.io.IOException;
37 import java.security.cert.CertificateException;
38 import java.security.cert.X509Certificate;
39 import java.util.concurrent.CountDownLatch;
40 import java.util.concurrent.CyclicBarrier;
41 import java.util.concurrent.ExecutorService;
42 import java.util.concurrent.Executors;
43 import java.util.concurrent.ScheduledExecutorService;
44 import java.util.concurrent.TimeUnit;
45 import org.junit.After;
46 import org.junit.Before;
47 import org.junit.Rule;
48 import org.junit.Test;
49 import org.junit.rules.Timeout;
50 import org.junit.runner.RunWith;
51 import org.junit.runners.JUnit4;
52 
53 
54 /**
55  * Tests that gRPC clients and servers can handle concurrent RPCs.
56  *
57  * <p>These tests use TLS to make them more realistic, and because we'd like to test the thread
58  * safety of the TLS-related code paths as well.
59  */
60 // TODO: Consider augmenting this class to perform non-streaming, client streaming, and
61 // bidirectional streaming requests also.
62 @RunWith(JUnit4.class)
63 public class ConcurrencyTest {
64 
65   @Rule public final Timeout globalTimeout = Timeout.seconds(10);
66 
67   /**
68    * A response observer that signals a {@code CountDownLatch} when the proper number of responses
69    * arrives and the server signals that the RPC is complete.
70    */
71   private static class SignalingResponseObserver
72       implements StreamObserver<StreamingOutputCallResponse> {
SignalingResponseObserver(CountDownLatch responsesDoneSignal)73     public SignalingResponseObserver(CountDownLatch responsesDoneSignal) {
74       this.responsesDoneSignal = responsesDoneSignal;
75     }
76 
77     @Override
onCompleted()78     public void onCompleted() {
79       Preconditions.checkState(numResponsesReceived == NUM_RESPONSES_PER_REQUEST);
80       responsesDoneSignal.countDown();
81     }
82 
83     @Override
onError(Throwable error)84     public void onError(Throwable error) {
85       // This should never happen. If it does happen, ensure that the error is visible.
86       error.printStackTrace();
87     }
88 
89     @Override
onNext(StreamingOutputCallResponse response)90     public void onNext(StreamingOutputCallResponse response) {
91       numResponsesReceived++;
92     }
93 
94     private final CountDownLatch responsesDoneSignal;
95     private int numResponsesReceived = 0;
96   }
97 
98   /**
99    * A client worker task that waits until all client workers are ready, then sends a request for a
100    * server-streaming RPC and arranges for a {@code CountDownLatch} to be signaled when the RPC is
101    * complete.
102    */
103   private class ClientWorker implements Runnable {
ClientWorker(CyclicBarrier startBarrier, CountDownLatch responsesDoneSignal)104     public ClientWorker(CyclicBarrier startBarrier, CountDownLatch responsesDoneSignal) {
105       this.startBarrier = startBarrier;
106       this.responsesDoneSignal = responsesDoneSignal;
107     }
108 
109     @Override
run()110     public void run() {
111       try {
112         // Prepare the request.
113         StreamingOutputCallRequest.Builder requestBuilder = StreamingOutputCallRequest.newBuilder()
114             .setResponseType(PayloadType.RANDOM);
115         for (int i = 0; i < NUM_RESPONSES_PER_REQUEST; i++) {
116           requestBuilder.addResponseParameters(ResponseParameters.newBuilder()
117               .setSize(1000)
118               .setIntervalUs(0));  // No delay between responses, for maximum concurrency.
119         }
120         StreamingOutputCallRequest request = requestBuilder.build();
121 
122         // Wait until all client worker threads are poised & ready, then send the request. This way
123         // all clients send their requests at approximately the same time.
124         startBarrier.await();
125         clientStub.streamingOutputCall(request, new SignalingResponseObserver(responsesDoneSignal));
126       } catch (Exception e) {
127         throw e instanceof RuntimeException ? (RuntimeException) e : new RuntimeException(e);
128       }
129     }
130 
131     private final CyclicBarrier startBarrier;
132     private final CountDownLatch responsesDoneSignal;
133   }
134 
135   private static final int NUM_SERVER_THREADS = 10;
136   private static final int NUM_CONCURRENT_REQUESTS = 100;
137   private static final int NUM_RESPONSES_PER_REQUEST = 100;
138 
139   private Server server;
140   private ManagedChannel clientChannel;
141   private TestServiceGrpc.TestServiceStub clientStub;
142   private ScheduledExecutorService serverExecutor;
143   private ExecutorService clientExecutor;
144 
145   @Before
setUp()146   public void setUp() throws Exception {
147     serverExecutor = Executors.newScheduledThreadPool(NUM_SERVER_THREADS);
148     clientExecutor = Executors.newFixedThreadPool(NUM_CONCURRENT_REQUESTS);
149 
150     server = newServer();
151 
152     // Create the client. Keep a reference to its channel so we can shut it down during tearDown().
153     clientChannel = newClientChannel();
154     clientStub = TestServiceGrpc.newStub(clientChannel);
155   }
156 
157   @After
tearDown()158   public void tearDown() {
159     if (server != null) {
160       server.shutdown();
161     }
162     if (clientChannel != null) {
163       clientChannel.shutdown();
164     }
165 
166     MoreExecutors.shutdownAndAwaitTermination(serverExecutor, 5, TimeUnit.SECONDS);
167     MoreExecutors.shutdownAndAwaitTermination(clientExecutor, 5, TimeUnit.SECONDS);
168   }
169 
170   /**
171    * Tests that gRPC can handle concurrent server-streaming RPCs.
172    */
173   @Test
serverStreamingTest()174   public void serverStreamingTest() throws Exception {
175     CyclicBarrier startBarrier = new CyclicBarrier(NUM_CONCURRENT_REQUESTS);
176     CountDownLatch responsesDoneSignal = new CountDownLatch(NUM_CONCURRENT_REQUESTS);
177 
178     for (int i = 0; i < NUM_CONCURRENT_REQUESTS; i++) {
179       clientExecutor.execute(new ClientWorker(startBarrier, responsesDoneSignal));
180     }
181 
182     // Wait until the clients all receive their complete RPC response streams.
183     responsesDoneSignal.await();
184   }
185 
186   /**
187    * Creates and starts a new {@link TestServiceImpl} server.
188    */
newServer()189   private Server newServer() throws CertificateException, IOException {
190     File serverCertChainFile = TestUtils.loadCert("server1.pem");
191     File serverPrivateKeyFile = TestUtils.loadCert("server1.key");
192     X509Certificate[] serverTrustedCaCerts = {
193       TestUtils.loadX509Cert("ca.pem")
194     };
195 
196     SslContext sslContext =
197         GrpcSslContexts.forServer(serverCertChainFile, serverPrivateKeyFile)
198                        .trustManager(serverTrustedCaCerts)
199                        .clientAuth(ClientAuth.REQUIRE)
200                        .build();
201 
202     return NettyServerBuilder.forPort(0)
203         .sslContext(sslContext)
204         .addService(new TestServiceImpl(serverExecutor))
205         .build()
206         .start();
207   }
208 
newClientChannel()209   private ManagedChannel newClientChannel() throws CertificateException, IOException {
210     File clientCertChainFile = TestUtils.loadCert("client.pem");
211     File clientPrivateKeyFile = TestUtils.loadCert("client.key");
212     X509Certificate[] clientTrustedCaCerts = {
213       TestUtils.loadX509Cert("ca.pem")
214     };
215 
216     SslContext sslContext =
217         GrpcSslContexts.forClient()
218                        .keyManager(clientCertChainFile, clientPrivateKeyFile)
219                        .trustManager(clientTrustedCaCerts)
220                        .build();
221 
222     return NettyChannelBuilder.forAddress("localhost", server.getPort())
223         .overrideAuthority(TestUtils.TEST_SERVER_HOST)
224         .negotiationType(NegotiationType.TLS)
225         .sslContext(sslContext)
226         .build();
227   }
228 }
229