1 // Copyright 2019 The Marl Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef marl_scheduler_h
16 #define marl_scheduler_h
17 
18 #include "containers.h"
19 #include "debug.h"
20 #include "deprecated.h"
21 #include "export.h"
22 #include "memory.h"
23 #include "mutex.h"
24 #include "task.h"
25 #include "thread.h"
26 
27 #include <array>
28 #include <atomic>
29 #include <chrono>
30 #include <condition_variable>
31 #include <functional>
32 #include <thread>
33 
34 namespace marl {
35 
36 class OSFiber;
37 
38 // Scheduler asynchronously processes Tasks.
39 // A scheduler can be bound to one or more threads using the bind() method.
40 // Once bound to a thread, that thread can call marl::schedule() to enqueue
41 // work tasks to be executed asynchronously.
42 // Scheduler are initially constructed in single-threaded mode.
43 // Call setWorkerThreadCount() to spawn dedicated worker threads.
44 class Scheduler {
45   class Worker;
46 
47  public:
48   using TimePoint = std::chrono::system_clock::time_point;
49   using Predicate = std::function<bool()>;
50   using ThreadInitializer = std::function<void(int workerId)>;
51 
52   // Config holds scheduler configuration settings that can be passed to the
53   // Scheduler constructor.
54   struct Config {
55     static constexpr size_t DefaultFiberStackSize = 1024 * 1024;
56 
57     // Per-worker-thread settings.
58     struct WorkerThread {
59       // Total number of dedicated worker threads to spawn for the scheduler.
60       int count = 0;
61 
62       // Initializer function to call after thread creation and before any work
63       // is run by the thread.
64       ThreadInitializer initializer;
65 
66       // Thread affinity policy to use for worker threads.
67       std::shared_ptr<Thread::Affinity::Policy> affinityPolicy;
68     };
69 
70     WorkerThread workerThread;
71 
72     // Memory allocator to use for the scheduler and internal allocations.
73     Allocator* allocator = Allocator::Default;
74 
75     // Size of each fiber stack. This may be rounded up to the nearest
76     // allocation granularity for the given platform.
77     size_t fiberStackSize = DefaultFiberStackSize;
78 
79     // allCores() returns a Config with a worker thread for each of the logical
80     // cpus available to the process.
81     MARL_EXPORT
82     static Config allCores();
83 
84     // Fluent setters that return this Config so set calls can be chained.
85     MARL_NO_EXPORT inline Config& setAllocator(Allocator*);
86     MARL_NO_EXPORT inline Config& setFiberStackSize(size_t);
87     MARL_NO_EXPORT inline Config& setWorkerThreadCount(int);
88     MARL_NO_EXPORT inline Config& setWorkerThreadInitializer(
89         const ThreadInitializer&);
90     MARL_NO_EXPORT inline Config& setWorkerThreadAffinityPolicy(
91         const std::shared_ptr<Thread::Affinity::Policy>&);
92   };
93 
94   // Constructor.
95   MARL_EXPORT
96   Scheduler(const Config&);
97 
98   // Destructor.
99   // Blocks until the scheduler is unbound from all threads before returning.
100   MARL_EXPORT
101   ~Scheduler();
102 
103   // get() returns the scheduler bound to the current thread.
104   MARL_EXPORT
105   static Scheduler* get();
106 
107   // bind() binds this scheduler to the current thread.
108   // There must be no existing scheduler bound to the thread prior to calling.
109   MARL_EXPORT
110   void bind();
111 
112   // unbind() unbinds the scheduler currently bound to the current thread.
113   // There must be a existing scheduler bound to the thread prior to calling.
114   // unbind() flushes any enqueued tasks on the single-threaded worker before
115   // returning.
116   MARL_EXPORT
117   static void unbind();
118 
119   // enqueue() queues the task for asynchronous execution.
120   MARL_EXPORT
121   void enqueue(Task&& task);
122 
123   // config() returns the Config that was used to build the schededuler.
124   MARL_EXPORT
125   const Config& config() const;
126 
127   // Fibers expose methods to perform cooperative multitasking and are
128   // automatically created by the Scheduler.
129   //
130   // The currently executing Fiber can be obtained by calling Fiber::current().
131   //
132   // When execution becomes blocked, yield() can be called to suspend execution
133   // of the fiber and start executing other pending work. Once the block has
134   // been lifted, schedule() can be called to reschedule the Fiber on the same
135   // thread that previously executed it.
136   class Fiber {
137    public:
138     // current() returns the currently executing fiber, or nullptr if called
139     // without a bound scheduler.
140     MARL_EXPORT
141     static Fiber* current();
142 
143     // wait() suspends execution of this Fiber until the Fiber is woken up with
144     // a call to notify() and the predicate pred returns true.
145     // If the predicate pred does not return true when notify() is called, then
146     // the Fiber is automatically re-suspended, and will need to be woken with
147     // another call to notify().
148     // While the Fiber is suspended, the scheduler thread may continue executing
149     // other tasks.
150     // lock must be locked before calling, and is unlocked by wait() just before
151     // the Fiber is suspended, and re-locked before the fiber is resumed. lock
152     // will be locked before wait() returns.
153     // pred will be always be called with the lock held.
154     // wait() must only be called on the currently executing fiber.
155     MARL_EXPORT
156     void wait(marl::lock& lock, const Predicate& pred);
157 
158     // wait() suspends execution of this Fiber until the Fiber is woken up with
159     // a call to notify() and the predicate pred returns true, or sometime after
160     // the timeout is reached.
161     // If the predicate pred does not return true when notify() is called, then
162     // the Fiber is automatically re-suspended, and will need to be woken with
163     // another call to notify() or will be woken sometime after the timeout is
164     // reached.
165     // While the Fiber is suspended, the scheduler thread may continue executing
166     // other tasks.
167     // lock must be locked before calling, and is unlocked by wait() just before
168     // the Fiber is suspended, and re-locked before the fiber is resumed. lock
169     // will be locked before wait() returns.
170     // pred will be always be called with the lock held.
171     // wait() must only be called on the currently executing fiber.
172     template <typename Clock, typename Duration>
173     MARL_NO_EXPORT inline bool wait(
174         marl::lock& lock,
175         const std::chrono::time_point<Clock, Duration>& timeout,
176         const Predicate& pred);
177 
178     // wait() suspends execution of this Fiber until the Fiber is woken up with
179     // a call to notify().
180     // While the Fiber is suspended, the scheduler thread may continue executing
181     // other tasks.
182     // wait() must only be called on the currently executing fiber.
183     //
184     // Warning: Unlike wait() overloads that take a lock and predicate, this
185     // form of wait() offers no safety for notify() signals that occur before
186     // the fiber is suspended, when signalling between different threads. In
187     // this scenario you may deadlock. For this reason, it is only ever
188     // recommended to use this overload if you can guarantee that the calls to
189     // wait() and notify() are made by the same thread.
190     //
191     // Use with extreme caution.
192     MARL_NO_EXPORT inline void wait();
193 
194     // wait() suspends execution of this Fiber until the Fiber is woken up with
195     // a call to notify(), or sometime after the timeout is reached.
196     // While the Fiber is suspended, the scheduler thread may continue executing
197     // other tasks.
198     // wait() must only be called on the currently executing fiber.
199     //
200     // Warning: Unlike wait() overloads that take a lock and predicate, this
201     // form of wait() offers no safety for notify() signals that occur before
202     // the fiber is suspended, when signalling between different threads. For
203     // this reason, it is only ever recommended to use this overload if you can
204     // guarantee that the calls to wait() and notify() are made by the same
205     // thread.
206     //
207     // Use with extreme caution.
208     template <typename Clock, typename Duration>
209     MARL_NO_EXPORT inline bool wait(
210         const std::chrono::time_point<Clock, Duration>& timeout);
211 
212     // notify() reschedules the suspended Fiber for execution.
213     // notify() is usually only called when the predicate for one or more wait()
214     // calls will likely return true.
215     MARL_EXPORT
216     void notify();
217 
218     // id is the thread-unique identifier of the Fiber.
219     uint32_t const id;
220 
221    private:
222     friend class Allocator;
223     friend class Scheduler;
224 
225     enum class State {
226       // Idle: the Fiber is currently unused, and sits in Worker::idleFibers,
227       // ready to be recycled.
228       Idle,
229 
230       // Yielded: the Fiber is currently blocked on a wait() call with no
231       // timeout.
232       Yielded,
233 
234       // Waiting: the Fiber is currently blocked on a wait() call with a
235       // timeout. The fiber is stilling in the Worker::Work::waiting queue.
236       Waiting,
237 
238       // Queued: the Fiber is currently queued for execution in the
239       // Worker::Work::fibers queue.
240       Queued,
241 
242       // Running: the Fiber is currently executing.
243       Running,
244     };
245 
246     Fiber(Allocator::unique_ptr<OSFiber>&&, uint32_t id);
247 
248     // switchTo() switches execution to the given fiber.
249     // switchTo() must only be called on the currently executing fiber.
250     void switchTo(Fiber*);
251 
252     // create() constructs and returns a new fiber with the given identifier,
253     // stack size that will executed func when switched to.
254     static Allocator::unique_ptr<Fiber> create(
255         Allocator* allocator,
256         uint32_t id,
257         size_t stackSize,
258         const std::function<void()>& func);
259 
260     // createFromCurrentThread() constructs and returns a new fiber with the
261     // given identifier for the current thread.
262     static Allocator::unique_ptr<Fiber> createFromCurrentThread(
263         Allocator* allocator,
264         uint32_t id);
265 
266     // toString() returns a string representation of the given State.
267     // Used for debugging.
268     static const char* toString(State state);
269 
270     Allocator::unique_ptr<OSFiber> const impl;
271     Worker* const worker;
272     State state = State::Running;  // Guarded by Worker's work.mutex.
273   };
274 
275  private:
276   Scheduler(const Scheduler&) = delete;
277   Scheduler(Scheduler&&) = delete;
278   Scheduler& operator=(const Scheduler&) = delete;
279   Scheduler& operator=(Scheduler&&) = delete;
280 
281   // Maximum number of worker threads.
282   static constexpr size_t MaxWorkerThreads = 256;
283 
284   // WaitingFibers holds all the fibers waiting on a timeout.
285   struct WaitingFibers {
286     inline WaitingFibers(Allocator*);
287 
288     // operator bool() returns true iff there are any wait fibers.
289     inline operator bool() const;
290 
291     // take() returns the next fiber that has exceeded its timeout, or nullptr
292     // if there are no fibers that have yet exceeded their timeouts.
293     inline Fiber* take(const TimePoint& timeout);
294 
295     // next() returns the timepoint of the next fiber to timeout.
296     // next() can only be called if operator bool() returns true.
297     inline TimePoint next() const;
298 
299     // add() adds another fiber and timeout to the list of waiting fibers.
300     inline void add(const TimePoint& timeout, Fiber* fiber);
301 
302     // erase() removes the fiber from the waiting list.
303     inline void erase(Fiber* fiber);
304 
305     // contains() returns true if fiber is waiting.
306     inline bool contains(Fiber* fiber) const;
307 
308    private:
309     struct Timeout {
310       TimePoint timepoint;
311       Fiber* fiber;
312       inline bool operator<(const Timeout&) const;
313     };
314     containers::set<Timeout, std::less<Timeout>> timeouts;
315     containers::unordered_map<Fiber*, TimePoint> fibers;
316   };
317 
318   // TODO: Implement a queue that recycles elements to reduce number of
319   // heap allocations.
320   using TaskQueue = containers::deque<Task>;
321   using FiberQueue = containers::deque<Fiber*>;
322   using FiberSet = containers::unordered_set<Fiber*>;
323 
324   // Workers executes Tasks on a single thread.
325   // Once a task is started, it may yield to other tasks on the same Worker.
326   // Tasks are always resumed by the same Worker.
327   class Worker {
328    public:
329     enum class Mode {
330       // Worker will spawn a background thread to process tasks.
331       MultiThreaded,
332 
333       // Worker will execute tasks whenever it yields.
334       SingleThreaded,
335     };
336 
337     Worker(Scheduler* scheduler, Mode mode, uint32_t id);
338 
339     // start() begins execution of the worker.
340     void start() EXCLUDES(work.mutex);
341 
342     // stop() ceases execution of the worker, blocking until all pending
343     // tasks have fully finished.
344     void stop() EXCLUDES(work.mutex);
345 
346     // wait() suspends execution of the current task until the predicate pred
347     // returns true or the optional timeout is reached.
348     // See Fiber::wait() for more information.
349     MARL_EXPORT
350     bool wait(marl::lock& lock, const TimePoint* timeout, const Predicate& pred)
351         EXCLUDES(work.mutex);
352 
353     // wait() suspends execution of the current task until the fiber is
354     // notified, or the optional timeout is reached.
355     // See Fiber::wait() for more information.
356     MARL_EXPORT
357     bool wait(const TimePoint* timeout) EXCLUDES(work.mutex);
358 
359     // suspend() suspends the currenetly executing Fiber until the fiber is
360     // woken with a call to enqueue(Fiber*), or automatically sometime after the
361     // optional timeout.
362     void suspend(const TimePoint* timeout) REQUIRES(work.mutex);
363 
364     // enqueue(Fiber*) enqueues resuming of a suspended fiber.
365     void enqueue(Fiber* fiber) EXCLUDES(work.mutex);
366 
367     // enqueue(Task&&) enqueues a new, unstarted task.
368     void enqueue(Task&& task) EXCLUDES(work.mutex);
369 
370     // tryLock() attempts to lock the worker for task enqueing.
371     // If the lock was successful then true is returned, and the caller must
372     // call enqueueAndUnlock().
373     bool tryLock() EXCLUDES(work.mutex) TRY_ACQUIRE(true, work.mutex);
374 
375     // enqueueAndUnlock() enqueues the task and unlocks the worker.
376     // Must only be called after a call to tryLock() which returned true.
377     // _Releases_lock_(work.mutex)
378     void enqueueAndUnlock(Task&& task) REQUIRES(work.mutex) RELEASE(work.mutex);
379 
380     // runUntilShutdown() processes all tasks and fibers until there are no more
381     // and shutdown is true, upon runUntilShutdown() returns.
382     void runUntilShutdown() REQUIRES(work.mutex);
383 
384     // steal() attempts to steal a Task from the worker for another worker.
385     // Returns true if a task was taken and assigned to out, otherwise false.
386     bool steal(Task& out) EXCLUDES(work.mutex);
387 
388     // getCurrent() returns the Worker currently bound to the current
389     // thread.
390     static inline Worker* getCurrent();
391 
392     // getCurrentFiber() returns the Fiber currently being executed.
393     inline Fiber* getCurrentFiber() const;
394 
395     // Unique identifier of the Worker.
396     const uint32_t id;
397 
398    private:
399     // run() is the task processing function for the worker.
400     // run() processes tasks until stop() is called.
401     void run() REQUIRES(work.mutex);
402 
403     // createWorkerFiber() creates a new fiber that when executed calls
404     // run().
405     Fiber* createWorkerFiber() REQUIRES(work.mutex);
406 
407     // switchToFiber() switches execution to the given fiber. The fiber
408     // must belong to this worker.
409     void switchToFiber(Fiber*) REQUIRES(work.mutex);
410 
411     // runUntilIdle() executes all pending tasks and then returns.
412     void runUntilIdle() REQUIRES(work.mutex);
413 
414     // waitForWork() blocks until new work is available, potentially calling
415     // spinForWork().
416     void waitForWork() REQUIRES(work.mutex);
417 
418     // spinForWork() attempts to steal work from another Worker, and keeps
419     // the thread awake for a short duration. This reduces overheads of
420     // frequently putting the thread to sleep and re-waking.
421     void spinForWork();
422 
423     // enqueueFiberTimeouts() enqueues all the fibers that have finished
424     // waiting.
425     void enqueueFiberTimeouts() REQUIRES(work.mutex);
426 
427     inline void changeFiberState(Fiber* fiber,
428                                  Fiber::State from,
429                                  Fiber::State to) const REQUIRES(work.mutex);
430 
431     inline void setFiberState(Fiber* fiber, Fiber::State to) const
432         REQUIRES(work.mutex);
433 
434     // Work holds tasks and fibers that are enqueued on the Worker.
435     struct Work {
436       inline Work(Allocator*);
437 
438       std::atomic<uint64_t> num = {0};  // tasks.size() + fibers.size()
439       GUARDED_BY(mutex) uint64_t numBlockedFibers = 0;
440       GUARDED_BY(mutex) TaskQueue tasks;
441       GUARDED_BY(mutex) FiberQueue fibers;
442       GUARDED_BY(mutex) WaitingFibers waiting;
443       GUARDED_BY(mutex) bool notifyAdded = true;
444       std::condition_variable added;
445       marl::mutex mutex;
446 
447       template <typename F>
448       inline void wait(F&&) REQUIRES(mutex);
449     };
450 
451     // https://en.wikipedia.org/wiki/Xorshift
452     class FastRnd {
453      public:
operator()454       inline uint64_t operator()() {
455         x ^= x << 13;
456         x ^= x >> 7;
457         x ^= x << 17;
458         return x;
459       }
460 
461      private:
462       uint64_t x = std::chrono::system_clock::now().time_since_epoch().count();
463     };
464 
465     // The current worker bound to the current thread.
466     static thread_local Worker* current;
467 
468     Mode const mode;
469     Scheduler* const scheduler;
470     Allocator::unique_ptr<Fiber> mainFiber;
471     Fiber* currentFiber = nullptr;
472     Thread thread;
473     Work work;
474     FiberSet idleFibers;  // Fibers that have completed which can be reused.
475     containers::vector<Allocator::unique_ptr<Fiber>, 16>
476         workerFibers;  // All fibers created by this worker.
477     FastRnd rng;
478     bool shutdown = false;
479   };
480 
481   // stealWork() attempts to steal a task from the worker with the given id.
482   // Returns true if a task was stolen and assigned to out, otherwise false.
483   bool stealWork(Worker* thief, uint64_t from, Task& out);
484 
485   // onBeginSpinning() is called when a Worker calls spinForWork().
486   // The scheduler will prioritize this worker for new tasks to try to prevent
487   // it going to sleep.
488   void onBeginSpinning(int workerId);
489 
490   // The scheduler currently bound to the current thread.
491   static thread_local Scheduler* bound;
492 
493   // The immutable configuration used to build the scheduler.
494   const Config cfg;
495 
496   std::array<std::atomic<int>, 8> spinningWorkers;
497   std::atomic<unsigned int> nextSpinningWorkerIdx = {0x8000000};
498 
499   std::atomic<unsigned int> nextEnqueueIndex = {0};
500   std::array<Worker*, MaxWorkerThreads> workerThreads;
501 
502   struct SingleThreadedWorkers {
503     inline SingleThreadedWorkers(Allocator*);
504 
505     using WorkerByTid =
506         containers::unordered_map<std::thread::id,
507                                   Allocator::unique_ptr<Worker>>;
508     marl::mutex mutex;
509     GUARDED_BY(mutex) std::condition_variable unbind;
510     GUARDED_BY(mutex) WorkerByTid byTid;
511   };
512   SingleThreadedWorkers singleThreadedWorkers;
513 };
514 
515 ////////////////////////////////////////////////////////////////////////////////
516 // Scheduler::Config
517 ////////////////////////////////////////////////////////////////////////////////
setAllocator(Allocator * alloc)518 Scheduler::Config& Scheduler::Config::setAllocator(Allocator* alloc) {
519   allocator = alloc;
520   return *this;
521 }
522 
setFiberStackSize(size_t size)523 Scheduler::Config& Scheduler::Config::setFiberStackSize(size_t size) {
524   fiberStackSize = size;
525   return *this;
526 }
527 
setWorkerThreadCount(int count)528 Scheduler::Config& Scheduler::Config::setWorkerThreadCount(int count) {
529   workerThread.count = count;
530   return *this;
531 }
532 
setWorkerThreadInitializer(const ThreadInitializer & initializer)533 Scheduler::Config& Scheduler::Config::setWorkerThreadInitializer(
534     const ThreadInitializer& initializer) {
535   workerThread.initializer = initializer;
536   return *this;
537 }
538 
setWorkerThreadAffinityPolicy(const std::shared_ptr<Thread::Affinity::Policy> & policy)539 Scheduler::Config& Scheduler::Config::setWorkerThreadAffinityPolicy(
540     const std::shared_ptr<Thread::Affinity::Policy>& policy) {
541   workerThread.affinityPolicy = policy;
542   return *this;
543 }
544 
545 ////////////////////////////////////////////////////////////////////////////////
546 // Scheduler::Fiber
547 ////////////////////////////////////////////////////////////////////////////////
548 template <typename Clock, typename Duration>
wait(marl::lock & lock,const std::chrono::time_point<Clock,Duration> & timeout,const Predicate & pred)549 bool Scheduler::Fiber::wait(
550     marl::lock& lock,
551     const std::chrono::time_point<Clock, Duration>& timeout,
552     const Predicate& pred) {
553   using ToDuration = typename TimePoint::duration;
554   using ToClock = typename TimePoint::clock;
555   auto tp = std::chrono::time_point_cast<ToDuration, ToClock>(timeout);
556   return worker->wait(lock, &tp, pred);
557 }
558 
wait()559 void Scheduler::Fiber::wait() {
560   worker->wait(nullptr);
561 }
562 
563 template <typename Clock, typename Duration>
wait(const std::chrono::time_point<Clock,Duration> & timeout)564 bool Scheduler::Fiber::wait(
565     const std::chrono::time_point<Clock, Duration>& timeout) {
566   using ToDuration = typename TimePoint::duration;
567   using ToClock = typename TimePoint::clock;
568   auto tp = std::chrono::time_point_cast<ToDuration, ToClock>(timeout);
569   return worker->wait(&tp);
570 }
571 
getCurrent()572 Scheduler::Worker* Scheduler::Worker::getCurrent() {
573   return Worker::current;
574 }
575 
getCurrentFiber()576 Scheduler::Fiber* Scheduler::Worker::getCurrentFiber() const {
577   return currentFiber;
578 }
579 
580 // schedule() schedules the task T to be asynchronously called using the
581 // currently bound scheduler.
schedule(Task && t)582 inline void schedule(Task&& t) {
583   MARL_ASSERT_HAS_BOUND_SCHEDULER("marl::schedule");
584   auto scheduler = Scheduler::get();
585   scheduler->enqueue(std::move(t));
586 }
587 
588 // schedule() schedules the function f to be asynchronously called with the
589 // given arguments using the currently bound scheduler.
590 template <typename Function, typename... Args>
schedule(Function && f,Args &&...args)591 inline void schedule(Function&& f, Args&&... args) {
592   MARL_ASSERT_HAS_BOUND_SCHEDULER("marl::schedule");
593   auto scheduler = Scheduler::get();
594   scheduler->enqueue(
595       Task(std::bind(std::forward<Function>(f), std::forward<Args>(args)...)));
596 }
597 
598 // schedule() schedules the function f to be asynchronously called using the
599 // currently bound scheduler.
600 template <typename Function>
schedule(Function && f)601 inline void schedule(Function&& f) {
602   MARL_ASSERT_HAS_BOUND_SCHEDULER("marl::schedule");
603   auto scheduler = Scheduler::get();
604   scheduler->enqueue(Task(std::forward<Function>(f)));
605 }
606 
607 }  // namespace marl
608 
609 #endif  // marl_scheduler_h
610