1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2016 Dmitry Vyukov <dvyukov@google.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_THREADPOOL_NONBLOCKING_THREAD_POOL_H
11 #define EIGEN_CXX11_THREADPOOL_NONBLOCKING_THREAD_POOL_H
12 
13 
14 namespace Eigen {
15 
16 template <typename Environment>
17 class NonBlockingThreadPoolTempl : public Eigen::ThreadPoolInterface {
18  public:
19   typedef typename Environment::Task Task;
20   typedef RunQueue<Task, 1024> Queue;
21 
22   NonBlockingThreadPoolTempl(int num_threads, Environment env = Environment())
env_(env)23       : env_(env),
24         threads_(num_threads),
25         queues_(num_threads),
26         coprimes_(num_threads),
27         waiters_(num_threads),
28         blocked_(0),
29         spinning_(0),
30         done_(false),
31         ec_(waiters_) {
32     waiters_.resize(num_threads);
33 
34     // Calculate coprimes of num_threads.
35     // Coprimes are used for a random walk over all threads in Steal
36     // and NonEmptyQueueIndex. Iteration is based on the fact that if we take
37     // a walk starting thread index t and calculate num_threads - 1 subsequent
38     // indices as (t + coprime) % num_threads, we will cover all threads without
39     // repetitions (effectively getting a presudo-random permutation of thread
40     // indices).
41     for (int i = 1; i <= num_threads; i++) {
42       unsigned a = i;
43       unsigned b = num_threads;
44       // If GCD(a, b) == 1, then a and b are coprimes.
45       while (b != 0) {
46         unsigned tmp = a;
47         a = b;
48         b = tmp % b;
49       }
50       if (a == 1) {
51         coprimes_.push_back(i);
52       }
53     }
54     for (int i = 0; i < num_threads; i++) {
55       queues_.push_back(new Queue());
56     }
57     for (int i = 0; i < num_threads; i++) {
58       threads_.push_back(env_.CreateThread([this, i]() { WorkerLoop(i); }));
59     }
60   }
61 
~NonBlockingThreadPoolTempl()62   ~NonBlockingThreadPoolTempl() {
63     done_ = true;
64     // Now if all threads block without work, they will start exiting.
65     // But note that threads can continue to work arbitrary long,
66     // block, submit new work, unblock and otherwise live full life.
67     ec_.Notify(true);
68 
69     // Join threads explicitly to avoid destruction order issues.
70     for (size_t i = 0; i < threads_.size(); i++) delete threads_[i];
71     for (size_t i = 0; i < threads_.size(); i++) delete queues_[i];
72   }
73 
Schedule(std::function<void ()> fn)74   void Schedule(std::function<void()> fn) {
75     Task t = env_.CreateTask(std::move(fn));
76     PerThread* pt = GetPerThread();
77     if (pt->pool == this) {
78       // Worker thread of this pool, push onto the thread's queue.
79       Queue* q = queues_[pt->thread_id];
80       t = q->PushFront(std::move(t));
81     } else {
82       // A free-standing thread (or worker of another pool), push onto a random
83       // queue.
84       Queue* q = queues_[Rand(&pt->rand) % queues_.size()];
85       t = q->PushBack(std::move(t));
86     }
87     // Note: below we touch this after making w available to worker threads.
88     // Strictly speaking, this can lead to a racy-use-after-free. Consider that
89     // Schedule is called from a thread that is neither main thread nor a worker
90     // thread of this pool. Then, execution of w directly or indirectly
91     // completes overall computations, which in turn leads to destruction of
92     // this. We expect that such scenario is prevented by program, that is,
93     // this is kept alive while any threads can potentially be in Schedule.
94     if (!t.f)
95       ec_.Notify(false);
96     else
97       env_.ExecuteTask(t);  // Push failed, execute directly.
98   }
99 
NumThreads()100   int NumThreads() const final {
101     return static_cast<int>(threads_.size());
102   }
103 
CurrentThreadId()104   int CurrentThreadId() const final {
105     const PerThread* pt =
106         const_cast<NonBlockingThreadPoolTempl*>(this)->GetPerThread();
107     if (pt->pool == this) {
108       return pt->thread_id;
109     } else {
110       return -1;
111     }
112   }
113 
114  private:
115   typedef typename Environment::EnvThread Thread;
116 
117   struct PerThread {
PerThreadPerThread118     constexpr PerThread() : pool(NULL), rand(0), thread_id(-1) { }
119     NonBlockingThreadPoolTempl* pool;  // Parent pool, or null for normal threads.
120     uint64_t rand;  // Random generator state.
121     int thread_id;  // Worker thread index in pool.
122   };
123 
124   Environment env_;
125   MaxSizeVector<Thread*> threads_;
126   MaxSizeVector<Queue*> queues_;
127   MaxSizeVector<unsigned> coprimes_;
128   MaxSizeVector<EventCount::Waiter> waiters_;
129   std::atomic<unsigned> blocked_;
130   std::atomic<bool> spinning_;
131   std::atomic<bool> done_;
132   EventCount ec_;
133 
134   // Main worker thread loop.
WorkerLoop(int thread_id)135   void WorkerLoop(int thread_id) {
136     PerThread* pt = GetPerThread();
137     pt->pool = this;
138     pt->rand = std::hash<std::thread::id>()(std::this_thread::get_id());
139     pt->thread_id = thread_id;
140     Queue* q = queues_[thread_id];
141     EventCount::Waiter* waiter = &waiters_[thread_id];
142     for (;;) {
143       Task t = q->PopFront();
144       if (!t.f) {
145         t = Steal();
146         if (!t.f) {
147           // Leave one thread spinning. This reduces latency.
148           // TODO(dvyukov): 1000 iterations is based on fair dice roll, tune it.
149           // Also, the time it takes to attempt to steal work 1000 times depends
150           // on the size of the thread pool. However the speed at which the user
151           // of the thread pool submit tasks is independent of the size of the
152           // pool. Consider a time based limit instead.
153           if (!spinning_ && !spinning_.exchange(true)) {
154             for (int i = 0; i < 1000 && !t.f; i++) {
155               t = Steal();
156             }
157             spinning_ = false;
158           }
159           if (!t.f) {
160             if (!WaitForWork(waiter, &t)) {
161               return;
162             }
163           }
164         }
165       }
166       if (t.f) {
167         env_.ExecuteTask(t);
168       }
169     }
170   }
171 
172   // Steal tries to steal work from other worker threads in best-effort manner.
Steal()173   Task Steal() {
174     PerThread* pt = GetPerThread();
175     const size_t size = queues_.size();
176     unsigned r = Rand(&pt->rand);
177     unsigned inc = coprimes_[r % coprimes_.size()];
178     unsigned victim = r % size;
179     for (unsigned i = 0; i < size; i++) {
180       Task t = queues_[victim]->PopBack();
181       if (t.f) {
182         return t;
183       }
184       victim += inc;
185       if (victim >= size) {
186         victim -= size;
187       }
188     }
189     return Task();
190   }
191 
192   // WaitForWork blocks until new work is available (returns true), or if it is
193   // time to exit (returns false). Can optionally return a task to execute in t
194   // (in such case t.f != nullptr on return).
WaitForWork(EventCount::Waiter * waiter,Task * t)195   bool WaitForWork(EventCount::Waiter* waiter, Task* t) {
196     eigen_assert(!t->f);
197     // We already did best-effort emptiness check in Steal, so prepare for
198     // blocking.
199     ec_.Prewait(waiter);
200     // Now do a reliable emptiness check.
201     int victim = NonEmptyQueueIndex();
202     if (victim != -1) {
203       ec_.CancelWait(waiter);
204       *t = queues_[victim]->PopBack();
205       return true;
206     }
207     // Number of blocked threads is used as termination condition.
208     // If we are shutting down and all worker threads blocked without work,
209     // that's we are done.
210     blocked_++;
211     if (done_ && blocked_ == threads_.size()) {
212       ec_.CancelWait(waiter);
213       // Almost done, but need to re-check queues.
214       // Consider that all queues are empty and all worker threads are preempted
215       // right after incrementing blocked_ above. Now a free-standing thread
216       // submits work and calls destructor (which sets done_). If we don't
217       // re-check queues, we will exit leaving the work unexecuted.
218       if (NonEmptyQueueIndex() != -1) {
219         // Note: we must not pop from queues before we decrement blocked_,
220         // otherwise the following scenario is possible. Consider that instead
221         // of checking for emptiness we popped the only element from queues.
222         // Now other worker threads can start exiting, which is bad if the
223         // work item submits other work. So we just check emptiness here,
224         // which ensures that all worker threads exit at the same time.
225         blocked_--;
226         return true;
227       }
228       // Reached stable termination state.
229       ec_.Notify(true);
230       return false;
231     }
232     ec_.CommitWait(waiter);
233     blocked_--;
234     return true;
235   }
236 
NonEmptyQueueIndex()237   int NonEmptyQueueIndex() {
238     PerThread* pt = GetPerThread();
239     const size_t size = queues_.size();
240     unsigned r = Rand(&pt->rand);
241     unsigned inc = coprimes_[r % coprimes_.size()];
242     unsigned victim = r % size;
243     for (unsigned i = 0; i < size; i++) {
244       if (!queues_[victim]->Empty()) {
245         return victim;
246       }
247       victim += inc;
248       if (victim >= size) {
249         victim -= size;
250       }
251     }
252     return -1;
253   }
254 
GetPerThread()255   static EIGEN_STRONG_INLINE PerThread* GetPerThread() {
256     EIGEN_THREAD_LOCAL PerThread per_thread_;
257     PerThread* pt = &per_thread_;
258     return pt;
259   }
260 
Rand(uint64_t * state)261   static EIGEN_STRONG_INLINE unsigned Rand(uint64_t* state) {
262     uint64_t current = *state;
263     // Update the internal state
264     *state = current * 6364136223846793005ULL + 0xda3e39cb94b95bdbULL;
265     // Generate the random output (using the PCG-XSH-RS scheme)
266     return static_cast<unsigned>((current ^ (current >> 22)) >> (22 + (current >> 61)));
267   }
268 };
269 
270 typedef NonBlockingThreadPoolTempl<StlThreadEnvironment> NonBlockingThreadPool;
271 
272 }  // namespace Eigen
273 
274 #endif  // EIGEN_CXX11_THREADPOOL_NONBLOCKING_THREAD_POOL_H
275