1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/Async/IR/Async.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "mlir/Transforms/RegionUtils.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/Support/FormatVariadic.h"
23 
24 #define DEBUG_TYPE "convert-async-to-llvm"
25 
26 using namespace mlir;
27 using namespace mlir::async;
28 
29 // Prefix for functions outlined from `async.execute` op regions.
30 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
31 
32 //===----------------------------------------------------------------------===//
33 // Async Runtime C API declaration.
34 //===----------------------------------------------------------------------===//
35 
36 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
37 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
38 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
39 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
40 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
41 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
42 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
43 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
44 static constexpr const char *kAddTokenToGroup =
45     "mlirAsyncRuntimeAddTokenToGroup";
46 static constexpr const char *kAwaitAndExecute =
47     "mlirAsyncRuntimeAwaitTokenAndExecute";
48 static constexpr const char *kAwaitAllAndExecute =
49     "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
50 
51 namespace {
52 // Async Runtime API function types.
53 struct AsyncAPI {
addOrDropRefFunctionType__anon61a0c9350111::AsyncAPI54   static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
55     auto ref = LLVM::LLVMType::getInt8PtrTy(ctx);
56     auto count = IntegerType::get(32, ctx);
57     return FunctionType::get({ref, count}, {}, ctx);
58   }
59 
createTokenFunctionType__anon61a0c9350111::AsyncAPI60   static FunctionType createTokenFunctionType(MLIRContext *ctx) {
61     return FunctionType::get({}, {TokenType::get(ctx)}, ctx);
62   }
63 
createGroupFunctionType__anon61a0c9350111::AsyncAPI64   static FunctionType createGroupFunctionType(MLIRContext *ctx) {
65     return FunctionType::get({}, {GroupType::get(ctx)}, ctx);
66   }
67 
emplaceTokenFunctionType__anon61a0c9350111::AsyncAPI68   static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
69     return FunctionType::get({TokenType::get(ctx)}, {}, ctx);
70   }
71 
awaitTokenFunctionType__anon61a0c9350111::AsyncAPI72   static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
73     return FunctionType::get({TokenType::get(ctx)}, {}, ctx);
74   }
75 
awaitGroupFunctionType__anon61a0c9350111::AsyncAPI76   static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
77     return FunctionType::get({GroupType::get(ctx)}, {}, ctx);
78   }
79 
executeFunctionType__anon61a0c9350111::AsyncAPI80   static FunctionType executeFunctionType(MLIRContext *ctx) {
81     auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
82     auto resume = resumeFunctionType(ctx).getPointerTo();
83     return FunctionType::get({hdl, resume}, {}, ctx);
84   }
85 
addTokenToGroupFunctionType__anon61a0c9350111::AsyncAPI86   static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
87     auto i64 = IntegerType::get(64, ctx);
88     return FunctionType::get({TokenType::get(ctx), GroupType::get(ctx)}, {i64},
89                              ctx);
90   }
91 
awaitAndExecuteFunctionType__anon61a0c9350111::AsyncAPI92   static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) {
93     auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
94     auto resume = resumeFunctionType(ctx).getPointerTo();
95     return FunctionType::get({TokenType::get(ctx), hdl, resume}, {}, ctx);
96   }
97 
awaitAllAndExecuteFunctionType__anon61a0c9350111::AsyncAPI98   static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
99     auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
100     auto resume = resumeFunctionType(ctx).getPointerTo();
101     return FunctionType::get({GroupType::get(ctx), hdl, resume}, {}, ctx);
102   }
103 
104   // Auxiliary coroutine resume intrinsic wrapper.
resumeFunctionType__anon61a0c9350111::AsyncAPI105   static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) {
106     auto voidTy = LLVM::LLVMType::getVoidTy(ctx);
107     auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
108     return LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false);
109   }
110 };
111 } // namespace
112 
113 // Adds Async Runtime C API declarations to the module.
addAsyncRuntimeApiDeclarations(ModuleOp module)114 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
115   auto builder = OpBuilder::atBlockTerminator(module.getBody());
116 
117   auto addFuncDecl = [&](StringRef name, FunctionType type) {
118     if (module.lookupSymbol(name))
119       return;
120     builder.create<FuncOp>(module.getLoc(), name, type).setPrivate();
121   };
122 
123   MLIRContext *ctx = module.getContext();
124   addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
125   addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
126   addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
127   addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
128   addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
129   addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
130   addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
131   addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
132   addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
133   addFuncDecl(kAwaitAndExecute, AsyncAPI::awaitAndExecuteFunctionType(ctx));
134   addFuncDecl(kAwaitAllAndExecute,
135               AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
136 }
137 
138 //===----------------------------------------------------------------------===//
139 // LLVM coroutines intrinsics declarations.
140 //===----------------------------------------------------------------------===//
141 
142 static constexpr const char *kCoroId = "llvm.coro.id";
143 static constexpr const char *kCoroSizeI64 = "llvm.coro.size.i64";
144 static constexpr const char *kCoroBegin = "llvm.coro.begin";
145 static constexpr const char *kCoroSave = "llvm.coro.save";
146 static constexpr const char *kCoroSuspend = "llvm.coro.suspend";
147 static constexpr const char *kCoroEnd = "llvm.coro.end";
148 static constexpr const char *kCoroFree = "llvm.coro.free";
149 static constexpr const char *kCoroResume = "llvm.coro.resume";
150 
151 /// Adds an LLVM function declaration to a module.
addLLVMFuncDecl(ModuleOp module,OpBuilder & builder,StringRef name,LLVM::LLVMType ret,ArrayRef<LLVM::LLVMType> params)152 static void addLLVMFuncDecl(ModuleOp module, OpBuilder &builder, StringRef name,
153                             LLVM::LLVMType ret,
154                             ArrayRef<LLVM::LLVMType> params) {
155   if (module.lookupSymbol(name))
156     return;
157   LLVM::LLVMType type = LLVM::LLVMType::getFunctionTy(ret, params, false);
158   builder.create<LLVM::LLVMFuncOp>(module.getLoc(), name, type);
159 }
160 
161 /// Adds coroutine intrinsics declarations to the module.
addCoroutineIntrinsicsDeclarations(ModuleOp module)162 static void addCoroutineIntrinsicsDeclarations(ModuleOp module) {
163   using namespace mlir::LLVM;
164 
165   MLIRContext *ctx = module.getContext();
166   OpBuilder builder(module.getBody()->getTerminator());
167 
168   auto token = LLVMTokenType::get(ctx);
169   auto voidTy = LLVMType::getVoidTy(ctx);
170 
171   auto i8 = LLVMType::getInt8Ty(ctx);
172   auto i1 = LLVMType::getInt1Ty(ctx);
173   auto i32 = LLVMType::getInt32Ty(ctx);
174   auto i64 = LLVMType::getInt64Ty(ctx);
175   auto i8Ptr = LLVMType::getInt8PtrTy(ctx);
176 
177   addLLVMFuncDecl(module, builder, kCoroId, token, {i32, i8Ptr, i8Ptr, i8Ptr});
178   addLLVMFuncDecl(module, builder, kCoroSizeI64, i64, {});
179   addLLVMFuncDecl(module, builder, kCoroBegin, i8Ptr, {token, i8Ptr});
180   addLLVMFuncDecl(module, builder, kCoroSave, token, {i8Ptr});
181   addLLVMFuncDecl(module, builder, kCoroSuspend, i8, {token, i1});
182   addLLVMFuncDecl(module, builder, kCoroEnd, i1, {i8Ptr, i1});
183   addLLVMFuncDecl(module, builder, kCoroFree, i8Ptr, {token, i8Ptr});
184   addLLVMFuncDecl(module, builder, kCoroResume, voidTy, {i8Ptr});
185 }
186 
187 //===----------------------------------------------------------------------===//
188 // Add malloc/free declarations to the module.
189 //===----------------------------------------------------------------------===//
190 
191 static constexpr const char *kMalloc = "malloc";
192 static constexpr const char *kFree = "free";
193 
194 /// Adds malloc/free declarations to the module.
addCRuntimeDeclarations(ModuleOp module)195 static void addCRuntimeDeclarations(ModuleOp module) {
196   using namespace mlir::LLVM;
197 
198   MLIRContext *ctx = module.getContext();
199   OpBuilder builder(module.getBody()->getTerminator());
200 
201   auto voidTy = LLVMType::getVoidTy(ctx);
202   auto i64 = LLVMType::getInt64Ty(ctx);
203   auto i8Ptr = LLVMType::getInt8PtrTy(ctx);
204 
205   addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64});
206   addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr});
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // Coroutine resume function wrapper.
211 //===----------------------------------------------------------------------===//
212 
213 static constexpr const char *kResume = "__resume";
214 
215 // A function that takes a coroutine handle and calls a `llvm.coro.resume`
216 // intrinsics. We need this function to be able to pass it to the async
217 // runtime execute API.
addResumeFunction(ModuleOp module)218 static void addResumeFunction(ModuleOp module) {
219   MLIRContext *ctx = module.getContext();
220 
221   OpBuilder moduleBuilder(module.getBody()->getTerminator());
222   Location loc = module.getLoc();
223 
224   if (module.lookupSymbol(kResume))
225     return;
226 
227   auto voidTy = LLVM::LLVMType::getVoidTy(ctx);
228   auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
229 
230   auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
231       loc, kResume, LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false));
232   resumeOp.setPrivate();
233 
234   auto *block = resumeOp.addEntryBlock();
235   OpBuilder blockBuilder = OpBuilder::atBlockEnd(block);
236 
237   blockBuilder.create<LLVM::CallOp>(loc, TypeRange(),
238                                     blockBuilder.getSymbolRefAttr(kCoroResume),
239                                     resumeOp.getArgument(0));
240 
241   blockBuilder.create<LLVM::ReturnOp>(loc, ValueRange());
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // async.execute op outlining to the coroutine functions.
246 //===----------------------------------------------------------------------===//
247 
248 // Function targeted for coroutine transformation has two additional blocks at
249 // the end: coroutine cleanup and coroutine suspension.
250 //
251 // async.await op lowering additionaly creates a resume block for each
252 // operation to enable non-blocking waiting via coroutine suspension.
253 namespace {
254 struct CoroMachinery {
255   Value asyncToken;
256   Value coroHandle;
257   Block *cleanup;
258   Block *suspend;
259 };
260 } // namespace
261 
262 // Builds an coroutine template compatible with LLVM coroutines lowering.
263 //
264 //  - `entry` block sets up the coroutine.
265 //  - `cleanup` block cleans up the coroutine state.
266 //  - `suspend block after the @llvm.coro.end() defines what value will be
267 //    returned to the initial caller of a coroutine. Everything before the
268 //    @llvm.coro.end() will be executed at every suspension point.
269 //
270 // Coroutine structure (only the important bits):
271 //
272 //   func @async_execute_fn(<function-arguments>) -> !async.token {
273 //     ^entryBlock(<function-arguments>):
274 //       %token = <async token> : !async.token // create async runtime token
275 //       %hdl = llvm.call @llvm.coro.id(...)   // create a coroutine handle
276 //       br ^cleanup
277 //
278 //     ^cleanup:
279 //       llvm.call @llvm.coro.free(...)        // delete coroutine state
280 //       br ^suspend
281 //
282 //     ^suspend:
283 //       llvm.call @llvm.coro.end(...)         // marks the end of a coroutine
284 //       return %token : !async.token
285 //   }
286 //
287 // The actual code for the async.execute operation body region will be inserted
288 // before the entry block terminator.
289 //
290 //
setupCoroMachinery(FuncOp func)291 static CoroMachinery setupCoroMachinery(FuncOp func) {
292   assert(func.getBody().empty() && "Function must have empty body");
293 
294   MLIRContext *ctx = func.getContext();
295 
296   auto token = LLVM::LLVMTokenType::get(ctx);
297   auto i1 = LLVM::LLVMType::getInt1Ty(ctx);
298   auto i32 = LLVM::LLVMType::getInt32Ty(ctx);
299   auto i64 = LLVM::LLVMType::getInt64Ty(ctx);
300   auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
301 
302   Block *entryBlock = func.addEntryBlock();
303   Location loc = func.getBody().getLoc();
304 
305   OpBuilder builder = OpBuilder::atBlockBegin(entryBlock);
306 
307   // ------------------------------------------------------------------------ //
308   // Allocate async tokens/values that we will return from a ramp function.
309   // ------------------------------------------------------------------------ //
310   auto createToken =
311       builder.create<CallOp>(loc, kCreateToken, TokenType::get(ctx));
312 
313   // ------------------------------------------------------------------------ //
314   // Initialize coroutine: allocate frame, get coroutine handle.
315   // ------------------------------------------------------------------------ //
316 
317   // Constants for initializing coroutine frame.
318   auto constZero =
319       builder.create<LLVM::ConstantOp>(loc, i32, builder.getI32IntegerAttr(0));
320   auto constFalse =
321       builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false));
322   auto nullPtr = builder.create<LLVM::NullOp>(loc, i8Ptr);
323 
324   // Get coroutine id: @llvm.coro.id
325   auto coroId = builder.create<LLVM::CallOp>(
326       loc, token, builder.getSymbolRefAttr(kCoroId),
327       ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
328 
329   // Get coroutine frame size: @llvm.coro.size.i64
330   auto coroSize = builder.create<LLVM::CallOp>(
331       loc, i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange());
332 
333   // Allocate memory for coroutine frame.
334   auto coroAlloc = builder.create<LLVM::CallOp>(
335       loc, i8Ptr, builder.getSymbolRefAttr(kMalloc),
336       ValueRange(coroSize.getResult(0)));
337 
338   // Begin a coroutine: @llvm.coro.begin
339   auto coroHdl = builder.create<LLVM::CallOp>(
340       loc, i8Ptr, builder.getSymbolRefAttr(kCoroBegin),
341       ValueRange({coroId.getResult(0), coroAlloc.getResult(0)}));
342 
343   Block *cleanupBlock = func.addBlock();
344   Block *suspendBlock = func.addBlock();
345 
346   // ------------------------------------------------------------------------ //
347   // Coroutine cleanup block: deallocate coroutine frame, free the memory.
348   // ------------------------------------------------------------------------ //
349   builder.setInsertionPointToStart(cleanupBlock);
350 
351   // Get a pointer to the coroutine frame memory: @llvm.coro.free.
352   auto coroMem = builder.create<LLVM::CallOp>(
353       loc, i8Ptr, builder.getSymbolRefAttr(kCoroFree),
354       ValueRange({coroId.getResult(0), coroHdl.getResult(0)}));
355 
356   // Free the memory.
357   builder.create<LLVM::CallOp>(loc, TypeRange(),
358                                builder.getSymbolRefAttr(kFree),
359                                ValueRange(coroMem.getResult(0)));
360   // Branch into the suspend block.
361   builder.create<BranchOp>(loc, suspendBlock);
362 
363   // ------------------------------------------------------------------------ //
364   // Coroutine suspend block: mark the end of a coroutine and return allocated
365   // async token.
366   // ------------------------------------------------------------------------ //
367   builder.setInsertionPointToStart(suspendBlock);
368 
369   // Mark the end of a coroutine: @llvm.coro.end.
370   builder.create<LLVM::CallOp>(loc, i1, builder.getSymbolRefAttr(kCoroEnd),
371                                ValueRange({coroHdl.getResult(0), constFalse}));
372 
373   // Return created `async.token` from the suspend block. This will be the
374   // return value of a coroutine ramp function.
375   builder.create<ReturnOp>(loc, createToken.getResult(0));
376 
377   // Branch from the entry block to the cleanup block to create a valid CFG.
378   builder.setInsertionPointToEnd(entryBlock);
379 
380   builder.create<BranchOp>(loc, cleanupBlock);
381 
382   // `async.await` op lowering will create resume blocks for async
383   // continuations, and will conditionally branch to cleanup or suspend blocks.
384 
385   return {createToken.getResult(0), coroHdl.getResult(0), cleanupBlock,
386           suspendBlock};
387 }
388 
389 // Add a LLVM coroutine suspension point to the end of suspended block, to
390 // resume execution in resume block. The caller is responsible for creating the
391 // two suspended/resume blocks with the desired ops contained in each block.
392 // This function merely provides the required control flow logic.
393 //
394 // `coroState` must be a value returned from the call to @llvm.coro.save(...)
395 // intrinsic (saved coroutine state).
396 //
397 // Before:
398 //
399 //   ^bb0:
400 //     "opBefore"(...)
401 //     "op"(...)
402 //   ^cleanup: ...
403 //   ^suspend: ...
404 //   ^resume:
405 //     "op"(...)
406 //
407 // After:
408 //
409 //   ^bb0:
410 //     "opBefore"(...)
411 //     %suspend = llmv.call @llvm.coro.suspend(...)
412 //     switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
413 //   ^resume:
414 //     "op"(...)
415 //   ^cleanup: ...
416 //   ^suspend: ...
417 //
addSuspensionPoint(CoroMachinery coro,Value coroState,Operation * op,Block * suspended,Block * resume,OpBuilder & builder)418 static void addSuspensionPoint(CoroMachinery coro, Value coroState,
419                                Operation *op, Block *suspended, Block *resume,
420                                OpBuilder &builder) {
421   Location loc = op->getLoc();
422   MLIRContext *ctx = op->getContext();
423   auto i1 = LLVM::LLVMType::getInt1Ty(ctx);
424   auto i8 = LLVM::LLVMType::getInt8Ty(ctx);
425 
426   // Add a coroutine suspension in place of original `op` in the split block.
427   OpBuilder::InsertionGuard guard(builder);
428   builder.setInsertionPointToEnd(suspended);
429 
430   auto constFalse =
431       builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false));
432 
433   // Suspend a coroutine: @llvm.coro.suspend
434   auto coroSuspend = builder.create<LLVM::CallOp>(
435       loc, i8, builder.getSymbolRefAttr(kCoroSuspend),
436       ValueRange({coroState, constFalse}));
437 
438   // After a suspension point decide if we should branch into resume, cleanup
439   // or suspend block of the coroutine (see @llvm.coro.suspend return code
440   // documentation).
441   auto constZero =
442       builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(0));
443   auto constNegOne =
444       builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(-1));
445 
446   Block *resumeOrCleanup = builder.createBlock(resume);
447 
448   // Suspend the coroutine ...?
449   builder.setInsertionPointToEnd(suspended);
450   auto isNegOne = builder.create<LLVM::ICmpOp>(
451       loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constNegOne);
452   builder.create<LLVM::CondBrOp>(loc, isNegOne, /*trueDest=*/coro.suspend,
453                                  /*falseDest=*/resumeOrCleanup);
454 
455   // ... or resume or cleanup the coroutine?
456   builder.setInsertionPointToStart(resumeOrCleanup);
457   auto isZero = builder.create<LLVM::ICmpOp>(
458       loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero);
459   builder.create<LLVM::CondBrOp>(loc, isZero, /*trueDest=*/resume,
460                                  /*falseDest=*/coro.cleanup);
461 }
462 
463 // Outline the body region attached to the `async.execute` op into a standalone
464 // function.
465 //
466 // Note that this is not reversible transformation.
467 static std::pair<FuncOp, CoroMachinery>
outlineExecuteOp(SymbolTable & symbolTable,ExecuteOp execute)468 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
469   ModuleOp module = execute->getParentOfType<ModuleOp>();
470 
471   MLIRContext *ctx = module.getContext();
472   Location loc = execute.getLoc();
473 
474   OpBuilder moduleBuilder(module.getBody()->getTerminator());
475 
476   // Collect all outlined function inputs.
477   llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
478                                               execute.dependencies().end());
479   getUsedValuesDefinedAbove(execute.body(), functionInputs);
480 
481   // Collect types for the outlined function inputs and outputs.
482   auto typesRange = llvm::map_range(
483       functionInputs, [](Value value) { return value.getType(); });
484   SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
485   auto outputTypes = execute.getResultTypes();
486 
487   auto funcType = moduleBuilder.getFunctionType(inputTypes, outputTypes);
488   auto funcAttrs = ArrayRef<NamedAttribute>();
489 
490   // TODO: Derive outlined function name from the parent FuncOp (support
491   // multiple nested async.execute operations).
492   FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
493   symbolTable.insert(func, moduleBuilder.getInsertionPoint());
494 
495   SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
496 
497   // Prepare a function for coroutine lowering by adding entry/cleanup/suspend
498   // blocks, adding llvm.coro instrinsics and setting up control flow.
499   CoroMachinery coro = setupCoroMachinery(func);
500 
501   // Suspend async function at the end of an entry block, and resume it using
502   // Async execute API (execution will be resumed in a thread managed by the
503   // async runtime).
504   Block *entryBlock = &func.getBlocks().front();
505   OpBuilder builder = OpBuilder::atBlockTerminator(entryBlock);
506 
507   // A pointer to coroutine resume intrinsic wrapper.
508   auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
509   auto resumePtr = builder.create<LLVM::AddressOfOp>(
510       loc, resumeFnTy.getPointerTo(), kResume);
511 
512   // Save the coroutine state: @llvm.coro.save
513   auto coroSave = builder.create<LLVM::CallOp>(
514       loc, LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave),
515       ValueRange({coro.coroHandle}));
516 
517   // Call async runtime API to execute a coroutine in the managed thread.
518   SmallVector<Value, 2> executeArgs = {coro.coroHandle, resumePtr.res()};
519   builder.create<CallOp>(loc, TypeRange(), kExecute, executeArgs);
520 
521   // Split the entry block before the terminator.
522   auto *terminatorOp = entryBlock->getTerminator();
523   Block *suspended = terminatorOp->getBlock();
524   Block *resume = suspended->splitBlock(terminatorOp);
525   addSuspensionPoint(coro, coroSave.getResult(0), terminatorOp, suspended,
526                      resume, builder);
527 
528   // Await on all dependencies before starting to execute the body region.
529   builder.setInsertionPointToStart(resume);
530   for (size_t i = 0; i < execute.dependencies().size(); ++i)
531     builder.create<AwaitOp>(loc, func.getArgument(i));
532 
533   // Map from function inputs defined above the execute op to the function
534   // arguments.
535   BlockAndValueMapping valueMapping;
536   valueMapping.map(functionInputs, func.getArguments());
537 
538   // Clone all operations from the execute operation body into the outlined
539   // function body, and replace all `async.yield` operations with a call
540   // to async runtime to emplace the result token.
541   for (Operation &op : execute.body().getOps()) {
542     if (isa<async::YieldOp>(op)) {
543       builder.create<CallOp>(loc, kEmplaceToken, TypeRange(), coro.asyncToken);
544       continue;
545     }
546     builder.clone(op, valueMapping);
547   }
548 
549   // Replace the original `async.execute` with a call to outlined function.
550   OpBuilder callBuilder(execute);
551   auto callOutlinedFunc =
552       callBuilder.create<CallOp>(loc, func.getName(), execute.getResultTypes(),
553                                  functionInputs.getArrayRef());
554   execute.replaceAllUsesWith(callOutlinedFunc.getResults());
555   execute.erase();
556 
557   return {func, coro};
558 }
559 
560 //===----------------------------------------------------------------------===//
561 // Convert Async dialect types to LLVM types.
562 //===----------------------------------------------------------------------===//
563 
564 namespace {
565 class AsyncRuntimeTypeConverter : public TypeConverter {
566 public:
AsyncRuntimeTypeConverter()567   AsyncRuntimeTypeConverter() { addConversion(convertType); }
568 
convertType(Type type)569   static Type convertType(Type type) {
570     MLIRContext *ctx = type.getContext();
571     // Convert async tokens and groups to opaque pointers.
572     if (type.isa<TokenType, GroupType>())
573       return LLVM::LLVMType::getInt8PtrTy(ctx);
574     return type;
575   }
576 };
577 } // namespace
578 
579 //===----------------------------------------------------------------------===//
580 // Convert types for all call operations to lowered async types.
581 //===----------------------------------------------------------------------===//
582 
583 namespace {
584 class CallOpOpConversion : public ConversionPattern {
585 public:
CallOpOpConversion(MLIRContext * ctx)586   explicit CallOpOpConversion(MLIRContext *ctx)
587       : ConversionPattern(CallOp::getOperationName(), 1, ctx) {}
588 
589   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const590   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
591                   ConversionPatternRewriter &rewriter) const override {
592     AsyncRuntimeTypeConverter converter;
593 
594     SmallVector<Type, 5> resultTypes;
595     converter.convertTypes(op->getResultTypes(), resultTypes);
596 
597     CallOp call = cast<CallOp>(op);
598     rewriter.replaceOpWithNewOp<CallOp>(op, resultTypes, call.callee(),
599                                         operands);
600 
601     return success();
602   }
603 };
604 } // namespace
605 
606 //===----------------------------------------------------------------------===//
607 // Async reference counting ops lowering (`async.add_ref` and `async.drop_ref`
608 // to the corresponding API calls).
609 //===----------------------------------------------------------------------===//
610 
611 namespace {
612 
613 template <typename RefCountingOp>
614 class RefCountingOpLowering : public ConversionPattern {
615 public:
RefCountingOpLowering(MLIRContext * ctx,StringRef apiFunctionName)616   explicit RefCountingOpLowering(MLIRContext *ctx, StringRef apiFunctionName)
617       : ConversionPattern(RefCountingOp::getOperationName(), 1, ctx),
618         apiFunctionName(apiFunctionName) {}
619 
620   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const621   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
622                   ConversionPatternRewriter &rewriter) const override {
623     RefCountingOp refCountingOp = cast<RefCountingOp>(op);
624 
625     auto count = rewriter.create<ConstantOp>(
626         op->getLoc(), rewriter.getI32Type(),
627         rewriter.getI32IntegerAttr(refCountingOp.count()));
628 
629     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName,
630                                         ValueRange({operands[0], count}));
631 
632     return success();
633   }
634 
635 private:
636   StringRef apiFunctionName;
637 };
638 
639 // async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call.
640 class AddRefOpLowering : public RefCountingOpLowering<AddRefOp> {
641 public:
AddRefOpLowering(MLIRContext * ctx)642   explicit AddRefOpLowering(MLIRContext *ctx)
643       : RefCountingOpLowering(ctx, kAddRef) {}
644 };
645 
646 // async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
647 class DropRefOpLowering : public RefCountingOpLowering<DropRefOp> {
648 public:
DropRefOpLowering(MLIRContext * ctx)649   explicit DropRefOpLowering(MLIRContext *ctx)
650       : RefCountingOpLowering(ctx, kDropRef) {}
651 };
652 
653 } // namespace
654 
655 //===----------------------------------------------------------------------===//
656 // async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
657 //===----------------------------------------------------------------------===//
658 
659 namespace {
660 class CreateGroupOpLowering : public ConversionPattern {
661 public:
CreateGroupOpLowering(MLIRContext * ctx)662   explicit CreateGroupOpLowering(MLIRContext *ctx)
663       : ConversionPattern(CreateGroupOp::getOperationName(), 1, ctx) {}
664 
665   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const666   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
667                   ConversionPatternRewriter &rewriter) const override {
668     auto retTy = GroupType::get(op->getContext());
669     rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup, retTy);
670     return success();
671   }
672 };
673 } // namespace
674 
675 //===----------------------------------------------------------------------===//
676 // async.add_to_group op lowering to runtime function call.
677 //===----------------------------------------------------------------------===//
678 
679 namespace {
680 class AddToGroupOpLowering : public ConversionPattern {
681 public:
AddToGroupOpLowering(MLIRContext * ctx)682   explicit AddToGroupOpLowering(MLIRContext *ctx)
683       : ConversionPattern(AddToGroupOp::getOperationName(), 1, ctx) {}
684 
685   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const686   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
687                   ConversionPatternRewriter &rewriter) const override {
688     // Currently we can only add tokens to the group.
689     auto addToGroup = cast<AddToGroupOp>(op);
690     if (!addToGroup.operand().getType().isa<TokenType>())
691       return failure();
692 
693     auto i64 = IntegerType::get(64, op->getContext());
694     rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, i64, operands);
695     return success();
696   }
697 };
698 } // namespace
699 
700 //===----------------------------------------------------------------------===//
701 // async.await and async.await_all op lowerings to the corresponding async
702 // runtime function calls.
703 //===----------------------------------------------------------------------===//
704 
705 namespace {
706 
707 template <typename AwaitType, typename AwaitableType>
708 class AwaitOpLoweringBase : public ConversionPattern {
709 protected:
AwaitOpLoweringBase(MLIRContext * ctx,const llvm::DenseMap<FuncOp,CoroMachinery> & outlinedFunctions,StringRef blockingAwaitFuncName,StringRef coroAwaitFuncName)710   explicit AwaitOpLoweringBase(
711       MLIRContext *ctx,
712       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions,
713       StringRef blockingAwaitFuncName, StringRef coroAwaitFuncName)
714       : ConversionPattern(AwaitType::getOperationName(), 1, ctx),
715         outlinedFunctions(outlinedFunctions),
716         blockingAwaitFuncName(blockingAwaitFuncName),
717         coroAwaitFuncName(coroAwaitFuncName) {}
718 
719 public:
720   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const721   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
722                   ConversionPatternRewriter &rewriter) const override {
723     // We can only await on one the `AwaitableType` (for `await` it can be
724     // only a `token`, for `await_all` it is a `group`).
725     auto await = cast<AwaitType>(op);
726     if (!await.operand().getType().template isa<AwaitableType>())
727       return failure();
728 
729     // Check if await operation is inside the outlined coroutine function.
730     auto func = await->template getParentOfType<FuncOp>();
731     auto outlined = outlinedFunctions.find(func);
732     const bool isInCoroutine = outlined != outlinedFunctions.end();
733 
734     Location loc = op->getLoc();
735 
736     // Inside regular function we convert await operation to the blocking
737     // async API await function call.
738     if (!isInCoroutine)
739       rewriter.create<CallOp>(loc, TypeRange(), blockingAwaitFuncName,
740                               ValueRange(operands[0]));
741 
742     // Inside the coroutine we convert await operation into coroutine suspension
743     // point, and resume execution asynchronously.
744     if (isInCoroutine) {
745       const CoroMachinery &coro = outlined->getSecond();
746 
747       OpBuilder builder(op, rewriter.getListener());
748       MLIRContext *ctx = op->getContext();
749 
750       // A pointer to coroutine resume intrinsic wrapper.
751       auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
752       auto resumePtr = builder.create<LLVM::AddressOfOp>(
753           loc, resumeFnTy.getPointerTo(), kResume);
754 
755       // Save the coroutine state: @llvm.coro.save
756       auto coroSave = builder.create<LLVM::CallOp>(
757           loc, LLVM::LLVMTokenType::get(ctx),
758           builder.getSymbolRefAttr(kCoroSave), ValueRange(coro.coroHandle));
759 
760       // Call async runtime API to resume a coroutine in the managed thread when
761       // the async await argument becomes ready.
762       SmallVector<Value, 3> awaitAndExecuteArgs = {operands[0], coro.coroHandle,
763                                                    resumePtr.res()};
764       builder.create<CallOp>(loc, TypeRange(), coroAwaitFuncName,
765                              awaitAndExecuteArgs);
766 
767       Block *suspended = op->getBlock();
768 
769       // Split the entry block before the await operation.
770       Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
771       addSuspensionPoint(coro, coroSave.getResult(0), op, suspended, resume,
772                          builder);
773     }
774 
775     // Original operation was replaced by function call or suspension point.
776     rewriter.eraseOp(op);
777 
778     return success();
779   }
780 
781 private:
782   const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
783   StringRef blockingAwaitFuncName;
784   StringRef coroAwaitFuncName;
785 };
786 
787 // Lowering for `async.await` operation (only token operands are supported).
788 class AwaitOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
789   using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
790 
791 public:
AwaitOpLowering(MLIRContext * ctx,const llvm::DenseMap<FuncOp,CoroMachinery> & outlinedFunctions)792   explicit AwaitOpLowering(
793       MLIRContext *ctx,
794       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
795       : Base(ctx, outlinedFunctions, kAwaitToken, kAwaitAndExecute) {}
796 };
797 
798 // Lowering for `async.await_all` operation.
799 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
800   using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
801 
802 public:
AwaitAllOpLowering(MLIRContext * ctx,const llvm::DenseMap<FuncOp,CoroMachinery> & outlinedFunctions)803   explicit AwaitAllOpLowering(
804       MLIRContext *ctx,
805       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
806       : Base(ctx, outlinedFunctions, kAwaitGroup, kAwaitAllAndExecute) {}
807 };
808 
809 } // namespace
810 
811 //===----------------------------------------------------------------------===//
812 
813 namespace {
814 struct ConvertAsyncToLLVMPass
815     : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> {
816   void runOnOperation() override;
817 };
818 
runOnOperation()819 void ConvertAsyncToLLVMPass::runOnOperation() {
820   ModuleOp module = getOperation();
821   SymbolTable symbolTable(module);
822 
823   // Outline all `async.execute` body regions into async functions (coroutines).
824   llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
825 
826   WalkResult outlineResult = module.walk([&](ExecuteOp execute) {
827     // We currently do not support execute operations that have async value
828     // operands or produce async results.
829     if (!execute.operands().empty() || !execute.results().empty()) {
830       execute.emitOpError("can't outline async.execute op with async value "
831                           "operands or returned async results");
832       return WalkResult::interrupt();
833     }
834 
835     outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
836 
837     return WalkResult::advance();
838   });
839 
840   // Failed to outline all async execute operations.
841   if (outlineResult.wasInterrupted()) {
842     signalPassFailure();
843     return;
844   }
845 
846   LLVM_DEBUG({
847     llvm::dbgs() << "Outlined " << outlinedFunctions.size()
848                  << " async functions\n";
849   });
850 
851   // Add declarations for all functions required by the coroutines lowering.
852   addResumeFunction(module);
853   addAsyncRuntimeApiDeclarations(module);
854   addCoroutineIntrinsicsDeclarations(module);
855   addCRuntimeDeclarations(module);
856 
857   MLIRContext *ctx = &getContext();
858 
859   // Convert async dialect types and operations to LLVM dialect.
860   AsyncRuntimeTypeConverter converter;
861   OwningRewritePatternList patterns;
862 
863   populateFuncOpTypeConversionPattern(patterns, ctx, converter);
864   patterns.insert<CallOpOpConversion>(ctx);
865   patterns.insert<AddRefOpLowering, DropRefOpLowering>(ctx);
866   patterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
867   patterns.insert<AwaitOpLowering, AwaitAllOpLowering>(ctx, outlinedFunctions);
868 
869   ConversionTarget target(*ctx);
870   target.addLegalOp<ConstantOp>();
871   target.addLegalDialect<LLVM::LLVMDialect>();
872   target.addIllegalDialect<AsyncDialect>();
873   target.addDynamicallyLegalOp<FuncOp>(
874       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
875   target.addDynamicallyLegalOp<CallOp>(
876       [&](CallOp op) { return converter.isLegal(op.getResultTypes()); });
877 
878   if (failed(applyPartialConversion(module, target, std::move(patterns))))
879     signalPassFailure();
880 }
881 } // namespace
882 
createConvertAsyncToLLVMPass()883 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
884   return std::make_unique<ConvertAsyncToLLVMPass>();
885 }
886