1 /*
2  * Copyright 2015 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.benchmarks.netty;
18 
19 import io.grpc.CallOptions;
20 import io.grpc.ClientCall;
21 import io.grpc.ManagedChannel;
22 import io.grpc.Metadata;
23 import io.grpc.MethodDescriptor;
24 import io.grpc.MethodDescriptor.MethodType;
25 import io.grpc.Server;
26 import io.grpc.ServerCall;
27 import io.grpc.ServerCallHandler;
28 import io.grpc.ServerServiceDefinition;
29 import io.grpc.ServiceDescriptor;
30 import io.grpc.Status;
31 import io.grpc.benchmarks.ByteBufOutputMarshaller;
32 import io.grpc.netty.NegotiationType;
33 import io.grpc.netty.NettyChannelBuilder;
34 import io.grpc.netty.NettyServerBuilder;
35 import io.grpc.stub.ClientCalls;
36 import io.grpc.stub.StreamObserver;
37 import io.netty.buffer.ByteBuf;
38 import io.netty.buffer.PooledByteBufAllocator;
39 import io.netty.channel.local.LocalAddress;
40 import io.netty.channel.local.LocalChannel;
41 import io.netty.channel.local.LocalServerChannel;
42 import io.netty.channel.nio.NioEventLoopGroup;
43 import io.netty.util.concurrent.DefaultThreadFactory;
44 import java.net.InetAddress;
45 import java.net.InetSocketAddress;
46 import java.net.NetworkInterface;
47 import java.net.ServerSocket;
48 import java.net.SocketAddress;
49 import java.net.SocketException;
50 import java.net.UnknownHostException;
51 import java.util.Enumeration;
52 import java.util.concurrent.CountDownLatch;
53 import java.util.concurrent.ThreadFactory;
54 import java.util.concurrent.TimeUnit;
55 import java.util.concurrent.atomic.AtomicBoolean;
56 import java.util.concurrent.atomic.AtomicLong;
57 import java.util.concurrent.atomic.AtomicReference;
58 import java.util.logging.Level;
59 import java.util.logging.Logger;
60 
61 /**
62  * Abstract base class for Netty end-to-end benchmarks.
63  */
64 public abstract class AbstractBenchmark {
65 
66   private static final Logger logger = Logger.getLogger(AbstractBenchmark.class.getName());
67 
68   /**
69    * Standard message sizes.
70    */
71   public enum MessageSize {
72     // Max out at 1MB to avoid creating messages larger than Netty's buffer pool can handle
73     // by default
74     SMALL(10), MEDIUM(1024), LARGE(65536), JUMBO(1048576);
75 
76     private final int bytes;
MessageSize(int bytes)77     MessageSize(int bytes) {
78       this.bytes = bytes;
79     }
80 
bytes()81     public int bytes() {
82       return bytes;
83     }
84   }
85 
86   /**
87    * Standard flow-control window sizes.
88    */
89   public enum FlowWindowSize {
90     SMALL(16383), MEDIUM(65535), LARGE(1048575), JUMBO(8388607);
91 
92     private final int bytes;
FlowWindowSize(int bytes)93     FlowWindowSize(int bytes) {
94       this.bytes = bytes;
95     }
96 
bytes()97     public int bytes() {
98       return bytes;
99     }
100   }
101 
102   /**
103    * Executor types used by Channel & Server.
104    */
105   public enum ExecutorType {
106     DEFAULT, DIRECT;
107   }
108 
109   /**
110    * Support channel types.
111    */
112   public enum ChannelType {
113     NIO, LOCAL;
114   }
115 
116   private static final CallOptions CALL_OPTIONS = CallOptions.DEFAULT;
117 
118   private static final InetAddress BENCHMARK_ADDR = buildBenchmarkAddr();
119 
120   /**
121    * Resolve the address bound to the benchmark interface. Currently we assume it's a
122    * child interface of the loopback interface with the term 'benchmark' in its name.
123    *
124    * <p>>This allows traffic shaping to be applied to an IP address and to have the benchmarks
125    * detect it's presence and use it. E.g for Linux we can apply netem to a specific IP to
126    * do traffic shaping, bind that IP to the loopback adapter and then apply a label to that
127    * binding so that it appears as a child interface.
128    *
129    * <pre>
130    * sudo tc qdisc del dev lo root
131    * sudo tc qdisc add dev lo root handle 1: prio
132    * sudo tc qdisc add dev lo parent 1:1 handle 2: netem delay 0.1ms rate 10gbit
133    * sudo tc filter add dev lo parent 1:0 protocol ip prio 1  \
134    *            u32 match ip dst 127.127.127.127 flowid 2:1
135    * sudo ip addr add dev lo 127.127.127.127/32 label lo:benchmark
136    * </pre>
137    */
buildBenchmarkAddr()138   private static InetAddress buildBenchmarkAddr() {
139     InetAddress tmp = null;
140     try {
141       Enumeration<NetworkInterface> networkInterfaces = NetworkInterface.getNetworkInterfaces();
142       outer: while (networkInterfaces.hasMoreElements()) {
143         NetworkInterface networkInterface = networkInterfaces.nextElement();
144         if (!networkInterface.isLoopback()) {
145           continue;
146         }
147         Enumeration<NetworkInterface> subInterfaces = networkInterface.getSubInterfaces();
148         while (subInterfaces.hasMoreElements()) {
149           NetworkInterface subLoopback = subInterfaces.nextElement();
150           if (subLoopback.getDisplayName().contains("benchmark")) {
151             tmp = subLoopback.getInetAddresses().nextElement();
152             System.out.println("\nResolved benchmark address to " + tmp + " on "
153                 + subLoopback.getDisplayName() + "\n\n");
154             break outer;
155           }
156         }
157       }
158     } catch (SocketException se) {
159       System.out.println("\nWARNING: Error trying to resolve benchmark interface \n" +  se);
160     }
161     if (tmp == null) {
162       try {
163         System.out.println(
164             "\nWARNING: Unable to resolve benchmark interface, defaulting to localhost");
165         tmp = InetAddress.getLocalHost();
166       } catch (UnknownHostException uhe) {
167         throw new RuntimeException(uhe);
168       }
169     }
170     return tmp;
171   }
172 
173   protected Server server;
174   protected ByteBuf request;
175   protected ByteBuf response;
176   protected MethodDescriptor<ByteBuf, ByteBuf> unaryMethod;
177   private MethodDescriptor<ByteBuf, ByteBuf> pingPongMethod;
178   private MethodDescriptor<ByteBuf, ByteBuf> flowControlledStreaming;
179   protected ManagedChannel[] channels;
180 
AbstractBenchmark()181   public AbstractBenchmark() {
182   }
183 
184   /**
185    * Initialize the environment for the executor.
186    */
setup(ExecutorType clientExecutor, ExecutorType serverExecutor, MessageSize requestSize, MessageSize responseSize, FlowWindowSize windowSize, ChannelType channelType, int maxConcurrentStreams, int channelCount)187   public void setup(ExecutorType clientExecutor,
188                     ExecutorType serverExecutor,
189                     MessageSize requestSize,
190                     MessageSize responseSize,
191                     FlowWindowSize windowSize,
192                     ChannelType channelType,
193                     int maxConcurrentStreams,
194                     int channelCount) throws Exception {
195     NettyServerBuilder serverBuilder;
196     NettyChannelBuilder channelBuilder;
197     if (channelType == ChannelType.LOCAL) {
198       LocalAddress address = new LocalAddress("netty-e2e-benchmark");
199       serverBuilder = NettyServerBuilder.forAddress(address);
200       serverBuilder.channelType(LocalServerChannel.class);
201       channelBuilder = NettyChannelBuilder.forAddress(address);
202       channelBuilder.channelType(LocalChannel.class);
203     } else {
204       ServerSocket sock = new ServerSocket();
205       // Pick a port using an ephemeral socket.
206       sock.bind(new InetSocketAddress(BENCHMARK_ADDR, 0));
207       SocketAddress address = sock.getLocalSocketAddress();
208       sock.close();
209       serverBuilder = NettyServerBuilder.forAddress(address);
210       channelBuilder = NettyChannelBuilder.forAddress(address);
211     }
212 
213     if (serverExecutor == ExecutorType.DIRECT) {
214       serverBuilder.directExecutor();
215     }
216     if (clientExecutor == ExecutorType.DIRECT) {
217       channelBuilder.directExecutor();
218     }
219 
220     // Always use a different worker group from the client.
221     ThreadFactory serverThreadFactory = new DefaultThreadFactory("STF pool", true /* daemon */);
222     serverBuilder.workerEventLoopGroup(new NioEventLoopGroup(0, serverThreadFactory));
223 
224     // Always set connection and stream window size to same value
225     serverBuilder.flowControlWindow(windowSize.bytes());
226     channelBuilder.flowControlWindow(windowSize.bytes());
227 
228     channelBuilder.negotiationType(NegotiationType.PLAINTEXT);
229     serverBuilder.maxConcurrentCallsPerConnection(maxConcurrentStreams);
230 
231     // Create buffers of the desired size for requests and responses.
232     PooledByteBufAllocator alloc = PooledByteBufAllocator.DEFAULT;
233     // Use a heap buffer for now, since MessageFramer doesn't know how to directly convert this
234     // into a WritableBuffer
235     // TODO(carl-mastrangelo): convert this into a regular buffer() call.  See
236     // https://github.com/grpc/grpc-java/issues/2062#issuecomment-234646216
237     request = alloc.heapBuffer(requestSize.bytes());
238     request.writerIndex(request.capacity() - 1);
239     response = alloc.heapBuffer(responseSize.bytes());
240     response.writerIndex(response.capacity() - 1);
241 
242     // Simple method that sends and receives NettyByteBuf
243     unaryMethod = MethodDescriptor.<ByteBuf, ByteBuf>newBuilder()
244         .setType(MethodType.UNARY)
245         .setFullMethodName("benchmark/unary")
246         .setRequestMarshaller(new ByteBufOutputMarshaller())
247         .setResponseMarshaller(new ByteBufOutputMarshaller())
248         .build();
249 
250     pingPongMethod = unaryMethod.toBuilder()
251         .setType(MethodType.BIDI_STREAMING)
252         .setFullMethodName("benchmark/pingPong")
253         .build();
254     flowControlledStreaming = pingPongMethod.toBuilder()
255         .setFullMethodName("benchmark/flowControlledStreaming")
256         .build();
257 
258     // Server implementation of unary & streaming methods
259     serverBuilder.addService(
260         ServerServiceDefinition.builder(
261             new ServiceDescriptor("benchmark",
262                 unaryMethod,
263                 pingPongMethod,
264                 flowControlledStreaming))
265             .addMethod(unaryMethod, new ServerCallHandler<ByteBuf, ByteBuf>() {
266                   @Override
267                   public ServerCall.Listener<ByteBuf> startCall(
268                       final ServerCall<ByteBuf, ByteBuf> call,
269                       Metadata headers) {
270                     call.sendHeaders(new Metadata());
271                     call.request(1);
272                     return new ServerCall.Listener<ByteBuf>() {
273                       @Override
274                       public void onMessage(ByteBuf message) {
275                         // no-op
276                         message.release();
277                         call.sendMessage(response.slice());
278                       }
279 
280                       @Override
281                       public void onHalfClose() {
282                         call.close(Status.OK, new Metadata());
283                       }
284 
285                       @Override
286                       public void onCancel() {
287 
288                       }
289 
290                       @Override
291                       public void onComplete() {
292                       }
293                     };
294                   }
295                 })
296             .addMethod(pingPongMethod, new ServerCallHandler<ByteBuf, ByteBuf>() {
297                   @Override
298                   public ServerCall.Listener<ByteBuf> startCall(
299                       final ServerCall<ByteBuf, ByteBuf> call,
300                       Metadata headers) {
301                     call.sendHeaders(new Metadata());
302                     call.request(1);
303                     return new ServerCall.Listener<ByteBuf>() {
304                       @Override
305                       public void onMessage(ByteBuf message) {
306                         message.release();
307                         call.sendMessage(response.slice());
308                         // Request next message
309                         call.request(1);
310                       }
311 
312                       @Override
313                       public void onHalfClose() {
314                         call.close(Status.OK, new Metadata());
315                       }
316 
317                       @Override
318                       public void onCancel() {
319 
320                       }
321 
322                       @Override
323                       public void onComplete() {
324 
325                       }
326                     };
327                   }
328                 })
329             .addMethod(flowControlledStreaming, new ServerCallHandler<ByteBuf, ByteBuf>() {
330                   @Override
331                   public ServerCall.Listener<ByteBuf> startCall(
332                       final ServerCall<ByteBuf, ByteBuf> call,
333                       Metadata headers) {
334                     call.sendHeaders(new Metadata());
335                     call.request(1);
336                     return new ServerCall.Listener<ByteBuf>() {
337                       @Override
338                       public void onMessage(ByteBuf message) {
339                         message.release();
340                         while (call.isReady()) {
341                           call.sendMessage(response.slice());
342                         }
343                         // Request next message
344                         call.request(1);
345                       }
346 
347                       @Override
348                       public void onHalfClose() {
349                         call.close(Status.OK, new Metadata());
350                       }
351 
352                       @Override
353                       public void onCancel() {
354 
355                       }
356 
357                       @Override
358                       public void onComplete() {
359 
360                       }
361 
362                       @Override
363                       public void onReady() {
364                         while (call.isReady()) {
365                           call.sendMessage(response.slice());
366                         }
367                       }
368                     };
369                   }
370                 })
371             .build());
372 
373     // Build and start the clients and servers
374     server = serverBuilder.build();
375     server.start();
376     channels = new ManagedChannel[channelCount];
377     ThreadFactory clientThreadFactory = new DefaultThreadFactory("CTF pool", true /* daemon */);
378     for (int i = 0; i < channelCount; i++) {
379       // Use a dedicated event-loop for each channel
380       channels[i] = channelBuilder
381           .eventLoopGroup(new NioEventLoopGroup(1, clientThreadFactory))
382           .build();
383     }
384   }
385 
386   /**
387    * Start a continuously executing set of unary calls that will terminate when
388    * {@code done.get()} is true. Each completed call will increment the counter by the specified
389    * delta which benchmarks can use to measure QPS or bandwidth.
390    */
startUnaryCalls(int callsPerChannel, final AtomicLong counter, final AtomicBoolean done, final long counterDelta)391   protected void startUnaryCalls(int callsPerChannel,
392                                  final AtomicLong counter,
393                                  final AtomicBoolean done,
394                                  final long counterDelta) {
395     for (final ManagedChannel channel : channels) {
396       for (int i = 0; i < callsPerChannel; i++) {
397         StreamObserver<ByteBuf> observer = new StreamObserver<ByteBuf>() {
398           @Override
399           public void onNext(ByteBuf value) {
400             counter.addAndGet(counterDelta);
401           }
402 
403           @Override
404           public void onError(Throwable t) {
405             done.set(true);
406           }
407 
408           @Override
409           public void onCompleted() {
410             if (!done.get()) {
411               ByteBuf slice = request.slice();
412               ClientCalls.asyncUnaryCall(
413                   channel.newCall(unaryMethod, CALL_OPTIONS), slice, this);
414             }
415           }
416         };
417         observer.onCompleted();
418       }
419     }
420   }
421 
422   /**
423    * Start a continuously executing set of duplex streaming ping-pong calls that will terminate when
424    * {@code done.get()} is true. Each completed call will increment the counter by the specified
425    * delta which benchmarks can use to measure messages per second or bandwidth.
426    */
startStreamingCalls(int callsPerChannel, final AtomicLong counter, final AtomicBoolean record, final AtomicBoolean done, final long counterDelta)427   protected CountDownLatch startStreamingCalls(int callsPerChannel, final AtomicLong counter,
428       final AtomicBoolean record, final AtomicBoolean done, final long counterDelta) {
429     final CountDownLatch latch = new CountDownLatch(callsPerChannel * channels.length);
430     for (final ManagedChannel channel : channels) {
431       for (int i = 0; i < callsPerChannel; i++) {
432         final ClientCall<ByteBuf, ByteBuf> streamingCall =
433             channel.newCall(pingPongMethod, CALL_OPTIONS);
434         final AtomicReference<StreamObserver<ByteBuf>> requestObserverRef =
435             new AtomicReference<StreamObserver<ByteBuf>>();
436         final AtomicBoolean ignoreMessages = new AtomicBoolean();
437         StreamObserver<ByteBuf> requestObserver = ClientCalls.asyncBidiStreamingCall(
438             streamingCall,
439             new StreamObserver<ByteBuf>() {
440               @Override
441               public void onNext(ByteBuf value) {
442                 if (done.get()) {
443                   if (!ignoreMessages.getAndSet(true)) {
444                     requestObserverRef.get().onCompleted();
445                   }
446                   return;
447                 }
448                 requestObserverRef.get().onNext(request.slice());
449                 if (record.get()) {
450                   counter.addAndGet(counterDelta);
451                 }
452                 // request is called automatically because the observer implicitly has auto
453                 // inbound flow control
454               }
455 
456               @Override
457               public void onError(Throwable t) {
458                 logger.log(Level.WARNING, "call error", t);
459                 latch.countDown();
460               }
461 
462               @Override
463               public void onCompleted() {
464                 latch.countDown();
465               }
466             });
467         requestObserverRef.set(requestObserver);
468         requestObserver.onNext(request.slice());
469         requestObserver.onNext(request.slice());
470       }
471     }
472     return latch;
473   }
474 
475   /**
476    * Start a continuously executing set of duplex streaming ping-pong calls that will terminate when
477    * {@code done.get()} is true. Each completed call will increment the counter by the specified
478    * delta which benchmarks can use to measure messages per second or bandwidth.
479    */
startFlowControlledStreamingCalls(int callsPerChannel, final AtomicLong counter, final AtomicBoolean record, final AtomicBoolean done, final long counterDelta)480   protected CountDownLatch startFlowControlledStreamingCalls(int callsPerChannel,
481       final AtomicLong counter, final AtomicBoolean record, final AtomicBoolean done,
482       final long counterDelta) {
483     final CountDownLatch latch = new CountDownLatch(callsPerChannel * channels.length);
484     for (final ManagedChannel channel : channels) {
485       for (int i = 0; i < callsPerChannel; i++) {
486         final ClientCall<ByteBuf, ByteBuf> streamingCall =
487             channel.newCall(flowControlledStreaming, CALL_OPTIONS);
488         final AtomicReference<StreamObserver<ByteBuf>> requestObserverRef =
489             new AtomicReference<StreamObserver<ByteBuf>>();
490         final AtomicBoolean ignoreMessages = new AtomicBoolean();
491         StreamObserver<ByteBuf> requestObserver = ClientCalls.asyncBidiStreamingCall(
492             streamingCall,
493             new StreamObserver<ByteBuf>() {
494               @Override
495               public void onNext(ByteBuf value) {
496                 StreamObserver<ByteBuf> obs = requestObserverRef.get();
497                 if (done.get()) {
498                   if (!ignoreMessages.getAndSet(true)) {
499                     obs.onCompleted();
500                   }
501                   return;
502                 }
503                 if (record.get()) {
504                   counter.addAndGet(counterDelta);
505                 }
506                 // request is called automatically because the observer implicitly has auto
507                 // inbound flow control
508               }
509 
510               @Override
511               public void onError(Throwable t) {
512                 logger.log(Level.WARNING, "call error", t);
513                 latch.countDown();
514               }
515 
516               @Override
517               public void onCompleted() {
518                 latch.countDown();
519               }
520             });
521         requestObserverRef.set(requestObserver);
522 
523         // Add some outstanding requests to ensure the server is filling the connection
524         streamingCall.request(5);
525         requestObserver.onNext(request.slice());
526       }
527     }
528     return latch;
529   }
530 
531   /**
532    * Shutdown all the client channels and then shutdown the server.
533    */
teardown()534   protected void teardown() throws Exception {
535     logger.fine("shutting down channels");
536     for (ManagedChannel channel : channels) {
537       channel.shutdown();
538     }
539     logger.fine("shutting down server");
540     server.shutdown();
541     if (!server.awaitTermination(5, TimeUnit.SECONDS)) {
542       logger.warning("Failed to shutdown server");
543     }
544     logger.fine("server shut down");
545     for (ManagedChannel channel : channels) {
546       if (!channel.awaitTermination(1, TimeUnit.SECONDS)) {
547         logger.warning("Failed to shutdown client");
548       }
549     }
550     logger.fine("channels shut down");
551   }
552 }
553