1 /* Copyright (c) 2016, Google Inc.
2  *
3  * Permission to use, copy, modify, and/or distribute this software for any
4  * purpose with or without fee is hereby granted, provided that the above
5  * copyright notice and this permission notice appear in all copies.
6  *
7  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10  * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12  * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13  * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
14 
15 #include <openssl/ssl.h>
16 
17 #include <assert.h>
18 #include <string.h>
19 
20 #include <utility>
21 
22 #include <openssl/aead.h>
23 #include <openssl/bytestring.h>
24 #include <openssl/digest.h>
25 #include <openssl/hkdf.h>
26 #include <openssl/hmac.h>
27 #include <openssl/mem.h>
28 
29 #include "../crypto/internal.h"
30 #include "internal.h"
31 
32 
33 BSSL_NAMESPACE_BEGIN
34 
init_key_schedule(SSL_HANDSHAKE * hs,uint16_t version,const SSL_CIPHER * cipher)35 static bool init_key_schedule(SSL_HANDSHAKE *hs, uint16_t version,
36                              const SSL_CIPHER *cipher) {
37   if (!hs->transcript.InitHash(version, cipher)) {
38     return false;
39   }
40 
41   hs->hash_len = hs->transcript.DigestLen();
42 
43   // Initialize the secret to the zero key.
44   OPENSSL_memset(hs->secret, 0, hs->hash_len);
45 
46   return true;
47 }
48 
tls13_init_key_schedule(SSL_HANDSHAKE * hs,const uint8_t * psk,size_t psk_len)49 bool tls13_init_key_schedule(SSL_HANDSHAKE *hs, const uint8_t *psk,
50                              size_t psk_len) {
51   if (!init_key_schedule(hs, ssl_protocol_version(hs->ssl), hs->new_cipher)) {
52     return false;
53   }
54 
55   hs->transcript.FreeBuffer();
56   return HKDF_extract(hs->secret, &hs->hash_len, hs->transcript.Digest(), psk,
57                       psk_len, hs->secret, hs->hash_len);
58 }
59 
tls13_init_early_key_schedule(SSL_HANDSHAKE * hs,const uint8_t * psk,size_t psk_len)60 bool tls13_init_early_key_schedule(SSL_HANDSHAKE *hs, const uint8_t *psk,
61                                    size_t psk_len) {
62   SSL *const ssl = hs->ssl;
63   return init_key_schedule(hs, ssl_session_protocol_version(ssl->session.get()),
64                            ssl->session->cipher) &&
65          HKDF_extract(hs->secret, &hs->hash_len, hs->transcript.Digest(), psk,
66                       psk_len, hs->secret, hs->hash_len);
67 }
68 
hkdf_expand_label(uint8_t * out,const EVP_MD * digest,const uint8_t * secret,size_t secret_len,const char * label,size_t label_len,const uint8_t * hash,size_t hash_len,size_t len)69 static bool hkdf_expand_label(uint8_t *out, const EVP_MD *digest,
70                               const uint8_t *secret, size_t secret_len,
71                               const char *label, size_t label_len,
72                               const uint8_t *hash, size_t hash_len,
73                               size_t len) {
74   static const char kTLS13ProtocolLabel[] = "tls13 ";
75 
76   ScopedCBB cbb;
77   CBB child;
78   Array<uint8_t> hkdf_label;
79   if (!CBB_init(cbb.get(), 2 + 1 + strlen(kTLS13ProtocolLabel) + label_len + 1 +
80                                hash_len) ||
81       !CBB_add_u16(cbb.get(), len) ||
82       !CBB_add_u8_length_prefixed(cbb.get(), &child) ||
83       !CBB_add_bytes(&child, (const uint8_t *)kTLS13ProtocolLabel,
84                      strlen(kTLS13ProtocolLabel)) ||
85       !CBB_add_bytes(&child, (const uint8_t *)label, label_len) ||
86       !CBB_add_u8_length_prefixed(cbb.get(), &child) ||
87       !CBB_add_bytes(&child, hash, hash_len) ||
88       !CBBFinishArray(cbb.get(), &hkdf_label)) {
89     return false;
90   }
91 
92   return HKDF_expand(out, len, digest, secret, secret_len, hkdf_label.data(),
93                      hkdf_label.size());
94 }
95 
96 static const char kTLS13LabelDerived[] = "derived";
97 
tls13_advance_key_schedule(SSL_HANDSHAKE * hs,const uint8_t * in,size_t len)98 bool tls13_advance_key_schedule(SSL_HANDSHAKE *hs, const uint8_t *in,
99                                 size_t len) {
100   uint8_t derive_context[EVP_MAX_MD_SIZE];
101   unsigned derive_context_len;
102   if (!EVP_Digest(nullptr, 0, derive_context, &derive_context_len,
103                   hs->transcript.Digest(), nullptr)) {
104     return false;
105   }
106 
107   if (!hkdf_expand_label(hs->secret, hs->transcript.Digest(), hs->secret,
108                          hs->hash_len, kTLS13LabelDerived,
109                          strlen(kTLS13LabelDerived), derive_context,
110                          derive_context_len, hs->hash_len)) {
111     return false;
112   }
113 
114   return HKDF_extract(hs->secret, &hs->hash_len, hs->transcript.Digest(), in,
115                       len, hs->secret, hs->hash_len);
116 }
117 
118 // derive_secret derives a secret of length |len| and writes the result in |out|
119 // with the given label and the current base secret and most recently-saved
120 // handshake context. It returns true on success and false on error.
derive_secret(SSL_HANDSHAKE * hs,uint8_t * out,size_t len,const char * label,size_t label_len)121 static bool derive_secret(SSL_HANDSHAKE *hs, uint8_t *out, size_t len,
122                           const char *label, size_t label_len) {
123   uint8_t context_hash[EVP_MAX_MD_SIZE];
124   size_t context_hash_len;
125   if (!hs->transcript.GetHash(context_hash, &context_hash_len)) {
126     return false;
127   }
128 
129   return hkdf_expand_label(out, hs->transcript.Digest(), hs->secret,
130                            hs->hash_len, label, label_len, context_hash,
131                            context_hash_len, len);
132 }
133 
tls13_set_traffic_key(SSL * ssl,enum ssl_encryption_level_t level,enum evp_aead_direction_t direction,const uint8_t * traffic_secret,size_t traffic_secret_len)134 bool tls13_set_traffic_key(SSL *ssl, enum ssl_encryption_level_t level,
135                            enum evp_aead_direction_t direction,
136                            const uint8_t *traffic_secret,
137                            size_t traffic_secret_len) {
138   const SSL_SESSION *session = SSL_get_session(ssl);
139   uint16_t version = ssl_session_protocol_version(session);
140 
141   if (traffic_secret_len > 0xff) {
142     OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
143     return false;
144   }
145 
146   UniquePtr<SSLAEADContext> traffic_aead;
147   if (ssl->quic_method == nullptr) {
148     // Look up cipher suite properties.
149     const EVP_AEAD *aead;
150     size_t discard;
151     if (!ssl_cipher_get_evp_aead(&aead, &discard, &discard, session->cipher,
152                                  version, SSL_is_dtls(ssl))) {
153       return false;
154     }
155 
156     const EVP_MD *digest = ssl_session_get_digest(session);
157 
158     // Derive the key.
159     size_t key_len = EVP_AEAD_key_length(aead);
160     uint8_t key[EVP_AEAD_MAX_KEY_LENGTH];
161     if (!hkdf_expand_label(key, digest, traffic_secret, traffic_secret_len,
162                            "key", 3, NULL, 0, key_len)) {
163       return false;
164     }
165 
166     // Derive the IV.
167     size_t iv_len = EVP_AEAD_nonce_length(aead);
168     uint8_t iv[EVP_AEAD_MAX_NONCE_LENGTH];
169     if (!hkdf_expand_label(iv, digest, traffic_secret, traffic_secret_len, "iv",
170                            2, NULL, 0, iv_len)) {
171       return false;
172     }
173 
174 
175     traffic_aead = SSLAEADContext::Create(
176         direction, session->ssl_version, SSL_is_dtls(ssl), session->cipher,
177         MakeConstSpan(key, key_len), Span<const uint8_t>(),
178         MakeConstSpan(iv, iv_len));
179   } else {
180     // Install a placeholder SSLAEADContext so that SSL accessors work. The
181     // encryption itself will be handled by the SSL_QUIC_METHOD.
182     traffic_aead =
183         SSLAEADContext::CreatePlaceholderForQUIC(version, session->cipher);
184   }
185 
186   if (!traffic_aead) {
187     return false;
188   }
189 
190   if (direction == evp_aead_open) {
191     if (!ssl->method->set_read_state(ssl, std::move(traffic_aead))) {
192       return false;
193     }
194   } else {
195     if (!ssl->method->set_write_state(ssl, std::move(traffic_aead))) {
196       return false;
197     }
198   }
199 
200   // Save the traffic secret.
201   if (direction == evp_aead_open) {
202     OPENSSL_memmove(ssl->s3->read_traffic_secret, traffic_secret,
203                     traffic_secret_len);
204     ssl->s3->read_traffic_secret_len = traffic_secret_len;
205     ssl->s3->read_level = level;
206   } else {
207     OPENSSL_memmove(ssl->s3->write_traffic_secret, traffic_secret,
208                     traffic_secret_len);
209     ssl->s3->write_traffic_secret_len = traffic_secret_len;
210     ssl->s3->write_level = level;
211   }
212 
213   return true;
214 }
215 
216 
217 static const char kTLS13LabelExporter[] = "exp master";
218 static const char kTLS13LabelEarlyExporter[] = "e exp master";
219 
220 static const char kTLS13LabelClientEarlyTraffic[] = "c e traffic";
221 static const char kTLS13LabelClientHandshakeTraffic[] = "c hs traffic";
222 static const char kTLS13LabelServerHandshakeTraffic[] = "s hs traffic";
223 static const char kTLS13LabelClientApplicationTraffic[] = "c ap traffic";
224 static const char kTLS13LabelServerApplicationTraffic[] = "s ap traffic";
225 
tls13_derive_early_secrets(SSL_HANDSHAKE * hs)226 bool tls13_derive_early_secrets(SSL_HANDSHAKE *hs) {
227   SSL *const ssl = hs->ssl;
228   if (!derive_secret(hs, hs->early_traffic_secret, hs->hash_len,
229                      kTLS13LabelClientEarlyTraffic,
230                      strlen(kTLS13LabelClientEarlyTraffic)) ||
231       !ssl_log_secret(ssl, "CLIENT_EARLY_TRAFFIC_SECRET",
232                       hs->early_traffic_secret, hs->hash_len) ||
233       !derive_secret(hs, ssl->s3->early_exporter_secret, hs->hash_len,
234                      kTLS13LabelEarlyExporter,
235                      strlen(kTLS13LabelEarlyExporter))) {
236     return false;
237   }
238   ssl->s3->early_exporter_secret_len = hs->hash_len;
239 
240   if (ssl->quic_method != nullptr) {
241     if (ssl->server) {
242       if (!ssl->quic_method->set_encryption_secrets(
243               ssl, ssl_encryption_early_data, nullptr, hs->early_traffic_secret,
244               hs->hash_len)) {
245         OPENSSL_PUT_ERROR(SSL, SSL_R_QUIC_INTERNAL_ERROR);
246         return false;
247       }
248     } else {
249       if (!ssl->quic_method->set_encryption_secrets(
250               ssl, ssl_encryption_early_data, hs->early_traffic_secret, nullptr,
251               hs->hash_len)) {
252         OPENSSL_PUT_ERROR(SSL, SSL_R_QUIC_INTERNAL_ERROR);
253         return false;
254       }
255     }
256   }
257 
258   return true;
259 }
260 
tls13_derive_handshake_secrets(SSL_HANDSHAKE * hs)261 bool tls13_derive_handshake_secrets(SSL_HANDSHAKE *hs) {
262   SSL *const ssl = hs->ssl;
263   if (!derive_secret(hs, hs->client_handshake_secret, hs->hash_len,
264                      kTLS13LabelClientHandshakeTraffic,
265                      strlen(kTLS13LabelClientHandshakeTraffic)) ||
266       !ssl_log_secret(ssl, "CLIENT_HANDSHAKE_TRAFFIC_SECRET",
267                       hs->client_handshake_secret, hs->hash_len) ||
268       !derive_secret(hs, hs->server_handshake_secret, hs->hash_len,
269                      kTLS13LabelServerHandshakeTraffic,
270                      strlen(kTLS13LabelServerHandshakeTraffic)) ||
271       !ssl_log_secret(ssl, "SERVER_HANDSHAKE_TRAFFIC_SECRET",
272                       hs->server_handshake_secret, hs->hash_len)) {
273     return false;
274   }
275 
276   if (ssl->quic_method != nullptr) {
277     if (ssl->server) {
278       if (!ssl->quic_method->set_encryption_secrets(
279               ssl, ssl_encryption_handshake, hs->client_handshake_secret,
280               hs->server_handshake_secret, hs->hash_len)) {
281         OPENSSL_PUT_ERROR(SSL, SSL_R_QUIC_INTERNAL_ERROR);
282         return false;
283       }
284     } else {
285       if (!ssl->quic_method->set_encryption_secrets(
286               ssl, ssl_encryption_handshake, hs->server_handshake_secret,
287               hs->client_handshake_secret, hs->hash_len)) {
288         OPENSSL_PUT_ERROR(SSL, SSL_R_QUIC_INTERNAL_ERROR);
289         return false;
290       }
291     }
292   }
293 
294   return true;
295 }
296 
tls13_derive_application_secrets(SSL_HANDSHAKE * hs)297 bool tls13_derive_application_secrets(SSL_HANDSHAKE *hs) {
298   SSL *const ssl = hs->ssl;
299   ssl->s3->exporter_secret_len = hs->hash_len;
300   if (!derive_secret(hs, hs->client_traffic_secret_0, hs->hash_len,
301                      kTLS13LabelClientApplicationTraffic,
302                      strlen(kTLS13LabelClientApplicationTraffic)) ||
303       !ssl_log_secret(ssl, "CLIENT_TRAFFIC_SECRET_0",
304                       hs->client_traffic_secret_0, hs->hash_len) ||
305       !derive_secret(hs, hs->server_traffic_secret_0, hs->hash_len,
306                      kTLS13LabelServerApplicationTraffic,
307                      strlen(kTLS13LabelServerApplicationTraffic)) ||
308       !ssl_log_secret(ssl, "SERVER_TRAFFIC_SECRET_0",
309                       hs->server_traffic_secret_0, hs->hash_len) ||
310       !derive_secret(hs, ssl->s3->exporter_secret, hs->hash_len,
311                      kTLS13LabelExporter, strlen(kTLS13LabelExporter)) ||
312       !ssl_log_secret(ssl, "EXPORTER_SECRET", ssl->s3->exporter_secret,
313                       hs->hash_len)) {
314     return false;
315   }
316 
317   if (ssl->quic_method != nullptr) {
318     if (ssl->server) {
319       if (!ssl->quic_method->set_encryption_secrets(
320               ssl, ssl_encryption_application, hs->client_traffic_secret_0,
321               hs->server_traffic_secret_0, hs->hash_len)) {
322         OPENSSL_PUT_ERROR(SSL, SSL_R_QUIC_INTERNAL_ERROR);
323         return false;
324       }
325     } else {
326       if (!ssl->quic_method->set_encryption_secrets(
327               ssl, ssl_encryption_application, hs->server_traffic_secret_0,
328               hs->client_traffic_secret_0, hs->hash_len)) {
329         OPENSSL_PUT_ERROR(SSL, SSL_R_QUIC_INTERNAL_ERROR);
330         return false;
331       }
332     }
333   }
334 
335   return true;
336 }
337 
338 static const char kTLS13LabelApplicationTraffic[] = "traffic upd";
339 
tls13_rotate_traffic_key(SSL * ssl,enum evp_aead_direction_t direction)340 bool tls13_rotate_traffic_key(SSL *ssl, enum evp_aead_direction_t direction) {
341   uint8_t *secret;
342   size_t secret_len;
343   if (direction == evp_aead_open) {
344     secret = ssl->s3->read_traffic_secret;
345     secret_len = ssl->s3->read_traffic_secret_len;
346   } else {
347     secret = ssl->s3->write_traffic_secret;
348     secret_len = ssl->s3->write_traffic_secret_len;
349   }
350 
351   const EVP_MD *digest = ssl_session_get_digest(SSL_get_session(ssl));
352   if (!hkdf_expand_label(secret, digest, secret, secret_len,
353                          kTLS13LabelApplicationTraffic,
354                          strlen(kTLS13LabelApplicationTraffic), NULL, 0,
355                          secret_len)) {
356     return false;
357   }
358 
359   return tls13_set_traffic_key(ssl, ssl_encryption_application, direction,
360                                secret, secret_len);
361 }
362 
363 static const char kTLS13LabelResumption[] = "res master";
364 
tls13_derive_resumption_secret(SSL_HANDSHAKE * hs)365 bool tls13_derive_resumption_secret(SSL_HANDSHAKE *hs) {
366   if (hs->hash_len > SSL_MAX_MASTER_KEY_LENGTH) {
367     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
368     return false;
369   }
370   hs->new_session->master_key_length = hs->hash_len;
371   return derive_secret(hs, hs->new_session->master_key,
372                        hs->new_session->master_key_length,
373                        kTLS13LabelResumption, strlen(kTLS13LabelResumption));
374 }
375 
376 static const char kTLS13LabelFinished[] = "finished";
377 
378 // tls13_verify_data sets |out| to be the HMAC of |context| using a derived
379 // Finished key for both Finished messages and the PSK binder.
tls13_verify_data(const EVP_MD * digest,uint16_t version,uint8_t * out,size_t * out_len,const uint8_t * secret,size_t hash_len,uint8_t * context,size_t context_len)380 static bool tls13_verify_data(const EVP_MD *digest, uint16_t version,
381                               uint8_t *out, size_t *out_len,
382                               const uint8_t *secret, size_t hash_len,
383                               uint8_t *context, size_t context_len) {
384   uint8_t key[EVP_MAX_MD_SIZE];
385   unsigned len;
386   if (!hkdf_expand_label(key, digest, secret, hash_len, kTLS13LabelFinished,
387                          strlen(kTLS13LabelFinished), NULL, 0, hash_len) ||
388       HMAC(digest, key, hash_len, context, context_len, out, &len) == NULL) {
389     return false;
390   }
391   *out_len = len;
392   return true;
393 }
394 
tls13_finished_mac(SSL_HANDSHAKE * hs,uint8_t * out,size_t * out_len,bool is_server)395 bool tls13_finished_mac(SSL_HANDSHAKE *hs, uint8_t *out, size_t *out_len,
396                         bool is_server) {
397   const uint8_t *traffic_secret;
398   if (is_server) {
399     traffic_secret = hs->server_handshake_secret;
400   } else {
401     traffic_secret = hs->client_handshake_secret;
402   }
403 
404   uint8_t context_hash[EVP_MAX_MD_SIZE];
405   size_t context_hash_len;
406   if (!hs->transcript.GetHash(context_hash, &context_hash_len) ||
407       !tls13_verify_data(hs->transcript.Digest(), hs->ssl->version, out,
408                          out_len, traffic_secret, hs->hash_len, context_hash,
409                          context_hash_len)) {
410     return 0;
411   }
412   return 1;
413 }
414 
415 static const char kTLS13LabelResumptionPSK[] = "resumption";
416 
tls13_derive_session_psk(SSL_SESSION * session,Span<const uint8_t> nonce)417 bool tls13_derive_session_psk(SSL_SESSION *session, Span<const uint8_t> nonce) {
418   const EVP_MD *digest = ssl_session_get_digest(session);
419   return hkdf_expand_label(session->master_key, digest, session->master_key,
420                            session->master_key_length, kTLS13LabelResumptionPSK,
421                            strlen(kTLS13LabelResumptionPSK), nonce.data(),
422                            nonce.size(), session->master_key_length);
423 }
424 
425 static const char kTLS13LabelExportKeying[] = "exporter";
426 
tls13_export_keying_material(SSL * ssl,Span<uint8_t> out,Span<const uint8_t> secret,Span<const char> label,Span<const uint8_t> context)427 bool tls13_export_keying_material(SSL *ssl, Span<uint8_t> out,
428                                   Span<const uint8_t> secret,
429                                   Span<const char> label,
430                                   Span<const uint8_t> context) {
431   if (secret.empty()) {
432     assert(0);
433     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
434     return false;
435   }
436 
437   const EVP_MD *digest = ssl_session_get_digest(SSL_get_session(ssl));
438 
439   uint8_t hash[EVP_MAX_MD_SIZE];
440   uint8_t export_context[EVP_MAX_MD_SIZE];
441   uint8_t derived_secret[EVP_MAX_MD_SIZE];
442   unsigned hash_len;
443   unsigned export_context_len;
444   unsigned derived_secret_len = EVP_MD_size(digest);
445   return EVP_Digest(context.data(), context.size(), hash, &hash_len, digest,
446                     nullptr) &&
447          EVP_Digest(nullptr, 0, export_context, &export_context_len, digest,
448                     nullptr) &&
449          hkdf_expand_label(derived_secret, digest, secret.data(), secret.size(),
450                            label.data(), label.size(), export_context,
451                            export_context_len, derived_secret_len) &&
452          hkdf_expand_label(out.data(), digest, derived_secret,
453                            derived_secret_len, kTLS13LabelExportKeying,
454                            strlen(kTLS13LabelExportKeying), hash, hash_len,
455                            out.size());
456 }
457 
458 static const char kTLS13LabelPSKBinder[] = "res binder";
459 
tls13_psk_binder(uint8_t * out,uint16_t version,const EVP_MD * digest,uint8_t * psk,size_t psk_len,uint8_t * context,size_t context_len,size_t hash_len)460 static bool tls13_psk_binder(uint8_t *out, uint16_t version,
461                              const EVP_MD *digest, uint8_t *psk, size_t psk_len,
462                              uint8_t *context, size_t context_len,
463                              size_t hash_len) {
464   uint8_t binder_context[EVP_MAX_MD_SIZE];
465   unsigned binder_context_len;
466   if (!EVP_Digest(NULL, 0, binder_context, &binder_context_len, digest, NULL)) {
467     return false;
468   }
469 
470   uint8_t early_secret[EVP_MAX_MD_SIZE] = {0};
471   size_t early_secret_len;
472   if (!HKDF_extract(early_secret, &early_secret_len, digest, psk, hash_len,
473                     NULL, 0)) {
474     return false;
475   }
476 
477   uint8_t binder_key[EVP_MAX_MD_SIZE] = {0};
478   size_t len;
479   if (!hkdf_expand_label(binder_key, digest, early_secret, hash_len,
480                          kTLS13LabelPSKBinder, strlen(kTLS13LabelPSKBinder),
481                          binder_context, binder_context_len, hash_len) ||
482       !tls13_verify_data(digest, version, out, &len, binder_key, hash_len,
483                          context, context_len)) {
484     return false;
485   }
486 
487   return true;
488 }
489 
tls13_write_psk_binder(SSL_HANDSHAKE * hs,uint8_t * msg,size_t len)490 bool tls13_write_psk_binder(SSL_HANDSHAKE *hs, uint8_t *msg, size_t len) {
491   SSL *const ssl = hs->ssl;
492   const EVP_MD *digest = ssl_session_get_digest(ssl->session.get());
493   size_t hash_len = EVP_MD_size(digest);
494 
495   if (len < hash_len + 3) {
496     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
497     return false;
498   }
499 
500   ScopedEVP_MD_CTX ctx;
501   uint8_t context[EVP_MAX_MD_SIZE];
502   unsigned context_len;
503 
504   if (!EVP_DigestInit_ex(ctx.get(), digest, NULL) ||
505       !EVP_DigestUpdate(ctx.get(), hs->transcript.buffer().data(),
506                         hs->transcript.buffer().size()) ||
507       !EVP_DigestUpdate(ctx.get(), msg, len - hash_len - 3) ||
508       !EVP_DigestFinal_ex(ctx.get(), context, &context_len)) {
509     return false;
510   }
511 
512   uint8_t verify_data[EVP_MAX_MD_SIZE] = {0};
513   if (!tls13_psk_binder(verify_data, ssl->session->ssl_version, digest,
514                         ssl->session->master_key,
515                         ssl->session->master_key_length, context, context_len,
516                         hash_len)) {
517     return false;
518   }
519 
520   OPENSSL_memcpy(msg + len - hash_len, verify_data, hash_len);
521   return true;
522 }
523 
tls13_verify_psk_binder(SSL_HANDSHAKE * hs,SSL_SESSION * session,const SSLMessage & msg,CBS * binders)524 bool tls13_verify_psk_binder(SSL_HANDSHAKE *hs, SSL_SESSION *session,
525                              const SSLMessage &msg, CBS *binders) {
526   size_t hash_len = hs->transcript.DigestLen();
527 
528   // The message must be large enough to exclude the binders.
529   if (CBS_len(&msg.raw) < CBS_len(binders) + 2) {
530     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
531     return false;
532   }
533 
534   // Hash a ClientHello prefix up to the binders. This includes the header. For
535   // now, this assumes we only ever verify PSK binders on initial
536   // ClientHellos.
537   uint8_t context[EVP_MAX_MD_SIZE];
538   unsigned context_len;
539   if (!EVP_Digest(CBS_data(&msg.raw), CBS_len(&msg.raw) - CBS_len(binders) - 2,
540                   context, &context_len, hs->transcript.Digest(), NULL)) {
541     return false;
542   }
543 
544   uint8_t verify_data[EVP_MAX_MD_SIZE] = {0};
545   CBS binder;
546   if (!tls13_psk_binder(verify_data, hs->ssl->version, hs->transcript.Digest(),
547                         session->master_key, session->master_key_length,
548                         context, context_len, hash_len) ||
549       // We only consider the first PSK, so compare against the first binder.
550       !CBS_get_u8_length_prefixed(binders, &binder)) {
551     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
552     return false;
553   }
554 
555   bool binder_ok = CBS_len(&binder) == hash_len &&
556                    CRYPTO_memcmp(CBS_data(&binder), verify_data, hash_len) == 0;
557 #if defined(BORINGSSL_UNSAFE_FUZZER_MODE)
558   binder_ok = true;
559 #endif
560   if (!binder_ok) {
561     OPENSSL_PUT_ERROR(SSL, SSL_R_DIGEST_CHECK_FAILED);
562     return false;
563   }
564 
565   return true;
566 }
567 
568 BSSL_NAMESPACE_END
569