1 /*
2  * Copyright 2018 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.alts.internal;
18 
19 import static java.nio.charset.StandardCharsets.UTF_8;
20 
21 import com.google.common.base.Preconditions;
22 import io.grpc.alts.internal.TsiPeer.Property;
23 import io.netty.buffer.ByteBufAllocator;
24 import java.nio.ByteBuffer;
25 import java.security.GeneralSecurityException;
26 import java.util.Collections;
27 import java.util.logging.Level;
28 import java.util.logging.Logger;
29 
30 /**
31  * A fake handshaker compatible with security/transport_security/fake_transport_security.h See
32  * {@link TsiHandshaker} for documentation.
33  */
34 public class FakeTsiHandshaker implements TsiHandshaker {
35   private static final Logger logger = Logger.getLogger(FakeTsiHandshaker.class.getName());
36 
37   private static final TsiHandshakerFactory clientHandshakerFactory =
38       new TsiHandshakerFactory() {
39         @Override
40         public TsiHandshaker newHandshaker(String authority) {
41           return new FakeTsiHandshaker(true);
42         }
43       };
44 
45   private static final TsiHandshakerFactory serverHandshakerFactory =
46       new TsiHandshakerFactory() {
47         @Override
48         public TsiHandshaker newHandshaker(String authority) {
49           return new FakeTsiHandshaker(false);
50         }
51       };
52 
53   private boolean isClient;
54   private ByteBuffer sendBuffer = null;
55   private AltsFraming.Parser frameParser = new AltsFraming.Parser();
56 
57   private State sendState;
58   private State receiveState;
59 
60   enum State {
61     CLIENT_NONE,
62     SERVER_NONE,
63     CLIENT_INIT,
64     SERVER_INIT,
65     CLIENT_FINISHED,
66     SERVER_FINISHED;
67 
68     // Returns the next State. In order to advance to sendState=N, receiveState must be N-1.
next()69     public State next() {
70       if (ordinal() + 1 < values().length) {
71         return values()[ordinal() + 1];
72       }
73       throw new UnsupportedOperationException("Can't call next() on last element: " + this);
74     }
75   }
76 
clientHandshakerFactory()77   public static TsiHandshakerFactory clientHandshakerFactory() {
78     return clientHandshakerFactory;
79   }
80 
serverHandshakerFactory()81   public static TsiHandshakerFactory serverHandshakerFactory() {
82     return serverHandshakerFactory;
83   }
84 
newFakeHandshakerClient()85   public static TsiHandshaker newFakeHandshakerClient() {
86     return clientHandshakerFactory.newHandshaker(null);
87   }
88 
newFakeHandshakerServer()89   public static TsiHandshaker newFakeHandshakerServer() {
90     return serverHandshakerFactory.newHandshaker(null);
91   }
92 
FakeTsiHandshaker(boolean isClient)93   protected FakeTsiHandshaker(boolean isClient) {
94     this.isClient = isClient;
95     if (isClient) {
96       sendState = State.CLIENT_NONE;
97       receiveState = State.SERVER_NONE;
98     } else {
99       sendState = State.SERVER_NONE;
100       receiveState = State.CLIENT_NONE;
101     }
102   }
103 
getNextState(State state)104   private State getNextState(State state) {
105     switch (state) {
106       case CLIENT_NONE:
107         return State.CLIENT_INIT;
108       case SERVER_NONE:
109         return State.SERVER_INIT;
110       case CLIENT_INIT:
111         return State.CLIENT_FINISHED;
112       case SERVER_INIT:
113         return State.SERVER_FINISHED;
114       default:
115         return null;
116     }
117   }
118 
getNextMessage()119   private String getNextMessage() {
120     State result = getNextState(sendState);
121     return result == null ? "BAD STATE" : result.toString();
122   }
123 
getExpectedMessage()124   private String getExpectedMessage() {
125     State result = getNextState(receiveState);
126     return result == null ? "BAD STATE" : result.toString();
127   }
128 
incrementSendState()129   private void incrementSendState() {
130     sendState = getNextState(sendState);
131   }
132 
incrementReceiveState()133   private void incrementReceiveState() {
134     receiveState = getNextState(receiveState);
135   }
136 
137   @Override
getBytesToSendToPeer(ByteBuffer bytes)138   public void getBytesToSendToPeer(ByteBuffer bytes) throws GeneralSecurityException {
139     Preconditions.checkNotNull(bytes);
140 
141     // If we're done, return nothing.
142     if (sendState == State.CLIENT_FINISHED || sendState == State.SERVER_FINISHED) {
143       return;
144     }
145 
146     // Prepare the next message, if neeeded.
147     if (sendBuffer == null) {
148       if (sendState.next() != receiveState) {
149         // We're still waiting for bytes from the peer, so bail.
150         return;
151       }
152       ByteBuffer payload = ByteBuffer.wrap(getNextMessage().getBytes(UTF_8));
153       sendBuffer = AltsFraming.toFrame(payload, payload.remaining());
154       logger.log(Level.FINE, "Buffered message: {0}", getNextMessage());
155     }
156     while (bytes.hasRemaining() && sendBuffer.hasRemaining()) {
157       bytes.put(sendBuffer.get());
158     }
159     if (!sendBuffer.hasRemaining()) {
160       // Get ready to send the next message.
161       sendBuffer = null;
162       incrementSendState();
163     }
164   }
165 
166   @Override
processBytesFromPeer(ByteBuffer bytes)167   public boolean processBytesFromPeer(ByteBuffer bytes) throws GeneralSecurityException {
168     Preconditions.checkNotNull(bytes);
169 
170     frameParser.readBytes(bytes);
171     if (frameParser.isComplete()) {
172       ByteBuffer messageBytes = frameParser.getRawFrame();
173       int offset = AltsFraming.getFramingOverhead();
174       int length = messageBytes.limit() - offset;
175       String message = new String(messageBytes.array(), offset, length, UTF_8);
176       logger.log(Level.FINE, "Read message: {0}", message);
177 
178       if (!message.equals(getExpectedMessage())) {
179         throw new IllegalArgumentException(
180             "Bad handshake message. Got "
181                 + message
182                 + " (length = "
183                 + message.length()
184                 + ") expected "
185                 + getExpectedMessage()
186                 + " (length = "
187                 + getExpectedMessage().length()
188                 + ")");
189       }
190       incrementReceiveState();
191       return true;
192     }
193     return false;
194   }
195 
196   @Override
isInProgress()197   public boolean isInProgress() {
198     boolean finishedReceiving =
199         receiveState == State.CLIENT_FINISHED || receiveState == State.SERVER_FINISHED;
200     boolean finishedSending =
201         sendState == State.CLIENT_FINISHED || sendState == State.SERVER_FINISHED;
202     return !finishedSending || !finishedReceiving;
203   }
204 
205   @Override
extractPeer()206   public TsiPeer extractPeer() {
207     return new TsiPeer(Collections.<Property<?>>emptyList());
208   }
209 
210   @Override
extractPeerObject()211   public Object extractPeerObject() {
212     return AltsAuthContext.getDefaultInstance();
213   }
214 
215   @Override
createFrameProtector(int maxFrameSize, ByteBufAllocator alloc)216   public TsiFrameProtector createFrameProtector(int maxFrameSize, ByteBufAllocator alloc) {
217     Preconditions.checkState(!isInProgress(), "Handshake is not complete.");
218 
219     // We use an all-zero key, since this is the fake handshaker.
220     byte[] key = new byte[AltsChannelCrypter.getKeyLength()];
221     return new AltsTsiFrameProtector(maxFrameSize, new AltsChannelCrypter(key, isClient), alloc);
222   }
223 
224   @Override
createFrameProtector(ByteBufAllocator alloc)225   public TsiFrameProtector createFrameProtector(ByteBufAllocator alloc) {
226     return createFrameProtector(AltsTsiFrameProtector.getMaxAllowedFrameBytes(), alloc);
227   }
228 }
229