1 /*
2  * Copyright 2018 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.alts.internal;
18 
19 import static com.google.common.base.Preconditions.checkNotNull;
20 import static com.google.common.base.Preconditions.checkState;
21 
22 import com.google.common.annotations.VisibleForTesting;
23 import io.grpc.alts.internal.TsiFrameProtector.Consumer;
24 import io.grpc.alts.internal.TsiHandshakeHandler.TsiHandshakeCompletionEvent;
25 import io.netty.buffer.ByteBuf;
26 import io.netty.channel.ChannelException;
27 import io.netty.channel.ChannelHandlerContext;
28 import io.netty.channel.ChannelOutboundHandler;
29 import io.netty.channel.ChannelPromise;
30 import io.netty.channel.PendingWriteQueue;
31 import io.netty.handler.codec.ByteToMessageDecoder;
32 import java.net.SocketAddress;
33 import java.security.GeneralSecurityException;
34 import java.util.ArrayList;
35 import java.util.List;
36 import java.util.concurrent.Future;
37 
38 /**
39  * Encrypts and decrypts TSI Frames. Writes are buffered here until {@link #flush} is called. Writes
40  * must not be made before the TSI handshake is complete.
41  */
42 public final class TsiFrameHandler extends ByteToMessageDecoder implements ChannelOutboundHandler {
43 
44   private TsiFrameProtector protector;
45   private PendingWriteQueue pendingUnprotectedWrites;
46 
TsiFrameHandler()47   public TsiFrameHandler() {}
48 
49   @Override
handlerAdded(ChannelHandlerContext ctx)50   public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
51     super.handlerAdded(ctx);
52     assert pendingUnprotectedWrites == null;
53     pendingUnprotectedWrites = new PendingWriteQueue(checkNotNull(ctx));
54   }
55 
56   @Override
userEventTriggered(ChannelHandlerContext ctx, Object event)57   public void userEventTriggered(ChannelHandlerContext ctx, Object event) throws Exception {
58     if (event instanceof TsiHandshakeCompletionEvent) {
59       TsiHandshakeCompletionEvent tsiEvent = (TsiHandshakeCompletionEvent) event;
60       if (tsiEvent.isSuccess()) {
61         setProtector(tsiEvent.protector());
62       }
63       // Ignore errors.  Another handler in the pipeline must handle TSI Errors.
64     }
65     // Keep propagating the message, as others may want to read it.
66     super.userEventTriggered(ctx, event);
67   }
68 
69   @VisibleForTesting
setProtector(TsiFrameProtector protector)70   void setProtector(TsiFrameProtector protector) {
71     checkState(this.protector == null);
72     this.protector = checkNotNull(protector);
73   }
74 
75   @Override
decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out)76   protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
77     checkState(protector != null, "Cannot read frames while the TSI handshake is in progress");
78     protector.unprotect(in, out, ctx.alloc());
79   }
80 
81   @Override
write(ChannelHandlerContext ctx, Object message, ChannelPromise promise)82   public void write(ChannelHandlerContext ctx, Object message, ChannelPromise promise)
83       throws Exception {
84     checkState(protector != null, "Cannot write frames while the TSI handshake is in progress");
85     ByteBuf msg = (ByteBuf) message;
86     if (!msg.isReadable()) {
87       // Nothing to encode.
88       @SuppressWarnings("unused") // go/futurereturn-lsc
89       Future<?> possiblyIgnoredError = promise.setSuccess();
90       return;
91     }
92 
93     // Just add the message to the pending queue. We'll write it on the next flush.
94     pendingUnprotectedWrites.add(msg, promise);
95   }
96 
97   @Override
handlerRemoved0(ChannelHandlerContext ctx)98   public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
99     if (!pendingUnprotectedWrites.isEmpty()) {
100       pendingUnprotectedWrites.removeAndFailAll(
101           new ChannelException("Pending write on removal of TSI handler"));
102     }
103   }
104 
105   @Override
exceptionCaught(ChannelHandlerContext ctx, Throwable cause)106   public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
107     pendingUnprotectedWrites.removeAndFailAll(cause);
108     super.exceptionCaught(ctx, cause);
109   }
110 
111   @Override
bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise)112   public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) {
113     ctx.bind(localAddress, promise);
114   }
115 
116   @Override
connect( ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise)117   public void connect(
118       ChannelHandlerContext ctx,
119       SocketAddress remoteAddress,
120       SocketAddress localAddress,
121       ChannelPromise promise) {
122     ctx.connect(remoteAddress, localAddress, promise);
123   }
124 
125   @Override
disconnect(ChannelHandlerContext ctx, ChannelPromise promise)126   public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) {
127     ctx.disconnect(promise);
128   }
129 
130   @Override
close(ChannelHandlerContext ctx, ChannelPromise promise)131   public void close(ChannelHandlerContext ctx, ChannelPromise promise) {
132     ctx.close(promise);
133   }
134 
135   @Override
deregister(ChannelHandlerContext ctx, ChannelPromise promise)136   public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) {
137     ctx.deregister(promise);
138   }
139 
140   @Override
read(ChannelHandlerContext ctx)141   public void read(ChannelHandlerContext ctx) {
142     ctx.read();
143   }
144 
145   @Override
flush(final ChannelHandlerContext ctx)146   public void flush(final ChannelHandlerContext ctx) throws GeneralSecurityException {
147     checkState(protector != null, "Cannot write frames while the TSI handshake is in progress");
148     final ProtectedPromise aggregatePromise =
149         new ProtectedPromise(ctx.channel(), ctx.executor(), pendingUnprotectedWrites.size());
150 
151     List<ByteBuf> bufs = new ArrayList<>(pendingUnprotectedWrites.size());
152 
153     if (pendingUnprotectedWrites.isEmpty()) {
154       // Return early if there's nothing to write. Otherwise protector.protectFlush() below may
155       // not check for "no-data" and go on writing the 0-byte "data" to the socket with the
156       // protection framing.
157       return;
158     }
159     // Drain the unprotected writes.
160     while (!pendingUnprotectedWrites.isEmpty()) {
161       ByteBuf in = (ByteBuf) pendingUnprotectedWrites.current();
162       bufs.add(in.retain());
163       // Remove and release the buffer and add its promise to the aggregate.
164       aggregatePromise.addUnprotectedPromise(pendingUnprotectedWrites.remove());
165     }
166 
167     protector.protectFlush(
168         bufs,
169         new Consumer<ByteBuf>() {
170           @Override
171           public void accept(ByteBuf b) {
172             ctx.writeAndFlush(b, aggregatePromise.newPromise());
173           }
174         },
175         ctx.alloc());
176 
177     // We're done writing, start the flow of promise events.
178     @SuppressWarnings("unused") // go/futurereturn-lsc
179     Future<?> possiblyIgnoredError = aggregatePromise.doneAllocatingPromises();
180   }
181 }
182