1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "ohttp_jni.h"
17 #include "jni_util.h"
18 
19 #include <android/log.h>
20 #include <openssl/digest.h>
21 #include <openssl/hkdf.h>
22 #include <openssl/hpke.h>
23 
24 #include <iostream>
25 #include <string_view>
26 #include <vector>
27 
28 constexpr char const *LOG_TAG = "OhttpJniWrapper";
29 constexpr const char *IllegalArgumentExceptionClass = "java/lang/IllegalArgumentException";
30 constexpr const char* IllegalStateExceptionClass = "java/lang/IllegalStateException";
31 
32 // TODO(b/274425716) : Use macros similar to Conscrypt's JNI_TRACE for cleaner
33 // logging
34 // TODO(b/274598556) : Add error throwing convenience methods
35 
36 JNIEXPORT jlong JNICALL
Java_com_android_adservices_ohttp_OhttpJniWrapper_hpkeKemDhkemX25519HkdfSha256(JNIEnv * env,jclass)37 Java_com_android_adservices_ohttp_OhttpJniWrapper_hpkeKemDhkemX25519HkdfSha256(
38     JNIEnv *env, jclass) {
39   __android_log_write(ANDROID_LOG_INFO, LOG_TAG,
40                       "hpkeKemDhkemX25519HkdfSha256");
41 
42   const EVP_HPKE_KEM *ctx = EVP_hpke_x25519_hkdf_sha256();
43   return reinterpret_cast<jlong>(ctx);
44 }
45 
46 JNIEXPORT jlong JNICALL
Java_com_android_adservices_ohttp_OhttpJniWrapper_hpkeKdfHkdfSha256(JNIEnv * env,jclass)47 Java_com_android_adservices_ohttp_OhttpJniWrapper_hpkeKdfHkdfSha256(JNIEnv *env,
48                                                                     jclass) {
49   __android_log_write(ANDROID_LOG_INFO, LOG_TAG, "hpkeKdfHkdfSha256");
50 
51   const EVP_HPKE_KDF *ctx = EVP_hpke_hkdf_sha256();
52   return reinterpret_cast<jlong>(ctx);
53 }
54 
55 JNIEXPORT jlong JNICALL
Java_com_android_adservices_ohttp_OhttpJniWrapper_hpkeAeadAes256Gcm(JNIEnv * env,jclass)56 Java_com_android_adservices_ohttp_OhttpJniWrapper_hpkeAeadAes256Gcm(JNIEnv *env,
57                                                                     jclass) {
58   __android_log_write(ANDROID_LOG_INFO, LOG_TAG, "hpkeAeadAes256Gcm");
59 
60   const EVP_HPKE_AEAD *ctx = EVP_hpke_aes_256_gcm();
61   return reinterpret_cast<jlong>(ctx);
62 }
63 
64 JNIEXPORT jlong JNICALL
Java_com_android_adservices_ohttp_OhttpJniWrapper_hkdfSha256MessageDigest(JNIEnv * env,jclass)65 Java_com_android_adservices_ohttp_OhttpJniWrapper_hkdfSha256MessageDigest(
66     JNIEnv *env, jclass) {
67   __android_log_write(ANDROID_LOG_INFO, LOG_TAG, "hkdfSha256MessageDigest");
68 
69   const EVP_MD *evp_md = EVP_sha256();
70   return reinterpret_cast<jlong>(evp_md);
71 }
72 
73 JNIEXPORT void JNICALL
Java_com_android_adservices_ohttp_OhttpJniWrapper_hpkeCtxFree(JNIEnv * env,jclass,jlong hpkeCtxRef)74 Java_com_android_adservices_ohttp_OhttpJniWrapper_hpkeCtxFree(
75     JNIEnv *env, jclass, jlong hpkeCtxRef) {
76   __android_log_write(ANDROID_LOG_INFO, LOG_TAG, "hpkeCtxFree");
77 
78   EVP_HPKE_CTX *ctx = reinterpret_cast<EVP_HPKE_CTX *>(hpkeCtxRef);
79   if (ctx != nullptr) {
80     EVP_HPKE_CTX_free(ctx);
81   }
82 }
83 
84 JNIEXPORT jlong JNICALL
Java_com_android_adservices_ohttp_OhttpJniWrapper_hpkeCtxNew(JNIEnv * env,jclass)85 Java_com_android_adservices_ohttp_OhttpJniWrapper_hpkeCtxNew(JNIEnv *env,
86                                                              jclass) {
87   __android_log_write(ANDROID_LOG_INFO, LOG_TAG, "hpkeCtxNew");
88 
89   const EVP_HPKE_CTX *ctx = EVP_HPKE_CTX_new();
90   return reinterpret_cast<jlong>(ctx);
91 }
92 
93 // Defining EVP_HPKE_KEM struct with only the field needed to call the
94 // function "EVP_HPKE_CTX_setup_sender_with_seed_for_testing" using
95 // "kem->seed_len"
96 struct evp_hpke_kem_st {
97   size_t seed_len;
98 };
99 
100 JNIEXPORT jbyteArray JNICALL
Java_com_android_adservices_ohttp_OhttpJniWrapper_hpkeCtxSetupSenderWithSeed(JNIEnv * env,jclass,jlong senderHpkeCtxRef,jlong evpKemRef,jlong evpKdfRef,jlong evpAeadRef,jbyteArray publicKeyArray,jbyteArray infoArray,jbyteArray seedArray)101 Java_com_android_adservices_ohttp_OhttpJniWrapper_hpkeCtxSetupSenderWithSeed(
102     JNIEnv *env, jclass, jlong senderHpkeCtxRef, jlong evpKemRef,
103     jlong evpKdfRef, jlong evpAeadRef, jbyteArray publicKeyArray,
104     jbyteArray infoArray, jbyteArray seedArray) {
105   __android_log_write(ANDROID_LOG_INFO, LOG_TAG, "hpkeCtxSetupSenderWithSeed");
106 
107   EVP_HPKE_CTX *ctx = reinterpret_cast<EVP_HPKE_CTX *>(senderHpkeCtxRef);
108   if (ctx == nullptr) {
109     // TODO(b/274598556) : throw NullPointerException
110     __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "hpke context is null");
111     return {};
112   }
113 
114   const EVP_HPKE_KEM *kem = reinterpret_cast<const EVP_HPKE_KEM *>(evpKemRef);
115   const EVP_HPKE_KDF *kdf = reinterpret_cast<const EVP_HPKE_KDF *>(evpKdfRef);
116   const EVP_HPKE_AEAD *aead =
117       reinterpret_cast<const EVP_HPKE_AEAD *>(evpAeadRef);
118 
119   __android_log_print(
120       ANDROID_LOG_INFO, LOG_TAG,
121       "EVP_HPKE_CTX_setup_sender_with_seed(%p, %ld, %ld, %ld, %p, %p, %p)", ctx,
122       (long)evpKemRef, (long)evpKdfRef, (long)evpAeadRef, publicKeyArray,
123       infoArray, seedArray);
124 
125   if (kem == nullptr || kdf == nullptr || aead == nullptr) {
126     __android_log_print(ANDROID_LOG_ERROR, LOG_TAG,
127                         "kem or kdf or aead is null");
128     return {};
129   }
130 
131   if (publicKeyArray == nullptr || seedArray == nullptr) {
132     __android_log_print(ANDROID_LOG_ERROR, LOG_TAG,
133                         "public key array or seed array is null");
134     return {};
135   }
136 
137   jbyte *peer_public_key = env->GetByteArrayElements(publicKeyArray, 0);
138   jbyte *seed = env->GetByteArrayElements(seedArray, 0);
139 
140   jbyte *infoArrayBytes = nullptr;
141   const uint8_t *info = nullptr;
142   size_t infoLen = 0;
143   if (infoArray != nullptr) {
144     infoArrayBytes = env->GetByteArrayElements(infoArray, 0);
145     info = reinterpret_cast<const uint8_t *>(infoArrayBytes);
146     infoLen = env->GetArrayLength(infoArray);
147   }
148 
149   size_t encapsulatedSharedSecretLen;
150   std::vector<uint8_t> encapsulatedSharedSecret(EVP_HPKE_MAX_ENC_LENGTH);
151   if (!EVP_HPKE_CTX_setup_sender_with_seed_for_testing(
152           /* ctx= */ ctx,
153           /* out_enc= */ encapsulatedSharedSecret.data(),
154           /* out_enc_len= */ &encapsulatedSharedSecretLen,
155           /* max_enc= */ encapsulatedSharedSecret.size(),
156           /* kem= */ kem,
157           /* kdf= */ kdf,
158           /* aead= */ aead,
159           /* peer_public_key= */
160           reinterpret_cast<const uint8_t *>(peer_public_key),
161           /* peer_public_key_len= */ env->GetArrayLength(publicKeyArray),
162           /* info= */ info,
163           /* info_len= */ infoLen,
164           /* seed= */ reinterpret_cast<const uint8_t *>(seed),
165           /* seed_len= */ kem->seed_len)) {
166     env->ReleaseByteArrayElements(publicKeyArray, peer_public_key, JNI_ABORT);
167     env->ReleaseByteArrayElements(seedArray, seed, JNI_ABORT);
168 
169     if (infoArrayBytes != nullptr) {
170       env->ReleaseByteArrayElements(infoArray, infoArrayBytes, JNI_ABORT);
171     }
172 
173     __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "setup sender returned 0");
174     return {};
175   }
176 
177   env->ReleaseByteArrayElements(publicKeyArray, peer_public_key, JNI_ABORT);
178   env->ReleaseByteArrayElements(seedArray, seed, JNI_ABORT);
179 
180   if (infoArrayBytes != nullptr) {
181     env->ReleaseByteArrayElements(infoArray, infoArrayBytes, JNI_ABORT);
182   }
183 
184   jbyteArray encArray = env->NewByteArray(encapsulatedSharedSecretLen);
185   env->SetByteArrayRegion(
186       encArray, 0, encapsulatedSharedSecretLen,
187       reinterpret_cast<const jbyte *>(encapsulatedSharedSecret.data()));
188 
189   return encArray;
190 }
191 
192 JNIEXPORT jbyteArray JNICALL
Java_com_android_adservices_ohttp_OhttpJniWrapper_gatewayDecrypt(JNIEnv * env,jclass javaClass,jlong hpkeCtxRef,jlong evpKemRef,jlong evpKdfRef,jlong evpAeadRef,jbyteArray encryptedDataArray)193 Java_com_android_adservices_ohttp_OhttpJniWrapper_gatewayDecrypt(
194     JNIEnv *env,
195     jclass javaClass,
196     jlong hpkeCtxRef,
197     jlong evpKemRef,
198     jlong evpKdfRef,
199     jlong evpAeadRef,
200     jbyteArray encryptedDataArray)
201     {
202   __android_log_write(ANDROID_LOG_INFO, LOG_TAG, "gatewayDecrypt");
203 
204   EVP_HPKE_CTX *gatewayCtx = reinterpret_cast<EVP_HPKE_CTX *>(hpkeCtxRef);
205 
206   jbyte* encryptedDataPtr = env->GetByteArrayElements(encryptedDataArray, 0);
207   size_t encryptedDataLen = env->GetArrayLength(encryptedDataArray);
208 
209   std::string decrypted(encryptedDataLen, '\0');
210   size_t decryptedLen;
211   if (!EVP_HPKE_CTX_open(
212           gatewayCtx, reinterpret_cast<uint8_t*>(decrypted.data()),
213           &decryptedLen, decrypted.size(),
214           reinterpret_cast<const uint8_t*>(encryptedDataPtr),
215           encryptedDataLen, nullptr, 0)) {
216     env->ReleaseByteArrayElements(encryptedDataArray, encryptedDataPtr, JNI_ABORT);
217     jni_util::JniUtil::ThrowJavaException(
218             env,
219             IllegalStateExceptionClass,
220             "Could't decrypt ciphertext");
221     return {};
222   }
223 
224   decrypted.resize(decryptedLen);
225   env->ReleaseByteArrayElements(encryptedDataArray, encryptedDataPtr, JNI_ABORT);
226 
227   jbyteArray decryptedArray = env->NewByteArray(decryptedLen);
228   env->SetByteArrayRegion(decryptedArray, 0, decryptedLen,
229                           reinterpret_cast<const jbyte *>(decrypted.data()));
230   return decryptedArray;
231 }
232 
233 JNIEXPORT jboolean JNICALL
Java_com_android_adservices_ohttp_OhttpJniWrapper_hpkeSetupRecipient(JNIEnv * env,jclass,jlong hpkeCtxRef,jlong evpKemRef,jlong evpKdfRef,jlong evpAeadRef,jbyteArray privKeyArray,jbyteArray encArray,jbyteArray infoArray)234 Java_com_android_adservices_ohttp_OhttpJniWrapper_hpkeSetupRecipient(
235     JNIEnv *env, jclass,
236     jlong hpkeCtxRef,
237     jlong evpKemRef,
238     jlong evpKdfRef,
239     jlong evpAeadRef,
240     jbyteArray privKeyArray,
241     jbyteArray encArray,
242     jbyteArray infoArray)
243     {
244   __android_log_write(ANDROID_LOG_INFO, LOG_TAG, "hpkeSetupRecipient");
245 
246   if (infoArray == nullptr ||
247         encArray == nullptr ||
248         privKeyArray == nullptr) {
249     jni_util::JniUtil::ThrowJavaException(
250             env,
251             IllegalArgumentExceptionClass,
252             "One of required input parameters is null");
253     return {};
254   }
255 
256   const EVP_HPKE_KEM *kem = reinterpret_cast<const EVP_HPKE_KEM *>(evpKemRef);
257   const EVP_HPKE_KDF *kdf = reinterpret_cast<const EVP_HPKE_KDF *>(evpKdfRef);
258   const EVP_HPKE_AEAD *aead =
259       reinterpret_cast<const EVP_HPKE_AEAD *>(evpAeadRef);
260 
261   if (kem == nullptr || kdf == nullptr || aead == nullptr) {
262     jni_util::JniUtil::ThrowJavaException(
263             env,
264             IllegalArgumentExceptionClass,
265             "One of HPKE Algorithms is null");
266     return (jboolean) 0;
267   }
268 
269   jbyte* privKeyPtr = env->GetByteArrayElements(privKeyArray, 0);
270   size_t keyLength = env->GetArrayLength(privKeyArray);
271 
272   EVP_HPKE_KEY *recipientKey = EVP_HPKE_KEY_new();
273   if (recipientKey == nullptr) {
274     env->ReleaseByteArrayElements(privKeyArray, privKeyPtr, JNI_ABORT);
275     jni_util::JniUtil::ThrowJavaException(
276             env,
277             IllegalStateExceptionClass,
278             "Could't create new ENV_HPKE_KEY");
279     return (jboolean) 0;
280   }
281 
282   if (!EVP_HPKE_KEY_init(
283           recipientKey, kem,
284           reinterpret_cast<const uint8_t*>(privKeyPtr),
285           keyLength)) {
286     env->ReleaseByteArrayElements(privKeyArray, privKeyPtr, JNI_ABORT);
287     jni_util::JniUtil::ThrowJavaException(
288             env,
289             IllegalStateExceptionClass,
290             "Could't initialize ENV_HPKE_KEY with gateway private key");
291       return (jboolean) 0;
292   }
293 
294   jbyte* encPtr = env->GetByteArrayElements(encArray, 0);
295   size_t encLength = env->GetArrayLength(encArray);
296 
297   jbyte* infoPtr = env->GetByteArrayElements(infoArray, 0);
298   size_t infoLen = env->GetArrayLength(infoArray);
299 
300   EVP_HPKE_CTX *gatewayCtx = reinterpret_cast<EVP_HPKE_CTX *>(hpkeCtxRef);
301   if (gatewayCtx == nullptr) {
302     env->ReleaseByteArrayElements(privKeyArray, privKeyPtr, JNI_ABORT);
303     env->ReleaseByteArrayElements(encArray, encPtr, JNI_ABORT);
304     env->ReleaseByteArrayElements(infoArray, infoPtr, JNI_ABORT);
305     jni_util::JniUtil::ThrowJavaException(
306             env,
307             IllegalStateExceptionClass,
308             "Could't get HPKE context");
309     return (jboolean) 0;
310   }
311 
312   if (!EVP_HPKE_CTX_setup_recipient(
313           gatewayCtx, recipientKey, kdf,
314           aead,
315           reinterpret_cast<const uint8_t*>(encPtr),
316           encLength,
317           reinterpret_cast<const uint8_t*>(infoPtr), infoLen)) {
318     env->ReleaseByteArrayElements(privKeyArray, privKeyPtr, JNI_ABORT);
319     env->ReleaseByteArrayElements(encArray, encPtr, JNI_ABORT);
320     env->ReleaseByteArrayElements(infoArray, infoPtr, JNI_ABORT);
321     jni_util::JniUtil::ThrowJavaException(
322             env,
323             IllegalStateExceptionClass,
324             "Could't setup receiver context");
325     return (jboolean) 0;
326   }
327 
328   // Release resources
329   env->ReleaseByteArrayElements(privKeyArray, privKeyPtr, JNI_ABORT);
330   env->ReleaseByteArrayElements(encArray, encPtr, JNI_ABORT);
331   env->ReleaseByteArrayElements(infoArray, infoPtr, JNI_ABORT);
332 
333   return (jboolean) 1;
334 }
335 
336 JNIEXPORT jbyteArray JNICALL
Java_com_android_adservices_ohttp_OhttpJniWrapper_hpkeCtxSeal(JNIEnv * env,jclass,jlong senderHpkeCtxRef,jbyteArray plaintextArray,jbyteArray aadArray)337 Java_com_android_adservices_ohttp_OhttpJniWrapper_hpkeCtxSeal(
338     JNIEnv *env, jclass, jlong senderHpkeCtxRef, jbyteArray plaintextArray,
339     jbyteArray aadArray) {
340   __android_log_print(ANDROID_LOG_INFO, LOG_TAG,
341                       "EVP_HPKE_CTX_seal(%ld, %p, %p)", (long)senderHpkeCtxRef,
342                       plaintextArray, aadArray);
343 
344   EVP_HPKE_CTX *ctx = reinterpret_cast<EVP_HPKE_CTX *>(senderHpkeCtxRef);
345   if (ctx == nullptr) {
346     // TODO(b/274598556) : throw NullPointerException
347     __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "hpke context is null");
348     return {};
349   }
350 
351   if (plaintextArray == nullptr) {
352     __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "plaintext array is null");
353     return {};
354   }
355 
356   jbyte *plaintext = env->GetByteArrayElements(plaintextArray, 0);
357 
358   jbyte *aadArrayElement = nullptr;
359   const uint8_t *aad = nullptr;
360   size_t aadLen = 0;
361   if (aadArray != nullptr) {
362     aadArrayElement = env->GetByteArrayElements(aadArray, 0);
363     aad = reinterpret_cast<const uint8_t *>(aadArrayElement);
364     aadLen = env->GetArrayLength(aadArray);
365   }
366 
367 
368   size_t encryptedLen;
369   std::vector<uint8_t> encrypted(env->GetArrayLength(plaintextArray) +
370                                  EVP_HPKE_CTX_max_overhead(ctx));
371 
372   if (!EVP_HPKE_CTX_seal(/* ctx= */ ctx,
373                          /* out= */ encrypted.data(),
374                          /* out_len= */ &encryptedLen,
375                          /* max_out_len= */ encrypted.size(),
376                          /* in= */ reinterpret_cast<const uint8_t *>(plaintext),
377                          /* in_len= */ env->GetArrayLength(plaintextArray),
378                          /* aad= */ aad,
379                          /* aad_len= */ aadLen)) {
380     env->ReleaseByteArrayElements(plaintextArray, plaintext, JNI_ABORT);
381     if (aadArrayElement != nullptr) {
382       env->ReleaseByteArrayElements(aadArray, aadArrayElement, JNI_ABORT);
383     }
384 
385     __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "EVP_HPKE_CTX_seal failed");
386     return {};
387   }
388 
389   env->ReleaseByteArrayElements(plaintextArray, plaintext, JNI_ABORT);
390   if (aadArrayElement != nullptr) {
391     env->ReleaseByteArrayElements(aadArray, aadArrayElement, JNI_ABORT);
392   }
393 
394   jbyteArray ciphertextArray = env->NewByteArray(encryptedLen);
395   env->SetByteArrayRegion(ciphertextArray, 0, encryptedLen,
396                           reinterpret_cast<const jbyte *>(encrypted.data()));
397   return ciphertextArray;
398 }
399 
400 JNIEXPORT jbyteArray JNICALL
Java_com_android_adservices_ohttp_OhttpJniWrapper_hpkeExport(JNIEnv * env,jclass,jlong hpkeCtxRef,jbyteArray exporterCtxArray,jint length)401 Java_com_android_adservices_ohttp_OhttpJniWrapper_hpkeExport(
402     JNIEnv *env, jclass, jlong hpkeCtxRef, jbyteArray exporterCtxArray,
403     jint length) {
404   __android_log_print(ANDROID_LOG_INFO, LOG_TAG, "HPKE_Export(%ld, %p, %d)",
405                       (long)hpkeCtxRef, exporterCtxArray, (int)length);
406   EVP_HPKE_CTX *ctx = reinterpret_cast<EVP_HPKE_CTX *>(hpkeCtxRef);
407 
408   jbyte *exporterCtxArrayElement = nullptr;
409   const uint8_t *exporterCtx = nullptr;
410   size_t exporterCtxLen = 0;
411   if (exporterCtxArray != nullptr) {
412     exporterCtxArrayElement = env->GetByteArrayElements(exporterCtxArray, 0);
413     exporterCtx = reinterpret_cast<const uint8_t *>(exporterCtxArrayElement);
414     exporterCtxLen = env->GetArrayLength(exporterCtxArray);
415   }
416 
417   size_t exportedLen = length;
418   std::vector<uint8_t> exported(exportedLen);
419 
420   if (!EVP_HPKE_CTX_export(/* ctx= */ ctx,
421                            /* out= */ exported.data(),
422                            /* secret_len= */ exportedLen,
423                            /* context= */ exporterCtx,
424                            /* context_len= */ exporterCtxLen)) {
425     if (exporterCtxArrayElement != nullptr) {
426       env->ReleaseByteArrayElements(exporterCtxArray, exporterCtxArrayElement,
427                                     JNI_ABORT);
428     }
429     __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "HPKE_Export failed");
430     return {};
431   }
432 
433   if (exporterCtxArrayElement != nullptr) {
434     env->ReleaseByteArrayElements(exporterCtxArray, exporterCtxArrayElement,
435                                   JNI_ABORT);
436   }
437 
438   jbyteArray exportedArray = env->NewByteArray(exportedLen);
439   env->SetByteArrayRegion(exportedArray, 0, exportedLen,
440                           reinterpret_cast<const jbyte *>(exported.data()));
441   return exportedArray;
442 }
443 
444 JNIEXPORT jbyteArray JNICALL
Java_com_android_adservices_ohttp_OhttpJniWrapper_hkdfExtract(JNIEnv * env,jclass,jlong hkdfMd,jbyteArray secretArray,jbyteArray saltArray)445 Java_com_android_adservices_ohttp_OhttpJniWrapper_hkdfExtract(
446     JNIEnv *env, jclass, jlong hkdfMd, jbyteArray secretArray,
447     jbyteArray saltArray) {
448   __android_log_print(ANDROID_LOG_INFO, LOG_TAG, "HKDF_extract(%ld, %p, %p)",
449                       (long)hkdfMd, secretArray, saltArray);
450 
451   const EVP_MD *evp_md = reinterpret_cast<const EVP_MD *>(hkdfMd);
452 
453   jbyte *secret = env->GetByteArrayElements(secretArray, 0);
454   size_t secretLen = env->GetArrayLength(secretArray);
455 
456   jbyte *salt = env->GetByteArrayElements(saltArray, 0);
457   size_t saltLen = env->GetArrayLength(saltArray);
458 
459   std::vector<uint8_t> pseudorandom_key(EVP_MAX_MD_SIZE);
460   size_t prk_len;
461 
462   if (!HKDF_extract(reinterpret_cast<uint8_t *>(pseudorandom_key.data()),
463                     &prk_len, evp_md, reinterpret_cast<const uint8_t *>(secret),
464                     secretLen, reinterpret_cast<const uint8_t *>(salt),
465                     saltLen)) {
466     env->ReleaseByteArrayElements(secretArray, secret, JNI_ABORT);
467     env->ReleaseByteArrayElements(saltArray, salt, JNI_ABORT);
468     __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "HKDF_Extract failed");
469     return {};
470   }
471 
472   env->ReleaseByteArrayElements(secretArray, secret, JNI_ABORT);
473   env->ReleaseByteArrayElements(saltArray, salt, JNI_ABORT);
474 
475   pseudorandom_key.resize(prk_len);
476 
477   jbyteArray prkArray = env->NewByteArray(prk_len);
478   env->SetByteArrayRegion(
479       prkArray, 0, prk_len,
480       reinterpret_cast<const jbyte *>(pseudorandom_key.data()));
481   return prkArray;
482 }
483 
484 JNIEXPORT jbyteArray JNICALL
Java_com_android_adservices_ohttp_OhttpJniWrapper_hkdfExpand(JNIEnv * env,jclass,jlong hkdfMd,jbyteArray prkArray,jbyteArray infoArray,jint key_len)485 Java_com_android_adservices_ohttp_OhttpJniWrapper_hkdfExpand(
486     JNIEnv *env, jclass, jlong hkdfMd, jbyteArray prkArray,
487     jbyteArray infoArray, jint key_len) {
488   __android_log_print(ANDROID_LOG_INFO, LOG_TAG, "HKDF_expand(%ld, %p, %p)",
489                       (long)hkdfMd, prkArray, infoArray);
490 
491   const EVP_MD *evp_md = reinterpret_cast<const EVP_MD *>(hkdfMd);
492 
493   jbyte *prk = env->GetByteArrayElements(prkArray, 0);
494   size_t prkLen = env->GetArrayLength(prkArray);
495 
496   jbyte *info = env->GetByteArrayElements(infoArray, 0);
497   size_t infoLen = env->GetArrayLength(infoArray);
498 
499   std::vector<uint8_t> out_key(key_len);
500 
501   if (!HKDF_expand(reinterpret_cast<uint8_t *>(out_key.data()), key_len, evp_md,
502                    reinterpret_cast<const uint8_t *>(prk), prkLen,
503                    reinterpret_cast<const uint8_t *>(info), infoLen)) {
504     env->ReleaseByteArrayElements(prkArray, prk, JNI_ABORT);
505     env->ReleaseByteArrayElements(infoArray, info, JNI_ABORT);
506     __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "HKDF_Expand failed");
507     return {};
508   }
509 
510   env->ReleaseByteArrayElements(prkArray, prk, JNI_ABORT);
511   env->ReleaseByteArrayElements(infoArray, info, JNI_ABORT);
512 
513   jbyteArray responseArray = env->NewByteArray(key_len);
514   env->SetByteArrayRegion(responseArray, 0, key_len,
515                           reinterpret_cast<const jbyte *>(out_key.data()));
516   return responseArray;
517 }
518 
519 JNIEXPORT jbyteArray JNICALL
Java_com_android_adservices_ohttp_OhttpJniWrapper_aeadOpen(JNIEnv * env,jclass,jlong evpAeadRef,jbyteArray keyArray,jbyteArray nonceArray,jbyteArray cipherTextArray)520 Java_com_android_adservices_ohttp_OhttpJniWrapper_aeadOpen(
521     JNIEnv *env, jclass, jlong evpAeadRef, jbyteArray keyArray,
522     jbyteArray nonceArray, jbyteArray cipherTextArray) {
523   __android_log_print(ANDROID_LOG_INFO, LOG_TAG,
524                       "EVP_HPKE_AEAD_CTX_open(%p, %p, %p)", keyArray,
525                       nonceArray, cipherTextArray);
526 
527   const EVP_HPKE_AEAD *hpkeAead = reinterpret_cast<const EVP_HPKE_AEAD *>(evpAeadRef);
528   const EVP_AEAD *aead = EVP_HPKE_AEAD_aead(hpkeAead);
529 
530   if (aead == nullptr) {
531       __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "aead == null");
532       return {};
533   }
534 
535   jbyte *key = env->GetByteArrayElements(keyArray, 0);
536   size_t keyLen = env->GetArrayLength(keyArray);
537 
538   EVP_AEAD_CTX *aead_ctx =
539       EVP_AEAD_CTX_new(aead, reinterpret_cast<const uint8_t *>(key), keyLen, 0);
540 
541   if (aead_ctx == nullptr) {
542     __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "aead ctx == null");
543     return {};
544   }
545 
546   jbyte *nonce = env->GetByteArrayElements(nonceArray, 0);
547   size_t nonceLen = env->GetArrayLength(nonceArray);
548 
549   jbyte *ciphertext = env->GetByteArrayElements(cipherTextArray, 0);
550   size_t ciphertextLen = env->GetArrayLength(cipherTextArray);
551 
552   std::vector<uint8_t> plaintext(ciphertextLen);
553   size_t plaintextLen;
554 
555   if (!EVP_AEAD_CTX_open(aead_ctx,
556                          reinterpret_cast<uint8_t *>(plaintext.data()),
557                          &plaintextLen, plaintext.size(),
558                          reinterpret_cast<const uint8_t *>(nonce), nonceLen,
559                          reinterpret_cast<const uint8_t *>(ciphertext),
560                          ciphertextLen, nullptr, 0)) {
561     __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "EVP_AEAD_CTX_open failed");
562     env->ReleaseByteArrayElements(keyArray, key, JNI_ABORT);
563     env->ReleaseByteArrayElements(nonceArray, nonce, JNI_ABORT);
564     env->ReleaseByteArrayElements(cipherTextArray, ciphertext, JNI_ABORT);
565     return {};
566   }
567 
568   env->ReleaseByteArrayElements(keyArray, key, JNI_ABORT);
569   env->ReleaseByteArrayElements(nonceArray, nonce, JNI_ABORT);
570   env->ReleaseByteArrayElements(cipherTextArray, ciphertext, JNI_ABORT);
571   plaintext.resize(plaintextLen);
572 
573   jbyteArray plaintextArray = env->NewByteArray(plaintextLen);
574   env->SetByteArrayRegion(plaintextArray, 0, plaintextLen,
575                           reinterpret_cast<const jbyte *>(plaintext.data()));
576   return plaintextArray;
577 }
578 
579 
580 JNIEXPORT jbyteArray JNICALL
Java_com_android_adservices_ohttp_OhttpJniWrapper_aeadSeal(JNIEnv * env,jclass,jlong evpAeadRef,jbyteArray keyArray,jbyteArray nonceArray,jbyteArray plainTextArray)581 Java_com_android_adservices_ohttp_OhttpJniWrapper_aeadSeal(
582     JNIEnv *env,
583     jclass,
584     jlong evpAeadRef,
585     jbyteArray keyArray,
586     jbyteArray nonceArray,
587     jbyteArray plainTextArray) {
588   __android_log_print(ANDROID_LOG_INFO,
589                         LOG_TAG,
590                        "aead_Seal(%p, %p, %p)",
591                         keyArray,
592                         nonceArray,
593                         plainTextArray);
594 
595   const EVP_HPKE_AEAD *hpkeAead = reinterpret_cast<const EVP_HPKE_AEAD *>(evpAeadRef);
596   const EVP_AEAD *aead = EVP_HPKE_AEAD_aead(hpkeAead);
597 
598   if (aead == nullptr) {
599       jni_util::JniUtil::ThrowJavaException(
600               env,
601               IllegalArgumentExceptionClass,
602               "Unable to initialize AEAD object");
603       return {};
604   }
605 
606   jbyte *keyPtr = env->GetByteArrayElements(keyArray, 0);
607   size_t keyLen = env->GetArrayLength(keyArray);
608 
609   EVP_AEAD_CTX *aeadCtx =
610       EVP_AEAD_CTX_new(aead, reinterpret_cast<const uint8_t *>(keyPtr), keyLen, 0);
611 
612   if (aeadCtx == nullptr) {
613     env->ReleaseByteArrayElements(keyArray, keyPtr, JNI_ABORT);
614     jni_util::JniUtil::ThrowJavaException(
615             env,
616             IllegalArgumentExceptionClass,
617             "Unable to initialize AEAD ctx object");
618     return {};
619   }
620 
621   jbyte *noncePtr = env->GetByteArrayElements(nonceArray, 0);
622   size_t nonceLen = env->GetArrayLength(nonceArray);
623 
624   jbyte *plainTextPtr = env->GetByteArrayElements(plainTextArray, 0);
625   size_t plainTextLen = env->GetArrayLength(plainTextArray);
626 
627   const size_t maxEncryptedDataSize =
628       nonceLen + plainTextLen + EVP_AEAD_max_overhead(aead);
629   std::string encryptedData(maxEncryptedDataSize, '\0');
630 
631   size_t ciphertextLen;
632   if (!EVP_AEAD_CTX_seal(aeadCtx,
633                          reinterpret_cast<uint8_t *>(encryptedData.data()),
634                          &ciphertextLen,
635                          encryptedData.size() - nonceLen,
636                          reinterpret_cast<const uint8_t *>(noncePtr),
637                          nonceLen,
638                          reinterpret_cast<const uint8_t *>(plainTextPtr),
639                          plainTextLen,
640                          nullptr,
641                          0)) {
642     env->ReleaseByteArrayElements(keyArray, keyPtr, JNI_ABORT);
643     env->ReleaseByteArrayElements(nonceArray, noncePtr, JNI_ABORT);
644     env->ReleaseByteArrayElements(plainTextArray, plainTextPtr, JNI_ABORT);
645     jni_util::JniUtil::ThrowJavaException(
646             env,
647             IllegalStateExceptionClass,
648             "EVP_AEAD_CTX_seal failed");
649     return {};
650   }
651 
652   env->ReleaseByteArrayElements(keyArray, keyPtr, JNI_ABORT);
653   env->ReleaseByteArrayElements(nonceArray, noncePtr, JNI_ABORT);
654   env->ReleaseByteArrayElements(plainTextArray, plainTextPtr, JNI_ABORT);
655   encryptedData.resize(ciphertextLen + nonceLen);
656 
657   jbyteArray encryptedDataArray = env->NewByteArray(ciphertextLen);
658   env->SetByteArrayRegion(encryptedDataArray, 0, ciphertextLen,
659                           reinterpret_cast<const jbyte *>(encryptedData.data()));
660   return encryptedDataArray;
661 }
662