1 /* 2 * Copyright 2018 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.alts.internal; 18 19 import static java.nio.charset.StandardCharsets.UTF_8; 20 21 import com.google.common.base.Preconditions; 22 import io.grpc.alts.internal.TsiPeer.Property; 23 import io.netty.buffer.ByteBufAllocator; 24 import java.nio.ByteBuffer; 25 import java.security.GeneralSecurityException; 26 import java.util.Collections; 27 import java.util.logging.Level; 28 import java.util.logging.Logger; 29 30 /** 31 * A fake handshaker compatible with security/transport_security/fake_transport_security.h See 32 * {@link TsiHandshaker} for documentation. 33 */ 34 public class FakeTsiHandshaker implements TsiHandshaker { 35 private static final Logger logger = Logger.getLogger(FakeTsiHandshaker.class.getName()); 36 37 private static final TsiHandshakerFactory clientHandshakerFactory = 38 new TsiHandshakerFactory() { 39 @Override 40 public TsiHandshaker newHandshaker(String authority) { 41 return new FakeTsiHandshaker(true); 42 } 43 }; 44 45 private static final TsiHandshakerFactory serverHandshakerFactory = 46 new TsiHandshakerFactory() { 47 @Override 48 public TsiHandshaker newHandshaker(String authority) { 49 return new FakeTsiHandshaker(false); 50 } 51 }; 52 53 private boolean isClient; 54 private ByteBuffer sendBuffer = null; 55 private AltsFraming.Parser frameParser = new AltsFraming.Parser(); 56 57 private State sendState; 58 private State receiveState; 59 60 enum State { 61 CLIENT_NONE, 62 SERVER_NONE, 63 CLIENT_INIT, 64 SERVER_INIT, 65 CLIENT_FINISHED, 66 SERVER_FINISHED; 67 68 // Returns the next State. In order to advance to sendState=N, receiveState must be N-1. next()69 public State next() { 70 if (ordinal() + 1 < values().length) { 71 return values()[ordinal() + 1]; 72 } 73 throw new UnsupportedOperationException("Can't call next() on last element: " + this); 74 } 75 } 76 clientHandshakerFactory()77 public static TsiHandshakerFactory clientHandshakerFactory() { 78 return clientHandshakerFactory; 79 } 80 serverHandshakerFactory()81 public static TsiHandshakerFactory serverHandshakerFactory() { 82 return serverHandshakerFactory; 83 } 84 newFakeHandshakerClient()85 public static TsiHandshaker newFakeHandshakerClient() { 86 return clientHandshakerFactory.newHandshaker(null); 87 } 88 newFakeHandshakerServer()89 public static TsiHandshaker newFakeHandshakerServer() { 90 return serverHandshakerFactory.newHandshaker(null); 91 } 92 FakeTsiHandshaker(boolean isClient)93 protected FakeTsiHandshaker(boolean isClient) { 94 this.isClient = isClient; 95 if (isClient) { 96 sendState = State.CLIENT_NONE; 97 receiveState = State.SERVER_NONE; 98 } else { 99 sendState = State.SERVER_NONE; 100 receiveState = State.CLIENT_NONE; 101 } 102 } 103 getNextState(State state)104 private State getNextState(State state) { 105 switch (state) { 106 case CLIENT_NONE: 107 return State.CLIENT_INIT; 108 case SERVER_NONE: 109 return State.SERVER_INIT; 110 case CLIENT_INIT: 111 return State.CLIENT_FINISHED; 112 case SERVER_INIT: 113 return State.SERVER_FINISHED; 114 default: 115 return null; 116 } 117 } 118 getNextMessage()119 private String getNextMessage() { 120 State result = getNextState(sendState); 121 return result == null ? "BAD STATE" : result.toString(); 122 } 123 getExpectedMessage()124 private String getExpectedMessage() { 125 State result = getNextState(receiveState); 126 return result == null ? "BAD STATE" : result.toString(); 127 } 128 incrementSendState()129 private void incrementSendState() { 130 sendState = getNextState(sendState); 131 } 132 incrementReceiveState()133 private void incrementReceiveState() { 134 receiveState = getNextState(receiveState); 135 } 136 137 @Override getBytesToSendToPeer(ByteBuffer bytes)138 public void getBytesToSendToPeer(ByteBuffer bytes) throws GeneralSecurityException { 139 Preconditions.checkNotNull(bytes); 140 141 // If we're done, return nothing. 142 if (sendState == State.CLIENT_FINISHED || sendState == State.SERVER_FINISHED) { 143 return; 144 } 145 146 // Prepare the next message, if neeeded. 147 if (sendBuffer == null) { 148 if (sendState.next() != receiveState) { 149 // We're still waiting for bytes from the peer, so bail. 150 return; 151 } 152 ByteBuffer payload = ByteBuffer.wrap(getNextMessage().getBytes(UTF_8)); 153 sendBuffer = AltsFraming.toFrame(payload, payload.remaining()); 154 logger.log(Level.FINE, "Buffered message: {0}", getNextMessage()); 155 } 156 while (bytes.hasRemaining() && sendBuffer.hasRemaining()) { 157 bytes.put(sendBuffer.get()); 158 } 159 if (!sendBuffer.hasRemaining()) { 160 // Get ready to send the next message. 161 sendBuffer = null; 162 incrementSendState(); 163 } 164 } 165 166 @Override processBytesFromPeer(ByteBuffer bytes)167 public boolean processBytesFromPeer(ByteBuffer bytes) throws GeneralSecurityException { 168 Preconditions.checkNotNull(bytes); 169 170 frameParser.readBytes(bytes); 171 if (frameParser.isComplete()) { 172 ByteBuffer messageBytes = frameParser.getRawFrame(); 173 int offset = AltsFraming.getFramingOverhead(); 174 int length = messageBytes.limit() - offset; 175 String message = new String(messageBytes.array(), offset, length, UTF_8); 176 logger.log(Level.FINE, "Read message: {0}", message); 177 178 if (!message.equals(getExpectedMessage())) { 179 throw new IllegalArgumentException( 180 "Bad handshake message. Got " 181 + message 182 + " (length = " 183 + message.length() 184 + ") expected " 185 + getExpectedMessage() 186 + " (length = " 187 + getExpectedMessage().length() 188 + ")"); 189 } 190 incrementReceiveState(); 191 return true; 192 } 193 return false; 194 } 195 196 @Override isInProgress()197 public boolean isInProgress() { 198 boolean finishedReceiving = 199 receiveState == State.CLIENT_FINISHED || receiveState == State.SERVER_FINISHED; 200 boolean finishedSending = 201 sendState == State.CLIENT_FINISHED || sendState == State.SERVER_FINISHED; 202 return !finishedSending || !finishedReceiving; 203 } 204 205 @Override extractPeer()206 public TsiPeer extractPeer() { 207 return new TsiPeer(Collections.<Property<?>>emptyList()); 208 } 209 210 @Override extractPeerObject()211 public Object extractPeerObject() { 212 return AltsAuthContext.getDefaultInstance(); 213 } 214 215 @Override createFrameProtector(int maxFrameSize, ByteBufAllocator alloc)216 public TsiFrameProtector createFrameProtector(int maxFrameSize, ByteBufAllocator alloc) { 217 Preconditions.checkState(!isInProgress(), "Handshake is not complete."); 218 219 // We use an all-zero key, since this is the fake handshaker. 220 byte[] key = new byte[AltsChannelCrypter.getKeyLength()]; 221 return new AltsTsiFrameProtector(maxFrameSize, new AltsChannelCrypter(key, isClient), alloc); 222 } 223 224 @Override createFrameProtector(ByteBufAllocator alloc)225 public TsiFrameProtector createFrameProtector(ByteBufAllocator alloc) { 226 return createFrameProtector(AltsTsiFrameProtector.getMaxAllowedFrameBytes(), alloc); 227 } 228 } 229