1 /*
2  * Copyright 2018 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.alts.internal;
18 
19 import static java.nio.charset.StandardCharsets.UTF_8;
20 import static org.junit.Assert.assertEquals;
21 import static org.junit.Assert.assertFalse;
22 
23 import com.google.common.testing.GcFinalization;
24 import io.grpc.alts.internal.ByteBufTestUtils.RegisterRef;
25 import io.grpc.alts.internal.TsiTest.Handshakers;
26 import io.netty.buffer.ByteBuf;
27 import io.netty.util.ReferenceCounted;
28 import io.netty.util.ResourceLeakDetector;
29 import io.netty.util.ResourceLeakDetector.Level;
30 import java.nio.ByteBuffer;
31 import java.security.GeneralSecurityException;
32 import java.util.ArrayList;
33 import java.util.List;
34 import org.junit.After;
35 import org.junit.Before;
36 import org.junit.Test;
37 import org.junit.runner.RunWith;
38 import org.junit.runners.JUnit4;
39 
40 /** Unit tests for {@link TsiHandshaker}. */
41 @RunWith(JUnit4.class)
42 public class FakeTsiTest {
43 
44   private static final int OVERHEAD =
45       FakeChannelCrypter.getTagBytes() + AltsTsiFrameProtector.getHeaderBytes();
46 
47   private final List<ReferenceCounted> references = new ArrayList<>();
48   private final RegisterRef ref =
49       new RegisterRef() {
50         @Override
51         public ByteBuf register(ByteBuf buf) {
52           if (buf != null) {
53             references.add(buf);
54           }
55           return buf;
56         }
57       };
58 
newHandshakers()59   private static Handshakers newHandshakers() {
60     TsiHandshaker clientHandshaker = FakeTsiHandshaker.newFakeHandshakerClient();
61     TsiHandshaker serverHandshaker = FakeTsiHandshaker.newFakeHandshakerServer();
62     return new Handshakers(clientHandshaker, serverHandshaker);
63   }
64 
65   @Before
setUp()66   public void setUp() {
67     ResourceLeakDetector.setLevel(Level.PARANOID);
68   }
69 
70   @After
tearDown()71   public void tearDown() {
72     for (ReferenceCounted reference : references) {
73       reference.release();
74     }
75     references.clear();
76     // Increase our chances to detect ByteBuf leaks.
77     GcFinalization.awaitFullGc();
78   }
79 
80   @Test
handshakeStateOrderTest()81   public void handshakeStateOrderTest() {
82     try {
83       Handshakers handshakers = newHandshakers();
84       TsiHandshaker clientHandshaker = handshakers.getClient();
85       TsiHandshaker serverHandshaker = handshakers.getServer();
86 
87       byte[] transportBufferBytes = new byte[TsiTest.getDefaultTransportBufferSize()];
88       ByteBuffer transportBuffer = ByteBuffer.wrap(transportBufferBytes);
89       transportBuffer.limit(0); // Start off with an empty buffer
90 
91       transportBuffer.clear();
92       clientHandshaker.getBytesToSendToPeer(transportBuffer);
93       transportBuffer.flip();
94       assertEquals(
95           FakeTsiHandshaker.State.CLIENT_INIT.toString().trim(),
96           new String(transportBufferBytes, 4, transportBuffer.remaining(), UTF_8).trim());
97 
98       serverHandshaker.processBytesFromPeer(transportBuffer);
99       assertFalse(transportBuffer.hasRemaining());
100 
101       // client shouldn't offer any more bytes
102       transportBuffer.clear();
103       clientHandshaker.getBytesToSendToPeer(transportBuffer);
104       transportBuffer.flip();
105       assertFalse(transportBuffer.hasRemaining());
106 
107       transportBuffer.clear();
108       serverHandshaker.getBytesToSendToPeer(transportBuffer);
109       transportBuffer.flip();
110       assertEquals(
111           FakeTsiHandshaker.State.SERVER_INIT.toString().trim(),
112           new String(transportBufferBytes, 4, transportBuffer.remaining(), UTF_8).trim());
113 
114       clientHandshaker.processBytesFromPeer(transportBuffer);
115       assertFalse(transportBuffer.hasRemaining());
116 
117       // server shouldn't offer any more bytes
118       transportBuffer.clear();
119       serverHandshaker.getBytesToSendToPeer(transportBuffer);
120       transportBuffer.flip();
121       assertFalse(transportBuffer.hasRemaining());
122 
123       transportBuffer.clear();
124       clientHandshaker.getBytesToSendToPeer(transportBuffer);
125       transportBuffer.flip();
126       assertEquals(
127           FakeTsiHandshaker.State.CLIENT_FINISHED.toString().trim(),
128           new String(transportBufferBytes, 4, transportBuffer.remaining(), UTF_8).trim());
129 
130       serverHandshaker.processBytesFromPeer(transportBuffer);
131       assertFalse(transportBuffer.hasRemaining());
132 
133       // client shouldn't offer any more bytes
134       transportBuffer.clear();
135       clientHandshaker.getBytesToSendToPeer(transportBuffer);
136       transportBuffer.flip();
137       assertFalse(transportBuffer.hasRemaining());
138 
139       transportBuffer.clear();
140       serverHandshaker.getBytesToSendToPeer(transportBuffer);
141       transportBuffer.flip();
142       assertEquals(
143           FakeTsiHandshaker.State.SERVER_FINISHED.toString().trim(),
144           new String(transportBufferBytes, 4, transportBuffer.remaining(), UTF_8).trim());
145 
146       clientHandshaker.processBytesFromPeer(transportBuffer);
147       assertFalse(transportBuffer.hasRemaining());
148 
149       // server shouldn't offer any more bytes
150       transportBuffer.clear();
151       serverHandshaker.getBytesToSendToPeer(transportBuffer);
152       transportBuffer.flip();
153       assertFalse(transportBuffer.hasRemaining());
154     } catch (GeneralSecurityException e) {
155       throw new AssertionError(e);
156     }
157   }
158 
159   @Test
handshake()160   public void handshake() throws GeneralSecurityException {
161     TsiTest.handshakeTest(newHandshakers());
162   }
163 
164   @Test
handshakeSmallBuffer()165   public void handshakeSmallBuffer() throws GeneralSecurityException {
166     TsiTest.handshakeSmallBufferTest(newHandshakers());
167   }
168 
169   @Test
pingPong()170   public void pingPong() throws GeneralSecurityException {
171     TsiTest.pingPongTest(newHandshakers(), ref);
172   }
173 
174   @Test
pingPongExactFrameSize()175   public void pingPongExactFrameSize() throws GeneralSecurityException {
176     TsiTest.pingPongExactFrameSizeTest(newHandshakers(), ref);
177   }
178 
179   @Test
pingPongSmallBuffer()180   public void pingPongSmallBuffer() throws GeneralSecurityException {
181     TsiTest.pingPongSmallBufferTest(newHandshakers(), ref);
182   }
183 
184   @Test
pingPongSmallFrame()185   public void pingPongSmallFrame() throws GeneralSecurityException {
186     TsiTest.pingPongSmallFrameTest(OVERHEAD, newHandshakers(), ref);
187   }
188 
189   @Test
pingPongSmallFrameSmallBuffer()190   public void pingPongSmallFrameSmallBuffer() throws GeneralSecurityException {
191     TsiTest.pingPongSmallFrameSmallBufferTest(OVERHEAD, newHandshakers(), ref);
192   }
193 
194   @Test
corruptedCounter()195   public void corruptedCounter() throws GeneralSecurityException {
196     TsiTest.corruptedCounterTest(newHandshakers(), ref);
197   }
198 
199   @Test
corruptedCiphertext()200   public void corruptedCiphertext() throws GeneralSecurityException {
201     TsiTest.corruptedCiphertextTest(newHandshakers(), ref);
202   }
203 
204   @Test
corruptedTag()205   public void corruptedTag() throws GeneralSecurityException {
206     TsiTest.corruptedTagTest(newHandshakers(), ref);
207   }
208 
209   @Test
reflectedCiphertext()210   public void reflectedCiphertext() throws GeneralSecurityException {
211     TsiTest.reflectedCiphertextTest(newHandshakers(), ref);
212   }
213 }
214