1 //===----------------------------------------------------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is dual licensed under the MIT and the University of Illinois Open
6 // Source Licenses. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 
10 #ifndef COUNT_NEW_HPP
11 #define COUNT_NEW_HPP
12 
13 # include <cstdlib>
14 # include <cassert>
15 # include <new>
16 
17 #include "test_macros.h"
18 
19 #if defined(TEST_HAS_SANITIZERS)
20 #define DISABLE_NEW_COUNT
21 #endif
22 
23 namespace detail
24 {
25    TEST_NORETURN
throw_bad_alloc_helper()26    inline void throw_bad_alloc_helper() {
27 #ifndef TEST_HAS_NO_EXCEPTIONS
28        throw std::bad_alloc();
29 #else
30        std::abort();
31 #endif
32    }
33 }
34 
35 class MemCounter
36 {
37 public:
38     // Make MemCounter super hard to accidentally construct or copy.
39     class MemCounterCtorArg_ {};
MemCounter(MemCounterCtorArg_)40     explicit MemCounter(MemCounterCtorArg_) { reset(); }
41 
42 private:
43     MemCounter(MemCounter const &);
44     MemCounter & operator=(MemCounter const &);
45 
46 public:
47     // All checks return true when disable_checking is enabled.
48     static const bool disable_checking;
49 
50     // Disallow any allocations from occurring. Useful for testing that
51     // code doesn't perform any allocations.
52     bool disable_allocations;
53 
54     // number of allocations to throw after. Default (unsigned)-1. If
55     // throw_after has the default value it will never be decremented.
56     static const unsigned never_throw_value = static_cast<unsigned>(-1);
57     unsigned throw_after;
58 
59     int outstanding_new;
60     int new_called;
61     int delete_called;
62     std::size_t last_new_size;
63 
64     int outstanding_array_new;
65     int new_array_called;
66     int delete_array_called;
67     std::size_t last_new_array_size;
68 
69 public:
newCalled(std::size_t s)70     void newCalled(std::size_t s)
71     {
72         assert(disable_allocations == false);
73         assert(s);
74         if (throw_after == 0) {
75             throw_after = never_throw_value;
76             detail::throw_bad_alloc_helper();
77         } else if (throw_after != never_throw_value) {
78             --throw_after;
79         }
80         ++new_called;
81         ++outstanding_new;
82         last_new_size = s;
83     }
84 
deleteCalled(void * p)85     void deleteCalled(void * p)
86     {
87         assert(p);
88         --outstanding_new;
89         ++delete_called;
90     }
91 
newArrayCalled(std::size_t s)92     void newArrayCalled(std::size_t s)
93     {
94         assert(disable_allocations == false);
95         assert(s);
96         if (throw_after == 0) {
97             throw_after = never_throw_value;
98             detail::throw_bad_alloc_helper();
99         } else {
100             // don't decrement throw_after here. newCalled will end up doing that.
101         }
102         ++outstanding_array_new;
103         ++new_array_called;
104         last_new_array_size = s;
105     }
106 
deleteArrayCalled(void * p)107     void deleteArrayCalled(void * p)
108     {
109         assert(p);
110         --outstanding_array_new;
111         ++delete_array_called;
112     }
113 
disableAllocations()114     void disableAllocations()
115     {
116         disable_allocations = true;
117     }
118 
enableAllocations()119     void enableAllocations()
120     {
121         disable_allocations = false;
122     }
123 
124 
reset()125     void reset()
126     {
127         disable_allocations = false;
128         throw_after = never_throw_value;
129 
130         outstanding_new = 0;
131         new_called = 0;
132         delete_called = 0;
133         last_new_size = 0;
134 
135         outstanding_array_new = 0;
136         new_array_called = 0;
137         delete_array_called = 0;
138         last_new_array_size = 0;
139     }
140 
141 public:
checkOutstandingNewEq(int n) const142     bool checkOutstandingNewEq(int n) const
143     {
144         return disable_checking || n == outstanding_new;
145     }
146 
checkOutstandingNewNotEq(int n) const147     bool checkOutstandingNewNotEq(int n) const
148     {
149         return disable_checking || n != outstanding_new;
150     }
151 
checkNewCalledEq(int n) const152     bool checkNewCalledEq(int n) const
153     {
154         return disable_checking || n == new_called;
155     }
156 
checkNewCalledNotEq(int n) const157     bool checkNewCalledNotEq(int n) const
158     {
159         return disable_checking || n != new_called;
160     }
161 
checkNewCalledGreaterThan(int n) const162     bool checkNewCalledGreaterThan(int n) const
163     {
164         return disable_checking || new_called > n;
165     }
166 
checkDeleteCalledEq(int n) const167     bool checkDeleteCalledEq(int n) const
168     {
169         return disable_checking || n == delete_called;
170     }
171 
checkDeleteCalledNotEq(int n) const172     bool checkDeleteCalledNotEq(int n) const
173     {
174         return disable_checking || n != delete_called;
175     }
176 
checkLastNewSizeEq(std::size_t n) const177     bool checkLastNewSizeEq(std::size_t n) const
178     {
179         return disable_checking || n == last_new_size;
180     }
181 
checkLastNewSizeNotEq(std::size_t n) const182     bool checkLastNewSizeNotEq(std::size_t n) const
183     {
184         return disable_checking || n != last_new_size;
185     }
186 
checkOutstandingArrayNewEq(int n) const187     bool checkOutstandingArrayNewEq(int n) const
188     {
189         return disable_checking || n == outstanding_array_new;
190     }
191 
checkOutstandingArrayNewNotEq(int n) const192     bool checkOutstandingArrayNewNotEq(int n) const
193     {
194         return disable_checking || n != outstanding_array_new;
195     }
196 
checkNewArrayCalledEq(int n) const197     bool checkNewArrayCalledEq(int n) const
198     {
199         return disable_checking || n == new_array_called;
200     }
201 
checkNewArrayCalledNotEq(int n) const202     bool checkNewArrayCalledNotEq(int n) const
203     {
204         return disable_checking || n != new_array_called;
205     }
206 
checkDeleteArrayCalledEq(int n) const207     bool checkDeleteArrayCalledEq(int n) const
208     {
209         return disable_checking || n == delete_array_called;
210     }
211 
checkDeleteArrayCalledNotEq(int n) const212     bool checkDeleteArrayCalledNotEq(int n) const
213     {
214         return disable_checking || n != delete_array_called;
215     }
216 
checkLastNewArraySizeEq(std::size_t n) const217     bool checkLastNewArraySizeEq(std::size_t n) const
218     {
219         return disable_checking || n == last_new_array_size;
220     }
221 
checkLastNewArraySizeNotEq(std::size_t n) const222     bool checkLastNewArraySizeNotEq(std::size_t n) const
223     {
224         return disable_checking || n != last_new_array_size;
225     }
226 };
227 
228 #ifdef DISABLE_NEW_COUNT
229   const bool MemCounter::disable_checking = true;
230 #else
231   const bool MemCounter::disable_checking = false;
232 #endif
233 
234 MemCounter globalMemCounter((MemCounter::MemCounterCtorArg_()));
235 
236 #ifndef DISABLE_NEW_COUNT
operator new(std::size_t s)237 void* operator new(std::size_t s) TEST_THROW_SPEC(std::bad_alloc)
238 {
239     globalMemCounter.newCalled(s);
240     void* ret = std::malloc(s);
241     if (ret == nullptr)
242         detail::throw_bad_alloc_helper();
243     return ret;
244 }
245 
operator delete(void * p)246 void  operator delete(void* p) TEST_NOEXCEPT
247 {
248     globalMemCounter.deleteCalled(p);
249     std::free(p);
250 }
251 
252 
operator new[](std::size_t s)253 void* operator new[](std::size_t s) TEST_THROW_SPEC(std::bad_alloc)
254 {
255     globalMemCounter.newArrayCalled(s);
256     return operator new(s);
257 }
258 
259 
operator delete[](void * p)260 void operator delete[](void* p) TEST_NOEXCEPT
261 {
262     globalMemCounter.deleteArrayCalled(p);
263     operator delete(p);
264 }
265 
266 #endif // DISABLE_NEW_COUNT
267 
268 
269 struct DisableAllocationGuard {
DisableAllocationGuardDisableAllocationGuard270     explicit DisableAllocationGuard(bool disable = true) : m_disabled(disable)
271     {
272         // Don't re-disable if already disabled.
273         if (globalMemCounter.disable_allocations == true) m_disabled = false;
274         if (m_disabled) globalMemCounter.disableAllocations();
275     }
276 
releaseDisableAllocationGuard277     void release() {
278         if (m_disabled) globalMemCounter.enableAllocations();
279         m_disabled = false;
280     }
281 
~DisableAllocationGuardDisableAllocationGuard282     ~DisableAllocationGuard() {
283         release();
284     }
285 
286 private:
287     bool m_disabled;
288 
289     DisableAllocationGuard(DisableAllocationGuard const&);
290     DisableAllocationGuard& operator=(DisableAllocationGuard const&);
291 };
292 
293 
294 struct RequireAllocationGuard {
RequireAllocationGuardRequireAllocationGuard295     explicit RequireAllocationGuard(std::size_t RequireAtLeast = 1)
296             : m_req_alloc(RequireAtLeast),
297               m_new_count_on_init(globalMemCounter.new_called),
298               m_outstanding_new_on_init(globalMemCounter.outstanding_new),
299               m_exactly(false)
300     {
301     }
302 
requireAtLeastRequireAllocationGuard303     void requireAtLeast(std::size_t N) { m_req_alloc = N; m_exactly = false; }
requireExactlyRequireAllocationGuard304     void requireExactly(std::size_t N) { m_req_alloc = N; m_exactly = true; }
305 
~RequireAllocationGuardRequireAllocationGuard306     ~RequireAllocationGuard() {
307         assert(globalMemCounter.checkOutstandingNewEq(static_cast<int>(m_outstanding_new_on_init)));
308         std::size_t Expect = m_new_count_on_init + m_req_alloc;
309         assert(globalMemCounter.checkNewCalledEq(static_cast<int>(Expect)) ||
310                (!m_exactly && globalMemCounter.checkNewCalledGreaterThan(static_cast<int>(Expect))));
311     }
312 
313 private:
314     std::size_t m_req_alloc;
315     const std::size_t m_new_count_on_init;
316     const std::size_t m_outstanding_new_on_init;
317     bool m_exactly;
318     RequireAllocationGuard(RequireAllocationGuard const&);
319     RequireAllocationGuard& operator=(RequireAllocationGuard const&);
320 };
321 
322 #endif /* COUNT_NEW_HPP */
323