//===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" #include "../PassDetail.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/FormatVariadic.h" #define DEBUG_TYPE "convert-async-to-llvm" using namespace mlir; using namespace mlir::async; // Prefix for functions outlined from `async.execute` op regions. static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; //===----------------------------------------------------------------------===// // Async Runtime C API declaration. //===----------------------------------------------------------------------===// static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; static constexpr const char *kAddTokenToGroup = "mlirAsyncRuntimeAddTokenToGroup"; static constexpr const char *kAwaitAndExecute = "mlirAsyncRuntimeAwaitTokenAndExecute"; static constexpr const char *kAwaitAllAndExecute = "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; namespace { // Async Runtime API function types. struct AsyncAPI { static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { auto ref = LLVM::LLVMType::getInt8PtrTy(ctx); auto count = IntegerType::get(32, ctx); return FunctionType::get({ref, count}, {}, ctx); } static FunctionType createTokenFunctionType(MLIRContext *ctx) { return FunctionType::get({}, {TokenType::get(ctx)}, ctx); } static FunctionType createGroupFunctionType(MLIRContext *ctx) { return FunctionType::get({}, {GroupType::get(ctx)}, ctx); } static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { return FunctionType::get({TokenType::get(ctx)}, {}, ctx); } static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { return FunctionType::get({TokenType::get(ctx)}, {}, ctx); } static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { return FunctionType::get({GroupType::get(ctx)}, {}, ctx); } static FunctionType executeFunctionType(MLIRContext *ctx) { auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); auto resume = resumeFunctionType(ctx).getPointerTo(); return FunctionType::get({hdl, resume}, {}, ctx); } static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { auto i64 = IntegerType::get(64, ctx); return FunctionType::get({TokenType::get(ctx), GroupType::get(ctx)}, {i64}, ctx); } static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) { auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); auto resume = resumeFunctionType(ctx).getPointerTo(); return FunctionType::get({TokenType::get(ctx), hdl, resume}, {}, ctx); } static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); auto resume = resumeFunctionType(ctx).getPointerTo(); return FunctionType::get({GroupType::get(ctx), hdl, resume}, {}, ctx); } // Auxiliary coroutine resume intrinsic wrapper. static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) { auto voidTy = LLVM::LLVMType::getVoidTy(ctx); auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx); return LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false); } }; } // namespace // Adds Async Runtime C API declarations to the module. static void addAsyncRuntimeApiDeclarations(ModuleOp module) { auto builder = OpBuilder::atBlockTerminator(module.getBody()); auto addFuncDecl = [&](StringRef name, FunctionType type) { if (module.lookupSymbol(name)) return; builder.create(module.getLoc(), name, type).setPrivate(); }; MLIRContext *ctx = module.getContext(); addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); addFuncDecl(kAwaitAndExecute, AsyncAPI::awaitAndExecuteFunctionType(ctx)); addFuncDecl(kAwaitAllAndExecute, AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); } //===----------------------------------------------------------------------===// // LLVM coroutines intrinsics declarations. //===----------------------------------------------------------------------===// static constexpr const char *kCoroId = "llvm.coro.id"; static constexpr const char *kCoroSizeI64 = "llvm.coro.size.i64"; static constexpr const char *kCoroBegin = "llvm.coro.begin"; static constexpr const char *kCoroSave = "llvm.coro.save"; static constexpr const char *kCoroSuspend = "llvm.coro.suspend"; static constexpr const char *kCoroEnd = "llvm.coro.end"; static constexpr const char *kCoroFree = "llvm.coro.free"; static constexpr const char *kCoroResume = "llvm.coro.resume"; /// Adds an LLVM function declaration to a module. static void addLLVMFuncDecl(ModuleOp module, OpBuilder &builder, StringRef name, LLVM::LLVMType ret, ArrayRef params) { if (module.lookupSymbol(name)) return; LLVM::LLVMType type = LLVM::LLVMType::getFunctionTy(ret, params, false); builder.create(module.getLoc(), name, type); } /// Adds coroutine intrinsics declarations to the module. static void addCoroutineIntrinsicsDeclarations(ModuleOp module) { using namespace mlir::LLVM; MLIRContext *ctx = module.getContext(); OpBuilder builder(module.getBody()->getTerminator()); auto token = LLVMTokenType::get(ctx); auto voidTy = LLVMType::getVoidTy(ctx); auto i8 = LLVMType::getInt8Ty(ctx); auto i1 = LLVMType::getInt1Ty(ctx); auto i32 = LLVMType::getInt32Ty(ctx); auto i64 = LLVMType::getInt64Ty(ctx); auto i8Ptr = LLVMType::getInt8PtrTy(ctx); addLLVMFuncDecl(module, builder, kCoroId, token, {i32, i8Ptr, i8Ptr, i8Ptr}); addLLVMFuncDecl(module, builder, kCoroSizeI64, i64, {}); addLLVMFuncDecl(module, builder, kCoroBegin, i8Ptr, {token, i8Ptr}); addLLVMFuncDecl(module, builder, kCoroSave, token, {i8Ptr}); addLLVMFuncDecl(module, builder, kCoroSuspend, i8, {token, i1}); addLLVMFuncDecl(module, builder, kCoroEnd, i1, {i8Ptr, i1}); addLLVMFuncDecl(module, builder, kCoroFree, i8Ptr, {token, i8Ptr}); addLLVMFuncDecl(module, builder, kCoroResume, voidTy, {i8Ptr}); } //===----------------------------------------------------------------------===// // Add malloc/free declarations to the module. //===----------------------------------------------------------------------===// static constexpr const char *kMalloc = "malloc"; static constexpr const char *kFree = "free"; /// Adds malloc/free declarations to the module. static void addCRuntimeDeclarations(ModuleOp module) { using namespace mlir::LLVM; MLIRContext *ctx = module.getContext(); OpBuilder builder(module.getBody()->getTerminator()); auto voidTy = LLVMType::getVoidTy(ctx); auto i64 = LLVMType::getInt64Ty(ctx); auto i8Ptr = LLVMType::getInt8PtrTy(ctx); addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64}); addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr}); } //===----------------------------------------------------------------------===// // Coroutine resume function wrapper. //===----------------------------------------------------------------------===// static constexpr const char *kResume = "__resume"; // A function that takes a coroutine handle and calls a `llvm.coro.resume` // intrinsics. We need this function to be able to pass it to the async // runtime execute API. static void addResumeFunction(ModuleOp module) { MLIRContext *ctx = module.getContext(); OpBuilder moduleBuilder(module.getBody()->getTerminator()); Location loc = module.getLoc(); if (module.lookupSymbol(kResume)) return; auto voidTy = LLVM::LLVMType::getVoidTy(ctx); auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx); auto resumeOp = moduleBuilder.create( loc, kResume, LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false)); resumeOp.setPrivate(); auto *block = resumeOp.addEntryBlock(); OpBuilder blockBuilder = OpBuilder::atBlockEnd(block); blockBuilder.create(loc, TypeRange(), blockBuilder.getSymbolRefAttr(kCoroResume), resumeOp.getArgument(0)); blockBuilder.create(loc, ValueRange()); } //===----------------------------------------------------------------------===// // async.execute op outlining to the coroutine functions. //===----------------------------------------------------------------------===// // Function targeted for coroutine transformation has two additional blocks at // the end: coroutine cleanup and coroutine suspension. // // async.await op lowering additionaly creates a resume block for each // operation to enable non-blocking waiting via coroutine suspension. namespace { struct CoroMachinery { Value asyncToken; Value coroHandle; Block *cleanup; Block *suspend; }; } // namespace // Builds an coroutine template compatible with LLVM coroutines lowering. // // - `entry` block sets up the coroutine. // - `cleanup` block cleans up the coroutine state. // - `suspend block after the @llvm.coro.end() defines what value will be // returned to the initial caller of a coroutine. Everything before the // @llvm.coro.end() will be executed at every suspension point. // // Coroutine structure (only the important bits): // // func @async_execute_fn() -> !async.token { // ^entryBlock(): // %token = : !async.token // create async runtime token // %hdl = llvm.call @llvm.coro.id(...) // create a coroutine handle // br ^cleanup // // ^cleanup: // llvm.call @llvm.coro.free(...) // delete coroutine state // br ^suspend // // ^suspend: // llvm.call @llvm.coro.end(...) // marks the end of a coroutine // return %token : !async.token // } // // The actual code for the async.execute operation body region will be inserted // before the entry block terminator. // // static CoroMachinery setupCoroMachinery(FuncOp func) { assert(func.getBody().empty() && "Function must have empty body"); MLIRContext *ctx = func.getContext(); auto token = LLVM::LLVMTokenType::get(ctx); auto i1 = LLVM::LLVMType::getInt1Ty(ctx); auto i32 = LLVM::LLVMType::getInt32Ty(ctx); auto i64 = LLVM::LLVMType::getInt64Ty(ctx); auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx); Block *entryBlock = func.addEntryBlock(); Location loc = func.getBody().getLoc(); OpBuilder builder = OpBuilder::atBlockBegin(entryBlock); // ------------------------------------------------------------------------ // // Allocate async tokens/values that we will return from a ramp function. // ------------------------------------------------------------------------ // auto createToken = builder.create(loc, kCreateToken, TokenType::get(ctx)); // ------------------------------------------------------------------------ // // Initialize coroutine: allocate frame, get coroutine handle. // ------------------------------------------------------------------------ // // Constants for initializing coroutine frame. auto constZero = builder.create(loc, i32, builder.getI32IntegerAttr(0)); auto constFalse = builder.create(loc, i1, builder.getBoolAttr(false)); auto nullPtr = builder.create(loc, i8Ptr); // Get coroutine id: @llvm.coro.id auto coroId = builder.create( loc, token, builder.getSymbolRefAttr(kCoroId), ValueRange({constZero, nullPtr, nullPtr, nullPtr})); // Get coroutine frame size: @llvm.coro.size.i64 auto coroSize = builder.create( loc, i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange()); // Allocate memory for coroutine frame. auto coroAlloc = builder.create( loc, i8Ptr, builder.getSymbolRefAttr(kMalloc), ValueRange(coroSize.getResult(0))); // Begin a coroutine: @llvm.coro.begin auto coroHdl = builder.create( loc, i8Ptr, builder.getSymbolRefAttr(kCoroBegin), ValueRange({coroId.getResult(0), coroAlloc.getResult(0)})); Block *cleanupBlock = func.addBlock(); Block *suspendBlock = func.addBlock(); // ------------------------------------------------------------------------ // // Coroutine cleanup block: deallocate coroutine frame, free the memory. // ------------------------------------------------------------------------ // builder.setInsertionPointToStart(cleanupBlock); // Get a pointer to the coroutine frame memory: @llvm.coro.free. auto coroMem = builder.create( loc, i8Ptr, builder.getSymbolRefAttr(kCoroFree), ValueRange({coroId.getResult(0), coroHdl.getResult(0)})); // Free the memory. builder.create(loc, TypeRange(), builder.getSymbolRefAttr(kFree), ValueRange(coroMem.getResult(0))); // Branch into the suspend block. builder.create(loc, suspendBlock); // ------------------------------------------------------------------------ // // Coroutine suspend block: mark the end of a coroutine and return allocated // async token. // ------------------------------------------------------------------------ // builder.setInsertionPointToStart(suspendBlock); // Mark the end of a coroutine: @llvm.coro.end. builder.create(loc, i1, builder.getSymbolRefAttr(kCoroEnd), ValueRange({coroHdl.getResult(0), constFalse})); // Return created `async.token` from the suspend block. This will be the // return value of a coroutine ramp function. builder.create(loc, createToken.getResult(0)); // Branch from the entry block to the cleanup block to create a valid CFG. builder.setInsertionPointToEnd(entryBlock); builder.create(loc, cleanupBlock); // `async.await` op lowering will create resume blocks for async // continuations, and will conditionally branch to cleanup or suspend blocks. return {createToken.getResult(0), coroHdl.getResult(0), cleanupBlock, suspendBlock}; } // Add a LLVM coroutine suspension point to the end of suspended block, to // resume execution in resume block. The caller is responsible for creating the // two suspended/resume blocks with the desired ops contained in each block. // This function merely provides the required control flow logic. // // `coroState` must be a value returned from the call to @llvm.coro.save(...) // intrinsic (saved coroutine state). // // Before: // // ^bb0: // "opBefore"(...) // "op"(...) // ^cleanup: ... // ^suspend: ... // ^resume: // "op"(...) // // After: // // ^bb0: // "opBefore"(...) // %suspend = llmv.call @llvm.coro.suspend(...) // switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] // ^resume: // "op"(...) // ^cleanup: ... // ^suspend: ... // static void addSuspensionPoint(CoroMachinery coro, Value coroState, Operation *op, Block *suspended, Block *resume, OpBuilder &builder) { Location loc = op->getLoc(); MLIRContext *ctx = op->getContext(); auto i1 = LLVM::LLVMType::getInt1Ty(ctx); auto i8 = LLVM::LLVMType::getInt8Ty(ctx); // Add a coroutine suspension in place of original `op` in the split block. OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToEnd(suspended); auto constFalse = builder.create(loc, i1, builder.getBoolAttr(false)); // Suspend a coroutine: @llvm.coro.suspend auto coroSuspend = builder.create( loc, i8, builder.getSymbolRefAttr(kCoroSuspend), ValueRange({coroState, constFalse})); // After a suspension point decide if we should branch into resume, cleanup // or suspend block of the coroutine (see @llvm.coro.suspend return code // documentation). auto constZero = builder.create(loc, i8, builder.getI8IntegerAttr(0)); auto constNegOne = builder.create(loc, i8, builder.getI8IntegerAttr(-1)); Block *resumeOrCleanup = builder.createBlock(resume); // Suspend the coroutine ...? builder.setInsertionPointToEnd(suspended); auto isNegOne = builder.create( loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constNegOne); builder.create(loc, isNegOne, /*trueDest=*/coro.suspend, /*falseDest=*/resumeOrCleanup); // ... or resume or cleanup the coroutine? builder.setInsertionPointToStart(resumeOrCleanup); auto isZero = builder.create( loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero); builder.create(loc, isZero, /*trueDest=*/resume, /*falseDest=*/coro.cleanup); } // Outline the body region attached to the `async.execute` op into a standalone // function. // // Note that this is not reversible transformation. static std::pair outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { ModuleOp module = execute->getParentOfType(); MLIRContext *ctx = module.getContext(); Location loc = execute.getLoc(); OpBuilder moduleBuilder(module.getBody()->getTerminator()); // Collect all outlined function inputs. llvm::SetVector functionInputs(execute.dependencies().begin(), execute.dependencies().end()); getUsedValuesDefinedAbove(execute.body(), functionInputs); // Collect types for the outlined function inputs and outputs. auto typesRange = llvm::map_range( functionInputs, [](Value value) { return value.getType(); }); SmallVector inputTypes(typesRange.begin(), typesRange.end()); auto outputTypes = execute.getResultTypes(); auto funcType = moduleBuilder.getFunctionType(inputTypes, outputTypes); auto funcAttrs = ArrayRef(); // TODO: Derive outlined function name from the parent FuncOp (support // multiple nested async.execute operations). FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); symbolTable.insert(func, moduleBuilder.getInsertionPoint()); SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); // Prepare a function for coroutine lowering by adding entry/cleanup/suspend // blocks, adding llvm.coro instrinsics and setting up control flow. CoroMachinery coro = setupCoroMachinery(func); // Suspend async function at the end of an entry block, and resume it using // Async execute API (execution will be resumed in a thread managed by the // async runtime). Block *entryBlock = &func.getBlocks().front(); OpBuilder builder = OpBuilder::atBlockTerminator(entryBlock); // A pointer to coroutine resume intrinsic wrapper. auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); auto resumePtr = builder.create( loc, resumeFnTy.getPointerTo(), kResume); // Save the coroutine state: @llvm.coro.save auto coroSave = builder.create( loc, LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave), ValueRange({coro.coroHandle})); // Call async runtime API to execute a coroutine in the managed thread. SmallVector executeArgs = {coro.coroHandle, resumePtr.res()}; builder.create(loc, TypeRange(), kExecute, executeArgs); // Split the entry block before the terminator. auto *terminatorOp = entryBlock->getTerminator(); Block *suspended = terminatorOp->getBlock(); Block *resume = suspended->splitBlock(terminatorOp); addSuspensionPoint(coro, coroSave.getResult(0), terminatorOp, suspended, resume, builder); // Await on all dependencies before starting to execute the body region. builder.setInsertionPointToStart(resume); for (size_t i = 0; i < execute.dependencies().size(); ++i) builder.create(loc, func.getArgument(i)); // Map from function inputs defined above the execute op to the function // arguments. BlockAndValueMapping valueMapping; valueMapping.map(functionInputs, func.getArguments()); // Clone all operations from the execute operation body into the outlined // function body, and replace all `async.yield` operations with a call // to async runtime to emplace the result token. for (Operation &op : execute.body().getOps()) { if (isa(op)) { builder.create(loc, kEmplaceToken, TypeRange(), coro.asyncToken); continue; } builder.clone(op, valueMapping); } // Replace the original `async.execute` with a call to outlined function. OpBuilder callBuilder(execute); auto callOutlinedFunc = callBuilder.create(loc, func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); execute.replaceAllUsesWith(callOutlinedFunc.getResults()); execute.erase(); return {func, coro}; } //===----------------------------------------------------------------------===// // Convert Async dialect types to LLVM types. //===----------------------------------------------------------------------===// namespace { class AsyncRuntimeTypeConverter : public TypeConverter { public: AsyncRuntimeTypeConverter() { addConversion(convertType); } static Type convertType(Type type) { MLIRContext *ctx = type.getContext(); // Convert async tokens and groups to opaque pointers. if (type.isa()) return LLVM::LLVMType::getInt8PtrTy(ctx); return type; } }; } // namespace //===----------------------------------------------------------------------===// // Convert types for all call operations to lowered async types. //===----------------------------------------------------------------------===// namespace { class CallOpOpConversion : public ConversionPattern { public: explicit CallOpOpConversion(MLIRContext *ctx) : ConversionPattern(CallOp::getOperationName(), 1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { AsyncRuntimeTypeConverter converter; SmallVector resultTypes; converter.convertTypes(op->getResultTypes(), resultTypes); CallOp call = cast(op); rewriter.replaceOpWithNewOp(op, resultTypes, call.callee(), operands); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Async reference counting ops lowering (`async.add_ref` and `async.drop_ref` // to the corresponding API calls). //===----------------------------------------------------------------------===// namespace { template class RefCountingOpLowering : public ConversionPattern { public: explicit RefCountingOpLowering(MLIRContext *ctx, StringRef apiFunctionName) : ConversionPattern(RefCountingOp::getOperationName(), 1, ctx), apiFunctionName(apiFunctionName) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { RefCountingOp refCountingOp = cast(op); auto count = rewriter.create( op->getLoc(), rewriter.getI32Type(), rewriter.getI32IntegerAttr(refCountingOp.count())); rewriter.replaceOpWithNewOp(op, TypeRange(), apiFunctionName, ValueRange({operands[0], count})); return success(); } private: StringRef apiFunctionName; }; // async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call. class AddRefOpLowering : public RefCountingOpLowering { public: explicit AddRefOpLowering(MLIRContext *ctx) : RefCountingOpLowering(ctx, kAddRef) {} }; // async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. class DropRefOpLowering : public RefCountingOpLowering { public: explicit DropRefOpLowering(MLIRContext *ctx) : RefCountingOpLowering(ctx, kDropRef) {} }; } // namespace //===----------------------------------------------------------------------===// // async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. //===----------------------------------------------------------------------===// namespace { class CreateGroupOpLowering : public ConversionPattern { public: explicit CreateGroupOpLowering(MLIRContext *ctx) : ConversionPattern(CreateGroupOp::getOperationName(), 1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto retTy = GroupType::get(op->getContext()); rewriter.replaceOpWithNewOp(op, kCreateGroup, retTy); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // async.add_to_group op lowering to runtime function call. //===----------------------------------------------------------------------===// namespace { class AddToGroupOpLowering : public ConversionPattern { public: explicit AddToGroupOpLowering(MLIRContext *ctx) : ConversionPattern(AddToGroupOp::getOperationName(), 1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // Currently we can only add tokens to the group. auto addToGroup = cast(op); if (!addToGroup.operand().getType().isa()) return failure(); auto i64 = IntegerType::get(64, op->getContext()); rewriter.replaceOpWithNewOp(op, kAddTokenToGroup, i64, operands); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // async.await and async.await_all op lowerings to the corresponding async // runtime function calls. //===----------------------------------------------------------------------===// namespace { template class AwaitOpLoweringBase : public ConversionPattern { protected: explicit AwaitOpLoweringBase( MLIRContext *ctx, const llvm::DenseMap &outlinedFunctions, StringRef blockingAwaitFuncName, StringRef coroAwaitFuncName) : ConversionPattern(AwaitType::getOperationName(), 1, ctx), outlinedFunctions(outlinedFunctions), blockingAwaitFuncName(blockingAwaitFuncName), coroAwaitFuncName(coroAwaitFuncName) {} public: LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // We can only await on one the `AwaitableType` (for `await` it can be // only a `token`, for `await_all` it is a `group`). auto await = cast(op); if (!await.operand().getType().template isa()) return failure(); // Check if await operation is inside the outlined coroutine function. auto func = await->template getParentOfType(); auto outlined = outlinedFunctions.find(func); const bool isInCoroutine = outlined != outlinedFunctions.end(); Location loc = op->getLoc(); // Inside regular function we convert await operation to the blocking // async API await function call. if (!isInCoroutine) rewriter.create(loc, TypeRange(), blockingAwaitFuncName, ValueRange(operands[0])); // Inside the coroutine we convert await operation into coroutine suspension // point, and resume execution asynchronously. if (isInCoroutine) { const CoroMachinery &coro = outlined->getSecond(); OpBuilder builder(op, rewriter.getListener()); MLIRContext *ctx = op->getContext(); // A pointer to coroutine resume intrinsic wrapper. auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); auto resumePtr = builder.create( loc, resumeFnTy.getPointerTo(), kResume); // Save the coroutine state: @llvm.coro.save auto coroSave = builder.create( loc, LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave), ValueRange(coro.coroHandle)); // Call async runtime API to resume a coroutine in the managed thread when // the async await argument becomes ready. SmallVector awaitAndExecuteArgs = {operands[0], coro.coroHandle, resumePtr.res()}; builder.create(loc, TypeRange(), coroAwaitFuncName, awaitAndExecuteArgs); Block *suspended = op->getBlock(); // Split the entry block before the await operation. Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); addSuspensionPoint(coro, coroSave.getResult(0), op, suspended, resume, builder); } // Original operation was replaced by function call or suspension point. rewriter.eraseOp(op); return success(); } private: const llvm::DenseMap &outlinedFunctions; StringRef blockingAwaitFuncName; StringRef coroAwaitFuncName; }; // Lowering for `async.await` operation (only token operands are supported). class AwaitOpLowering : public AwaitOpLoweringBase { using Base = AwaitOpLoweringBase; public: explicit AwaitOpLowering( MLIRContext *ctx, const llvm::DenseMap &outlinedFunctions) : Base(ctx, outlinedFunctions, kAwaitToken, kAwaitAndExecute) {} }; // Lowering for `async.await_all` operation. class AwaitAllOpLowering : public AwaitOpLoweringBase { using Base = AwaitOpLoweringBase; public: explicit AwaitAllOpLowering( MLIRContext *ctx, const llvm::DenseMap &outlinedFunctions) : Base(ctx, outlinedFunctions, kAwaitGroup, kAwaitAllAndExecute) {} }; } // namespace //===----------------------------------------------------------------------===// namespace { struct ConvertAsyncToLLVMPass : public ConvertAsyncToLLVMBase { void runOnOperation() override; }; void ConvertAsyncToLLVMPass::runOnOperation() { ModuleOp module = getOperation(); SymbolTable symbolTable(module); // Outline all `async.execute` body regions into async functions (coroutines). llvm::DenseMap outlinedFunctions; WalkResult outlineResult = module.walk([&](ExecuteOp execute) { // We currently do not support execute operations that have async value // operands or produce async results. if (!execute.operands().empty() || !execute.results().empty()) { execute.emitOpError("can't outline async.execute op with async value " "operands or returned async results"); return WalkResult::interrupt(); } outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); return WalkResult::advance(); }); // Failed to outline all async execute operations. if (outlineResult.wasInterrupted()) { signalPassFailure(); return; } LLVM_DEBUG({ llvm::dbgs() << "Outlined " << outlinedFunctions.size() << " async functions\n"; }); // Add declarations for all functions required by the coroutines lowering. addResumeFunction(module); addAsyncRuntimeApiDeclarations(module); addCoroutineIntrinsicsDeclarations(module); addCRuntimeDeclarations(module); MLIRContext *ctx = &getContext(); // Convert async dialect types and operations to LLVM dialect. AsyncRuntimeTypeConverter converter; OwningRewritePatternList patterns; populateFuncOpTypeConversionPattern(patterns, ctx, converter); patterns.insert(ctx); patterns.insert(ctx); patterns.insert(ctx); patterns.insert(ctx, outlinedFunctions); ConversionTarget target(*ctx); target.addLegalOp(); target.addLegalDialect(); target.addIllegalDialect(); target.addDynamicallyLegalOp( [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); target.addDynamicallyLegalOp( [&](CallOp op) { return converter.isLegal(op.getResultTypes()); }); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } } // namespace std::unique_ptr> mlir::createConvertAsyncToLLVMPass() { return std::make_unique(); }