1 /**
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <dlfcn.h>
18 #include <string.h>
19 #include <openssl/ssl.h>
20 #include <openssl/crypto.h>
21 #include <openssl/bn.h>
22 #include <memory>
23 #include "../includes/common.h"
24
25 /** NOTE: These values are for the BIGNUM declared in kBN2DecTests and */
26 /** must be updated if kBN2DecTests is changed. */
27 #if _32_BIT
28 #define ALLOCATION_SIZE 52
29 static const int sMallocSkipCount[] = {1,0};
30 #else
31 #define ALLOCATION_SIZE 56
32 static const int sMallocSkipCount[] = {0,0};
33 #endif
34
35 static const char *kTest =
36 "123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890";
37
38 static int sCount = 0;
39 static bool sOverloadMalloc = false;
40 int loopIndex = 0;
41
42 template<typename T>
43 struct OpenSSLFree {
operator ()OpenSSLFree44 void operator()(T *buf) {
45 OPENSSL_free(buf);
46 }
47 };
48
49 using ScopedOpenSSLString = std::unique_ptr<char, OpenSSLFree<char>>;
50
51 namespace crypto {
52 template<typename T, void (*func)(T*)>
53 struct OpenSSLDeleter {
operator ()crypto::OpenSSLDeleter54 void operator()(T *obj) {
55 func(obj);
56 }
57 };
58
59 template<typename Type, void (*Destroyer)(Type*)>
60 struct OpenSSLDestroyer {
operator ()crypto::OpenSSLDestroyer61 void operator()(Type* ptr) const {
62 Destroyer(ptr);
63 }
64 };
65
66 template<typename T, void (*func)(T*)>
67 using ScopedOpenSSLType = std::unique_ptr<T, OpenSSLDeleter<T, func>>;
68
69 template<typename PointerType, void (*Destroyer)(PointerType*)>
70 using ScopedOpenSSL =
71 std::unique_ptr<PointerType, OpenSSLDestroyer<PointerType, Destroyer>>;
72
73 struct OpenSSLFree {
operator ()crypto::OpenSSLFree74 void operator()(uint8_t* ptr) const {
75 OPENSSL_free(ptr);
76 }
77 };
78
79 using ScopedBIGNUM = ScopedOpenSSL<BIGNUM, BN_free>;
80 using ScopedBN_CTX = ScopedOpenSSLType<BN_CTX, BN_CTX_free>;
81 } // namespace crypto
82
DecimalToBIGNUM(crypto::ScopedBIGNUM * out,const char * in)83 static int DecimalToBIGNUM(crypto::ScopedBIGNUM *out, const char *in) {
84 BIGNUM *raw = nullptr;
85 int ret = BN_dec2bn(&raw, in);
86 out->reset(raw);
87 return ret;
88 }
89
90 void* (*realMalloc)(size_t) = nullptr;
91
mtraceInit(void)92 void mtraceInit(void) {
93 realMalloc = (void *(*)(size_t))dlsym(RTLD_NEXT, "malloc");
94 return;
95 }
96
malloc(size_t size)97 void *malloc(size_t size) {
98 if (realMalloc == nullptr) {
99 mtraceInit();
100 }
101 if (!sOverloadMalloc) {
102 return realMalloc(size);
103 }
104 if (size == ALLOCATION_SIZE) {
105 if (sCount >= sMallocSkipCount[loopIndex]) {
106 return nullptr;
107 }
108 ++sCount;
109 }
110 return realMalloc(size);
111 }
112
113 using namespace crypto;
114
main()115 int main() {
116 CRYPTO_library_init();
117 ScopedBN_CTX ctx(BN_CTX_new());
118 if (!ctx) {
119 return EXIT_FAILURE;
120 }
121 for(loopIndex = 0; loopIndex < 2; ++loopIndex) {
122 ScopedBIGNUM bn;
123 int ret = DecimalToBIGNUM(&bn, kTest);
124 if (!ret) {
125 return EXIT_FAILURE;
126 }
127 sOverloadMalloc = true;
128 ScopedOpenSSLString dec(BN_bn2dec(bn.get()));
129 sOverloadMalloc = false;
130 if (!dec) {
131 return EXIT_FAILURE;
132 }
133 if (strcmp(dec.get(), kTest)) {
134 return EXIT_FAILURE;
135 }
136 }
137 return EXIT_SUCCESS;
138 }
139