1 /**
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <dlfcn.h>
18 #include <string.h>
19 #include <openssl/ssl.h>
20 #include <openssl/crypto.h>
21 #include <openssl/bn.h>
22 #include <memory>
23 #include "../includes/common.h"
24 
25 /** NOTE: These values are for the BIGNUM declared in kBN2DecTests and  */
26 /** must be updated if kBN2DecTests is changed.                         */
27 #if _32_BIT
28 #define ALLOCATION_SIZE      52
29 static const int sMallocSkipCount[] = {1,0};
30 #else
31 #define ALLOCATION_SIZE      56
32 static const int sMallocSkipCount[] = {0,0};
33 #endif
34 
35 static const char *kTest =
36     "123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890";
37 
38 static int sCount = 0;
39 static bool sOverloadMalloc = false;
40 int loopIndex = 0;
41 
42 template<typename T>
43 struct OpenSSLFree {
operator ()OpenSSLFree44     void operator()(T *buf) {
45         OPENSSL_free(buf);
46     }
47 };
48 
49 using ScopedOpenSSLString = std::unique_ptr<char, OpenSSLFree<char>>;
50 
51 namespace crypto {
52 template<typename T, void (*func)(T*)>
53 struct OpenSSLDeleter {
operator ()crypto::OpenSSLDeleter54     void operator()(T *obj) {
55         func(obj);
56     }
57 };
58 
59 template<typename Type, void (*Destroyer)(Type*)>
60 struct OpenSSLDestroyer {
operator ()crypto::OpenSSLDestroyer61     void operator()(Type* ptr) const {
62         Destroyer(ptr);
63     }
64 };
65 
66 template<typename T, void (*func)(T*)>
67 using ScopedOpenSSLType = std::unique_ptr<T, OpenSSLDeleter<T, func>>;
68 
69 template<typename PointerType, void (*Destroyer)(PointerType*)>
70 using ScopedOpenSSL =
71 std::unique_ptr<PointerType, OpenSSLDestroyer<PointerType, Destroyer>>;
72 
73 struct OpenSSLFree {
operator ()crypto::OpenSSLFree74     void operator()(uint8_t* ptr) const {
75         OPENSSL_free(ptr);
76     }
77 };
78 
79 using ScopedBIGNUM = ScopedOpenSSL<BIGNUM, BN_free>;
80 using ScopedBN_CTX = ScopedOpenSSLType<BN_CTX, BN_CTX_free>;
81 }  // namespace crypto
82 
DecimalToBIGNUM(crypto::ScopedBIGNUM * out,const char * in)83 static int DecimalToBIGNUM(crypto::ScopedBIGNUM *out, const char *in) {
84     BIGNUM *raw = nullptr;
85     int ret = BN_dec2bn(&raw, in);
86     out->reset(raw);
87     return ret;
88 }
89 
90 void* (*realMalloc)(size_t) = nullptr;
91 
mtraceInit(void)92 void mtraceInit(void) {
93     realMalloc = (void *(*)(size_t))dlsym(RTLD_NEXT, "malloc");
94     return;
95 }
96 
malloc(size_t size)97 void *malloc(size_t size) {
98     if (realMalloc == nullptr) {
99         mtraceInit();
100     }
101     if (!sOverloadMalloc) {
102         return realMalloc(size);
103     }
104     if (size == ALLOCATION_SIZE) {
105         if (sCount >= sMallocSkipCount[loopIndex]) {
106             return nullptr;
107         }
108         ++sCount;
109     }
110     return realMalloc(size);
111 }
112 
113 using namespace crypto;
114 
main()115 int main() {
116     CRYPTO_library_init();
117     ScopedBN_CTX ctx(BN_CTX_new());
118     if (!ctx) {
119         return EXIT_FAILURE;
120     }
121     for(loopIndex = 0; loopIndex < 2; ++loopIndex) {
122         ScopedBIGNUM bn;
123         int ret = DecimalToBIGNUM(&bn, kTest);
124         if (!ret) {
125             return EXIT_FAILURE;
126         }
127         sOverloadMalloc = true;
128         ScopedOpenSSLString dec(BN_bn2dec(bn.get()));
129         sOverloadMalloc = false;
130         if (!dec) {
131             return EXIT_FAILURE;
132         }
133         if (strcmp(dec.get(), kTest)) {
134             return EXIT_FAILURE;
135         }
136     }
137     return EXIT_SUCCESS;
138 }
139