1 // Copyright 2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "securegcm/ukey2_handshake.h"
16 
17 #include <sstream>
18 
19 #include "securegcm/d2d_crypto_ops.h"
20 #include "securemessage/public_key_proto_util.h"
21 
22 namespace securegcm {
23 
24 using securemessage::ByteBuffer;
25 using securemessage::CryptoOps;
26 using securemessage::GenericPublicKey;
27 using securemessage::PublicKeyProtoUtil;
28 
29 namespace {
30 
31 // Salt value used to derive client and server keys for next protocol.
32 const char kUkey2HkdfSalt[] = "UKEY2 v1 next";
33 
34 // Salt value used to derive verification string.
35 const char kUkey2VerificationStringSalt[] = "UKEY2 v1 auth";
36 
37 // Maximum version of the handshake supported by this class.
38 const uint32_t kVersion = 1;
39 
40 // Random nonce is fixed at 32 bytes (as per go/ukey2).
41 const uint32_t kNonceLengthInBytes = 32;
42 
43 // Currently, we only support one next protocol.
44 const char kNextProtocol[] = "AES_256_CBC-HMAC_SHA256";
45 
46 // Creates the appropriate KeyPair for |cipher|.
GenerateKeyPair(UKey2Handshake::HandshakeCipher cipher)47 std::unique_ptr<CryptoOps::KeyPair> GenerateKeyPair(
48     UKey2Handshake::HandshakeCipher cipher) {
49   switch (cipher) {
50     case UKey2Handshake::HandshakeCipher::P256_SHA512:
51       return CryptoOps::GenerateEcP256KeyPair();
52     default:
53       return nullptr;
54   }
55 }
56 
57 // Parses a CryptoOps::PublicKey from a serialized GenericPublicKey.
ParsePublicKey(const string & serialized_generic_public_key)58 std::unique_ptr<securemessage::CryptoOps::PublicKey> ParsePublicKey(
59     const string& serialized_generic_public_key) {
60   GenericPublicKey generic_public_key;
61   if (!generic_public_key.ParseFromString(serialized_generic_public_key)) {
62     return nullptr;
63   }
64   return PublicKeyProtoUtil::ParsePublicKey(generic_public_key);
65 }
66 
67 }  // namespace
68 
69 // static.
ForInitiator(HandshakeCipher cipher)70 std::unique_ptr<UKey2Handshake> UKey2Handshake::ForInitiator(
71     HandshakeCipher cipher) {
72   return std::unique_ptr<UKey2Handshake>(
73       new UKey2Handshake(InternalState::CLIENT_START, cipher));
74 }
75 
76 // static.
ForResponder(HandshakeCipher cipher)77 std::unique_ptr<UKey2Handshake> UKey2Handshake::ForResponder(
78     HandshakeCipher cipher) {
79   return std::unique_ptr<UKey2Handshake>(
80       new UKey2Handshake(InternalState::SERVER_START, cipher));
81 }
82 
UKey2Handshake(InternalState state,HandshakeCipher cipher)83 UKey2Handshake::UKey2Handshake(InternalState state, HandshakeCipher cipher)
84     : handshake_state_(state),
85       handshake_cipher_(cipher),
86       handshake_role_(state == InternalState::CLIENT_START
87                           ? HandshakeRole::CLIENT
88                           : HandshakeRole::SERVER),
89       our_key_pair_(GenerateKeyPair(cipher)) {}
90 
GetHandshakeState() const91 UKey2Handshake::State UKey2Handshake::GetHandshakeState() const {
92   switch (handshake_state_) {
93     case InternalState::CLIENT_START:
94     case InternalState::CLIENT_WAITING_FOR_SERVER_INIT:
95     case InternalState::CLIENT_AFTER_SERVER_INIT:
96     case InternalState::SERVER_START:
97     case InternalState::SERVER_AFTER_CLIENT_INIT:
98     case InternalState::SERVER_WAITING_FOR_CLIENT_FINISHED:
99       // Fallthrough intended -- these are all in-progress states.
100       return State::kInProgress;
101     case InternalState::HANDSHAKE_VERIFICATION_NEEDED:
102       return State::kVerificationNeeded;
103     case InternalState::HANDSHAKE_VERIFICATION_IN_PROGRESS:
104       return State::kVerificationInProgress;
105     case InternalState::HANDSHAKE_FINISHED:
106       return State::kFinished;
107     case InternalState::HANDSHAKE_ALREADY_USED:
108       return State::kAlreadyUsed;
109     case InternalState::HANDSHAKE_ERROR:
110       return State::kError;
111     default:
112       // Unreachable.
113       return State::kError;
114   }
115 }
116 
GetLastError() const117 const string& UKey2Handshake::GetLastError() const {
118   return last_error_;
119 }
120 
GetNextHandshakeMessage()121 std::unique_ptr<string> UKey2Handshake::GetNextHandshakeMessage() {
122   switch (handshake_state_) {
123     case InternalState::CLIENT_START: {
124       std::unique_ptr<string> client_init = MakeClientInitUkey2Message();
125       if (!client_init) {
126         // |last_error_| is already set.
127         return nullptr;
128       }
129 
130       wrapped_client_init_ = *client_init;
131       handshake_state_ = InternalState::CLIENT_WAITING_FOR_SERVER_INIT;
132       return client_init;
133     }
134 
135     case InternalState::SERVER_AFTER_CLIENT_INIT: {
136       std::unique_ptr<string> server_init = MakeServerInitUkey2Message();
137       if (!server_init) {
138         // |last_error_| is already set.
139         return nullptr;
140       }
141 
142       wrapped_server_init_ = *server_init;
143       handshake_state_ = InternalState::SERVER_WAITING_FOR_CLIENT_FINISHED;
144       return server_init;
145     }
146 
147     case InternalState::CLIENT_AFTER_SERVER_INIT: {
148       // Make sure we have a message 3 for the chosen cipher.
149       if (raw_message3_map_.count(handshake_cipher_) == 0) {
150         std::ostringstream stream;
151         stream << "Client state is CLIENT_AFTER_SERVER_INIT, and cipher is "
152                << static_cast<int>(handshake_cipher_)
153                << ", but no corresponding raw "
154                << "[Client Finished] message has been generated.";
155         SetError(stream.str());
156         return nullptr;
157       }
158       handshake_state_ = InternalState::HANDSHAKE_VERIFICATION_NEEDED;
159       return std::unique_ptr<string>(
160           new string(raw_message3_map_[handshake_cipher_]));
161     }
162 
163     default: {
164       std::ostringstream stream;
165       stream << "Cannot get next message in state "
166              << static_cast<int>(handshake_state_);
167       SetError(stream.str());
168       return nullptr;
169     }
170   }
171 }
172 
173 UKey2Handshake::ParseResult
ParseHandshakeMessage(const string & handshake_message)174 UKey2Handshake::ParseHandshakeMessage(const string& handshake_message) {
175   switch (handshake_state_) {
176     case InternalState::SERVER_START:
177       return ParseClientInitUkey2Message(handshake_message);
178     case InternalState::CLIENT_WAITING_FOR_SERVER_INIT:
179       return ParseServerInitUkey2Message(handshake_message);
180     case InternalState::SERVER_WAITING_FOR_CLIENT_FINISHED:
181       return ParseClientFinishUkey2Message(handshake_message);
182     default:
183       std::ostringstream stream;
184       stream << "Cannot parse message in state "
185              << static_cast<int>(handshake_state_);
186       SetError(stream.str());
187       return {false, nullptr};
188   }
189 }
190 
GetVerificationString(int byte_length)191 std::unique_ptr<string> UKey2Handshake::GetVerificationString(int byte_length) {
192   if (byte_length < 1 || byte_length > 32) {
193     SetError("Minimum length is 1 byte, max is 32 bytes.");
194     return nullptr;
195   }
196 
197   if (handshake_state_ != InternalState::HANDSHAKE_VERIFICATION_NEEDED) {
198     std::ostringstream stream;
199     stream << "Unexpected state: " << static_cast<int>(handshake_state_);
200     SetError(stream.str());
201     return nullptr;
202   }
203 
204   if (!our_key_pair_ || !our_key_pair_->private_key || !their_public_key_) {
205     SetError("One of our private key or their public key is null.");
206     return nullptr;
207   }
208 
209   switch (handshake_cipher_) {
210     case HandshakeCipher::P256_SHA512:
211       derived_secret_key_ = CryptoOps::KeyAgreementSha256(
212           *(our_key_pair_->private_key), *their_public_key_);
213       break;
214     default:
215       // Unreachable.
216       return nullptr;
217   }
218 
219   if (!derived_secret_key_) {
220     SetError("Failed to derive shared secret key.");
221     return nullptr;
222   }
223 
224   std::unique_ptr<string> auth_string = CryptoOps::Hkdf(
225       derived_secret_key_->data().String(),
226       string(kUkey2VerificationStringSalt, sizeof(kUkey2VerificationStringSalt)),
227       wrapped_client_init_ + wrapped_server_init_);
228 
229   handshake_state_ = InternalState::HANDSHAKE_VERIFICATION_IN_PROGRESS;
230   return auth_string;
231 }
232 
VerifyHandshake()233 bool UKey2Handshake::VerifyHandshake() {
234   if (handshake_state_ != InternalState::HANDSHAKE_VERIFICATION_IN_PROGRESS) {
235     std::ostringstream stream;
236     stream << "Unexpected state: " << static_cast<int>(handshake_state_);
237     SetError(stream.str());
238     return false;
239   }
240 
241   handshake_state_ = InternalState::HANDSHAKE_FINISHED;
242   return true;
243 }
244 
ToConnectionContext()245 std::unique_ptr<D2DConnectionContextV1> UKey2Handshake::ToConnectionContext() {
246   if (InternalState::HANDSHAKE_FINISHED != handshake_state_) {
247     std::ostringstream stream;
248     stream << "ToConnectionContext can only be called when handshake is "
249            << "completed, but current state is "
250            << static_cast<int>(handshake_state_);
251     SetError(stream.str());
252     return nullptr;
253   }
254 
255   if (!derived_secret_key_) {
256     SetError("Derived key is null.");
257     return nullptr;
258   }
259 
260   string info = wrapped_client_init_ + wrapped_server_init_;
261   std::unique_ptr<string> master_key_data = CryptoOps::Hkdf(
262       derived_secret_key_->data().String(), kUkey2HkdfSalt, info);
263 
264   if (!master_key_data) {
265     SetError("Failed to create master key.");
266     return nullptr;
267   }
268 
269   // Derive separate encode keys for both client and server.
270   CryptoOps::SecretKey master_key(*master_key_data, CryptoOps::AES_256_KEY);
271   std::unique_ptr<CryptoOps::SecretKey> client_key =
272       D2DCryptoOps::DeriveNewKeyForPurpose(master_key, "client");
273   std::unique_ptr<CryptoOps::SecretKey> server_key =
274       D2DCryptoOps::DeriveNewKeyForPurpose(master_key, "server");
275   if (!client_key || !server_key) {
276     SetError("Failed to derive client or server key.");
277     return nullptr;
278   }
279 
280   handshake_state_ = InternalState::HANDSHAKE_ALREADY_USED;
281 
282   return std::unique_ptr<D2DConnectionContextV1>(new D2DConnectionContextV1(
283       handshake_role_ == HandshakeRole::CLIENT ? *client_key : *server_key,
284       handshake_role_ == HandshakeRole::CLIENT ? *server_key : *client_key,
285       0 /* initial encode sequence number */,
286       0 /* initial decode sequence number */));
287 }
288 
ParseClientInitUkey2Message(const string & handshake_message)289 UKey2Handshake::ParseResult UKey2Handshake::ParseClientInitUkey2Message(
290     const string& handshake_message) {
291   // Deserialize the protobuf.
292   Ukey2Message message;
293   if (!message.ParseFromString(handshake_message)) {
294     return CreateFailedResultWithAlert(Ukey2Alert::BAD_MESSAGE,
295                                        "Can't parse message 1.");
296   }
297 
298   // Verify that message_type == CLIENT_INIT.
299   if (!message.has_message_type() ||
300       message.message_type() != Ukey2Message::CLIENT_INIT) {
301     return CreateFailedResultWithAlert(
302         Ukey2Alert::BAD_MESSAGE,
303         "Expected, but did not find ClientInit message type.");
304   }
305 
306   // Derserialize message_data as a ClientInit message.
307   if (!message.has_message_data()) {
308     return CreateFailedResultWithAlert(
309         Ukey2Alert::BAD_MESSAGE_DATA,
310         "Expected message data, but did not find it.");
311   }
312 
313   Ukey2ClientInit client_init;
314   if (!client_init.ParseFromString(message.message_data())) {
315     return CreateFailedResultWithAlert(
316         Ukey2Alert::BAD_MESSAGE_DATA,
317         "Can't parse message data into ClientInit.");
318   }
319 
320   // Check that version == VERSION.
321   if (!client_init.has_version()) {
322     return CreateFailedResultWithAlert(Ukey2Alert::BAD_VERSION,
323                                        "ClientInit missing version.");
324   }
325   if (client_init.version() != kVersion) {
326     return CreateFailedResultWithAlert(Ukey2Alert::BAD_VERSION,
327                                        "ClientInit version mismatch.");
328   }
329 
330   // Check that random is exactly kNonceLengthInBytes.
331   if (!client_init.has_random()) {
332     return CreateFailedResultWithAlert(Ukey2Alert::BAD_RANDOM,
333                                        "ClientInit missing random.");
334   }
335   if (client_init.random().length() != kNonceLengthInBytes) {
336     return CreateFailedResultWithAlert(
337         Ukey2Alert::BAD_RANDOM, "ClientInit has incorrect nonce length.");
338   }
339 
340   // Check to see if any of the handshake_cipher in handshake_cipher_commitment
341   // are acceptable. Servers should select the first ahdnshake_cipher that it
342   // finds acceptable to support clients signalling deprecated but supported
343   // HandshakeCiphers. If no handshake_cipher is acceptable (or there are no
344   // HandshakeCiphers in the message), the server sends a BAD_HANDSHAKE_CIPHER
345   // alert message.
346   if (client_init.cipher_commitments_size() == 0) {
347     return CreateFailedResultWithAlert(
348         Ukey2Alert::BAD_HANDSHAKE_CIPHER,
349         "ClientInit is missing cipher commitments.");
350   }
351 
352   for (const Ukey2ClientInit::CipherCommitment& commitment :
353        client_init.cipher_commitments()) {
354     if (!commitment.has_handshake_cipher() || !commitment.has_commitment() ||
355         commitment.commitment().empty()) {
356       return CreateFailedResultWithAlert(
357           Ukey2Alert::BAD_HANDSHAKE_CIPHER,
358           "ClientInit has improperly formatted cipher commitment.");
359     }
360 
361     // TODO(aczeskis): for now we only support one cipher, eventually support
362     // more.
363     if (commitment.handshake_cipher() == static_cast<int>(handshake_cipher_)) {
364       peer_commitment_ = commitment.commitment();
365     }
366   }
367 
368   if (peer_commitment_.empty()) {
369     return CreateFailedResultWithAlert(Ukey2Alert::BAD_HANDSHAKE_CIPHER,
370                                        "No acceptable commitments found");
371   }
372 
373   // Checks that next_protocol contains a protocol that the server supports. We
374   // currently only support one protocol.
375   if (!client_init.has_next_protocol() ||
376       client_init.next_protocol() != kNextProtocol) {
377     return CreateFailedResultWithAlert(Ukey2Alert::BAD_NEXT_PROTOCOL,
378                                        "Incorrect next protocol.");
379   }
380 
381   // Store raw message for AUTH_STRING computation.
382   wrapped_client_init_ = handshake_message;
383   handshake_state_ = InternalState::SERVER_AFTER_CLIENT_INIT;
384   return CreateSuccessResult();
385 }
386 
ParseServerInitUkey2Message(const string & handshake_message)387 UKey2Handshake::ParseResult UKey2Handshake::ParseServerInitUkey2Message(
388     const string& handshake_message) {
389   // Deserialize the protobuf.
390   Ukey2Message message;
391   if (!message.ParseFromString(handshake_message)) {
392     return CreateFailedResultWithAlert(Ukey2Alert::BAD_MESSAGE,
393                                        "Can't parse message 2.");
394   }
395 
396   // Verify that message_type == SERVER_INIT.
397   if (!message.has_message_type() ||
398       message.message_type() != Ukey2Message::SERVER_INIT) {
399     return CreateFailedResultWithAlert(
400         Ukey2Alert::BAD_MESSAGE,
401         "Expected, but did not find SERVER_INIT message type.");
402   }
403 
404   // Derserialize message_data as a ServerInit message.
405   if (!message.has_message_data()) {
406     return CreateFailedResultWithAlert(
407         Ukey2Alert::BAD_MESSAGE_DATA,
408         "Expected message data, but did not find it.");
409   }
410 
411   Ukey2ServerInit server_init;
412   if (!server_init.ParseFromString(message.message_data())) {
413     return CreateFailedResultWithAlert(
414         Ukey2Alert::BAD_MESSAGE_DATA,
415         "Can't parse message data into ServerInit.");
416   }
417 
418   // Check that version == VERSION.
419   if (!server_init.has_version()) {
420     return CreateFailedResultWithAlert(Ukey2Alert::BAD_VERSION,
421                                        "ServerInit missing version.");
422   }
423   if (server_init.version() != kVersion) {
424     return CreateFailedResultWithAlert(Ukey2Alert::BAD_VERSION,
425                                        "ServerInit version mismatch.");
426   }
427 
428   // Check that random is exactly kNonceLengthInBytes.
429   if (!server_init.has_random()) {
430     return CreateFailedResultWithAlert(Ukey2Alert::BAD_RANDOM,
431                                        "ServerInit missing random.");
432   }
433   if (server_init.random().length() != kNonceLengthInBytes) {
434     return CreateFailedResultWithAlert(
435         Ukey2Alert::BAD_RANDOM, "ServerInit has incorrect nonce length.");
436   }
437 
438   // Check that the handshake_cipher matches a handshake cipher that was sent in
439   // ClientInit::cipher_commitments().
440   if (!server_init.has_handshake_cipher()) {
441     return CreateFailedResultWithAlert(Ukey2Alert::BAD_HANDSHAKE_CIPHER,
442                                        "No handshake cipher found.");
443   }
444 
445   Ukey2HandshakeCipher cipher = server_init.handshake_cipher();
446   HandshakeCipher server_cipher;
447   switch (static_cast<HandshakeCipher>(cipher)) {
448     case HandshakeCipher::P256_SHA512:
449       server_cipher = static_cast<HandshakeCipher>(cipher);
450       break;
451     default:
452       return CreateFailedResultWithAlert(Ukey2Alert::BAD_HANDSHAKE_CIPHER,
453                                          "No acceptable handshake found.");
454   }
455 
456   // Check that public_key parses into a correct public key structure.
457   if (!server_init.has_public_key()) {
458     return CreateFailedResultWithAlert(Ukey2Alert::BAD_PUBLIC_KEY,
459                                        "No public key found in ServerInit.");
460   }
461 
462   their_public_key_ = ParsePublicKey(server_init.public_key());
463   if (!their_public_key_) {
464     return CreateFailedResultWithAlert(Ukey2Alert::BAD_PUBLIC_KEY,
465                                        "Failed to parse public key.");
466   }
467 
468   // Store raw message for AUTH_STRING computation.
469   wrapped_server_init_ = handshake_message;
470   handshake_state_ = InternalState::CLIENT_AFTER_SERVER_INIT;
471   return CreateSuccessResult();
472 }
473 
ParseClientFinishUkey2Message(const string & handshake_message)474 UKey2Handshake::ParseResult UKey2Handshake::ParseClientFinishUkey2Message(
475     const string& handshake_message) {
476   // Deserialize the protobuf.
477   Ukey2Message message;
478   if (!message.ParseFromString(handshake_message)) {
479     return CreateFailedResultWithoutAlert("Can't parse message 3.");
480   }
481 
482   // Verify that message_type == CLIENT_FINISH.
483   if (!message.has_message_type() ||
484       message.message_type() != Ukey2Message::CLIENT_FINISH) {
485     return CreateFailedResultWithoutAlert(
486         "Expected, but did not find CLIENT_FINISH message type.");
487   }
488 
489   // Verify that the hash of the CLientFinished message matches the expected
490   // commitment from ClientInit.
491   if (!VerifyCommitment(handshake_message)) {
492     return CreateFailedResultWithoutAlert(last_error_);
493   }
494 
495   // Deserialize message_data as a ClientFinished message.
496   if (!message.has_message_data()) {
497     return CreateFailedResultWithoutAlert(
498         "Expected message data, but didn't find it.");
499   }
500 
501   Ukey2ClientFinished client_finished;
502   if (!client_finished.ParseFromString(message.message_data())) {
503     return CreateFailedResultWithoutAlert("Failed to parse ClientFinished.");
504   }
505 
506   // Check that public_key parses into a correct public key structure.
507   if (!client_finished.has_public_key()) {
508     return CreateFailedResultWithoutAlert(
509         "No public key found in ClientFinished.");
510   }
511 
512   their_public_key_ = ParsePublicKey(client_finished.public_key());
513   if (!their_public_key_) {
514     return CreateFailedResultWithoutAlert("Failed to parse public key.");
515   }
516 
517   handshake_state_ = InternalState::HANDSHAKE_VERIFICATION_NEEDED;
518   return CreateSuccessResult();
519 }
520 
CreateFailedResultWithAlert(Ukey2Alert::AlertType alert_type,const string & error_message)521 UKey2Handshake::ParseResult UKey2Handshake::CreateFailedResultWithAlert(
522     Ukey2Alert::AlertType alert_type, const string& error_message) {
523   if (!Ukey2Alert_AlertType_IsValid(alert_type)) {
524     std::ostringstream stream;
525     stream << "Unknown alert type: " << static_cast<int>(alert_type);
526     SetError(stream.str());
527     return {false, nullptr};
528   }
529 
530   Ukey2Alert alert;
531   alert.set_type(alert_type);
532   if (!error_message.empty()) {
533     alert.set_error_message(error_message);
534   }
535 
536   std::unique_ptr<string> alert_message =
537       MakeUkey2Message(Ukey2Message::ALERT, alert.SerializeAsString());
538 
539   SetError(error_message);
540   ParseResult result{false, std::move(alert_message)};
541   return result;
542 }
543 
544 UKey2Handshake::ParseResult
CreateFailedResultWithoutAlert(const string & error_message)545 UKey2Handshake::CreateFailedResultWithoutAlert(const string& error_message) {
546   SetError(error_message);
547   return {false, nullptr};
548 }
549 
CreateSuccessResult()550 UKey2Handshake::ParseResult UKey2Handshake::CreateSuccessResult() {
551   return {true, nullptr};
552 }
553 
VerifyCommitment(const string & handshake_message)554 bool UKey2Handshake::VerifyCommitment(const string& handshake_message) {
555   std::unique_ptr<ByteBuffer> actual_client_finish_hash;
556   switch (handshake_cipher_) {
557     case HandshakeCipher::P256_SHA512:
558       actual_client_finish_hash =
559           CryptoOps::Sha512(ByteBuffer(handshake_message));
560       break;
561     default:
562       // Unreachable.
563       return false;
564   }
565 
566   if (!actual_client_finish_hash) {
567     SetError("Failed to hash ClientFinish message.");
568     return false;
569   }
570 
571   // Note: Equals() is a time constant comparison operation.
572   if (!actual_client_finish_hash->Equals(peer_commitment_)) {
573     SetError("Failed to verify commitment.");
574     return false;
575   }
576 
577   return true;
578 }
579 
580 std::unique_ptr<Ukey2ClientInit::CipherCommitment>
GenerateP256Sha512Commitment()581 UKey2Handshake::GenerateP256Sha512Commitment() {
582   // Generate the corresponding ClientFinished message if it's not done yet.
583   if (raw_message3_map_.count(HandshakeCipher::P256_SHA512) == 0) {
584     if (!our_key_pair_ || !our_key_pair_->public_key) {
585       SetError("Invalid public key.");
586       return nullptr;
587     }
588 
589     std::unique_ptr<GenericPublicKey> generic_public_key =
590         PublicKeyProtoUtil::EncodePublicKey(*(our_key_pair_->public_key));
591     if (!generic_public_key) {
592       SetError("Failed to encode generic public key.");
593       return nullptr;
594     }
595 
596     Ukey2ClientFinished client_finished;
597     client_finished.set_public_key(generic_public_key->SerializeAsString());
598     std::unique_ptr<string> serialized_ukey2_message = MakeUkey2Message(
599         Ukey2Message::CLIENT_FINISH, client_finished.SerializeAsString());
600     if (!serialized_ukey2_message) {
601       SetError("Failed to serialized Ukey2Message.");
602       return nullptr;
603     }
604 
605     raw_message3_map_[HandshakeCipher::P256_SHA512] = *serialized_ukey2_message;
606   }
607 
608   // Create the SHA512 commitment from raw message 3.
609   std::unique_ptr<ByteBuffer> commitment = CryptoOps::Sha512(
610       ByteBuffer(raw_message3_map_[HandshakeCipher::P256_SHA512]));
611   if (!commitment) {
612     SetError("Failed to hash message for commitment.");
613     return nullptr;
614   }
615 
616   // Wrap the commitment in a proto.
617   std::unique_ptr<Ukey2ClientInit::CipherCommitment>
618       handshake_cipher_commitment(new Ukey2ClientInit::CipherCommitment());
619   handshake_cipher_commitment->set_handshake_cipher(P256_SHA512);
620   handshake_cipher_commitment->set_commitment(commitment->String());
621 
622   return handshake_cipher_commitment;
623 }
624 
MakeClientInitUkey2Message()625 std::unique_ptr<string> UKey2Handshake::MakeClientInitUkey2Message() {
626   std::unique_ptr<ByteBuffer> nonce =
627       CryptoOps::SecureRandom(kNonceLengthInBytes);
628   if (!nonce) {
629     SetError("Failed to generate nonce.");
630     return nullptr;
631   }
632 
633   Ukey2ClientInit client_init;
634   client_init.set_version(kVersion);
635   client_init.set_random(nonce->String());
636   client_init.set_next_protocol(kNextProtocol);
637 
638   // At the moment, we only support one cipher.
639   std::unique_ptr<Ukey2ClientInit::CipherCommitment>
640       handshake_cipher_commitment = GenerateP256Sha512Commitment();
641   if (!handshake_cipher_commitment) {
642     // |last_error_| already set.
643     return nullptr;
644   }
645   *(client_init.add_cipher_commitments()) = *handshake_cipher_commitment;
646 
647   return MakeUkey2Message(Ukey2Message::CLIENT_INIT,
648                           client_init.SerializeAsString());
649 }
650 
MakeServerInitUkey2Message()651 std::unique_ptr<string> UKey2Handshake::MakeServerInitUkey2Message() {
652   std::unique_ptr<ByteBuffer> nonce =
653       CryptoOps::SecureRandom(kNonceLengthInBytes);
654   if (!nonce) {
655     SetError("Failed to generate nonce.");
656     return nullptr;
657   }
658 
659   if (!our_key_pair_ || !our_key_pair_->public_key) {
660     SetError("Invalid key pair.");
661     return nullptr;
662   }
663 
664   std::unique_ptr<GenericPublicKey> public_key =
665       PublicKeyProtoUtil::EncodePublicKey(*(our_key_pair_->public_key));
666   if (!public_key) {
667     SetError("Failed to encode public key.");
668     return nullptr;
669   }
670 
671   Ukey2ServerInit server_init;
672   server_init.set_version(kVersion);
673   server_init.set_random(nonce->String());
674   server_init.set_handshake_cipher(
675       static_cast<Ukey2HandshakeCipher>(handshake_cipher_));
676   server_init.set_public_key(public_key->SerializeAsString());
677 
678   return MakeUkey2Message(Ukey2Message::SERVER_INIT,
679                           server_init.SerializeAsString());
680 }
681 
682 // Generates the serialized representation of a Ukey2Message based on the
683 // provided |type| and |data|. On error, returns nullptr and writes error
684 // message to |out_error|.
MakeUkey2Message(Ukey2Message::Type type,const string & data)685 std::unique_ptr<string> UKey2Handshake::MakeUkey2Message(
686     Ukey2Message::Type type, const string& data) {
687   Ukey2Message message;
688   if (!Ukey2Message::Type_IsValid(type)) {
689     std::ostringstream stream;
690     stream << "Invalid message type: " << type;
691     SetError(stream.str());
692     return nullptr;
693   }
694   message.set_message_type(type);
695 
696   // Only ALERT messages can have a blank data field.
697   if (type != Ukey2Message::ALERT) {
698     if (data.length() == 0) {
699       SetError("Cannot send empty message data for non-alert messages");
700       return nullptr;
701     }
702   }
703   message.set_message_data(data);
704 
705   std::unique_ptr<string> serialized(new string());
706   message.SerializeToString(serialized.get());
707   return serialized;
708 }
709 
SetError(const string & error_message)710 void UKey2Handshake::SetError(const string& error_message) {
711   handshake_state_ = InternalState::HANDSHAKE_ERROR;
712   last_error_ = error_message;
713 }
714 
715 }  // namespace securegcm
716