1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/Async/IR/Async.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "mlir/Transforms/RegionUtils.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/Support/FormatVariadic.h"
23
24 #define DEBUG_TYPE "convert-async-to-llvm"
25
26 using namespace mlir;
27 using namespace mlir::async;
28
29 // Prefix for functions outlined from `async.execute` op regions.
30 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
31
32 //===----------------------------------------------------------------------===//
33 // Async Runtime C API declaration.
34 //===----------------------------------------------------------------------===//
35
36 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
37 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
38 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
39 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
40 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
41 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
42 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
43 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
44 static constexpr const char *kAddTokenToGroup =
45 "mlirAsyncRuntimeAddTokenToGroup";
46 static constexpr const char *kAwaitAndExecute =
47 "mlirAsyncRuntimeAwaitTokenAndExecute";
48 static constexpr const char *kAwaitAllAndExecute =
49 "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
50
51 namespace {
52 // Async Runtime API function types.
53 struct AsyncAPI {
addOrDropRefFunctionType__anon61a0c9350111::AsyncAPI54 static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
55 auto ref = LLVM::LLVMType::getInt8PtrTy(ctx);
56 auto count = IntegerType::get(32, ctx);
57 return FunctionType::get({ref, count}, {}, ctx);
58 }
59
createTokenFunctionType__anon61a0c9350111::AsyncAPI60 static FunctionType createTokenFunctionType(MLIRContext *ctx) {
61 return FunctionType::get({}, {TokenType::get(ctx)}, ctx);
62 }
63
createGroupFunctionType__anon61a0c9350111::AsyncAPI64 static FunctionType createGroupFunctionType(MLIRContext *ctx) {
65 return FunctionType::get({}, {GroupType::get(ctx)}, ctx);
66 }
67
emplaceTokenFunctionType__anon61a0c9350111::AsyncAPI68 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
69 return FunctionType::get({TokenType::get(ctx)}, {}, ctx);
70 }
71
awaitTokenFunctionType__anon61a0c9350111::AsyncAPI72 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
73 return FunctionType::get({TokenType::get(ctx)}, {}, ctx);
74 }
75
awaitGroupFunctionType__anon61a0c9350111::AsyncAPI76 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
77 return FunctionType::get({GroupType::get(ctx)}, {}, ctx);
78 }
79
executeFunctionType__anon61a0c9350111::AsyncAPI80 static FunctionType executeFunctionType(MLIRContext *ctx) {
81 auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
82 auto resume = resumeFunctionType(ctx).getPointerTo();
83 return FunctionType::get({hdl, resume}, {}, ctx);
84 }
85
addTokenToGroupFunctionType__anon61a0c9350111::AsyncAPI86 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
87 auto i64 = IntegerType::get(64, ctx);
88 return FunctionType::get({TokenType::get(ctx), GroupType::get(ctx)}, {i64},
89 ctx);
90 }
91
awaitAndExecuteFunctionType__anon61a0c9350111::AsyncAPI92 static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) {
93 auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
94 auto resume = resumeFunctionType(ctx).getPointerTo();
95 return FunctionType::get({TokenType::get(ctx), hdl, resume}, {}, ctx);
96 }
97
awaitAllAndExecuteFunctionType__anon61a0c9350111::AsyncAPI98 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
99 auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
100 auto resume = resumeFunctionType(ctx).getPointerTo();
101 return FunctionType::get({GroupType::get(ctx), hdl, resume}, {}, ctx);
102 }
103
104 // Auxiliary coroutine resume intrinsic wrapper.
resumeFunctionType__anon61a0c9350111::AsyncAPI105 static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) {
106 auto voidTy = LLVM::LLVMType::getVoidTy(ctx);
107 auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
108 return LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false);
109 }
110 };
111 } // namespace
112
113 // Adds Async Runtime C API declarations to the module.
addAsyncRuntimeApiDeclarations(ModuleOp module)114 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
115 auto builder = OpBuilder::atBlockTerminator(module.getBody());
116
117 auto addFuncDecl = [&](StringRef name, FunctionType type) {
118 if (module.lookupSymbol(name))
119 return;
120 builder.create<FuncOp>(module.getLoc(), name, type).setPrivate();
121 };
122
123 MLIRContext *ctx = module.getContext();
124 addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
125 addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
126 addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
127 addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
128 addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
129 addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
130 addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
131 addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
132 addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
133 addFuncDecl(kAwaitAndExecute, AsyncAPI::awaitAndExecuteFunctionType(ctx));
134 addFuncDecl(kAwaitAllAndExecute,
135 AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
136 }
137
138 //===----------------------------------------------------------------------===//
139 // LLVM coroutines intrinsics declarations.
140 //===----------------------------------------------------------------------===//
141
142 static constexpr const char *kCoroId = "llvm.coro.id";
143 static constexpr const char *kCoroSizeI64 = "llvm.coro.size.i64";
144 static constexpr const char *kCoroBegin = "llvm.coro.begin";
145 static constexpr const char *kCoroSave = "llvm.coro.save";
146 static constexpr const char *kCoroSuspend = "llvm.coro.suspend";
147 static constexpr const char *kCoroEnd = "llvm.coro.end";
148 static constexpr const char *kCoroFree = "llvm.coro.free";
149 static constexpr const char *kCoroResume = "llvm.coro.resume";
150
151 /// Adds an LLVM function declaration to a module.
addLLVMFuncDecl(ModuleOp module,OpBuilder & builder,StringRef name,LLVM::LLVMType ret,ArrayRef<LLVM::LLVMType> params)152 static void addLLVMFuncDecl(ModuleOp module, OpBuilder &builder, StringRef name,
153 LLVM::LLVMType ret,
154 ArrayRef<LLVM::LLVMType> params) {
155 if (module.lookupSymbol(name))
156 return;
157 LLVM::LLVMType type = LLVM::LLVMType::getFunctionTy(ret, params, false);
158 builder.create<LLVM::LLVMFuncOp>(module.getLoc(), name, type);
159 }
160
161 /// Adds coroutine intrinsics declarations to the module.
addCoroutineIntrinsicsDeclarations(ModuleOp module)162 static void addCoroutineIntrinsicsDeclarations(ModuleOp module) {
163 using namespace mlir::LLVM;
164
165 MLIRContext *ctx = module.getContext();
166 OpBuilder builder(module.getBody()->getTerminator());
167
168 auto token = LLVMTokenType::get(ctx);
169 auto voidTy = LLVMType::getVoidTy(ctx);
170
171 auto i8 = LLVMType::getInt8Ty(ctx);
172 auto i1 = LLVMType::getInt1Ty(ctx);
173 auto i32 = LLVMType::getInt32Ty(ctx);
174 auto i64 = LLVMType::getInt64Ty(ctx);
175 auto i8Ptr = LLVMType::getInt8PtrTy(ctx);
176
177 addLLVMFuncDecl(module, builder, kCoroId, token, {i32, i8Ptr, i8Ptr, i8Ptr});
178 addLLVMFuncDecl(module, builder, kCoroSizeI64, i64, {});
179 addLLVMFuncDecl(module, builder, kCoroBegin, i8Ptr, {token, i8Ptr});
180 addLLVMFuncDecl(module, builder, kCoroSave, token, {i8Ptr});
181 addLLVMFuncDecl(module, builder, kCoroSuspend, i8, {token, i1});
182 addLLVMFuncDecl(module, builder, kCoroEnd, i1, {i8Ptr, i1});
183 addLLVMFuncDecl(module, builder, kCoroFree, i8Ptr, {token, i8Ptr});
184 addLLVMFuncDecl(module, builder, kCoroResume, voidTy, {i8Ptr});
185 }
186
187 //===----------------------------------------------------------------------===//
188 // Add malloc/free declarations to the module.
189 //===----------------------------------------------------------------------===//
190
191 static constexpr const char *kMalloc = "malloc";
192 static constexpr const char *kFree = "free";
193
194 /// Adds malloc/free declarations to the module.
addCRuntimeDeclarations(ModuleOp module)195 static void addCRuntimeDeclarations(ModuleOp module) {
196 using namespace mlir::LLVM;
197
198 MLIRContext *ctx = module.getContext();
199 OpBuilder builder(module.getBody()->getTerminator());
200
201 auto voidTy = LLVMType::getVoidTy(ctx);
202 auto i64 = LLVMType::getInt64Ty(ctx);
203 auto i8Ptr = LLVMType::getInt8PtrTy(ctx);
204
205 addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64});
206 addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr});
207 }
208
209 //===----------------------------------------------------------------------===//
210 // Coroutine resume function wrapper.
211 //===----------------------------------------------------------------------===//
212
213 static constexpr const char *kResume = "__resume";
214
215 // A function that takes a coroutine handle and calls a `llvm.coro.resume`
216 // intrinsics. We need this function to be able to pass it to the async
217 // runtime execute API.
addResumeFunction(ModuleOp module)218 static void addResumeFunction(ModuleOp module) {
219 MLIRContext *ctx = module.getContext();
220
221 OpBuilder moduleBuilder(module.getBody()->getTerminator());
222 Location loc = module.getLoc();
223
224 if (module.lookupSymbol(kResume))
225 return;
226
227 auto voidTy = LLVM::LLVMType::getVoidTy(ctx);
228 auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
229
230 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
231 loc, kResume, LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false));
232 resumeOp.setPrivate();
233
234 auto *block = resumeOp.addEntryBlock();
235 OpBuilder blockBuilder = OpBuilder::atBlockEnd(block);
236
237 blockBuilder.create<LLVM::CallOp>(loc, TypeRange(),
238 blockBuilder.getSymbolRefAttr(kCoroResume),
239 resumeOp.getArgument(0));
240
241 blockBuilder.create<LLVM::ReturnOp>(loc, ValueRange());
242 }
243
244 //===----------------------------------------------------------------------===//
245 // async.execute op outlining to the coroutine functions.
246 //===----------------------------------------------------------------------===//
247
248 // Function targeted for coroutine transformation has two additional blocks at
249 // the end: coroutine cleanup and coroutine suspension.
250 //
251 // async.await op lowering additionaly creates a resume block for each
252 // operation to enable non-blocking waiting via coroutine suspension.
253 namespace {
254 struct CoroMachinery {
255 Value asyncToken;
256 Value coroHandle;
257 Block *cleanup;
258 Block *suspend;
259 };
260 } // namespace
261
262 // Builds an coroutine template compatible with LLVM coroutines lowering.
263 //
264 // - `entry` block sets up the coroutine.
265 // - `cleanup` block cleans up the coroutine state.
266 // - `suspend block after the @llvm.coro.end() defines what value will be
267 // returned to the initial caller of a coroutine. Everything before the
268 // @llvm.coro.end() will be executed at every suspension point.
269 //
270 // Coroutine structure (only the important bits):
271 //
272 // func @async_execute_fn(<function-arguments>) -> !async.token {
273 // ^entryBlock(<function-arguments>):
274 // %token = <async token> : !async.token // create async runtime token
275 // %hdl = llvm.call @llvm.coro.id(...) // create a coroutine handle
276 // br ^cleanup
277 //
278 // ^cleanup:
279 // llvm.call @llvm.coro.free(...) // delete coroutine state
280 // br ^suspend
281 //
282 // ^suspend:
283 // llvm.call @llvm.coro.end(...) // marks the end of a coroutine
284 // return %token : !async.token
285 // }
286 //
287 // The actual code for the async.execute operation body region will be inserted
288 // before the entry block terminator.
289 //
290 //
setupCoroMachinery(FuncOp func)291 static CoroMachinery setupCoroMachinery(FuncOp func) {
292 assert(func.getBody().empty() && "Function must have empty body");
293
294 MLIRContext *ctx = func.getContext();
295
296 auto token = LLVM::LLVMTokenType::get(ctx);
297 auto i1 = LLVM::LLVMType::getInt1Ty(ctx);
298 auto i32 = LLVM::LLVMType::getInt32Ty(ctx);
299 auto i64 = LLVM::LLVMType::getInt64Ty(ctx);
300 auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
301
302 Block *entryBlock = func.addEntryBlock();
303 Location loc = func.getBody().getLoc();
304
305 OpBuilder builder = OpBuilder::atBlockBegin(entryBlock);
306
307 // ------------------------------------------------------------------------ //
308 // Allocate async tokens/values that we will return from a ramp function.
309 // ------------------------------------------------------------------------ //
310 auto createToken =
311 builder.create<CallOp>(loc, kCreateToken, TokenType::get(ctx));
312
313 // ------------------------------------------------------------------------ //
314 // Initialize coroutine: allocate frame, get coroutine handle.
315 // ------------------------------------------------------------------------ //
316
317 // Constants for initializing coroutine frame.
318 auto constZero =
319 builder.create<LLVM::ConstantOp>(loc, i32, builder.getI32IntegerAttr(0));
320 auto constFalse =
321 builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false));
322 auto nullPtr = builder.create<LLVM::NullOp>(loc, i8Ptr);
323
324 // Get coroutine id: @llvm.coro.id
325 auto coroId = builder.create<LLVM::CallOp>(
326 loc, token, builder.getSymbolRefAttr(kCoroId),
327 ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
328
329 // Get coroutine frame size: @llvm.coro.size.i64
330 auto coroSize = builder.create<LLVM::CallOp>(
331 loc, i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange());
332
333 // Allocate memory for coroutine frame.
334 auto coroAlloc = builder.create<LLVM::CallOp>(
335 loc, i8Ptr, builder.getSymbolRefAttr(kMalloc),
336 ValueRange(coroSize.getResult(0)));
337
338 // Begin a coroutine: @llvm.coro.begin
339 auto coroHdl = builder.create<LLVM::CallOp>(
340 loc, i8Ptr, builder.getSymbolRefAttr(kCoroBegin),
341 ValueRange({coroId.getResult(0), coroAlloc.getResult(0)}));
342
343 Block *cleanupBlock = func.addBlock();
344 Block *suspendBlock = func.addBlock();
345
346 // ------------------------------------------------------------------------ //
347 // Coroutine cleanup block: deallocate coroutine frame, free the memory.
348 // ------------------------------------------------------------------------ //
349 builder.setInsertionPointToStart(cleanupBlock);
350
351 // Get a pointer to the coroutine frame memory: @llvm.coro.free.
352 auto coroMem = builder.create<LLVM::CallOp>(
353 loc, i8Ptr, builder.getSymbolRefAttr(kCoroFree),
354 ValueRange({coroId.getResult(0), coroHdl.getResult(0)}));
355
356 // Free the memory.
357 builder.create<LLVM::CallOp>(loc, TypeRange(),
358 builder.getSymbolRefAttr(kFree),
359 ValueRange(coroMem.getResult(0)));
360 // Branch into the suspend block.
361 builder.create<BranchOp>(loc, suspendBlock);
362
363 // ------------------------------------------------------------------------ //
364 // Coroutine suspend block: mark the end of a coroutine and return allocated
365 // async token.
366 // ------------------------------------------------------------------------ //
367 builder.setInsertionPointToStart(suspendBlock);
368
369 // Mark the end of a coroutine: @llvm.coro.end.
370 builder.create<LLVM::CallOp>(loc, i1, builder.getSymbolRefAttr(kCoroEnd),
371 ValueRange({coroHdl.getResult(0), constFalse}));
372
373 // Return created `async.token` from the suspend block. This will be the
374 // return value of a coroutine ramp function.
375 builder.create<ReturnOp>(loc, createToken.getResult(0));
376
377 // Branch from the entry block to the cleanup block to create a valid CFG.
378 builder.setInsertionPointToEnd(entryBlock);
379
380 builder.create<BranchOp>(loc, cleanupBlock);
381
382 // `async.await` op lowering will create resume blocks for async
383 // continuations, and will conditionally branch to cleanup or suspend blocks.
384
385 return {createToken.getResult(0), coroHdl.getResult(0), cleanupBlock,
386 suspendBlock};
387 }
388
389 // Add a LLVM coroutine suspension point to the end of suspended block, to
390 // resume execution in resume block. The caller is responsible for creating the
391 // two suspended/resume blocks with the desired ops contained in each block.
392 // This function merely provides the required control flow logic.
393 //
394 // `coroState` must be a value returned from the call to @llvm.coro.save(...)
395 // intrinsic (saved coroutine state).
396 //
397 // Before:
398 //
399 // ^bb0:
400 // "opBefore"(...)
401 // "op"(...)
402 // ^cleanup: ...
403 // ^suspend: ...
404 // ^resume:
405 // "op"(...)
406 //
407 // After:
408 //
409 // ^bb0:
410 // "opBefore"(...)
411 // %suspend = llmv.call @llvm.coro.suspend(...)
412 // switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
413 // ^resume:
414 // "op"(...)
415 // ^cleanup: ...
416 // ^suspend: ...
417 //
addSuspensionPoint(CoroMachinery coro,Value coroState,Operation * op,Block * suspended,Block * resume,OpBuilder & builder)418 static void addSuspensionPoint(CoroMachinery coro, Value coroState,
419 Operation *op, Block *suspended, Block *resume,
420 OpBuilder &builder) {
421 Location loc = op->getLoc();
422 MLIRContext *ctx = op->getContext();
423 auto i1 = LLVM::LLVMType::getInt1Ty(ctx);
424 auto i8 = LLVM::LLVMType::getInt8Ty(ctx);
425
426 // Add a coroutine suspension in place of original `op` in the split block.
427 OpBuilder::InsertionGuard guard(builder);
428 builder.setInsertionPointToEnd(suspended);
429
430 auto constFalse =
431 builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false));
432
433 // Suspend a coroutine: @llvm.coro.suspend
434 auto coroSuspend = builder.create<LLVM::CallOp>(
435 loc, i8, builder.getSymbolRefAttr(kCoroSuspend),
436 ValueRange({coroState, constFalse}));
437
438 // After a suspension point decide if we should branch into resume, cleanup
439 // or suspend block of the coroutine (see @llvm.coro.suspend return code
440 // documentation).
441 auto constZero =
442 builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(0));
443 auto constNegOne =
444 builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(-1));
445
446 Block *resumeOrCleanup = builder.createBlock(resume);
447
448 // Suspend the coroutine ...?
449 builder.setInsertionPointToEnd(suspended);
450 auto isNegOne = builder.create<LLVM::ICmpOp>(
451 loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constNegOne);
452 builder.create<LLVM::CondBrOp>(loc, isNegOne, /*trueDest=*/coro.suspend,
453 /*falseDest=*/resumeOrCleanup);
454
455 // ... or resume or cleanup the coroutine?
456 builder.setInsertionPointToStart(resumeOrCleanup);
457 auto isZero = builder.create<LLVM::ICmpOp>(
458 loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero);
459 builder.create<LLVM::CondBrOp>(loc, isZero, /*trueDest=*/resume,
460 /*falseDest=*/coro.cleanup);
461 }
462
463 // Outline the body region attached to the `async.execute` op into a standalone
464 // function.
465 //
466 // Note that this is not reversible transformation.
467 static std::pair<FuncOp, CoroMachinery>
outlineExecuteOp(SymbolTable & symbolTable,ExecuteOp execute)468 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
469 ModuleOp module = execute->getParentOfType<ModuleOp>();
470
471 MLIRContext *ctx = module.getContext();
472 Location loc = execute.getLoc();
473
474 OpBuilder moduleBuilder(module.getBody()->getTerminator());
475
476 // Collect all outlined function inputs.
477 llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
478 execute.dependencies().end());
479 getUsedValuesDefinedAbove(execute.body(), functionInputs);
480
481 // Collect types for the outlined function inputs and outputs.
482 auto typesRange = llvm::map_range(
483 functionInputs, [](Value value) { return value.getType(); });
484 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
485 auto outputTypes = execute.getResultTypes();
486
487 auto funcType = moduleBuilder.getFunctionType(inputTypes, outputTypes);
488 auto funcAttrs = ArrayRef<NamedAttribute>();
489
490 // TODO: Derive outlined function name from the parent FuncOp (support
491 // multiple nested async.execute operations).
492 FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
493 symbolTable.insert(func, moduleBuilder.getInsertionPoint());
494
495 SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
496
497 // Prepare a function for coroutine lowering by adding entry/cleanup/suspend
498 // blocks, adding llvm.coro instrinsics and setting up control flow.
499 CoroMachinery coro = setupCoroMachinery(func);
500
501 // Suspend async function at the end of an entry block, and resume it using
502 // Async execute API (execution will be resumed in a thread managed by the
503 // async runtime).
504 Block *entryBlock = &func.getBlocks().front();
505 OpBuilder builder = OpBuilder::atBlockTerminator(entryBlock);
506
507 // A pointer to coroutine resume intrinsic wrapper.
508 auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
509 auto resumePtr = builder.create<LLVM::AddressOfOp>(
510 loc, resumeFnTy.getPointerTo(), kResume);
511
512 // Save the coroutine state: @llvm.coro.save
513 auto coroSave = builder.create<LLVM::CallOp>(
514 loc, LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave),
515 ValueRange({coro.coroHandle}));
516
517 // Call async runtime API to execute a coroutine in the managed thread.
518 SmallVector<Value, 2> executeArgs = {coro.coroHandle, resumePtr.res()};
519 builder.create<CallOp>(loc, TypeRange(), kExecute, executeArgs);
520
521 // Split the entry block before the terminator.
522 auto *terminatorOp = entryBlock->getTerminator();
523 Block *suspended = terminatorOp->getBlock();
524 Block *resume = suspended->splitBlock(terminatorOp);
525 addSuspensionPoint(coro, coroSave.getResult(0), terminatorOp, suspended,
526 resume, builder);
527
528 // Await on all dependencies before starting to execute the body region.
529 builder.setInsertionPointToStart(resume);
530 for (size_t i = 0; i < execute.dependencies().size(); ++i)
531 builder.create<AwaitOp>(loc, func.getArgument(i));
532
533 // Map from function inputs defined above the execute op to the function
534 // arguments.
535 BlockAndValueMapping valueMapping;
536 valueMapping.map(functionInputs, func.getArguments());
537
538 // Clone all operations from the execute operation body into the outlined
539 // function body, and replace all `async.yield` operations with a call
540 // to async runtime to emplace the result token.
541 for (Operation &op : execute.body().getOps()) {
542 if (isa<async::YieldOp>(op)) {
543 builder.create<CallOp>(loc, kEmplaceToken, TypeRange(), coro.asyncToken);
544 continue;
545 }
546 builder.clone(op, valueMapping);
547 }
548
549 // Replace the original `async.execute` with a call to outlined function.
550 OpBuilder callBuilder(execute);
551 auto callOutlinedFunc =
552 callBuilder.create<CallOp>(loc, func.getName(), execute.getResultTypes(),
553 functionInputs.getArrayRef());
554 execute.replaceAllUsesWith(callOutlinedFunc.getResults());
555 execute.erase();
556
557 return {func, coro};
558 }
559
560 //===----------------------------------------------------------------------===//
561 // Convert Async dialect types to LLVM types.
562 //===----------------------------------------------------------------------===//
563
564 namespace {
565 class AsyncRuntimeTypeConverter : public TypeConverter {
566 public:
AsyncRuntimeTypeConverter()567 AsyncRuntimeTypeConverter() { addConversion(convertType); }
568
convertType(Type type)569 static Type convertType(Type type) {
570 MLIRContext *ctx = type.getContext();
571 // Convert async tokens and groups to opaque pointers.
572 if (type.isa<TokenType, GroupType>())
573 return LLVM::LLVMType::getInt8PtrTy(ctx);
574 return type;
575 }
576 };
577 } // namespace
578
579 //===----------------------------------------------------------------------===//
580 // Convert types for all call operations to lowered async types.
581 //===----------------------------------------------------------------------===//
582
583 namespace {
584 class CallOpOpConversion : public ConversionPattern {
585 public:
CallOpOpConversion(MLIRContext * ctx)586 explicit CallOpOpConversion(MLIRContext *ctx)
587 : ConversionPattern(CallOp::getOperationName(), 1, ctx) {}
588
589 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const590 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
591 ConversionPatternRewriter &rewriter) const override {
592 AsyncRuntimeTypeConverter converter;
593
594 SmallVector<Type, 5> resultTypes;
595 converter.convertTypes(op->getResultTypes(), resultTypes);
596
597 CallOp call = cast<CallOp>(op);
598 rewriter.replaceOpWithNewOp<CallOp>(op, resultTypes, call.callee(),
599 operands);
600
601 return success();
602 }
603 };
604 } // namespace
605
606 //===----------------------------------------------------------------------===//
607 // Async reference counting ops lowering (`async.add_ref` and `async.drop_ref`
608 // to the corresponding API calls).
609 //===----------------------------------------------------------------------===//
610
611 namespace {
612
613 template <typename RefCountingOp>
614 class RefCountingOpLowering : public ConversionPattern {
615 public:
RefCountingOpLowering(MLIRContext * ctx,StringRef apiFunctionName)616 explicit RefCountingOpLowering(MLIRContext *ctx, StringRef apiFunctionName)
617 : ConversionPattern(RefCountingOp::getOperationName(), 1, ctx),
618 apiFunctionName(apiFunctionName) {}
619
620 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const621 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
622 ConversionPatternRewriter &rewriter) const override {
623 RefCountingOp refCountingOp = cast<RefCountingOp>(op);
624
625 auto count = rewriter.create<ConstantOp>(
626 op->getLoc(), rewriter.getI32Type(),
627 rewriter.getI32IntegerAttr(refCountingOp.count()));
628
629 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName,
630 ValueRange({operands[0], count}));
631
632 return success();
633 }
634
635 private:
636 StringRef apiFunctionName;
637 };
638
639 // async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call.
640 class AddRefOpLowering : public RefCountingOpLowering<AddRefOp> {
641 public:
AddRefOpLowering(MLIRContext * ctx)642 explicit AddRefOpLowering(MLIRContext *ctx)
643 : RefCountingOpLowering(ctx, kAddRef) {}
644 };
645
646 // async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
647 class DropRefOpLowering : public RefCountingOpLowering<DropRefOp> {
648 public:
DropRefOpLowering(MLIRContext * ctx)649 explicit DropRefOpLowering(MLIRContext *ctx)
650 : RefCountingOpLowering(ctx, kDropRef) {}
651 };
652
653 } // namespace
654
655 //===----------------------------------------------------------------------===//
656 // async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
657 //===----------------------------------------------------------------------===//
658
659 namespace {
660 class CreateGroupOpLowering : public ConversionPattern {
661 public:
CreateGroupOpLowering(MLIRContext * ctx)662 explicit CreateGroupOpLowering(MLIRContext *ctx)
663 : ConversionPattern(CreateGroupOp::getOperationName(), 1, ctx) {}
664
665 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const666 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
667 ConversionPatternRewriter &rewriter) const override {
668 auto retTy = GroupType::get(op->getContext());
669 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup, retTy);
670 return success();
671 }
672 };
673 } // namespace
674
675 //===----------------------------------------------------------------------===//
676 // async.add_to_group op lowering to runtime function call.
677 //===----------------------------------------------------------------------===//
678
679 namespace {
680 class AddToGroupOpLowering : public ConversionPattern {
681 public:
AddToGroupOpLowering(MLIRContext * ctx)682 explicit AddToGroupOpLowering(MLIRContext *ctx)
683 : ConversionPattern(AddToGroupOp::getOperationName(), 1, ctx) {}
684
685 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const686 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
687 ConversionPatternRewriter &rewriter) const override {
688 // Currently we can only add tokens to the group.
689 auto addToGroup = cast<AddToGroupOp>(op);
690 if (!addToGroup.operand().getType().isa<TokenType>())
691 return failure();
692
693 auto i64 = IntegerType::get(64, op->getContext());
694 rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, i64, operands);
695 return success();
696 }
697 };
698 } // namespace
699
700 //===----------------------------------------------------------------------===//
701 // async.await and async.await_all op lowerings to the corresponding async
702 // runtime function calls.
703 //===----------------------------------------------------------------------===//
704
705 namespace {
706
707 template <typename AwaitType, typename AwaitableType>
708 class AwaitOpLoweringBase : public ConversionPattern {
709 protected:
AwaitOpLoweringBase(MLIRContext * ctx,const llvm::DenseMap<FuncOp,CoroMachinery> & outlinedFunctions,StringRef blockingAwaitFuncName,StringRef coroAwaitFuncName)710 explicit AwaitOpLoweringBase(
711 MLIRContext *ctx,
712 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions,
713 StringRef blockingAwaitFuncName, StringRef coroAwaitFuncName)
714 : ConversionPattern(AwaitType::getOperationName(), 1, ctx),
715 outlinedFunctions(outlinedFunctions),
716 blockingAwaitFuncName(blockingAwaitFuncName),
717 coroAwaitFuncName(coroAwaitFuncName) {}
718
719 public:
720 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const721 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
722 ConversionPatternRewriter &rewriter) const override {
723 // We can only await on one the `AwaitableType` (for `await` it can be
724 // only a `token`, for `await_all` it is a `group`).
725 auto await = cast<AwaitType>(op);
726 if (!await.operand().getType().template isa<AwaitableType>())
727 return failure();
728
729 // Check if await operation is inside the outlined coroutine function.
730 auto func = await->template getParentOfType<FuncOp>();
731 auto outlined = outlinedFunctions.find(func);
732 const bool isInCoroutine = outlined != outlinedFunctions.end();
733
734 Location loc = op->getLoc();
735
736 // Inside regular function we convert await operation to the blocking
737 // async API await function call.
738 if (!isInCoroutine)
739 rewriter.create<CallOp>(loc, TypeRange(), blockingAwaitFuncName,
740 ValueRange(operands[0]));
741
742 // Inside the coroutine we convert await operation into coroutine suspension
743 // point, and resume execution asynchronously.
744 if (isInCoroutine) {
745 const CoroMachinery &coro = outlined->getSecond();
746
747 OpBuilder builder(op, rewriter.getListener());
748 MLIRContext *ctx = op->getContext();
749
750 // A pointer to coroutine resume intrinsic wrapper.
751 auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
752 auto resumePtr = builder.create<LLVM::AddressOfOp>(
753 loc, resumeFnTy.getPointerTo(), kResume);
754
755 // Save the coroutine state: @llvm.coro.save
756 auto coroSave = builder.create<LLVM::CallOp>(
757 loc, LLVM::LLVMTokenType::get(ctx),
758 builder.getSymbolRefAttr(kCoroSave), ValueRange(coro.coroHandle));
759
760 // Call async runtime API to resume a coroutine in the managed thread when
761 // the async await argument becomes ready.
762 SmallVector<Value, 3> awaitAndExecuteArgs = {operands[0], coro.coroHandle,
763 resumePtr.res()};
764 builder.create<CallOp>(loc, TypeRange(), coroAwaitFuncName,
765 awaitAndExecuteArgs);
766
767 Block *suspended = op->getBlock();
768
769 // Split the entry block before the await operation.
770 Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
771 addSuspensionPoint(coro, coroSave.getResult(0), op, suspended, resume,
772 builder);
773 }
774
775 // Original operation was replaced by function call or suspension point.
776 rewriter.eraseOp(op);
777
778 return success();
779 }
780
781 private:
782 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
783 StringRef blockingAwaitFuncName;
784 StringRef coroAwaitFuncName;
785 };
786
787 // Lowering for `async.await` operation (only token operands are supported).
788 class AwaitOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
789 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
790
791 public:
AwaitOpLowering(MLIRContext * ctx,const llvm::DenseMap<FuncOp,CoroMachinery> & outlinedFunctions)792 explicit AwaitOpLowering(
793 MLIRContext *ctx,
794 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
795 : Base(ctx, outlinedFunctions, kAwaitToken, kAwaitAndExecute) {}
796 };
797
798 // Lowering for `async.await_all` operation.
799 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
800 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
801
802 public:
AwaitAllOpLowering(MLIRContext * ctx,const llvm::DenseMap<FuncOp,CoroMachinery> & outlinedFunctions)803 explicit AwaitAllOpLowering(
804 MLIRContext *ctx,
805 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
806 : Base(ctx, outlinedFunctions, kAwaitGroup, kAwaitAllAndExecute) {}
807 };
808
809 } // namespace
810
811 //===----------------------------------------------------------------------===//
812
813 namespace {
814 struct ConvertAsyncToLLVMPass
815 : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> {
816 void runOnOperation() override;
817 };
818
runOnOperation()819 void ConvertAsyncToLLVMPass::runOnOperation() {
820 ModuleOp module = getOperation();
821 SymbolTable symbolTable(module);
822
823 // Outline all `async.execute` body regions into async functions (coroutines).
824 llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
825
826 WalkResult outlineResult = module.walk([&](ExecuteOp execute) {
827 // We currently do not support execute operations that have async value
828 // operands or produce async results.
829 if (!execute.operands().empty() || !execute.results().empty()) {
830 execute.emitOpError("can't outline async.execute op with async value "
831 "operands or returned async results");
832 return WalkResult::interrupt();
833 }
834
835 outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
836
837 return WalkResult::advance();
838 });
839
840 // Failed to outline all async execute operations.
841 if (outlineResult.wasInterrupted()) {
842 signalPassFailure();
843 return;
844 }
845
846 LLVM_DEBUG({
847 llvm::dbgs() << "Outlined " << outlinedFunctions.size()
848 << " async functions\n";
849 });
850
851 // Add declarations for all functions required by the coroutines lowering.
852 addResumeFunction(module);
853 addAsyncRuntimeApiDeclarations(module);
854 addCoroutineIntrinsicsDeclarations(module);
855 addCRuntimeDeclarations(module);
856
857 MLIRContext *ctx = &getContext();
858
859 // Convert async dialect types and operations to LLVM dialect.
860 AsyncRuntimeTypeConverter converter;
861 OwningRewritePatternList patterns;
862
863 populateFuncOpTypeConversionPattern(patterns, ctx, converter);
864 patterns.insert<CallOpOpConversion>(ctx);
865 patterns.insert<AddRefOpLowering, DropRefOpLowering>(ctx);
866 patterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
867 patterns.insert<AwaitOpLowering, AwaitAllOpLowering>(ctx, outlinedFunctions);
868
869 ConversionTarget target(*ctx);
870 target.addLegalOp<ConstantOp>();
871 target.addLegalDialect<LLVM::LLVMDialect>();
872 target.addIllegalDialect<AsyncDialect>();
873 target.addDynamicallyLegalOp<FuncOp>(
874 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
875 target.addDynamicallyLegalOp<CallOp>(
876 [&](CallOp op) { return converter.isLegal(op.getResultTypes()); });
877
878 if (failed(applyPartialConversion(module, target, std::move(patterns))))
879 signalPassFailure();
880 }
881 } // namespace
882
createConvertAsyncToLLVMPass()883 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
884 return std::make_unique<ConvertAsyncToLLVMPass>();
885 }
886