1 //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements basic Async runtime API for supporting Async dialect
10 // to LLVM dialect lowering.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/ExecutionEngine/AsyncRuntime.h"
15
16 #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
17
18 #include <atomic>
19 #include <cassert>
20 #include <condition_variable>
21 #include <functional>
22 #include <iostream>
23 #include <mutex>
24 #include <thread>
25 #include <vector>
26
27 //===----------------------------------------------------------------------===//
28 // Async runtime API.
29 //===----------------------------------------------------------------------===//
30
31 namespace {
32
33 // Forward declare class defined below.
34 class RefCounted;
35
36 // -------------------------------------------------------------------------- //
37 // AsyncRuntime orchestrates all async operations and Async runtime API is built
38 // on top of the default runtime instance.
39 // -------------------------------------------------------------------------- //
40
41 class AsyncRuntime {
42 public:
AsyncRuntime()43 AsyncRuntime() : numRefCountedObjects(0) {}
44
~AsyncRuntime()45 ~AsyncRuntime() {
46 assert(getNumRefCountedObjects() == 0 &&
47 "all ref counted objects must be destroyed");
48 }
49
getNumRefCountedObjects()50 int32_t getNumRefCountedObjects() {
51 return numRefCountedObjects.load(std::memory_order_relaxed);
52 }
53
54 private:
55 friend class RefCounted;
56
57 // Count the total number of reference counted objects in this instance
58 // of an AsyncRuntime. For debugging purposes only.
addNumRefCountedObjects()59 void addNumRefCountedObjects() {
60 numRefCountedObjects.fetch_add(1, std::memory_order_relaxed);
61 }
dropNumRefCountedObjects()62 void dropNumRefCountedObjects() {
63 numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed);
64 }
65
66 std::atomic<int32_t> numRefCountedObjects;
67 };
68
69 // Returns the default per-process instance of an async runtime.
getDefaultAsyncRuntimeInstance()70 AsyncRuntime *getDefaultAsyncRuntimeInstance() {
71 static auto runtime = std::make_unique<AsyncRuntime>();
72 return runtime.get();
73 }
74
75 // -------------------------------------------------------------------------- //
76 // A base class for all reference counted objects created by the async runtime.
77 // -------------------------------------------------------------------------- //
78
79 class RefCounted {
80 public:
RefCounted(AsyncRuntime * runtime,int32_t refCount=1)81 RefCounted(AsyncRuntime *runtime, int32_t refCount = 1)
82 : runtime(runtime), refCount(refCount) {
83 runtime->addNumRefCountedObjects();
84 }
85
~RefCounted()86 virtual ~RefCounted() {
87 assert(refCount.load() == 0 && "reference count must be zero");
88 runtime->dropNumRefCountedObjects();
89 }
90
91 RefCounted(const RefCounted &) = delete;
92 RefCounted &operator=(const RefCounted &) = delete;
93
addRef(int32_t count=1)94 void addRef(int32_t count = 1) { refCount.fetch_add(count); }
95
dropRef(int32_t count=1)96 void dropRef(int32_t count = 1) {
97 int32_t previous = refCount.fetch_sub(count);
98 assert(previous >= count && "reference count should not go below zero");
99 if (previous == count)
100 destroy();
101 }
102
103 protected:
destroy()104 virtual void destroy() { delete this; }
105
106 private:
107 AsyncRuntime *runtime;
108 std::atomic<int32_t> refCount;
109 };
110
111 } // namespace
112
113 struct AsyncToken : public RefCounted {
114 // AsyncToken created with a reference count of 2 because it will be returned
115 // to the `async.execute` caller and also will be later on emplaced by the
116 // asynchronously executed task. If the caller immediately will drop its
117 // reference we must ensure that the token will be alive until the
118 // asynchronous operation is completed.
AsyncTokenAsyncToken119 AsyncToken(AsyncRuntime *runtime) : RefCounted(runtime, /*count=*/2) {}
120
121 // Internal state below guarded by a mutex.
122 std::mutex mu;
123 std::condition_variable cv;
124
125 bool ready = false;
126 std::vector<std::function<void()>> awaiters;
127 };
128
129 struct AsyncGroup : public RefCounted {
AsyncGroupAsyncGroup130 AsyncGroup(AsyncRuntime *runtime)
131 : RefCounted(runtime), pendingTokens(0), rank(0) {}
132
133 std::atomic<int> pendingTokens;
134 std::atomic<int> rank;
135
136 // Internal state below guarded by a mutex.
137 std::mutex mu;
138 std::condition_variable cv;
139
140 std::vector<std::function<void()>> awaiters;
141 };
142
143 // Adds references to reference counted runtime object.
mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr,int32_t count)144 extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) {
145 RefCounted *refCounted = static_cast<RefCounted *>(ptr);
146 refCounted->addRef(count);
147 }
148
149 // Drops references from reference counted runtime object.
mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr,int32_t count)150 extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) {
151 RefCounted *refCounted = static_cast<RefCounted *>(ptr);
152 refCounted->dropRef(count);
153 }
154
155 // Create a new `async.token` in not-ready state.
mlirAsyncRuntimeCreateToken()156 extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
157 AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance());
158 return token;
159 }
160
161 // Create a new `async.group` in empty state.
mlirAsyncRuntimeCreateGroup()162 extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() {
163 AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance());
164 return group;
165 }
166
mlirAsyncRuntimeAddTokenToGroup(AsyncToken * token,AsyncGroup * group)167 extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
168 AsyncGroup *group) {
169 std::unique_lock<std::mutex> lockToken(token->mu);
170 std::unique_lock<std::mutex> lockGroup(group->mu);
171
172 // Get the rank of the token inside the group before we drop the reference.
173 int rank = group->rank.fetch_add(1);
174 group->pendingTokens.fetch_add(1);
175
176 auto onTokenReady = [group]() {
177 // Run all group awaiters if it was the last token in the group.
178 if (group->pendingTokens.fetch_sub(1) == 1) {
179 group->cv.notify_all();
180 for (auto &awaiter : group->awaiters)
181 awaiter();
182 }
183 };
184
185 if (token->ready) {
186 // Update group pending tokens immediately and maybe run awaiters.
187 onTokenReady();
188
189 } else {
190 // Update group pending tokens when token will become ready. Because this
191 // will happen asynchronously we must ensure that `group` is alive until
192 // then, and re-ackquire the lock.
193 group->addRef();
194
195 token->awaiters.push_back([group, onTokenReady]() {
196 // Make sure that `dropRef` does not destroy the mutex owned by the lock.
197 {
198 std::unique_lock<std::mutex> lockGroup(group->mu);
199 onTokenReady();
200 }
201 group->dropRef();
202 });
203 }
204
205 return rank;
206 }
207
208 // Switches `async.token` to ready state and runs all awaiters.
mlirAsyncRuntimeEmplaceToken(AsyncToken * token)209 extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
210 // Make sure that `dropRef` does not destroy the mutex owned by the lock.
211 {
212 std::unique_lock<std::mutex> lock(token->mu);
213 token->ready = true;
214 token->cv.notify_all();
215 for (auto &awaiter : token->awaiters)
216 awaiter();
217 }
218
219 // Async tokens created with a ref count `2` to keep token alive until the
220 // async task completes. Drop this reference explicitly when token emplaced.
221 token->dropRef();
222 }
223
mlirAsyncRuntimeAwaitToken(AsyncToken * token)224 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
225 std::unique_lock<std::mutex> lock(token->mu);
226 if (!token->ready)
227 token->cv.wait(lock, [token] { return token->ready; });
228 }
229
mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup * group)230 extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
231 std::unique_lock<std::mutex> lock(group->mu);
232 if (group->pendingTokens != 0)
233 group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
234 }
235
mlirAsyncRuntimeExecute(CoroHandle handle,CoroResume resume)236 extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
237 (*resume)(handle);
238 }
239
mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken * token,CoroHandle handle,CoroResume resume)240 extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
241 CoroHandle handle,
242 CoroResume resume) {
243 std::unique_lock<std::mutex> lock(token->mu);
244 auto execute = [handle, resume]() { (*resume)(handle); };
245 if (token->ready)
246 execute();
247 else
248 token->awaiters.push_back([execute]() { execute(); });
249 }
250
mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup * group,CoroHandle handle,CoroResume resume)251 extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
252 CoroHandle handle,
253 CoroResume resume) {
254 std::unique_lock<std::mutex> lock(group->mu);
255 auto execute = [handle, resume]() { (*resume)(handle); };
256 if (group->pendingTokens == 0)
257 execute();
258 else
259 group->awaiters.push_back([execute]() { execute(); });
260 }
261
262 //===----------------------------------------------------------------------===//
263 // Small async runtime support library for testing.
264 //===----------------------------------------------------------------------===//
265
mlirAsyncRuntimePrintCurrentThreadId()266 extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
267 static thread_local std::thread::id thisId = std::this_thread::get_id();
268 std::cout << "Current thread id: " << thisId << std::endl;
269 }
270
271 #endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
272