1 // Copyright 2019 The Marl Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef marl_waitgroup_h
16 #define marl_waitgroup_h
17 
18 #include "conditionvariable.h"
19 #include "debug.h"
20 
21 #include <atomic>
22 #include <mutex>
23 
24 namespace marl {
25 
26 // WaitGroup is a synchronization primitive that holds an internal counter that
27 // can incremented, decremented and waited on until it reaches 0.
28 // WaitGroups can be used as a simple mechanism for waiting on a number of
29 // concurrently execute a number of tasks to complete.
30 //
31 // Example:
32 //
33 //  void runTasksConcurrently(int numConcurrentTasks)
34 //  {
35 //      // Construct the WaitGroup with an initial count of numConcurrentTasks.
36 //      marl::WaitGroup wg(numConcurrentTasks);
37 //      for (int i = 0; i < numConcurrentTasks; i++)
38 //      {
39 //          // Schedule a task to be run asynchronously.
40 //          // These may all be run concurrently.
41 //          marl::schedule([=] {
42 //              // Once the task has finished, decrement the waitgroup counter
43 //              // to signal that this has completed.
44 //              defer(wg.done());
45 //              doSomeWork();
46 //          });
47 //      }
48 //      // Block until all tasks have completed.
49 //      wg.wait();
50 //  }
51 class WaitGroup {
52  public:
53   // Constructs the WaitGroup with the specified initial count.
54   MARL_NO_EXPORT inline WaitGroup(unsigned int initialCount = 0,
55                                   Allocator* allocator = Allocator::Default);
56 
57   // add() increments the internal counter by count.
58   MARL_NO_EXPORT inline void add(unsigned int count = 1) const;
59 
60   // done() decrements the internal counter by one.
61   // Returns true if the internal count has reached zero.
62   MARL_NO_EXPORT inline bool done() const;
63 
64   // wait() blocks until the WaitGroup counter reaches zero.
65   MARL_NO_EXPORT inline void wait() const;
66 
67  private:
68   struct Data {
69     MARL_NO_EXPORT inline Data(Allocator* allocator);
70 
71     std::atomic<unsigned int> count = {0};
72     ConditionVariable cv;
73     marl::mutex mutex;
74   };
75   const std::shared_ptr<Data> data;
76 };
77 
Data(Allocator * allocator)78 WaitGroup::Data::Data(Allocator* allocator) : cv(allocator) {}
79 
WaitGroup(unsigned int initialCount,Allocator * allocator)80 WaitGroup::WaitGroup(unsigned int initialCount /* = 0 */,
81                      Allocator* allocator /* = Allocator::Default */)
82     : data(std::make_shared<Data>(allocator)) {
83   data->count = initialCount;
84 }
85 
add(unsigned int count)86 void WaitGroup::add(unsigned int count /* = 1 */) const {
87   data->count += count;
88 }
89 
done()90 bool WaitGroup::done() const {
91   MARL_ASSERT(data->count > 0, "marl::WaitGroup::done() called too many times");
92   auto count = --data->count;
93   if (count == 0) {
94     marl::lock lock(data->mutex);
95     data->cv.notify_all();
96     return true;
97   }
98   return false;
99 }
100 
wait()101 void WaitGroup::wait() const {
102   marl::lock lock(data->mutex);
103   data->cv.wait(lock, [this] { return data->count == 0; });
104 }
105 
106 }  // namespace marl
107 
108 #endif  // marl_waitgroup_h
109