/* * Test program that illustrates how to annotate a smart pointer * implementation. In a multithreaded program the following is relevant when * working with smart pointers: * - whether or not the objects pointed at are shared over threads. * - whether or not the methods of the objects pointed at are thread-safe. * - whether or not the smart pointer objects are shared over threads. * - whether or not the smart pointer object itself is thread-safe. * * Most smart pointer implemenations are not thread-safe * (e.g. boost::shared_ptr<>, tr1::shared_ptr<> and the smart_ptr<> * implementation below). This means that it is not safe to modify a shared * pointer object that is shared over threads without proper synchronization. * * Even for non-thread-safe smart pointers it is possible to have different * threads access the same object via smart pointers without triggering data * races on the smart pointer objects. * * A smart pointer implementation guarantees that the destructor of the object * pointed at is invoked after the last smart pointer that points to that * object has been destroyed or reset. Data race detection tools cannot detect * this ordering without explicit annotation for smart pointers that track * references without invoking synchronization operations recognized by data * race detection tools. */ #include // assert() #include // PTHREAD_STACK_MIN #include // std::cerr #include // atoi() #include #ifdef _WIN32 #include // _beginthreadex() #include // CRITICAL_SECTION #else #include // pthread_mutex_t #endif #include "unified_annotations.h" static bool s_enable_annotations; #ifdef _WIN32 class AtomicInt32 { public: AtomicInt32(const int value = 0) : m_value(value) { } ~AtomicInt32() { } LONG operator++() { return InterlockedIncrement(&m_value); } LONG operator--() { return InterlockedDecrement(&m_value); } private: volatile LONG m_value; }; class Mutex { public: Mutex() : m_mutex() { InitializeCriticalSection(&m_mutex); } ~Mutex() { DeleteCriticalSection(&m_mutex); } void Lock() { EnterCriticalSection(&m_mutex); } void Unlock() { LeaveCriticalSection(&m_mutex); } private: CRITICAL_SECTION m_mutex; }; class Thread { public: Thread() : m_thread(INVALID_HANDLE_VALUE) { } ~Thread() { } void Create(void* (*pf)(void*), void* arg) { WrapperArgs* wrapper_arg_p = new WrapperArgs(pf, arg); m_thread = reinterpret_cast(_beginthreadex(NULL, 0, wrapper, wrapper_arg_p, 0, NULL)); } void Join() { WaitForSingleObject(m_thread, INFINITE); } private: struct WrapperArgs { WrapperArgs(void* (*pf)(void*), void* arg) : m_pf(pf), m_arg(arg) { } void* (*m_pf)(void*); void* m_arg; }; static unsigned int __stdcall wrapper(void* arg) { WrapperArgs* wrapper_arg_p = reinterpret_cast(arg); WrapperArgs wa = *wrapper_arg_p; delete wrapper_arg_p; return reinterpret_cast((wa.m_pf)(wa.m_arg)); } HANDLE m_thread; }; #else // _WIN32 class AtomicInt32 { public: AtomicInt32(const int value = 0) : m_value(value) { } ~AtomicInt32() { } int operator++() { return __sync_add_and_fetch(&m_value, 1); } int operator--() { return __sync_sub_and_fetch(&m_value, 1); } private: volatile int m_value; }; class Mutex { public: Mutex() : m_mutex() { pthread_mutex_init(&m_mutex, NULL); } ~Mutex() { pthread_mutex_destroy(&m_mutex); } void Lock() { pthread_mutex_lock(&m_mutex); } void Unlock() { pthread_mutex_unlock(&m_mutex); } private: pthread_mutex_t m_mutex; }; class Thread { public: Thread() : m_tid() { } ~Thread() { } void Create(void* (*pf)(void*), void* arg) { pthread_attr_t attr; pthread_attr_init(&attr); pthread_attr_setstacksize(&attr, PTHREAD_STACK_MIN + 4096); pthread_create(&m_tid, &attr, pf, arg); pthread_attr_destroy(&attr); } void Join() { pthread_join(m_tid, NULL); } private: pthread_t m_tid; }; #endif // !defined(_WIN32) template class smart_ptr { public: typedef AtomicInt32 counter_t; template friend class smart_ptr; explicit smart_ptr() : m_ptr(NULL), m_count_ptr(NULL) { } explicit smart_ptr(T* const pT) : m_ptr(NULL), m_count_ptr(NULL) { set(pT, pT ? new counter_t(0) : NULL); } template explicit smart_ptr(Q* const q) : m_ptr(NULL), m_count_ptr(NULL) { set(q, q ? new counter_t(0) : NULL); } ~smart_ptr() { set(NULL, NULL); } smart_ptr(const smart_ptr& sp) : m_ptr(NULL), m_count_ptr(NULL) { set(sp.m_ptr, sp.m_count_ptr); } template smart_ptr(const smart_ptr& sp) : m_ptr(NULL), m_count_ptr(NULL) { set(sp.m_ptr, sp.m_count_ptr); } smart_ptr& operator=(const smart_ptr& sp) { set(sp.m_ptr, sp.m_count_ptr); return *this; } smart_ptr& operator=(T* const p) { set(p, p ? new counter_t(0) : NULL); return *this; } template smart_ptr& operator=(Q* const q) { set(q, q ? new counter_t(0) : NULL); return *this; } T* operator->() const { assert(m_ptr); return m_ptr; } T& operator*() const { assert(m_ptr); return *m_ptr; } private: void set(T* const pT, counter_t* const count_ptr) { if (m_ptr != pT) { if (m_count_ptr) { if (s_enable_annotations) U_ANNOTATE_HAPPENS_BEFORE(m_count_ptr); if (--(*m_count_ptr) == 0) { if (s_enable_annotations) U_ANNOTATE_HAPPENS_AFTER(m_count_ptr); delete m_ptr; m_ptr = NULL; delete m_count_ptr; m_count_ptr = NULL; } } m_ptr = pT; m_count_ptr = count_ptr; if (count_ptr) ++(*m_count_ptr); } } T* m_ptr; counter_t* m_count_ptr; }; class counter { public: counter() : m_mutex(), m_count() { } ~counter() { // Data race detection tools that do not recognize the // ANNOTATE_HAPPENS_BEFORE() / ANNOTATE_HAPPENS_AFTER() annotations in the // smart_ptr<> implementation will report that the assignment below // triggers a data race. m_count = -1; } int get() const { int result; m_mutex.Lock(); result = m_count; m_mutex.Unlock(); return result; } int post_increment() { int result; m_mutex.Lock(); result = m_count++; m_mutex.Unlock(); return result; } private: mutable Mutex m_mutex; int m_count; }; static void* thread_func(void* arg) { smart_ptr* pp = reinterpret_cast*>(arg); (*pp)->post_increment(); *pp = NULL; delete pp; return NULL; } int main(int argc, char** argv) { const int nthreads = std::max(argc > 1 ? atoi(argv[1]) : 1, 1); const int iterations = std::max(argc > 2 ? atoi(argv[2]) : 1, 1); s_enable_annotations = argc > 3 ? !!atoi(argv[3]) : true; for (int j = 0; j < iterations; ++j) { std::vector T(nthreads); smart_ptr p(new counter); p->post_increment(); for (std::vector::iterator q = T.begin(); q != T.end(); q++) q->Create(thread_func, new smart_ptr(p)); { // Avoid that counter.m_mutex introduces a false ordering on the // counter.m_count accesses. const timespec delay = { 0, 100 * 1000 * 1000 }; nanosleep(&delay, 0); } p = NULL; for (std::vector::iterator q = T.begin(); q != T.end(); q++) q->Join(); } std::cerr << "Done.\n"; return 0; }