1 /*
2  * Copyright 2018 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.alts.internal;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 import static io.grpc.alts.internal.ByteBufTestUtils.getDirectBuffer;
21 import static java.nio.charset.StandardCharsets.UTF_8;
22 import static org.junit.Assert.fail;
23 
24 import io.grpc.alts.internal.ByteBufTestUtils.RegisterRef;
25 import io.grpc.alts.internal.TsiFrameProtector.Consumer;
26 import io.netty.buffer.ByteBuf;
27 import io.netty.buffer.Unpooled;
28 import io.netty.buffer.UnpooledByteBufAllocator;
29 import java.nio.ByteBuffer;
30 import java.security.GeneralSecurityException;
31 import java.util.ArrayList;
32 import java.util.Collections;
33 import java.util.List;
34 import javax.crypto.AEADBadTagException;
35 
36 /** Utility class that provides tests for implementations of @{link TsiHandshaker}. */
37 public final class TsiTest {
38   private static final String DECRYPTION_FAILURE_RE = "Tag mismatch!";
39 
TsiTest()40   private TsiTest() {}
41 
42   /** A @{code TsiHandshaker} pair for running tests. */
43   public static class Handshakers {
44     private final TsiHandshaker client;
45     private final TsiHandshaker server;
46 
Handshakers(TsiHandshaker client, TsiHandshaker server)47     public Handshakers(TsiHandshaker client, TsiHandshaker server) {
48       this.client = client;
49       this.server = server;
50     }
51 
getClient()52     public TsiHandshaker getClient() {
53       return client;
54     }
55 
getServer()56     public TsiHandshaker getServer() {
57       return server;
58     }
59   }
60 
61   private static final int DEFAULT_TRANSPORT_BUFFER_SIZE = 2048;
62 
63   private static final UnpooledByteBufAllocator alloc = UnpooledByteBufAllocator.DEFAULT;
64 
65   private static final String EXAMPLE_MESSAGE1 = "hello world";
66   private static final String EXAMPLE_MESSAGE2 = "oysteroystersoysterseateateat";
67 
68   private static final int EXAMPLE_MESSAGE1_LEN = EXAMPLE_MESSAGE1.getBytes(UTF_8).length;
69   private static final int EXAMPLE_MESSAGE2_LEN = EXAMPLE_MESSAGE2.getBytes(UTF_8).length;
70 
getDefaultTransportBufferSize()71   static int getDefaultTransportBufferSize() {
72     return DEFAULT_TRANSPORT_BUFFER_SIZE;
73   }
74 
75   /**
76    * Performs a handshake between the client handshaker and server handshaker using a transport of
77    * length transportBufferSize.
78    */
performHandshake(int transportBufferSize, Handshakers handshakers)79   static void performHandshake(int transportBufferSize, Handshakers handshakers)
80       throws GeneralSecurityException {
81     TsiHandshaker clientHandshaker = handshakers.getClient();
82     TsiHandshaker serverHandshaker = handshakers.getServer();
83 
84     byte[] transportBufferBytes = new byte[transportBufferSize];
85     ByteBuffer transportBuffer = ByteBuffer.wrap(transportBufferBytes);
86     transportBuffer.limit(0); // Start off with an empty buffer
87 
88     while (clientHandshaker.isInProgress() || serverHandshaker.isInProgress()) {
89       for (TsiHandshaker handshaker : new TsiHandshaker[] {clientHandshaker, serverHandshaker}) {
90         if (handshaker.isInProgress()) {
91           // Process any bytes on the wire.
92           if (transportBuffer.hasRemaining()) {
93             handshaker.processBytesFromPeer(transportBuffer);
94           }
95           // Put new bytes on the wire, if needed.
96           if (handshaker.isInProgress()) {
97             transportBuffer.clear();
98             handshaker.getBytesToSendToPeer(transportBuffer);
99             transportBuffer.flip();
100           }
101         }
102       }
103     }
104     clientHandshaker.extractPeer();
105     serverHandshaker.extractPeer();
106   }
107 
handshakeTest(Handshakers handshakers)108   public static void handshakeTest(Handshakers handshakers) throws GeneralSecurityException {
109     performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
110   }
111 
handshakeSmallBufferTest(Handshakers handshakers)112   public static void handshakeSmallBufferTest(Handshakers handshakers)
113       throws GeneralSecurityException {
114     performHandshake(9, handshakers);
115   }
116 
117   /** Sends a message between the sender and receiver. */
sendMessage( TsiFrameProtector sender, TsiFrameProtector receiver, int recvFragmentSize, String message, RegisterRef ref)118   private static void sendMessage(
119       TsiFrameProtector sender,
120       TsiFrameProtector receiver,
121       int recvFragmentSize,
122       String message,
123       RegisterRef ref)
124       throws GeneralSecurityException {
125 
126     ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
127     final List<ByteBuf> protectOut = new ArrayList<>();
128     List<Object> unprotectOut = new ArrayList<>();
129 
130     sender.protectFlush(
131         Collections.singletonList(plaintextBuffer),
132         new Consumer<ByteBuf>() {
133           @Override
134           public void accept(ByteBuf buf) {
135             protectOut.add(buf);
136           }
137         },
138         alloc);
139     assertThat(protectOut.size()).isEqualTo(1);
140 
141     ByteBuf protect = ref.register(protectOut.get(0));
142     while (protect.isReadable()) {
143       ByteBuf buf = protect;
144       if (recvFragmentSize > 0) {
145         int size = Math.min(protect.readableBytes(), recvFragmentSize);
146         buf = protect.readSlice(size);
147       }
148       receiver.unprotect(buf, unprotectOut, alloc);
149     }
150     ByteBuf plaintextRecvd = getDirectBuffer(message.getBytes(UTF_8).length, ref);
151     for (Object unprotect : unprotectOut) {
152       ByteBuf unprotectBuf = ref.register((ByteBuf) unprotect);
153       plaintextRecvd.writeBytes(unprotectBuf);
154     }
155     assertThat(plaintextRecvd).isEqualTo(Unpooled.wrappedBuffer(message.getBytes(UTF_8)));
156   }
157 
158   /** Ping pong test. */
pingPongTest(Handshakers handshakers, RegisterRef ref)159   public static void pingPongTest(Handshakers handshakers, RegisterRef ref)
160       throws GeneralSecurityException {
161     performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
162 
163     TsiFrameProtector clientProtector = handshakers.getClient().createFrameProtector(alloc);
164     TsiFrameProtector serverProtector = handshakers.getServer().createFrameProtector(alloc);
165 
166     sendMessage(clientProtector, serverProtector, -1, EXAMPLE_MESSAGE1, ref);
167     sendMessage(serverProtector, clientProtector, -1, EXAMPLE_MESSAGE2, ref);
168 
169     clientProtector.destroy();
170     serverProtector.destroy();
171   }
172 
173   /** Ping pong test with exact frame size. */
pingPongExactFrameSizeTest(Handshakers handshakers, RegisterRef ref)174   public static void pingPongExactFrameSizeTest(Handshakers handshakers, RegisterRef ref)
175       throws GeneralSecurityException {
176     performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
177 
178     int frameSize =
179         EXAMPLE_MESSAGE1.getBytes(UTF_8).length
180             + AltsTsiFrameProtector.getHeaderBytes()
181             + FakeChannelCrypter.getTagBytes();
182 
183     TsiFrameProtector clientProtector =
184         handshakers.getClient().createFrameProtector(frameSize, alloc);
185     TsiFrameProtector serverProtector =
186         handshakers.getServer().createFrameProtector(frameSize, alloc);
187 
188     sendMessage(clientProtector, serverProtector, -1, EXAMPLE_MESSAGE1, ref);
189     sendMessage(serverProtector, clientProtector, -1, EXAMPLE_MESSAGE1, ref);
190 
191     clientProtector.destroy();
192     serverProtector.destroy();
193   }
194 
195   /** Ping pong test with small buffer size. */
pingPongSmallBufferTest(Handshakers handshakers, RegisterRef ref)196   public static void pingPongSmallBufferTest(Handshakers handshakers, RegisterRef ref)
197       throws GeneralSecurityException {
198     performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
199 
200     TsiFrameProtector clientProtector = handshakers.getClient().createFrameProtector(alloc);
201     TsiFrameProtector serverProtector = handshakers.getServer().createFrameProtector(alloc);
202 
203     sendMessage(clientProtector, serverProtector, 1, EXAMPLE_MESSAGE1, ref);
204     sendMessage(serverProtector, clientProtector, 1, EXAMPLE_MESSAGE2, ref);
205 
206     clientProtector.destroy();
207     serverProtector.destroy();
208   }
209 
210   /** Ping pong test with small frame size. */
pingPongSmallFrameTest( int frameProtectorOverhead, Handshakers handshakers, RegisterRef ref)211   public static void pingPongSmallFrameTest(
212       int frameProtectorOverhead, Handshakers handshakers, RegisterRef ref)
213       throws GeneralSecurityException {
214     performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
215 
216     // We send messages using small non-aligned buffers. We use 3 and 5, small primes.
217     TsiFrameProtector clientProtector =
218         handshakers.getClient().createFrameProtector(frameProtectorOverhead + 3, alloc);
219     TsiFrameProtector serverProtector =
220         handshakers.getServer().createFrameProtector(frameProtectorOverhead + 5, alloc);
221 
222     sendMessage(clientProtector, serverProtector, EXAMPLE_MESSAGE1_LEN, EXAMPLE_MESSAGE1, ref);
223     sendMessage(serverProtector, clientProtector, EXAMPLE_MESSAGE2_LEN, EXAMPLE_MESSAGE2, ref);
224 
225     clientProtector.destroy();
226     serverProtector.destroy();
227   }
228 
229   /** Ping pong test with small frame and small buffer. */
pingPongSmallFrameSmallBufferTest( int frameProtectorOverhead, Handshakers handshakers, RegisterRef ref)230   public static void pingPongSmallFrameSmallBufferTest(
231       int frameProtectorOverhead, Handshakers handshakers, RegisterRef ref)
232       throws GeneralSecurityException {
233     performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
234 
235     // We send messages using small non-aligned buffers. We use 3 and 5, small primes.
236     TsiFrameProtector clientProtector =
237         handshakers.getClient().createFrameProtector(frameProtectorOverhead + 3, alloc);
238     TsiFrameProtector serverProtector =
239         handshakers.getServer().createFrameProtector(frameProtectorOverhead + 5, alloc);
240 
241     sendMessage(clientProtector, serverProtector, EXAMPLE_MESSAGE1_LEN, EXAMPLE_MESSAGE1, ref);
242     sendMessage(serverProtector, clientProtector, EXAMPLE_MESSAGE2_LEN, EXAMPLE_MESSAGE2, ref);
243 
244     sendMessage(clientProtector, serverProtector, EXAMPLE_MESSAGE1_LEN, EXAMPLE_MESSAGE1, ref);
245     sendMessage(serverProtector, clientProtector, EXAMPLE_MESSAGE2_LEN, EXAMPLE_MESSAGE2, ref);
246 
247     clientProtector.destroy();
248     serverProtector.destroy();
249   }
250 
251   /** Test corrupted counter. */
corruptedCounterTest(Handshakers handshakers, RegisterRef ref)252   public static void corruptedCounterTest(Handshakers handshakers, RegisterRef ref)
253       throws GeneralSecurityException {
254     performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
255 
256     TsiFrameProtector sender = handshakers.getClient().createFrameProtector(alloc);
257     TsiFrameProtector receiver = handshakers.getServer().createFrameProtector(alloc);
258 
259     String message = "hello world";
260     ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
261     final List<ByteBuf> protectOut = new ArrayList<>();
262     List<Object> unprotectOut = new ArrayList<>();
263 
264     sender.protectFlush(
265         Collections.singletonList(plaintextBuffer),
266         new Consumer<ByteBuf>() {
267           @Override
268           public void accept(ByteBuf buf) {
269             protectOut.add(buf);
270           }
271         },
272         alloc);
273     assertThat(protectOut.size()).isEqualTo(1);
274 
275     ByteBuf protect = ref.register(protectOut.get(0));
276     // Unprotect once to increase receiver counter.
277     receiver.unprotect(protect.slice(), unprotectOut, alloc);
278     assertThat(unprotectOut.size()).isEqualTo(1);
279     ref.register((ByteBuf) unprotectOut.get(0));
280 
281     try {
282       receiver.unprotect(protect, unprotectOut, alloc);
283       fail("Exception expected");
284     } catch (AEADBadTagException ex) {
285       assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_RE);
286     }
287 
288     sender.destroy();
289     receiver.destroy();
290   }
291 
292   /** Test corrupted ciphertext. */
corruptedCiphertextTest(Handshakers handshakers, RegisterRef ref)293   public static void corruptedCiphertextTest(Handshakers handshakers, RegisterRef ref)
294       throws GeneralSecurityException {
295     performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
296 
297     TsiFrameProtector sender = handshakers.getClient().createFrameProtector(alloc);
298     TsiFrameProtector receiver = handshakers.getServer().createFrameProtector(alloc);
299 
300     String message = "hello world";
301     ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
302     final List<ByteBuf> protectOut = new ArrayList<>();
303     List<Object> unprotectOut = new ArrayList<>();
304 
305     sender.protectFlush(
306         Collections.singletonList(plaintextBuffer),
307         new Consumer<ByteBuf>() {
308           @Override
309           public void accept(ByteBuf buf) {
310             protectOut.add(buf);
311           }
312         },
313         alloc);
314     assertThat(protectOut.size()).isEqualTo(1);
315 
316     ByteBuf protect = ref.register(protectOut.get(0));
317     int ciphertextIdx = protect.writerIndex() - FakeChannelCrypter.getTagBytes() - 2;
318     protect.setByte(ciphertextIdx, protect.getByte(ciphertextIdx) + 1);
319 
320     try {
321       receiver.unprotect(protect, unprotectOut, alloc);
322       fail("Exception expected");
323     } catch (AEADBadTagException ex) {
324       assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_RE);
325     }
326 
327     sender.destroy();
328     receiver.destroy();
329   }
330 
331   /** Test corrupted tag. */
corruptedTagTest(Handshakers handshakers, RegisterRef ref)332   public static void corruptedTagTest(Handshakers handshakers, RegisterRef ref)
333       throws GeneralSecurityException {
334     performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
335 
336     TsiFrameProtector sender = handshakers.getClient().createFrameProtector(alloc);
337     TsiFrameProtector receiver = handshakers.getServer().createFrameProtector(alloc);
338 
339     String message = "hello world";
340     ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
341     final List<ByteBuf> protectOut = new ArrayList<>();
342     List<Object> unprotectOut = new ArrayList<>();
343 
344     sender.protectFlush(
345         Collections.singletonList(plaintextBuffer),
346         new Consumer<ByteBuf>() {
347           @Override
348           public void accept(ByteBuf buf) {
349             protectOut.add(buf);
350           }
351         },
352         alloc);
353     assertThat(protectOut.size()).isEqualTo(1);
354 
355     ByteBuf protect = ref.register(protectOut.get(0));
356     int tagIdx = protect.writerIndex() - 1;
357     protect.setByte(tagIdx, protect.getByte(tagIdx) + 1);
358 
359     try {
360       receiver.unprotect(protect, unprotectOut, alloc);
361       fail("Exception expected");
362     } catch (AEADBadTagException ex) {
363       assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_RE);
364     }
365 
366     sender.destroy();
367     receiver.destroy();
368   }
369 
370   /** Test reflected ciphertext. */
reflectedCiphertextTest(Handshakers handshakers, RegisterRef ref)371   public static void reflectedCiphertextTest(Handshakers handshakers, RegisterRef ref)
372       throws GeneralSecurityException {
373     performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
374 
375     TsiFrameProtector sender = handshakers.getClient().createFrameProtector(alloc);
376     TsiFrameProtector receiver = handshakers.getServer().createFrameProtector(alloc);
377 
378     String message = "hello world";
379     ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
380     final List<ByteBuf> protectOut = new ArrayList<>();
381     List<Object> unprotectOut = new ArrayList<>();
382 
383     sender.protectFlush(
384         Collections.singletonList(plaintextBuffer),
385         new Consumer<ByteBuf>() {
386           @Override
387           public void accept(ByteBuf buf) {
388             protectOut.add(buf);
389           }
390         },
391         alloc);
392     assertThat(protectOut.size()).isEqualTo(1);
393 
394     ByteBuf protect = ref.register(protectOut.get(0));
395     try {
396       sender.unprotect(protect.slice(), unprotectOut, alloc);
397       fail("Exception expected");
398     } catch (AEADBadTagException ex) {
399       assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_RE);
400     }
401 
402     sender.destroy();
403     receiver.destroy();
404   }
405 }
406