1 /* 2 * Copyright 2018 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.alts.internal; 18 19 import static java.nio.charset.StandardCharsets.UTF_8; 20 import static org.junit.Assert.assertEquals; 21 import static org.junit.Assert.assertFalse; 22 23 import com.google.common.testing.GcFinalization; 24 import io.grpc.alts.internal.ByteBufTestUtils.RegisterRef; 25 import io.grpc.alts.internal.TsiTest.Handshakers; 26 import io.netty.buffer.ByteBuf; 27 import io.netty.util.ReferenceCounted; 28 import io.netty.util.ResourceLeakDetector; 29 import io.netty.util.ResourceLeakDetector.Level; 30 import java.nio.ByteBuffer; 31 import java.security.GeneralSecurityException; 32 import java.util.ArrayList; 33 import java.util.List; 34 import org.junit.After; 35 import org.junit.Before; 36 import org.junit.Test; 37 import org.junit.runner.RunWith; 38 import org.junit.runners.JUnit4; 39 40 /** Unit tests for {@link TsiHandshaker}. */ 41 @RunWith(JUnit4.class) 42 public class FakeTsiTest { 43 44 private static final int OVERHEAD = 45 FakeChannelCrypter.getTagBytes() + AltsTsiFrameProtector.getHeaderBytes(); 46 47 private final List<ReferenceCounted> references = new ArrayList<>(); 48 private final RegisterRef ref = 49 new RegisterRef() { 50 @Override 51 public ByteBuf register(ByteBuf buf) { 52 if (buf != null) { 53 references.add(buf); 54 } 55 return buf; 56 } 57 }; 58 newHandshakers()59 private static Handshakers newHandshakers() { 60 TsiHandshaker clientHandshaker = FakeTsiHandshaker.newFakeHandshakerClient(); 61 TsiHandshaker serverHandshaker = FakeTsiHandshaker.newFakeHandshakerServer(); 62 return new Handshakers(clientHandshaker, serverHandshaker); 63 } 64 65 @Before setUp()66 public void setUp() { 67 ResourceLeakDetector.setLevel(Level.PARANOID); 68 } 69 70 @After tearDown()71 public void tearDown() { 72 for (ReferenceCounted reference : references) { 73 reference.release(); 74 } 75 references.clear(); 76 // Increase our chances to detect ByteBuf leaks. 77 GcFinalization.awaitFullGc(); 78 } 79 80 @Test handshakeStateOrderTest()81 public void handshakeStateOrderTest() { 82 try { 83 Handshakers handshakers = newHandshakers(); 84 TsiHandshaker clientHandshaker = handshakers.getClient(); 85 TsiHandshaker serverHandshaker = handshakers.getServer(); 86 87 byte[] transportBufferBytes = new byte[TsiTest.getDefaultTransportBufferSize()]; 88 ByteBuffer transportBuffer = ByteBuffer.wrap(transportBufferBytes); 89 transportBuffer.limit(0); // Start off with an empty buffer 90 91 transportBuffer.clear(); 92 clientHandshaker.getBytesToSendToPeer(transportBuffer); 93 transportBuffer.flip(); 94 assertEquals( 95 FakeTsiHandshaker.State.CLIENT_INIT.toString().trim(), 96 new String(transportBufferBytes, 4, transportBuffer.remaining(), UTF_8).trim()); 97 98 serverHandshaker.processBytesFromPeer(transportBuffer); 99 assertFalse(transportBuffer.hasRemaining()); 100 101 // client shouldn't offer any more bytes 102 transportBuffer.clear(); 103 clientHandshaker.getBytesToSendToPeer(transportBuffer); 104 transportBuffer.flip(); 105 assertFalse(transportBuffer.hasRemaining()); 106 107 transportBuffer.clear(); 108 serverHandshaker.getBytesToSendToPeer(transportBuffer); 109 transportBuffer.flip(); 110 assertEquals( 111 FakeTsiHandshaker.State.SERVER_INIT.toString().trim(), 112 new String(transportBufferBytes, 4, transportBuffer.remaining(), UTF_8).trim()); 113 114 clientHandshaker.processBytesFromPeer(transportBuffer); 115 assertFalse(transportBuffer.hasRemaining()); 116 117 // server shouldn't offer any more bytes 118 transportBuffer.clear(); 119 serverHandshaker.getBytesToSendToPeer(transportBuffer); 120 transportBuffer.flip(); 121 assertFalse(transportBuffer.hasRemaining()); 122 123 transportBuffer.clear(); 124 clientHandshaker.getBytesToSendToPeer(transportBuffer); 125 transportBuffer.flip(); 126 assertEquals( 127 FakeTsiHandshaker.State.CLIENT_FINISHED.toString().trim(), 128 new String(transportBufferBytes, 4, transportBuffer.remaining(), UTF_8).trim()); 129 130 serverHandshaker.processBytesFromPeer(transportBuffer); 131 assertFalse(transportBuffer.hasRemaining()); 132 133 // client shouldn't offer any more bytes 134 transportBuffer.clear(); 135 clientHandshaker.getBytesToSendToPeer(transportBuffer); 136 transportBuffer.flip(); 137 assertFalse(transportBuffer.hasRemaining()); 138 139 transportBuffer.clear(); 140 serverHandshaker.getBytesToSendToPeer(transportBuffer); 141 transportBuffer.flip(); 142 assertEquals( 143 FakeTsiHandshaker.State.SERVER_FINISHED.toString().trim(), 144 new String(transportBufferBytes, 4, transportBuffer.remaining(), UTF_8).trim()); 145 146 clientHandshaker.processBytesFromPeer(transportBuffer); 147 assertFalse(transportBuffer.hasRemaining()); 148 149 // server shouldn't offer any more bytes 150 transportBuffer.clear(); 151 serverHandshaker.getBytesToSendToPeer(transportBuffer); 152 transportBuffer.flip(); 153 assertFalse(transportBuffer.hasRemaining()); 154 } catch (GeneralSecurityException e) { 155 throw new AssertionError(e); 156 } 157 } 158 159 @Test handshake()160 public void handshake() throws GeneralSecurityException { 161 TsiTest.handshakeTest(newHandshakers()); 162 } 163 164 @Test handshakeSmallBuffer()165 public void handshakeSmallBuffer() throws GeneralSecurityException { 166 TsiTest.handshakeSmallBufferTest(newHandshakers()); 167 } 168 169 @Test pingPong()170 public void pingPong() throws GeneralSecurityException { 171 TsiTest.pingPongTest(newHandshakers(), ref); 172 } 173 174 @Test pingPongExactFrameSize()175 public void pingPongExactFrameSize() throws GeneralSecurityException { 176 TsiTest.pingPongExactFrameSizeTest(newHandshakers(), ref); 177 } 178 179 @Test pingPongSmallBuffer()180 public void pingPongSmallBuffer() throws GeneralSecurityException { 181 TsiTest.pingPongSmallBufferTest(newHandshakers(), ref); 182 } 183 184 @Test pingPongSmallFrame()185 public void pingPongSmallFrame() throws GeneralSecurityException { 186 TsiTest.pingPongSmallFrameTest(OVERHEAD, newHandshakers(), ref); 187 } 188 189 @Test pingPongSmallFrameSmallBuffer()190 public void pingPongSmallFrameSmallBuffer() throws GeneralSecurityException { 191 TsiTest.pingPongSmallFrameSmallBufferTest(OVERHEAD, newHandshakers(), ref); 192 } 193 194 @Test corruptedCounter()195 public void corruptedCounter() throws GeneralSecurityException { 196 TsiTest.corruptedCounterTest(newHandshakers(), ref); 197 } 198 199 @Test corruptedCiphertext()200 public void corruptedCiphertext() throws GeneralSecurityException { 201 TsiTest.corruptedCiphertextTest(newHandshakers(), ref); 202 } 203 204 @Test corruptedTag()205 public void corruptedTag() throws GeneralSecurityException { 206 TsiTest.corruptedTagTest(newHandshakers(), ref); 207 } 208 209 @Test reflectedCiphertext()210 public void reflectedCiphertext() throws GeneralSecurityException { 211 TsiTest.reflectedCiphertextTest(newHandshakers(), ref); 212 } 213 } 214