1 /*
2  * Test program that illustrates how to annotate a smart pointer
3  * implementation.  In a multithreaded program the following is relevant when
4  * working with smart pointers:
5  * - whether or not the objects pointed at are shared over threads.
6  * - whether or not the methods of the objects pointed at are thread-safe.
7  * - whether or not the smart pointer objects are shared over threads.
8  * - whether or not the smart pointer object itself is thread-safe.
9  *
10  * Most smart pointer implemenations are not thread-safe
11  * (e.g. boost::shared_ptr<>, tr1::shared_ptr<> and the smart_ptr<>
12  * implementation below). This means that it is not safe to modify a shared
13  * pointer object that is shared over threads without proper synchronization.
14  *
15  * Even for non-thread-safe smart pointers it is possible to have different
16  * threads access the same object via smart pointers without triggering data
17  * races on the smart pointer objects.
18  *
19  * A smart pointer implementation guarantees that the destructor of the object
20  * pointed at is invoked after the last smart pointer that points to that
21  * object has been destroyed or reset. Data race detection tools cannot detect
22  * this ordering without explicit annotation for smart pointers that track
23  * references without invoking synchronization operations recognized by data
24  * race detection tools.
25  */
26 
27 
28 #include <cassert>     // assert()
29 #include <climits>     // PTHREAD_STACK_MIN
30 #include <iostream>    // std::cerr
31 #include <stdlib.h>    // atoi()
32 #include <vector>
33 #ifdef _WIN32
34 #include <process.h>   // _beginthreadex()
35 #include <windows.h>   // CRITICAL_SECTION
36 #else
37 #include <pthread.h>   // pthread_mutex_t
38 #endif
39 #include "unified_annotations.h"
40 
41 
42 static bool s_enable_annotations;
43 
44 
45 #ifdef _WIN32
46 
47 class AtomicInt32
48 {
49 public:
AtomicInt32(const int value=0)50   AtomicInt32(const int value = 0) : m_value(value) { }
~AtomicInt32()51   ~AtomicInt32() { }
operator ++()52   LONG operator++() { return InterlockedIncrement(&m_value); }
operator --()53   LONG operator--() { return InterlockedDecrement(&m_value); }
54 
55 private:
56   volatile LONG m_value;
57 };
58 
59 class Mutex
60 {
61 public:
Mutex()62   Mutex() : m_mutex()
63   { InitializeCriticalSection(&m_mutex); }
~Mutex()64   ~Mutex()
65   { DeleteCriticalSection(&m_mutex); }
Lock()66   void Lock()
67   { EnterCriticalSection(&m_mutex); }
Unlock()68   void Unlock()
69   { LeaveCriticalSection(&m_mutex); }
70 
71 private:
72   CRITICAL_SECTION m_mutex;
73 };
74 
75 class Thread
76 {
77 public:
Thread()78   Thread() : m_thread(INVALID_HANDLE_VALUE) { }
~Thread()79   ~Thread() { }
Create(void * (* pf)(void *),void * arg)80   void Create(void* (*pf)(void*), void* arg)
81   {
82     WrapperArgs* wrapper_arg_p = new WrapperArgs(pf, arg);
83     m_thread = reinterpret_cast<HANDLE>(_beginthreadex(NULL, 0, wrapper,
84 						       wrapper_arg_p, 0, NULL));
85   }
Join()86   void Join()
87   { WaitForSingleObject(m_thread, INFINITE); }
88 
89 private:
90   struct WrapperArgs
91   {
WrapperArgsThread::WrapperArgs92     WrapperArgs(void* (*pf)(void*), void* arg) : m_pf(pf), m_arg(arg) { }
93 
94     void* (*m_pf)(void*);
95     void* m_arg;
96   };
wrapper(void * arg)97   static unsigned int __stdcall wrapper(void* arg)
98   {
99     WrapperArgs* wrapper_arg_p = reinterpret_cast<WrapperArgs*>(arg);
100     WrapperArgs wa = *wrapper_arg_p;
101     delete wrapper_arg_p;
102     return reinterpret_cast<unsigned>((wa.m_pf)(wa.m_arg));
103   }
104   HANDLE m_thread;
105 };
106 
107 #else // _WIN32
108 
109 class AtomicInt32
110 {
111 public:
AtomicInt32(const int value=0)112   AtomicInt32(const int value = 0) : m_value(value) { }
~AtomicInt32()113   ~AtomicInt32() { }
operator ++()114   int operator++() { return __sync_add_and_fetch(&m_value, 1); }
operator --()115   int operator--() { return __sync_sub_and_fetch(&m_value, 1); }
116 private:
117   volatile int m_value;
118 };
119 
120 class Mutex
121 {
122 public:
Mutex()123   Mutex() : m_mutex()
124   { pthread_mutex_init(&m_mutex, NULL); }
~Mutex()125   ~Mutex()
126   { pthread_mutex_destroy(&m_mutex); }
Lock()127   void Lock()
128   { pthread_mutex_lock(&m_mutex); }
Unlock()129   void Unlock()
130   { pthread_mutex_unlock(&m_mutex); }
131 
132 private:
133   pthread_mutex_t m_mutex;
134 };
135 
136 class Thread
137 {
138 public:
Thread()139   Thread() : m_tid() { }
~Thread()140   ~Thread() { }
Create(void * (* pf)(void *),void * arg)141   void Create(void* (*pf)(void*), void* arg)
142   {
143     pthread_attr_t attr;
144     pthread_attr_init(&attr);
145     pthread_attr_setstacksize(&attr, PTHREAD_STACK_MIN + 4096);
146     pthread_create(&m_tid, &attr, pf, arg);
147     pthread_attr_destroy(&attr);
148   }
Join()149   void Join()
150   { pthread_join(m_tid, NULL); }
151 private:
152   pthread_t m_tid;
153 };
154 
155 #endif // !defined(_WIN32)
156 
157 
158 template<class T>
159 class smart_ptr
160 {
161 public:
162   typedef AtomicInt32 counter_t;
163 
164   template <typename Q> friend class smart_ptr;
165 
smart_ptr()166   explicit smart_ptr()
167     : m_ptr(NULL), m_count_ptr(NULL)
168   { }
169 
smart_ptr(T * const pT)170   explicit smart_ptr(T* const pT)
171     : m_ptr(NULL), m_count_ptr(NULL)
172   {
173     set(pT, pT ? new counter_t(0) : NULL);
174   }
175 
176   template <typename Q>
smart_ptr(Q * const q)177   explicit smart_ptr(Q* const q)
178     : m_ptr(NULL), m_count_ptr(NULL)
179   {
180     set(q, q ? new counter_t(0) : NULL);
181   }
182 
~smart_ptr()183   ~smart_ptr()
184   {
185     set(NULL, NULL);
186   }
187 
smart_ptr(const smart_ptr<T> & sp)188   smart_ptr(const smart_ptr<T>& sp)
189     : m_ptr(NULL), m_count_ptr(NULL)
190   {
191     set(sp.m_ptr, sp.m_count_ptr);
192   }
193 
194   template <typename Q>
smart_ptr(const smart_ptr<Q> & sp)195   smart_ptr(const smart_ptr<Q>& sp)
196     : m_ptr(NULL), m_count_ptr(NULL)
197   {
198     set(sp.m_ptr, sp.m_count_ptr);
199   }
200 
operator =(const smart_ptr<T> & sp)201   smart_ptr& operator=(const smart_ptr<T>& sp)
202   {
203     set(sp.m_ptr, sp.m_count_ptr);
204     return *this;
205   }
206 
operator =(T * const p)207   smart_ptr& operator=(T* const p)
208   {
209     set(p, p ? new counter_t(0) : NULL);
210     return *this;
211   }
212 
213   template <typename Q>
operator =(Q * const q)214   smart_ptr& operator=(Q* const q)
215   {
216     set(q, q ? new counter_t(0) : NULL);
217     return *this;
218   }
219 
operator ->() const220   T* operator->() const
221   {
222     assert(m_ptr);
223     return m_ptr;
224   }
225 
operator *() const226   T& operator*() const
227   {
228     assert(m_ptr);
229     return *m_ptr;
230   }
231 
232 private:
set(T * const pT,counter_t * const count_ptr)233   void set(T* const pT, counter_t* const count_ptr)
234   {
235     if (m_ptr != pT)
236     {
237       if (m_count_ptr)
238       {
239 	if (s_enable_annotations)
240 	  U_ANNOTATE_HAPPENS_BEFORE(m_count_ptr);
241 	if (--(*m_count_ptr) == 0)
242 	{
243 	  if (s_enable_annotations)
244 	    U_ANNOTATE_HAPPENS_AFTER(m_count_ptr);
245 	  delete m_ptr;
246 	  m_ptr = NULL;
247 	  delete m_count_ptr;
248 	  m_count_ptr = NULL;
249 	}
250       }
251       m_ptr = pT;
252       m_count_ptr = count_ptr;
253       if (count_ptr)
254 	++(*m_count_ptr);
255     }
256   }
257 
258   T*         m_ptr;
259   counter_t* m_count_ptr;
260 };
261 
262 class counter
263 {
264 public:
counter()265   counter()
266     : m_mutex(), m_count()
267   { }
~counter()268   ~counter()
269   {
270     // Data race detection tools that do not recognize the
271     // ANNOTATE_HAPPENS_BEFORE() / ANNOTATE_HAPPENS_AFTER() annotations in the
272     // smart_ptr<> implementation will report that the assignment below
273     // triggers a data race.
274     m_count = -1;
275   }
get() const276   int get() const
277   {
278     int result;
279     m_mutex.Lock();
280     result = m_count;
281     m_mutex.Unlock();
282     return result;
283   }
post_increment()284   int post_increment()
285   {
286     int result;
287     m_mutex.Lock();
288     result = m_count++;
289     m_mutex.Unlock();
290     return result;
291   }
292 
293 private:
294   mutable Mutex m_mutex;
295   int           m_count;
296 };
297 
thread_func(void * arg)298 static void* thread_func(void* arg)
299 {
300   smart_ptr<counter>* pp = reinterpret_cast<smart_ptr<counter>*>(arg);
301   (*pp)->post_increment();
302   *pp = NULL;
303   delete pp;
304   return NULL;
305 }
306 
main(int argc,char ** argv)307 int main(int argc, char** argv)
308 {
309   const int nthreads = std::max(argc > 1 ? atoi(argv[1]) : 1, 1);
310   const int iterations = std::max(argc > 2 ? atoi(argv[2]) : 1, 1);
311   s_enable_annotations = argc > 3 ? !!atoi(argv[3]) : true;
312 
313   for (int j = 0; j < iterations; ++j)
314   {
315     std::vector<Thread> T(nthreads);
316     smart_ptr<counter> p(new counter);
317     p->post_increment();
318     for (std::vector<Thread>::iterator q = T.begin(); q != T.end(); q++)
319       q->Create(thread_func, new smart_ptr<counter>(p));
320     {
321       // Avoid that counter.m_mutex introduces a false ordering on the
322       // counter.m_count accesses.
323       const timespec delay = { 0, 100 * 1000 * 1000 };
324       nanosleep(&delay, 0);
325     }
326     p = NULL;
327     for (std::vector<Thread>::iterator q = T.begin(); q != T.end(); q++)
328       q->Join();
329   }
330   std::cerr << "Done.\n";
331   return 0;
332 }
333