1 // Copyright (C) 2019 The Android Open Source Project
2 // Copyright (C) 2019 Google Inc.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 #include "aemu/base/threads/AndroidWorkPool.h"
16 
17 #include "aemu/base/threads/AndroidFunctorThread.h"
18 #include "aemu/base/synchronization/AndroidLock.h"
19 #include "aemu/base/synchronization/AndroidConditionVariable.h"
20 #include "aemu/base/synchronization/AndroidMessageChannel.h"
21 
22 #include <atomic>
23 #include <memory>
24 #include <unordered_map>
25 #include <sys/time.h>
26 
27 using gfxstream::guest::AutoLock;
28 using gfxstream::guest::ConditionVariable;
29 using gfxstream::guest::FunctorThread;
30 using gfxstream::guest::Lock;
31 using gfxstream::guest::MessageChannel;
32 
33 namespace gfxstream {
34 namespace guest {
35 
36 class WaitGroup { // intrusive refcounted
37 public:
38 
WaitGroup(int numTasksRemaining)39     WaitGroup(int numTasksRemaining) :
40         mNumTasksInitial(numTasksRemaining),
41         mNumTasksRemaining(numTasksRemaining) { }
42 
43     ~WaitGroup() = default;
44 
getLock()45     gfxstream::guest::Lock& getLock() { return mLock; }
46 
acquire()47     void acquire() {
48         if (0 == mRefCount.fetch_add(1, std::memory_order_seq_cst)) {
49             ALOGE("%s: goofed, refcount0 acquire\n", __func__);
50             abort();
51         }
52     }
53 
release()54     bool release() {
55         if (0 == mRefCount) {
56             ALOGE("%s: goofed, refcount0 release\n", __func__);
57             abort();
58         }
59         if (1 == mRefCount.fetch_sub(1, std::memory_order_seq_cst)) {
60             std::atomic_thread_fence(std::memory_order_acquire);
61             delete this;
62             return true;
63         }
64         return false;
65     }
66 
67     // wait on all of or any of the associated tasks to complete.
waitAllLocked(WorkPool::TimeoutUs timeout)68     bool waitAllLocked(WorkPool::TimeoutUs timeout) {
69         return conditionalTimeoutLocked(
70             [this] { return mNumTasksRemaining > 0; },
71             timeout);
72     }
73 
waitAnyLocked(WorkPool::TimeoutUs timeout)74     bool waitAnyLocked(WorkPool::TimeoutUs timeout) {
75         return conditionalTimeoutLocked(
76             [this] { return mNumTasksRemaining == mNumTasksInitial; },
77             timeout);
78     }
79 
80     // broadcasts to all waiters that there has been a new job that has completed
decrementBroadcast()81     bool decrementBroadcast() {
82         AutoLock<Lock> lock(mLock);
83         bool done =
84             (1 == mNumTasksRemaining.fetch_sub(1, std::memory_order_seq_cst));
85         std::atomic_thread_fence(std::memory_order_acquire);
86         mCv.broadcast();
87         return done;
88     }
89 
90 private:
91 
doWait(WorkPool::TimeoutUs timeout)92     bool doWait(WorkPool::TimeoutUs timeout) {
93         if (timeout == ~0ULL) {
94             ALOGV("%s: uncond wait\n", __func__);
95             mCv.wait(&mLock);
96             return true;
97         } else {
98             return mCv.timedWait(&mLock, getDeadline(timeout));
99         }
100     }
101 
getDeadline(WorkPool::TimeoutUs relative)102     struct timespec getDeadline(WorkPool::TimeoutUs relative) {
103         struct timeval deadlineUs;
104         struct timespec deadlineNs;
105         gettimeofday(&deadlineUs, 0);
106 
107         auto prevDeadlineUs = deadlineUs.tv_usec;
108 
109         deadlineUs.tv_usec += relative;
110 
111         // Wrap around
112         if (prevDeadlineUs > deadlineUs.tv_usec) {
113             ++deadlineUs.tv_sec;
114         }
115 
116         deadlineNs.tv_sec = deadlineUs.tv_sec;
117         deadlineNs.tv_nsec = deadlineUs.tv_usec * 1000LL;
118         return deadlineNs;
119     }
120 
currTimeUs()121     uint64_t currTimeUs() {
122         struct timeval tv;
123         gettimeofday(&tv, 0);
124         return (uint64_t)(tv.tv_sec * 1000000LL + tv.tv_usec);
125     }
126 
conditionalTimeoutLocked(std::function<bool ()> conditionFunc,WorkPool::TimeoutUs timeout)127     bool conditionalTimeoutLocked(std::function<bool()> conditionFunc, WorkPool::TimeoutUs timeout) {
128         uint64_t currTime = currTimeUs();
129         WorkPool::TimeoutUs currTimeout = timeout;
130 
131         while (conditionFunc()) {
132             doWait(currTimeout);
133             if (conditionFunc()) {
134                 // Decrement timeout for wakeups
135                 uint64_t nextTime = currTimeUs();
136                 WorkPool::TimeoutUs waited =
137                     nextTime - currTime;
138                 currTime = nextTime;
139 
140                 if (currTimeout > waited) {
141                     currTimeout -= waited;
142                 } else {
143                     return conditionFunc();
144                 }
145             }
146         }
147 
148         return true;
149     }
150 
151     std::atomic<int> mRefCount = { 1 };
152     int mNumTasksInitial;
153     std::atomic<int> mNumTasksRemaining;
154 
155     Lock mLock;
156     ConditionVariable mCv;
157 };
158 
159 class WorkPoolThread {
160 public:
161     // State diagram for each work pool thread
162     //
163     // Unacquired: (Start state) When no one else has claimed the thread.
164     // Acquired: When the thread has been claimed for work,
165     // but work has not been issued to it yet.
166     // Scheduled: When the thread is running tasks from the acquirer.
167     // Exiting: cleanup
168     //
169     // Messages:
170     //
171     // Acquire
172     // Run
173     // Exit
174     //
175     // Transitions:
176     //
177     // Note: While task is being run, messages will come back with a failure value.
178     //
179     // Unacquired:
180     //     message Acquire -> Acquired. effect: return success value
181     //     message Run -> Unacquired. effect: return failure value
182     //     message Exit -> Exiting. effect: return success value
183     //
184     // Acquired:
185     //     message Acquire -> Acquired. effect: return failure value
186     //     message Run -> Scheduled. effect: run the task, return success
187     //     message Exit -> Exiting. effect: return success value
188     //
189     // Scheduled:
190     //     implicit effect: after task is run, transition back to Unacquired.
191     //     message Acquire -> Scheduled. effect: return failure value
192     //     message Run -> Scheduled. effect: return failure value
193     //     message Exit -> queue up exit message, then transition to Exiting after that is done.
194     //         effect: return success value
195     //
196     enum State {
197         Unacquired = 0,
198         Acquired = 1,
199         Scheduled = 2,
200         Exiting = 3,
201     };
202 
__anon141b1e2d0302null203     WorkPoolThread() : mThread([this] { threadFunc(); }) {
204         mThread.start();
205     }
206 
~WorkPoolThread()207     ~WorkPoolThread() {
208         exit();
209         mThread.wait();
210     }
211 
acquire()212     bool acquire() {
213         AutoLock<Lock> lock(mLock);
214         switch (mState) {
215             case State::Unacquired:
216                 mState = State::Acquired;
217                 return true;
218             case State::Acquired:
219             case State::Scheduled:
220             case State::Exiting:
221                 return false;
222             default:
223                 return false;
224         }
225     }
226 
run(WorkPool::WaitGroupHandle waitGroupHandle,WaitGroup * waitGroup,WorkPool::Task task)227     bool run(WorkPool::WaitGroupHandle waitGroupHandle, WaitGroup* waitGroup, WorkPool::Task task) {
228         AutoLock<Lock> lock(mLock);
229         switch (mState) {
230             case State::Unacquired:
231                 return false;
232             case State::Acquired: {
233                 mState = State::Scheduled;
234                 mToCleanupWaitGroupHandle = waitGroupHandle;
235                 waitGroup->acquire();
236                 mToCleanupWaitGroup = waitGroup;
237                 mShouldCleanupWaitGroup = false;
238                 TaskInfo msg = {
239                     Command::Run,
240                     waitGroup, task,
241                 };
242                 mRunMessages.send(msg);
243                 return true;
244             }
245             case State::Scheduled:
246             case State::Exiting:
247                 return false;
248             default:
249                 return false;
250         }
251     }
252 
shouldCleanupWaitGroup(WorkPool::WaitGroupHandle * waitGroupHandle,WaitGroup ** waitGroup)253     bool shouldCleanupWaitGroup(WorkPool::WaitGroupHandle* waitGroupHandle, WaitGroup** waitGroup) {
254         AutoLock<Lock> lock(mLock);
255         bool res = mShouldCleanupWaitGroup;
256         *waitGroupHandle = mToCleanupWaitGroupHandle;
257         *waitGroup = mToCleanupWaitGroup;
258         mShouldCleanupWaitGroup = false;
259         return res;
260     }
261 
262 private:
263     enum Command {
264         Run = 0,
265         Exit = 1,
266     };
267 
268     struct TaskInfo {
269         Command cmd;
270         WaitGroup* waitGroup = nullptr;
271         WorkPool::Task task = {};
272     };
273 
exit()274     bool exit() {
275         AutoLock<Lock> lock(mLock);
276         TaskInfo msg { Command::Exit, };
277         mRunMessages.send(msg);
278         return true;
279     }
280 
threadFunc()281     void threadFunc() {
282         TaskInfo taskInfo;
283         bool done = false;
284 
285         while (!done) {
286             mRunMessages.receive(&taskInfo);
287             switch (taskInfo.cmd) {
288                 case Command::Run:
289                     doRun(taskInfo);
290                     break;
291                 case Command::Exit: {
292                     AutoLock<Lock> lock(mLock);
293                     mState = State::Exiting;
294                     break;
295                 }
296             }
297             AutoLock<Lock> lock(mLock);
298             done = mState == State::Exiting;
299         }
300     }
301 
302     // Assumption: the wait group refcount is >= 1 when entering
303     // this function (before decrement)..
304     // at least it doesn't get to 0
doRun(TaskInfo & msg)305     void doRun(TaskInfo& msg) {
306         WaitGroup* waitGroup = msg.waitGroup;
307 
308         if (msg.task) msg.task();
309 
310         bool lastTask =
311             waitGroup->decrementBroadcast();
312 
313         AutoLock<Lock> lock(mLock);
314         mState = State::Unacquired;
315 
316         if (lastTask) {
317             mShouldCleanupWaitGroup = true;
318         }
319 
320         waitGroup->release();
321     }
322 
323     FunctorThread mThread;
324     Lock mLock;
325     State mState = State::Unacquired;
326     MessageChannel<TaskInfo, 4> mRunMessages;
327     WorkPool::WaitGroupHandle mToCleanupWaitGroupHandle = 0;
328     WaitGroup* mToCleanupWaitGroup = nullptr;
329     bool mShouldCleanupWaitGroup = false;
330 };
331 
332 class WorkPool::Impl {
333 public:
Impl(int numInitialThreads)334     Impl(int numInitialThreads) : mThreads(numInitialThreads) {
335         for (size_t i = 0; i < mThreads.size(); ++i) {
336             mThreads[i].reset(new WorkPoolThread);
337         }
338     }
339 
340     ~Impl() = default;
341 
schedule(const std::vector<WorkPool::Task> & tasks)342     WorkPool::WaitGroupHandle schedule(const std::vector<WorkPool::Task>& tasks) {
343 
344         if (tasks.empty()) abort();
345 
346         AutoLock<Lock> lock(mLock);
347 
348         // Sweep old wait groups
349         for (size_t i = 0; i < mThreads.size(); ++i) {
350             WaitGroupHandle handle;
351             WaitGroup* waitGroup;
352             bool cleanup = mThreads[i]->shouldCleanupWaitGroup(&handle, &waitGroup);
353             if (cleanup) {
354                 mWaitGroups.erase(handle);
355                 waitGroup->release();
356             }
357         }
358 
359         WorkPool::WaitGroupHandle resHandle = genWaitGroupHandleLocked();
360         WaitGroup* waitGroup =
361             new WaitGroup(tasks.size());
362 
363         mWaitGroups[resHandle] = waitGroup;
364 
365         std::vector<size_t> threadIndices;
366 
367         while (threadIndices.size() < tasks.size()) {
368             for (size_t i = 0; i < mThreads.size(); ++i) {
369                 if (!mThreads[i]->acquire()) continue;
370                 threadIndices.push_back(i);
371                 if (threadIndices.size() == tasks.size()) break;
372             }
373             if (threadIndices.size() < tasks.size()) {
374                 mThreads.resize(mThreads.size() + 1);
375                 mThreads[mThreads.size() - 1].reset(new WorkPoolThread);
376             }
377         }
378 
379         // every thread here is acquired
380         for (size_t i = 0; i < threadIndices.size(); ++i) {
381             mThreads[threadIndices[i]]->run(resHandle, waitGroup, tasks[i]);
382         }
383 
384         return resHandle;
385     }
386 
waitAny(WorkPool::WaitGroupHandle waitGroupHandle,WorkPool::TimeoutUs timeout)387     bool waitAny(WorkPool::WaitGroupHandle waitGroupHandle, WorkPool::TimeoutUs timeout) {
388         AutoLock<Lock> lock(mLock);
389         auto it = mWaitGroups.find(waitGroupHandle);
390         if (it == mWaitGroups.end()) return true;
391 
392         auto waitGroup = it->second;
393         waitGroup->acquire();
394         lock.unlock();
395 
396         bool waitRes = false;
397 
398         {
399             AutoLock<Lock> waitGroupLock(waitGroup->getLock());
400             waitRes = waitGroup->waitAnyLocked(timeout);
401         }
402 
403         waitGroup->release();
404 
405         return waitRes;
406     }
407 
waitAll(WorkPool::WaitGroupHandle waitGroupHandle,WorkPool::TimeoutUs timeout)408     bool waitAll(WorkPool::WaitGroupHandle waitGroupHandle, WorkPool::TimeoutUs timeout) {
409         auto waitGroup = acquireWaitGroupFromHandle(waitGroupHandle);
410         if (!waitGroup) return true;
411 
412         bool waitRes = false;
413 
414         {
415             AutoLock<Lock> waitGroupLock(waitGroup->getLock());
416             waitRes = waitGroup->waitAllLocked(timeout);
417         }
418 
419         waitGroup->release();
420 
421         return waitRes;
422     }
423 
424 private:
425     // Increments wait group refcount by 1.
acquireWaitGroupFromHandle(WorkPool::WaitGroupHandle waitGroupHandle)426     WaitGroup* acquireWaitGroupFromHandle(WorkPool::WaitGroupHandle waitGroupHandle) {
427         AutoLock<Lock> lock(mLock);
428         auto it = mWaitGroups.find(waitGroupHandle);
429         if (it == mWaitGroups.end()) return nullptr;
430 
431         auto waitGroup = it->second;
432         waitGroup->acquire();
433 
434         return waitGroup;
435     }
436 
437     using WaitGroupStore = std::unordered_map<WorkPool::WaitGroupHandle, WaitGroup*>;
438 
genWaitGroupHandleLocked()439     WorkPool::WaitGroupHandle genWaitGroupHandleLocked() {
440         WorkPool::WaitGroupHandle res = mNextWaitGroupHandle;
441         ++mNextWaitGroupHandle;
442         return res;
443     }
444 
445     Lock mLock;
446     uint64_t mNextWaitGroupHandle = 0;
447     WaitGroupStore mWaitGroups;
448     std::vector<std::unique_ptr<WorkPoolThread>> mThreads;
449 };
450 
WorkPool(int numInitialThreads)451 WorkPool::WorkPool(int numInitialThreads) : mImpl(new WorkPool::Impl(numInitialThreads)) { }
452 WorkPool::~WorkPool() = default;
453 
schedule(const std::vector<WorkPool::Task> & tasks)454 WorkPool::WaitGroupHandle WorkPool::schedule(const std::vector<WorkPool::Task>& tasks) {
455     return mImpl->schedule(tasks);
456 }
457 
waitAny(WorkPool::WaitGroupHandle waitGroup,WorkPool::TimeoutUs timeout)458 bool WorkPool::waitAny(WorkPool::WaitGroupHandle waitGroup, WorkPool::TimeoutUs timeout) {
459     return mImpl->waitAny(waitGroup, timeout);
460 }
461 
waitAll(WorkPool::WaitGroupHandle waitGroup,WorkPool::TimeoutUs timeout)462 bool WorkPool::waitAll(WorkPool::WaitGroupHandle waitGroup, WorkPool::TimeoutUs timeout) {
463     return mImpl->waitAll(waitGroup, timeout);
464 }
465 
466 } // namespace guest
467 } // namespace gfxstream
468