1 // Copyright (C) 2019 The Android Open Source Project
2 // Copyright (C) 2019 Google Inc.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 #include "aemu/base/threads/AndroidWorkPool.h"
16
17 #include "aemu/base/threads/AndroidFunctorThread.h"
18 #include "aemu/base/synchronization/AndroidLock.h"
19 #include "aemu/base/synchronization/AndroidConditionVariable.h"
20 #include "aemu/base/synchronization/AndroidMessageChannel.h"
21
22 #include <atomic>
23 #include <memory>
24 #include <unordered_map>
25 #include <sys/time.h>
26
27 using gfxstream::guest::AutoLock;
28 using gfxstream::guest::ConditionVariable;
29 using gfxstream::guest::FunctorThread;
30 using gfxstream::guest::Lock;
31 using gfxstream::guest::MessageChannel;
32
33 namespace gfxstream {
34 namespace guest {
35
36 class WaitGroup { // intrusive refcounted
37 public:
38
WaitGroup(int numTasksRemaining)39 WaitGroup(int numTasksRemaining) :
40 mNumTasksInitial(numTasksRemaining),
41 mNumTasksRemaining(numTasksRemaining) { }
42
43 ~WaitGroup() = default;
44
getLock()45 gfxstream::guest::Lock& getLock() { return mLock; }
46
acquire()47 void acquire() {
48 if (0 == mRefCount.fetch_add(1, std::memory_order_seq_cst)) {
49 ALOGE("%s: goofed, refcount0 acquire\n", __func__);
50 abort();
51 }
52 }
53
release()54 bool release() {
55 if (0 == mRefCount) {
56 ALOGE("%s: goofed, refcount0 release\n", __func__);
57 abort();
58 }
59 if (1 == mRefCount.fetch_sub(1, std::memory_order_seq_cst)) {
60 std::atomic_thread_fence(std::memory_order_acquire);
61 delete this;
62 return true;
63 }
64 return false;
65 }
66
67 // wait on all of or any of the associated tasks to complete.
waitAllLocked(WorkPool::TimeoutUs timeout)68 bool waitAllLocked(WorkPool::TimeoutUs timeout) {
69 return conditionalTimeoutLocked(
70 [this] { return mNumTasksRemaining > 0; },
71 timeout);
72 }
73
waitAnyLocked(WorkPool::TimeoutUs timeout)74 bool waitAnyLocked(WorkPool::TimeoutUs timeout) {
75 return conditionalTimeoutLocked(
76 [this] { return mNumTasksRemaining == mNumTasksInitial; },
77 timeout);
78 }
79
80 // broadcasts to all waiters that there has been a new job that has completed
decrementBroadcast()81 bool decrementBroadcast() {
82 AutoLock<Lock> lock(mLock);
83 bool done =
84 (1 == mNumTasksRemaining.fetch_sub(1, std::memory_order_seq_cst));
85 std::atomic_thread_fence(std::memory_order_acquire);
86 mCv.broadcast();
87 return done;
88 }
89
90 private:
91
doWait(WorkPool::TimeoutUs timeout)92 bool doWait(WorkPool::TimeoutUs timeout) {
93 if (timeout == ~0ULL) {
94 ALOGV("%s: uncond wait\n", __func__);
95 mCv.wait(&mLock);
96 return true;
97 } else {
98 return mCv.timedWait(&mLock, getDeadline(timeout));
99 }
100 }
101
getDeadline(WorkPool::TimeoutUs relative)102 struct timespec getDeadline(WorkPool::TimeoutUs relative) {
103 struct timeval deadlineUs;
104 struct timespec deadlineNs;
105 gettimeofday(&deadlineUs, 0);
106
107 auto prevDeadlineUs = deadlineUs.tv_usec;
108
109 deadlineUs.tv_usec += relative;
110
111 // Wrap around
112 if (prevDeadlineUs > deadlineUs.tv_usec) {
113 ++deadlineUs.tv_sec;
114 }
115
116 deadlineNs.tv_sec = deadlineUs.tv_sec;
117 deadlineNs.tv_nsec = deadlineUs.tv_usec * 1000LL;
118 return deadlineNs;
119 }
120
currTimeUs()121 uint64_t currTimeUs() {
122 struct timeval tv;
123 gettimeofday(&tv, 0);
124 return (uint64_t)(tv.tv_sec * 1000000LL + tv.tv_usec);
125 }
126
conditionalTimeoutLocked(std::function<bool ()> conditionFunc,WorkPool::TimeoutUs timeout)127 bool conditionalTimeoutLocked(std::function<bool()> conditionFunc, WorkPool::TimeoutUs timeout) {
128 uint64_t currTime = currTimeUs();
129 WorkPool::TimeoutUs currTimeout = timeout;
130
131 while (conditionFunc()) {
132 doWait(currTimeout);
133 if (conditionFunc()) {
134 // Decrement timeout for wakeups
135 uint64_t nextTime = currTimeUs();
136 WorkPool::TimeoutUs waited =
137 nextTime - currTime;
138 currTime = nextTime;
139
140 if (currTimeout > waited) {
141 currTimeout -= waited;
142 } else {
143 return conditionFunc();
144 }
145 }
146 }
147
148 return true;
149 }
150
151 std::atomic<int> mRefCount = { 1 };
152 int mNumTasksInitial;
153 std::atomic<int> mNumTasksRemaining;
154
155 Lock mLock;
156 ConditionVariable mCv;
157 };
158
159 class WorkPoolThread {
160 public:
161 // State diagram for each work pool thread
162 //
163 // Unacquired: (Start state) When no one else has claimed the thread.
164 // Acquired: When the thread has been claimed for work,
165 // but work has not been issued to it yet.
166 // Scheduled: When the thread is running tasks from the acquirer.
167 // Exiting: cleanup
168 //
169 // Messages:
170 //
171 // Acquire
172 // Run
173 // Exit
174 //
175 // Transitions:
176 //
177 // Note: While task is being run, messages will come back with a failure value.
178 //
179 // Unacquired:
180 // message Acquire -> Acquired. effect: return success value
181 // message Run -> Unacquired. effect: return failure value
182 // message Exit -> Exiting. effect: return success value
183 //
184 // Acquired:
185 // message Acquire -> Acquired. effect: return failure value
186 // message Run -> Scheduled. effect: run the task, return success
187 // message Exit -> Exiting. effect: return success value
188 //
189 // Scheduled:
190 // implicit effect: after task is run, transition back to Unacquired.
191 // message Acquire -> Scheduled. effect: return failure value
192 // message Run -> Scheduled. effect: return failure value
193 // message Exit -> queue up exit message, then transition to Exiting after that is done.
194 // effect: return success value
195 //
196 enum State {
197 Unacquired = 0,
198 Acquired = 1,
199 Scheduled = 2,
200 Exiting = 3,
201 };
202
__anon141b1e2d0302null203 WorkPoolThread() : mThread([this] { threadFunc(); }) {
204 mThread.start();
205 }
206
~WorkPoolThread()207 ~WorkPoolThread() {
208 exit();
209 mThread.wait();
210 }
211
acquire()212 bool acquire() {
213 AutoLock<Lock> lock(mLock);
214 switch (mState) {
215 case State::Unacquired:
216 mState = State::Acquired;
217 return true;
218 case State::Acquired:
219 case State::Scheduled:
220 case State::Exiting:
221 return false;
222 default:
223 return false;
224 }
225 }
226
run(WorkPool::WaitGroupHandle waitGroupHandle,WaitGroup * waitGroup,WorkPool::Task task)227 bool run(WorkPool::WaitGroupHandle waitGroupHandle, WaitGroup* waitGroup, WorkPool::Task task) {
228 AutoLock<Lock> lock(mLock);
229 switch (mState) {
230 case State::Unacquired:
231 return false;
232 case State::Acquired: {
233 mState = State::Scheduled;
234 mToCleanupWaitGroupHandle = waitGroupHandle;
235 waitGroup->acquire();
236 mToCleanupWaitGroup = waitGroup;
237 mShouldCleanupWaitGroup = false;
238 TaskInfo msg = {
239 Command::Run,
240 waitGroup, task,
241 };
242 mRunMessages.send(msg);
243 return true;
244 }
245 case State::Scheduled:
246 case State::Exiting:
247 return false;
248 default:
249 return false;
250 }
251 }
252
shouldCleanupWaitGroup(WorkPool::WaitGroupHandle * waitGroupHandle,WaitGroup ** waitGroup)253 bool shouldCleanupWaitGroup(WorkPool::WaitGroupHandle* waitGroupHandle, WaitGroup** waitGroup) {
254 AutoLock<Lock> lock(mLock);
255 bool res = mShouldCleanupWaitGroup;
256 *waitGroupHandle = mToCleanupWaitGroupHandle;
257 *waitGroup = mToCleanupWaitGroup;
258 mShouldCleanupWaitGroup = false;
259 return res;
260 }
261
262 private:
263 enum Command {
264 Run = 0,
265 Exit = 1,
266 };
267
268 struct TaskInfo {
269 Command cmd;
270 WaitGroup* waitGroup = nullptr;
271 WorkPool::Task task = {};
272 };
273
exit()274 bool exit() {
275 AutoLock<Lock> lock(mLock);
276 TaskInfo msg { Command::Exit, };
277 mRunMessages.send(msg);
278 return true;
279 }
280
threadFunc()281 void threadFunc() {
282 TaskInfo taskInfo;
283 bool done = false;
284
285 while (!done) {
286 mRunMessages.receive(&taskInfo);
287 switch (taskInfo.cmd) {
288 case Command::Run:
289 doRun(taskInfo);
290 break;
291 case Command::Exit: {
292 AutoLock<Lock> lock(mLock);
293 mState = State::Exiting;
294 break;
295 }
296 }
297 AutoLock<Lock> lock(mLock);
298 done = mState == State::Exiting;
299 }
300 }
301
302 // Assumption: the wait group refcount is >= 1 when entering
303 // this function (before decrement)..
304 // at least it doesn't get to 0
doRun(TaskInfo & msg)305 void doRun(TaskInfo& msg) {
306 WaitGroup* waitGroup = msg.waitGroup;
307
308 if (msg.task) msg.task();
309
310 bool lastTask =
311 waitGroup->decrementBroadcast();
312
313 AutoLock<Lock> lock(mLock);
314 mState = State::Unacquired;
315
316 if (lastTask) {
317 mShouldCleanupWaitGroup = true;
318 }
319
320 waitGroup->release();
321 }
322
323 FunctorThread mThread;
324 Lock mLock;
325 State mState = State::Unacquired;
326 MessageChannel<TaskInfo, 4> mRunMessages;
327 WorkPool::WaitGroupHandle mToCleanupWaitGroupHandle = 0;
328 WaitGroup* mToCleanupWaitGroup = nullptr;
329 bool mShouldCleanupWaitGroup = false;
330 };
331
332 class WorkPool::Impl {
333 public:
Impl(int numInitialThreads)334 Impl(int numInitialThreads) : mThreads(numInitialThreads) {
335 for (size_t i = 0; i < mThreads.size(); ++i) {
336 mThreads[i].reset(new WorkPoolThread);
337 }
338 }
339
340 ~Impl() = default;
341
schedule(const std::vector<WorkPool::Task> & tasks)342 WorkPool::WaitGroupHandle schedule(const std::vector<WorkPool::Task>& tasks) {
343
344 if (tasks.empty()) abort();
345
346 AutoLock<Lock> lock(mLock);
347
348 // Sweep old wait groups
349 for (size_t i = 0; i < mThreads.size(); ++i) {
350 WaitGroupHandle handle;
351 WaitGroup* waitGroup;
352 bool cleanup = mThreads[i]->shouldCleanupWaitGroup(&handle, &waitGroup);
353 if (cleanup) {
354 mWaitGroups.erase(handle);
355 waitGroup->release();
356 }
357 }
358
359 WorkPool::WaitGroupHandle resHandle = genWaitGroupHandleLocked();
360 WaitGroup* waitGroup =
361 new WaitGroup(tasks.size());
362
363 mWaitGroups[resHandle] = waitGroup;
364
365 std::vector<size_t> threadIndices;
366
367 while (threadIndices.size() < tasks.size()) {
368 for (size_t i = 0; i < mThreads.size(); ++i) {
369 if (!mThreads[i]->acquire()) continue;
370 threadIndices.push_back(i);
371 if (threadIndices.size() == tasks.size()) break;
372 }
373 if (threadIndices.size() < tasks.size()) {
374 mThreads.resize(mThreads.size() + 1);
375 mThreads[mThreads.size() - 1].reset(new WorkPoolThread);
376 }
377 }
378
379 // every thread here is acquired
380 for (size_t i = 0; i < threadIndices.size(); ++i) {
381 mThreads[threadIndices[i]]->run(resHandle, waitGroup, tasks[i]);
382 }
383
384 return resHandle;
385 }
386
waitAny(WorkPool::WaitGroupHandle waitGroupHandle,WorkPool::TimeoutUs timeout)387 bool waitAny(WorkPool::WaitGroupHandle waitGroupHandle, WorkPool::TimeoutUs timeout) {
388 AutoLock<Lock> lock(mLock);
389 auto it = mWaitGroups.find(waitGroupHandle);
390 if (it == mWaitGroups.end()) return true;
391
392 auto waitGroup = it->second;
393 waitGroup->acquire();
394 lock.unlock();
395
396 bool waitRes = false;
397
398 {
399 AutoLock<Lock> waitGroupLock(waitGroup->getLock());
400 waitRes = waitGroup->waitAnyLocked(timeout);
401 }
402
403 waitGroup->release();
404
405 return waitRes;
406 }
407
waitAll(WorkPool::WaitGroupHandle waitGroupHandle,WorkPool::TimeoutUs timeout)408 bool waitAll(WorkPool::WaitGroupHandle waitGroupHandle, WorkPool::TimeoutUs timeout) {
409 auto waitGroup = acquireWaitGroupFromHandle(waitGroupHandle);
410 if (!waitGroup) return true;
411
412 bool waitRes = false;
413
414 {
415 AutoLock<Lock> waitGroupLock(waitGroup->getLock());
416 waitRes = waitGroup->waitAllLocked(timeout);
417 }
418
419 waitGroup->release();
420
421 return waitRes;
422 }
423
424 private:
425 // Increments wait group refcount by 1.
acquireWaitGroupFromHandle(WorkPool::WaitGroupHandle waitGroupHandle)426 WaitGroup* acquireWaitGroupFromHandle(WorkPool::WaitGroupHandle waitGroupHandle) {
427 AutoLock<Lock> lock(mLock);
428 auto it = mWaitGroups.find(waitGroupHandle);
429 if (it == mWaitGroups.end()) return nullptr;
430
431 auto waitGroup = it->second;
432 waitGroup->acquire();
433
434 return waitGroup;
435 }
436
437 using WaitGroupStore = std::unordered_map<WorkPool::WaitGroupHandle, WaitGroup*>;
438
genWaitGroupHandleLocked()439 WorkPool::WaitGroupHandle genWaitGroupHandleLocked() {
440 WorkPool::WaitGroupHandle res = mNextWaitGroupHandle;
441 ++mNextWaitGroupHandle;
442 return res;
443 }
444
445 Lock mLock;
446 uint64_t mNextWaitGroupHandle = 0;
447 WaitGroupStore mWaitGroups;
448 std::vector<std::unique_ptr<WorkPoolThread>> mThreads;
449 };
450
WorkPool(int numInitialThreads)451 WorkPool::WorkPool(int numInitialThreads) : mImpl(new WorkPool::Impl(numInitialThreads)) { }
452 WorkPool::~WorkPool() = default;
453
schedule(const std::vector<WorkPool::Task> & tasks)454 WorkPool::WaitGroupHandle WorkPool::schedule(const std::vector<WorkPool::Task>& tasks) {
455 return mImpl->schedule(tasks);
456 }
457
waitAny(WorkPool::WaitGroupHandle waitGroup,WorkPool::TimeoutUs timeout)458 bool WorkPool::waitAny(WorkPool::WaitGroupHandle waitGroup, WorkPool::TimeoutUs timeout) {
459 return mImpl->waitAny(waitGroup, timeout);
460 }
461
waitAll(WorkPool::WaitGroupHandle waitGroup,WorkPool::TimeoutUs timeout)462 bool WorkPool::waitAll(WorkPool::WaitGroupHandle waitGroup, WorkPool::TimeoutUs timeout) {
463 return mImpl->waitAll(waitGroup, timeout);
464 }
465
466 } // namespace guest
467 } // namespace gfxstream
468