1 /*
2  * Copyright 2019, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "IdentityCredentialSupport"
18 
19 #include <android/hardware/identity/support/IdentityCredentialSupport.h>
20 
21 #define _POSIX_C_SOURCE 199309L
22 
23 #include <ctype.h>
24 #include <stdarg.h>
25 #include <stdio.h>
26 #include <time.h>
27 #include <chrono>
28 #include <iomanip>
29 
30 #include <openssl/aes.h>
31 #include <openssl/bn.h>
32 #include <openssl/crypto.h>
33 #include <openssl/ec.h>
34 #include <openssl/err.h>
35 #include <openssl/evp.h>
36 #include <openssl/hkdf.h>
37 #include <openssl/hmac.h>
38 #include <openssl/objects.h>
39 #include <openssl/pem.h>
40 #include <openssl/pkcs12.h>
41 #include <openssl/rand.h>
42 #include <openssl/x509.h>
43 #include <openssl/x509_vfy.h>
44 
45 #include <android-base/logging.h>
46 #include <android-base/stringprintf.h>
47 #include <charconv>
48 
49 #include <cppbor.h>
50 #include <cppbor_parse.h>
51 
52 #include <android/hardware/keymaster/4.0/types.h>
53 #include <keymaster/authorization_set.h>
54 #include <keymaster/contexts/pure_soft_keymaster_context.h>
55 #include <keymaster/contexts/soft_attestation_cert.h>
56 #include <keymaster/keymaster_tags.h>
57 #include <keymaster/km_openssl/asymmetric_key.h>
58 #include <keymaster/km_openssl/attestation_utils.h>
59 #include <keymaster/km_openssl/certificate_utils.h>
60 
61 namespace android {
62 namespace hardware {
63 namespace identity {
64 namespace support {
65 
66 using ::std::pair;
67 using ::std::unique_ptr;
68 
69 // ---------------------------------------------------------------------------
70 // Miscellaneous utilities.
71 // ---------------------------------------------------------------------------
72 
hexdump(const string & name,const vector<uint8_t> & data)73 void hexdump(const string& name, const vector<uint8_t>& data) {
74     fprintf(stderr, "%s: dumping %zd bytes\n", name.c_str(), data.size());
75     size_t n, m, o;
76     for (n = 0; n < data.size(); n += 16) {
77         fprintf(stderr, "%04zx  ", n);
78         for (m = 0; m < 16 && n + m < data.size(); m++) {
79             fprintf(stderr, "%02x ", data[n + m]);
80         }
81         for (o = m; o < 16; o++) {
82             fprintf(stderr, "   ");
83         }
84         fprintf(stderr, " ");
85         for (m = 0; m < 16 && n + m < data.size(); m++) {
86             int c = data[n + m];
87             fprintf(stderr, "%c", isprint(c) ? c : '.');
88         }
89         fprintf(stderr, "\n");
90     }
91     fprintf(stderr, "\n");
92 }
93 
encodeHex(const uint8_t * data,size_t dataLen)94 string encodeHex(const uint8_t* data, size_t dataLen) {
95     static const char hexDigits[16] = {'0', '1', '2', '3', '4', '5', '6', '7',
96                                        '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'};
97 
98     string ret;
99     ret.resize(dataLen * 2);
100     for (size_t n = 0; n < dataLen; n++) {
101         uint8_t byte = data[n];
102         ret[n * 2 + 0] = hexDigits[byte >> 4];
103         ret[n * 2 + 1] = hexDigits[byte & 0x0f];
104     }
105 
106     return ret;
107 }
108 
encodeHex(const string & str)109 string encodeHex(const string& str) {
110     return encodeHex(reinterpret_cast<const uint8_t*>(str.data()), str.size());
111 }
112 
encodeHex(const vector<uint8_t> & data)113 string encodeHex(const vector<uint8_t>& data) {
114     return encodeHex(data.data(), data.size());
115 }
116 
117 // Returns -1 on error, otherwise an integer in the range 0 through 15, both inclusive.
parseHexDigit(char hexDigit)118 int parseHexDigit(char hexDigit) {
119     if (hexDigit >= '0' && hexDigit <= '9') {
120         return int(hexDigit) - '0';
121     } else if (hexDigit >= 'a' && hexDigit <= 'f') {
122         return int(hexDigit) - 'a' + 10;
123     } else if (hexDigit >= 'A' && hexDigit <= 'F') {
124         return int(hexDigit) - 'A' + 10;
125     }
126     return -1;
127 }
128 
decodeHex(const string & hexEncoded)129 optional<vector<uint8_t>> decodeHex(const string& hexEncoded) {
130     vector<uint8_t> out;
131     size_t hexSize = hexEncoded.size();
132     if ((hexSize & 1) != 0) {
133         LOG(ERROR) << "Size of data cannot be odd";
134         return {};
135     }
136 
137     out.resize(hexSize / 2);
138     for (size_t n = 0; n < hexSize / 2; n++) {
139         int upperNibble = parseHexDigit(hexEncoded[n * 2]);
140         int lowerNibble = parseHexDigit(hexEncoded[n * 2 + 1]);
141         if (upperNibble == -1 || lowerNibble == -1) {
142             LOG(ERROR) << "Invalid hex digit at position " << n;
143             return {};
144         }
145         out[n] = (upperNibble << 4) + lowerNibble;
146     }
147 
148     return out;
149 }
150 
151 // ---------------------------------------------------------------------------
152 // Crypto functionality / abstraction.
153 // ---------------------------------------------------------------------------
154 
155 using EvpCipherCtxPtr = bssl::UniquePtr<EVP_CIPHER_CTX>;
156 using EC_KEY_Ptr = bssl::UniquePtr<EC_KEY>;
157 using EVP_PKEY_Ptr = bssl::UniquePtr<EVP_PKEY>;
158 using EVP_PKEY_CTX_Ptr = bssl::UniquePtr<EVP_PKEY_CTX>;
159 using EC_GROUP_Ptr = bssl::UniquePtr<EC_GROUP>;
160 using EC_POINT_Ptr = bssl::UniquePtr<EC_POINT>;
161 using ECDSA_SIG_Ptr = bssl::UniquePtr<ECDSA_SIG>;
162 using X509_Ptr = bssl::UniquePtr<X509>;
163 using PKCS12_Ptr = bssl::UniquePtr<PKCS12>;
164 using BIGNUM_Ptr = bssl::UniquePtr<BIGNUM>;
165 using ASN1_INTEGER_Ptr = bssl::UniquePtr<ASN1_INTEGER>;
166 using ASN1_TIME_Ptr = bssl::UniquePtr<ASN1_TIME>;
167 using ASN1_OCTET_STRING_Ptr = bssl::UniquePtr<ASN1_OCTET_STRING>;
168 using ASN1_OBJECT_Ptr = bssl::UniquePtr<ASN1_OBJECT>;
169 using X509_NAME_Ptr = bssl::UniquePtr<X509_NAME>;
170 using X509_EXTENSION_Ptr = bssl::UniquePtr<X509_EXTENSION>;
171 
172 namespace {
173 
generateP256Key()174 EVP_PKEY_Ptr generateP256Key() {
175     EC_KEY_Ptr ec_key(EC_KEY_new());
176     EVP_PKEY_Ptr pkey(EVP_PKEY_new());
177     EC_GROUP_Ptr group(EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
178 
179     if (ec_key.get() == nullptr || pkey.get() == nullptr) {
180         LOG(ERROR) << "Memory allocation failed";
181         return {};
182     }
183 
184     if (EC_KEY_set_group(ec_key.get(), group.get()) != 1 ||
185         EC_KEY_generate_key(ec_key.get()) != 1 || EC_KEY_check_key(ec_key.get()) < 0) {
186         LOG(ERROR) << "Error generating key";
187         return {};
188     }
189 
190     if (EVP_PKEY_set1_EC_KEY(pkey.get(), ec_key.get()) != 1) {
191         LOG(ERROR) << "Error getting private key";
192         return {};
193     }
194 
195     return pkey;
196 }
197 
derEncodeKeyPair(const EVP_PKEY & pkey)198 optional<vector<uint8_t>> derEncodeKeyPair(const EVP_PKEY& pkey) {
199     int size = i2d_PrivateKey(&pkey, nullptr);
200     if (size == 0) {
201         LOG(ERROR) << "Error generating public key encoding";
202         return std::nullopt;
203     }
204 
205     vector<uint8_t> keyPair(size);
206     unsigned char* p = keyPair.data();
207     i2d_PrivateKey(&pkey, &p);
208 
209     return keyPair;
210 }
211 
212 // Generates the attestation certificate with the parameters passed in.  Note
213 // that the passed in |activeTimeMilliSeconds| |expireTimeMilliSeconds| are in
214 // milli seconds since epoch.  We are setting them to milliseconds due to
215 // requirement in AuthorizationSet KM_DATE fields.  The certificate created is
216 // actually in seconds.
217 //
signAttestationCertificate(const::keymaster::PureSoftKeymasterContext & context,const EVP_PKEY * key,const vector<uint8_t> & applicationId,const vector<uint8_t> & challenge,const vector<uint8_t> & attestationKeyBlob,const vector<uint8_t> & derAttestationCertSubjectName,uint64_t activeTimeMilliSeconds,uint64_t expireTimeMilliSeconds,bool isTestCredential)218 optional<vector<vector<uint8_t>>> signAttestationCertificate(
219         const ::keymaster::PureSoftKeymasterContext& context, const EVP_PKEY* key,
220         const vector<uint8_t>& applicationId, const vector<uint8_t>& challenge,
221         const vector<uint8_t>& attestationKeyBlob,
222         const vector<uint8_t>& derAttestationCertSubjectName, uint64_t activeTimeMilliSeconds,
223         uint64_t expireTimeMilliSeconds, bool isTestCredential) {
224     ::keymaster::X509_NAME_Ptr subjectName;
225     if (KM_ERROR_OK !=
226         ::keymaster::make_name_from_str("Android Identity Credential Key", &subjectName)) {
227         LOG(ERROR) << "Cannot create attestation subject";
228         return {};
229     }
230 
231     vector<uint8_t> subject(i2d_X509_NAME(subjectName.get(), NULL));
232     unsigned char* subjectPtr = subject.data();
233 
234     i2d_X509_NAME(subjectName.get(), &subjectPtr);
235 
236     ::keymaster::AuthorizationSet auth_set(
237             ::keymaster::AuthorizationSetBuilder()
238                     .Authorization(::keymaster::TAG_CERTIFICATE_NOT_BEFORE, activeTimeMilliSeconds)
239                     .Authorization(::keymaster::TAG_CERTIFICATE_NOT_AFTER, expireTimeMilliSeconds)
240                     .Authorization(::keymaster::TAG_ATTESTATION_CHALLENGE, challenge.data(),
241                                    challenge.size())
242                     .Authorization(::keymaster::TAG_ACTIVE_DATETIME, activeTimeMilliSeconds)
243                     // Even though identity attestation hal said the application
244                     // id should be in software enforced authentication set,
245                     // keymaster portable lib expect the input in this
246                     // parameter because the software enforced in input to keymaster
247                     // refers to the key software enforced properties. And this
248                     // parameter refers to properties of the attestation which
249                     // includes app id.
250                     .Authorization(::keymaster::TAG_ATTESTATION_APPLICATION_ID,
251                                    applicationId.data(), applicationId.size())
252                     .Authorization(::keymaster::TAG_CERTIFICATE_SUBJECT, subject.data(),
253                                    subject.size())
254                     .Authorization(::keymaster::TAG_USAGE_EXPIRE_DATETIME, expireTimeMilliSeconds));
255 
256     // Unique id and device id is not applicable for identity credential attestation,
257     // so we don't need to set those or application id.
258     ::keymaster::AuthorizationSet swEnforced(::keymaster::AuthorizationSetBuilder().Authorization(
259             ::keymaster::TAG_CREATION_DATETIME, activeTimeMilliSeconds));
260 
261     ::keymaster::AuthorizationSetBuilder hwEnforcedBuilder =
262             ::keymaster::AuthorizationSetBuilder()
263                     .Authorization(::keymaster::TAG_PURPOSE, KM_PURPOSE_SIGN)
264                     .Authorization(::keymaster::TAG_KEY_SIZE, 256)
265                     .Authorization(::keymaster::TAG_ALGORITHM, KM_ALGORITHM_EC)
266                     .Authorization(::keymaster::TAG_NO_AUTH_REQUIRED)
267                     .Authorization(::keymaster::TAG_DIGEST, KM_DIGEST_SHA_2_256)
268                     .Authorization(::keymaster::TAG_EC_CURVE, KM_EC_CURVE_P_256)
269                     .Authorization(::keymaster::TAG_OS_VERSION, 42)
270                     .Authorization(::keymaster::TAG_OS_PATCHLEVEL, 43);
271 
272     // Only include TAG_IDENTITY_CREDENTIAL_KEY if it's not a test credential
273     if (!isTestCredential) {
274         hwEnforcedBuilder.Authorization(::keymaster::TAG_IDENTITY_CREDENTIAL_KEY);
275     }
276     ::keymaster::AuthorizationSet hwEnforced(hwEnforcedBuilder);
277 
278     keymaster_error_t error;
279     ::keymaster::AttestKeyInfo attestKeyInfo;
280     ::keymaster::KeymasterBlob issuerSubjectNameBlob;
281     if (!attestationKeyBlob.empty()) {
282         ::keymaster::KeymasterKeyBlob blob(attestationKeyBlob.data(), attestationKeyBlob.size());
283         ::keymaster::UniquePtr<::keymaster::Key> parsedKey;
284         error = context.ParseKeyBlob(blob, /*additional_params=*/{}, &parsedKey);
285         if (error != KM_ERROR_OK) {
286             LOG(ERROR) << "Error loading attestation key: " << error;
287             return std::nullopt;
288         }
289 
290         attestKeyInfo.signing_key =
291                 static_cast<::keymaster::AsymmetricKey&>(*parsedKey).InternalToEvp();
292         issuerSubjectNameBlob = ::keymaster::KeymasterBlob(derAttestationCertSubjectName.data(),
293                                                            derAttestationCertSubjectName.size());
294         attestKeyInfo.issuer_subject = &issuerSubjectNameBlob;
295     }
296 
297     ::keymaster::CertificateChain certChain = generate_attestation(
298             key, swEnforced, hwEnforced, auth_set, std::move(attestKeyInfo), context, &error);
299 
300     if (KM_ERROR_OK != error) {
301         LOG(ERROR) << "Error generating attestation from EVP key: " << error;
302         return std::nullopt;
303     }
304 
305     vector<vector<uint8_t>> vectors(certChain.entry_count);
306     for (std::size_t i = 0; i < certChain.entry_count; i++) {
307         vectors[i] = {certChain.entries[i].data,
308                       certChain.entries[i].data + certChain.entries[i].data_length};
309     }
310     return vectors;
311 }
312 
parseDigits(const char ** s,int numDigits)313 int parseDigits(const char** s, int numDigits) {
314     int result;
315     auto [_, ec] = std::from_chars(*s, *s + numDigits, result);
316     if (ec != std::errc()) {
317         LOG(ERROR) << "Error parsing " << numDigits << " digits "
318                    << " from " << s;
319         return 0;
320     }
321     *s += numDigits;
322     return result;
323 }
324 
parseAsn1Time(const ASN1_TIME * asn1Time,time_t * outTime)325 bool parseAsn1Time(const ASN1_TIME* asn1Time, time_t* outTime) {
326     struct tm tm;
327 
328     memset(&tm, '\0', sizeof(tm));
329     const char* timeStr = (const char*)asn1Time->data;
330     const char* s = timeStr;
331     if (asn1Time->type == V_ASN1_UTCTIME) {
332         tm.tm_year = parseDigits(&s, 2);
333         if (tm.tm_year < 70) {
334             tm.tm_year += 100;
335         }
336     } else if (asn1Time->type == V_ASN1_GENERALIZEDTIME) {
337         tm.tm_year = parseDigits(&s, 4) - 1900;
338         tm.tm_year -= 1900;
339     } else {
340         LOG(ERROR) << "Unsupported ASN1_TIME type " << asn1Time->type;
341         return false;
342     }
343     tm.tm_mon = parseDigits(&s, 2) - 1;
344     tm.tm_mday = parseDigits(&s, 2);
345     tm.tm_hour = parseDigits(&s, 2);
346     tm.tm_min = parseDigits(&s, 2);
347     tm.tm_sec = parseDigits(&s, 2);
348     // This may need to be updated if someone create certificates using +/- instead of Z.
349     //
350     if (*s != 'Z') {
351         LOG(ERROR) << "Expected Z in string '" << timeStr << "' at offset " << (s - timeStr);
352         return false;
353     }
354 
355     time_t t = timegm(&tm);
356     if (t == -1) {
357         LOG(ERROR) << "Error converting broken-down time to time_t";
358         return false;
359     }
360     *outTime = t;
361     return true;
362 }
363 
getCertificateExpiryAsMillis(const uint8_t * derCert,size_t derCertSize)364 optional<uint64_t> getCertificateExpiryAsMillis(const uint8_t* derCert, size_t derCertSize) {
365     X509_Ptr x509Cert(d2i_X509(nullptr, &derCert, derCertSize));
366     if (!x509Cert) {
367         LOG(ERROR) << "Error parsing certificate";
368         return std::nullopt;
369     }
370 
371     time_t notAfter;
372     if (!parseAsn1Time(X509_get0_notAfter(x509Cert.get()), &notAfter)) {
373         LOG(ERROR) << "Error getting notAfter from batch certificate";
374         return std::nullopt;
375     }
376 
377     return notAfter * 1000;
378 }
379 
createAttestation(EVP_PKEY * pkey,const vector<uint8_t> & challenge,const vector<uint8_t> & applicationId,bool isTestCredential)380 optional<vector<vector<uint8_t>>> createAttestation(EVP_PKEY* pkey,
381                                                     const vector<uint8_t>& challenge,
382                                                     const vector<uint8_t>& applicationId,
383                                                     bool isTestCredential) {
384     // Pretend to be implemented in a trusted environment just so we can pass
385     // the VTS tests. Of course, this is a pretend-only game since hopefully no
386     // relying party is ever going to trust our batch key and those keys above
387     // it.
388     ::keymaster::PureSoftKeymasterContext context(::keymaster::KmVersion::KEYMINT_1,
389                                                   KM_SECURITY_LEVEL_TRUSTED_ENVIRONMENT);
390 
391     keymaster_error_t error;
392     ::keymaster::CertificateChain attestation_chain =
393             context.GetAttestationChain(KM_ALGORITHM_EC, &error);
394     if (KM_ERROR_OK != error) {
395         LOG(ERROR) << "Error getting attestation chain " << error;
396         return std::nullopt;
397     }
398 
399     if (attestation_chain.entry_count < 1) {
400         LOG(ERROR) << "Expected at least one entry in attestation chain";
401         return std::nullopt;
402     }
403 
404     uint64_t activeTimeMs = time(nullptr) * 1000;
405     optional<uint64_t> expireTimeMs = getCertificateExpiryAsMillis(
406             attestation_chain.entries[0].data, attestation_chain.entries[0].data_length);
407     if (!expireTimeMs) {
408         LOG(ERROR) << "Error getting expiration time for batch cert";
409         return std::nullopt;
410     }
411 
412     return signAttestationCertificate(context, pkey, applicationId, challenge,
413                                       /*attestationKeyBlob=*/{},
414                                       /*derAttestationCertSubjectName=*/{}, activeTimeMs,
415                                       *expireTimeMs, isTestCredential);
416 }
417 
418 }  // namespace
419 
420 // bool getRandom(size_t numBytes, vector<uint8_t>& output) {
getRandom(size_t numBytes)421 optional<vector<uint8_t>> getRandom(size_t numBytes) {
422     vector<uint8_t> output;
423     output.resize(numBytes);
424     if (RAND_bytes(output.data(), numBytes) != 1) {
425         LOG(ERROR) << "RAND_bytes: failed getting " << numBytes << " random";
426         return {};
427     }
428     return output;
429 }
430 
decryptAes128Gcm(const vector<uint8_t> & key,const vector<uint8_t> & encryptedData,const vector<uint8_t> & additionalAuthenticatedData)431 optional<vector<uint8_t>> decryptAes128Gcm(const vector<uint8_t>& key,
432                                            const vector<uint8_t>& encryptedData,
433                                            const vector<uint8_t>& additionalAuthenticatedData) {
434     int cipherTextSize = int(encryptedData.size()) - kAesGcmIvSize - kAesGcmTagSize;
435     if (cipherTextSize < 0) {
436         LOG(ERROR) << "encryptedData too small";
437         return {};
438     }
439     unsigned char* nonce = (unsigned char*)encryptedData.data();
440     unsigned char* cipherText = nonce + kAesGcmIvSize;
441     unsigned char* tag = cipherText + cipherTextSize;
442 
443     vector<uint8_t> plainText;
444     plainText.resize(cipherTextSize);
445 
446     auto ctx = EvpCipherCtxPtr(EVP_CIPHER_CTX_new());
447     if (ctx.get() == nullptr) {
448         LOG(ERROR) << "EVP_CIPHER_CTX_new: failed";
449         return {};
450     }
451 
452     if (EVP_DecryptInit_ex(ctx.get(), EVP_aes_128_gcm(), NULL, NULL, NULL) != 1) {
453         LOG(ERROR) << "EVP_DecryptInit_ex: failed";
454         return {};
455     }
456 
457     if (EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_SET_IVLEN, kAesGcmIvSize, NULL) != 1) {
458         LOG(ERROR) << "EVP_CIPHER_CTX_ctrl: failed setting nonce length";
459         return {};
460     }
461 
462     if (EVP_DecryptInit_ex(ctx.get(), NULL, NULL, (unsigned char*)key.data(), nonce) != 1) {
463         LOG(ERROR) << "EVP_DecryptInit_ex: failed";
464         return {};
465     }
466 
467     int numWritten;
468     if (additionalAuthenticatedData.size() > 0) {
469         if (EVP_DecryptUpdate(ctx.get(), NULL, &numWritten,
470                               (unsigned char*)additionalAuthenticatedData.data(),
471                               additionalAuthenticatedData.size()) != 1) {
472             LOG(ERROR) << "EVP_DecryptUpdate: failed for additionalAuthenticatedData";
473             return {};
474         }
475         if ((size_t)numWritten != additionalAuthenticatedData.size()) {
476             LOG(ERROR) << "EVP_DecryptUpdate: Unexpected outl=" << numWritten << " (expected "
477                        << additionalAuthenticatedData.size() << ") for additionalAuthenticatedData";
478             return {};
479         }
480     }
481 
482     if (EVP_DecryptUpdate(ctx.get(), (unsigned char*)plainText.data(), &numWritten, cipherText,
483                           cipherTextSize) != 1) {
484         LOG(ERROR) << "EVP_DecryptUpdate: failed";
485         return {};
486     }
487     if (numWritten != cipherTextSize) {
488         LOG(ERROR) << "EVP_DecryptUpdate: Unexpected outl=" << numWritten << " (expected "
489                    << cipherTextSize << ")";
490         return {};
491     }
492 
493     if (!EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_SET_TAG, kAesGcmTagSize, tag)) {
494         LOG(ERROR) << "EVP_CIPHER_CTX_ctrl: failed setting expected tag";
495         return {};
496     }
497 
498     int ret = EVP_DecryptFinal_ex(ctx.get(), (unsigned char*)plainText.data() + numWritten,
499                                   &numWritten);
500     if (ret != 1) {
501         LOG(ERROR) << "EVP_DecryptFinal_ex: failed";
502         return {};
503     }
504     if (numWritten != 0) {
505         LOG(ERROR) << "EVP_DecryptFinal_ex: Unexpected non-zero outl=" << numWritten;
506         return {};
507     }
508 
509     return plainText;
510 }
511 
encryptAes128Gcm(const vector<uint8_t> & key,const vector<uint8_t> & nonce,const vector<uint8_t> & data,const vector<uint8_t> & additionalAuthenticatedData)512 optional<vector<uint8_t>> encryptAes128Gcm(const vector<uint8_t>& key, const vector<uint8_t>& nonce,
513                                            const vector<uint8_t>& data,
514                                            const vector<uint8_t>& additionalAuthenticatedData) {
515     if (key.size() != kAes128GcmKeySize) {
516         LOG(ERROR) << "key is not kAes128GcmKeySize bytes";
517         return {};
518     }
519     if (nonce.size() != kAesGcmIvSize) {
520         LOG(ERROR) << "nonce is not kAesGcmIvSize bytes";
521         return {};
522     }
523 
524     // The result is the nonce (kAesGcmIvSize bytes), the ciphertext, and
525     // finally the tag (kAesGcmTagSize bytes).
526     vector<uint8_t> encryptedData;
527     encryptedData.resize(data.size() + kAesGcmIvSize + kAesGcmTagSize);
528     unsigned char* noncePtr = (unsigned char*)encryptedData.data();
529     unsigned char* cipherText = noncePtr + kAesGcmIvSize;
530     unsigned char* tag = cipherText + data.size();
531     memcpy(noncePtr, nonce.data(), kAesGcmIvSize);
532 
533     auto ctx = EvpCipherCtxPtr(EVP_CIPHER_CTX_new());
534     if (ctx.get() == nullptr) {
535         LOG(ERROR) << "EVP_CIPHER_CTX_new: failed";
536         return {};
537     }
538 
539     if (EVP_EncryptInit_ex(ctx.get(), EVP_aes_128_gcm(), NULL, NULL, NULL) != 1) {
540         LOG(ERROR) << "EVP_EncryptInit_ex: failed";
541         return {};
542     }
543 
544     if (EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_SET_IVLEN, kAesGcmIvSize, NULL) != 1) {
545         LOG(ERROR) << "EVP_CIPHER_CTX_ctrl: failed setting nonce length";
546         return {};
547     }
548 
549     if (EVP_EncryptInit_ex(ctx.get(), NULL, NULL, (unsigned char*)key.data(),
550                            (unsigned char*)nonce.data()) != 1) {
551         LOG(ERROR) << "EVP_EncryptInit_ex: failed";
552         return {};
553     }
554 
555     int numWritten;
556     if (additionalAuthenticatedData.size() > 0) {
557         if (EVP_EncryptUpdate(ctx.get(), NULL, &numWritten,
558                               (unsigned char*)additionalAuthenticatedData.data(),
559                               additionalAuthenticatedData.size()) != 1) {
560             LOG(ERROR) << "EVP_EncryptUpdate: failed for additionalAuthenticatedData";
561             return {};
562         }
563         if ((size_t)numWritten != additionalAuthenticatedData.size()) {
564             LOG(ERROR) << "EVP_EncryptUpdate: Unexpected outl=" << numWritten << " (expected "
565                        << additionalAuthenticatedData.size() << ") for additionalAuthenticatedData";
566             return {};
567         }
568     }
569 
570     if (data.size() > 0) {
571         if (EVP_EncryptUpdate(ctx.get(), cipherText, &numWritten, (unsigned char*)data.data(),
572                               data.size()) != 1) {
573             LOG(ERROR) << "EVP_EncryptUpdate: failed";
574             return {};
575         }
576         if ((size_t)numWritten != data.size()) {
577             LOG(ERROR) << "EVP_EncryptUpdate: Unexpected outl=" << numWritten << " (expected "
578                        << data.size() << ")";
579             return {};
580         }
581     }
582 
583     if (EVP_EncryptFinal_ex(ctx.get(), cipherText + numWritten, &numWritten) != 1) {
584         LOG(ERROR) << "EVP_EncryptFinal_ex: failed";
585         return {};
586     }
587     if (numWritten != 0) {
588         LOG(ERROR) << "EVP_EncryptFinal_ex: Unexpected non-zero outl=" << numWritten;
589         return {};
590     }
591 
592     if (EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_GET_TAG, kAesGcmTagSize, tag) != 1) {
593         LOG(ERROR) << "EVP_CIPHER_CTX_ctrl: failed getting tag";
594         return {};
595     }
596 
597     return encryptedData;
598 }
599 
certificateChainJoin(const vector<vector<uint8_t>> & certificateChain)600 vector<uint8_t> certificateChainJoin(const vector<vector<uint8_t>>& certificateChain) {
601     vector<uint8_t> ret;
602     for (const vector<uint8_t>& certificate : certificateChain) {
603         ret.insert(ret.end(), certificate.begin(), certificate.end());
604     }
605     return ret;
606 }
607 
certificateChainSplit(const vector<uint8_t> & certificateChain)608 optional<vector<vector<uint8_t>>> certificateChainSplit(const vector<uint8_t>& certificateChain) {
609     const unsigned char* pStart = (unsigned char*)certificateChain.data();
610     const unsigned char* p = pStart;
611     const unsigned char* pEnd = p + certificateChain.size();
612     vector<vector<uint8_t>> certificates;
613     while (p < pEnd) {
614         size_t begin = p - pStart;
615         auto x509 = X509_Ptr(d2i_X509(nullptr, &p, pEnd - p));
616         size_t next = p - pStart;
617         if (x509 == nullptr) {
618             LOG(ERROR) << "Error parsing X509 certificate";
619             return {};
620         }
621         vector<uint8_t> cert =
622                 vector<uint8_t>(certificateChain.begin() + begin, certificateChain.begin() + next);
623         certificates.push_back(std::move(cert));
624     }
625     return certificates;
626 }
627 
parseX509Certificates(const vector<uint8_t> & certificateChain,vector<X509_Ptr> & parsedCertificates)628 static bool parseX509Certificates(const vector<uint8_t>& certificateChain,
629                                   vector<X509_Ptr>& parsedCertificates) {
630     const unsigned char* p = (unsigned char*)certificateChain.data();
631     const unsigned char* pEnd = p + certificateChain.size();
632     parsedCertificates.resize(0);
633     while (p < pEnd) {
634         auto x509 = X509_Ptr(d2i_X509(nullptr, &p, pEnd - p));
635         if (x509 == nullptr) {
636             LOG(ERROR) << "Error parsing X509 certificate";
637             return false;
638         }
639         parsedCertificates.push_back(std::move(x509));
640     }
641     return true;
642 }
643 
certificateSignedByPublicKey(const vector<uint8_t> & certificate,const vector<uint8_t> & publicKey)644 bool certificateSignedByPublicKey(const vector<uint8_t>& certificate,
645                                   const vector<uint8_t>& publicKey) {
646     const unsigned char* p = certificate.data();
647     auto x509 = X509_Ptr(d2i_X509(nullptr, &p, certificate.size()));
648     if (x509 == nullptr) {
649         LOG(ERROR) << "Error parsing X509 certificate";
650         return false;
651     }
652 
653     auto group = EC_GROUP_Ptr(EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
654     auto point = EC_POINT_Ptr(EC_POINT_new(group.get()));
655     if (EC_POINT_oct2point(group.get(), point.get(), publicKey.data(), publicKey.size(), nullptr) !=
656         1) {
657         LOG(ERROR) << "Error decoding publicKey";
658         return false;
659     }
660     auto ecKey = EC_KEY_Ptr(EC_KEY_new());
661     auto pkey = EVP_PKEY_Ptr(EVP_PKEY_new());
662     if (ecKey.get() == nullptr || pkey.get() == nullptr) {
663         LOG(ERROR) << "Memory allocation failed";
664         return false;
665     }
666     if (EC_KEY_set_group(ecKey.get(), group.get()) != 1) {
667         LOG(ERROR) << "Error setting group";
668         return false;
669     }
670     if (EC_KEY_set_public_key(ecKey.get(), point.get()) != 1) {
671         LOG(ERROR) << "Error setting point";
672         return false;
673     }
674     if (EVP_PKEY_set1_EC_KEY(pkey.get(), ecKey.get()) != 1) {
675         LOG(ERROR) << "Error setting key";
676         return false;
677     }
678 
679     if (X509_verify(x509.get(), pkey.get()) != 1) {
680         return false;
681     }
682 
683     return true;
684 }
685 
686 // TODO: Right now the only check we perform is to check that each certificate
687 //       is signed by its successor. We should - but currently don't - also check
688 //       things like valid dates etc.
689 //
690 //       It would be nice to use X509_verify_cert() instead of doing our own thing.
691 //
certificateChainValidate(const vector<uint8_t> & certificateChain)692 bool certificateChainValidate(const vector<uint8_t>& certificateChain) {
693     vector<X509_Ptr> certs;
694 
695     if (!parseX509Certificates(certificateChain, certs)) {
696         LOG(ERROR) << "Error parsing X509 certificates";
697         return false;
698     }
699 
700     if (certs.size() == 1) {
701         return true;
702     }
703 
704     for (size_t n = 1; n < certs.size(); n++) {
705         const X509_Ptr& keyCert = certs[n - 1];
706         const X509_Ptr& signingCert = certs[n];
707         EVP_PKEY_Ptr signingPubkey(X509_get_pubkey(signingCert.get()));
708         if (X509_verify(keyCert.get(), signingPubkey.get()) != 1) {
709             LOG(ERROR) << "Error validating cert at index " << n - 1
710                        << " is signed by its successor";
711             return false;
712         }
713     }
714 
715     return true;
716 }
717 
checkEcDsaSignature(const vector<uint8_t> & digest,const vector<uint8_t> & signature,const vector<uint8_t> & publicKey)718 bool checkEcDsaSignature(const vector<uint8_t>& digest, const vector<uint8_t>& signature,
719                          const vector<uint8_t>& publicKey) {
720     const unsigned char* p = (unsigned char*)signature.data();
721     auto sig = ECDSA_SIG_Ptr(d2i_ECDSA_SIG(nullptr, &p, signature.size()));
722     if (sig.get() == nullptr) {
723         LOG(ERROR) << "Error decoding DER encoded signature";
724         return false;
725     }
726 
727     auto group = EC_GROUP_Ptr(EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
728     auto point = EC_POINT_Ptr(EC_POINT_new(group.get()));
729     if (EC_POINT_oct2point(group.get(), point.get(), publicKey.data(), publicKey.size(), nullptr) !=
730         1) {
731         LOG(ERROR) << "Error decoding publicKey";
732         return false;
733     }
734     auto ecKey = EC_KEY_Ptr(EC_KEY_new());
735     auto pkey = EVP_PKEY_Ptr(EVP_PKEY_new());
736     if (ecKey.get() == nullptr || pkey.get() == nullptr) {
737         LOG(ERROR) << "Memory allocation failed";
738         return false;
739     }
740     if (EC_KEY_set_group(ecKey.get(), group.get()) != 1) {
741         LOG(ERROR) << "Error setting group";
742         return false;
743     }
744     if (EC_KEY_set_public_key(ecKey.get(), point.get()) != 1) {
745         LOG(ERROR) << "Error setting point";
746         return false;
747     }
748     if (EVP_PKEY_set1_EC_KEY(pkey.get(), ecKey.get()) != 1) {
749         LOG(ERROR) << "Error setting key";
750         return false;
751     }
752 
753     int rc = ECDSA_do_verify(digest.data(), digest.size(), sig.get(), ecKey.get());
754     if (rc != 1) {
755         LOG(ERROR) << "Error verifying signature (rc=" << rc << ")";
756         return false;
757     }
758 
759     return true;
760 }
761 
sha256(const vector<uint8_t> & data)762 vector<uint8_t> sha256(const vector<uint8_t>& data) {
763     vector<uint8_t> ret;
764     ret.resize(SHA256_DIGEST_LENGTH);
765     SHA256_CTX ctx;
766     SHA256_Init(&ctx);
767     SHA256_Update(&ctx, data.data(), data.size());
768     SHA256_Final((unsigned char*)ret.data(), &ctx);
769     return ret;
770 }
771 
signEcDsaDigest(const vector<uint8_t> & key,const vector<uint8_t> & dataDigest)772 optional<vector<uint8_t>> signEcDsaDigest(const vector<uint8_t>& key,
773                                           const vector<uint8_t>& dataDigest) {
774     auto bn = BIGNUM_Ptr(BN_bin2bn(key.data(), key.size(), nullptr));
775     if (bn.get() == nullptr) {
776         LOG(ERROR) << "Error creating BIGNUM";
777         return {};
778     }
779 
780     auto ec_key = EC_KEY_Ptr(EC_KEY_new_by_curve_name(NID_X9_62_prime256v1));
781     if (EC_KEY_set_private_key(ec_key.get(), bn.get()) != 1) {
782         LOG(ERROR) << "Error setting private key from BIGNUM";
783         return {};
784     }
785 
786     ECDSA_SIG* sig = ECDSA_do_sign(dataDigest.data(), dataDigest.size(), ec_key.get());
787     if (sig == nullptr) {
788         LOG(ERROR) << "Error signing digest";
789         return {};
790     }
791     size_t len = i2d_ECDSA_SIG(sig, nullptr);
792     vector<uint8_t> signature;
793     signature.resize(len);
794     unsigned char* p = (unsigned char*)signature.data();
795     i2d_ECDSA_SIG(sig, &p);
796     ECDSA_SIG_free(sig);
797     return signature;
798 }
799 
signEcDsa(const vector<uint8_t> & key,const vector<uint8_t> & data)800 optional<vector<uint8_t>> signEcDsa(const vector<uint8_t>& key, const vector<uint8_t>& data) {
801     return signEcDsaDigest(key, sha256(data));
802 }
803 
hmacSha256(const vector<uint8_t> & key,const vector<uint8_t> & data)804 optional<vector<uint8_t>> hmacSha256(const vector<uint8_t>& key, const vector<uint8_t>& data) {
805     HMAC_CTX ctx;
806     HMAC_CTX_init(&ctx);
807     if (HMAC_Init_ex(&ctx, key.data(), key.size(), EVP_sha256(), nullptr /* impl */) != 1) {
808         LOG(ERROR) << "Error initializing HMAC_CTX";
809         return {};
810     }
811     if (HMAC_Update(&ctx, data.data(), data.size()) != 1) {
812         LOG(ERROR) << "Error updating HMAC_CTX";
813         return {};
814     }
815     vector<uint8_t> hmac;
816     hmac.resize(32);
817     unsigned int size = 0;
818     if (HMAC_Final(&ctx, hmac.data(), &size) != 1) {
819         LOG(ERROR) << "Error finalizing HMAC_CTX";
820         return {};
821     }
822     if (size != 32) {
823         LOG(ERROR) << "Expected 32 bytes from HMAC_Final, got " << size;
824         return {};
825     }
826     return hmac;
827 }
828 
createEcKeyPairAndAttestation(const vector<uint8_t> & challenge,const vector<uint8_t> & applicationId,bool isTestCredential)829 optional<std::pair<vector<uint8_t>, vector<vector<uint8_t>>>> createEcKeyPairAndAttestation(
830         const vector<uint8_t>& challenge, const vector<uint8_t>& applicationId,
831         bool isTestCredential) {
832     EVP_PKEY_Ptr pkey = generateP256Key();
833 
834     optional<vector<vector<uint8_t>>> attestationCertChain =
835             createAttestation(pkey.get(), challenge, applicationId, isTestCredential);
836     if (!attestationCertChain) {
837         LOG(ERROR) << "Error create attestation from key and challenge";
838         return {};
839     }
840 
841     optional<vector<uint8_t>> keyPair = derEncodeKeyPair(*pkey);
842     if (!keyPair) {
843         return std::nullopt;
844     }
845 
846     return make_pair(*keyPair, *attestationCertChain);
847 }
848 
createEcKeyPairWithAttestationKey(const vector<uint8_t> & challenge,const vector<uint8_t> & applicationId,const vector<uint8_t> & attestationKeyBlob,const vector<uint8_t> & attestationKeyCert,bool isTestCredential)849 optional<std::pair<vector<uint8_t>, vector<uint8_t>>> createEcKeyPairWithAttestationKey(
850         const vector<uint8_t>& challenge, const vector<uint8_t>& applicationId,
851         const vector<uint8_t>& attestationKeyBlob, const vector<uint8_t>& attestationKeyCert,
852         bool isTestCredential) {
853     // Pretend to be implemented in a trusted environment just so we can pass
854     // the VTS tests. Of course, this is a pretend-only game since hopefully no
855     // relying party is ever going to trust our batch key and those keys above
856     // it.
857     ::keymaster::PureSoftKeymasterContext context(::keymaster::KmVersion::KEYMINT_1,
858                                                   KM_SECURITY_LEVEL_TRUSTED_ENVIRONMENT);
859 
860     EVP_PKEY_Ptr pkey = generateP256Key();
861 
862     uint64_t validFromMs = time(nullptr) * 1000;
863     optional<uint64_t> notAfterMs =
864             getCertificateExpiryAsMillis(attestationKeyCert.data(), attestationKeyCert.size());
865     if (!notAfterMs) {
866         LOG(ERROR) << "Error getting expiration time for attestation cert";
867         return std::nullopt;
868     }
869 
870     optional<vector<uint8_t>> derIssuerSubject =
871             support::extractDerSubjectFromCertificate(attestationKeyCert);
872     if (!derIssuerSubject) {
873         LOG(ERROR) << "Error error extracting issuer name from the given certificate chain";
874         return std::nullopt;
875     }
876 
877     optional<vector<vector<uint8_t>>> attestationCertChain = signAttestationCertificate(
878             context, pkey.get(), applicationId, challenge, attestationKeyBlob, *derIssuerSubject,
879             validFromMs, *notAfterMs, isTestCredential);
880     if (!attestationCertChain) {
881         LOG(ERROR) << "Error signing attestation certificate";
882         return std::nullopt;
883     }
884     if (!attestationCertChain) {
885         LOG(ERROR) << "Error create attestation from key and challenge";
886         return std::nullopt;
887     }
888     if (attestationCertChain->size() != 1) {
889         LOG(ERROR) << "Expected exactly one attestation cert, got " << attestationCertChain->size();
890         return std::nullopt;
891     }
892 
893     optional<vector<uint8_t>> keyPair = derEncodeKeyPair(*pkey);
894     if (!keyPair) {
895         return std::nullopt;
896     }
897 
898     return make_pair(*keyPair, attestationCertChain->at(0));
899 }
900 
createAttestationForEcPublicKey(const vector<uint8_t> & publicKey,const vector<uint8_t> & challenge,const vector<uint8_t> & applicationId)901 optional<vector<vector<uint8_t>>> createAttestationForEcPublicKey(
902         const vector<uint8_t>& publicKey, const vector<uint8_t>& challenge,
903         const vector<uint8_t>& applicationId) {
904     auto group = EC_GROUP_Ptr(EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
905     auto point = EC_POINT_Ptr(EC_POINT_new(group.get()));
906     if (EC_POINT_oct2point(group.get(), point.get(), publicKey.data(), publicKey.size(), nullptr) !=
907         1) {
908         LOG(ERROR) << "Error decoding publicKey";
909         return {};
910     }
911     auto ecKey = EC_KEY_Ptr(EC_KEY_new());
912     auto pkey = EVP_PKEY_Ptr(EVP_PKEY_new());
913     if (ecKey.get() == nullptr || pkey.get() == nullptr) {
914         LOG(ERROR) << "Memory allocation failed";
915         return {};
916     }
917     if (EC_KEY_set_group(ecKey.get(), group.get()) != 1) {
918         LOG(ERROR) << "Error setting group";
919         return {};
920     }
921     if (EC_KEY_set_public_key(ecKey.get(), point.get()) != 1) {
922         LOG(ERROR) << "Error setting point";
923         return {};
924     }
925     if (EVP_PKEY_set1_EC_KEY(pkey.get(), ecKey.get()) != 1) {
926         LOG(ERROR) << "Error setting key";
927         return {};
928     }
929 
930     optional<vector<vector<uint8_t>>> attestationCert =
931             createAttestation(pkey.get(), applicationId, challenge, false /* isTestCredential */);
932     if (!attestationCert) {
933         LOG(ERROR) << "Error create attestation from key and challenge";
934         return {};
935     }
936 
937     return attestationCert.value();
938 }
939 
createEcKeyPair()940 optional<vector<uint8_t>> createEcKeyPair() {
941     auto ec_key = EC_KEY_Ptr(EC_KEY_new());
942     auto pkey = EVP_PKEY_Ptr(EVP_PKEY_new());
943     if (ec_key.get() == nullptr || pkey.get() == nullptr) {
944         LOG(ERROR) << "Memory allocation failed";
945         return {};
946     }
947 
948     auto group = EC_GROUP_Ptr(EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
949     if (group.get() == nullptr) {
950         LOG(ERROR) << "Error creating EC group by curve name";
951         return {};
952     }
953 
954     if (EC_KEY_set_group(ec_key.get(), group.get()) != 1 ||
955         EC_KEY_generate_key(ec_key.get()) != 1 || EC_KEY_check_key(ec_key.get()) < 0) {
956         LOG(ERROR) << "Error generating key";
957         return {};
958     }
959 
960     if (EVP_PKEY_set1_EC_KEY(pkey.get(), ec_key.get()) != 1) {
961         LOG(ERROR) << "Error getting private key";
962         return {};
963     }
964 
965     int size = i2d_PrivateKey(pkey.get(), nullptr);
966     if (size == 0) {
967         LOG(ERROR) << "Error generating public key encoding";
968         return {};
969     }
970     vector<uint8_t> keyPair;
971     keyPair.resize(size);
972     unsigned char* p = keyPair.data();
973     i2d_PrivateKey(pkey.get(), &p);
974     return keyPair;
975 }
976 
ecKeyPairGetPublicKey(const vector<uint8_t> & keyPair)977 optional<vector<uint8_t>> ecKeyPairGetPublicKey(const vector<uint8_t>& keyPair) {
978     const unsigned char* p = (const unsigned char*)keyPair.data();
979     auto pkey = EVP_PKEY_Ptr(d2i_PrivateKey(EVP_PKEY_EC, nullptr, &p, keyPair.size()));
980     if (pkey.get() == nullptr) {
981         LOG(ERROR) << "Error parsing keyPair";
982         return {};
983     }
984 
985     auto ecKey = EC_KEY_Ptr(EVP_PKEY_get1_EC_KEY(pkey.get()));
986     if (ecKey.get() == nullptr) {
987         LOG(ERROR) << "Failed getting EC key";
988         return {};
989     }
990 
991     auto ecGroup = EC_KEY_get0_group(ecKey.get());
992     auto ecPoint = EC_KEY_get0_public_key(ecKey.get());
993     int size = EC_POINT_point2oct(ecGroup, ecPoint, POINT_CONVERSION_UNCOMPRESSED, nullptr, 0,
994                                   nullptr);
995     if (size == 0) {
996         LOG(ERROR) << "Error generating public key encoding";
997         return {};
998     }
999 
1000     vector<uint8_t> publicKey;
1001     publicKey.resize(size);
1002     EC_POINT_point2oct(ecGroup, ecPoint, POINT_CONVERSION_UNCOMPRESSED, publicKey.data(),
1003                        publicKey.size(), nullptr);
1004     return publicKey;
1005 }
1006 
ecKeyPairGetPrivateKey(const vector<uint8_t> & keyPair)1007 optional<vector<uint8_t>> ecKeyPairGetPrivateKey(const vector<uint8_t>& keyPair) {
1008     const unsigned char* p = (const unsigned char*)keyPair.data();
1009     auto pkey = EVP_PKEY_Ptr(d2i_PrivateKey(EVP_PKEY_EC, nullptr, &p, keyPair.size()));
1010     if (pkey.get() == nullptr) {
1011         LOG(ERROR) << "Error parsing keyPair";
1012         return {};
1013     }
1014 
1015     auto ecKey = EC_KEY_Ptr(EVP_PKEY_get1_EC_KEY(pkey.get()));
1016     if (ecKey.get() == nullptr) {
1017         LOG(ERROR) << "Failed getting EC key";
1018         return {};
1019     }
1020 
1021     const BIGNUM* bignum = EC_KEY_get0_private_key(ecKey.get());
1022     if (bignum == nullptr) {
1023         LOG(ERROR) << "Error getting bignum from private key";
1024         return {};
1025     }
1026     vector<uint8_t> privateKey;
1027 
1028     // Note that this may return fewer than 32 bytes so pad with zeroes since we
1029     // want to always return 32 bytes.
1030     size_t numBytes = BN_num_bytes(bignum);
1031     if (numBytes > 32) {
1032         LOG(ERROR) << "Size is " << numBytes << ", expected this to be 32 or less";
1033         return {};
1034     }
1035     privateKey.resize(32);
1036     for (size_t n = 0; n < 32 - numBytes; n++) {
1037         privateKey[n] = 0x00;
1038     }
1039     BN_bn2bin(bignum, privateKey.data() + 32 - numBytes);
1040     return privateKey;
1041 }
1042 
ecPrivateKeyToKeyPair(const vector<uint8_t> & privateKey)1043 optional<vector<uint8_t>> ecPrivateKeyToKeyPair(const vector<uint8_t>& privateKey) {
1044     auto bn = BIGNUM_Ptr(BN_bin2bn(privateKey.data(), privateKey.size(), nullptr));
1045     if (bn.get() == nullptr) {
1046         LOG(ERROR) << "Error creating BIGNUM";
1047         return {};
1048     }
1049 
1050     auto ecKey = EC_KEY_Ptr(EC_KEY_new_by_curve_name(NID_X9_62_prime256v1));
1051     if (EC_KEY_set_private_key(ecKey.get(), bn.get()) != 1) {
1052         LOG(ERROR) << "Error setting private key from BIGNUM";
1053         return {};
1054     }
1055 
1056     auto pkey = EVP_PKEY_Ptr(EVP_PKEY_new());
1057     if (pkey.get() == nullptr) {
1058         LOG(ERROR) << "Memory allocation failed";
1059         return {};
1060     }
1061 
1062     if (EVP_PKEY_set1_EC_KEY(pkey.get(), ecKey.get()) != 1) {
1063         LOG(ERROR) << "Error getting private key";
1064         return {};
1065     }
1066 
1067     int size = i2d_PrivateKey(pkey.get(), nullptr);
1068     if (size == 0) {
1069         LOG(ERROR) << "Error generating public key encoding";
1070         return {};
1071     }
1072     vector<uint8_t> keyPair;
1073     keyPair.resize(size);
1074     unsigned char* p = keyPair.data();
1075     i2d_PrivateKey(pkey.get(), &p);
1076     return keyPair;
1077 }
1078 
ecKeyPairGetPkcs12(const vector<uint8_t> & keyPair,const string & name,const string & serialDecimal,const string & issuer,const string & subject,time_t validityNotBefore,time_t validityNotAfter)1079 optional<vector<uint8_t>> ecKeyPairGetPkcs12(const vector<uint8_t>& keyPair, const string& name,
1080                                              const string& serialDecimal, const string& issuer,
1081                                              const string& subject, time_t validityNotBefore,
1082                                              time_t validityNotAfter) {
1083     const unsigned char* p = (const unsigned char*)keyPair.data();
1084     auto pkey = EVP_PKEY_Ptr(d2i_PrivateKey(EVP_PKEY_EC, nullptr, &p, keyPair.size()));
1085     if (pkey.get() == nullptr) {
1086         LOG(ERROR) << "Error parsing keyPair";
1087         return {};
1088     }
1089 
1090     auto x509 = X509_Ptr(X509_new());
1091     if (!x509.get()) {
1092         LOG(ERROR) << "Error creating X509 certificate";
1093         return {};
1094     }
1095 
1096     if (!X509_set_version(x509.get(), 2 /* version 3, but zero-based */)) {
1097         LOG(ERROR) << "Error setting version to 3";
1098         return {};
1099     }
1100 
1101     if (X509_set_pubkey(x509.get(), pkey.get()) != 1) {
1102         LOG(ERROR) << "Error setting public key";
1103         return {};
1104     }
1105 
1106     BIGNUM* bignumSerial = nullptr;
1107     if (BN_dec2bn(&bignumSerial, serialDecimal.c_str()) == 0) {
1108         LOG(ERROR) << "Error parsing serial";
1109         return {};
1110     }
1111     auto bignumSerialPtr = BIGNUM_Ptr(bignumSerial);
1112     auto asnSerial = ASN1_INTEGER_Ptr(BN_to_ASN1_INTEGER(bignumSerial, nullptr));
1113     if (X509_set_serialNumber(x509.get(), asnSerial.get()) != 1) {
1114         LOG(ERROR) << "Error setting serial";
1115         return {};
1116     }
1117 
1118     auto x509Issuer = X509_NAME_Ptr(X509_NAME_new());
1119     if (x509Issuer.get() == nullptr ||
1120         X509_NAME_add_entry_by_txt(x509Issuer.get(), "CN", MBSTRING_ASC,
1121                                    (const uint8_t*)issuer.c_str(), issuer.size(), -1 /* loc */,
1122                                    0 /* set */) != 1 ||
1123         X509_set_issuer_name(x509.get(), x509Issuer.get()) != 1) {
1124         LOG(ERROR) << "Error setting issuer";
1125         return {};
1126     }
1127 
1128     auto x509Subject = X509_NAME_Ptr(X509_NAME_new());
1129     if (x509Subject.get() == nullptr ||
1130         X509_NAME_add_entry_by_txt(x509Subject.get(), "CN", MBSTRING_ASC,
1131                                    (const uint8_t*)subject.c_str(), subject.size(), -1 /* loc */,
1132                                    0 /* set */) != 1 ||
1133         X509_set_subject_name(x509.get(), x509Subject.get()) != 1) {
1134         LOG(ERROR) << "Error setting subject";
1135         return {};
1136     }
1137 
1138     auto asnNotBefore = ASN1_TIME_Ptr(ASN1_TIME_set(nullptr, validityNotBefore));
1139     if (asnNotBefore.get() == nullptr || X509_set_notBefore(x509.get(), asnNotBefore.get()) != 1) {
1140         LOG(ERROR) << "Error setting notBefore";
1141         return {};
1142     }
1143 
1144     auto asnNotAfter = ASN1_TIME_Ptr(ASN1_TIME_set(nullptr, validityNotAfter));
1145     if (asnNotAfter.get() == nullptr || X509_set_notAfter(x509.get(), asnNotAfter.get()) != 1) {
1146         LOG(ERROR) << "Error setting notAfter";
1147         return {};
1148     }
1149 
1150     if (X509_sign(x509.get(), pkey.get(), EVP_sha256()) == 0) {
1151         LOG(ERROR) << "Error signing X509 certificate";
1152         return {};
1153     }
1154 
1155     // Ideally we wouldn't encrypt it (we're only using this function for
1156     // sending a key-pair over binder to the Android app) but BoringSSL does not
1157     // support this: from pkcs8_x509.c in BoringSSL: "In OpenSSL, -1 here means
1158     // to use no encryption, which we do not currently support."
1159     //
1160     // Passing nullptr as |pass|, though, means "no password". So we'll do that.
1161     // Compare with the receiving side - CredstoreIdentityCredential.java - where
1162     // an empty char[] is passed as the password.
1163     //
1164     auto pkcs12 = PKCS12_Ptr(PKCS12_create(nullptr, name.c_str(), pkey.get(), x509.get(),
1165                                            nullptr,  // ca
1166                                            0,        // nid_key
1167                                            0,        // nid_cert
1168                                            0,        // iter,
1169                                            0,        // mac_iter,
1170                                            0));      // keytype
1171     if (pkcs12.get() == nullptr) {
1172         char buf[128];
1173         long errCode = ERR_get_error();
1174         ERR_error_string_n(errCode, buf, sizeof buf);
1175         LOG(ERROR) << "Error creating PKCS12, code " << errCode << ": " << buf;
1176         return {};
1177     }
1178 
1179     unsigned char* buffer = nullptr;
1180     int length = i2d_PKCS12(pkcs12.get(), &buffer);
1181     if (length < 0) {
1182         LOG(ERROR) << "Error encoding PKCS12";
1183         return {};
1184     }
1185     vector<uint8_t> pkcs12Bytes;
1186     pkcs12Bytes.resize(length);
1187     memcpy(pkcs12Bytes.data(), buffer, length);
1188     OPENSSL_free(buffer);
1189 
1190     return pkcs12Bytes;
1191 }
1192 
ecPublicKeyGenerateCertificate(const vector<uint8_t> & publicKey,const vector<uint8_t> & signingKey,const string & serialDecimal,const string & issuer,const string & subject,time_t validityNotBefore,time_t validityNotAfter,const map<string,vector<uint8_t>> & extensions)1193 optional<vector<uint8_t>> ecPublicKeyGenerateCertificate(
1194         const vector<uint8_t>& publicKey, const vector<uint8_t>& signingKey,
1195         const string& serialDecimal, const string& issuer, const string& subject,
1196         time_t validityNotBefore, time_t validityNotAfter,
1197         const map<string, vector<uint8_t>>& extensions) {
1198     auto group = EC_GROUP_Ptr(EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
1199     auto point = EC_POINT_Ptr(EC_POINT_new(group.get()));
1200     if (EC_POINT_oct2point(group.get(), point.get(), publicKey.data(), publicKey.size(), nullptr) !=
1201         1) {
1202         LOG(ERROR) << "Error decoding publicKey";
1203         return {};
1204     }
1205     auto ecKey = EC_KEY_Ptr(EC_KEY_new());
1206     auto pkey = EVP_PKEY_Ptr(EVP_PKEY_new());
1207     if (ecKey.get() == nullptr || pkey.get() == nullptr) {
1208         LOG(ERROR) << "Memory allocation failed";
1209         return {};
1210     }
1211     if (EC_KEY_set_group(ecKey.get(), group.get()) != 1) {
1212         LOG(ERROR) << "Error setting group";
1213         return {};
1214     }
1215     if (EC_KEY_set_public_key(ecKey.get(), point.get()) != 1) {
1216         LOG(ERROR) << "Error setting point";
1217         return {};
1218     }
1219     if (EVP_PKEY_set1_EC_KEY(pkey.get(), ecKey.get()) != 1) {
1220         LOG(ERROR) << "Error setting key";
1221         return {};
1222     }
1223 
1224     auto bn = BIGNUM_Ptr(BN_bin2bn(signingKey.data(), signingKey.size(), nullptr));
1225     if (bn.get() == nullptr) {
1226         LOG(ERROR) << "Error creating BIGNUM for private key";
1227         return {};
1228     }
1229     auto privEcKey = EC_KEY_Ptr(EC_KEY_new_by_curve_name(NID_X9_62_prime256v1));
1230     if (EC_KEY_set_private_key(privEcKey.get(), bn.get()) != 1) {
1231         LOG(ERROR) << "Error setting private key from BIGNUM";
1232         return {};
1233     }
1234     auto privPkey = EVP_PKEY_Ptr(EVP_PKEY_new());
1235     if (EVP_PKEY_set1_EC_KEY(privPkey.get(), privEcKey.get()) != 1) {
1236         LOG(ERROR) << "Error setting private key";
1237         return {};
1238     }
1239 
1240     return ecPublicKeyGenerateCertificate(pkey.get(), privPkey.get(), serialDecimal, issuer,
1241                                           subject, validityNotBefore, validityNotAfter, extensions);
1242 }
1243 
ecPublicKeyGenerateCertificate(EVP_PKEY * publicKey,EVP_PKEY * signingKey,const string & serialDecimal,const string & issuer,const string & subject,time_t validityNotBefore,time_t validityNotAfter,const map<string,vector<uint8_t>> & extensions)1244 optional<vector<uint8_t>> ecPublicKeyGenerateCertificate(
1245         EVP_PKEY* publicKey, EVP_PKEY* signingKey, const string& serialDecimal,
1246         const string& issuer, const string& subject, time_t validityNotBefore,
1247         time_t validityNotAfter, const map<string, vector<uint8_t>>& extensions) {
1248     auto x509 = X509_Ptr(X509_new());
1249     if (!x509.get()) {
1250         LOG(ERROR) << "Error creating X509 certificate";
1251         return {};
1252     }
1253 
1254     if (!X509_set_version(x509.get(), 2 /* version 3, but zero-based */)) {
1255         LOG(ERROR) << "Error setting version to 3";
1256         return {};
1257     }
1258 
1259     if (X509_set_pubkey(x509.get(), publicKey) != 1) {
1260         LOG(ERROR) << "Error setting public key";
1261         return {};
1262     }
1263 
1264     BIGNUM* bignumSerial = nullptr;
1265     if (BN_dec2bn(&bignumSerial, serialDecimal.c_str()) == 0) {
1266         LOG(ERROR) << "Error parsing serial";
1267         return {};
1268     }
1269     auto bignumSerialPtr = BIGNUM_Ptr(bignumSerial);
1270     auto asnSerial = ASN1_INTEGER_Ptr(BN_to_ASN1_INTEGER(bignumSerial, nullptr));
1271     if (X509_set_serialNumber(x509.get(), asnSerial.get()) != 1) {
1272         LOG(ERROR) << "Error setting serial";
1273         return {};
1274     }
1275 
1276     auto x509Issuer = X509_NAME_Ptr(X509_NAME_new());
1277     if (x509Issuer.get() == nullptr ||
1278         X509_NAME_add_entry_by_txt(x509Issuer.get(), "CN", MBSTRING_ASC,
1279                                    (const uint8_t*)issuer.c_str(), issuer.size(), -1 /* loc */,
1280                                    0 /* set */) != 1 ||
1281         X509_set_issuer_name(x509.get(), x509Issuer.get()) != 1) {
1282         LOG(ERROR) << "Error setting issuer";
1283         return {};
1284     }
1285 
1286     auto x509Subject = X509_NAME_Ptr(X509_NAME_new());
1287     if (x509Subject.get() == nullptr ||
1288         X509_NAME_add_entry_by_txt(x509Subject.get(), "CN", MBSTRING_ASC,
1289                                    (const uint8_t*)subject.c_str(), subject.size(), -1 /* loc */,
1290                                    0 /* set */) != 1 ||
1291         X509_set_subject_name(x509.get(), x509Subject.get()) != 1) {
1292         LOG(ERROR) << "Error setting subject";
1293         return {};
1294     }
1295 
1296     auto asnNotBefore = ASN1_TIME_Ptr(ASN1_TIME_set(nullptr, validityNotBefore));
1297     if (asnNotBefore.get() == nullptr || X509_set_notBefore(x509.get(), asnNotBefore.get()) != 1) {
1298         LOG(ERROR) << "Error setting notBefore";
1299         return {};
1300     }
1301 
1302     auto asnNotAfter = ASN1_TIME_Ptr(ASN1_TIME_set(nullptr, validityNotAfter));
1303     if (asnNotAfter.get() == nullptr || X509_set_notAfter(x509.get(), asnNotAfter.get()) != 1) {
1304         LOG(ERROR) << "Error setting notAfter";
1305         return {};
1306     }
1307 
1308     for (auto const& [oidStr, blob] : extensions) {
1309         ASN1_OBJECT_Ptr oid(
1310                 OBJ_txt2obj(oidStr.c_str(), 1));  // accept numerical dotted string form only
1311         if (!oid.get()) {
1312             LOG(ERROR) << "Error setting OID";
1313             return {};
1314         }
1315         ASN1_OCTET_STRING_Ptr octetString(ASN1_OCTET_STRING_new());
1316         if (!ASN1_OCTET_STRING_set(octetString.get(), blob.data(), blob.size())) {
1317             LOG(ERROR) << "Error setting octet string for extension";
1318             return {};
1319         }
1320 
1321         X509_EXTENSION_Ptr extension = X509_EXTENSION_Ptr(X509_EXTENSION_new());
1322         extension.reset(X509_EXTENSION_create_by_OBJ(nullptr, oid.get(), 0 /* not critical */,
1323                                                      octetString.get()));
1324         if (!extension.get()) {
1325             LOG(ERROR) << "Error setting extension";
1326             return {};
1327         }
1328         if (!X509_add_ext(x509.get(), extension.get(), -1)) {
1329             LOG(ERROR) << "Error adding extension";
1330             return {};
1331         }
1332     }
1333 
1334     if (X509_sign(x509.get(), signingKey, EVP_sha256()) == 0) {
1335         LOG(ERROR) << "Error signing X509 certificate";
1336         return {};
1337     }
1338 
1339     unsigned char* buffer = nullptr;
1340     int length = i2d_X509(x509.get(), &buffer);
1341     if (length < 0) {
1342         LOG(ERROR) << "Error DER encoding X509 certificate";
1343         return {};
1344     }
1345 
1346     vector<uint8_t> certificate;
1347     certificate.resize(length);
1348     memcpy(certificate.data(), buffer, length);
1349     OPENSSL_free(buffer);
1350     return certificate;
1351 }
1352 
ecdh(const vector<uint8_t> & publicKey,const vector<uint8_t> & privateKey)1353 optional<vector<uint8_t>> ecdh(const vector<uint8_t>& publicKey,
1354                                const vector<uint8_t>& privateKey) {
1355     auto group = EC_GROUP_Ptr(EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
1356     auto point = EC_POINT_Ptr(EC_POINT_new(group.get()));
1357     if (EC_POINT_oct2point(group.get(), point.get(), publicKey.data(), publicKey.size(), nullptr) !=
1358         1) {
1359         LOG(ERROR) << "Error decoding publicKey";
1360         return {};
1361     }
1362     auto ecKey = EC_KEY_Ptr(EC_KEY_new());
1363     auto pkey = EVP_PKEY_Ptr(EVP_PKEY_new());
1364     if (ecKey.get() == nullptr || pkey.get() == nullptr) {
1365         LOG(ERROR) << "Memory allocation failed";
1366         return {};
1367     }
1368     if (EC_KEY_set_group(ecKey.get(), group.get()) != 1) {
1369         LOG(ERROR) << "Error setting group";
1370         return {};
1371     }
1372     if (EC_KEY_set_public_key(ecKey.get(), point.get()) != 1) {
1373         LOG(ERROR) << "Error setting point";
1374         return {};
1375     }
1376     if (EVP_PKEY_set1_EC_KEY(pkey.get(), ecKey.get()) != 1) {
1377         LOG(ERROR) << "Error setting key";
1378         return {};
1379     }
1380 
1381     auto bn = BIGNUM_Ptr(BN_bin2bn(privateKey.data(), privateKey.size(), nullptr));
1382     if (bn.get() == nullptr) {
1383         LOG(ERROR) << "Error creating BIGNUM for private key";
1384         return {};
1385     }
1386     auto privEcKey = EC_KEY_Ptr(EC_KEY_new_by_curve_name(NID_X9_62_prime256v1));
1387     if (EC_KEY_set_private_key(privEcKey.get(), bn.get()) != 1) {
1388         LOG(ERROR) << "Error setting private key from BIGNUM";
1389         return {};
1390     }
1391     auto privPkey = EVP_PKEY_Ptr(EVP_PKEY_new());
1392     if (EVP_PKEY_set1_EC_KEY(privPkey.get(), privEcKey.get()) != 1) {
1393         LOG(ERROR) << "Error setting private key";
1394         return {};
1395     }
1396 
1397     auto ctx = EVP_PKEY_CTX_Ptr(EVP_PKEY_CTX_new(privPkey.get(), NULL));
1398     if (ctx.get() == nullptr) {
1399         LOG(ERROR) << "Error creating context";
1400         return {};
1401     }
1402 
1403     if (EVP_PKEY_derive_init(ctx.get()) != 1) {
1404         LOG(ERROR) << "Error initializing context";
1405         return {};
1406     }
1407 
1408     if (EVP_PKEY_derive_set_peer(ctx.get(), pkey.get()) != 1) {
1409         LOG(ERROR) << "Error setting peer";
1410         return {};
1411     }
1412 
1413     /* Determine buffer length for shared secret */
1414     size_t secretLen = 0;
1415     if (EVP_PKEY_derive(ctx.get(), NULL, &secretLen) != 1) {
1416         LOG(ERROR) << "Error determing length of shared secret";
1417         return {};
1418     }
1419     vector<uint8_t> sharedSecret;
1420     sharedSecret.resize(secretLen);
1421 
1422     if (EVP_PKEY_derive(ctx.get(), sharedSecret.data(), &secretLen) != 1) {
1423         LOG(ERROR) << "Error deriving shared secret";
1424         return {};
1425     }
1426     return sharedSecret;
1427 }
1428 
hkdf(const vector<uint8_t> & sharedSecret,const vector<uint8_t> & salt,const vector<uint8_t> & info,size_t size)1429 optional<vector<uint8_t>> hkdf(const vector<uint8_t>& sharedSecret, const vector<uint8_t>& salt,
1430                                const vector<uint8_t>& info, size_t size) {
1431     vector<uint8_t> derivedKey;
1432     derivedKey.resize(size);
1433     if (HKDF(derivedKey.data(), derivedKey.size(), EVP_sha256(), sharedSecret.data(),
1434              sharedSecret.size(), salt.data(), salt.size(), info.data(), info.size()) != 1) {
1435         LOG(ERROR) << "Error deriving key";
1436         return {};
1437     }
1438     return derivedKey;
1439 }
1440 
removeLeadingZeroes(vector<uint8_t> & vec)1441 void removeLeadingZeroes(vector<uint8_t>& vec) {
1442     while (vec.size() >= 1 && vec[0] == 0x00) {
1443         vec.erase(vec.begin());
1444     }
1445 }
1446 
ecPublicKeyGetXandY(const vector<uint8_t> & publicKey)1447 tuple<bool, vector<uint8_t>, vector<uint8_t>> ecPublicKeyGetXandY(
1448         const vector<uint8_t>& publicKey) {
1449     if (publicKey.size() != 65 || publicKey[0] != 0x04) {
1450         LOG(ERROR) << "publicKey is not in the expected format";
1451         return std::make_tuple(false, vector<uint8_t>(), vector<uint8_t>());
1452     }
1453     vector<uint8_t> x, y;
1454     x.resize(32);
1455     y.resize(32);
1456     memcpy(x.data(), publicKey.data() + 1, 32);
1457     memcpy(y.data(), publicKey.data() + 33, 32);
1458 
1459     removeLeadingZeroes(x);
1460     removeLeadingZeroes(y);
1461 
1462     return std::make_tuple(true, x, y);
1463 }
1464 
certificateChainGetTopMostKey(const vector<uint8_t> & certificateChain)1465 optional<vector<uint8_t>> certificateChainGetTopMostKey(const vector<uint8_t>& certificateChain) {
1466     vector<X509_Ptr> certs;
1467     if (!parseX509Certificates(certificateChain, certs)) {
1468         return {};
1469     }
1470     if (certs.size() < 1) {
1471         LOG(ERROR) << "No certificates in chain";
1472         return {};
1473     }
1474 
1475     auto pkey = EVP_PKEY_Ptr(X509_get_pubkey(certs[0].get()));
1476     if (pkey.get() == nullptr) {
1477         LOG(ERROR) << "No public key";
1478         return {};
1479     }
1480 
1481     auto ecKey = EC_KEY_Ptr(EVP_PKEY_get1_EC_KEY(pkey.get()));
1482     if (ecKey.get() == nullptr) {
1483         LOG(ERROR) << "Failed getting EC key";
1484         return {};
1485     }
1486 
1487     auto ecGroup = EC_KEY_get0_group(ecKey.get());
1488     auto ecPoint = EC_KEY_get0_public_key(ecKey.get());
1489     int size = EC_POINT_point2oct(ecGroup, ecPoint, POINT_CONVERSION_UNCOMPRESSED, nullptr, 0,
1490                                   nullptr);
1491     if (size == 0) {
1492         LOG(ERROR) << "Error generating public key encoding";
1493         return {};
1494     }
1495     vector<uint8_t> publicKey;
1496     publicKey.resize(size);
1497     EC_POINT_point2oct(ecGroup, ecPoint, POINT_CONVERSION_UNCOMPRESSED, publicKey.data(),
1498                        publicKey.size(), nullptr);
1499     return publicKey;
1500 }
1501 
certificateGetExtension(const vector<uint8_t> & x509Certificate,const string & oidStr)1502 optional<vector<uint8_t>> certificateGetExtension(const vector<uint8_t>& x509Certificate,
1503                                                   const string& oidStr) {
1504     vector<X509_Ptr> certs;
1505     if (!parseX509Certificates(x509Certificate, certs)) {
1506         return {};
1507     }
1508     if (certs.size() < 1) {
1509         LOG(ERROR) << "No certificates in chain";
1510         return {};
1511     }
1512 
1513     ASN1_OBJECT_Ptr oid(
1514             OBJ_txt2obj(oidStr.c_str(), 1));  // accept numerical dotted string form only
1515     if (!oid.get()) {
1516         LOG(ERROR) << "Error setting OID";
1517         return {};
1518     }
1519 
1520     int location = X509_get_ext_by_OBJ(certs[0].get(), oid.get(), -1 /* search from beginning */);
1521     if (location == -1) {
1522         return {};
1523     }
1524 
1525     X509_EXTENSION* ext = X509_get_ext(certs[0].get(), location);
1526     if (ext == nullptr) {
1527         return {};
1528     }
1529 
1530     ASN1_OCTET_STRING* octetString = X509_EXTENSION_get_data(ext);
1531     if (octetString == nullptr) {
1532         return {};
1533     }
1534     vector<uint8_t> result;
1535     result.resize(octetString->length);
1536     memcpy(result.data(), octetString->data, octetString->length);
1537     return result;
1538 }
1539 
certificateFindPublicKey(const vector<uint8_t> & x509Certificate)1540 optional<pair<size_t, size_t>> certificateFindPublicKey(const vector<uint8_t>& x509Certificate) {
1541     vector<X509_Ptr> certs;
1542     if (!parseX509Certificates(x509Certificate, certs)) {
1543         return {};
1544     }
1545     if (certs.size() < 1) {
1546         LOG(ERROR) << "No certificates in chain";
1547         return {};
1548     }
1549 
1550     auto pkey = EVP_PKEY_Ptr(X509_get_pubkey(certs[0].get()));
1551     if (pkey.get() == nullptr) {
1552         LOG(ERROR) << "No public key";
1553         return {};
1554     }
1555 
1556     auto ecKey = EC_KEY_Ptr(EVP_PKEY_get1_EC_KEY(pkey.get()));
1557     if (ecKey.get() == nullptr) {
1558         LOG(ERROR) << "Failed getting EC key";
1559         return {};
1560     }
1561 
1562     auto ecGroup = EC_KEY_get0_group(ecKey.get());
1563     auto ecPoint = EC_KEY_get0_public_key(ecKey.get());
1564     int size = EC_POINT_point2oct(ecGroup, ecPoint, POINT_CONVERSION_UNCOMPRESSED, nullptr, 0,
1565                                   nullptr);
1566     if (size == 0) {
1567         LOG(ERROR) << "Error generating public key encoding";
1568         return {};
1569     }
1570     vector<uint8_t> publicKey;
1571     publicKey.resize(size);
1572     EC_POINT_point2oct(ecGroup, ecPoint, POINT_CONVERSION_UNCOMPRESSED, publicKey.data(),
1573                        publicKey.size(), nullptr);
1574 
1575     size_t publicKeyOffset = 0;
1576     size_t publicKeySize = (size_t)size;
1577     void* location = memmem((const void*)x509Certificate.data(), x509Certificate.size(),
1578                             (const void*)publicKey.data(), publicKey.size());
1579 
1580     if (location == NULL) {
1581         LOG(ERROR) << "Error finding publicKey from x509Certificate";
1582         return {};
1583     }
1584     publicKeyOffset = (size_t)((const char*)location - (const char*)x509Certificate.data());
1585 
1586     return std::make_pair(publicKeyOffset, publicKeySize);
1587 }
1588 
certificateTbsCertificate(const vector<uint8_t> & x509Certificate)1589 optional<pair<size_t, size_t>> certificateTbsCertificate(const vector<uint8_t>& x509Certificate) {
1590     vector<X509_Ptr> certs;
1591     if (!parseX509Certificates(x509Certificate, certs)) {
1592         return {};
1593     }
1594     if (certs.size() < 1) {
1595         LOG(ERROR) << "No certificates in chain";
1596         return {};
1597     }
1598 
1599     unsigned char* buf = NULL;
1600     int len = i2d_re_X509_tbs(certs[0].get(), &buf);
1601     if ((len < 0) || (buf == NULL)) {
1602         LOG(ERROR) << "fail to extract tbsCertificate in x509Certificate";
1603         return {};
1604     }
1605 
1606     vector<uint8_t> tbsCertificate(len);
1607     memcpy(tbsCertificate.data(), buf, len);
1608 
1609     size_t tbsCertificateOffset = 0;
1610     size_t tbsCertificateSize = (size_t)len;
1611     void* location = memmem((const void*)x509Certificate.data(), x509Certificate.size(),
1612                             (const void*)tbsCertificate.data(), tbsCertificate.size());
1613 
1614     if (location == NULL) {
1615         LOG(ERROR) << "Error finding tbsCertificate from x509Certificate";
1616         return {};
1617     }
1618     tbsCertificateOffset = (size_t)((const char*)location - (const char*)x509Certificate.data());
1619 
1620     return std::make_pair(tbsCertificateOffset, tbsCertificateSize);
1621 }
1622 
certificateGetValidity(const vector<uint8_t> & x509Certificate)1623 optional<pair<time_t, time_t>> certificateGetValidity(const vector<uint8_t>& x509Certificate) {
1624     vector<X509_Ptr> certs;
1625     if (!parseX509Certificates(x509Certificate, certs)) {
1626         LOG(ERROR) << "Error parsing certificates";
1627         return {};
1628     }
1629     if (certs.size() < 1) {
1630         LOG(ERROR) << "No certificates in chain";
1631         return {};
1632     }
1633 
1634     time_t notBefore;
1635     time_t notAfter;
1636     if (!parseAsn1Time(X509_get0_notBefore(certs[0].get()), &notBefore)) {
1637         LOG(ERROR) << "Error parsing notBefore";
1638         return {};
1639     }
1640 
1641     if (!parseAsn1Time(X509_get0_notAfter(certs[0].get()), &notAfter)) {
1642         LOG(ERROR) << "Error parsing notAfter";
1643         return {};
1644     }
1645 
1646     return std::make_pair(notBefore, notAfter);
1647 }
1648 
certificateFindSignature(const vector<uint8_t> & x509Certificate)1649 optional<pair<size_t, size_t>> certificateFindSignature(const vector<uint8_t>& x509Certificate) {
1650     vector<X509_Ptr> certs;
1651     if (!parseX509Certificates(x509Certificate, certs)) {
1652         return {};
1653     }
1654     if (certs.size() < 1) {
1655         LOG(ERROR) << "No certificates in chain";
1656         return {};
1657     }
1658 
1659     ASN1_BIT_STRING* psig;
1660     X509_ALGOR* palg;
1661     X509_get0_signature((const ASN1_BIT_STRING**)&psig, (const X509_ALGOR**)&palg, certs[0].get());
1662 
1663     vector<char> signature(psig->length);
1664     memcpy(signature.data(), psig->data, psig->length);
1665 
1666     size_t signatureOffset = 0;
1667     size_t signatureSize = (size_t)psig->length;
1668     void* location = memmem((const void*)x509Certificate.data(), x509Certificate.size(),
1669                             (const void*)signature.data(), signature.size());
1670 
1671     if (location == NULL) {
1672         LOG(ERROR) << "Error finding signature from x509Certificate";
1673         return {};
1674     }
1675     signatureOffset = (size_t)((const char*)location - (const char*)x509Certificate.data());
1676 
1677     return std::make_pair(signatureOffset, signatureSize);
1678 }
1679 
1680 // ---------------------------------------------------------------------------
1681 // COSE Utility Functions
1682 // ---------------------------------------------------------------------------
1683 
coseBuildToBeSigned(const vector<uint8_t> & encodedProtectedHeaders,const vector<uint8_t> & data,const vector<uint8_t> & detachedContent)1684 vector<uint8_t> coseBuildToBeSigned(const vector<uint8_t>& encodedProtectedHeaders,
1685                                     const vector<uint8_t>& data,
1686                                     const vector<uint8_t>& detachedContent) {
1687     cppbor::Array sigStructure;
1688     sigStructure.add("Signature1");
1689     sigStructure.add(encodedProtectedHeaders);
1690 
1691     // We currently don't support Externally Supplied Data (RFC 8152 section 4.3)
1692     // so external_aad is the empty bstr
1693     vector<uint8_t> emptyExternalAad;
1694     sigStructure.add(emptyExternalAad);
1695 
1696     // Next field is the payload, independently of how it's transported (RFC
1697     // 8152 section 4.4). Since our API specifies only one of |data| and
1698     // |detachedContent| can be non-empty, it's simply just the non-empty one.
1699     if (data.size() > 0) {
1700         sigStructure.add(data);
1701     } else {
1702         sigStructure.add(detachedContent);
1703     }
1704     return sigStructure.encode();
1705 }
1706 
coseEncodeHeaders(const cppbor::Map & protectedHeaders)1707 vector<uint8_t> coseEncodeHeaders(const cppbor::Map& protectedHeaders) {
1708     if (protectedHeaders.size() == 0) {
1709         cppbor::Bstr emptyBstr(vector<uint8_t>({}));
1710         return emptyBstr.encode();
1711     }
1712     return protectedHeaders.encode();
1713 }
1714 
1715 // From https://tools.ietf.org/html/rfc8152
1716 const int COSE_LABEL_ALG = 1;
1717 const int COSE_LABEL_X5CHAIN = 33;  // temporary identifier
1718 
1719 // From "COSE Algorithms" registry
1720 const int COSE_ALG_ECDSA_256 = -7;
1721 const int COSE_ALG_HMAC_256_256 = 5;
1722 
ecdsaSignatureCoseToDer(const vector<uint8_t> & ecdsaCoseSignature,vector<uint8_t> & ecdsaDerSignature)1723 bool ecdsaSignatureCoseToDer(const vector<uint8_t>& ecdsaCoseSignature,
1724                              vector<uint8_t>& ecdsaDerSignature) {
1725     if (ecdsaCoseSignature.size() != 64) {
1726         LOG(ERROR) << "COSE signature length is " << ecdsaCoseSignature.size() << ", expected 64";
1727         return false;
1728     }
1729 
1730     auto rBn = BIGNUM_Ptr(BN_bin2bn(ecdsaCoseSignature.data(), 32, nullptr));
1731     if (rBn.get() == nullptr) {
1732         LOG(ERROR) << "Error creating BIGNUM for r";
1733         return false;
1734     }
1735 
1736     auto sBn = BIGNUM_Ptr(BN_bin2bn(ecdsaCoseSignature.data() + 32, 32, nullptr));
1737     if (sBn.get() == nullptr) {
1738         LOG(ERROR) << "Error creating BIGNUM for s";
1739         return false;
1740     }
1741 
1742     ECDSA_SIG sig;
1743     sig.r = rBn.get();
1744     sig.s = sBn.get();
1745 
1746     size_t len = i2d_ECDSA_SIG(&sig, nullptr);
1747     ecdsaDerSignature.resize(len);
1748     unsigned char* p = (unsigned char*)ecdsaDerSignature.data();
1749     i2d_ECDSA_SIG(&sig, &p);
1750 
1751     return true;
1752 }
1753 
ecdsaSignatureDerToCose(const vector<uint8_t> & ecdsaDerSignature,vector<uint8_t> & ecdsaCoseSignature)1754 bool ecdsaSignatureDerToCose(const vector<uint8_t>& ecdsaDerSignature,
1755                              vector<uint8_t>& ecdsaCoseSignature) {
1756     ECDSA_SIG* sig;
1757     const unsigned char* p = ecdsaDerSignature.data();
1758     sig = d2i_ECDSA_SIG(nullptr, &p, ecdsaDerSignature.size());
1759     if (sig == nullptr) {
1760         LOG(ERROR) << "Error decoding DER signature";
1761         return false;
1762     }
1763 
1764     ecdsaCoseSignature.clear();
1765     ecdsaCoseSignature.resize(64);
1766     if (BN_bn2binpad(ECDSA_SIG_get0_r(sig), ecdsaCoseSignature.data(), 32) != 32) {
1767         LOG(ERROR) << "Error encoding r";
1768         return false;
1769     }
1770     if (BN_bn2binpad(ECDSA_SIG_get0_s(sig), ecdsaCoseSignature.data() + 32, 32) != 32) {
1771         LOG(ERROR) << "Error encoding s";
1772         return false;
1773     }
1774     return true;
1775 }
1776 
coseSignEcDsaWithSignature(const vector<uint8_t> & signatureToBeSigned,const vector<uint8_t> & data,const vector<uint8_t> & certificateChain)1777 optional<vector<uint8_t>> coseSignEcDsaWithSignature(const vector<uint8_t>& signatureToBeSigned,
1778                                                      const vector<uint8_t>& data,
1779                                                      const vector<uint8_t>& certificateChain) {
1780     if (signatureToBeSigned.size() != 64) {
1781         LOG(ERROR) << "Invalid size for signatureToBeSigned, expected 64 got "
1782                    << signatureToBeSigned.size();
1783         return {};
1784     }
1785 
1786     cppbor::Map unprotectedHeaders;
1787     cppbor::Map protectedHeaders;
1788 
1789     protectedHeaders.add(COSE_LABEL_ALG, COSE_ALG_ECDSA_256);
1790 
1791     if (certificateChain.size() != 0) {
1792         optional<vector<vector<uint8_t>>> certs = support::certificateChainSplit(certificateChain);
1793         if (!certs) {
1794             LOG(ERROR) << "Error splitting certificate chain";
1795             return {};
1796         }
1797         if (certs.value().size() == 1) {
1798             unprotectedHeaders.add(COSE_LABEL_X5CHAIN, certs.value()[0]);
1799         } else {
1800             cppbor::Array certArray;
1801             for (const vector<uint8_t>& cert : certs.value()) {
1802                 certArray.add(cert);
1803             }
1804             unprotectedHeaders.add(COSE_LABEL_X5CHAIN, std::move(certArray));
1805         }
1806     }
1807 
1808     vector<uint8_t> encodedProtectedHeaders = coseEncodeHeaders(protectedHeaders);
1809 
1810     cppbor::Array coseSign1;
1811     coseSign1.add(encodedProtectedHeaders);
1812     coseSign1.add(std::move(unprotectedHeaders));
1813     if (data.size() == 0) {
1814         cppbor::Null nullValue;
1815         coseSign1.add(std::move(nullValue));
1816     } else {
1817         coseSign1.add(data);
1818     }
1819     coseSign1.add(signatureToBeSigned);
1820     vector<uint8_t> signatureCoseSign1;
1821     signatureCoseSign1 = coseSign1.encode();
1822 
1823     return signatureCoseSign1;
1824 }
1825 
coseSignEcDsa(const vector<uint8_t> & key,const vector<uint8_t> & data,const vector<uint8_t> & detachedContent,const vector<uint8_t> & certificateChain)1826 optional<vector<uint8_t>> coseSignEcDsa(const vector<uint8_t>& key, const vector<uint8_t>& data,
1827                                         const vector<uint8_t>& detachedContent,
1828                                         const vector<uint8_t>& certificateChain) {
1829     cppbor::Map unprotectedHeaders;
1830     cppbor::Map protectedHeaders;
1831 
1832     if (data.size() > 0 && detachedContent.size() > 0) {
1833         LOG(ERROR) << "data and detachedContent cannot both be non-empty";
1834         return {};
1835     }
1836 
1837     protectedHeaders.add(COSE_LABEL_ALG, COSE_ALG_ECDSA_256);
1838 
1839     if (certificateChain.size() != 0) {
1840         optional<vector<vector<uint8_t>>> certs = support::certificateChainSplit(certificateChain);
1841         if (!certs) {
1842             LOG(ERROR) << "Error splitting certificate chain";
1843             return {};
1844         }
1845         if (certs.value().size() == 1) {
1846             unprotectedHeaders.add(COSE_LABEL_X5CHAIN, certs.value()[0]);
1847         } else {
1848             cppbor::Array certArray;
1849             for (const vector<uint8_t>& cert : certs.value()) {
1850                 certArray.add(cert);
1851             }
1852             unprotectedHeaders.add(COSE_LABEL_X5CHAIN, std::move(certArray));
1853         }
1854     }
1855 
1856     vector<uint8_t> encodedProtectedHeaders = coseEncodeHeaders(protectedHeaders);
1857     vector<uint8_t> toBeSigned =
1858             coseBuildToBeSigned(encodedProtectedHeaders, data, detachedContent);
1859 
1860     optional<vector<uint8_t>> derSignature = signEcDsa(key, toBeSigned);
1861     if (!derSignature) {
1862         LOG(ERROR) << "Error signing toBeSigned data";
1863         return {};
1864     }
1865     vector<uint8_t> coseSignature;
1866     if (!ecdsaSignatureDerToCose(derSignature.value(), coseSignature)) {
1867         LOG(ERROR) << "Error converting ECDSA signature from DER to COSE format";
1868         return {};
1869     }
1870 
1871     cppbor::Array coseSign1;
1872     coseSign1.add(encodedProtectedHeaders);
1873     coseSign1.add(std::move(unprotectedHeaders));
1874     if (data.size() == 0) {
1875         cppbor::Null nullValue;
1876         coseSign1.add(std::move(nullValue));
1877     } else {
1878         coseSign1.add(data);
1879     }
1880     coseSign1.add(coseSignature);
1881     vector<uint8_t> signatureCoseSign1;
1882     signatureCoseSign1 = coseSign1.encode();
1883     return signatureCoseSign1;
1884 }
1885 
coseCheckEcDsaSignature(const vector<uint8_t> & signatureCoseSign1,const vector<uint8_t> & detachedContent,const vector<uint8_t> & publicKey)1886 bool coseCheckEcDsaSignature(const vector<uint8_t>& signatureCoseSign1,
1887                              const vector<uint8_t>& detachedContent,
1888                              const vector<uint8_t>& publicKey) {
1889     auto [item, _, message] = cppbor::parse(signatureCoseSign1);
1890     if (item == nullptr) {
1891         LOG(ERROR) << "Passed-in COSE_Sign1 is not valid CBOR: " << message;
1892         return false;
1893     }
1894     const cppbor::Array* array = item->asArray();
1895     if (array == nullptr) {
1896         LOG(ERROR) << "Value for COSE_Sign1 is not an array";
1897         return false;
1898     }
1899     if (array->size() != 4) {
1900         LOG(ERROR) << "Value for COSE_Sign1 is not an array of size 4";
1901         return false;
1902     }
1903 
1904     const cppbor::Bstr* encodedProtectedHeadersBstr = (*array)[0]->asBstr();
1905     ;
1906     if (encodedProtectedHeadersBstr == nullptr) {
1907         LOG(ERROR) << "Value for encodedProtectedHeaders is not a bstr";
1908         return false;
1909     }
1910     const vector<uint8_t> encodedProtectedHeaders = encodedProtectedHeadersBstr->value();
1911 
1912     const cppbor::Map* unprotectedHeaders = (*array)[1]->asMap();
1913     if (unprotectedHeaders == nullptr) {
1914         LOG(ERROR) << "Value for unprotectedHeaders is not a map";
1915         return false;
1916     }
1917 
1918     vector<uint8_t> data;
1919     const cppbor::Simple* payloadAsSimple = (*array)[2]->asSimple();
1920     if (payloadAsSimple != nullptr) {
1921         if (payloadAsSimple->asNull() == nullptr) {
1922             LOG(ERROR) << "Value for payload is not null or a bstr";
1923             return false;
1924         }
1925     } else {
1926         const cppbor::Bstr* payloadAsBstr = (*array)[2]->asBstr();
1927         if (payloadAsBstr == nullptr) {
1928             LOG(ERROR) << "Value for payload is not null or a bstr";
1929             return false;
1930         }
1931         data = payloadAsBstr->value();  // TODO: avoid copy
1932     }
1933 
1934     if (data.size() > 0 && detachedContent.size() > 0) {
1935         LOG(ERROR) << "data and detachedContent cannot both be non-empty";
1936         return false;
1937     }
1938 
1939     const cppbor::Bstr* signatureBstr = (*array)[3]->asBstr();
1940     if (signatureBstr == nullptr) {
1941         LOG(ERROR) << "Value for signature is a bstr";
1942         return false;
1943     }
1944     const vector<uint8_t>& coseSignature = signatureBstr->value();
1945 
1946     vector<uint8_t> derSignature;
1947     if (!ecdsaSignatureCoseToDer(coseSignature, derSignature)) {
1948         LOG(ERROR) << "Error converting ECDSA signature from COSE to DER format";
1949         return false;
1950     }
1951 
1952     vector<uint8_t> toBeSigned =
1953             coseBuildToBeSigned(encodedProtectedHeaders, data, detachedContent);
1954     if (!checkEcDsaSignature(support::sha256(toBeSigned), derSignature, publicKey)) {
1955         LOG(ERROR) << "Signature check failed";
1956         return false;
1957     }
1958     return true;
1959 }
1960 
1961 // Extracts the signature (of the ToBeSigned CBOR) from a COSE_Sign1.
coseSignGetSignature(const vector<uint8_t> & signatureCoseSign1)1962 optional<vector<uint8_t>> coseSignGetSignature(const vector<uint8_t>& signatureCoseSign1) {
1963     auto [item, _, message] = cppbor::parse(signatureCoseSign1);
1964     if (item == nullptr) {
1965         LOG(ERROR) << "Passed-in COSE_Sign1 is not valid CBOR: " << message;
1966         return {};
1967     }
1968     const cppbor::Array* array = item->asArray();
1969     if (array == nullptr) {
1970         LOG(ERROR) << "Value for COSE_Sign1 is not an array";
1971         return {};
1972     }
1973     if (array->size() != 4) {
1974         LOG(ERROR) << "Value for COSE_Sign1 is not an array of size 4";
1975         return {};
1976     }
1977 
1978     vector<uint8_t> signature;
1979     const cppbor::Bstr* signatureAsBstr = (*array)[3]->asBstr();
1980     if (signatureAsBstr == nullptr) {
1981         LOG(ERROR) << "Value for signature is not a bstr";
1982         return {};
1983     }
1984     // Copy payload into |data|
1985     signature = signatureAsBstr->value();
1986 
1987     return signature;
1988 }
1989 
coseSignGetPayload(const vector<uint8_t> & signatureCoseSign1)1990 optional<vector<uint8_t>> coseSignGetPayload(const vector<uint8_t>& signatureCoseSign1) {
1991     auto [item, _, message] = cppbor::parse(signatureCoseSign1);
1992     if (item == nullptr) {
1993         LOG(ERROR) << "Passed-in COSE_Sign1 is not valid CBOR: " << message;
1994         return {};
1995     }
1996     const cppbor::Array* array = item->asArray();
1997     if (array == nullptr) {
1998         LOG(ERROR) << "Value for COSE_Sign1 is not an array";
1999         return {};
2000     }
2001     if (array->size() != 4) {
2002         LOG(ERROR) << "Value for COSE_Sign1 is not an array of size 4";
2003         return {};
2004     }
2005 
2006     vector<uint8_t> data;
2007     const cppbor::Simple* payloadAsSimple = (*array)[2]->asSimple();
2008     if (payloadAsSimple != nullptr) {
2009         if (payloadAsSimple->asNull() == nullptr) {
2010             LOG(ERROR) << "Value for payload is not null or a bstr";
2011             return {};
2012         }
2013         // payload is null, so |data| should be empty (as it is)
2014     } else {
2015         const cppbor::Bstr* payloadAsBstr = (*array)[2]->asBstr();
2016         if (payloadAsBstr == nullptr) {
2017             LOG(ERROR) << "Value for payload is not null or a bstr";
2018             return {};
2019         }
2020         // Copy payload into |data|
2021         data = payloadAsBstr->value();
2022     }
2023 
2024     return data;
2025 }
2026 
coseSignGetAlg(const vector<uint8_t> & signatureCoseSign1)2027 optional<int> coseSignGetAlg(const vector<uint8_t>& signatureCoseSign1) {
2028     auto [item, _, message] = cppbor::parse(signatureCoseSign1);
2029     if (item == nullptr) {
2030         LOG(ERROR) << "Passed-in COSE_Sign1 is not valid CBOR: " << message;
2031         return {};
2032     }
2033     const cppbor::Array* array = item->asArray();
2034     if (array == nullptr) {
2035         LOG(ERROR) << "Value for COSE_Sign1 is not an array";
2036         return {};
2037     }
2038     if (array->size() != 4) {
2039         LOG(ERROR) << "Value for COSE_Sign1 is not an array of size 4";
2040         return {};
2041     }
2042 
2043     const cppbor::Bstr* protectedHeadersBytes = (*array)[0]->asBstr();
2044     if (protectedHeadersBytes == nullptr) {
2045         LOG(ERROR) << "Value for protectedHeaders is not a bstr";
2046         return {};
2047     }
2048     auto [item2, _2, message2] = cppbor::parse(protectedHeadersBytes->value());
2049     if (item2 == nullptr) {
2050         LOG(ERROR) << "Error parsing protectedHeaders: " << message2;
2051         return {};
2052     }
2053     const cppbor::Map* protectedHeaders = item2->asMap();
2054     if (protectedHeaders == nullptr) {
2055         LOG(ERROR) << "Decoded CBOR for protectedHeaders is not a map";
2056         return {};
2057     }
2058 
2059     for (size_t n = 0; n < protectedHeaders->size(); n++) {
2060         auto& [keyItem, valueItem] = (*protectedHeaders)[n];
2061         const cppbor::Int* number = keyItem->asInt();
2062         if (number == nullptr) {
2063             LOG(ERROR) << "Key item in top-level map is not a number";
2064             return {};
2065         }
2066         int label = number->value();
2067         if (label == COSE_LABEL_ALG) {
2068             const cppbor::Int* number = valueItem->asInt();
2069             if (number != nullptr) {
2070                 return number->value();
2071             }
2072             LOG(ERROR) << "Value for COSE_LABEL_ALG label is not a number";
2073             return {};
2074         }
2075     }
2076     LOG(ERROR) << "Did not find COSE_LABEL_ALG label in protected headers";
2077     return {};
2078 }
2079 
coseSignGetX5Chain(const vector<uint8_t> & signatureCoseSign1)2080 optional<vector<uint8_t>> coseSignGetX5Chain(const vector<uint8_t>& signatureCoseSign1) {
2081     auto [item, _, message] = cppbor::parse(signatureCoseSign1);
2082     if (item == nullptr) {
2083         LOG(ERROR) << "Passed-in COSE_Sign1 is not valid CBOR: " << message;
2084         return {};
2085     }
2086     const cppbor::Array* array = item->asArray();
2087     if (array == nullptr) {
2088         LOG(ERROR) << "Value for COSE_Sign1 is not an array";
2089         return {};
2090     }
2091     if (array->size() != 4) {
2092         LOG(ERROR) << "Value for COSE_Sign1 is not an array of size 4";
2093         return {};
2094     }
2095 
2096     const cppbor::Map* unprotectedHeaders = (*array)[1]->asMap();
2097     if (unprotectedHeaders == nullptr) {
2098         LOG(ERROR) << "Value for unprotectedHeaders is not a map";
2099         return {};
2100     }
2101 
2102     for (size_t n = 0; n < unprotectedHeaders->size(); n++) {
2103         auto& [keyItem, valueItem] = (*unprotectedHeaders)[n];
2104         const cppbor::Int* number = keyItem->asInt();
2105         if (number == nullptr) {
2106             LOG(ERROR) << "Key item in top-level map is not a number";
2107             return {};
2108         }
2109         int label = number->value();
2110         if (label == COSE_LABEL_X5CHAIN) {
2111             const cppbor::Bstr* bstr = valueItem->asBstr();
2112             if (bstr != nullptr) {
2113                 return bstr->value();
2114             }
2115             const cppbor::Array* array = valueItem->asArray();
2116             if (array != nullptr) {
2117                 vector<uint8_t> certs;
2118                 for (size_t m = 0; m < array->size(); m++) {
2119                     const cppbor::Bstr* bstr = ((*array)[m])->asBstr();
2120                     if (bstr == nullptr) {
2121                         LOG(ERROR) << "Item in x5chain array is not a bstr";
2122                         return {};
2123                     }
2124                     const vector<uint8_t>& certValue = bstr->value();
2125                     certs.insert(certs.end(), certValue.begin(), certValue.end());
2126                 }
2127                 return certs;
2128             }
2129             LOG(ERROR) << "Value for x5chain label is not a bstr or array";
2130             return {};
2131         }
2132     }
2133     LOG(ERROR) << "Did not find x5chain label in unprotected headers";
2134     return {};
2135 }
2136 
coseBuildToBeMACed(const vector<uint8_t> & encodedProtectedHeaders,const vector<uint8_t> & data,const vector<uint8_t> & detachedContent)2137 vector<uint8_t> coseBuildToBeMACed(const vector<uint8_t>& encodedProtectedHeaders,
2138                                    const vector<uint8_t>& data,
2139                                    const vector<uint8_t>& detachedContent) {
2140     cppbor::Array macStructure;
2141     macStructure.add("MAC0");
2142     macStructure.add(encodedProtectedHeaders);
2143 
2144     // We currently don't support Externally Supplied Data (RFC 8152 section 4.3)
2145     // so external_aad is the empty bstr
2146     vector<uint8_t> emptyExternalAad;
2147     macStructure.add(emptyExternalAad);
2148 
2149     // Next field is the payload, independently of how it's transported (RFC
2150     // 8152 section 4.4). Since our API specifies only one of |data| and
2151     // |detachedContent| can be non-empty, it's simply just the non-empty one.
2152     if (data.size() > 0) {
2153         macStructure.add(data);
2154     } else {
2155         macStructure.add(detachedContent);
2156     }
2157 
2158     return macStructure.encode();
2159 }
2160 
coseMac0(const vector<uint8_t> & key,const vector<uint8_t> & data,const vector<uint8_t> & detachedContent)2161 optional<vector<uint8_t>> coseMac0(const vector<uint8_t>& key, const vector<uint8_t>& data,
2162                                    const vector<uint8_t>& detachedContent) {
2163     cppbor::Map unprotectedHeaders;
2164     cppbor::Map protectedHeaders;
2165 
2166     if (data.size() > 0 && detachedContent.size() > 0) {
2167         LOG(ERROR) << "data and detachedContent cannot both be non-empty";
2168         return {};
2169     }
2170 
2171     protectedHeaders.add(COSE_LABEL_ALG, COSE_ALG_HMAC_256_256);
2172 
2173     vector<uint8_t> encodedProtectedHeaders = coseEncodeHeaders(protectedHeaders);
2174     vector<uint8_t> toBeMACed = coseBuildToBeMACed(encodedProtectedHeaders, data, detachedContent);
2175 
2176     optional<vector<uint8_t>> mac = hmacSha256(key, toBeMACed);
2177     if (!mac) {
2178         LOG(ERROR) << "Error MACing toBeMACed data";
2179         return {};
2180     }
2181 
2182     cppbor::Array array;
2183     array.add(encodedProtectedHeaders);
2184     array.add(std::move(unprotectedHeaders));
2185     if (data.size() == 0) {
2186         cppbor::Null nullValue;
2187         array.add(std::move(nullValue));
2188     } else {
2189         array.add(data);
2190     }
2191     array.add(mac.value());
2192     return array.encode();
2193 }
2194 
coseMacWithDigest(const vector<uint8_t> & digestToBeMaced,const vector<uint8_t> & data)2195 optional<vector<uint8_t>> coseMacWithDigest(const vector<uint8_t>& digestToBeMaced,
2196                                             const vector<uint8_t>& data) {
2197     cppbor::Map unprotectedHeaders;
2198     cppbor::Map protectedHeaders;
2199 
2200     protectedHeaders.add(COSE_LABEL_ALG, COSE_ALG_HMAC_256_256);
2201 
2202     vector<uint8_t> encodedProtectedHeaders = coseEncodeHeaders(protectedHeaders);
2203 
2204     cppbor::Array array;
2205     array.add(encodedProtectedHeaders);
2206     array.add(std::move(unprotectedHeaders));
2207     if (data.size() == 0) {
2208         cppbor::Null nullValue;
2209         array.add(std::move(nullValue));
2210     } else {
2211         array.add(data);
2212     }
2213     array.add(digestToBeMaced);
2214     return array.encode();
2215 }
2216 
2217 // ---------------------------------------------------------------------------
2218 // Utility functions specific to IdentityCredential.
2219 // ---------------------------------------------------------------------------
2220 
calcEMacKey(const vector<uint8_t> & privateKey,const vector<uint8_t> & publicKey,const vector<uint8_t> & sessionTranscriptBytes)2221 optional<vector<uint8_t>> calcEMacKey(const vector<uint8_t>& privateKey,
2222                                       const vector<uint8_t>& publicKey,
2223                                       const vector<uint8_t>& sessionTranscriptBytes) {
2224     optional<vector<uint8_t>> sharedSecret = support::ecdh(publicKey, privateKey);
2225     if (!sharedSecret) {
2226         LOG(ERROR) << "Error performing ECDH";
2227         return {};
2228     }
2229     vector<uint8_t> salt = support::sha256(sessionTranscriptBytes);
2230     vector<uint8_t> info = {'E', 'M', 'a', 'c', 'K', 'e', 'y'};
2231     optional<vector<uint8_t>> derivedKey = support::hkdf(sharedSecret.value(), salt, info, 32);
2232     if (!derivedKey) {
2233         LOG(ERROR) << "Error performing HKDF";
2234         return {};
2235     }
2236     return derivedKey.value();
2237 }
2238 
calcMac(const vector<uint8_t> & sessionTranscriptEncoded,const string & docType,const vector<uint8_t> & deviceNameSpacesEncoded,const vector<uint8_t> & eMacKey)2239 optional<vector<uint8_t>> calcMac(const vector<uint8_t>& sessionTranscriptEncoded,
2240                                   const string& docType,
2241                                   const vector<uint8_t>& deviceNameSpacesEncoded,
2242                                   const vector<uint8_t>& eMacKey) {
2243     auto [sessionTranscriptItem, _, errMsg] = cppbor::parse(sessionTranscriptEncoded);
2244     if (sessionTranscriptItem == nullptr) {
2245         LOG(ERROR) << "Error parsing sessionTranscriptEncoded: " << errMsg;
2246         return {};
2247     }
2248     // The data that is MACed is ["DeviceAuthentication", sessionTranscript, docType,
2249     // deviceNameSpacesBytes] so build up that structure
2250     cppbor::Array deviceAuthentication =
2251             cppbor::Array()
2252                     .add("DeviceAuthentication")
2253                     .add(std::move(sessionTranscriptItem))
2254                     .add(docType)
2255                     .add(cppbor::SemanticTag(kSemanticTagEncodedCbor, deviceNameSpacesEncoded));
2256     vector<uint8_t> deviceAuthenticationBytes =
2257             cppbor::SemanticTag(kSemanticTagEncodedCbor, deviceAuthentication.encode()).encode();
2258     optional<vector<uint8_t>> calculatedMac =
2259             support::coseMac0(eMacKey, {},                 // payload
2260                               deviceAuthenticationBytes);  // detached content
2261     return calculatedMac;
2262 }
2263 
chunkVector(const vector<uint8_t> & content,size_t maxChunkSize)2264 vector<vector<uint8_t>> chunkVector(const vector<uint8_t>& content, size_t maxChunkSize) {
2265     vector<vector<uint8_t>> ret;
2266 
2267     size_t contentSize = content.size();
2268     if (contentSize <= maxChunkSize) {
2269         ret.push_back(content);
2270         return ret;
2271     }
2272 
2273     size_t numChunks = (contentSize + maxChunkSize - 1) / maxChunkSize;
2274 
2275     size_t pos = 0;
2276     for (size_t n = 0; n < numChunks; n++) {
2277         size_t size = contentSize - pos;
2278         if (size > maxChunkSize) {
2279             size = maxChunkSize;
2280         }
2281         auto begin = content.begin() + pos;
2282         auto end = content.begin() + pos + size;
2283         ret.emplace_back(vector<uint8_t>(begin, end));
2284         pos += maxChunkSize;
2285     }
2286 
2287     return ret;
2288 }
2289 
2290 vector<uint8_t> testHardwareBoundKey = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
2291 
getTestHardwareBoundKey()2292 const vector<uint8_t>& getTestHardwareBoundKey() {
2293     return testHardwareBoundKey;
2294 }
2295 
extractDerSubjectFromCertificate(const vector<uint8_t> & certificate)2296 optional<vector<uint8_t>> extractDerSubjectFromCertificate(const vector<uint8_t>& certificate) {
2297     const uint8_t* input = certificate.data();
2298     X509_Ptr cert(d2i_X509(/*cert=*/nullptr, &input, certificate.size()));
2299     if (!cert) {
2300         LOG(ERROR) << "Failed to parse certificate";
2301         return std::nullopt;
2302     }
2303 
2304     X509_NAME* subject = X509_get_subject_name(cert.get());
2305     if (!subject) {
2306         LOG(ERROR) << "Failed to retrieve subject name";
2307         return std::nullopt;
2308     }
2309 
2310     int encodedSubjectLength = i2d_X509_NAME(subject, /*out=*/nullptr);
2311     if (encodedSubjectLength < 0) {
2312         LOG(ERROR) << "Error obtaining encoded subject name length";
2313         return std::nullopt;
2314     }
2315 
2316     vector<uint8_t> encodedSubject(encodedSubjectLength);
2317     uint8_t* out = encodedSubject.data();
2318     if (encodedSubjectLength != i2d_X509_NAME(subject, &out)) {
2319         LOG(ERROR) << "Error encoding subject name";
2320         return std::nullopt;
2321     }
2322 
2323     return encodedSubject;
2324 }
2325 
2326 }  // namespace support
2327 }  // namespace identity
2328 }  // namespace hardware
2329 }  // namespace android
2330