1 /* 2 * Copyright 2018 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.alts.internal; 18 19 import static com.google.common.truth.Truth.assertThat; 20 import static io.grpc.alts.internal.ByteBufTestUtils.getDirectBuffer; 21 import static java.nio.charset.StandardCharsets.UTF_8; 22 import static org.junit.Assert.fail; 23 24 import io.grpc.alts.internal.ByteBufTestUtils.RegisterRef; 25 import io.grpc.alts.internal.TsiFrameProtector.Consumer; 26 import io.netty.buffer.ByteBuf; 27 import io.netty.buffer.Unpooled; 28 import io.netty.buffer.UnpooledByteBufAllocator; 29 import java.nio.ByteBuffer; 30 import java.security.GeneralSecurityException; 31 import java.util.ArrayList; 32 import java.util.Collections; 33 import java.util.List; 34 import javax.crypto.AEADBadTagException; 35 36 /** Utility class that provides tests for implementations of @{link TsiHandshaker}. */ 37 public final class TsiTest { 38 private static final String DECRYPTION_FAILURE_RE = "Tag mismatch!"; 39 TsiTest()40 private TsiTest() {} 41 42 /** A @{code TsiHandshaker} pair for running tests. */ 43 public static class Handshakers { 44 private final TsiHandshaker client; 45 private final TsiHandshaker server; 46 Handshakers(TsiHandshaker client, TsiHandshaker server)47 public Handshakers(TsiHandshaker client, TsiHandshaker server) { 48 this.client = client; 49 this.server = server; 50 } 51 getClient()52 public TsiHandshaker getClient() { 53 return client; 54 } 55 getServer()56 public TsiHandshaker getServer() { 57 return server; 58 } 59 } 60 61 private static final int DEFAULT_TRANSPORT_BUFFER_SIZE = 2048; 62 63 private static final UnpooledByteBufAllocator alloc = UnpooledByteBufAllocator.DEFAULT; 64 65 private static final String EXAMPLE_MESSAGE1 = "hello world"; 66 private static final String EXAMPLE_MESSAGE2 = "oysteroystersoysterseateateat"; 67 68 private static final int EXAMPLE_MESSAGE1_LEN = EXAMPLE_MESSAGE1.getBytes(UTF_8).length; 69 private static final int EXAMPLE_MESSAGE2_LEN = EXAMPLE_MESSAGE2.getBytes(UTF_8).length; 70 getDefaultTransportBufferSize()71 static int getDefaultTransportBufferSize() { 72 return DEFAULT_TRANSPORT_BUFFER_SIZE; 73 } 74 75 /** 76 * Performs a handshake between the client handshaker and server handshaker using a transport of 77 * length transportBufferSize. 78 */ performHandshake(int transportBufferSize, Handshakers handshakers)79 static void performHandshake(int transportBufferSize, Handshakers handshakers) 80 throws GeneralSecurityException { 81 TsiHandshaker clientHandshaker = handshakers.getClient(); 82 TsiHandshaker serverHandshaker = handshakers.getServer(); 83 84 byte[] transportBufferBytes = new byte[transportBufferSize]; 85 ByteBuffer transportBuffer = ByteBuffer.wrap(transportBufferBytes); 86 transportBuffer.limit(0); // Start off with an empty buffer 87 88 while (clientHandshaker.isInProgress() || serverHandshaker.isInProgress()) { 89 for (TsiHandshaker handshaker : new TsiHandshaker[] {clientHandshaker, serverHandshaker}) { 90 if (handshaker.isInProgress()) { 91 // Process any bytes on the wire. 92 if (transportBuffer.hasRemaining()) { 93 handshaker.processBytesFromPeer(transportBuffer); 94 } 95 // Put new bytes on the wire, if needed. 96 if (handshaker.isInProgress()) { 97 transportBuffer.clear(); 98 handshaker.getBytesToSendToPeer(transportBuffer); 99 transportBuffer.flip(); 100 } 101 } 102 } 103 } 104 clientHandshaker.extractPeer(); 105 serverHandshaker.extractPeer(); 106 } 107 handshakeTest(Handshakers handshakers)108 public static void handshakeTest(Handshakers handshakers) throws GeneralSecurityException { 109 performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers); 110 } 111 handshakeSmallBufferTest(Handshakers handshakers)112 public static void handshakeSmallBufferTest(Handshakers handshakers) 113 throws GeneralSecurityException { 114 performHandshake(9, handshakers); 115 } 116 117 /** Sends a message between the sender and receiver. */ sendMessage( TsiFrameProtector sender, TsiFrameProtector receiver, int recvFragmentSize, String message, RegisterRef ref)118 private static void sendMessage( 119 TsiFrameProtector sender, 120 TsiFrameProtector receiver, 121 int recvFragmentSize, 122 String message, 123 RegisterRef ref) 124 throws GeneralSecurityException { 125 126 ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8)); 127 final List<ByteBuf> protectOut = new ArrayList<>(); 128 List<Object> unprotectOut = new ArrayList<>(); 129 130 sender.protectFlush( 131 Collections.singletonList(plaintextBuffer), 132 new Consumer<ByteBuf>() { 133 @Override 134 public void accept(ByteBuf buf) { 135 protectOut.add(buf); 136 } 137 }, 138 alloc); 139 assertThat(protectOut.size()).isEqualTo(1); 140 141 ByteBuf protect = ref.register(protectOut.get(0)); 142 while (protect.isReadable()) { 143 ByteBuf buf = protect; 144 if (recvFragmentSize > 0) { 145 int size = Math.min(protect.readableBytes(), recvFragmentSize); 146 buf = protect.readSlice(size); 147 } 148 receiver.unprotect(buf, unprotectOut, alloc); 149 } 150 ByteBuf plaintextRecvd = getDirectBuffer(message.getBytes(UTF_8).length, ref); 151 for (Object unprotect : unprotectOut) { 152 ByteBuf unprotectBuf = ref.register((ByteBuf) unprotect); 153 plaintextRecvd.writeBytes(unprotectBuf); 154 } 155 assertThat(plaintextRecvd).isEqualTo(Unpooled.wrappedBuffer(message.getBytes(UTF_8))); 156 } 157 158 /** Ping pong test. */ pingPongTest(Handshakers handshakers, RegisterRef ref)159 public static void pingPongTest(Handshakers handshakers, RegisterRef ref) 160 throws GeneralSecurityException { 161 performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers); 162 163 TsiFrameProtector clientProtector = handshakers.getClient().createFrameProtector(alloc); 164 TsiFrameProtector serverProtector = handshakers.getServer().createFrameProtector(alloc); 165 166 sendMessage(clientProtector, serverProtector, -1, EXAMPLE_MESSAGE1, ref); 167 sendMessage(serverProtector, clientProtector, -1, EXAMPLE_MESSAGE2, ref); 168 169 clientProtector.destroy(); 170 serverProtector.destroy(); 171 } 172 173 /** Ping pong test with exact frame size. */ pingPongExactFrameSizeTest(Handshakers handshakers, RegisterRef ref)174 public static void pingPongExactFrameSizeTest(Handshakers handshakers, RegisterRef ref) 175 throws GeneralSecurityException { 176 performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers); 177 178 int frameSize = 179 EXAMPLE_MESSAGE1.getBytes(UTF_8).length 180 + AltsTsiFrameProtector.getHeaderBytes() 181 + FakeChannelCrypter.getTagBytes(); 182 183 TsiFrameProtector clientProtector = 184 handshakers.getClient().createFrameProtector(frameSize, alloc); 185 TsiFrameProtector serverProtector = 186 handshakers.getServer().createFrameProtector(frameSize, alloc); 187 188 sendMessage(clientProtector, serverProtector, -1, EXAMPLE_MESSAGE1, ref); 189 sendMessage(serverProtector, clientProtector, -1, EXAMPLE_MESSAGE1, ref); 190 191 clientProtector.destroy(); 192 serverProtector.destroy(); 193 } 194 195 /** Ping pong test with small buffer size. */ pingPongSmallBufferTest(Handshakers handshakers, RegisterRef ref)196 public static void pingPongSmallBufferTest(Handshakers handshakers, RegisterRef ref) 197 throws GeneralSecurityException { 198 performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers); 199 200 TsiFrameProtector clientProtector = handshakers.getClient().createFrameProtector(alloc); 201 TsiFrameProtector serverProtector = handshakers.getServer().createFrameProtector(alloc); 202 203 sendMessage(clientProtector, serverProtector, 1, EXAMPLE_MESSAGE1, ref); 204 sendMessage(serverProtector, clientProtector, 1, EXAMPLE_MESSAGE2, ref); 205 206 clientProtector.destroy(); 207 serverProtector.destroy(); 208 } 209 210 /** Ping pong test with small frame size. */ pingPongSmallFrameTest( int frameProtectorOverhead, Handshakers handshakers, RegisterRef ref)211 public static void pingPongSmallFrameTest( 212 int frameProtectorOverhead, Handshakers handshakers, RegisterRef ref) 213 throws GeneralSecurityException { 214 performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers); 215 216 // We send messages using small non-aligned buffers. We use 3 and 5, small primes. 217 TsiFrameProtector clientProtector = 218 handshakers.getClient().createFrameProtector(frameProtectorOverhead + 3, alloc); 219 TsiFrameProtector serverProtector = 220 handshakers.getServer().createFrameProtector(frameProtectorOverhead + 5, alloc); 221 222 sendMessage(clientProtector, serverProtector, EXAMPLE_MESSAGE1_LEN, EXAMPLE_MESSAGE1, ref); 223 sendMessage(serverProtector, clientProtector, EXAMPLE_MESSAGE2_LEN, EXAMPLE_MESSAGE2, ref); 224 225 clientProtector.destroy(); 226 serverProtector.destroy(); 227 } 228 229 /** Ping pong test with small frame and small buffer. */ pingPongSmallFrameSmallBufferTest( int frameProtectorOverhead, Handshakers handshakers, RegisterRef ref)230 public static void pingPongSmallFrameSmallBufferTest( 231 int frameProtectorOverhead, Handshakers handshakers, RegisterRef ref) 232 throws GeneralSecurityException { 233 performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers); 234 235 // We send messages using small non-aligned buffers. We use 3 and 5, small primes. 236 TsiFrameProtector clientProtector = 237 handshakers.getClient().createFrameProtector(frameProtectorOverhead + 3, alloc); 238 TsiFrameProtector serverProtector = 239 handshakers.getServer().createFrameProtector(frameProtectorOverhead + 5, alloc); 240 241 sendMessage(clientProtector, serverProtector, EXAMPLE_MESSAGE1_LEN, EXAMPLE_MESSAGE1, ref); 242 sendMessage(serverProtector, clientProtector, EXAMPLE_MESSAGE2_LEN, EXAMPLE_MESSAGE2, ref); 243 244 sendMessage(clientProtector, serverProtector, EXAMPLE_MESSAGE1_LEN, EXAMPLE_MESSAGE1, ref); 245 sendMessage(serverProtector, clientProtector, EXAMPLE_MESSAGE2_LEN, EXAMPLE_MESSAGE2, ref); 246 247 clientProtector.destroy(); 248 serverProtector.destroy(); 249 } 250 251 /** Test corrupted counter. */ corruptedCounterTest(Handshakers handshakers, RegisterRef ref)252 public static void corruptedCounterTest(Handshakers handshakers, RegisterRef ref) 253 throws GeneralSecurityException { 254 performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers); 255 256 TsiFrameProtector sender = handshakers.getClient().createFrameProtector(alloc); 257 TsiFrameProtector receiver = handshakers.getServer().createFrameProtector(alloc); 258 259 String message = "hello world"; 260 ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8)); 261 final List<ByteBuf> protectOut = new ArrayList<>(); 262 List<Object> unprotectOut = new ArrayList<>(); 263 264 sender.protectFlush( 265 Collections.singletonList(plaintextBuffer), 266 new Consumer<ByteBuf>() { 267 @Override 268 public void accept(ByteBuf buf) { 269 protectOut.add(buf); 270 } 271 }, 272 alloc); 273 assertThat(protectOut.size()).isEqualTo(1); 274 275 ByteBuf protect = ref.register(protectOut.get(0)); 276 // Unprotect once to increase receiver counter. 277 receiver.unprotect(protect.slice(), unprotectOut, alloc); 278 assertThat(unprotectOut.size()).isEqualTo(1); 279 ref.register((ByteBuf) unprotectOut.get(0)); 280 281 try { 282 receiver.unprotect(protect, unprotectOut, alloc); 283 fail("Exception expected"); 284 } catch (AEADBadTagException ex) { 285 assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_RE); 286 } 287 288 sender.destroy(); 289 receiver.destroy(); 290 } 291 292 /** Test corrupted ciphertext. */ corruptedCiphertextTest(Handshakers handshakers, RegisterRef ref)293 public static void corruptedCiphertextTest(Handshakers handshakers, RegisterRef ref) 294 throws GeneralSecurityException { 295 performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers); 296 297 TsiFrameProtector sender = handshakers.getClient().createFrameProtector(alloc); 298 TsiFrameProtector receiver = handshakers.getServer().createFrameProtector(alloc); 299 300 String message = "hello world"; 301 ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8)); 302 final List<ByteBuf> protectOut = new ArrayList<>(); 303 List<Object> unprotectOut = new ArrayList<>(); 304 305 sender.protectFlush( 306 Collections.singletonList(plaintextBuffer), 307 new Consumer<ByteBuf>() { 308 @Override 309 public void accept(ByteBuf buf) { 310 protectOut.add(buf); 311 } 312 }, 313 alloc); 314 assertThat(protectOut.size()).isEqualTo(1); 315 316 ByteBuf protect = ref.register(protectOut.get(0)); 317 int ciphertextIdx = protect.writerIndex() - FakeChannelCrypter.getTagBytes() - 2; 318 protect.setByte(ciphertextIdx, protect.getByte(ciphertextIdx) + 1); 319 320 try { 321 receiver.unprotect(protect, unprotectOut, alloc); 322 fail("Exception expected"); 323 } catch (AEADBadTagException ex) { 324 assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_RE); 325 } 326 327 sender.destroy(); 328 receiver.destroy(); 329 } 330 331 /** Test corrupted tag. */ corruptedTagTest(Handshakers handshakers, RegisterRef ref)332 public static void corruptedTagTest(Handshakers handshakers, RegisterRef ref) 333 throws GeneralSecurityException { 334 performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers); 335 336 TsiFrameProtector sender = handshakers.getClient().createFrameProtector(alloc); 337 TsiFrameProtector receiver = handshakers.getServer().createFrameProtector(alloc); 338 339 String message = "hello world"; 340 ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8)); 341 final List<ByteBuf> protectOut = new ArrayList<>(); 342 List<Object> unprotectOut = new ArrayList<>(); 343 344 sender.protectFlush( 345 Collections.singletonList(plaintextBuffer), 346 new Consumer<ByteBuf>() { 347 @Override 348 public void accept(ByteBuf buf) { 349 protectOut.add(buf); 350 } 351 }, 352 alloc); 353 assertThat(protectOut.size()).isEqualTo(1); 354 355 ByteBuf protect = ref.register(protectOut.get(0)); 356 int tagIdx = protect.writerIndex() - 1; 357 protect.setByte(tagIdx, protect.getByte(tagIdx) + 1); 358 359 try { 360 receiver.unprotect(protect, unprotectOut, alloc); 361 fail("Exception expected"); 362 } catch (AEADBadTagException ex) { 363 assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_RE); 364 } 365 366 sender.destroy(); 367 receiver.destroy(); 368 } 369 370 /** Test reflected ciphertext. */ reflectedCiphertextTest(Handshakers handshakers, RegisterRef ref)371 public static void reflectedCiphertextTest(Handshakers handshakers, RegisterRef ref) 372 throws GeneralSecurityException { 373 performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers); 374 375 TsiFrameProtector sender = handshakers.getClient().createFrameProtector(alloc); 376 TsiFrameProtector receiver = handshakers.getServer().createFrameProtector(alloc); 377 378 String message = "hello world"; 379 ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8)); 380 final List<ByteBuf> protectOut = new ArrayList<>(); 381 List<Object> unprotectOut = new ArrayList<>(); 382 383 sender.protectFlush( 384 Collections.singletonList(plaintextBuffer), 385 new Consumer<ByteBuf>() { 386 @Override 387 public void accept(ByteBuf buf) { 388 protectOut.add(buf); 389 } 390 }, 391 alloc); 392 assertThat(protectOut.size()).isEqualTo(1); 393 394 ByteBuf protect = ref.register(protectOut.get(0)); 395 try { 396 sender.unprotect(protect.slice(), unprotectOut, alloc); 397 fail("Exception expected"); 398 } catch (AEADBadTagException ex) { 399 assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_RE); 400 } 401 402 sender.destroy(); 403 receiver.destroy(); 404 } 405 } 406