1 /* Copyright 2018 Google LLC
2  *
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     https://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 package com.google.security.cryptauth.lib.securegcm;
16 
17 import com.google.protobuf.ByteString;
18 import com.google.protobuf.InvalidProtocolBufferException;
19 import com.google.security.cryptauth.lib.securegcm.UkeyProto.Ukey2Alert;
20 import com.google.security.cryptauth.lib.securegcm.UkeyProto.Ukey2ClientFinished;
21 import com.google.security.cryptauth.lib.securegcm.UkeyProto.Ukey2ClientInit;
22 import com.google.security.cryptauth.lib.securegcm.UkeyProto.Ukey2ClientInit.CipherCommitment;
23 import com.google.security.cryptauth.lib.securegcm.UkeyProto.Ukey2Message;
24 import com.google.security.cryptauth.lib.securegcm.UkeyProto.Ukey2ServerInit;
25 import com.google.security.cryptauth.lib.securemessage.CryptoOps;
26 import com.google.security.cryptauth.lib.securemessage.PublicKeyProtoUtil;
27 import com.google.security.cryptauth.lib.securemessage.SecureMessageProto.GenericPublicKey;
28 import java.io.ByteArrayOutputStream;
29 import java.io.IOException;
30 import java.io.UnsupportedEncodingException;
31 import java.security.InvalidKeyException;
32 import java.security.KeyPair;
33 import java.security.MessageDigest;
34 import java.security.NoSuchAlgorithmException;
35 import java.security.PublicKey;
36 import java.security.SecureRandom;
37 import java.security.spec.InvalidKeySpecException;
38 import java.util.Arrays;
39 import java.util.HashMap;
40 import java.util.List;
41 import javax.annotation.Nullable;
42 import javax.crypto.SecretKey;
43 import javax.crypto.spec.SecretKeySpec;
44 
45 /**
46  * Implements UKEY2 and produces a {@link D2DConnectionContext}.
47  *
48  * <p>Client Usage:
49  * <code>
50  * try {
51  *   Ukey2Handshake client = Ukey2Handshake.forInitiator(HandshakeCipher.P256_SHA512);
52  *   byte[] handshakeMessage;
53  *
54  *   // Message 1 (Client Init)
55  *   handshakeMessage = client.getNextHandshakeMessage();
56  *   sendMessageToServer(handshakeMessage);
57  *
58  *   // Message 2 (Server Init)
59  *   handshakeMessage = receiveMessageFromServer();
60  *   client.parseHandshakeMessage(handshakeMessage);
61  *
62  *   // Message 3 (Client Finish)
63  *   handshakeMessage = client.getNextHandshakeMessage();
64  *   sendMessageToServer(handshakeMessage);
65  *
66  *   // Get the auth string
67  *   byte[] clientAuthString = client.getVerificationString(STRING_LENGTH);
68  *   showStringToUser(clientAuthString);
69  *
70  *   // Using out-of-band channel, verify auth string, then call:
71  *   client.verifyHandshake();
72  *
73  *   // Make a connection context
74  *   D2DConnectionContext clientContext = client.toConnectionContext();
75  * } catch (AlertException e) {
76  *   log(e.getMessage);
77  *   sendMessageToServer(e.getAlertMessageToSend());
78  * } catch (HandshakeException e) {
79  *   log(e);
80  *   // terminate handshake
81  * }
82  * </code>
83  *
84  * <p>Server Usage:
85  * <code>
86  * try {
87  *   Ukey2Handshake server = Ukey2Handshake.forResponder(HandshakeCipher.P256_SHA512);
88  *   byte[] handshakeMessage;
89  *
90  *   // Message 1 (Client Init)
91  *   handshakeMessage = receiveMessageFromClient();
92  *   server.parseHandshakeMessage(handshakeMessage);
93  *
94  *   // Message 2 (Server Init)
95  *   handshakeMessage = server.getNextHandshakeMessage();
96  *   sendMessageToServer(handshakeMessage);
97  *
98  *   // Message 3 (Client Finish)
99  *   handshakeMessage = receiveMessageFromClient();
100  *   server.parseHandshakeMessage(handshakeMessage);
101  *
102  *   // Get the auth string
103  *   byte[] serverAuthString = server.getVerificationString(STRING_LENGTH);
104  *   showStringToUser(serverAuthString);
105  *
106  *   // Using out-of-band channel, verify auth string, then call:
107  *   server.verifyHandshake();
108  *
109  *   // Make a connection context
110  *   D2DConnectionContext serverContext = server.toConnectionContext();
111  * } catch (AlertException e) {
112  *   log(e.getMessage);
113  *   sendMessageToClient(e.getAlertMessageToSend());
114  * } catch (HandshakeException e) {
115  *   log(e);
116  *   // terminate handshake
117  * }
118  * </code>
119  */
120 public class Ukey2Handshake {
121 
122   /**
123    * Creates a {@link Ukey2Handshake} with a particular cipher that can be used by an initiator /
124    * client.
125    *
126    * @throws HandshakeException
127    */
forInitiator(HandshakeCipher cipher)128   public static Ukey2Handshake forInitiator(HandshakeCipher cipher) throws HandshakeException {
129     return new Ukey2Handshake(InternalState.CLIENT_START, cipher);
130   }
131 
132   /**
133    * Creates a {@link Ukey2Handshake} with a particular cipher that can be used by an responder /
134    * server.
135    *
136    * @throws HandshakeException
137    */
forResponder(HandshakeCipher cipher)138   public static Ukey2Handshake forResponder(HandshakeCipher cipher) throws HandshakeException {
139     return new Ukey2Handshake(InternalState.SERVER_START, cipher);
140   }
141 
142   /**
143    * Handshake States. Meaning of states:
144    * <ul>
145    * <li>IN_PROGRESS: The handshake is in progress, caller should use
146    * {@link Ukey2Handshake#getNextHandshakeMessage()} and
147    * {@link Ukey2Handshake#parseHandshakeMessage(byte[])} to continue the handshake.
148    * <li>VERIFICATION_NEEDED: The handshake is complete, but pending verification of the
149    * authentication string. Clients should use {@link Ukey2Handshake#getVerificationString(int)} to
150    * get the verification string and use out-of-band methods to authenticate the handshake.
151    * <li>VERIFICATION_IN_PROGRESS: The handshake is complete, verification string has been
152    * generated, but has not been confirmed. After authenticating the handshake out-of-band, use
153    * {@link Ukey2Handshake#verifyHandshake()} to mark the handshake as verified.
154    * <li>FINISHED: The handshake is finished, and caller can use
155    * {@link Ukey2Handshake#toConnectionContext()} to produce a {@link D2DConnectionContext}.
156    * <li>ALREADY_USED: The handshake has already been used and should be discarded / garbage
157    * collected.
158    * <li>ERROR: The handshake produced an error and should be destroyed.
159    * </ul>
160    */
161   public enum State {
162     IN_PROGRESS,
163     VERIFICATION_NEEDED,
164     VERIFICATION_IN_PROGRESS,
165     FINISHED,
166     ALREADY_USED,
167     ERROR,
168   }
169 
170   /**
171    * Currently implemented UKEY2 handshake ciphers. Each cipher is a tuple consisting of a key
172    * negotiation cipher and a hash function used for a commitment. Currently the ciphers are:
173    * <code>
174    *   +-----------------------------------------------------+
175    *   | Enum        | Key negotiation       | Hash function |
176    *   +-------------+-----------------------+---------------+
177    *   | P256_SHA512 | ECDH using NIST P-256 | SHA512        |
178    *   +-----------------------------------------------------+
179    * </code>
180    *
181    * <p>Note that these should correspond to values in device_to_device_messages.proto.
182    */
183   public enum HandshakeCipher {
184     P256_SHA512(UkeyProto.Ukey2HandshakeCipher.P256_SHA512);
185     // TODO(aczeskis): add CURVE25519_SHA512
186 
187     private final UkeyProto.Ukey2HandshakeCipher value;
188 
HandshakeCipher(UkeyProto.Ukey2HandshakeCipher value)189     HandshakeCipher(UkeyProto.Ukey2HandshakeCipher value) {
190       // Make sure we only accept values that are valid as per the ukey protobuf.
191       // NOTE: Don't use switch statement on value, as that will trigger a bug. b/30682989.
192       if (value == UkeyProto.Ukey2HandshakeCipher.P256_SHA512) {
193           this.value = value;
194       } else {
195           throw new IllegalArgumentException("Unknown cipher value: " + value);
196       }
197     }
198 
getValue()199     public UkeyProto.Ukey2HandshakeCipher getValue() {
200       return value;
201     }
202   }
203 
204   /**
205    * If thrown, this exception contains information that should be sent on the wire. Specifically,
206    * the {@link #getAlertMessageToSend()} method returns a <code>byte[]</code> that communicates the
207    * error to the other party in the handshake. Meanwhile, the {@link #getMessage()} method can be
208    * used to get a log-able error message.
209    */
210   public static class AlertException extends Exception {
211     private final Ukey2Alert alertMessageToSend;
212 
AlertException(String alertMessageToLog, Ukey2Alert alertMessageToSend)213     public AlertException(String alertMessageToLog, Ukey2Alert alertMessageToSend) {
214       super(alertMessageToLog);
215       this.alertMessageToSend = alertMessageToSend;
216     }
217 
218     /**
219      * @return a message suitable for sending to other member of handshake.
220      */
getAlertMessageToSend()221     public byte[] getAlertMessageToSend() {
222       return alertMessageToSend.toByteArray();
223     }
224   }
225 
226   // Maximum version of the handshake supported by this class.
227   public static final int VERSION = 1;
228 
229   // Random nonce is fixed at 32 bytes (as per go/ukey2).
230   private static final int NONCE_LENGTH_IN_BYTES = 32;
231 
232   private static final String UTF_8 = "UTF-8";
233 
234   // Currently, we only support one next protocol.
235   private static final String NEXT_PROTOCOL = "AES_256_CBC-HMAC_SHA256";
236 
237   // Clients need to store a map of message 3's (client finishes) for each commitment.
238   private final HashMap<HandshakeCipher, byte[]> rawMessage3Map = new HashMap<>();
239 
240   private final HandshakeCipher handshakeCipher;
241   private final HandshakeRole handshakeRole;
242   private InternalState handshakeState;
243   private final KeyPair ourKeyPair;
244   private PublicKey theirPublicKey;
245   private SecretKey derivedSecretKey;
246 
247   // Servers need to store client commitments.
248   private byte[] theirCommitment;
249 
250   // We store the raw messages sent for computing the authentication strings and next key.
251   private byte[] rawMessage1;
252   private byte[] rawMessage2;
253 
254   // Enums for internal state machinery
255   private enum InternalState {
256     // Initiator/client state
257     CLIENT_START,
258     CLIENT_WAITING_FOR_SERVER_INIT,
259     CLIENT_AFTER_SERVER_INIT,
260 
261     // Responder/server state
262     SERVER_START,
263     SERVER_AFTER_CLIENT_INIT,
264     SERVER_WAITING_FOR_CLIENT_FINISHED,
265 
266     // Common completion state
267     HANDSHAKE_VERIFICATION_NEEDED,
268     HANDSHAKE_VERIFICATION_IN_PROGRESS,
269     HANDSHAKE_FINISHED,
270     HANDSHAKE_ALREADY_USED,
271     HANDSHAKE_ERROR,
272   }
273 
274   // Helps us remember our role in the handshake
275   private enum HandshakeRole {
276     CLIENT,
277     SERVER
278   }
279 
280   /**
281    * Never invoked directly. Caller should use {@link #forInitiator(HandshakeCipher)} or
282    * {@link #forResponder(HandshakeCipher)} instead.
283    *
284    * @throws HandshakeException if an unrecoverable error occurs and the connection should be shut
285    * down.
286    */
Ukey2Handshake(InternalState state, HandshakeCipher cipher)287   private Ukey2Handshake(InternalState state, HandshakeCipher cipher) throws HandshakeException {
288     if (cipher == null) {
289       throwIllegalArgumentException("Invalid handshake cipher");
290     }
291     this.handshakeCipher = cipher;
292 
293     switch (state) {
294       case CLIENT_START:
295         handshakeRole = HandshakeRole.CLIENT;
296         break;
297       case SERVER_START:
298         handshakeRole = HandshakeRole.SERVER;
299         break;
300       default:
301         throwIllegalStateException("Invalid handshake state");
302         handshakeRole = null; // unreachable, but makes compiler happy
303     }
304     this.handshakeState = state;
305 
306     this.ourKeyPair = genKeyPair(cipher);
307   }
308 
309   /**
310    * Get the next handshake message suitable for sending on the wire.
311    *
312    * @throws HandshakeException if an unrecoverable error occurs and the connection should be shut
313    * down.
314    */
getNextHandshakeMessage()315   public byte[] getNextHandshakeMessage() throws HandshakeException {
316     switch (handshakeState) {
317       case CLIENT_START:
318         rawMessage1 = makeUkey2Message(Ukey2Message.Type.CLIENT_INIT, makeClientInitMessage());
319         handshakeState = InternalState.CLIENT_WAITING_FOR_SERVER_INIT;
320         return rawMessage1;
321 
322       case SERVER_AFTER_CLIENT_INIT:
323         rawMessage2 = makeUkey2Message(Ukey2Message.Type.SERVER_INIT, makeServerInitMessage());
324         handshakeState = InternalState.SERVER_WAITING_FOR_CLIENT_FINISHED;
325         return rawMessage2;
326 
327       case CLIENT_AFTER_SERVER_INIT:
328         // Make sure we have a message 3 for the chosen cipher.
329         if (!rawMessage3Map.containsKey(handshakeCipher)) {
330           throwIllegalStateException(
331               "Client state is CLIENT_AFTER_SERVER_INIT, and cipher is "
332                   + handshakeCipher
333                   + ", but no corresponding raw client finished message has been generated");
334         }
335         handshakeState = InternalState.HANDSHAKE_VERIFICATION_NEEDED;
336         return rawMessage3Map.get(handshakeCipher);
337 
338       default:
339         throwIllegalStateException("Cannot get next message in state: " + handshakeState);
340         return null; // unreachable, but makes compiler happy
341     }
342   }
343 
344   /**
345    * Returns an authentication string suitable for authenticating the handshake out-of-band. Note
346    * that the authentication string can be short (e.g., a 6 digit visual confirmation code). Note:
347    * this should only be called when the state returned byte {@link #getHandshakeState()} is
348    * {@link State#VERIFICATION_NEEDED}, which means this can only be called once.
349    *
350    * @param byteLength length of output in bytes. Min length is 1; max length is 32.
351    */
getVerificationString(int byteLength)352   public byte[] getVerificationString(int byteLength) throws HandshakeException {
353     if (byteLength < 1 || byteLength > 32) {
354       throwIllegalArgumentException("Minimum length is 1 byte, max is 32 bytes");
355     }
356 
357     if (handshakeState != InternalState.HANDSHAKE_VERIFICATION_NEEDED) {
358       throwIllegalStateException("Unexpected state: " + handshakeState);
359     }
360 
361     try {
362       derivedSecretKey =
363           EnrollmentCryptoOps.doKeyAgreement(ourKeyPair.getPrivate(), theirPublicKey);
364     } catch (InvalidKeyException e) {
365       // unreachable in practice
366       throwHandshakeException(e);
367     }
368 
369     ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
370     try {
371       byteStream.write(rawMessage1);
372       byteStream.write(rawMessage2);
373     } catch (IOException e) {
374       // unreachable in practice
375       throwHandshakeException(e);
376     }
377     byte[] info = byteStream.toByteArray();
378 
379     byte[] salt = null;
380 
381     try {
382       salt = "UKEY2 v1 auth".getBytes(UTF_8);
383     } catch (UnsupportedEncodingException e) {
384       // unreachable in practice
385       throwHandshakeException(e);
386     }
387 
388     byte[] authString = null;
389     try {
390       authString = CryptoOps.hkdf(derivedSecretKey, salt, info);
391     } catch (InvalidKeyException | NoSuchAlgorithmException e) {
392       // unreachable in practice
393       throwHandshakeException(e);
394     }
395 
396     handshakeState = InternalState.HANDSHAKE_VERIFICATION_IN_PROGRESS;
397     return Arrays.copyOf(authString, byteLength);
398   }
399 
400   /**
401    * Invoked to let handshake state machine know that caller has validated the authentication
402    * string obtained via {@link #getVerificationString(int)}; Note: this should only be called when
403    * the state returned byte {@link #getHandshakeState()} is {@link State#VERIFICATION_IN_PROGRESS}.
404    */
verifyHandshake()405   public void verifyHandshake() {
406     if (handshakeState != InternalState.HANDSHAKE_VERIFICATION_IN_PROGRESS) {
407       throwIllegalStateException("Unexpected state: " + handshakeState);
408     }
409     handshakeState = InternalState.HANDSHAKE_FINISHED;
410   }
411 
412   /**
413    * Parses the given handshake message.
414    * @throws AlertException if an error occurs that should be sent to other party.
415    * @throws HandshakeException in an error occurs and the connection should be torn down.
416    */
parseHandshakeMessage(byte[] handshakeMessage)417   public void parseHandshakeMessage(byte[] handshakeMessage)
418       throws AlertException, HandshakeException {
419     switch (handshakeState) {
420       case SERVER_START:
421         parseMessage1(handshakeMessage);
422         handshakeState = InternalState.SERVER_AFTER_CLIENT_INIT;
423         break;
424 
425       case CLIENT_WAITING_FOR_SERVER_INIT:
426         parseMessage2(handshakeMessage);
427         handshakeState = InternalState.CLIENT_AFTER_SERVER_INIT;
428         break;
429 
430       case SERVER_WAITING_FOR_CLIENT_FINISHED:
431         parseMessage3(handshakeMessage);
432         handshakeState = InternalState.HANDSHAKE_VERIFICATION_NEEDED;
433         break;
434 
435       default:
436         throwIllegalStateException("Cannot parse message in state " + handshakeState);
437     }
438   }
439 
440   /**
441    * Returns the current state of the handshake. See {@link State}.
442    */
getHandshakeState()443   public State getHandshakeState() {
444     switch (handshakeState) {
445       case CLIENT_START:
446       case CLIENT_WAITING_FOR_SERVER_INIT:
447       case CLIENT_AFTER_SERVER_INIT:
448       case SERVER_START:
449       case SERVER_WAITING_FOR_CLIENT_FINISHED:
450       case SERVER_AFTER_CLIENT_INIT:
451         // fallback intended -- these are all in-progress states
452         return State.IN_PROGRESS;
453 
454       case HANDSHAKE_ERROR:
455         return State.ERROR;
456 
457       case HANDSHAKE_VERIFICATION_NEEDED:
458         return State.VERIFICATION_NEEDED;
459 
460       case HANDSHAKE_VERIFICATION_IN_PROGRESS:
461         return State.VERIFICATION_IN_PROGRESS;
462 
463       case HANDSHAKE_FINISHED:
464         return State.FINISHED;
465 
466       case HANDSHAKE_ALREADY_USED:
467         return State.ALREADY_USED;
468 
469       default:
470         // unreachable in practice
471         throwIllegalStateException("Unknown state");
472         return null; // really unreachable, but makes compiler happy
473     }
474   }
475 
476   /**
477    * Can be called to generate a {@link D2DConnectionContext}. Note: this should only be called
478    * when the state returned byte {@link #getHandshakeState()} is {@link State#FINISHED}.
479    *
480    * @throws HandshakeException
481    */
toConnectionContext()482   public D2DConnectionContext toConnectionContext() throws HandshakeException {
483     switch (handshakeState) {
484       case HANDSHAKE_ERROR:
485         throwIllegalStateException("Cannot make context; handshake had error");
486         return null; // makes linter happy
487       case HANDSHAKE_ALREADY_USED:
488         throwIllegalStateException("Cannot reuse handshake context; is has already been used");
489         return null; // makes linter happy
490       case HANDSHAKE_VERIFICATION_NEEDED:
491         throwIllegalStateException("Handshake not verified, cannot create context");
492         return null; // makes linter happy
493       case HANDSHAKE_FINISHED:
494         // We're done, okay to return a context
495         break;
496       default:
497         // unreachable in practice
498         throwIllegalStateException("Handshake is not complete; cannot create connection context");
499     }
500 
501     if (derivedSecretKey == null) {
502       throwIllegalStateException("Unexpected state error: derived key is null");
503     }
504 
505     ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
506     try {
507       byteStream.write(rawMessage1);
508       byteStream.write(rawMessage2);
509     } catch (IOException e) {
510       // unreachable in practice
511       throwHandshakeException(e);
512     }
513     byte[] info = byteStream.toByteArray();
514 
515     byte[] salt = null;
516     try {
517       salt = "UKEY2 v1 next".getBytes(UTF_8);
518     } catch (UnsupportedEncodingException e) {
519       // unreachable
520       throwHandshakeException(e);
521     }
522 
523     SecretKey nextProtocolKey = null;
524     try {
525       nextProtocolKey = new SecretKeySpec(CryptoOps.hkdf(derivedSecretKey, salt, info), "AES");
526     } catch (InvalidKeyException | NoSuchAlgorithmException e) {
527       // unreachable in practice
528       throwHandshakeException(e);
529     }
530 
531     SecretKey clientKey = null;
532     SecretKey serverKey = null;
533     try {
534       clientKey = D2DCryptoOps.deriveNewKeyForPurpose(nextProtocolKey, "client");
535       serverKey = D2DCryptoOps.deriveNewKeyForPurpose(nextProtocolKey, "server");
536     } catch (InvalidKeyException | NoSuchAlgorithmException e) {
537       // unreachable in practice
538       throwHandshakeException(e);
539     }
540 
541     handshakeState = InternalState.HANDSHAKE_ALREADY_USED;
542 
543     return new D2DConnectionContextV1(
544         handshakeRole == HandshakeRole.CLIENT ? clientKey : serverKey,
545         handshakeRole == HandshakeRole.CLIENT ? serverKey : clientKey,
546         0 /* initial encode sequence number */,
547         0 /* initial decode sequence number */);
548   }
549 
550   /**
551    * Generates the byte[] encoding of a {@link Ukey2ClientInit} message.
552    *
553    * @throws HandshakeException
554    */
makeClientInitMessage()555   private byte[] makeClientInitMessage() throws HandshakeException {
556     Ukey2ClientInit.Builder clientInit = Ukey2ClientInit.newBuilder();
557     clientInit.setVersion(VERSION);
558     clientInit.setRandom(ByteString.copyFrom(generateRandomNonce()));
559     clientInit.setNextProtocol(NEXT_PROTOCOL);
560 
561     // At the moment, we only support one cipher
562     clientInit.addCipherCommitments(generateP256SHA512Commitment());
563 
564     return clientInit.build().toByteArray();
565   }
566 
567   /**
568    * Generates the byte[] encoding of a {@link Ukey2ServerInit} message.
569    */
makeServerInitMessage()570   private byte[] makeServerInitMessage() {
571     Ukey2ServerInit.Builder serverInit = Ukey2ServerInit.newBuilder();
572     serverInit.setVersion(VERSION);
573     serverInit.setRandom(ByteString.copyFrom(generateRandomNonce()));
574     serverInit.setHandshakeCipher(handshakeCipher.getValue());
575     serverInit.setPublicKey(
576         PublicKeyProtoUtil.encodePublicKey(ourKeyPair.getPublic()).toByteString());
577 
578     return serverInit.build().toByteArray();
579   }
580 
581   /**
582    * Generates a keypair for the provided handshake cipher. Currently only P256_SHA512 is
583    * supported.
584    *
585    * @throws HandshakeException
586    */
genKeyPair(HandshakeCipher cipher)587   private KeyPair genKeyPair(HandshakeCipher cipher) throws HandshakeException {
588     switch (cipher) {
589       case P256_SHA512:
590         return PublicKeyProtoUtil.generateEcP256KeyPair();
591       default:
592         // Should never happen
593         throwHandshakeException("unknown cipher: " + cipher);
594     }
595     return null; // unreachable, but makes compiler happy
596   }
597 
598   /**
599    * Attempts to parse message 1 (which is a wrapped {@link Ukey2ClientInit}). See go/ukey2 for
600    * details.
601    *
602    * @throws AlertException if an error occurs
603    */
parseMessage1(byte[] handshakeMessage)604   private void parseMessage1(byte[] handshakeMessage) throws AlertException, HandshakeException {
605     // Deserialize the protobuf; send a BAD_MESSAGE message if deserialization fails
606     Ukey2Message message = null;
607     try {
608       message = Ukey2Message.parseFrom(handshakeMessage);
609     } catch (InvalidProtocolBufferException e) {
610       throwAlertException(Ukey2Alert.AlertType.BAD_MESSAGE,
611           "Can't parse message 1 " + e.getMessage());
612     }
613 
614     // Verify that message_type == Type.CLIENT_INIT; send a BAD_MESSAGE_TYPE message if mismatch
615     if (!message.hasMessageType() || message.getMessageType() != Ukey2Message.Type.CLIENT_INIT) {
616       throwAlertException(
617           Ukey2Alert.AlertType.BAD_MESSAGE_TYPE,
618           "Expected, but did not find ClientInit message type");
619     }
620 
621     // Deserialize message_data as a ClientInit message; send a BAD_MESSAGE_DATA message if
622     // deserialization fails
623     if (!message.hasMessageData()) {
624       throwAlertException(Ukey2Alert.AlertType.BAD_MESSAGE_DATA,
625           "Expected message data, but didn't find it");
626     }
627     Ukey2ClientInit clientInit = null;
628     try {
629       clientInit = Ukey2ClientInit.parseFrom(message.getMessageData());
630     } catch (InvalidProtocolBufferException e) {
631       throwAlertException(Ukey2Alert.AlertType.BAD_MESSAGE_DATA,
632           "Can't parse message data into ClientInit");
633     }
634 
635     // Check that version == VERSION; send BAD_VERSION message if mismatch
636     if (!clientInit.hasVersion()) {
637       throwAlertException(Ukey2Alert.AlertType.BAD_VERSION, "ClientInit missing version");
638     }
639     if (clientInit.getVersion() != VERSION) {
640       throwAlertException(Ukey2Alert.AlertType.BAD_VERSION, "ClientInit version mismatch");
641     }
642 
643     // Check that random is exactly NONCE_LENGTH_IN_BYTES bytes; send Alert.BAD_RANDOM message if
644     // not.
645     if (!clientInit.hasRandom()) {
646       throwAlertException(Ukey2Alert.AlertType.BAD_RANDOM, "ClientInit missing random");
647     }
648     if (clientInit.getRandom().toByteArray().length != NONCE_LENGTH_IN_BYTES) {
649       throwAlertException(Ukey2Alert.AlertType.BAD_RANDOM, "ClientInit has incorrect nonce length");
650     }
651 
652     // Check to see if any of the handshake_cipher in cipher_commitment are acceptable. Servers
653     // should select the first handshake_cipher that it finds acceptable to support clients
654     // signaling deprecated but supported HandshakeCiphers. If no handshake_cipher is acceptable
655     // (or there are no HandshakeCiphers in the message), the server sends a BAD_HANDSHAKE_CIPHER
656     //  message
657     List<Ukey2ClientInit.CipherCommitment> commitments = clientInit.getCipherCommitmentsList();
658     if (commitments.isEmpty()) {
659       throwAlertException(
660           Ukey2Alert.AlertType.BAD_HANDSHAKE_CIPHER, "ClientInit is missing cipher commitments");
661     }
662     for (Ukey2ClientInit.CipherCommitment commitment : commitments) {
663       if (!commitment.hasHandshakeCipher()
664           || !commitment.hasCommitment()) {
665         throwAlertException(
666             Ukey2Alert.AlertType.BAD_HANDSHAKE_CIPHER,
667             "ClientInit has improperly formatted cipher commitment");
668       }
669 
670       // TODO(aczeskis): for now we only support one cipher, eventually support more
671       if (commitment.getHandshakeCipher() == handshakeCipher.getValue()) {
672         theirCommitment = commitment.getCommitment().toByteArray();
673       }
674     }
675     if (theirCommitment == null) {
676       throwAlertException(Ukey2Alert.AlertType.BAD_HANDSHAKE_CIPHER,
677           "No acceptable commitments found");
678     }
679 
680     // Checks that next_protocol contains a protocol that the server supports. Send a
681     // BAD_NEXT_PROTOCOL message if not. We currently only support one protocol
682     if (!clientInit.hasNextProtocol() || !NEXT_PROTOCOL.equals(clientInit.getNextProtocol())) {
683       throwAlertException(Ukey2Alert.AlertType.BAD_NEXT_PROTOCOL, "Incorrect next protocol");
684     }
685 
686     // Store raw message for AUTH_STRING computation
687     rawMessage1 = handshakeMessage;
688   }
689 
690   /**
691    * Attempts to parse message 2 (which is a wrapped {@link Ukey2ServerInit}). See go/ukey2 for
692    * details.
693    */
parseMessage2(final byte[] handshakeMessage)694   private void parseMessage2(final byte[] handshakeMessage)
695       throws AlertException, HandshakeException {
696     // Deserialize the protobuf; send a BAD_MESSAGE message if deserialization fails
697     Ukey2Message message = null;
698     try {
699       message = Ukey2Message.parseFrom(handshakeMessage);
700     } catch (InvalidProtocolBufferException e) {
701       throwAlertException(Ukey2Alert.AlertType.BAD_MESSAGE,
702           "Can't parse message 2 " + e.getMessage());
703     }
704 
705     // Verify that message_type == Type.SERVER_INIT; send a BAD_MESSAGE_TYPE message if mismatch
706     if (!message.hasMessageType()) {
707       throwAlertException(Ukey2Alert.AlertType.BAD_MESSAGE_TYPE,
708           "Expected, but did not find message type");
709     }
710     if (message.getMessageType() == Ukey2Message.Type.ALERT) {
711       handshakeState = InternalState.HANDSHAKE_ERROR;
712       throwHandshakeMessageFromAlertMessage(message);
713     }
714     if (message.getMessageType() != Ukey2Message.Type.SERVER_INIT) {
715       throwAlertException(
716           Ukey2Alert.AlertType.BAD_MESSAGE_TYPE,
717           "Expected, but did not find SERVER_INIT message type");
718     }
719 
720     // Deserialize message_data as a ServerInit message; send a BAD_MESSAGE_DATA message if
721     // deserialization fails
722     if (!message.hasMessageData()) {
723 
724       throwAlertException(Ukey2Alert.AlertType.BAD_MESSAGE_DATA,
725           "Expected message data, but didn't find it");
726     }
727     Ukey2ServerInit serverInit = null;
728     try {
729       serverInit = Ukey2ServerInit.parseFrom(message.getMessageData());
730     } catch (InvalidProtocolBufferException e) {
731       throwAlertException(Ukey2Alert.AlertType.BAD_MESSAGE_DATA,
732           "Can't parse message data into ServerInit");
733     }
734 
735     // Check that version == VERSION; send BAD_VERSION message if mismatch
736     if (!serverInit.hasVersion()) {
737       throwAlertException(Ukey2Alert.AlertType.BAD_VERSION, "ServerInit missing version");
738     }
739     if (serverInit.getVersion() != VERSION) {
740       throwAlertException(Ukey2Alert.AlertType.BAD_VERSION, "ServerInit version mismatch");
741     }
742 
743     // Check that random is exactly NONCE_LENGTH_IN_BYTES bytes; send Alert.BAD_RANDOM message if
744     // not.
745     if (!serverInit.hasRandom()) {
746       throwAlertException(Ukey2Alert.AlertType.BAD_RANDOM, "ServerInit missing random");
747     }
748     if (serverInit.getRandom().toByteArray().length != NONCE_LENGTH_IN_BYTES) {
749       throwAlertException(Ukey2Alert.AlertType.BAD_RANDOM, "ServerInit has incorrect nonce length");
750     }
751 
752     // Check that handshake_cipher matches a handshake cipher that was sent in
753     // ClientInit.cipher_commitments. If not, send a BAD_HANDSHAKECIPHER message
754     if (!serverInit.hasHandshakeCipher()) {
755       throwAlertException(Ukey2Alert.AlertType.BAD_HANDSHAKE_CIPHER, "No handshake cipher found");
756     }
757     HandshakeCipher serverCipher = null;
758     for (HandshakeCipher cipher : HandshakeCipher.values()) {
759       if (cipher.getValue() == serverInit.getHandshakeCipher()) {
760         serverCipher = cipher;
761         break;
762       }
763     }
764     if (serverCipher == null || serverCipher != handshakeCipher) {
765       throwAlertException(Ukey2Alert.AlertType.BAD_HANDSHAKE_CIPHER,
766           "No acceptable handshake cipher found");
767     }
768 
769     // Check that public_key parses into a correct public key structure. If not, send a
770     // BAD_PUBLIC_KEY message.
771     if (!serverInit.hasPublicKey()) {
772       throwAlertException(Ukey2Alert.AlertType.BAD_PUBLIC_KEY, "No public key found in ServerInit");
773     }
774     theirPublicKey = parseP256PublicKey(serverInit.getPublicKey().toByteArray());
775 
776     // Store raw message for AUTH_STRING computation
777     rawMessage2 = handshakeMessage;
778   }
779 
780   /**
781    * Attempts to parse message 3 (which is a wrapped {@link Ukey2ClientFinished}). See go/ukey2 for
782    * details.
783    */
parseMessage3(final byte[] handshakeMessage)784   private void parseMessage3(final byte[] handshakeMessage) throws HandshakeException {
785     // Deserialize the protobuf; terminate the connection if deserialization fails.
786     Ukey2Message message = null;
787     try {
788       message = Ukey2Message.parseFrom(handshakeMessage);
789     } catch (InvalidProtocolBufferException e) {
790       throwHandshakeException("Can't parse message 3", e);
791     }
792 
793     // Verify that message_type == Type.CLIENT_FINISH; terminate connection if mismatch occurs
794     if (!message.hasMessageType()) {
795       throw new HandshakeException("Expected, but did not find message type");
796     }
797     if (message.getMessageType() == Ukey2Message.Type.ALERT) {
798       throwHandshakeMessageFromAlertMessage(message);
799     }
800     if (message.getMessageType() != Ukey2Message.Type.CLIENT_FINISH) {
801       throwHandshakeException("Expected, but did not find CLIENT_FINISH message type");
802     }
803 
804     // Verify that the hash of the ClientFinished matches the expected commitment from ClientInit.
805     // Terminate the connection if the expected match fails.
806     verifyCommitment(handshakeMessage);
807 
808     // Deserialize message_data as a ClientFinished message; terminate the connection if
809     // deserialization fails.
810     if (!message.hasMessageData()) {
811       throwHandshakeException("Expected message data, but didn't find it");
812     }
813     Ukey2ClientFinished clientFinished = null;
814     try {
815       clientFinished = Ukey2ClientFinished.parseFrom(message.getMessageData());
816     } catch (InvalidProtocolBufferException e) {
817       throwHandshakeException(e);
818     }
819 
820     // Check that public_key parses into a correct public key structure. If not, terminate the
821     // connection.
822     if (!clientFinished.hasPublicKey()) {
823       throwHandshakeException("No public key found in ClientFinished");
824     }
825     try {
826       theirPublicKey = parseP256PublicKey(clientFinished.getPublicKey().toByteArray());
827     } catch (AlertException e) {
828       // Wrap in a HandshakeException because error should not be sent on the wire.
829       throwHandshakeException(e);
830     }
831   }
832 
verifyCommitment(byte[] handshakeMessage)833   private void verifyCommitment(byte[] handshakeMessage) throws HandshakeException {
834     byte[] actualClientFinishHash = null;
835     switch (handshakeCipher) {
836       case P256_SHA512:
837         actualClientFinishHash = sha512(handshakeMessage);
838         break;
839       default:
840         // should be unreachable
841         throwIllegalStateException("Unexpected handshakeCipher");
842     }
843 
844     // Time constant after Java SE 6 Update 17
845     // See http://www.oracle.com/technetwork/java/javase/6u17-141447.html
846     if (!MessageDigest.isEqual(actualClientFinishHash, theirCommitment)) {
847       throwHandshakeException("Commitment does not match");
848     }
849   }
850 
throwHandshakeMessageFromAlertMessage(Ukey2Message message)851   private void throwHandshakeMessageFromAlertMessage(Ukey2Message message)
852       throws HandshakeException {
853     if (message.hasMessageData()) {
854       Ukey2Alert alert = null;
855       try {
856         alert = Ukey2Alert.parseFrom(message.getMessageData());
857       } catch (InvalidProtocolBufferException e) {
858         throwHandshakeException("Cannot parse alert message", e);
859       }
860 
861       if (alert.hasType() && alert.hasErrorMessage()) {
862         throwHandshakeException(
863             "Received Alert message. Type: "
864                 + alert.getType()
865                 + " Error Message: "
866                 + alert.getErrorMessage());
867       } else if (alert.hasType()) {
868         throwHandshakeException("Received Alert message. Type: " + alert.getType());
869       }
870     }
871 
872     throwHandshakeException("Received empty Alert Message");
873   }
874 
875   /**
876    * Parses an encoded public P256 key.
877    */
parseP256PublicKey(byte[] encodedPublicKey)878   private PublicKey parseP256PublicKey(byte[] encodedPublicKey)
879       throws AlertException, HandshakeException {
880     try {
881       return PublicKeyProtoUtil.parsePublicKey(GenericPublicKey.parseFrom(encodedPublicKey));
882     } catch (InvalidProtocolBufferException | InvalidKeySpecException e) {
883       throwAlertException(Ukey2Alert.AlertType.BAD_PUBLIC_KEY,
884           "Cannot parse public key: " + e.getMessage());
885       return null; // unreachable, but makes compiler happy
886     }
887   }
888 
889   /**
890    * Generates a {@link CipherCommitment} for the P256_SHA512 cipher.
891    */
generateP256SHA512Commitment()892   private CipherCommitment generateP256SHA512Commitment() throws HandshakeException {
893     // Generate the corresponding finished message if it's not done yet
894     if (!rawMessage3Map.containsKey(HandshakeCipher.P256_SHA512)) {
895       generateP256SHA512ClientFinished(ourKeyPair);
896     }
897 
898     CipherCommitment.Builder cipherCommitment = CipherCommitment.newBuilder();
899     cipherCommitment.setHandshakeCipher(UkeyProto.Ukey2HandshakeCipher.P256_SHA512);
900     cipherCommitment.setCommitment(
901         ByteString.copyFrom(sha512(rawMessage3Map.get(HandshakeCipher.P256_SHA512))));
902 
903     return cipherCommitment.build();
904   }
905 
906   /**
907    * Generates and records a {@link Ukey2ClientFinished} message for the P256_SHA512 cipher.
908    */
generateP256SHA512ClientFinished(KeyPair p256KeyPair)909   private Ukey2ClientFinished generateP256SHA512ClientFinished(KeyPair p256KeyPair) {
910     byte[] encodedKey = PublicKeyProtoUtil.encodePublicKey(p256KeyPair.getPublic()).toByteArray();
911 
912     Ukey2ClientFinished.Builder clientFinished = Ukey2ClientFinished.newBuilder();
913     clientFinished.setPublicKey(ByteString.copyFrom(encodedKey));
914 
915     rawMessage3Map.put(
916         HandshakeCipher.P256_SHA512,
917         makeUkey2Message(Ukey2Message.Type.CLIENT_FINISH, clientFinished.build().toByteArray()));
918 
919     return clientFinished.build();
920   }
921 
922   /**
923    * Generates the serialized representation of a {@link Ukey2Message} based on the provided type
924    * and data.
925    */
makeUkey2Message(Ukey2Message.Type messageType, byte[] messageData)926   private byte[] makeUkey2Message(Ukey2Message.Type messageType, byte[] messageData) {
927     Ukey2Message.Builder message = Ukey2Message.newBuilder();
928 
929     switch (messageType) {
930       case ALERT:
931       case CLIENT_INIT:
932       case SERVER_INIT:
933       case CLIENT_FINISH:
934         // fall through intentional; valid message types
935         break;
936       default:
937         throwIllegalArgumentException("Invalid message type: " + messageType);
938     }
939     message.setMessageType(messageType);
940 
941     // Alerts a blank message data field
942     if (messageType != Ukey2Message.Type.ALERT) {
943       if (messageData == null || messageData.length == 0) {
944         throwIllegalArgumentException("Cannot send empty message data for non-alert messages");
945       }
946       message.setMessageData(ByteString.copyFrom(messageData));
947     }
948 
949     return message.build().toByteArray();
950   }
951 
952   /**
953    * Returns a {@link Ukey2Alert} message of given type and having the loggable additional data if
954    * present.
955    */
makeAlertMessage(Ukey2Alert.AlertType alertType, @Nullable String loggableAdditionalData)956   private Ukey2Alert makeAlertMessage(Ukey2Alert.AlertType alertType,
957       @Nullable String loggableAdditionalData) throws HandshakeException {
958     switch (alertType) {
959       case BAD_MESSAGE:
960       case BAD_MESSAGE_TYPE:
961       case INCORRECT_MESSAGE:
962       case BAD_MESSAGE_DATA:
963       case BAD_VERSION:
964       case BAD_RANDOM:
965       case BAD_HANDSHAKE_CIPHER:
966       case BAD_NEXT_PROTOCOL:
967       case BAD_PUBLIC_KEY:
968       case INTERNAL_ERROR:
969         // fall through intentional; valid alert types
970         break;
971       default:
972         throwHandshakeException("Unknown alert type: " + alertType);
973     }
974 
975     Ukey2Alert.Builder alert = Ukey2Alert.newBuilder();
976     alert.setType(alertType);
977 
978     if (loggableAdditionalData != null) {
979       alert.setErrorMessage(loggableAdditionalData);
980     }
981 
982     return alert.build();
983   }
984 
985   /**
986    * Generates a cryptoraphically random nonce of NONCE_LENGTH_IN_BYTES bytes.
987    */
generateRandomNonce()988   private static byte[] generateRandomNonce() {
989     SecureRandom rng = new SecureRandom();
990     byte[] randomNonce = new byte[NONCE_LENGTH_IN_BYTES];
991     rng.nextBytes(randomNonce);
992     return randomNonce;
993   }
994 
995   /**
996    * Handy wrapper to do SHA512.
997    */
sha512(byte[] input)998   private byte[] sha512(byte[] input) throws HandshakeException {
999     MessageDigest sha512;
1000     try {
1001       sha512 = MessageDigest.getInstance("SHA-512");
1002       return sha512.digest(input);
1003     } catch (NoSuchAlgorithmException e) {
1004       throwHandshakeException("No security provider initialized yet?", e);
1005       return null; // unreachable in practice, but makes compiler happy
1006     }
1007   }
1008 
1009   // Exception wrappers that remember to set the handshake state to ERROR
1010 
throwAlertException(Ukey2Alert.AlertType alertType, String alertLogStatement)1011   private void throwAlertException(Ukey2Alert.AlertType alertType, String alertLogStatement)
1012       throws AlertException, HandshakeException {
1013     handshakeState = InternalState.HANDSHAKE_ERROR;
1014     throw new AlertException(alertLogStatement, makeAlertMessage(alertType, alertLogStatement));
1015   }
1016 
throwHandshakeException(String logMessage)1017   private void throwHandshakeException(String logMessage) throws HandshakeException {
1018     handshakeState = InternalState.HANDSHAKE_ERROR;
1019     throw new HandshakeException(logMessage);
1020   }
1021 
throwHandshakeException(Exception e)1022   private void throwHandshakeException(Exception e) throws HandshakeException {
1023     handshakeState = InternalState.HANDSHAKE_ERROR;
1024     throw new HandshakeException(e);
1025   }
1026 
throwHandshakeException(String logMessage, Exception e)1027   private void throwHandshakeException(String logMessage, Exception e) throws HandshakeException {
1028     handshakeState = InternalState.HANDSHAKE_ERROR;
1029     throw new HandshakeException(logMessage, e);
1030   }
1031 
throwIllegalStateException(String logMessage)1032   private void throwIllegalStateException(String logMessage) {
1033     handshakeState = InternalState.HANDSHAKE_ERROR;
1034     throw new IllegalStateException(logMessage);
1035   }
1036 
throwIllegalArgumentException(String logMessage)1037   private void throwIllegalArgumentException(String logMessage) {
1038     handshakeState = InternalState.HANDSHAKE_ERROR;
1039     throw new IllegalArgumentException(logMessage);
1040   }
1041 }
1042