1 /*
2  * Copyright 2019, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "IdentityCredentialSupport"
18 
19 #include <android/hardware/identity/support/IdentityCredentialSupport.h>
20 
21 #define _POSIX_C_SOURCE 199309L
22 
23 #include <ctype.h>
24 #include <stdarg.h>
25 #include <stdio.h>
26 #include <time.h>
27 #include <chrono>
28 #include <iomanip>
29 
30 #include <openssl/aes.h>
31 #include <openssl/bn.h>
32 #include <openssl/crypto.h>
33 #include <openssl/ec.h>
34 #include <openssl/err.h>
35 #include <openssl/evp.h>
36 #include <openssl/hkdf.h>
37 #include <openssl/hmac.h>
38 #include <openssl/objects.h>
39 #include <openssl/pem.h>
40 #include <openssl/pkcs12.h>
41 #include <openssl/rand.h>
42 #include <openssl/x509.h>
43 #include <openssl/x509_vfy.h>
44 
45 #include <android-base/logging.h>
46 #include <android-base/stringprintf.h>
47 #include <charconv>
48 
49 #include <cppbor.h>
50 #include <cppbor_parse.h>
51 
52 #include <android/hardware/keymaster/4.0/types.h>
53 #include <keymaster/authorization_set.h>
54 #include <keymaster/contexts/pure_soft_keymaster_context.h>
55 #include <keymaster/contexts/soft_attestation_cert.h>
56 #include <keymaster/keymaster_tags.h>
57 #include <keymaster/km_openssl/attestation_utils.h>
58 #include <keymaster/km_openssl/certificate_utils.h>
59 
60 namespace android {
61 namespace hardware {
62 namespace identity {
63 namespace support {
64 
65 using ::std::pair;
66 using ::std::unique_ptr;
67 
68 // ---------------------------------------------------------------------------
69 // Miscellaneous utilities.
70 // ---------------------------------------------------------------------------
71 
hexdump(const string & name,const vector<uint8_t> & data)72 void hexdump(const string& name, const vector<uint8_t>& data) {
73     fprintf(stderr, "%s: dumping %zd bytes\n", name.c_str(), data.size());
74     size_t n, m, o;
75     for (n = 0; n < data.size(); n += 16) {
76         fprintf(stderr, "%04zx  ", n);
77         for (m = 0; m < 16 && n + m < data.size(); m++) {
78             fprintf(stderr, "%02x ", data[n + m]);
79         }
80         for (o = m; o < 16; o++) {
81             fprintf(stderr, "   ");
82         }
83         fprintf(stderr, " ");
84         for (m = 0; m < 16 && n + m < data.size(); m++) {
85             int c = data[n + m];
86             fprintf(stderr, "%c", isprint(c) ? c : '.');
87         }
88         fprintf(stderr, "\n");
89     }
90     fprintf(stderr, "\n");
91 }
92 
encodeHex(const uint8_t * data,size_t dataLen)93 string encodeHex(const uint8_t* data, size_t dataLen) {
94     static const char hexDigits[16] = {'0', '1', '2', '3', '4', '5', '6', '7',
95                                        '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'};
96 
97     string ret;
98     ret.resize(dataLen * 2);
99     for (size_t n = 0; n < dataLen; n++) {
100         uint8_t byte = data[n];
101         ret[n * 2 + 0] = hexDigits[byte >> 4];
102         ret[n * 2 + 1] = hexDigits[byte & 0x0f];
103     }
104 
105     return ret;
106 }
107 
encodeHex(const string & str)108 string encodeHex(const string& str) {
109     return encodeHex(reinterpret_cast<const uint8_t*>(str.data()), str.size());
110 }
111 
encodeHex(const vector<uint8_t> & data)112 string encodeHex(const vector<uint8_t>& data) {
113     return encodeHex(data.data(), data.size());
114 }
115 
116 // Returns -1 on error, otherwise an integer in the range 0 through 15, both inclusive.
parseHexDigit(char hexDigit)117 int parseHexDigit(char hexDigit) {
118     if (hexDigit >= '0' && hexDigit <= '9') {
119         return int(hexDigit) - '0';
120     } else if (hexDigit >= 'a' && hexDigit <= 'f') {
121         return int(hexDigit) - 'a' + 10;
122     } else if (hexDigit >= 'A' && hexDigit <= 'F') {
123         return int(hexDigit) - 'A' + 10;
124     }
125     return -1;
126 }
127 
decodeHex(const string & hexEncoded)128 optional<vector<uint8_t>> decodeHex(const string& hexEncoded) {
129     vector<uint8_t> out;
130     size_t hexSize = hexEncoded.size();
131     if ((hexSize & 1) != 0) {
132         LOG(ERROR) << "Size of data cannot be odd";
133         return {};
134     }
135 
136     out.resize(hexSize / 2);
137     for (size_t n = 0; n < hexSize / 2; n++) {
138         int upperNibble = parseHexDigit(hexEncoded[n * 2]);
139         int lowerNibble = parseHexDigit(hexEncoded[n * 2 + 1]);
140         if (upperNibble == -1 || lowerNibble == -1) {
141             LOG(ERROR) << "Invalid hex digit at position " << n;
142             return {};
143         }
144         out[n] = (upperNibble << 4) + lowerNibble;
145     }
146 
147     return out;
148 }
149 
150 // ---------------------------------------------------------------------------
151 // Crypto functionality / abstraction.
152 // ---------------------------------------------------------------------------
153 
154 using EvpCipherCtxPtr = bssl::UniquePtr<EVP_CIPHER_CTX>;
155 using EC_KEY_Ptr = bssl::UniquePtr<EC_KEY>;
156 using EVP_PKEY_Ptr = bssl::UniquePtr<EVP_PKEY>;
157 using EVP_PKEY_CTX_Ptr = bssl::UniquePtr<EVP_PKEY_CTX>;
158 using EC_GROUP_Ptr = bssl::UniquePtr<EC_GROUP>;
159 using EC_POINT_Ptr = bssl::UniquePtr<EC_POINT>;
160 using ECDSA_SIG_Ptr = bssl::UniquePtr<ECDSA_SIG>;
161 using X509_Ptr = bssl::UniquePtr<X509>;
162 using PKCS12_Ptr = bssl::UniquePtr<PKCS12>;
163 using BIGNUM_Ptr = bssl::UniquePtr<BIGNUM>;
164 using ASN1_INTEGER_Ptr = bssl::UniquePtr<ASN1_INTEGER>;
165 using ASN1_TIME_Ptr = bssl::UniquePtr<ASN1_TIME>;
166 using ASN1_OCTET_STRING_Ptr = bssl::UniquePtr<ASN1_OCTET_STRING>;
167 using ASN1_OBJECT_Ptr = bssl::UniquePtr<ASN1_OBJECT>;
168 using X509_NAME_Ptr = bssl::UniquePtr<X509_NAME>;
169 using X509_EXTENSION_Ptr = bssl::UniquePtr<X509_EXTENSION>;
170 
171 // bool getRandom(size_t numBytes, vector<uint8_t>& output) {
getRandom(size_t numBytes)172 optional<vector<uint8_t>> getRandom(size_t numBytes) {
173     vector<uint8_t> output;
174     output.resize(numBytes);
175     if (RAND_bytes(output.data(), numBytes) != 1) {
176         LOG(ERROR) << "RAND_bytes: failed getting " << numBytes << " random";
177         return {};
178     }
179     return output;
180 }
181 
decryptAes128Gcm(const vector<uint8_t> & key,const vector<uint8_t> & encryptedData,const vector<uint8_t> & additionalAuthenticatedData)182 optional<vector<uint8_t>> decryptAes128Gcm(const vector<uint8_t>& key,
183                                            const vector<uint8_t>& encryptedData,
184                                            const vector<uint8_t>& additionalAuthenticatedData) {
185     int cipherTextSize = int(encryptedData.size()) - kAesGcmIvSize - kAesGcmTagSize;
186     if (cipherTextSize < 0) {
187         LOG(ERROR) << "encryptedData too small";
188         return {};
189     }
190     unsigned char* nonce = (unsigned char*)encryptedData.data();
191     unsigned char* cipherText = nonce + kAesGcmIvSize;
192     unsigned char* tag = cipherText + cipherTextSize;
193 
194     vector<uint8_t> plainText;
195     plainText.resize(cipherTextSize);
196 
197     auto ctx = EvpCipherCtxPtr(EVP_CIPHER_CTX_new());
198     if (ctx.get() == nullptr) {
199         LOG(ERROR) << "EVP_CIPHER_CTX_new: failed";
200         return {};
201     }
202 
203     if (EVP_DecryptInit_ex(ctx.get(), EVP_aes_128_gcm(), NULL, NULL, NULL) != 1) {
204         LOG(ERROR) << "EVP_DecryptInit_ex: failed";
205         return {};
206     }
207 
208     if (EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_SET_IVLEN, kAesGcmIvSize, NULL) != 1) {
209         LOG(ERROR) << "EVP_CIPHER_CTX_ctrl: failed setting nonce length";
210         return {};
211     }
212 
213     if (EVP_DecryptInit_ex(ctx.get(), NULL, NULL, (unsigned char*)key.data(), nonce) != 1) {
214         LOG(ERROR) << "EVP_DecryptInit_ex: failed";
215         return {};
216     }
217 
218     int numWritten;
219     if (additionalAuthenticatedData.size() > 0) {
220         if (EVP_DecryptUpdate(ctx.get(), NULL, &numWritten,
221                               (unsigned char*)additionalAuthenticatedData.data(),
222                               additionalAuthenticatedData.size()) != 1) {
223             LOG(ERROR) << "EVP_DecryptUpdate: failed for additionalAuthenticatedData";
224             return {};
225         }
226         if ((size_t)numWritten != additionalAuthenticatedData.size()) {
227             LOG(ERROR) << "EVP_DecryptUpdate: Unexpected outl=" << numWritten << " (expected "
228                        << additionalAuthenticatedData.size() << ") for additionalAuthenticatedData";
229             return {};
230         }
231     }
232 
233     if (EVP_DecryptUpdate(ctx.get(), (unsigned char*)plainText.data(), &numWritten, cipherText,
234                           cipherTextSize) != 1) {
235         LOG(ERROR) << "EVP_DecryptUpdate: failed";
236         return {};
237     }
238     if (numWritten != cipherTextSize) {
239         LOG(ERROR) << "EVP_DecryptUpdate: Unexpected outl=" << numWritten << " (expected "
240                    << cipherTextSize << ")";
241         return {};
242     }
243 
244     if (!EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_SET_TAG, kAesGcmTagSize, tag)) {
245         LOG(ERROR) << "EVP_CIPHER_CTX_ctrl: failed setting expected tag";
246         return {};
247     }
248 
249     int ret = EVP_DecryptFinal_ex(ctx.get(), (unsigned char*)plainText.data() + numWritten,
250                                   &numWritten);
251     if (ret != 1) {
252         LOG(ERROR) << "EVP_DecryptFinal_ex: failed";
253         return {};
254     }
255     if (numWritten != 0) {
256         LOG(ERROR) << "EVP_DecryptFinal_ex: Unexpected non-zero outl=" << numWritten;
257         return {};
258     }
259 
260     return plainText;
261 }
262 
encryptAes128Gcm(const vector<uint8_t> & key,const vector<uint8_t> & nonce,const vector<uint8_t> & data,const vector<uint8_t> & additionalAuthenticatedData)263 optional<vector<uint8_t>> encryptAes128Gcm(const vector<uint8_t>& key, const vector<uint8_t>& nonce,
264                                            const vector<uint8_t>& data,
265                                            const vector<uint8_t>& additionalAuthenticatedData) {
266     if (key.size() != kAes128GcmKeySize) {
267         LOG(ERROR) << "key is not kAes128GcmKeySize bytes";
268         return {};
269     }
270     if (nonce.size() != kAesGcmIvSize) {
271         LOG(ERROR) << "nonce is not kAesGcmIvSize bytes";
272         return {};
273     }
274 
275     // The result is the nonce (kAesGcmIvSize bytes), the ciphertext, and
276     // finally the tag (kAesGcmTagSize bytes).
277     vector<uint8_t> encryptedData;
278     encryptedData.resize(data.size() + kAesGcmIvSize + kAesGcmTagSize);
279     unsigned char* noncePtr = (unsigned char*)encryptedData.data();
280     unsigned char* cipherText = noncePtr + kAesGcmIvSize;
281     unsigned char* tag = cipherText + data.size();
282     memcpy(noncePtr, nonce.data(), kAesGcmIvSize);
283 
284     auto ctx = EvpCipherCtxPtr(EVP_CIPHER_CTX_new());
285     if (ctx.get() == nullptr) {
286         LOG(ERROR) << "EVP_CIPHER_CTX_new: failed";
287         return {};
288     }
289 
290     if (EVP_EncryptInit_ex(ctx.get(), EVP_aes_128_gcm(), NULL, NULL, NULL) != 1) {
291         LOG(ERROR) << "EVP_EncryptInit_ex: failed";
292         return {};
293     }
294 
295     if (EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_SET_IVLEN, kAesGcmIvSize, NULL) != 1) {
296         LOG(ERROR) << "EVP_CIPHER_CTX_ctrl: failed setting nonce length";
297         return {};
298     }
299 
300     if (EVP_EncryptInit_ex(ctx.get(), NULL, NULL, (unsigned char*)key.data(),
301                            (unsigned char*)nonce.data()) != 1) {
302         LOG(ERROR) << "EVP_EncryptInit_ex: failed";
303         return {};
304     }
305 
306     int numWritten;
307     if (additionalAuthenticatedData.size() > 0) {
308         if (EVP_EncryptUpdate(ctx.get(), NULL, &numWritten,
309                               (unsigned char*)additionalAuthenticatedData.data(),
310                               additionalAuthenticatedData.size()) != 1) {
311             LOG(ERROR) << "EVP_EncryptUpdate: failed for additionalAuthenticatedData";
312             return {};
313         }
314         if ((size_t)numWritten != additionalAuthenticatedData.size()) {
315             LOG(ERROR) << "EVP_EncryptUpdate: Unexpected outl=" << numWritten << " (expected "
316                        << additionalAuthenticatedData.size() << ") for additionalAuthenticatedData";
317             return {};
318         }
319     }
320 
321     if (data.size() > 0) {
322         if (EVP_EncryptUpdate(ctx.get(), cipherText, &numWritten, (unsigned char*)data.data(),
323                               data.size()) != 1) {
324             LOG(ERROR) << "EVP_EncryptUpdate: failed";
325             return {};
326         }
327         if ((size_t)numWritten != data.size()) {
328             LOG(ERROR) << "EVP_EncryptUpdate: Unexpected outl=" << numWritten << " (expected "
329                        << data.size() << ")";
330             return {};
331         }
332     }
333 
334     if (EVP_EncryptFinal_ex(ctx.get(), cipherText + numWritten, &numWritten) != 1) {
335         LOG(ERROR) << "EVP_EncryptFinal_ex: failed";
336         return {};
337     }
338     if (numWritten != 0) {
339         LOG(ERROR) << "EVP_EncryptFinal_ex: Unexpected non-zero outl=" << numWritten;
340         return {};
341     }
342 
343     if (EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_GET_TAG, kAesGcmTagSize, tag) != 1) {
344         LOG(ERROR) << "EVP_CIPHER_CTX_ctrl: failed getting tag";
345         return {};
346     }
347 
348     return encryptedData;
349 }
350 
certificateChainJoin(const vector<vector<uint8_t>> & certificateChain)351 vector<uint8_t> certificateChainJoin(const vector<vector<uint8_t>>& certificateChain) {
352     vector<uint8_t> ret;
353     for (const vector<uint8_t>& certificate : certificateChain) {
354         ret.insert(ret.end(), certificate.begin(), certificate.end());
355     }
356     return ret;
357 }
358 
certificateChainSplit(const vector<uint8_t> & certificateChain)359 optional<vector<vector<uint8_t>>> certificateChainSplit(const vector<uint8_t>& certificateChain) {
360     const unsigned char* pStart = (unsigned char*)certificateChain.data();
361     const unsigned char* p = pStart;
362     const unsigned char* pEnd = p + certificateChain.size();
363     vector<vector<uint8_t>> certificates;
364     while (p < pEnd) {
365         size_t begin = p - pStart;
366         auto x509 = X509_Ptr(d2i_X509(nullptr, &p, pEnd - p));
367         size_t next = p - pStart;
368         if (x509 == nullptr) {
369             LOG(ERROR) << "Error parsing X509 certificate";
370             return {};
371         }
372         vector<uint8_t> cert =
373                 vector<uint8_t>(certificateChain.begin() + begin, certificateChain.begin() + next);
374         certificates.push_back(std::move(cert));
375     }
376     return certificates;
377 }
378 
parseX509Certificates(const vector<uint8_t> & certificateChain,vector<X509_Ptr> & parsedCertificates)379 static bool parseX509Certificates(const vector<uint8_t>& certificateChain,
380                                   vector<X509_Ptr>& parsedCertificates) {
381     const unsigned char* p = (unsigned char*)certificateChain.data();
382     const unsigned char* pEnd = p + certificateChain.size();
383     parsedCertificates.resize(0);
384     while (p < pEnd) {
385         auto x509 = X509_Ptr(d2i_X509(nullptr, &p, pEnd - p));
386         if (x509 == nullptr) {
387             LOG(ERROR) << "Error parsing X509 certificate";
388             return false;
389         }
390         parsedCertificates.push_back(std::move(x509));
391     }
392     return true;
393 }
394 
certificateSignedByPublicKey(const vector<uint8_t> & certificate,const vector<uint8_t> & publicKey)395 bool certificateSignedByPublicKey(const vector<uint8_t>& certificate,
396                                   const vector<uint8_t>& publicKey) {
397     const unsigned char* p = certificate.data();
398     auto x509 = X509_Ptr(d2i_X509(nullptr, &p, certificate.size()));
399     if (x509 == nullptr) {
400         LOG(ERROR) << "Error parsing X509 certificate";
401         return false;
402     }
403 
404     auto group = EC_GROUP_Ptr(EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
405     auto point = EC_POINT_Ptr(EC_POINT_new(group.get()));
406     if (EC_POINT_oct2point(group.get(), point.get(), publicKey.data(), publicKey.size(), nullptr) !=
407         1) {
408         LOG(ERROR) << "Error decoding publicKey";
409         return false;
410     }
411     auto ecKey = EC_KEY_Ptr(EC_KEY_new());
412     auto pkey = EVP_PKEY_Ptr(EVP_PKEY_new());
413     if (ecKey.get() == nullptr || pkey.get() == nullptr) {
414         LOG(ERROR) << "Memory allocation failed";
415         return false;
416     }
417     if (EC_KEY_set_group(ecKey.get(), group.get()) != 1) {
418         LOG(ERROR) << "Error setting group";
419         return false;
420     }
421     if (EC_KEY_set_public_key(ecKey.get(), point.get()) != 1) {
422         LOG(ERROR) << "Error setting point";
423         return false;
424     }
425     if (EVP_PKEY_set1_EC_KEY(pkey.get(), ecKey.get()) != 1) {
426         LOG(ERROR) << "Error setting key";
427         return false;
428     }
429 
430     if (X509_verify(x509.get(), pkey.get()) != 1) {
431         return false;
432     }
433 
434     return true;
435 }
436 
437 // TODO: Right now the only check we perform is to check that each certificate
438 //       is signed by its successor. We should - but currently don't - also check
439 //       things like valid dates etc.
440 //
441 //       It would be nice to use X509_verify_cert() instead of doing our own thing.
442 //
certificateChainValidate(const vector<uint8_t> & certificateChain)443 bool certificateChainValidate(const vector<uint8_t>& certificateChain) {
444     vector<X509_Ptr> certs;
445 
446     if (!parseX509Certificates(certificateChain, certs)) {
447         LOG(ERROR) << "Error parsing X509 certificates";
448         return false;
449     }
450 
451     if (certs.size() == 1) {
452         return true;
453     }
454 
455     for (size_t n = 1; n < certs.size(); n++) {
456         const X509_Ptr& keyCert = certs[n - 1];
457         const X509_Ptr& signingCert = certs[n];
458         EVP_PKEY_Ptr signingPubkey(X509_get_pubkey(signingCert.get()));
459         if (X509_verify(keyCert.get(), signingPubkey.get()) != 1) {
460             LOG(ERROR) << "Error validating cert at index " << n - 1
461                        << " is signed by its successor";
462             return false;
463         }
464     }
465 
466     return true;
467 }
468 
checkEcDsaSignature(const vector<uint8_t> & digest,const vector<uint8_t> & signature,const vector<uint8_t> & publicKey)469 bool checkEcDsaSignature(const vector<uint8_t>& digest, const vector<uint8_t>& signature,
470                          const vector<uint8_t>& publicKey) {
471     const unsigned char* p = (unsigned char*)signature.data();
472     auto sig = ECDSA_SIG_Ptr(d2i_ECDSA_SIG(nullptr, &p, signature.size()));
473     if (sig.get() == nullptr) {
474         LOG(ERROR) << "Error decoding DER encoded signature";
475         return false;
476     }
477 
478     auto group = EC_GROUP_Ptr(EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
479     auto point = EC_POINT_Ptr(EC_POINT_new(group.get()));
480     if (EC_POINT_oct2point(group.get(), point.get(), publicKey.data(), publicKey.size(), nullptr) !=
481         1) {
482         LOG(ERROR) << "Error decoding publicKey";
483         return false;
484     }
485     auto ecKey = EC_KEY_Ptr(EC_KEY_new());
486     auto pkey = EVP_PKEY_Ptr(EVP_PKEY_new());
487     if (ecKey.get() == nullptr || pkey.get() == nullptr) {
488         LOG(ERROR) << "Memory allocation failed";
489         return false;
490     }
491     if (EC_KEY_set_group(ecKey.get(), group.get()) != 1) {
492         LOG(ERROR) << "Error setting group";
493         return false;
494     }
495     if (EC_KEY_set_public_key(ecKey.get(), point.get()) != 1) {
496         LOG(ERROR) << "Error setting point";
497         return false;
498     }
499     if (EVP_PKEY_set1_EC_KEY(pkey.get(), ecKey.get()) != 1) {
500         LOG(ERROR) << "Error setting key";
501         return false;
502     }
503 
504     int rc = ECDSA_do_verify(digest.data(), digest.size(), sig.get(), ecKey.get());
505     if (rc != 1) {
506         LOG(ERROR) << "Error verifying signature (rc=" << rc << ")";
507         return false;
508     }
509 
510     return true;
511 }
512 
sha256(const vector<uint8_t> & data)513 vector<uint8_t> sha256(const vector<uint8_t>& data) {
514     vector<uint8_t> ret;
515     ret.resize(SHA256_DIGEST_LENGTH);
516     SHA256_CTX ctx;
517     SHA256_Init(&ctx);
518     SHA256_Update(&ctx, data.data(), data.size());
519     SHA256_Final((unsigned char*)ret.data(), &ctx);
520     return ret;
521 }
522 
signEcDsaDigest(const vector<uint8_t> & key,const vector<uint8_t> & dataDigest)523 optional<vector<uint8_t>> signEcDsaDigest(const vector<uint8_t>& key,
524                                           const vector<uint8_t>& dataDigest) {
525     auto bn = BIGNUM_Ptr(BN_bin2bn(key.data(), key.size(), nullptr));
526     if (bn.get() == nullptr) {
527         LOG(ERROR) << "Error creating BIGNUM";
528         return {};
529     }
530 
531     auto ec_key = EC_KEY_Ptr(EC_KEY_new_by_curve_name(NID_X9_62_prime256v1));
532     if (EC_KEY_set_private_key(ec_key.get(), bn.get()) != 1) {
533         LOG(ERROR) << "Error setting private key from BIGNUM";
534         return {};
535     }
536 
537     ECDSA_SIG* sig = ECDSA_do_sign(dataDigest.data(), dataDigest.size(), ec_key.get());
538     if (sig == nullptr) {
539         LOG(ERROR) << "Error signing digest";
540         return {};
541     }
542     size_t len = i2d_ECDSA_SIG(sig, nullptr);
543     vector<uint8_t> signature;
544     signature.resize(len);
545     unsigned char* p = (unsigned char*)signature.data();
546     i2d_ECDSA_SIG(sig, &p);
547     ECDSA_SIG_free(sig);
548     return signature;
549 }
550 
signEcDsa(const vector<uint8_t> & key,const vector<uint8_t> & data)551 optional<vector<uint8_t>> signEcDsa(const vector<uint8_t>& key, const vector<uint8_t>& data) {
552     return signEcDsaDigest(key, sha256(data));
553 }
554 
hmacSha256(const vector<uint8_t> & key,const vector<uint8_t> & data)555 optional<vector<uint8_t>> hmacSha256(const vector<uint8_t>& key, const vector<uint8_t>& data) {
556     HMAC_CTX ctx;
557     HMAC_CTX_init(&ctx);
558     if (HMAC_Init_ex(&ctx, key.data(), key.size(), EVP_sha256(), nullptr /* impl */) != 1) {
559         LOG(ERROR) << "Error initializing HMAC_CTX";
560         return {};
561     }
562     if (HMAC_Update(&ctx, data.data(), data.size()) != 1) {
563         LOG(ERROR) << "Error updating HMAC_CTX";
564         return {};
565     }
566     vector<uint8_t> hmac;
567     hmac.resize(32);
568     unsigned int size = 0;
569     if (HMAC_Final(&ctx, hmac.data(), &size) != 1) {
570         LOG(ERROR) << "Error finalizing HMAC_CTX";
571         return {};
572     }
573     if (size != 32) {
574         LOG(ERROR) << "Expected 32 bytes from HMAC_Final, got " << size;
575         return {};
576     }
577     return hmac;
578 }
579 
parseDigits(const char ** s,int numDigits)580 int parseDigits(const char** s, int numDigits) {
581     int result;
582     auto [_, ec] = std::from_chars(*s, *s + numDigits, result);
583     if (ec != std::errc()) {
584         LOG(ERROR) << "Error parsing " << numDigits << " digits "
585                    << " from " << s;
586         return 0;
587     }
588     *s += numDigits;
589     return result;
590 }
591 
parseAsn1Time(const ASN1_TIME * asn1Time,time_t * outTime)592 bool parseAsn1Time(const ASN1_TIME* asn1Time, time_t* outTime) {
593     struct tm tm;
594 
595     memset(&tm, '\0', sizeof(tm));
596     const char* timeStr = (const char*)asn1Time->data;
597     const char* s = timeStr;
598     if (asn1Time->type == V_ASN1_UTCTIME) {
599         tm.tm_year = parseDigits(&s, 2);
600         if (tm.tm_year < 70) {
601             tm.tm_year += 100;
602         }
603     } else if (asn1Time->type == V_ASN1_GENERALIZEDTIME) {
604         tm.tm_year = parseDigits(&s, 4) - 1900;
605         tm.tm_year -= 1900;
606     } else {
607         LOG(ERROR) << "Unsupported ASN1_TIME type " << asn1Time->type;
608         return false;
609     }
610     tm.tm_mon = parseDigits(&s, 2) - 1;
611     tm.tm_mday = parseDigits(&s, 2);
612     tm.tm_hour = parseDigits(&s, 2);
613     tm.tm_min = parseDigits(&s, 2);
614     tm.tm_sec = parseDigits(&s, 2);
615     // This may need to be updated if someone create certificates using +/- instead of Z.
616     //
617     if (*s != 'Z') {
618         LOG(ERROR) << "Expected Z in string '" << timeStr << "' at offset " << (s - timeStr);
619         return false;
620     }
621 
622     time_t t = timegm(&tm);
623     if (t == -1) {
624         LOG(ERROR) << "Error converting broken-down time to time_t";
625         return false;
626     }
627     *outTime = t;
628     return true;
629 }
630 
631 // Generates the attestation certificate with the parameters passed in.  Note
632 // that the passed in |activeTimeMilliSeconds| |expireTimeMilliSeconds| are in
633 // milli seconds since epoch.  We are setting them to milliseconds due to
634 // requirement in AuthorizationSet KM_DATE fields.  The certificate created is
635 // actually in seconds.
636 //
637 // If 0 is passed for expiration time, the expiration time from batch
638 // certificate will be used.
639 //
createAttestation(const EVP_PKEY * key,const vector<uint8_t> & applicationId,const vector<uint8_t> & challenge,uint64_t activeTimeMilliSeconds,uint64_t expireTimeMilliSeconds,bool isTestCredential)640 optional<vector<vector<uint8_t>>> createAttestation(
641         const EVP_PKEY* key, const vector<uint8_t>& applicationId, const vector<uint8_t>& challenge,
642         uint64_t activeTimeMilliSeconds, uint64_t expireTimeMilliSeconds, bool isTestCredential) {
643     // Pretend to be implemented in a trusted environment just so we can pass
644     // the VTS tests. Of course, this is a pretend-only game since hopefully no
645     // relying party is ever going to trust our batch key and those keys above
646     // it.
647     ::keymaster::PureSoftKeymasterContext context(::keymaster::KmVersion::KEYMINT_1,
648                                                   KM_SECURITY_LEVEL_TRUSTED_ENVIRONMENT);
649 
650     keymaster_error_t error;
651     ::keymaster::CertificateChain attestation_chain =
652             context.GetAttestationChain(KM_ALGORITHM_EC, &error);
653     if (KM_ERROR_OK != error) {
654         LOG(ERROR) << "Error getting attestation chain " << error;
655         return {};
656     }
657     if (expireTimeMilliSeconds == 0) {
658         if (attestation_chain.entry_count < 1) {
659             LOG(ERROR) << "Expected at least one entry in attestation chain";
660             return {};
661         }
662         keymaster_blob_t* bcBlob = &(attestation_chain.entries[0]);
663         const uint8_t* bcData = bcBlob->data;
664         auto bc = X509_Ptr(d2i_X509(nullptr, &bcData, bcBlob->data_length));
665         time_t bcNotAfter;
666         if (!parseAsn1Time(X509_get0_notAfter(bc.get()), &bcNotAfter)) {
667             LOG(ERROR) << "Error getting notAfter from batch certificate";
668             return {};
669         }
670         expireTimeMilliSeconds = bcNotAfter * 1000;
671     }
672 
673     ::keymaster::X509_NAME_Ptr subjectName;
674     if (KM_ERROR_OK !=
675         ::keymaster::make_name_from_str("Android Identity Credential Key", &subjectName)) {
676         LOG(ERROR) << "Cannot create attestation subject";
677         return {};
678     }
679 
680     vector<uint8_t> subject(i2d_X509_NAME(subjectName.get(), NULL));
681     unsigned char* subjectPtr = subject.data();
682 
683     i2d_X509_NAME(subjectName.get(), &subjectPtr);
684 
685     ::keymaster::AuthorizationSet auth_set(
686             ::keymaster::AuthorizationSetBuilder()
687                     .Authorization(::keymaster::TAG_CERTIFICATE_NOT_BEFORE, activeTimeMilliSeconds)
688                     .Authorization(::keymaster::TAG_CERTIFICATE_NOT_AFTER, expireTimeMilliSeconds)
689                     .Authorization(::keymaster::TAG_ATTESTATION_CHALLENGE, challenge.data(),
690                                    challenge.size())
691                     .Authorization(::keymaster::TAG_ACTIVE_DATETIME, activeTimeMilliSeconds)
692                     // Even though identity attestation hal said the application
693                     // id should be in software enforced authentication set,
694                     // keymaster portable lib expect the input in this
695                     // parameter because the software enforced in input to keymaster
696                     // refers to the key software enforced properties. And this
697                     // parameter refers to properties of the attestation which
698                     // includes app id.
699                     .Authorization(::keymaster::TAG_ATTESTATION_APPLICATION_ID,
700                                    applicationId.data(), applicationId.size())
701                     .Authorization(::keymaster::TAG_CERTIFICATE_SUBJECT, subject.data(),
702                                    subject.size())
703                     .Authorization(::keymaster::TAG_USAGE_EXPIRE_DATETIME, expireTimeMilliSeconds));
704 
705     // Unique id and device id is not applicable for identity credential attestation,
706     // so we don't need to set those or application id.
707     ::keymaster::AuthorizationSet swEnforced(::keymaster::AuthorizationSetBuilder().Authorization(
708             ::keymaster::TAG_CREATION_DATETIME, activeTimeMilliSeconds));
709 
710     ::keymaster::AuthorizationSetBuilder hwEnforcedBuilder =
711             ::keymaster::AuthorizationSetBuilder()
712                     .Authorization(::keymaster::TAG_PURPOSE, KM_PURPOSE_SIGN)
713                     .Authorization(::keymaster::TAG_KEY_SIZE, 256)
714                     .Authorization(::keymaster::TAG_ALGORITHM, KM_ALGORITHM_EC)
715                     .Authorization(::keymaster::TAG_NO_AUTH_REQUIRED)
716                     .Authorization(::keymaster::TAG_DIGEST, KM_DIGEST_SHA_2_256)
717                     .Authorization(::keymaster::TAG_EC_CURVE, KM_EC_CURVE_P_256)
718                     .Authorization(::keymaster::TAG_OS_VERSION, 42)
719                     .Authorization(::keymaster::TAG_OS_PATCHLEVEL, 43);
720 
721     // Only include TAG_IDENTITY_CREDENTIAL_KEY if it's not a test credential
722     if (!isTestCredential) {
723         hwEnforcedBuilder.Authorization(::keymaster::TAG_IDENTITY_CREDENTIAL_KEY);
724     }
725     ::keymaster::AuthorizationSet hwEnforced(hwEnforcedBuilder);
726 
727     ::keymaster::CertificateChain cert_chain_out = generate_attestation(
728             key, swEnforced, hwEnforced, auth_set, {} /* attest_key */, context, &error);
729 
730     if (KM_ERROR_OK != error) {
731         LOG(ERROR) << "Error generating attestation from EVP key: " << error;
732         return {};
733     }
734 
735     // translate certificate format from keymaster_cert_chain_t to vector<vector<uint8_t>>.
736     vector<vector<uint8_t>> attestationCertificate;
737     for (std::size_t i = 0; i < cert_chain_out.entry_count; i++) {
738         attestationCertificate.insert(
739                 attestationCertificate.end(),
740                 vector<uint8_t>(
741                         cert_chain_out.entries[i].data,
742                         cert_chain_out.entries[i].data + cert_chain_out.entries[i].data_length));
743     }
744 
745     return attestationCertificate;
746 }
747 
createEcKeyPairAndAttestation(const vector<uint8_t> & challenge,const vector<uint8_t> & applicationId,bool isTestCredential)748 optional<std::pair<vector<uint8_t>, vector<vector<uint8_t>>>> createEcKeyPairAndAttestation(
749         const vector<uint8_t>& challenge, const vector<uint8_t>& applicationId,
750         bool isTestCredential) {
751     auto ec_key = ::keymaster::EC_KEY_Ptr(EC_KEY_new());
752     auto pkey = ::keymaster::EVP_PKEY_Ptr(EVP_PKEY_new());
753     auto group = ::keymaster::EC_GROUP_Ptr(EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
754 
755     if (ec_key.get() == nullptr || pkey.get() == nullptr) {
756         LOG(ERROR) << "Memory allocation failed";
757         return {};
758     }
759 
760     if (EC_KEY_set_group(ec_key.get(), group.get()) != 1 ||
761         EC_KEY_generate_key(ec_key.get()) != 1 || EC_KEY_check_key(ec_key.get()) < 0) {
762         LOG(ERROR) << "Error generating key";
763         return {};
764     }
765 
766     if (EVP_PKEY_set1_EC_KEY(pkey.get(), ec_key.get()) != 1) {
767         LOG(ERROR) << "Error getting private key";
768         return {};
769     }
770 
771     uint64_t nowMs = time(nullptr) * 1000;
772     uint64_t expireTimeMs = 0;  // Set to same as batch certificate
773 
774     optional<vector<vector<uint8_t>>> attestationCert = createAttestation(
775             pkey.get(), applicationId, challenge, nowMs, expireTimeMs, isTestCredential);
776     if (!attestationCert) {
777         LOG(ERROR) << "Error create attestation from key and challenge";
778         return {};
779     }
780 
781     int size = i2d_PrivateKey(pkey.get(), nullptr);
782     if (size == 0) {
783         LOG(ERROR) << "Error generating public key encoding";
784         return {};
785     }
786 
787     vector<uint8_t> keyPair(size);
788     unsigned char* p = keyPair.data();
789     i2d_PrivateKey(pkey.get(), &p);
790 
791     return make_pair(keyPair, attestationCert.value());
792 }
793 
createAttestationForEcPublicKey(const vector<uint8_t> & publicKey,const vector<uint8_t> & challenge,const vector<uint8_t> & applicationId)794 optional<vector<vector<uint8_t>>> createAttestationForEcPublicKey(
795         const vector<uint8_t>& publicKey, const vector<uint8_t>& challenge,
796         const vector<uint8_t>& applicationId) {
797     auto group = EC_GROUP_Ptr(EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
798     auto point = EC_POINT_Ptr(EC_POINT_new(group.get()));
799     if (EC_POINT_oct2point(group.get(), point.get(), publicKey.data(), publicKey.size(), nullptr) !=
800         1) {
801         LOG(ERROR) << "Error decoding publicKey";
802         return {};
803     }
804     auto ecKey = EC_KEY_Ptr(EC_KEY_new());
805     auto pkey = EVP_PKEY_Ptr(EVP_PKEY_new());
806     if (ecKey.get() == nullptr || pkey.get() == nullptr) {
807         LOG(ERROR) << "Memory allocation failed";
808         return {};
809     }
810     if (EC_KEY_set_group(ecKey.get(), group.get()) != 1) {
811         LOG(ERROR) << "Error setting group";
812         return {};
813     }
814     if (EC_KEY_set_public_key(ecKey.get(), point.get()) != 1) {
815         LOG(ERROR) << "Error setting point";
816         return {};
817     }
818     if (EVP_PKEY_set1_EC_KEY(pkey.get(), ecKey.get()) != 1) {
819         LOG(ERROR) << "Error setting key";
820         return {};
821     }
822 
823     uint64_t nowMs = time(nullptr) * 1000;
824     uint64_t expireTimeMs = 0;  // Set to same as batch certificate
825 
826     optional<vector<vector<uint8_t>>> attestationCert =
827             createAttestation(pkey.get(), applicationId, challenge, nowMs, expireTimeMs,
828                               false /* isTestCredential */);
829     if (!attestationCert) {
830         LOG(ERROR) << "Error create attestation from key and challenge";
831         return {};
832     }
833 
834     return attestationCert.value();
835 }
836 
createEcKeyPair()837 optional<vector<uint8_t>> createEcKeyPair() {
838     auto ec_key = EC_KEY_Ptr(EC_KEY_new());
839     auto pkey = EVP_PKEY_Ptr(EVP_PKEY_new());
840     if (ec_key.get() == nullptr || pkey.get() == nullptr) {
841         LOG(ERROR) << "Memory allocation failed";
842         return {};
843     }
844 
845     auto group = EC_GROUP_Ptr(EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
846     if (group.get() == nullptr) {
847         LOG(ERROR) << "Error creating EC group by curve name";
848         return {};
849     }
850 
851     if (EC_KEY_set_group(ec_key.get(), group.get()) != 1 ||
852         EC_KEY_generate_key(ec_key.get()) != 1 || EC_KEY_check_key(ec_key.get()) < 0) {
853         LOG(ERROR) << "Error generating key";
854         return {};
855     }
856 
857     if (EVP_PKEY_set1_EC_KEY(pkey.get(), ec_key.get()) != 1) {
858         LOG(ERROR) << "Error getting private key";
859         return {};
860     }
861 
862     int size = i2d_PrivateKey(pkey.get(), nullptr);
863     if (size == 0) {
864         LOG(ERROR) << "Error generating public key encoding";
865         return {};
866     }
867     vector<uint8_t> keyPair;
868     keyPair.resize(size);
869     unsigned char* p = keyPair.data();
870     i2d_PrivateKey(pkey.get(), &p);
871     return keyPair;
872 }
873 
ecKeyPairGetPublicKey(const vector<uint8_t> & keyPair)874 optional<vector<uint8_t>> ecKeyPairGetPublicKey(const vector<uint8_t>& keyPair) {
875     const unsigned char* p = (const unsigned char*)keyPair.data();
876     auto pkey = EVP_PKEY_Ptr(d2i_PrivateKey(EVP_PKEY_EC, nullptr, &p, keyPair.size()));
877     if (pkey.get() == nullptr) {
878         LOG(ERROR) << "Error parsing keyPair";
879         return {};
880     }
881 
882     auto ecKey = EC_KEY_Ptr(EVP_PKEY_get1_EC_KEY(pkey.get()));
883     if (ecKey.get() == nullptr) {
884         LOG(ERROR) << "Failed getting EC key";
885         return {};
886     }
887 
888     auto ecGroup = EC_KEY_get0_group(ecKey.get());
889     auto ecPoint = EC_KEY_get0_public_key(ecKey.get());
890     int size = EC_POINT_point2oct(ecGroup, ecPoint, POINT_CONVERSION_UNCOMPRESSED, nullptr, 0,
891                                   nullptr);
892     if (size == 0) {
893         LOG(ERROR) << "Error generating public key encoding";
894         return {};
895     }
896 
897     vector<uint8_t> publicKey;
898     publicKey.resize(size);
899     EC_POINT_point2oct(ecGroup, ecPoint, POINT_CONVERSION_UNCOMPRESSED, publicKey.data(),
900                        publicKey.size(), nullptr);
901     return publicKey;
902 }
903 
ecKeyPairGetPrivateKey(const vector<uint8_t> & keyPair)904 optional<vector<uint8_t>> ecKeyPairGetPrivateKey(const vector<uint8_t>& keyPair) {
905     const unsigned char* p = (const unsigned char*)keyPair.data();
906     auto pkey = EVP_PKEY_Ptr(d2i_PrivateKey(EVP_PKEY_EC, nullptr, &p, keyPair.size()));
907     if (pkey.get() == nullptr) {
908         LOG(ERROR) << "Error parsing keyPair";
909         return {};
910     }
911 
912     auto ecKey = EC_KEY_Ptr(EVP_PKEY_get1_EC_KEY(pkey.get()));
913     if (ecKey.get() == nullptr) {
914         LOG(ERROR) << "Failed getting EC key";
915         return {};
916     }
917 
918     const BIGNUM* bignum = EC_KEY_get0_private_key(ecKey.get());
919     if (bignum == nullptr) {
920         LOG(ERROR) << "Error getting bignum from private key";
921         return {};
922     }
923     vector<uint8_t> privateKey;
924 
925     // Note that this may return fewer than 32 bytes so pad with zeroes since we
926     // want to always return 32 bytes.
927     size_t numBytes = BN_num_bytes(bignum);
928     if (numBytes > 32) {
929         LOG(ERROR) << "Size is " << numBytes << ", expected this to be 32 or less";
930         return {};
931     }
932     privateKey.resize(32);
933     for (size_t n = 0; n < 32 - numBytes; n++) {
934         privateKey[n] = 0x00;
935     }
936     BN_bn2bin(bignum, privateKey.data() + 32 - numBytes);
937     return privateKey;
938 }
939 
ecPrivateKeyToKeyPair(const vector<uint8_t> & privateKey)940 optional<vector<uint8_t>> ecPrivateKeyToKeyPair(const vector<uint8_t>& privateKey) {
941     auto bn = BIGNUM_Ptr(BN_bin2bn(privateKey.data(), privateKey.size(), nullptr));
942     if (bn.get() == nullptr) {
943         LOG(ERROR) << "Error creating BIGNUM";
944         return {};
945     }
946 
947     auto ecKey = EC_KEY_Ptr(EC_KEY_new_by_curve_name(NID_X9_62_prime256v1));
948     if (EC_KEY_set_private_key(ecKey.get(), bn.get()) != 1) {
949         LOG(ERROR) << "Error setting private key from BIGNUM";
950         return {};
951     }
952 
953     auto pkey = EVP_PKEY_Ptr(EVP_PKEY_new());
954     if (pkey.get() == nullptr) {
955         LOG(ERROR) << "Memory allocation failed";
956         return {};
957     }
958 
959     if (EVP_PKEY_set1_EC_KEY(pkey.get(), ecKey.get()) != 1) {
960         LOG(ERROR) << "Error getting private key";
961         return {};
962     }
963 
964     int size = i2d_PrivateKey(pkey.get(), nullptr);
965     if (size == 0) {
966         LOG(ERROR) << "Error generating public key encoding";
967         return {};
968     }
969     vector<uint8_t> keyPair;
970     keyPair.resize(size);
971     unsigned char* p = keyPair.data();
972     i2d_PrivateKey(pkey.get(), &p);
973     return keyPair;
974 }
975 
ecKeyPairGetPkcs12(const vector<uint8_t> & keyPair,const string & name,const string & serialDecimal,const string & issuer,const string & subject,time_t validityNotBefore,time_t validityNotAfter)976 optional<vector<uint8_t>> ecKeyPairGetPkcs12(const vector<uint8_t>& keyPair, const string& name,
977                                              const string& serialDecimal, const string& issuer,
978                                              const string& subject, time_t validityNotBefore,
979                                              time_t validityNotAfter) {
980     const unsigned char* p = (const unsigned char*)keyPair.data();
981     auto pkey = EVP_PKEY_Ptr(d2i_PrivateKey(EVP_PKEY_EC, nullptr, &p, keyPair.size()));
982     if (pkey.get() == nullptr) {
983         LOG(ERROR) << "Error parsing keyPair";
984         return {};
985     }
986 
987     auto x509 = X509_Ptr(X509_new());
988     if (!x509.get()) {
989         LOG(ERROR) << "Error creating X509 certificate";
990         return {};
991     }
992 
993     if (!X509_set_version(x509.get(), 2 /* version 3, but zero-based */)) {
994         LOG(ERROR) << "Error setting version to 3";
995         return {};
996     }
997 
998     if (X509_set_pubkey(x509.get(), pkey.get()) != 1) {
999         LOG(ERROR) << "Error setting public key";
1000         return {};
1001     }
1002 
1003     BIGNUM* bignumSerial = nullptr;
1004     if (BN_dec2bn(&bignumSerial, serialDecimal.c_str()) == 0) {
1005         LOG(ERROR) << "Error parsing serial";
1006         return {};
1007     }
1008     auto bignumSerialPtr = BIGNUM_Ptr(bignumSerial);
1009     auto asnSerial = ASN1_INTEGER_Ptr(BN_to_ASN1_INTEGER(bignumSerial, nullptr));
1010     if (X509_set_serialNumber(x509.get(), asnSerial.get()) != 1) {
1011         LOG(ERROR) << "Error setting serial";
1012         return {};
1013     }
1014 
1015     auto x509Issuer = X509_NAME_Ptr(X509_NAME_new());
1016     if (x509Issuer.get() == nullptr ||
1017         X509_NAME_add_entry_by_txt(x509Issuer.get(), "CN", MBSTRING_ASC,
1018                                    (const uint8_t*)issuer.c_str(), issuer.size(), -1 /* loc */,
1019                                    0 /* set */) != 1 ||
1020         X509_set_issuer_name(x509.get(), x509Issuer.get()) != 1) {
1021         LOG(ERROR) << "Error setting issuer";
1022         return {};
1023     }
1024 
1025     auto x509Subject = X509_NAME_Ptr(X509_NAME_new());
1026     if (x509Subject.get() == nullptr ||
1027         X509_NAME_add_entry_by_txt(x509Subject.get(), "CN", MBSTRING_ASC,
1028                                    (const uint8_t*)subject.c_str(), subject.size(), -1 /* loc */,
1029                                    0 /* set */) != 1 ||
1030         X509_set_subject_name(x509.get(), x509Subject.get()) != 1) {
1031         LOG(ERROR) << "Error setting subject";
1032         return {};
1033     }
1034 
1035     auto asnNotBefore = ASN1_TIME_Ptr(ASN1_TIME_set(nullptr, validityNotBefore));
1036     if (asnNotBefore.get() == nullptr || X509_set_notBefore(x509.get(), asnNotBefore.get()) != 1) {
1037         LOG(ERROR) << "Error setting notBefore";
1038         return {};
1039     }
1040 
1041     auto asnNotAfter = ASN1_TIME_Ptr(ASN1_TIME_set(nullptr, validityNotAfter));
1042     if (asnNotAfter.get() == nullptr || X509_set_notAfter(x509.get(), asnNotAfter.get()) != 1) {
1043         LOG(ERROR) << "Error setting notAfter";
1044         return {};
1045     }
1046 
1047     if (X509_sign(x509.get(), pkey.get(), EVP_sha256()) == 0) {
1048         LOG(ERROR) << "Error signing X509 certificate";
1049         return {};
1050     }
1051 
1052     // Ideally we wouldn't encrypt it (we're only using this function for
1053     // sending a key-pair over binder to the Android app) but BoringSSL does not
1054     // support this: from pkcs8_x509.c in BoringSSL: "In OpenSSL, -1 here means
1055     // to use no encryption, which we do not currently support."
1056     //
1057     // Passing nullptr as |pass|, though, means "no password". So we'll do that.
1058     // Compare with the receiving side - CredstoreIdentityCredential.java - where
1059     // an empty char[] is passed as the password.
1060     //
1061     auto pkcs12 = PKCS12_Ptr(PKCS12_create(nullptr, name.c_str(), pkey.get(), x509.get(),
1062                                            nullptr,  // ca
1063                                            0,        // nid_key
1064                                            0,        // nid_cert
1065                                            0,        // iter,
1066                                            0,        // mac_iter,
1067                                            0));      // keytype
1068     if (pkcs12.get() == nullptr) {
1069         char buf[128];
1070         long errCode = ERR_get_error();
1071         ERR_error_string_n(errCode, buf, sizeof buf);
1072         LOG(ERROR) << "Error creating PKCS12, code " << errCode << ": " << buf;
1073         return {};
1074     }
1075 
1076     unsigned char* buffer = nullptr;
1077     int length = i2d_PKCS12(pkcs12.get(), &buffer);
1078     if (length < 0) {
1079         LOG(ERROR) << "Error encoding PKCS12";
1080         return {};
1081     }
1082     vector<uint8_t> pkcs12Bytes;
1083     pkcs12Bytes.resize(length);
1084     memcpy(pkcs12Bytes.data(), buffer, length);
1085     OPENSSL_free(buffer);
1086 
1087     return pkcs12Bytes;
1088 }
1089 
ecPublicKeyGenerateCertificate(const vector<uint8_t> & publicKey,const vector<uint8_t> & signingKey,const string & serialDecimal,const string & issuer,const string & subject,time_t validityNotBefore,time_t validityNotAfter,const map<string,vector<uint8_t>> & extensions)1090 optional<vector<uint8_t>> ecPublicKeyGenerateCertificate(
1091         const vector<uint8_t>& publicKey, const vector<uint8_t>& signingKey,
1092         const string& serialDecimal, const string& issuer, const string& subject,
1093         time_t validityNotBefore, time_t validityNotAfter,
1094         const map<string, vector<uint8_t>>& extensions) {
1095     auto group = EC_GROUP_Ptr(EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
1096     auto point = EC_POINT_Ptr(EC_POINT_new(group.get()));
1097     if (EC_POINT_oct2point(group.get(), point.get(), publicKey.data(), publicKey.size(), nullptr) !=
1098         1) {
1099         LOG(ERROR) << "Error decoding publicKey";
1100         return {};
1101     }
1102     auto ecKey = EC_KEY_Ptr(EC_KEY_new());
1103     auto pkey = EVP_PKEY_Ptr(EVP_PKEY_new());
1104     if (ecKey.get() == nullptr || pkey.get() == nullptr) {
1105         LOG(ERROR) << "Memory allocation failed";
1106         return {};
1107     }
1108     if (EC_KEY_set_group(ecKey.get(), group.get()) != 1) {
1109         LOG(ERROR) << "Error setting group";
1110         return {};
1111     }
1112     if (EC_KEY_set_public_key(ecKey.get(), point.get()) != 1) {
1113         LOG(ERROR) << "Error setting point";
1114         return {};
1115     }
1116     if (EVP_PKEY_set1_EC_KEY(pkey.get(), ecKey.get()) != 1) {
1117         LOG(ERROR) << "Error setting key";
1118         return {};
1119     }
1120 
1121     auto bn = BIGNUM_Ptr(BN_bin2bn(signingKey.data(), signingKey.size(), nullptr));
1122     if (bn.get() == nullptr) {
1123         LOG(ERROR) << "Error creating BIGNUM for private key";
1124         return {};
1125     }
1126     auto privEcKey = EC_KEY_Ptr(EC_KEY_new_by_curve_name(NID_X9_62_prime256v1));
1127     if (EC_KEY_set_private_key(privEcKey.get(), bn.get()) != 1) {
1128         LOG(ERROR) << "Error setting private key from BIGNUM";
1129         return {};
1130     }
1131     auto privPkey = EVP_PKEY_Ptr(EVP_PKEY_new());
1132     if (EVP_PKEY_set1_EC_KEY(privPkey.get(), privEcKey.get()) != 1) {
1133         LOG(ERROR) << "Error setting private key";
1134         return {};
1135     }
1136 
1137     auto x509 = X509_Ptr(X509_new());
1138     if (!x509.get()) {
1139         LOG(ERROR) << "Error creating X509 certificate";
1140         return {};
1141     }
1142 
1143     if (!X509_set_version(x509.get(), 2 /* version 3, but zero-based */)) {
1144         LOG(ERROR) << "Error setting version to 3";
1145         return {};
1146     }
1147 
1148     if (X509_set_pubkey(x509.get(), pkey.get()) != 1) {
1149         LOG(ERROR) << "Error setting public key";
1150         return {};
1151     }
1152 
1153     BIGNUM* bignumSerial = nullptr;
1154     if (BN_dec2bn(&bignumSerial, serialDecimal.c_str()) == 0) {
1155         LOG(ERROR) << "Error parsing serial";
1156         return {};
1157     }
1158     auto bignumSerialPtr = BIGNUM_Ptr(bignumSerial);
1159     auto asnSerial = ASN1_INTEGER_Ptr(BN_to_ASN1_INTEGER(bignumSerial, nullptr));
1160     if (X509_set_serialNumber(x509.get(), asnSerial.get()) != 1) {
1161         LOG(ERROR) << "Error setting serial";
1162         return {};
1163     }
1164 
1165     auto x509Issuer = X509_NAME_Ptr(X509_NAME_new());
1166     if (x509Issuer.get() == nullptr ||
1167         X509_NAME_add_entry_by_txt(x509Issuer.get(), "CN", MBSTRING_ASC,
1168                                    (const uint8_t*)issuer.c_str(), issuer.size(), -1 /* loc */,
1169                                    0 /* set */) != 1 ||
1170         X509_set_issuer_name(x509.get(), x509Issuer.get()) != 1) {
1171         LOG(ERROR) << "Error setting issuer";
1172         return {};
1173     }
1174 
1175     auto x509Subject = X509_NAME_Ptr(X509_NAME_new());
1176     if (x509Subject.get() == nullptr ||
1177         X509_NAME_add_entry_by_txt(x509Subject.get(), "CN", MBSTRING_ASC,
1178                                    (const uint8_t*)subject.c_str(), subject.size(), -1 /* loc */,
1179                                    0 /* set */) != 1 ||
1180         X509_set_subject_name(x509.get(), x509Subject.get()) != 1) {
1181         LOG(ERROR) << "Error setting subject";
1182         return {};
1183     }
1184 
1185     auto asnNotBefore = ASN1_TIME_Ptr(ASN1_TIME_set(nullptr, validityNotBefore));
1186     if (asnNotBefore.get() == nullptr || X509_set_notBefore(x509.get(), asnNotBefore.get()) != 1) {
1187         LOG(ERROR) << "Error setting notBefore";
1188         return {};
1189     }
1190 
1191     auto asnNotAfter = ASN1_TIME_Ptr(ASN1_TIME_set(nullptr, validityNotAfter));
1192     if (asnNotAfter.get() == nullptr || X509_set_notAfter(x509.get(), asnNotAfter.get()) != 1) {
1193         LOG(ERROR) << "Error setting notAfter";
1194         return {};
1195     }
1196 
1197     for (auto const& [oidStr, blob] : extensions) {
1198         ASN1_OBJECT_Ptr oid(
1199                 OBJ_txt2obj(oidStr.c_str(), 1));  // accept numerical dotted string form only
1200         if (!oid.get()) {
1201             LOG(ERROR) << "Error setting OID";
1202             return {};
1203         }
1204         ASN1_OCTET_STRING_Ptr octetString(ASN1_OCTET_STRING_new());
1205         if (!ASN1_OCTET_STRING_set(octetString.get(), blob.data(), blob.size())) {
1206             LOG(ERROR) << "Error setting octet string for extension";
1207             return {};
1208         }
1209 
1210         X509_EXTENSION_Ptr extension = X509_EXTENSION_Ptr(X509_EXTENSION_new());
1211         extension.reset(X509_EXTENSION_create_by_OBJ(nullptr, oid.get(), 0 /* not critical */,
1212                                                      octetString.get()));
1213         if (!extension.get()) {
1214             LOG(ERROR) << "Error setting extension";
1215             return {};
1216         }
1217         if (!X509_add_ext(x509.get(), extension.get(), -1)) {
1218             LOG(ERROR) << "Error adding extension";
1219             return {};
1220         }
1221     }
1222 
1223     if (X509_sign(x509.get(), privPkey.get(), EVP_sha256()) == 0) {
1224         LOG(ERROR) << "Error signing X509 certificate";
1225         return {};
1226     }
1227 
1228     unsigned char* buffer = nullptr;
1229     int length = i2d_X509(x509.get(), &buffer);
1230     if (length < 0) {
1231         LOG(ERROR) << "Error DER encoding X509 certificate";
1232         return {};
1233     }
1234 
1235     vector<uint8_t> certificate;
1236     certificate.resize(length);
1237     memcpy(certificate.data(), buffer, length);
1238     OPENSSL_free(buffer);
1239     return certificate;
1240 }
1241 
ecdh(const vector<uint8_t> & publicKey,const vector<uint8_t> & privateKey)1242 optional<vector<uint8_t>> ecdh(const vector<uint8_t>& publicKey,
1243                                const vector<uint8_t>& privateKey) {
1244     auto group = EC_GROUP_Ptr(EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
1245     auto point = EC_POINT_Ptr(EC_POINT_new(group.get()));
1246     if (EC_POINT_oct2point(group.get(), point.get(), publicKey.data(), publicKey.size(), nullptr) !=
1247         1) {
1248         LOG(ERROR) << "Error decoding publicKey";
1249         return {};
1250     }
1251     auto ecKey = EC_KEY_Ptr(EC_KEY_new());
1252     auto pkey = EVP_PKEY_Ptr(EVP_PKEY_new());
1253     if (ecKey.get() == nullptr || pkey.get() == nullptr) {
1254         LOG(ERROR) << "Memory allocation failed";
1255         return {};
1256     }
1257     if (EC_KEY_set_group(ecKey.get(), group.get()) != 1) {
1258         LOG(ERROR) << "Error setting group";
1259         return {};
1260     }
1261     if (EC_KEY_set_public_key(ecKey.get(), point.get()) != 1) {
1262         LOG(ERROR) << "Error setting point";
1263         return {};
1264     }
1265     if (EVP_PKEY_set1_EC_KEY(pkey.get(), ecKey.get()) != 1) {
1266         LOG(ERROR) << "Error setting key";
1267         return {};
1268     }
1269 
1270     auto bn = BIGNUM_Ptr(BN_bin2bn(privateKey.data(), privateKey.size(), nullptr));
1271     if (bn.get() == nullptr) {
1272         LOG(ERROR) << "Error creating BIGNUM for private key";
1273         return {};
1274     }
1275     auto privEcKey = EC_KEY_Ptr(EC_KEY_new_by_curve_name(NID_X9_62_prime256v1));
1276     if (EC_KEY_set_private_key(privEcKey.get(), bn.get()) != 1) {
1277         LOG(ERROR) << "Error setting private key from BIGNUM";
1278         return {};
1279     }
1280     auto privPkey = EVP_PKEY_Ptr(EVP_PKEY_new());
1281     if (EVP_PKEY_set1_EC_KEY(privPkey.get(), privEcKey.get()) != 1) {
1282         LOG(ERROR) << "Error setting private key";
1283         return {};
1284     }
1285 
1286     auto ctx = EVP_PKEY_CTX_Ptr(EVP_PKEY_CTX_new(privPkey.get(), NULL));
1287     if (ctx.get() == nullptr) {
1288         LOG(ERROR) << "Error creating context";
1289         return {};
1290     }
1291 
1292     if (EVP_PKEY_derive_init(ctx.get()) != 1) {
1293         LOG(ERROR) << "Error initializing context";
1294         return {};
1295     }
1296 
1297     if (EVP_PKEY_derive_set_peer(ctx.get(), pkey.get()) != 1) {
1298         LOG(ERROR) << "Error setting peer";
1299         return {};
1300     }
1301 
1302     /* Determine buffer length for shared secret */
1303     size_t secretLen = 0;
1304     if (EVP_PKEY_derive(ctx.get(), NULL, &secretLen) != 1) {
1305         LOG(ERROR) << "Error determing length of shared secret";
1306         return {};
1307     }
1308     vector<uint8_t> sharedSecret;
1309     sharedSecret.resize(secretLen);
1310 
1311     if (EVP_PKEY_derive(ctx.get(), sharedSecret.data(), &secretLen) != 1) {
1312         LOG(ERROR) << "Error deriving shared secret";
1313         return {};
1314     }
1315     return sharedSecret;
1316 }
1317 
hkdf(const vector<uint8_t> & sharedSecret,const vector<uint8_t> & salt,const vector<uint8_t> & info,size_t size)1318 optional<vector<uint8_t>> hkdf(const vector<uint8_t>& sharedSecret, const vector<uint8_t>& salt,
1319                                const vector<uint8_t>& info, size_t size) {
1320     vector<uint8_t> derivedKey;
1321     derivedKey.resize(size);
1322     if (HKDF(derivedKey.data(), derivedKey.size(), EVP_sha256(), sharedSecret.data(),
1323              sharedSecret.size(), salt.data(), salt.size(), info.data(), info.size()) != 1) {
1324         LOG(ERROR) << "Error deriving key";
1325         return {};
1326     }
1327     return derivedKey;
1328 }
1329 
removeLeadingZeroes(vector<uint8_t> & vec)1330 void removeLeadingZeroes(vector<uint8_t>& vec) {
1331     while (vec.size() >= 1 && vec[0] == 0x00) {
1332         vec.erase(vec.begin());
1333     }
1334 }
1335 
ecPublicKeyGetXandY(const vector<uint8_t> & publicKey)1336 tuple<bool, vector<uint8_t>, vector<uint8_t>> ecPublicKeyGetXandY(
1337         const vector<uint8_t>& publicKey) {
1338     if (publicKey.size() != 65 || publicKey[0] != 0x04) {
1339         LOG(ERROR) << "publicKey is not in the expected format";
1340         return std::make_tuple(false, vector<uint8_t>(), vector<uint8_t>());
1341     }
1342     vector<uint8_t> x, y;
1343     x.resize(32);
1344     y.resize(32);
1345     memcpy(x.data(), publicKey.data() + 1, 32);
1346     memcpy(y.data(), publicKey.data() + 33, 32);
1347 
1348     removeLeadingZeroes(x);
1349     removeLeadingZeroes(y);
1350 
1351     return std::make_tuple(true, x, y);
1352 }
1353 
certificateChainGetTopMostKey(const vector<uint8_t> & certificateChain)1354 optional<vector<uint8_t>> certificateChainGetTopMostKey(const vector<uint8_t>& certificateChain) {
1355     vector<X509_Ptr> certs;
1356     if (!parseX509Certificates(certificateChain, certs)) {
1357         return {};
1358     }
1359     if (certs.size() < 1) {
1360         LOG(ERROR) << "No certificates in chain";
1361         return {};
1362     }
1363 
1364     auto pkey = EVP_PKEY_Ptr(X509_get_pubkey(certs[0].get()));
1365     if (pkey.get() == nullptr) {
1366         LOG(ERROR) << "No public key";
1367         return {};
1368     }
1369 
1370     auto ecKey = EC_KEY_Ptr(EVP_PKEY_get1_EC_KEY(pkey.get()));
1371     if (ecKey.get() == nullptr) {
1372         LOG(ERROR) << "Failed getting EC key";
1373         return {};
1374     }
1375 
1376     auto ecGroup = EC_KEY_get0_group(ecKey.get());
1377     auto ecPoint = EC_KEY_get0_public_key(ecKey.get());
1378     int size = EC_POINT_point2oct(ecGroup, ecPoint, POINT_CONVERSION_UNCOMPRESSED, nullptr, 0,
1379                                   nullptr);
1380     if (size == 0) {
1381         LOG(ERROR) << "Error generating public key encoding";
1382         return {};
1383     }
1384     vector<uint8_t> publicKey;
1385     publicKey.resize(size);
1386     EC_POINT_point2oct(ecGroup, ecPoint, POINT_CONVERSION_UNCOMPRESSED, publicKey.data(),
1387                        publicKey.size(), nullptr);
1388     return publicKey;
1389 }
1390 
certificateGetExtension(const vector<uint8_t> & x509Certificate,const string & oidStr)1391 optional<vector<uint8_t>> certificateGetExtension(const vector<uint8_t>& x509Certificate,
1392                                                   const string& oidStr) {
1393     vector<X509_Ptr> certs;
1394     if (!parseX509Certificates(x509Certificate, certs)) {
1395         return {};
1396     }
1397     if (certs.size() < 1) {
1398         LOG(ERROR) << "No certificates in chain";
1399         return {};
1400     }
1401 
1402     ASN1_OBJECT_Ptr oid(
1403             OBJ_txt2obj(oidStr.c_str(), 1));  // accept numerical dotted string form only
1404     if (!oid.get()) {
1405         LOG(ERROR) << "Error setting OID";
1406         return {};
1407     }
1408 
1409     int location = X509_get_ext_by_OBJ(certs[0].get(), oid.get(), -1 /* search from beginning */);
1410     if (location == -1) {
1411         return {};
1412     }
1413 
1414     X509_EXTENSION* ext = X509_get_ext(certs[0].get(), location);
1415     if (ext == nullptr) {
1416         return {};
1417     }
1418 
1419     ASN1_OCTET_STRING* octetString = X509_EXTENSION_get_data(ext);
1420     if (octetString == nullptr) {
1421         return {};
1422     }
1423     vector<uint8_t> result;
1424     result.resize(octetString->length);
1425     memcpy(result.data(), octetString->data, octetString->length);
1426     return result;
1427 }
1428 
certificateFindPublicKey(const vector<uint8_t> & x509Certificate)1429 optional<pair<size_t, size_t>> certificateFindPublicKey(const vector<uint8_t>& x509Certificate) {
1430     vector<X509_Ptr> certs;
1431     if (!parseX509Certificates(x509Certificate, certs)) {
1432         return {};
1433     }
1434     if (certs.size() < 1) {
1435         LOG(ERROR) << "No certificates in chain";
1436         return {};
1437     }
1438 
1439     auto pkey = EVP_PKEY_Ptr(X509_get_pubkey(certs[0].get()));
1440     if (pkey.get() == nullptr) {
1441         LOG(ERROR) << "No public key";
1442         return {};
1443     }
1444 
1445     auto ecKey = EC_KEY_Ptr(EVP_PKEY_get1_EC_KEY(pkey.get()));
1446     if (ecKey.get() == nullptr) {
1447         LOG(ERROR) << "Failed getting EC key";
1448         return {};
1449     }
1450 
1451     auto ecGroup = EC_KEY_get0_group(ecKey.get());
1452     auto ecPoint = EC_KEY_get0_public_key(ecKey.get());
1453     int size = EC_POINT_point2oct(ecGroup, ecPoint, POINT_CONVERSION_UNCOMPRESSED, nullptr, 0,
1454                                   nullptr);
1455     if (size == 0) {
1456         LOG(ERROR) << "Error generating public key encoding";
1457         return {};
1458     }
1459     vector<uint8_t> publicKey;
1460     publicKey.resize(size);
1461     EC_POINT_point2oct(ecGroup, ecPoint, POINT_CONVERSION_UNCOMPRESSED, publicKey.data(),
1462                        publicKey.size(), nullptr);
1463 
1464     size_t publicKeyOffset = 0;
1465     size_t publicKeySize = (size_t)size;
1466     void* location = memmem((const void*)x509Certificate.data(), x509Certificate.size(),
1467                             (const void*)publicKey.data(), publicKey.size());
1468 
1469     if (location == NULL) {
1470         LOG(ERROR) << "Error finding publicKey from x509Certificate";
1471         return {};
1472     }
1473     publicKeyOffset = (size_t)((const char*)location - (const char*)x509Certificate.data());
1474 
1475     return std::make_pair(publicKeyOffset, publicKeySize);
1476 }
1477 
certificateTbsCertificate(const vector<uint8_t> & x509Certificate)1478 optional<pair<size_t, size_t>> certificateTbsCertificate(const vector<uint8_t>& x509Certificate) {
1479     vector<X509_Ptr> certs;
1480     if (!parseX509Certificates(x509Certificate, certs)) {
1481         return {};
1482     }
1483     if (certs.size() < 1) {
1484         LOG(ERROR) << "No certificates in chain";
1485         return {};
1486     }
1487 
1488     unsigned char* buf = NULL;
1489     int len = i2d_re_X509_tbs(certs[0].get(), &buf);
1490     if ((len < 0) || (buf == NULL)) {
1491         LOG(ERROR) << "fail to extract tbsCertificate in x509Certificate";
1492         return {};
1493     }
1494 
1495     vector<uint8_t> tbsCertificate(len);
1496     memcpy(tbsCertificate.data(), buf, len);
1497 
1498     size_t tbsCertificateOffset = 0;
1499     size_t tbsCertificateSize = (size_t)len;
1500     void* location = memmem((const void*)x509Certificate.data(), x509Certificate.size(),
1501                             (const void*)tbsCertificate.data(), tbsCertificate.size());
1502 
1503     if (location == NULL) {
1504         LOG(ERROR) << "Error finding tbsCertificate from x509Certificate";
1505         return {};
1506     }
1507     tbsCertificateOffset = (size_t)((const char*)location - (const char*)x509Certificate.data());
1508 
1509     return std::make_pair(tbsCertificateOffset, tbsCertificateSize);
1510 }
1511 
certificateGetValidity(const vector<uint8_t> & x509Certificate)1512 optional<pair<time_t, time_t>> certificateGetValidity(const vector<uint8_t>& x509Certificate) {
1513     vector<X509_Ptr> certs;
1514     if (!parseX509Certificates(x509Certificate, certs)) {
1515         LOG(ERROR) << "Error parsing certificates";
1516         return {};
1517     }
1518     if (certs.size() < 1) {
1519         LOG(ERROR) << "No certificates in chain";
1520         return {};
1521     }
1522 
1523     time_t notBefore;
1524     time_t notAfter;
1525     if (!parseAsn1Time(X509_get0_notBefore(certs[0].get()), &notBefore)) {
1526         LOG(ERROR) << "Error parsing notBefore";
1527         return {};
1528     }
1529 
1530     if (!parseAsn1Time(X509_get0_notAfter(certs[0].get()), &notAfter)) {
1531         LOG(ERROR) << "Error parsing notAfter";
1532         return {};
1533     }
1534 
1535     return std::make_pair(notBefore, notAfter);
1536 }
1537 
certificateFindSignature(const vector<uint8_t> & x509Certificate)1538 optional<pair<size_t, size_t>> certificateFindSignature(const vector<uint8_t>& x509Certificate) {
1539     vector<X509_Ptr> certs;
1540     if (!parseX509Certificates(x509Certificate, certs)) {
1541         return {};
1542     }
1543     if (certs.size() < 1) {
1544         LOG(ERROR) << "No certificates in chain";
1545         return {};
1546     }
1547 
1548     ASN1_BIT_STRING* psig;
1549     X509_ALGOR* palg;
1550     X509_get0_signature((const ASN1_BIT_STRING**)&psig, (const X509_ALGOR**)&palg, certs[0].get());
1551 
1552     vector<char> signature(psig->length);
1553     memcpy(signature.data(), psig->data, psig->length);
1554 
1555     size_t signatureOffset = 0;
1556     size_t signatureSize = (size_t)psig->length;
1557     void* location = memmem((const void*)x509Certificate.data(), x509Certificate.size(),
1558                             (const void*)signature.data(), signature.size());
1559 
1560     if (location == NULL) {
1561         LOG(ERROR) << "Error finding signature from x509Certificate";
1562         return {};
1563     }
1564     signatureOffset = (size_t)((const char*)location - (const char*)x509Certificate.data());
1565 
1566     return std::make_pair(signatureOffset, signatureSize);
1567 }
1568 
1569 // ---------------------------------------------------------------------------
1570 // COSE Utility Functions
1571 // ---------------------------------------------------------------------------
1572 
coseBuildToBeSigned(const vector<uint8_t> & encodedProtectedHeaders,const vector<uint8_t> & data,const vector<uint8_t> & detachedContent)1573 vector<uint8_t> coseBuildToBeSigned(const vector<uint8_t>& encodedProtectedHeaders,
1574                                     const vector<uint8_t>& data,
1575                                     const vector<uint8_t>& detachedContent) {
1576     cppbor::Array sigStructure;
1577     sigStructure.add("Signature1");
1578     sigStructure.add(encodedProtectedHeaders);
1579 
1580     // We currently don't support Externally Supplied Data (RFC 8152 section 4.3)
1581     // so external_aad is the empty bstr
1582     vector<uint8_t> emptyExternalAad;
1583     sigStructure.add(emptyExternalAad);
1584 
1585     // Next field is the payload, independently of how it's transported (RFC
1586     // 8152 section 4.4). Since our API specifies only one of |data| and
1587     // |detachedContent| can be non-empty, it's simply just the non-empty one.
1588     if (data.size() > 0) {
1589         sigStructure.add(data);
1590     } else {
1591         sigStructure.add(detachedContent);
1592     }
1593     return sigStructure.encode();
1594 }
1595 
coseEncodeHeaders(const cppbor::Map & protectedHeaders)1596 vector<uint8_t> coseEncodeHeaders(const cppbor::Map& protectedHeaders) {
1597     if (protectedHeaders.size() == 0) {
1598         cppbor::Bstr emptyBstr(vector<uint8_t>({}));
1599         return emptyBstr.encode();
1600     }
1601     return protectedHeaders.encode();
1602 }
1603 
1604 // From https://tools.ietf.org/html/rfc8152
1605 const int COSE_LABEL_ALG = 1;
1606 const int COSE_LABEL_X5CHAIN = 33;  // temporary identifier
1607 
1608 // From "COSE Algorithms" registry
1609 const int COSE_ALG_ECDSA_256 = -7;
1610 const int COSE_ALG_HMAC_256_256 = 5;
1611 
ecdsaSignatureCoseToDer(const vector<uint8_t> & ecdsaCoseSignature,vector<uint8_t> & ecdsaDerSignature)1612 bool ecdsaSignatureCoseToDer(const vector<uint8_t>& ecdsaCoseSignature,
1613                              vector<uint8_t>& ecdsaDerSignature) {
1614     if (ecdsaCoseSignature.size() != 64) {
1615         LOG(ERROR) << "COSE signature length is " << ecdsaCoseSignature.size() << ", expected 64";
1616         return false;
1617     }
1618 
1619     auto rBn = BIGNUM_Ptr(BN_bin2bn(ecdsaCoseSignature.data(), 32, nullptr));
1620     if (rBn.get() == nullptr) {
1621         LOG(ERROR) << "Error creating BIGNUM for r";
1622         return false;
1623     }
1624 
1625     auto sBn = BIGNUM_Ptr(BN_bin2bn(ecdsaCoseSignature.data() + 32, 32, nullptr));
1626     if (sBn.get() == nullptr) {
1627         LOG(ERROR) << "Error creating BIGNUM for s";
1628         return false;
1629     }
1630 
1631     ECDSA_SIG sig;
1632     sig.r = rBn.get();
1633     sig.s = sBn.get();
1634 
1635     size_t len = i2d_ECDSA_SIG(&sig, nullptr);
1636     ecdsaDerSignature.resize(len);
1637     unsigned char* p = (unsigned char*)ecdsaDerSignature.data();
1638     i2d_ECDSA_SIG(&sig, &p);
1639 
1640     return true;
1641 }
1642 
ecdsaSignatureDerToCose(const vector<uint8_t> & ecdsaDerSignature,vector<uint8_t> & ecdsaCoseSignature)1643 bool ecdsaSignatureDerToCose(const vector<uint8_t>& ecdsaDerSignature,
1644                              vector<uint8_t>& ecdsaCoseSignature) {
1645     ECDSA_SIG* sig;
1646     const unsigned char* p = ecdsaDerSignature.data();
1647     sig = d2i_ECDSA_SIG(nullptr, &p, ecdsaDerSignature.size());
1648     if (sig == nullptr) {
1649         LOG(ERROR) << "Error decoding DER signature";
1650         return false;
1651     }
1652 
1653     ecdsaCoseSignature.clear();
1654     ecdsaCoseSignature.resize(64);
1655     if (BN_bn2binpad(ECDSA_SIG_get0_r(sig), ecdsaCoseSignature.data(), 32) != 32) {
1656         LOG(ERROR) << "Error encoding r";
1657         return false;
1658     }
1659     if (BN_bn2binpad(ECDSA_SIG_get0_s(sig), ecdsaCoseSignature.data() + 32, 32) != 32) {
1660         LOG(ERROR) << "Error encoding s";
1661         return false;
1662     }
1663     return true;
1664 }
1665 
coseSignEcDsaWithSignature(const vector<uint8_t> & signatureToBeSigned,const vector<uint8_t> & data,const vector<uint8_t> & certificateChain)1666 optional<vector<uint8_t>> coseSignEcDsaWithSignature(const vector<uint8_t>& signatureToBeSigned,
1667                                                      const vector<uint8_t>& data,
1668                                                      const vector<uint8_t>& certificateChain) {
1669     if (signatureToBeSigned.size() != 64) {
1670         LOG(ERROR) << "Invalid size for signatureToBeSigned, expected 64 got "
1671                    << signatureToBeSigned.size();
1672         return {};
1673     }
1674 
1675     cppbor::Map unprotectedHeaders;
1676     cppbor::Map protectedHeaders;
1677 
1678     protectedHeaders.add(COSE_LABEL_ALG, COSE_ALG_ECDSA_256);
1679 
1680     if (certificateChain.size() != 0) {
1681         optional<vector<vector<uint8_t>>> certs = support::certificateChainSplit(certificateChain);
1682         if (!certs) {
1683             LOG(ERROR) << "Error splitting certificate chain";
1684             return {};
1685         }
1686         if (certs.value().size() == 1) {
1687             unprotectedHeaders.add(COSE_LABEL_X5CHAIN, certs.value()[0]);
1688         } else {
1689             cppbor::Array certArray;
1690             for (const vector<uint8_t>& cert : certs.value()) {
1691                 certArray.add(cert);
1692             }
1693             unprotectedHeaders.add(COSE_LABEL_X5CHAIN, std::move(certArray));
1694         }
1695     }
1696 
1697     vector<uint8_t> encodedProtectedHeaders = coseEncodeHeaders(protectedHeaders);
1698 
1699     cppbor::Array coseSign1;
1700     coseSign1.add(encodedProtectedHeaders);
1701     coseSign1.add(std::move(unprotectedHeaders));
1702     if (data.size() == 0) {
1703         cppbor::Null nullValue;
1704         coseSign1.add(std::move(nullValue));
1705     } else {
1706         coseSign1.add(data);
1707     }
1708     coseSign1.add(signatureToBeSigned);
1709     vector<uint8_t> signatureCoseSign1;
1710     signatureCoseSign1 = coseSign1.encode();
1711 
1712     return signatureCoseSign1;
1713 }
1714 
coseSignEcDsa(const vector<uint8_t> & key,const vector<uint8_t> & data,const vector<uint8_t> & detachedContent,const vector<uint8_t> & certificateChain)1715 optional<vector<uint8_t>> coseSignEcDsa(const vector<uint8_t>& key, const vector<uint8_t>& data,
1716                                         const vector<uint8_t>& detachedContent,
1717                                         const vector<uint8_t>& certificateChain) {
1718     cppbor::Map unprotectedHeaders;
1719     cppbor::Map protectedHeaders;
1720 
1721     if (data.size() > 0 && detachedContent.size() > 0) {
1722         LOG(ERROR) << "data and detachedContent cannot both be non-empty";
1723         return {};
1724     }
1725 
1726     protectedHeaders.add(COSE_LABEL_ALG, COSE_ALG_ECDSA_256);
1727 
1728     if (certificateChain.size() != 0) {
1729         optional<vector<vector<uint8_t>>> certs = support::certificateChainSplit(certificateChain);
1730         if (!certs) {
1731             LOG(ERROR) << "Error splitting certificate chain";
1732             return {};
1733         }
1734         if (certs.value().size() == 1) {
1735             unprotectedHeaders.add(COSE_LABEL_X5CHAIN, certs.value()[0]);
1736         } else {
1737             cppbor::Array certArray;
1738             for (const vector<uint8_t>& cert : certs.value()) {
1739                 certArray.add(cert);
1740             }
1741             unprotectedHeaders.add(COSE_LABEL_X5CHAIN, std::move(certArray));
1742         }
1743     }
1744 
1745     vector<uint8_t> encodedProtectedHeaders = coseEncodeHeaders(protectedHeaders);
1746     vector<uint8_t> toBeSigned =
1747             coseBuildToBeSigned(encodedProtectedHeaders, data, detachedContent);
1748 
1749     optional<vector<uint8_t>> derSignature = signEcDsa(key, toBeSigned);
1750     if (!derSignature) {
1751         LOG(ERROR) << "Error signing toBeSigned data";
1752         return {};
1753     }
1754     vector<uint8_t> coseSignature;
1755     if (!ecdsaSignatureDerToCose(derSignature.value(), coseSignature)) {
1756         LOG(ERROR) << "Error converting ECDSA signature from DER to COSE format";
1757         return {};
1758     }
1759 
1760     cppbor::Array coseSign1;
1761     coseSign1.add(encodedProtectedHeaders);
1762     coseSign1.add(std::move(unprotectedHeaders));
1763     if (data.size() == 0) {
1764         cppbor::Null nullValue;
1765         coseSign1.add(std::move(nullValue));
1766     } else {
1767         coseSign1.add(data);
1768     }
1769     coseSign1.add(coseSignature);
1770     vector<uint8_t> signatureCoseSign1;
1771     signatureCoseSign1 = coseSign1.encode();
1772     return signatureCoseSign1;
1773 }
1774 
coseCheckEcDsaSignature(const vector<uint8_t> & signatureCoseSign1,const vector<uint8_t> & detachedContent,const vector<uint8_t> & publicKey)1775 bool coseCheckEcDsaSignature(const vector<uint8_t>& signatureCoseSign1,
1776                              const vector<uint8_t>& detachedContent,
1777                              const vector<uint8_t>& publicKey) {
1778     auto [item, _, message] = cppbor::parse(signatureCoseSign1);
1779     if (item == nullptr) {
1780         LOG(ERROR) << "Passed-in COSE_Sign1 is not valid CBOR: " << message;
1781         return false;
1782     }
1783     const cppbor::Array* array = item->asArray();
1784     if (array == nullptr) {
1785         LOG(ERROR) << "Value for COSE_Sign1 is not an array";
1786         return false;
1787     }
1788     if (array->size() != 4) {
1789         LOG(ERROR) << "Value for COSE_Sign1 is not an array of size 4";
1790         return false;
1791     }
1792 
1793     const cppbor::Bstr* encodedProtectedHeadersBstr = (*array)[0]->asBstr();
1794     ;
1795     if (encodedProtectedHeadersBstr == nullptr) {
1796         LOG(ERROR) << "Value for encodedProtectedHeaders is not a bstr";
1797         return false;
1798     }
1799     const vector<uint8_t> encodedProtectedHeaders = encodedProtectedHeadersBstr->value();
1800 
1801     const cppbor::Map* unprotectedHeaders = (*array)[1]->asMap();
1802     if (unprotectedHeaders == nullptr) {
1803         LOG(ERROR) << "Value for unprotectedHeaders is not a map";
1804         return false;
1805     }
1806 
1807     vector<uint8_t> data;
1808     const cppbor::Simple* payloadAsSimple = (*array)[2]->asSimple();
1809     if (payloadAsSimple != nullptr) {
1810         if (payloadAsSimple->asNull() == nullptr) {
1811             LOG(ERROR) << "Value for payload is not null or a bstr";
1812             return false;
1813         }
1814     } else {
1815         const cppbor::Bstr* payloadAsBstr = (*array)[2]->asBstr();
1816         if (payloadAsBstr == nullptr) {
1817             LOG(ERROR) << "Value for payload is not null or a bstr";
1818             return false;
1819         }
1820         data = payloadAsBstr->value();  // TODO: avoid copy
1821     }
1822 
1823     if (data.size() > 0 && detachedContent.size() > 0) {
1824         LOG(ERROR) << "data and detachedContent cannot both be non-empty";
1825         return false;
1826     }
1827 
1828     const cppbor::Bstr* signatureBstr = (*array)[3]->asBstr();
1829     if (signatureBstr == nullptr) {
1830         LOG(ERROR) << "Value for signature is a bstr";
1831         return false;
1832     }
1833     const vector<uint8_t>& coseSignature = signatureBstr->value();
1834 
1835     vector<uint8_t> derSignature;
1836     if (!ecdsaSignatureCoseToDer(coseSignature, derSignature)) {
1837         LOG(ERROR) << "Error converting ECDSA signature from COSE to DER format";
1838         return false;
1839     }
1840 
1841     vector<uint8_t> toBeSigned =
1842             coseBuildToBeSigned(encodedProtectedHeaders, data, detachedContent);
1843     if (!checkEcDsaSignature(support::sha256(toBeSigned), derSignature, publicKey)) {
1844         LOG(ERROR) << "Signature check failed";
1845         return false;
1846     }
1847     return true;
1848 }
1849 
1850 // Extracts the signature (of the ToBeSigned CBOR) from a COSE_Sign1.
coseSignGetSignature(const vector<uint8_t> & signatureCoseSign1)1851 optional<vector<uint8_t>> coseSignGetSignature(const vector<uint8_t>& signatureCoseSign1) {
1852     auto [item, _, message] = cppbor::parse(signatureCoseSign1);
1853     if (item == nullptr) {
1854         LOG(ERROR) << "Passed-in COSE_Sign1 is not valid CBOR: " << message;
1855         return {};
1856     }
1857     const cppbor::Array* array = item->asArray();
1858     if (array == nullptr) {
1859         LOG(ERROR) << "Value for COSE_Sign1 is not an array";
1860         return {};
1861     }
1862     if (array->size() != 4) {
1863         LOG(ERROR) << "Value for COSE_Sign1 is not an array of size 4";
1864         return {};
1865     }
1866 
1867     vector<uint8_t> signature;
1868     const cppbor::Bstr* signatureAsBstr = (*array)[3]->asBstr();
1869     if (signatureAsBstr == nullptr) {
1870         LOG(ERROR) << "Value for signature is not a bstr";
1871         return {};
1872     }
1873     // Copy payload into |data|
1874     signature = signatureAsBstr->value();
1875 
1876     return signature;
1877 }
1878 
coseSignGetPayload(const vector<uint8_t> & signatureCoseSign1)1879 optional<vector<uint8_t>> coseSignGetPayload(const vector<uint8_t>& signatureCoseSign1) {
1880     auto [item, _, message] = cppbor::parse(signatureCoseSign1);
1881     if (item == nullptr) {
1882         LOG(ERROR) << "Passed-in COSE_Sign1 is not valid CBOR: " << message;
1883         return {};
1884     }
1885     const cppbor::Array* array = item->asArray();
1886     if (array == nullptr) {
1887         LOG(ERROR) << "Value for COSE_Sign1 is not an array";
1888         return {};
1889     }
1890     if (array->size() != 4) {
1891         LOG(ERROR) << "Value for COSE_Sign1 is not an array of size 4";
1892         return {};
1893     }
1894 
1895     vector<uint8_t> data;
1896     const cppbor::Simple* payloadAsSimple = (*array)[2]->asSimple();
1897     if (payloadAsSimple != nullptr) {
1898         if (payloadAsSimple->asNull() == nullptr) {
1899             LOG(ERROR) << "Value for payload is not null or a bstr";
1900             return {};
1901         }
1902         // payload is null, so |data| should be empty (as it is)
1903     } else {
1904         const cppbor::Bstr* payloadAsBstr = (*array)[2]->asBstr();
1905         if (payloadAsBstr == nullptr) {
1906             LOG(ERROR) << "Value for payload is not null or a bstr";
1907             return {};
1908         }
1909         // Copy payload into |data|
1910         data = payloadAsBstr->value();
1911     }
1912 
1913     return data;
1914 }
1915 
coseSignGetAlg(const vector<uint8_t> & signatureCoseSign1)1916 optional<int> coseSignGetAlg(const vector<uint8_t>& signatureCoseSign1) {
1917     auto [item, _, message] = cppbor::parse(signatureCoseSign1);
1918     if (item == nullptr) {
1919         LOG(ERROR) << "Passed-in COSE_Sign1 is not valid CBOR: " << message;
1920         return {};
1921     }
1922     const cppbor::Array* array = item->asArray();
1923     if (array == nullptr) {
1924         LOG(ERROR) << "Value for COSE_Sign1 is not an array";
1925         return {};
1926     }
1927     if (array->size() != 4) {
1928         LOG(ERROR) << "Value for COSE_Sign1 is not an array of size 4";
1929         return {};
1930     }
1931 
1932     const cppbor::Bstr* protectedHeadersBytes = (*array)[0]->asBstr();
1933     if (protectedHeadersBytes == nullptr) {
1934         LOG(ERROR) << "Value for protectedHeaders is not a bstr";
1935         return {};
1936     }
1937     auto [item2, _2, message2] = cppbor::parse(protectedHeadersBytes->value());
1938     if (item2 == nullptr) {
1939         LOG(ERROR) << "Error parsing protectedHeaders: " << message2;
1940         return {};
1941     }
1942     const cppbor::Map* protectedHeaders = item2->asMap();
1943     if (protectedHeaders == nullptr) {
1944         LOG(ERROR) << "Decoded CBOR for protectedHeaders is not a map";
1945         return {};
1946     }
1947 
1948     for (size_t n = 0; n < protectedHeaders->size(); n++) {
1949         auto& [keyItem, valueItem] = (*protectedHeaders)[n];
1950         const cppbor::Int* number = keyItem->asInt();
1951         if (number == nullptr) {
1952             LOG(ERROR) << "Key item in top-level map is not a number";
1953             return {};
1954         }
1955         int label = number->value();
1956         if (label == COSE_LABEL_ALG) {
1957             const cppbor::Int* number = valueItem->asInt();
1958             if (number != nullptr) {
1959                 return number->value();
1960             }
1961             LOG(ERROR) << "Value for COSE_LABEL_ALG label is not a number";
1962             return {};
1963         }
1964     }
1965     LOG(ERROR) << "Did not find COSE_LABEL_ALG label in protected headers";
1966     return {};
1967 }
1968 
coseSignGetX5Chain(const vector<uint8_t> & signatureCoseSign1)1969 optional<vector<uint8_t>> coseSignGetX5Chain(const vector<uint8_t>& signatureCoseSign1) {
1970     auto [item, _, message] = cppbor::parse(signatureCoseSign1);
1971     if (item == nullptr) {
1972         LOG(ERROR) << "Passed-in COSE_Sign1 is not valid CBOR: " << message;
1973         return {};
1974     }
1975     const cppbor::Array* array = item->asArray();
1976     if (array == nullptr) {
1977         LOG(ERROR) << "Value for COSE_Sign1 is not an array";
1978         return {};
1979     }
1980     if (array->size() != 4) {
1981         LOG(ERROR) << "Value for COSE_Sign1 is not an array of size 4";
1982         return {};
1983     }
1984 
1985     const cppbor::Map* unprotectedHeaders = (*array)[1]->asMap();
1986     if (unprotectedHeaders == nullptr) {
1987         LOG(ERROR) << "Value for unprotectedHeaders is not a map";
1988         return {};
1989     }
1990 
1991     for (size_t n = 0; n < unprotectedHeaders->size(); n++) {
1992         auto& [keyItem, valueItem] = (*unprotectedHeaders)[n];
1993         const cppbor::Int* number = keyItem->asInt();
1994         if (number == nullptr) {
1995             LOG(ERROR) << "Key item in top-level map is not a number";
1996             return {};
1997         }
1998         int label = number->value();
1999         if (label == COSE_LABEL_X5CHAIN) {
2000             const cppbor::Bstr* bstr = valueItem->asBstr();
2001             if (bstr != nullptr) {
2002                 return bstr->value();
2003             }
2004             const cppbor::Array* array = valueItem->asArray();
2005             if (array != nullptr) {
2006                 vector<uint8_t> certs;
2007                 for (size_t m = 0; m < array->size(); m++) {
2008                     const cppbor::Bstr* bstr = ((*array)[m])->asBstr();
2009                     if (bstr == nullptr) {
2010                         LOG(ERROR) << "Item in x5chain array is not a bstr";
2011                         return {};
2012                     }
2013                     const vector<uint8_t>& certValue = bstr->value();
2014                     certs.insert(certs.end(), certValue.begin(), certValue.end());
2015                 }
2016                 return certs;
2017             }
2018             LOG(ERROR) << "Value for x5chain label is not a bstr or array";
2019             return {};
2020         }
2021     }
2022     LOG(ERROR) << "Did not find x5chain label in unprotected headers";
2023     return {};
2024 }
2025 
coseBuildToBeMACed(const vector<uint8_t> & encodedProtectedHeaders,const vector<uint8_t> & data,const vector<uint8_t> & detachedContent)2026 vector<uint8_t> coseBuildToBeMACed(const vector<uint8_t>& encodedProtectedHeaders,
2027                                    const vector<uint8_t>& data,
2028                                    const vector<uint8_t>& detachedContent) {
2029     cppbor::Array macStructure;
2030     macStructure.add("MAC0");
2031     macStructure.add(encodedProtectedHeaders);
2032 
2033     // We currently don't support Externally Supplied Data (RFC 8152 section 4.3)
2034     // so external_aad is the empty bstr
2035     vector<uint8_t> emptyExternalAad;
2036     macStructure.add(emptyExternalAad);
2037 
2038     // Next field is the payload, independently of how it's transported (RFC
2039     // 8152 section 4.4). Since our API specifies only one of |data| and
2040     // |detachedContent| can be non-empty, it's simply just the non-empty one.
2041     if (data.size() > 0) {
2042         macStructure.add(data);
2043     } else {
2044         macStructure.add(detachedContent);
2045     }
2046 
2047     return macStructure.encode();
2048 }
2049 
coseMac0(const vector<uint8_t> & key,const vector<uint8_t> & data,const vector<uint8_t> & detachedContent)2050 optional<vector<uint8_t>> coseMac0(const vector<uint8_t>& key, const vector<uint8_t>& data,
2051                                    const vector<uint8_t>& detachedContent) {
2052     cppbor::Map unprotectedHeaders;
2053     cppbor::Map protectedHeaders;
2054 
2055     if (data.size() > 0 && detachedContent.size() > 0) {
2056         LOG(ERROR) << "data and detachedContent cannot both be non-empty";
2057         return {};
2058     }
2059 
2060     protectedHeaders.add(COSE_LABEL_ALG, COSE_ALG_HMAC_256_256);
2061 
2062     vector<uint8_t> encodedProtectedHeaders = coseEncodeHeaders(protectedHeaders);
2063     vector<uint8_t> toBeMACed = coseBuildToBeMACed(encodedProtectedHeaders, data, detachedContent);
2064 
2065     optional<vector<uint8_t>> mac = hmacSha256(key, toBeMACed);
2066     if (!mac) {
2067         LOG(ERROR) << "Error MACing toBeMACed data";
2068         return {};
2069     }
2070 
2071     cppbor::Array array;
2072     array.add(encodedProtectedHeaders);
2073     array.add(std::move(unprotectedHeaders));
2074     if (data.size() == 0) {
2075         cppbor::Null nullValue;
2076         array.add(std::move(nullValue));
2077     } else {
2078         array.add(data);
2079     }
2080     array.add(mac.value());
2081     return array.encode();
2082 }
2083 
coseMacWithDigest(const vector<uint8_t> & digestToBeMaced,const vector<uint8_t> & data)2084 optional<vector<uint8_t>> coseMacWithDigest(const vector<uint8_t>& digestToBeMaced,
2085                                             const vector<uint8_t>& data) {
2086     cppbor::Map unprotectedHeaders;
2087     cppbor::Map protectedHeaders;
2088 
2089     protectedHeaders.add(COSE_LABEL_ALG, COSE_ALG_HMAC_256_256);
2090 
2091     vector<uint8_t> encodedProtectedHeaders = coseEncodeHeaders(protectedHeaders);
2092 
2093     cppbor::Array array;
2094     array.add(encodedProtectedHeaders);
2095     array.add(std::move(unprotectedHeaders));
2096     if (data.size() == 0) {
2097         cppbor::Null nullValue;
2098         array.add(std::move(nullValue));
2099     } else {
2100         array.add(data);
2101     }
2102     array.add(digestToBeMaced);
2103     return array.encode();
2104 }
2105 
2106 // ---------------------------------------------------------------------------
2107 // Utility functions specific to IdentityCredential.
2108 // ---------------------------------------------------------------------------
2109 
calcEMacKey(const vector<uint8_t> & privateKey,const vector<uint8_t> & publicKey,const vector<uint8_t> & sessionTranscriptBytes)2110 optional<vector<uint8_t>> calcEMacKey(const vector<uint8_t>& privateKey,
2111                                       const vector<uint8_t>& publicKey,
2112                                       const vector<uint8_t>& sessionTranscriptBytes) {
2113     optional<vector<uint8_t>> sharedSecret = support::ecdh(publicKey, privateKey);
2114     if (!sharedSecret) {
2115         LOG(ERROR) << "Error performing ECDH";
2116         return {};
2117     }
2118     vector<uint8_t> salt = support::sha256(sessionTranscriptBytes);
2119     vector<uint8_t> info = {'E', 'M', 'a', 'c', 'K', 'e', 'y'};
2120     optional<vector<uint8_t>> derivedKey = support::hkdf(sharedSecret.value(), salt, info, 32);
2121     if (!derivedKey) {
2122         LOG(ERROR) << "Error performing HKDF";
2123         return {};
2124     }
2125     return derivedKey.value();
2126 }
2127 
calcMac(const vector<uint8_t> & sessionTranscriptEncoded,const string & docType,const vector<uint8_t> & deviceNameSpacesEncoded,const vector<uint8_t> & eMacKey)2128 optional<vector<uint8_t>> calcMac(const vector<uint8_t>& sessionTranscriptEncoded,
2129                                   const string& docType,
2130                                   const vector<uint8_t>& deviceNameSpacesEncoded,
2131                                   const vector<uint8_t>& eMacKey) {
2132     auto [sessionTranscriptItem, _, errMsg] = cppbor::parse(sessionTranscriptEncoded);
2133     if (sessionTranscriptItem == nullptr) {
2134         LOG(ERROR) << "Error parsing sessionTranscriptEncoded: " << errMsg;
2135         return {};
2136     }
2137     // The data that is MACed is ["DeviceAuthentication", sessionTranscript, docType,
2138     // deviceNameSpacesBytes] so build up that structure
2139     cppbor::Array deviceAuthentication =
2140             cppbor::Array()
2141                     .add("DeviceAuthentication")
2142                     .add(std::move(sessionTranscriptItem))
2143                     .add(docType)
2144                     .add(cppbor::SemanticTag(kSemanticTagEncodedCbor, deviceNameSpacesEncoded));
2145     vector<uint8_t> deviceAuthenticationBytes =
2146             cppbor::SemanticTag(kSemanticTagEncodedCbor, deviceAuthentication.encode()).encode();
2147     optional<vector<uint8_t>> calculatedMac =
2148             support::coseMac0(eMacKey, {},                 // payload
2149                               deviceAuthenticationBytes);  // detached content
2150     return calculatedMac;
2151 }
2152 
chunkVector(const vector<uint8_t> & content,size_t maxChunkSize)2153 vector<vector<uint8_t>> chunkVector(const vector<uint8_t>& content, size_t maxChunkSize) {
2154     vector<vector<uint8_t>> ret;
2155 
2156     size_t contentSize = content.size();
2157     if (contentSize <= maxChunkSize) {
2158         ret.push_back(content);
2159         return ret;
2160     }
2161 
2162     size_t numChunks = (contentSize + maxChunkSize - 1) / maxChunkSize;
2163 
2164     size_t pos = 0;
2165     for (size_t n = 0; n < numChunks; n++) {
2166         size_t size = contentSize - pos;
2167         if (size > maxChunkSize) {
2168             size = maxChunkSize;
2169         }
2170         auto begin = content.begin() + pos;
2171         auto end = content.begin() + pos + size;
2172         ret.emplace_back(vector<uint8_t>(begin, end));
2173         pos += maxChunkSize;
2174     }
2175 
2176     return ret;
2177 }
2178 
2179 vector<uint8_t> testHardwareBoundKey = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
2180 
getTestHardwareBoundKey()2181 const vector<uint8_t>& getTestHardwareBoundKey() {
2182     return testHardwareBoundKey;
2183 }
2184 
2185 }  // namespace support
2186 }  // namespace identity
2187 }  // namespace hardware
2188 }  // namespace android
2189