1 /*
2  * Copyright 2015 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.netty;
18 
19 import static com.google.common.base.Charsets.UTF_8;
20 import static org.junit.Assert.assertEquals;
21 import static org.junit.Assert.assertFalse;
22 import static org.junit.Assert.assertNotNull;
23 import static org.junit.Assert.assertNull;
24 import static org.junit.Assert.assertTrue;
25 import static org.mockito.Matchers.any;
26 import static org.mockito.Mockito.mock;
27 import static org.mockito.Mockito.times;
28 
29 import io.grpc.internal.testing.TestUtils;
30 import io.grpc.netty.ProtocolNegotiators.HostPort;
31 import io.grpc.netty.ProtocolNegotiators.ServerTlsHandler;
32 import io.grpc.netty.ProtocolNegotiators.TlsNegotiator;
33 import io.netty.bootstrap.Bootstrap;
34 import io.netty.bootstrap.ServerBootstrap;
35 import io.netty.buffer.ByteBuf;
36 import io.netty.buffer.ByteBufUtil;
37 import io.netty.channel.Channel;
38 import io.netty.channel.ChannelFuture;
39 import io.netty.channel.ChannelHandler;
40 import io.netty.channel.ChannelHandlerContext;
41 import io.netty.channel.ChannelInboundHandler;
42 import io.netty.channel.ChannelPipeline;
43 import io.netty.channel.DefaultEventLoopGroup;
44 import io.netty.channel.embedded.EmbeddedChannel;
45 import io.netty.channel.local.LocalAddress;
46 import io.netty.channel.local.LocalChannel;
47 import io.netty.channel.local.LocalServerChannel;
48 import io.netty.handler.proxy.ProxyConnectException;
49 import io.netty.handler.ssl.SslContext;
50 import io.netty.handler.ssl.SslHandler;
51 import io.netty.handler.ssl.SslHandshakeCompletionEvent;
52 import io.netty.handler.ssl.SupportedCipherSuiteFilter;
53 import java.io.File;
54 import java.net.InetSocketAddress;
55 import java.net.SocketAddress;
56 import java.util.logging.Filter;
57 import java.util.logging.Level;
58 import java.util.logging.LogRecord;
59 import java.util.logging.Logger;
60 import javax.net.ssl.SSLContext;
61 import javax.net.ssl.SSLEngine;
62 import javax.net.ssl.SSLException;
63 import org.junit.Before;
64 import org.junit.Rule;
65 import org.junit.Test;
66 import org.junit.rules.ExpectedException;
67 import org.junit.rules.Timeout;
68 import org.junit.runner.RunWith;
69 import org.junit.runners.JUnit4;
70 import org.mockito.ArgumentCaptor;
71 import org.mockito.Mockito;
72 
73 @RunWith(JUnit4.class)
74 public class ProtocolNegotiatorsTest {
75   private static final Runnable NOOP_RUNNABLE = new Runnable() {
76     @Override public void run() {}
77   };
78 
79   @Rule public final Timeout globalTimeout = Timeout.seconds(5);
80 
81   @Rule
82   public final ExpectedException thrown = ExpectedException.none();
83 
84   private GrpcHttp2ConnectionHandler grpcHandler = mock(GrpcHttp2ConnectionHandler.class);
85 
86   private EmbeddedChannel channel = new EmbeddedChannel();
87   private ChannelPipeline pipeline = channel.pipeline();
88   private SslContext sslContext;
89   private SSLEngine engine;
90   private ChannelHandlerContext channelHandlerCtx;
91 
92   @Before
setUp()93   public void setUp() throws Exception {
94     File serverCert = TestUtils.loadCert("server1.pem");
95     File key = TestUtils.loadCert("server1.key");
96     sslContext = GrpcSslContexts.forServer(serverCert, key)
97         .ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE).build();
98     engine = SSLContext.getDefault().createSSLEngine();
99   }
100 
101   @Test
tlsHandler_failsOnNullEngine()102   public void tlsHandler_failsOnNullEngine() throws Exception {
103     thrown.expect(NullPointerException.class);
104     thrown.expectMessage("ssl");
105 
106     Object unused = ProtocolNegotiators.serverTls(null);
107   }
108 
109   @Test
tlsAdapter_exceptionClosesChannel()110   public void tlsAdapter_exceptionClosesChannel() throws Exception {
111     ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler);
112 
113     // Use addFirst due to the funny error handling in EmbeddedChannel.
114     pipeline.addFirst(handler);
115 
116     pipeline.fireExceptionCaught(new Exception("bad"));
117 
118     assertFalse(channel.isOpen());
119   }
120 
121   @Test
tlsHandler_handlerAddedAddsSslHandler()122   public void tlsHandler_handlerAddedAddsSslHandler() throws Exception {
123     ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler);
124 
125     pipeline.addLast(handler);
126 
127     assertTrue(pipeline.first() instanceof SslHandler);
128   }
129 
130   @Test
tlsHandler_userEventTriggeredNonSslEvent()131   public void tlsHandler_userEventTriggeredNonSslEvent() throws Exception {
132     ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler);
133     pipeline.addLast(handler);
134     channelHandlerCtx = pipeline.context(handler);
135     Object nonSslEvent = new Object();
136 
137     pipeline.fireUserEventTriggered(nonSslEvent);
138 
139     // A non ssl event should not cause the grpcHandler to be in the pipeline yet.
140     ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
141     assertNull(grpcHandlerCtx);
142   }
143 
144   @Test
tlsHandler_userEventTriggeredSslEvent_unsupportedProtocol()145   public void tlsHandler_userEventTriggeredSslEvent_unsupportedProtocol() throws Exception {
146     SslHandler badSslHandler = new SslHandler(engine, false) {
147       @Override
148       public String applicationProtocol() {
149         return "badprotocol";
150       }
151     };
152 
153     ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler);
154     pipeline.addLast(handler);
155 
156     pipeline.replace(SslHandler.class, null, badSslHandler);
157     channelHandlerCtx = pipeline.context(handler);
158     Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
159 
160     pipeline.fireUserEventTriggered(sslEvent);
161 
162     // No h2 protocol was specified, so this should be closed.
163     assertFalse(channel.isOpen());
164     ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
165     assertNull(grpcHandlerCtx);
166   }
167 
168   @Test
tlsHandler_userEventTriggeredSslEvent_handshakeFailure()169   public void tlsHandler_userEventTriggeredSslEvent_handshakeFailure() throws Exception {
170     ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler);
171     pipeline.addLast(handler);
172     channelHandlerCtx = pipeline.context(handler);
173     Object sslEvent = new SslHandshakeCompletionEvent(new RuntimeException("bad"));
174 
175     pipeline.fireUserEventTriggered(sslEvent);
176 
177     // No h2 protocol was specified, so this should be closed.
178     assertFalse(channel.isOpen());
179     ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
180     assertNull(grpcHandlerCtx);
181   }
182 
183   @Test
tlsHandler_userEventTriggeredSslEvent_supportedProtocolH2()184   public void tlsHandler_userEventTriggeredSslEvent_supportedProtocolH2() throws Exception {
185     SslHandler goodSslHandler = new SslHandler(engine, false) {
186       @Override
187       public String applicationProtocol() {
188         return "h2";
189       }
190     };
191 
192     ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler);
193     pipeline.addLast(handler);
194 
195     pipeline.replace(SslHandler.class, null, goodSslHandler);
196     channelHandlerCtx = pipeline.context(handler);
197     Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
198 
199     pipeline.fireUserEventTriggered(sslEvent);
200 
201     assertTrue(channel.isOpen());
202     ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
203     assertNotNull(grpcHandlerCtx);
204   }
205 
206   @Test
tlsHandler_userEventTriggeredSslEvent_supportedProtocolGrpcExp()207   public void tlsHandler_userEventTriggeredSslEvent_supportedProtocolGrpcExp() throws Exception {
208     SslHandler goodSslHandler = new SslHandler(engine, false) {
209       @Override
210       public String applicationProtocol() {
211         return "grpc-exp";
212       }
213     };
214 
215     ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler);
216     pipeline.addLast(handler);
217 
218     pipeline.replace(SslHandler.class, null, goodSslHandler);
219     channelHandlerCtx = pipeline.context(handler);
220     Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
221 
222     pipeline.fireUserEventTriggered(sslEvent);
223 
224     assertTrue(channel.isOpen());
225     ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
226     assertNotNull(grpcHandlerCtx);
227   }
228 
229   @Test
engineLog()230   public void engineLog() {
231     ChannelHandler handler = new ServerTlsHandler(sslContext, grpcHandler);
232     pipeline.addLast(handler);
233     channelHandlerCtx = pipeline.context(handler);
234 
235     Logger logger = Logger.getLogger(ProtocolNegotiators.class.getName());
236     Filter oldFilter = logger.getFilter();
237     try {
238       logger.setFilter(new Filter() {
239         @Override
240         public boolean isLoggable(LogRecord record) {
241           // We still want to the log method to be exercised, just not printed to stderr.
242           return false;
243         }
244       });
245 
246       ProtocolNegotiators.logSslEngineDetails(
247           Level.INFO, channelHandlerCtx, "message", new Exception("bad"));
248     } finally {
249       logger.setFilter(oldFilter);
250     }
251   }
252 
253   @Test
tls_failsOnNullSslContext()254   public void tls_failsOnNullSslContext() {
255     thrown.expect(NullPointerException.class);
256 
257     Object unused = ProtocolNegotiators.tls(null);
258   }
259 
260   @Test
tls_hostAndPort()261   public void tls_hostAndPort() throws SSLException {
262     SslContext ctx = GrpcSslContexts.forClient().build();
263     TlsNegotiator negotiator = (TlsNegotiator) ProtocolNegotiators.tls(ctx);
264     HostPort hostPort = negotiator.parseAuthority("authority:1234");
265 
266     assertEquals("authority", hostPort.host);
267     assertEquals(1234, hostPort.port);
268   }
269 
270   @Test
tls_host()271   public void tls_host() throws SSLException {
272     SslContext ctx = GrpcSslContexts.forClient().build();
273     TlsNegotiator negotiator = (TlsNegotiator) ProtocolNegotiators.tls(ctx);
274     HostPort hostPort = negotiator.parseAuthority("[::1]");
275 
276     assertEquals("[::1]", hostPort.host);
277     assertEquals(-1, hostPort.port);
278   }
279 
280   @Test
tls_invalidHost()281   public void tls_invalidHost() throws SSLException {
282     SslContext ctx = GrpcSslContexts.forClient().build();
283     TlsNegotiator negotiator = (TlsNegotiator) ProtocolNegotiators.tls(ctx);
284     HostPort hostPort = negotiator.parseAuthority("bad_host:1234");
285 
286     // Even though it looks like a port, we treat it as part of the authority, since the host is
287     // invalid.
288     assertEquals("bad_host:1234", hostPort.host);
289     assertEquals(-1, hostPort.port);
290   }
291 
292   @Test
httpProxy_nullAddressNpe()293   public void httpProxy_nullAddressNpe() throws Exception {
294     thrown.expect(NullPointerException.class);
295     Object unused =
296         ProtocolNegotiators.httpProxy(null, "user", "pass", ProtocolNegotiators.plaintext());
297   }
298 
299   @Test
httpProxy_nullNegotiatorNpe()300   public void httpProxy_nullNegotiatorNpe() throws Exception {
301     thrown.expect(NullPointerException.class);
302     Object unused = ProtocolNegotiators.httpProxy(
303         InetSocketAddress.createUnresolved("localhost", 80), "user", "pass", null);
304   }
305 
306   @Test
httpProxy_nullUserPassNoException()307   public void httpProxy_nullUserPassNoException() throws Exception {
308     assertNotNull(ProtocolNegotiators.httpProxy(
309         InetSocketAddress.createUnresolved("localhost", 80), null, null,
310         ProtocolNegotiators.plaintext()));
311   }
312 
313   @Test
httpProxy_completes()314   public void httpProxy_completes() throws Exception {
315     DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
316     // ProxyHandler is incompatible with EmbeddedChannel because when channelRegistered() is called
317     // the channel is already active.
318     LocalAddress proxy = new LocalAddress("httpProxy_completes");
319     SocketAddress host = InetSocketAddress.createUnresolved("specialHost", 314);
320 
321     ChannelInboundHandler mockHandler = mock(ChannelInboundHandler.class);
322     Channel serverChannel = new ServerBootstrap().group(elg).channel(LocalServerChannel.class)
323         .childHandler(mockHandler)
324         .bind(proxy).sync().channel();
325 
326     ProtocolNegotiator nego =
327         ProtocolNegotiators.httpProxy(proxy, null, null, ProtocolNegotiators.plaintext());
328     ChannelHandler handler = nego.newHandler(grpcHandler);
329     Channel channel = new Bootstrap().group(elg).channel(LocalChannel.class).handler(handler)
330         .register().sync().channel();
331     pipeline = channel.pipeline();
332     // Wait for initialization to complete
333     channel.eventLoop().submit(NOOP_RUNNABLE).sync();
334     // The grpcHandler must be in the pipeline, but we don't actually want it during our test
335     // because it will consume all events since it is a mock. We only use it because it is required
336     // to construct the Handler.
337     pipeline.remove(grpcHandler);
338     channel.connect(host).sync();
339     serverChannel.close();
340     ArgumentCaptor<ChannelHandlerContext> contextCaptor =
341         ArgumentCaptor.forClass(ChannelHandlerContext.class);
342     Mockito.verify(mockHandler).channelActive(contextCaptor.capture());
343     ChannelHandlerContext serverContext = contextCaptor.getValue();
344 
345     final String golden = "isThisThingOn?";
346     ChannelFuture negotiationFuture = channel.writeAndFlush(bb(golden, channel));
347 
348     // Wait for sending initial request to complete
349     channel.eventLoop().submit(NOOP_RUNNABLE).sync();
350     ArgumentCaptor<Object> objectCaptor = ArgumentCaptor.forClass(Object.class);
351     Mockito.verify(mockHandler)
352         .channelRead(any(ChannelHandlerContext.class), objectCaptor.capture());
353     ByteBuf b = (ByteBuf) objectCaptor.getValue();
354     String request = b.toString(UTF_8);
355     b.release();
356     assertTrue("No trailing newline: " + request, request.endsWith("\r\n\r\n"));
357     assertTrue("No CONNECT: " + request, request.startsWith("CONNECT specialHost:314 "));
358     assertTrue("No host header: " + request, request.contains("host: specialHost:314"));
359 
360     assertFalse(negotiationFuture.isDone());
361     serverContext.writeAndFlush(bb("HTTP/1.1 200 OK\r\n\r\n", serverContext.channel())).sync();
362     negotiationFuture.sync();
363 
364     channel.eventLoop().submit(NOOP_RUNNABLE).sync();
365     objectCaptor.getAllValues().clear();
366     Mockito.verify(mockHandler, times(2))
367         .channelRead(any(ChannelHandlerContext.class), objectCaptor.capture());
368     b = (ByteBuf) objectCaptor.getAllValues().get(1);
369     // If we were using the real grpcHandler, this would have been the HTTP/2 preface
370     String preface = b.toString(UTF_8);
371     b.release();
372     assertEquals(golden, preface);
373 
374     channel.close();
375   }
376 
377   @Test
httpProxy_500()378   public void httpProxy_500() throws Exception {
379     DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
380     // ProxyHandler is incompatible with EmbeddedChannel because when channelRegistered() is called
381     // the channel is already active.
382     LocalAddress proxy = new LocalAddress("httpProxy_500");
383     SocketAddress host = InetSocketAddress.createUnresolved("specialHost", 314);
384 
385     ChannelInboundHandler mockHandler = mock(ChannelInboundHandler.class);
386     Channel serverChannel = new ServerBootstrap().group(elg).channel(LocalServerChannel.class)
387         .childHandler(mockHandler)
388         .bind(proxy).sync().channel();
389 
390     ProtocolNegotiator nego =
391         ProtocolNegotiators.httpProxy(proxy, null, null, ProtocolNegotiators.plaintext());
392     ChannelHandler handler = nego.newHandler(grpcHandler);
393     Channel channel = new Bootstrap().group(elg).channel(LocalChannel.class).handler(handler)
394         .register().sync().channel();
395     pipeline = channel.pipeline();
396     // Wait for initialization to complete
397     channel.eventLoop().submit(NOOP_RUNNABLE).sync();
398     // The grpcHandler must be in the pipeline, but we don't actually want it during our test
399     // because it will consume all events since it is a mock. We only use it because it is required
400     // to construct the Handler.
401     pipeline.remove(grpcHandler);
402     channel.connect(host).sync();
403     serverChannel.close();
404     ArgumentCaptor<ChannelHandlerContext> contextCaptor =
405         ArgumentCaptor.forClass(ChannelHandlerContext.class);
406     Mockito.verify(mockHandler).channelActive(contextCaptor.capture());
407     ChannelHandlerContext serverContext = contextCaptor.getValue();
408 
409     final String golden = "isThisThingOn?";
410     ChannelFuture negotiationFuture = channel.writeAndFlush(bb(golden, channel));
411 
412     // Wait for sending initial request to complete
413     channel.eventLoop().submit(NOOP_RUNNABLE).sync();
414     ArgumentCaptor<Object> objectCaptor = ArgumentCaptor.forClass(Object.class);
415     Mockito.verify(mockHandler)
416         .channelRead(any(ChannelHandlerContext.class), objectCaptor.capture());
417     ByteBuf request = (ByteBuf) objectCaptor.getValue();
418     request.release();
419 
420     assertFalse(negotiationFuture.isDone());
421     String response = "HTTP/1.1 500 OMG\r\nContent-Length: 4\r\n\r\noops";
422     serverContext.writeAndFlush(bb(response, serverContext.channel())).sync();
423     thrown.expect(ProxyConnectException.class);
424     try {
425       negotiationFuture.sync();
426     } finally {
427       channel.close();
428     }
429   }
430 
bb(String s, Channel c)431   private static ByteBuf bb(String s, Channel c) {
432     return ByteBufUtil.writeUtf8(c.alloc(), s);
433   }
434 }
435