1 /*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "keymint_benchmark"
18
19 #include <base/command_line.h>
20 #include <benchmark/benchmark.h>
21 #include <iostream>
22
23 #include <aidl/Vintf.h>
24 #include <aidl/android/hardware/security/keymint/ErrorCode.h>
25 #include <aidl/android/hardware/security/keymint/IKeyMintDevice.h>
26 #include <android/binder_manager.h>
27 #include <binder/IServiceManager.h>
28 #include <keymint_support/authorization_set.h>
29
30 #define SMALL_MESSAGE_SIZE 64
31 #define MEDIUM_MESSAGE_SIZE 1024
32 #define LARGE_MESSAGE_SIZE 131072
33
34 namespace aidl::android::hardware::security::keymint::test {
35
36 ::std::ostream& operator<<(::std::ostream& os, const keymint::AuthorizationSet& set);
37
38 using ::android::sp;
39 using Status = ::ndk::ScopedAStatus;
40 using ::std::optional;
41 using ::std::shared_ptr;
42 using ::std::string;
43 using ::std::vector;
44
45 class KeyMintBenchmarkTest {
46 public:
KeyMintBenchmarkTest()47 KeyMintBenchmarkTest() {
48 message_cache_.push_back(string(SMALL_MESSAGE_SIZE, 'x'));
49 message_cache_.push_back(string(MEDIUM_MESSAGE_SIZE, 'x'));
50 message_cache_.push_back(string(LARGE_MESSAGE_SIZE, 'x'));
51 }
52
newInstance(const char * instanceName)53 static KeyMintBenchmarkTest* newInstance(const char* instanceName) {
54 if (AServiceManager_isDeclared(instanceName)) {
55 ::ndk::SpAIBinder binder(AServiceManager_waitForService(instanceName));
56 KeyMintBenchmarkTest* test = new KeyMintBenchmarkTest();
57 test->InitializeKeyMint(IKeyMintDevice::fromBinder(binder));
58 return test;
59 } else {
60 return nullptr;
61 }
62 }
63
getError()64 int getError() { return static_cast<int>(error_); }
65
GenerateMessage(int size)66 const string& GenerateMessage(int size) {
67 for (const string& message : message_cache_) {
68 if (message.size() == size) {
69 return message;
70 }
71 }
72 string message = string(size, 'x');
73 message_cache_.push_back(message);
74 return std::move(message);
75 }
76
getBlockMode(string transform)77 optional<BlockMode> getBlockMode(string transform) {
78 if (transform.find("/ECB") != string::npos) {
79 return BlockMode::ECB;
80 } else if (transform.find("/CBC") != string::npos) {
81 return BlockMode::CBC;
82 } else if (transform.find("/CTR") != string::npos) {
83 return BlockMode::CTR;
84 } else if (transform.find("/GCM") != string::npos) {
85 return BlockMode::GCM;
86 }
87 return {};
88 }
89
getPadding(string transform,bool sign)90 PaddingMode getPadding(string transform, bool sign) {
91 if (transform.find("/PKCS7") != string::npos) {
92 return PaddingMode::PKCS7;
93 } else if (transform.find("/PSS") != string::npos) {
94 return PaddingMode::RSA_PSS;
95 } else if (transform.find("/OAEP") != string::npos) {
96 return PaddingMode::RSA_OAEP;
97 } else if (transform.find("/PKCS1") != string::npos) {
98 return sign ? PaddingMode::RSA_PKCS1_1_5_SIGN : PaddingMode::RSA_PKCS1_1_5_ENCRYPT;
99 } else if (sign && transform.find("RSA") != string::npos) {
100 // RSA defaults to PKCS1 for sign
101 return PaddingMode::RSA_PKCS1_1_5_SIGN;
102 }
103 return PaddingMode::NONE;
104 }
105
getAlgorithm(string transform)106 optional<Algorithm> getAlgorithm(string transform) {
107 if (transform.find("AES") != string::npos) {
108 return Algorithm::AES;
109 } else if (transform.find("Hmac") != string::npos) {
110 return Algorithm::HMAC;
111 } else if (transform.find("DESede") != string::npos) {
112 return Algorithm::TRIPLE_DES;
113 } else if (transform.find("RSA") != string::npos) {
114 return Algorithm::RSA;
115 } else if (transform.find("EC") != string::npos) {
116 return Algorithm::EC;
117 }
118 std::cerr << "Can't find algorithm for " << transform << std::endl;
119 return {};
120 }
121
getDigest(string transform)122 Digest getDigest(string transform) {
123 if (transform.find("MD5") != string::npos) {
124 return Digest::MD5;
125 } else if (transform.find("SHA1") != string::npos ||
126 transform.find("SHA-1") != string::npos) {
127 return Digest::SHA1;
128 } else if (transform.find("SHA224") != string::npos) {
129 return Digest::SHA_2_224;
130 } else if (transform.find("SHA256") != string::npos) {
131 return Digest::SHA_2_256;
132 } else if (transform.find("SHA384") != string::npos) {
133 return Digest::SHA_2_384;
134 } else if (transform.find("SHA512") != string::npos) {
135 return Digest::SHA_2_512;
136 } else if (transform.find("RSA") != string::npos &&
137 transform.find("OAEP") != string::npos) {
138 return Digest::SHA1;
139 } else if (transform.find("Hmac") != string::npos) {
140 return Digest::SHA_2_256;
141 }
142 return Digest::NONE;
143 }
144
GenerateKey(string transform,int keySize,bool sign=false)145 bool GenerateKey(string transform, int keySize, bool sign = false) {
146 if (transform == key_transform_) {
147 return true;
148 } else if (key_transform_ != "") {
149 // Deleting old key first
150 key_transform_ = "";
151 if (DeleteKey() != ErrorCode::OK) {
152 return false;
153 }
154 }
155 std::optional<Algorithm> algorithm = getAlgorithm(transform);
156 if (!algorithm) {
157 std::cerr << "Error: invalid algorithm " << transform << std::endl;
158 return false;
159 }
160 key_transform_ = transform;
161 AuthorizationSetBuilder authSet = AuthorizationSetBuilder()
162 .Authorization(TAG_NO_AUTH_REQUIRED)
163 .Authorization(TAG_PURPOSE, KeyPurpose::ENCRYPT)
164 .Authorization(TAG_PURPOSE, KeyPurpose::DECRYPT)
165 .Authorization(TAG_PURPOSE, KeyPurpose::SIGN)
166 .Authorization(TAG_PURPOSE, KeyPurpose::VERIFY)
167 .Authorization(TAG_KEY_SIZE, keySize)
168 .Authorization(TAG_ALGORITHM, algorithm.value())
169 .Digest(getDigest(transform))
170 .Padding(getPadding(transform, sign));
171 std::optional<BlockMode> blockMode = getBlockMode(transform);
172 if (blockMode) {
173 authSet.BlockMode(blockMode.value());
174 if (blockMode == BlockMode::GCM) {
175 authSet.Authorization(TAG_MIN_MAC_LENGTH, 128);
176 }
177 }
178 if (algorithm == Algorithm::HMAC) {
179 authSet.Authorization(TAG_MIN_MAC_LENGTH, 128);
180 }
181 if (algorithm == Algorithm::RSA) {
182 authSet.Authorization(TAG_RSA_PUBLIC_EXPONENT, 65537U);
183 authSet.SetDefaultValidity();
184 }
185 if (algorithm == Algorithm::EC) {
186 authSet.SetDefaultValidity();
187 }
188 error_ = GenerateKey(authSet);
189 return error_ == ErrorCode::OK;
190 }
191
getOperationParams(string transform,bool sign=false)192 AuthorizationSet getOperationParams(string transform, bool sign = false) {
193 AuthorizationSetBuilder builder = AuthorizationSetBuilder()
194 .Padding(getPadding(transform, sign))
195 .Digest(getDigest(transform));
196 std::optional<BlockMode> blockMode = getBlockMode(transform);
197 if (sign && (transform.find("Hmac") != string::npos)) {
198 builder.Authorization(TAG_MAC_LENGTH, 128);
199 }
200 if (blockMode) {
201 builder.BlockMode(*blockMode);
202 if (blockMode == BlockMode::GCM) {
203 builder.Authorization(TAG_MAC_LENGTH, 128);
204 }
205 }
206 return std::move(builder);
207 }
208
Process(const string & message,const string & signature="")209 optional<string> Process(const string& message, const string& signature = "") {
210 ErrorCode result;
211
212 string output;
213 result = Finish(message, signature, &output);
214 if (result != ErrorCode::OK) {
215 error_ = result;
216 return {};
217 }
218 return output;
219 }
220
DeleteKey()221 ErrorCode DeleteKey() {
222 Status result = keymint_->deleteKey(key_blob_);
223 key_blob_ = vector<uint8_t>();
224 return GetReturnErrorCode(result);
225 }
226
Begin(KeyPurpose purpose,const AuthorizationSet & in_params,AuthorizationSet * out_params)227 ErrorCode Begin(KeyPurpose purpose, const AuthorizationSet& in_params,
228 AuthorizationSet* out_params) {
229 Status result;
230 BeginResult out;
231 result = keymint_->begin(purpose, key_blob_, in_params.vector_data(), std::nullopt, &out);
232 if (result.isOk()) {
233 *out_params = out.params;
234 op_ = out.operation;
235 }
236 return GetReturnErrorCode(result);
237 }
238
239 SecurityLevel securityLevel_;
240 string name_;
241
242 private:
GenerateKey(const AuthorizationSet & key_desc,const optional<AttestationKey> & attest_key=std::nullopt)243 ErrorCode GenerateKey(const AuthorizationSet& key_desc,
244 const optional<AttestationKey>& attest_key = std::nullopt) {
245 key_blob_.clear();
246 KeyCreationResult creationResult;
247 Status result = keymint_->generateKey(key_desc.vector_data(), attest_key, &creationResult);
248 if (result.isOk()) {
249 key_blob_ = std::move(creationResult.keyBlob);
250 creationResult.keyCharacteristics.clear();
251 creationResult.certificateChain.clear();
252 }
253 return GetReturnErrorCode(result);
254 }
255
InitializeKeyMint(std::shared_ptr<IKeyMintDevice> keyMint)256 void InitializeKeyMint(std::shared_ptr<IKeyMintDevice> keyMint) {
257 if (!keyMint) {
258 std::cerr << "Trying initialize nullptr in InitializeKeyMint" << std::endl;
259 return;
260 }
261 keymint_ = std::move(keyMint);
262 KeyMintHardwareInfo info;
263 Status result = keymint_->getHardwareInfo(&info);
264 if (!result.isOk()) {
265 std::cerr << "InitializeKeyMint: getHardwareInfo failed with "
266 << result.getServiceSpecificError() << std::endl;
267 }
268 securityLevel_ = info.securityLevel;
269 name_.assign(info.keyMintName.begin(), info.keyMintName.end());
270 }
271
Finish(const string & input,const string & signature,string * output)272 ErrorCode Finish(const string& input, const string& signature, string* output) {
273 if (!op_) {
274 std::cerr << "Finish: Operation is nullptr" << std::endl;
275 return ErrorCode::UNEXPECTED_NULL_POINTER;
276 }
277
278 vector<uint8_t> oPut;
279 Status result =
280 op_->finish(vector<uint8_t>(input.begin(), input.end()),
281 vector<uint8_t>(signature.begin(), signature.end()), {} /* authToken */,
282 {} /* timestampToken */, {} /* confirmationToken */, &oPut);
283
284 if (result.isOk()) output->append(oPut.begin(), oPut.end());
285
286 op_.reset();
287 return GetReturnErrorCode(result);
288 }
289
Update(const string & input,string * output)290 ErrorCode Update(const string& input, string* output) {
291 Status result;
292 if (!op_) {
293 std::cerr << "Update: Operation is nullptr" << std::endl;
294 return ErrorCode::UNEXPECTED_NULL_POINTER;
295 }
296
297 std::vector<uint8_t> o_put;
298 result = op_->update(vector<uint8_t>(input.begin(), input.end()), {} /* authToken */,
299 {} /* timestampToken */, &o_put);
300
301 if (result.isOk() && output) *output = {o_put.begin(), o_put.end()};
302 return GetReturnErrorCode(result);
303 }
304
GetReturnErrorCode(const Status & result)305 ErrorCode GetReturnErrorCode(const Status& result) {
306 error_ = static_cast<ErrorCode>(result.getServiceSpecificError());
307 if (result.isOk()) return ErrorCode::OK;
308
309 if (result.getExceptionCode() == EX_SERVICE_SPECIFIC) {
310 return static_cast<ErrorCode>(result.getServiceSpecificError());
311 }
312
313 return ErrorCode::UNKNOWN_ERROR;
314 }
315
316 std::shared_ptr<IKeyMintOperation> op_;
317 vector<Certificate> cert_chain_;
318 vector<uint8_t> key_blob_;
319 vector<KeyCharacteristics> key_characteristics_;
320 std::shared_ptr<IKeyMintDevice> keymint_;
321 std::vector<string> message_cache_;
322 std::string key_transform_;
323 ErrorCode error_;
324 };
325
326 KeyMintBenchmarkTest* keymintTest;
327
settings(benchmark::internal::Benchmark * benchmark)328 static void settings(benchmark::internal::Benchmark* benchmark) {
329 benchmark->Unit(benchmark::kMillisecond);
330 }
331
addDefaultLabel(benchmark::State & state)332 static void addDefaultLabel(benchmark::State& state) {
333 std::string secLevel;
334 switch (keymintTest->securityLevel_) {
335 case SecurityLevel::STRONGBOX:
336 secLevel = "STRONGBOX";
337 break;
338 case SecurityLevel::SOFTWARE:
339 secLevel = "SOFTWARE";
340 break;
341 case SecurityLevel::TRUSTED_ENVIRONMENT:
342 secLevel = "TEE";
343 break;
344 case SecurityLevel::KEYSTORE:
345 secLevel = "KEYSTORE";
346 break;
347 }
348 state.SetLabel("hardware_name:" + keymintTest->name_ + " sec_level:" + secLevel);
349 }
350
351 // clang-format off
352 #define BENCHMARK_KM(func, transform, keySize) \
353 BENCHMARK_CAPTURE(func, transform/keySize, #transform "/" #keySize, keySize)->Apply(settings);
354 #define BENCHMARK_KM_MSG(func, transform, keySize, msgSize) \
355 BENCHMARK_CAPTURE(func, transform/keySize/msgSize, #transform "/" #keySize "/" #msgSize, \
356 keySize, msgSize) \
357 ->Apply(settings);
358
359 #define BENCHMARK_KM_ALL_MSGS(func, transform, keySize) \
360 BENCHMARK_KM_MSG(func, transform, keySize, SMALL_MESSAGE_SIZE) \
361 BENCHMARK_KM_MSG(func, transform, keySize, MEDIUM_MESSAGE_SIZE) \
362 BENCHMARK_KM_MSG(func, transform, keySize, LARGE_MESSAGE_SIZE)
363
364 #define BENCHMARK_KM_CIPHER(transform, keySize, msgSize) \
365 BENCHMARK_KM_MSG(encrypt, transform, keySize, msgSize) \
366 BENCHMARK_KM_MSG(decrypt, transform, keySize, msgSize)
367
368 #define BENCHMARK_KM_CIPHER_ALL_MSGS(transform, keySize) \
369 BENCHMARK_KM_ALL_MSGS(encrypt, transform, keySize) \
370 BENCHMARK_KM_ALL_MSGS(decrypt, transform, keySize)
371
372 #define BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, keySize) \
373 BENCHMARK_KM_ALL_MSGS(sign, transform, keySize) \
374 BENCHMARK_KM_ALL_MSGS(verify, transform, keySize)
375 // clang-format on
376
377 /*
378 * ============= KeyGen TESTS ==================
379 */
keygen(benchmark::State & state,string transform,int keySize)380 static void keygen(benchmark::State& state, string transform, int keySize) {
381 addDefaultLabel(state);
382 for (auto _ : state) {
383 if (!keymintTest->GenerateKey(transform, keySize)) {
384 state.SkipWithError(
385 ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
386 }
387 state.PauseTiming();
388
389 keymintTest->DeleteKey();
390 state.ResumeTiming();
391 }
392 }
393
394 BENCHMARK_KM(keygen, AES, 128);
395 BENCHMARK_KM(keygen, AES, 256);
396
397 BENCHMARK_KM(keygen, RSA, 2048);
398 BENCHMARK_KM(keygen, RSA, 3072);
399 BENCHMARK_KM(keygen, RSA, 4096);
400
401 BENCHMARK_KM(keygen, EC, 224);
402 BENCHMARK_KM(keygen, EC, 256);
403 BENCHMARK_KM(keygen, EC, 384);
404 BENCHMARK_KM(keygen, EC, 521);
405
406 BENCHMARK_KM(keygen, DESede, 168);
407
408 BENCHMARK_KM(keygen, Hmac, 64);
409 BENCHMARK_KM(keygen, Hmac, 128);
410 BENCHMARK_KM(keygen, Hmac, 256);
411 BENCHMARK_KM(keygen, Hmac, 512);
412
413 /*
414 * ============= SIGNATURE TESTS ==================
415 */
416
sign(benchmark::State & state,string transform,int keySize,int msgSize)417 static void sign(benchmark::State& state, string transform, int keySize, int msgSize) {
418 addDefaultLabel(state);
419 if (!keymintTest->GenerateKey(transform, keySize, true)) {
420 state.SkipWithError(
421 ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
422 return;
423 }
424
425 auto in_params = keymintTest->getOperationParams(transform, true);
426 AuthorizationSet out_params;
427 string message = keymintTest->GenerateMessage(msgSize);
428
429 for (auto _ : state) {
430 state.PauseTiming();
431 ErrorCode error = keymintTest->Begin(KeyPurpose::SIGN, in_params, &out_params);
432 if (error != ErrorCode::OK) {
433 state.SkipWithError(
434 ("Error beginning sign, " + std::to_string(keymintTest->getError())).c_str());
435 return;
436 }
437 state.ResumeTiming();
438 out_params.Clear();
439 if (!keymintTest->Process(message)) {
440 state.SkipWithError(("Sign error, " + std::to_string(keymintTest->getError())).c_str());
441 break;
442 }
443 }
444 }
445
verify(benchmark::State & state,string transform,int keySize,int msgSize)446 static void verify(benchmark::State& state, string transform, int keySize, int msgSize) {
447 addDefaultLabel(state);
448 if (!keymintTest->GenerateKey(transform, keySize, true)) {
449 state.SkipWithError(
450 ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
451 return;
452 }
453 AuthorizationSet out_params;
454 auto in_params = keymintTest->getOperationParams(transform, true);
455 string message = keymintTest->GenerateMessage(msgSize);
456 ErrorCode error = keymintTest->Begin(KeyPurpose::SIGN, in_params, &out_params);
457 if (error != ErrorCode::OK) {
458 state.SkipWithError(
459 ("Error beginning sign, " + std::to_string(keymintTest->getError())).c_str());
460 return;
461 }
462 std::optional<string> signature = keymintTest->Process(message);
463 if (!signature) {
464 state.SkipWithError(("Sign error, " + std::to_string(keymintTest->getError())).c_str());
465 return;
466 }
467 out_params.Clear();
468 if (transform.find("Hmac") != string::npos) {
469 in_params = keymintTest->getOperationParams(transform, false);
470 }
471 for (auto _ : state) {
472 state.PauseTiming();
473 error = keymintTest->Begin(KeyPurpose::VERIFY, in_params, &out_params);
474 if (error != ErrorCode::OK) {
475 state.SkipWithError(
476 ("Verify begin error, " + std::to_string(keymintTest->getError())).c_str());
477 return;
478 }
479 state.ResumeTiming();
480 if (!keymintTest->Process(message, *signature)) {
481 state.SkipWithError(
482 ("Verify error, " + std::to_string(keymintTest->getError())).c_str());
483 break;
484 }
485 }
486 }
487
488 // clang-format off
489 #define BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(transform) \
490 BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 64) \
491 BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 128) \
492 BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 256) \
493 BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 512)
494
495 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA1)
496 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA256)
497 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA224)
498 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA256)
499 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA384)
500 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA512)
501
502 #define BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(transform) \
503 BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 224) \
504 BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 256) \
505 BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 384) \
506 BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 521)
507
508 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(NONEwithECDSA);
509 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA1withECDSA);
510 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA224withECDSA);
511 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA256withECDSA);
512 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA384withECDSA);
513 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA512withECDSA);
514
515 #define BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(transform) \
516 BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 2048) \
517 BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 3072) \
518 BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 4096)
519
520 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(MD5withRSA);
521 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA1withRSA);
522 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA224withRSA);
523 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA384withRSA);
524 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA512withRSA);
525
526 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(MD5withRSA/PSS);
527 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA1withRSA/PSS);
528 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA224withRSA/PSS);
529 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA384withRSA/PSS);
530 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA512withRSA/PSS);
531 // clang-format on
532
533 /*
534 * ============= CIPHER TESTS ==================
535 */
536
encrypt(benchmark::State & state,string transform,int keySize,int msgSize)537 static void encrypt(benchmark::State& state, string transform, int keySize, int msgSize) {
538 addDefaultLabel(state);
539 if (!keymintTest->GenerateKey(transform, keySize)) {
540 state.SkipWithError(
541 ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
542 return;
543 }
544 auto in_params = keymintTest->getOperationParams(transform);
545 AuthorizationSet out_params;
546 string message = keymintTest->GenerateMessage(msgSize);
547
548 for (auto _ : state) {
549 state.PauseTiming();
550 auto error = keymintTest->Begin(KeyPurpose::ENCRYPT, in_params, &out_params);
551 if (error != ErrorCode::OK) {
552 state.SkipWithError(
553 ("Encryption begin error, " + std::to_string(keymintTest->getError())).c_str());
554 return;
555 }
556 out_params.Clear();
557 state.ResumeTiming();
558 if (!keymintTest->Process(message)) {
559 state.SkipWithError(
560 ("Encryption error, " + std::to_string(keymintTest->getError())).c_str());
561 break;
562 }
563 }
564 }
565
decrypt(benchmark::State & state,string transform,int keySize,int msgSize)566 static void decrypt(benchmark::State& state, string transform, int keySize, int msgSize) {
567 addDefaultLabel(state);
568 if (!keymintTest->GenerateKey(transform, keySize)) {
569 state.SkipWithError(
570 ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
571 return;
572 }
573 AuthorizationSet out_params;
574 AuthorizationSet in_params = keymintTest->getOperationParams(transform);
575 string message = keymintTest->GenerateMessage(msgSize);
576 auto error = keymintTest->Begin(KeyPurpose::ENCRYPT, in_params, &out_params);
577 if (error != ErrorCode::OK) {
578 state.SkipWithError(
579 ("Encryption begin error, " + std::to_string(keymintTest->getError())).c_str());
580 return;
581 }
582 auto encryptedMessage = keymintTest->Process(message);
583 if (!encryptedMessage) {
584 state.SkipWithError(
585 ("Encryption error, " + std::to_string(keymintTest->getError())).c_str());
586 return;
587 }
588 in_params.push_back(out_params);
589 out_params.Clear();
590 for (auto _ : state) {
591 state.PauseTiming();
592 error = keymintTest->Begin(KeyPurpose::DECRYPT, in_params, &out_params);
593 if (error != ErrorCode::OK) {
594 state.SkipWithError(
595 ("Decryption begin error, " + std::to_string(keymintTest->getError())).c_str());
596 return;
597 }
598 state.ResumeTiming();
599 if (!keymintTest->Process(*encryptedMessage)) {
600 state.SkipWithError(
601 ("Decryption error, " + std::to_string(keymintTest->getError())).c_str());
602 break;
603 }
604 }
605 }
606
607 // clang-format off
608 // AES
609 #define BENCHMARK_KM_CIPHER_ALL_AES_KEYS(transform) \
610 BENCHMARK_KM_CIPHER_ALL_MSGS(transform, 128) \
611 BENCHMARK_KM_CIPHER_ALL_MSGS(transform, 256)
612
613 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/CBC/NoPadding);
614 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/CBC/PKCS7Padding);
615 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/CTR/NoPadding);
616 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/ECB/NoPadding);
617 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/ECB/PKCS7Padding);
618 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/GCM/NoPadding);
619
620 // Triple DES
621 BENCHMARK_KM_CIPHER_ALL_MSGS(DESede/CBC/NoPadding, 168);
622 BENCHMARK_KM_CIPHER_ALL_MSGS(DESede/CBC/PKCS7Padding, 168);
623 BENCHMARK_KM_CIPHER_ALL_MSGS(DESede/ECB/NoPadding, 168);
624 BENCHMARK_KM_CIPHER_ALL_MSGS(DESede/ECB/PKCS7Padding, 168);
625
626 #define BENCHMARK_KM_CIPHER_ALL_RSA_KEYS(transform, msgSize) \
627 BENCHMARK_KM_CIPHER(transform, 2048, msgSize) \
628 BENCHMARK_KM_CIPHER(transform, 3072, msgSize) \
629 BENCHMARK_KM_CIPHER(transform, 4096, msgSize)
630
631 BENCHMARK_KM_CIPHER_ALL_RSA_KEYS(RSA/ECB/NoPadding, SMALL_MESSAGE_SIZE);
632 BENCHMARK_KM_CIPHER_ALL_RSA_KEYS(RSA/ECB/PKCS1Padding, SMALL_MESSAGE_SIZE);
633 BENCHMARK_KM_CIPHER_ALL_RSA_KEYS(RSA/ECB/OAEPPadding, SMALL_MESSAGE_SIZE);
634
635 // clang-format on
636 } // namespace aidl::android::hardware::security::keymint::test
637
main(int argc,char ** argv)638 int main(int argc, char** argv) {
639 ::benchmark::Initialize(&argc, argv);
640 base::CommandLine::Init(argc, argv);
641 base::CommandLine* command_line = base::CommandLine::ForCurrentProcess();
642 auto service_name = command_line->GetSwitchValueASCII("service_name");
643 if (service_name.empty()) {
644 service_name =
645 std::string(
646 aidl::android::hardware::security::keymint::IKeyMintDevice::descriptor) +
647 "/default";
648 }
649 std::cerr << service_name << std::endl;
650 aidl::android::hardware::security::keymint::test::keymintTest =
651 aidl::android::hardware::security::keymint::test::KeyMintBenchmarkTest::newInstance(
652 service_name.c_str());
653 if (!aidl::android::hardware::security::keymint::test::keymintTest) {
654 return 1;
655 }
656 ::benchmark::RunSpecifiedBenchmarks();
657 }
658