1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "keymint_benchmark"
18 
19 #include <base/command_line.h>
20 #include <benchmark/benchmark.h>
21 #include <iostream>
22 
23 #include <aidl/Vintf.h>
24 #include <aidl/android/hardware/security/keymint/ErrorCode.h>
25 #include <aidl/android/hardware/security/keymint/IKeyMintDevice.h>
26 #include <android/binder_manager.h>
27 #include <binder/IServiceManager.h>
28 #include <keymint_support/authorization_set.h>
29 
30 #define SMALL_MESSAGE_SIZE 64
31 #define MEDIUM_MESSAGE_SIZE 1024
32 #define LARGE_MESSAGE_SIZE 131072
33 
34 namespace aidl::android::hardware::security::keymint::test {
35 
36 ::std::ostream& operator<<(::std::ostream& os, const keymint::AuthorizationSet& set);
37 
38 using ::android::sp;
39 using Status = ::ndk::ScopedAStatus;
40 using ::std::optional;
41 using ::std::shared_ptr;
42 using ::std::string;
43 using ::std::vector;
44 
45 class KeyMintBenchmarkTest {
46   public:
KeyMintBenchmarkTest()47     KeyMintBenchmarkTest() {
48         message_cache_.push_back(string(SMALL_MESSAGE_SIZE, 'x'));
49         message_cache_.push_back(string(MEDIUM_MESSAGE_SIZE, 'x'));
50         message_cache_.push_back(string(LARGE_MESSAGE_SIZE, 'x'));
51     }
52 
newInstance(const char * instanceName)53     static KeyMintBenchmarkTest* newInstance(const char* instanceName) {
54         if (AServiceManager_isDeclared(instanceName)) {
55             ::ndk::SpAIBinder binder(AServiceManager_waitForService(instanceName));
56             KeyMintBenchmarkTest* test = new KeyMintBenchmarkTest();
57             test->InitializeKeyMint(IKeyMintDevice::fromBinder(binder));
58             return test;
59         } else {
60             return nullptr;
61         }
62     }
63 
getError()64     int getError() { return static_cast<int>(error_); }
65 
GenerateMessage(int size)66     const string& GenerateMessage(int size) {
67         for (const string& message : message_cache_) {
68             if (message.size() == size) {
69                 return message;
70             }
71         }
72         string message = string(size, 'x');
73         message_cache_.push_back(message);
74         return std::move(message);
75     }
76 
getBlockMode(string transform)77     optional<BlockMode> getBlockMode(string transform) {
78         if (transform.find("/ECB") != string::npos) {
79             return BlockMode::ECB;
80         } else if (transform.find("/CBC") != string::npos) {
81             return BlockMode::CBC;
82         } else if (transform.find("/CTR") != string::npos) {
83             return BlockMode::CTR;
84         } else if (transform.find("/GCM") != string::npos) {
85             return BlockMode::GCM;
86         }
87         return {};
88     }
89 
getPadding(string transform,bool sign)90     PaddingMode getPadding(string transform, bool sign) {
91         if (transform.find("/PKCS7") != string::npos) {
92             return PaddingMode::PKCS7;
93         } else if (transform.find("/PSS") != string::npos) {
94             return PaddingMode::RSA_PSS;
95         } else if (transform.find("/OAEP") != string::npos) {
96             return PaddingMode::RSA_OAEP;
97         } else if (transform.find("/PKCS1") != string::npos) {
98             return sign ? PaddingMode::RSA_PKCS1_1_5_SIGN : PaddingMode::RSA_PKCS1_1_5_ENCRYPT;
99         } else if (sign && transform.find("RSA") != string::npos) {
100             // RSA defaults to PKCS1 for sign
101             return PaddingMode::RSA_PKCS1_1_5_SIGN;
102         }
103         return PaddingMode::NONE;
104     }
105 
getAlgorithm(string transform)106     optional<Algorithm> getAlgorithm(string transform) {
107         if (transform.find("AES") != string::npos) {
108             return Algorithm::AES;
109         } else if (transform.find("Hmac") != string::npos) {
110             return Algorithm::HMAC;
111         } else if (transform.find("DESede") != string::npos) {
112             return Algorithm::TRIPLE_DES;
113         } else if (transform.find("RSA") != string::npos) {
114             return Algorithm::RSA;
115         } else if (transform.find("EC") != string::npos) {
116             return Algorithm::EC;
117         }
118         std::cerr << "Can't find algorithm for " << transform << std::endl;
119         return {};
120     }
121 
getDigest(string transform)122     Digest getDigest(string transform) {
123         if (transform.find("MD5") != string::npos) {
124             return Digest::MD5;
125         } else if (transform.find("SHA1") != string::npos ||
126                    transform.find("SHA-1") != string::npos) {
127             return Digest::SHA1;
128         } else if (transform.find("SHA224") != string::npos) {
129             return Digest::SHA_2_224;
130         } else if (transform.find("SHA256") != string::npos) {
131             return Digest::SHA_2_256;
132         } else if (transform.find("SHA384") != string::npos) {
133             return Digest::SHA_2_384;
134         } else if (transform.find("SHA512") != string::npos) {
135             return Digest::SHA_2_512;
136         } else if (transform.find("RSA") != string::npos &&
137                    transform.find("OAEP") != string::npos) {
138             return Digest::SHA1;
139         } else if (transform.find("Hmac") != string::npos) {
140             return Digest::SHA_2_256;
141         }
142         return Digest::NONE;
143     }
144 
GenerateKey(string transform,int keySize,bool sign=false)145     bool GenerateKey(string transform, int keySize, bool sign = false) {
146         if (transform == key_transform_) {
147             return true;
148         } else if (key_transform_ != "") {
149             // Deleting old key first
150             key_transform_ = "";
151             if (DeleteKey() != ErrorCode::OK) {
152                 return false;
153             }
154         }
155         std::optional<Algorithm> algorithm = getAlgorithm(transform);
156         if (!algorithm) {
157             std::cerr << "Error: invalid algorithm " << transform << std::endl;
158             return false;
159         }
160         key_transform_ = transform;
161         AuthorizationSetBuilder authSet = AuthorizationSetBuilder()
162                                                   .Authorization(TAG_NO_AUTH_REQUIRED)
163                                                   .Authorization(TAG_PURPOSE, KeyPurpose::ENCRYPT)
164                                                   .Authorization(TAG_PURPOSE, KeyPurpose::DECRYPT)
165                                                   .Authorization(TAG_PURPOSE, KeyPurpose::SIGN)
166                                                   .Authorization(TAG_PURPOSE, KeyPurpose::VERIFY)
167                                                   .Authorization(TAG_KEY_SIZE, keySize)
168                                                   .Authorization(TAG_ALGORITHM, algorithm.value())
169                                                   .Digest(getDigest(transform))
170                                                   .Padding(getPadding(transform, sign));
171         std::optional<BlockMode> blockMode = getBlockMode(transform);
172         if (blockMode) {
173             authSet.BlockMode(blockMode.value());
174             if (blockMode == BlockMode::GCM) {
175                 authSet.Authorization(TAG_MIN_MAC_LENGTH, 128);
176             }
177         }
178         if (algorithm == Algorithm::HMAC) {
179             authSet.Authorization(TAG_MIN_MAC_LENGTH, 128);
180         }
181         if (algorithm == Algorithm::RSA) {
182             authSet.Authorization(TAG_RSA_PUBLIC_EXPONENT, 65537U);
183             authSet.SetDefaultValidity();
184         }
185         if (algorithm == Algorithm::EC) {
186             authSet.SetDefaultValidity();
187         }
188         error_ = GenerateKey(authSet);
189         return error_ == ErrorCode::OK;
190     }
191 
getOperationParams(string transform,bool sign=false)192     AuthorizationSet getOperationParams(string transform, bool sign = false) {
193         AuthorizationSetBuilder builder = AuthorizationSetBuilder()
194                                                   .Padding(getPadding(transform, sign))
195                                                   .Digest(getDigest(transform));
196         std::optional<BlockMode> blockMode = getBlockMode(transform);
197         if (sign && (transform.find("Hmac") != string::npos)) {
198             builder.Authorization(TAG_MAC_LENGTH, 128);
199         }
200         if (blockMode) {
201             builder.BlockMode(*blockMode);
202             if (blockMode == BlockMode::GCM) {
203                 builder.Authorization(TAG_MAC_LENGTH, 128);
204             }
205         }
206         return std::move(builder);
207     }
208 
Process(const string & message,const string & signature="")209     optional<string> Process(const string& message, const string& signature = "") {
210         ErrorCode result;
211 
212         string output;
213         result = Finish(message, signature, &output);
214         if (result != ErrorCode::OK) {
215             error_ = result;
216             return {};
217         }
218         return output;
219     }
220 
DeleteKey()221     ErrorCode DeleteKey() {
222         Status result = keymint_->deleteKey(key_blob_);
223         key_blob_ = vector<uint8_t>();
224         return GetReturnErrorCode(result);
225     }
226 
Begin(KeyPurpose purpose,const AuthorizationSet & in_params,AuthorizationSet * out_params)227     ErrorCode Begin(KeyPurpose purpose, const AuthorizationSet& in_params,
228                     AuthorizationSet* out_params) {
229         Status result;
230         BeginResult out;
231         result = keymint_->begin(purpose, key_blob_, in_params.vector_data(), std::nullopt, &out);
232         if (result.isOk()) {
233             *out_params = out.params;
234             op_ = out.operation;
235         }
236         return GetReturnErrorCode(result);
237     }
238 
239     SecurityLevel securityLevel_;
240     string name_;
241 
242   private:
GenerateKey(const AuthorizationSet & key_desc,const optional<AttestationKey> & attest_key=std::nullopt)243     ErrorCode GenerateKey(const AuthorizationSet& key_desc,
244                           const optional<AttestationKey>& attest_key = std::nullopt) {
245         key_blob_.clear();
246         KeyCreationResult creationResult;
247         Status result = keymint_->generateKey(key_desc.vector_data(), attest_key, &creationResult);
248         if (result.isOk()) {
249             key_blob_ = std::move(creationResult.keyBlob);
250             creationResult.keyCharacteristics.clear();
251             creationResult.certificateChain.clear();
252         }
253         return GetReturnErrorCode(result);
254     }
255 
InitializeKeyMint(std::shared_ptr<IKeyMintDevice> keyMint)256     void InitializeKeyMint(std::shared_ptr<IKeyMintDevice> keyMint) {
257         if (!keyMint) {
258             std::cerr << "Trying initialize nullptr in InitializeKeyMint" << std::endl;
259             return;
260         }
261         keymint_ = std::move(keyMint);
262         KeyMintHardwareInfo info;
263         Status result = keymint_->getHardwareInfo(&info);
264         if (!result.isOk()) {
265             std::cerr << "InitializeKeyMint: getHardwareInfo failed with "
266                       << result.getServiceSpecificError() << std::endl;
267         }
268         securityLevel_ = info.securityLevel;
269         name_.assign(info.keyMintName.begin(), info.keyMintName.end());
270     }
271 
Finish(const string & input,const string & signature,string * output)272     ErrorCode Finish(const string& input, const string& signature, string* output) {
273         if (!op_) {
274             std::cerr << "Finish: Operation is nullptr" << std::endl;
275             return ErrorCode::UNEXPECTED_NULL_POINTER;
276         }
277 
278         vector<uint8_t> oPut;
279         Status result =
280                 op_->finish(vector<uint8_t>(input.begin(), input.end()),
281                             vector<uint8_t>(signature.begin(), signature.end()), {} /* authToken */,
282                             {} /* timestampToken */, {} /* confirmationToken */, &oPut);
283 
284         if (result.isOk()) output->append(oPut.begin(), oPut.end());
285 
286         op_.reset();
287         return GetReturnErrorCode(result);
288     }
289 
Update(const string & input,string * output)290     ErrorCode Update(const string& input, string* output) {
291         Status result;
292         if (!op_) {
293             std::cerr << "Update: Operation is nullptr" << std::endl;
294             return ErrorCode::UNEXPECTED_NULL_POINTER;
295         }
296 
297         std::vector<uint8_t> o_put;
298         result = op_->update(vector<uint8_t>(input.begin(), input.end()), {} /* authToken */,
299                              {} /* timestampToken */, &o_put);
300 
301         if (result.isOk() && output) *output = {o_put.begin(), o_put.end()};
302         return GetReturnErrorCode(result);
303     }
304 
GetReturnErrorCode(const Status & result)305     ErrorCode GetReturnErrorCode(const Status& result) {
306         error_ = static_cast<ErrorCode>(result.getServiceSpecificError());
307         if (result.isOk()) return ErrorCode::OK;
308 
309         if (result.getExceptionCode() == EX_SERVICE_SPECIFIC) {
310             return static_cast<ErrorCode>(result.getServiceSpecificError());
311         }
312 
313         return ErrorCode::UNKNOWN_ERROR;
314     }
315 
316     std::shared_ptr<IKeyMintOperation> op_;
317     vector<Certificate> cert_chain_;
318     vector<uint8_t> key_blob_;
319     vector<KeyCharacteristics> key_characteristics_;
320     std::shared_ptr<IKeyMintDevice> keymint_;
321     std::vector<string> message_cache_;
322     std::string key_transform_;
323     ErrorCode error_;
324 };
325 
326 KeyMintBenchmarkTest* keymintTest;
327 
settings(benchmark::internal::Benchmark * benchmark)328 static void settings(benchmark::internal::Benchmark* benchmark) {
329     benchmark->Unit(benchmark::kMillisecond);
330 }
331 
addDefaultLabel(benchmark::State & state)332 static void addDefaultLabel(benchmark::State& state) {
333     std::string secLevel;
334     switch (keymintTest->securityLevel_) {
335         case SecurityLevel::STRONGBOX:
336             secLevel = "STRONGBOX";
337             break;
338         case SecurityLevel::SOFTWARE:
339             secLevel = "SOFTWARE";
340             break;
341         case SecurityLevel::TRUSTED_ENVIRONMENT:
342             secLevel = "TEE";
343             break;
344         case SecurityLevel::KEYSTORE:
345             secLevel = "KEYSTORE";
346             break;
347     }
348     state.SetLabel("hardware_name:" + keymintTest->name_ + " sec_level:" + secLevel);
349 }
350 
351 // clang-format off
352 #define BENCHMARK_KM(func, transform, keySize) \
353     BENCHMARK_CAPTURE(func, transform/keySize, #transform "/" #keySize, keySize)->Apply(settings);
354 #define BENCHMARK_KM_MSG(func, transform, keySize, msgSize)                                      \
355     BENCHMARK_CAPTURE(func, transform/keySize/msgSize, #transform "/" #keySize "/" #msgSize, \
356                       keySize, msgSize)                                                          \
357             ->Apply(settings);
358 
359 #define BENCHMARK_KM_ALL_MSGS(func, transform, keySize)             \
360     BENCHMARK_KM_MSG(func, transform, keySize, SMALL_MESSAGE_SIZE)  \
361     BENCHMARK_KM_MSG(func, transform, keySize, MEDIUM_MESSAGE_SIZE) \
362     BENCHMARK_KM_MSG(func, transform, keySize, LARGE_MESSAGE_SIZE)
363 
364 #define BENCHMARK_KM_CIPHER(transform, keySize, msgSize)   \
365     BENCHMARK_KM_MSG(encrypt, transform, keySize, msgSize) \
366     BENCHMARK_KM_MSG(decrypt, transform, keySize, msgSize)
367 
368 #define BENCHMARK_KM_CIPHER_ALL_MSGS(transform, keySize) \
369     BENCHMARK_KM_ALL_MSGS(encrypt, transform, keySize)   \
370     BENCHMARK_KM_ALL_MSGS(decrypt, transform, keySize)
371 
372 #define BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, keySize) \
373     BENCHMARK_KM_ALL_MSGS(sign, transform, keySize)         \
374     BENCHMARK_KM_ALL_MSGS(verify, transform, keySize)
375 // clang-format on
376 
377 /*
378  * ============= KeyGen TESTS ==================
379  */
keygen(benchmark::State & state,string transform,int keySize)380 static void keygen(benchmark::State& state, string transform, int keySize) {
381     addDefaultLabel(state);
382     for (auto _ : state) {
383         if (!keymintTest->GenerateKey(transform, keySize)) {
384             state.SkipWithError(
385                     ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
386         }
387         state.PauseTiming();
388 
389         keymintTest->DeleteKey();
390         state.ResumeTiming();
391     }
392 }
393 
394 BENCHMARK_KM(keygen, AES, 128);
395 BENCHMARK_KM(keygen, AES, 256);
396 
397 BENCHMARK_KM(keygen, RSA, 2048);
398 BENCHMARK_KM(keygen, RSA, 3072);
399 BENCHMARK_KM(keygen, RSA, 4096);
400 
401 BENCHMARK_KM(keygen, EC, 224);
402 BENCHMARK_KM(keygen, EC, 256);
403 BENCHMARK_KM(keygen, EC, 384);
404 BENCHMARK_KM(keygen, EC, 521);
405 
406 BENCHMARK_KM(keygen, DESede, 168);
407 
408 BENCHMARK_KM(keygen, Hmac, 64);
409 BENCHMARK_KM(keygen, Hmac, 128);
410 BENCHMARK_KM(keygen, Hmac, 256);
411 BENCHMARK_KM(keygen, Hmac, 512);
412 
413 /*
414  * ============= SIGNATURE TESTS ==================
415  */
416 
sign(benchmark::State & state,string transform,int keySize,int msgSize)417 static void sign(benchmark::State& state, string transform, int keySize, int msgSize) {
418     addDefaultLabel(state);
419     if (!keymintTest->GenerateKey(transform, keySize, true)) {
420         state.SkipWithError(
421                 ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
422         return;
423     }
424 
425     auto in_params = keymintTest->getOperationParams(transform, true);
426     AuthorizationSet out_params;
427     string message = keymintTest->GenerateMessage(msgSize);
428 
429     for (auto _ : state) {
430         state.PauseTiming();
431         ErrorCode error = keymintTest->Begin(KeyPurpose::SIGN, in_params, &out_params);
432         if (error != ErrorCode::OK) {
433             state.SkipWithError(
434                     ("Error beginning sign, " + std::to_string(keymintTest->getError())).c_str());
435             return;
436         }
437         state.ResumeTiming();
438         out_params.Clear();
439         if (!keymintTest->Process(message)) {
440             state.SkipWithError(("Sign error, " + std::to_string(keymintTest->getError())).c_str());
441             break;
442         }
443     }
444 }
445 
verify(benchmark::State & state,string transform,int keySize,int msgSize)446 static void verify(benchmark::State& state, string transform, int keySize, int msgSize) {
447     addDefaultLabel(state);
448     if (!keymintTest->GenerateKey(transform, keySize, true)) {
449         state.SkipWithError(
450                 ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
451         return;
452     }
453     AuthorizationSet out_params;
454     auto in_params = keymintTest->getOperationParams(transform, true);
455     string message = keymintTest->GenerateMessage(msgSize);
456     ErrorCode error = keymintTest->Begin(KeyPurpose::SIGN, in_params, &out_params);
457     if (error != ErrorCode::OK) {
458         state.SkipWithError(
459                 ("Error beginning sign, " + std::to_string(keymintTest->getError())).c_str());
460         return;
461     }
462     std::optional<string> signature = keymintTest->Process(message);
463     if (!signature) {
464         state.SkipWithError(("Sign error, " + std::to_string(keymintTest->getError())).c_str());
465         return;
466     }
467     out_params.Clear();
468     if (transform.find("Hmac") != string::npos) {
469         in_params = keymintTest->getOperationParams(transform, false);
470     }
471     for (auto _ : state) {
472         state.PauseTiming();
473         error = keymintTest->Begin(KeyPurpose::VERIFY, in_params, &out_params);
474         if (error != ErrorCode::OK) {
475             state.SkipWithError(
476                     ("Verify begin error, " + std::to_string(keymintTest->getError())).c_str());
477             return;
478         }
479         state.ResumeTiming();
480         if (!keymintTest->Process(message, *signature)) {
481             state.SkipWithError(
482                     ("Verify error, " + std::to_string(keymintTest->getError())).c_str());
483             break;
484         }
485     }
486 }
487 
488 // clang-format off
489 #define BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(transform) \
490     BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 64)      \
491     BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 128)     \
492     BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 256)     \
493     BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 512)
494 
495 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA1)
496 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA256)
497 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA224)
498 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA256)
499 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA384)
500 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA512)
501 
502 #define BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(transform) \
503     BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 224)      \
504     BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 256)      \
505     BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 384)      \
506     BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 521)
507 
508 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(NONEwithECDSA);
509 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA1withECDSA);
510 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA224withECDSA);
511 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA256withECDSA);
512 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA384withECDSA);
513 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA512withECDSA);
514 
515 #define BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(transform) \
516     BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 2048)   \
517     BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 3072)   \
518     BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 4096)
519 
520 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(MD5withRSA);
521 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA1withRSA);
522 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA224withRSA);
523 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA384withRSA);
524 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA512withRSA);
525 
526 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(MD5withRSA/PSS);
527 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA1withRSA/PSS);
528 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA224withRSA/PSS);
529 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA384withRSA/PSS);
530 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA512withRSA/PSS);
531 // clang-format on
532 
533 /*
534  * ============= CIPHER TESTS ==================
535  */
536 
encrypt(benchmark::State & state,string transform,int keySize,int msgSize)537 static void encrypt(benchmark::State& state, string transform, int keySize, int msgSize) {
538     addDefaultLabel(state);
539     if (!keymintTest->GenerateKey(transform, keySize)) {
540         state.SkipWithError(
541                 ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
542         return;
543     }
544     auto in_params = keymintTest->getOperationParams(transform);
545     AuthorizationSet out_params;
546     string message = keymintTest->GenerateMessage(msgSize);
547 
548     for (auto _ : state) {
549         state.PauseTiming();
550         auto error = keymintTest->Begin(KeyPurpose::ENCRYPT, in_params, &out_params);
551         if (error != ErrorCode::OK) {
552             state.SkipWithError(
553                     ("Encryption begin error, " + std::to_string(keymintTest->getError())).c_str());
554             return;
555         }
556         out_params.Clear();
557         state.ResumeTiming();
558         if (!keymintTest->Process(message)) {
559             state.SkipWithError(
560                     ("Encryption error, " + std::to_string(keymintTest->getError())).c_str());
561             break;
562         }
563     }
564 }
565 
decrypt(benchmark::State & state,string transform,int keySize,int msgSize)566 static void decrypt(benchmark::State& state, string transform, int keySize, int msgSize) {
567     addDefaultLabel(state);
568     if (!keymintTest->GenerateKey(transform, keySize)) {
569         state.SkipWithError(
570                 ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
571         return;
572     }
573     AuthorizationSet out_params;
574     AuthorizationSet in_params = keymintTest->getOperationParams(transform);
575     string message = keymintTest->GenerateMessage(msgSize);
576     auto error = keymintTest->Begin(KeyPurpose::ENCRYPT, in_params, &out_params);
577     if (error != ErrorCode::OK) {
578         state.SkipWithError(
579                 ("Encryption begin error, " + std::to_string(keymintTest->getError())).c_str());
580         return;
581     }
582     auto encryptedMessage = keymintTest->Process(message);
583     if (!encryptedMessage) {
584         state.SkipWithError(
585                 ("Encryption error, " + std::to_string(keymintTest->getError())).c_str());
586         return;
587     }
588     in_params.push_back(out_params);
589     out_params.Clear();
590     for (auto _ : state) {
591         state.PauseTiming();
592         error = keymintTest->Begin(KeyPurpose::DECRYPT, in_params, &out_params);
593         if (error != ErrorCode::OK) {
594             state.SkipWithError(
595                     ("Decryption begin error, " + std::to_string(keymintTest->getError())).c_str());
596             return;
597         }
598         state.ResumeTiming();
599         if (!keymintTest->Process(*encryptedMessage)) {
600             state.SkipWithError(
601                     ("Decryption error, " + std::to_string(keymintTest->getError())).c_str());
602             break;
603         }
604     }
605 }
606 
607 // clang-format off
608 // AES
609 #define BENCHMARK_KM_CIPHER_ALL_AES_KEYS(transform) \
610     BENCHMARK_KM_CIPHER_ALL_MSGS(transform, 128)    \
611     BENCHMARK_KM_CIPHER_ALL_MSGS(transform, 256)
612 
613 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/CBC/NoPadding);
614 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/CBC/PKCS7Padding);
615 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/CTR/NoPadding);
616 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/ECB/NoPadding);
617 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/ECB/PKCS7Padding);
618 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/GCM/NoPadding);
619 
620 // Triple DES
621 BENCHMARK_KM_CIPHER_ALL_MSGS(DESede/CBC/NoPadding, 168);
622 BENCHMARK_KM_CIPHER_ALL_MSGS(DESede/CBC/PKCS7Padding, 168);
623 BENCHMARK_KM_CIPHER_ALL_MSGS(DESede/ECB/NoPadding, 168);
624 BENCHMARK_KM_CIPHER_ALL_MSGS(DESede/ECB/PKCS7Padding, 168);
625 
626 #define BENCHMARK_KM_CIPHER_ALL_RSA_KEYS(transform, msgSize) \
627     BENCHMARK_KM_CIPHER(transform, 2048, msgSize)            \
628     BENCHMARK_KM_CIPHER(transform, 3072, msgSize)            \
629     BENCHMARK_KM_CIPHER(transform, 4096, msgSize)
630 
631 BENCHMARK_KM_CIPHER_ALL_RSA_KEYS(RSA/ECB/NoPadding, SMALL_MESSAGE_SIZE);
632 BENCHMARK_KM_CIPHER_ALL_RSA_KEYS(RSA/ECB/PKCS1Padding, SMALL_MESSAGE_SIZE);
633 BENCHMARK_KM_CIPHER_ALL_RSA_KEYS(RSA/ECB/OAEPPadding, SMALL_MESSAGE_SIZE);
634 
635 // clang-format on
636 }  // namespace aidl::android::hardware::security::keymint::test
637 
main(int argc,char ** argv)638 int main(int argc, char** argv) {
639     ::benchmark::Initialize(&argc, argv);
640     base::CommandLine::Init(argc, argv);
641     base::CommandLine* command_line = base::CommandLine::ForCurrentProcess();
642     auto service_name = command_line->GetSwitchValueASCII("service_name");
643     if (service_name.empty()) {
644         service_name =
645                 std::string(
646                         aidl::android::hardware::security::keymint::IKeyMintDevice::descriptor) +
647                 "/default";
648     }
649     std::cerr << service_name << std::endl;
650     aidl::android::hardware::security::keymint::test::keymintTest =
651             aidl::android::hardware::security::keymint::test::KeyMintBenchmarkTest::newInstance(
652                     service_name.c_str());
653     if (!aidl::android::hardware::security::keymint::test::keymintTest) {
654         return 1;
655     }
656     ::benchmark::RunSpecifiedBenchmarks();
657 }
658