1 /*
2  * Copyright 2018 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.alts.internal;
18 
19 import static com.google.common.base.Preconditions.checkArgument;
20 import static com.google.common.base.Preconditions.checkState;
21 import static com.google.common.base.Verify.verify;
22 
23 import com.google.common.primitives.Ints;
24 import io.netty.buffer.ByteBuf;
25 import io.netty.buffer.ByteBufAllocator;
26 import java.security.GeneralSecurityException;
27 import java.util.ArrayList;
28 import java.util.List;
29 
30 /** Frame protector that uses the ALTS framing. */
31 public final class AltsTsiFrameProtector implements TsiFrameProtector {
32   private static final int HEADER_LEN_FIELD_BYTES = 4;
33   private static final int HEADER_TYPE_FIELD_BYTES = 4;
34   private static final int HEADER_BYTES = HEADER_LEN_FIELD_BYTES + HEADER_TYPE_FIELD_BYTES;
35   private static final int HEADER_TYPE_DEFAULT = 6;
36   // Total frame size including full header and tag.
37   private static final int MAX_ALLOWED_FRAME_BYTES = 16 * 1024;
38   private static final int LIMIT_MAX_ALLOWED_FRAME_BYTES = 1024 * 1024;
39 
40   private final Protector protector;
41   private final Unprotector unprotector;
42 
43   /** Create a new AltsTsiFrameProtector. */
AltsTsiFrameProtector( int maxProtectedFrameBytes, ChannelCrypterNetty crypter, ByteBufAllocator alloc)44   public AltsTsiFrameProtector(
45       int maxProtectedFrameBytes, ChannelCrypterNetty crypter, ByteBufAllocator alloc) {
46     checkArgument(maxProtectedFrameBytes > HEADER_BYTES + crypter.getSuffixLength());
47     maxProtectedFrameBytes = Math.min(LIMIT_MAX_ALLOWED_FRAME_BYTES, maxProtectedFrameBytes);
48     protector = new Protector(maxProtectedFrameBytes, crypter);
49     unprotector = new Unprotector(crypter, alloc);
50   }
51 
getHeaderLenFieldBytes()52   static int getHeaderLenFieldBytes() {
53     return HEADER_LEN_FIELD_BYTES;
54   }
55 
getHeaderTypeFieldBytes()56   static int getHeaderTypeFieldBytes() {
57     return HEADER_TYPE_FIELD_BYTES;
58   }
59 
getHeaderBytes()60   public static int getHeaderBytes() {
61     return HEADER_BYTES;
62   }
63 
getHeaderTypeDefault()64   static int getHeaderTypeDefault() {
65     return HEADER_TYPE_DEFAULT;
66   }
67 
getMaxAllowedFrameBytes()68   public static int getMaxAllowedFrameBytes() {
69     return MAX_ALLOWED_FRAME_BYTES;
70   }
71 
getLimitMaxAllowedFrameBytes()72   static int getLimitMaxAllowedFrameBytes() {
73     return LIMIT_MAX_ALLOWED_FRAME_BYTES;
74   }
75 
76   @Override
protectFlush( List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc)77   public void protectFlush(
78       List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc)
79       throws GeneralSecurityException {
80     protector.protectFlush(unprotectedBufs, ctxWrite, alloc);
81   }
82 
83   @Override
unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc)84   public void unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc)
85       throws GeneralSecurityException {
86     unprotector.unprotect(in, out, alloc);
87   }
88 
89   @Override
destroy()90   public void destroy() {
91     try {
92       unprotector.destroy();
93     } finally {
94       protector.destroy();
95     }
96   }
97 
98   static final class Protector {
99     private final int maxUnprotectedBytesPerFrame;
100     private final int suffixBytes;
101     private ChannelCrypterNetty crypter;
102 
Protector(int maxProtectedFrameBytes, ChannelCrypterNetty crypter)103     Protector(int maxProtectedFrameBytes, ChannelCrypterNetty crypter) {
104       this.suffixBytes = crypter.getSuffixLength();
105       this.maxUnprotectedBytesPerFrame = maxProtectedFrameBytes - HEADER_BYTES - suffixBytes;
106       this.crypter = crypter;
107     }
108 
destroy()109     void destroy() {
110       // Shared with Unprotector and destroyed there.
111       crypter = null;
112     }
113 
protectFlush( List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc)114     void protectFlush(
115         List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc)
116         throws GeneralSecurityException {
117       checkState(crypter != null, "Cannot protectFlush after destroy.");
118       ByteBuf protectedBuf;
119       try {
120         protectedBuf = handleUnprotected(unprotectedBufs, alloc);
121       } finally {
122         for (ByteBuf buf : unprotectedBufs) {
123           buf.release();
124         }
125       }
126       if (protectedBuf != null) {
127         ctxWrite.accept(protectedBuf);
128       }
129     }
130 
131     @SuppressWarnings("BetaApi") // verify is stable in Guava
handleUnprotected(List<ByteBuf> unprotectedBufs, ByteBufAllocator alloc)132     private ByteBuf handleUnprotected(List<ByteBuf> unprotectedBufs, ByteBufAllocator alloc)
133         throws GeneralSecurityException {
134       long unprotectedBytes = 0;
135       for (ByteBuf buf : unprotectedBufs) {
136         unprotectedBytes += buf.readableBytes();
137       }
138       // Empty plaintext not allowed since this should be handled as no-op in layer above.
139       checkArgument(unprotectedBytes > 0);
140 
141       // Compute number of frames and allocate a single buffer for all frames.
142       long frameNum = unprotectedBytes / maxUnprotectedBytesPerFrame + 1;
143       int lastFrameUnprotectedBytes = (int) (unprotectedBytes % maxUnprotectedBytesPerFrame);
144       if (lastFrameUnprotectedBytes == 0) {
145         frameNum--;
146         lastFrameUnprotectedBytes = maxUnprotectedBytesPerFrame;
147       }
148       long protectedBytes = frameNum * (HEADER_BYTES + suffixBytes) + unprotectedBytes;
149 
150       ByteBuf protectedBuf = alloc.directBuffer(Ints.checkedCast(protectedBytes));
151       try {
152         int bufferIdx = 0;
153         for (int frameIdx = 0; frameIdx < frameNum; ++frameIdx) {
154           int unprotectedBytesLeft =
155               (frameIdx == frameNum - 1) ? lastFrameUnprotectedBytes : maxUnprotectedBytesPerFrame;
156           // Write header (at most LIMIT_MAX_ALLOWED_FRAME_BYTES).
157           protectedBuf.writeIntLE(unprotectedBytesLeft + HEADER_TYPE_FIELD_BYTES + suffixBytes);
158           protectedBuf.writeIntLE(HEADER_TYPE_DEFAULT);
159 
160           // Ownership of the backing buffer remains with protectedBuf.
161           ByteBuf frameOut = writeSlice(protectedBuf, unprotectedBytesLeft + suffixBytes);
162           List<ByteBuf> framePlain = new ArrayList<>();
163           while (unprotectedBytesLeft > 0) {
164             // Ownership of the buffer backing in remains with unprotectedBufs.
165             ByteBuf in = unprotectedBufs.get(bufferIdx);
166             if (in.readableBytes() <= unprotectedBytesLeft) {
167               // The complete buffer belongs to this frame.
168               framePlain.add(in);
169               unprotectedBytesLeft -= in.readableBytes();
170               bufferIdx++;
171             } else {
172               // The remainder of in will be part of the next frame.
173               framePlain.add(in.readSlice(unprotectedBytesLeft));
174               unprotectedBytesLeft = 0;
175             }
176           }
177           crypter.encrypt(frameOut, framePlain);
178           verify(!frameOut.isWritable());
179         }
180         protectedBuf.readerIndex(0);
181         protectedBuf.writerIndex(protectedBuf.capacity());
182         return protectedBuf.retain();
183       } finally {
184         protectedBuf.release();
185       }
186     }
187   }
188 
189   static final class Unprotector {
190     private final int suffixBytes;
191     private final ChannelCrypterNetty crypter;
192 
193     private DeframerState state = DeframerState.READ_HEADER;
194     private int requiredProtectedBytes;
195     private ByteBuf header;
196     private ByteBuf firstFrameTag;
197     private int unhandledIdx = 0;
198     private long unhandledBytes = 0;
199     private List<ByteBuf> unhandledBufs = new ArrayList<>(16);
200 
Unprotector(ChannelCrypterNetty crypter, ByteBufAllocator alloc)201     Unprotector(ChannelCrypterNetty crypter, ByteBufAllocator alloc) {
202       this.crypter = crypter;
203       this.suffixBytes = crypter.getSuffixLength();
204       this.header = alloc.directBuffer(HEADER_BYTES);
205       this.firstFrameTag = alloc.directBuffer(suffixBytes);
206     }
207 
addUnhandled(ByteBuf in)208     private void addUnhandled(ByteBuf in) {
209       if (in.isReadable()) {
210         ByteBuf buf = in.readRetainedSlice(in.readableBytes());
211         unhandledBufs.add(buf);
212         unhandledBytes += buf.readableBytes();
213       }
214     }
215 
unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc)216     void unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc)
217         throws GeneralSecurityException {
218       checkState(header != null, "Cannot unprotect after destroy.");
219       addUnhandled(in);
220       decodeFrame(alloc, out);
221     }
222 
223     @SuppressWarnings("fallthrough")
decodeFrame(ByteBufAllocator alloc, List<Object> out)224     private void decodeFrame(ByteBufAllocator alloc, List<Object> out)
225         throws GeneralSecurityException {
226       switch (state) {
227         case READ_HEADER:
228           if (unhandledBytes < HEADER_BYTES) {
229             return;
230           }
231           handleHeader();
232           // fall through
233         case READ_PROTECTED_PAYLOAD:
234           if (unhandledBytes < requiredProtectedBytes) {
235             return;
236           }
237           ByteBuf unprotectedBuf;
238           try {
239             unprotectedBuf = handlePayload(alloc);
240           } finally {
241             clearState();
242           }
243           if (unprotectedBuf != null) {
244             out.add(unprotectedBuf);
245           }
246           break;
247         default:
248           throw new AssertionError("impossible enum value");
249       }
250     }
251 
handleHeader()252     private void handleHeader() {
253       while (header.isWritable()) {
254         ByteBuf in = unhandledBufs.get(unhandledIdx);
255         int headerBytesToRead = Math.min(in.readableBytes(), header.writableBytes());
256         header.writeBytes(in, headerBytesToRead);
257         unhandledBytes -= headerBytesToRead;
258         if (!in.isReadable()) {
259           unhandledIdx++;
260         }
261       }
262       requiredProtectedBytes = header.readIntLE() - HEADER_TYPE_FIELD_BYTES;
263       checkArgument(
264           requiredProtectedBytes >= suffixBytes, "Invalid header field: frame size too small");
265       checkArgument(
266           requiredProtectedBytes <= LIMIT_MAX_ALLOWED_FRAME_BYTES - HEADER_BYTES,
267           "Invalid header field: frame size too large");
268       int frameType = header.readIntLE();
269       checkArgument(frameType == HEADER_TYPE_DEFAULT, "Invalid header field: frame type");
270       state = DeframerState.READ_PROTECTED_PAYLOAD;
271     }
272 
273     @SuppressWarnings("BetaApi") // verify is stable in Guava
handlePayload(ByteBufAllocator alloc)274     private ByteBuf handlePayload(ByteBufAllocator alloc) throws GeneralSecurityException {
275       int requiredCiphertextBytes = requiredProtectedBytes - suffixBytes;
276       int firstFrameUnprotectedLen = requiredCiphertextBytes;
277 
278       // We get the ciphertexts of the first frame and copy over the tag into a single buffer.
279       List<ByteBuf> firstFrameCiphertext = new ArrayList<>();
280       while (requiredCiphertextBytes > 0) {
281         ByteBuf buf = unhandledBufs.get(unhandledIdx);
282         if (buf.readableBytes() <= requiredCiphertextBytes) {
283           // We use the whole buffer.
284           firstFrameCiphertext.add(buf);
285           requiredCiphertextBytes -= buf.readableBytes();
286           unhandledIdx++;
287         } else {
288           firstFrameCiphertext.add(buf.readSlice(requiredCiphertextBytes));
289           requiredCiphertextBytes = 0;
290         }
291       }
292       int requiredSuffixBytes = suffixBytes;
293       while (true) {
294         ByteBuf buf = unhandledBufs.get(unhandledIdx);
295         if (buf.readableBytes() <= requiredSuffixBytes) {
296           // We use the whole buffer.
297           requiredSuffixBytes -= buf.readableBytes();
298           firstFrameTag.writeBytes(buf);
299           if (requiredSuffixBytes == 0) {
300             break;
301           }
302           unhandledIdx++;
303         } else {
304           firstFrameTag.writeBytes(buf, requiredSuffixBytes);
305           break;
306         }
307       }
308       verify(unhandledIdx == unhandledBufs.size() - 1);
309       ByteBuf lastBuf = unhandledBufs.get(unhandledIdx);
310 
311       // We get the remaining ciphertexts and tags contained in the last buffer.
312       List<ByteBuf> ciphertextsAndTags = new ArrayList<>();
313       List<Integer> unprotectedLens = new ArrayList<>();
314       long requiredUnprotectedBytesCompleteFrames = firstFrameUnprotectedLen;
315       while (lastBuf.readableBytes() >= HEADER_BYTES + suffixBytes) {
316         // Read frame size.
317         int frameSize = lastBuf.readIntLE();
318         int payloadSize = frameSize - HEADER_TYPE_FIELD_BYTES - suffixBytes;
319         // Break and undo read if we don't have the complete frame yet.
320         if (lastBuf.readableBytes() < frameSize) {
321           lastBuf.readerIndex(lastBuf.readerIndex() - HEADER_LEN_FIELD_BYTES);
322           break;
323         }
324         // Check the type header.
325         checkArgument(lastBuf.readIntLE() == 6);
326         // Create a new frame (except for out buffer).
327         ciphertextsAndTags.add(lastBuf.readSlice(payloadSize + suffixBytes));
328         // Update sizes for frame.
329         requiredUnprotectedBytesCompleteFrames += payloadSize;
330         unprotectedLens.add(payloadSize);
331       }
332 
333       // We leave space for suffixBytes to allow for in-place encryption. This allows for calling
334       // doFinal in the JCE implementation which can be optimized better than update and doFinal.
335       ByteBuf unprotectedBuf =
336           alloc.directBuffer(
337               Ints.checkedCast(requiredUnprotectedBytesCompleteFrames + suffixBytes));
338       try {
339 
340         ByteBuf out = writeSlice(unprotectedBuf, firstFrameUnprotectedLen + suffixBytes);
341         crypter.decrypt(out, firstFrameTag, firstFrameCiphertext);
342         verify(out.writableBytes() == suffixBytes);
343         unprotectedBuf.writerIndex(unprotectedBuf.writerIndex() - suffixBytes);
344 
345         for (int frameIdx = 0; frameIdx < ciphertextsAndTags.size(); ++frameIdx) {
346           out = writeSlice(unprotectedBuf, unprotectedLens.get(frameIdx) + suffixBytes);
347           crypter.decrypt(out, ciphertextsAndTags.get(frameIdx));
348           verify(out.writableBytes() == suffixBytes);
349           unprotectedBuf.writerIndex(unprotectedBuf.writerIndex() - suffixBytes);
350         }
351         return unprotectedBuf.retain();
352       } finally {
353         unprotectedBuf.release();
354       }
355     }
356 
clearState()357     private void clearState() {
358       int bufsSize = unhandledBufs.size();
359       ByteBuf lastBuf = unhandledBufs.get(bufsSize - 1);
360       boolean keepLast = lastBuf.isReadable();
361       for (int bufIdx = 0; bufIdx < (keepLast ? bufsSize - 1 : bufsSize); ++bufIdx) {
362         unhandledBufs.get(bufIdx).release();
363       }
364       unhandledBufs.clear();
365       unhandledBytes = 0;
366       unhandledIdx = 0;
367       if (keepLast) {
368         unhandledBufs.add(lastBuf);
369         unhandledBytes = lastBuf.readableBytes();
370       }
371       state = DeframerState.READ_HEADER;
372       requiredProtectedBytes = 0;
373       header.clear();
374       firstFrameTag.clear();
375     }
376 
destroy()377     void destroy() {
378       for (ByteBuf unhandledBuf : unhandledBufs) {
379         unhandledBuf.release();
380       }
381       unhandledBufs.clear();
382       if (header != null) {
383         header.release();
384         header = null;
385       }
386       if (firstFrameTag != null) {
387         firstFrameTag.release();
388         firstFrameTag = null;
389       }
390       crypter.destroy();
391     }
392   }
393 
394   private enum DeframerState {
395     READ_HEADER,
396     READ_PROTECTED_PAYLOAD
397   }
398 
writeSlice(ByteBuf in, int len)399   private static ByteBuf writeSlice(ByteBuf in, int len) {
400     checkArgument(len <= in.writableBytes());
401     ByteBuf out = in.slice(in.writerIndex(), len);
402     in.writerIndex(in.writerIndex() + len);
403     return out.writerIndex(0);
404   }
405 }
406