1 //===- CoroEarly.cpp - Coroutine Early Function Pass ----------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 // This pass lowers coroutine intrinsics that hide the details of the exact
10 // calling convention for coroutine resume and destroy functions and details of
11 // the structure of the coroutine frame.
12 //===----------------------------------------------------------------------===//
13 
14 #include "CoroInternal.h"
15 #include "llvm/IR/CallSite.h"
16 #include "llvm/IR/IRBuilder.h"
17 #include "llvm/IR/InstIterator.h"
18 #include "llvm/IR/Module.h"
19 #include "llvm/Pass.h"
20 
21 using namespace llvm;
22 
23 #define DEBUG_TYPE "coro-early"
24 
25 namespace {
26 // Created on demand if CoroEarly pass has work to do.
27 class Lowerer : public coro::LowererBase {
28   IRBuilder<> Builder;
29   PointerType *const AnyResumeFnPtrTy;
30   Constant *NoopCoro = nullptr;
31 
32   void lowerResumeOrDestroy(CallSite CS, CoroSubFnInst::ResumeKind);
33   void lowerCoroPromise(CoroPromiseInst *Intrin);
34   void lowerCoroDone(IntrinsicInst *II);
35   void lowerCoroNoop(IntrinsicInst *II);
36 
37 public:
Lowerer(Module & M)38   Lowerer(Module &M)
39       : LowererBase(M), Builder(Context),
40         AnyResumeFnPtrTy(FunctionType::get(Type::getVoidTy(Context), Int8Ptr,
41                                            /*isVarArg=*/false)
42                              ->getPointerTo()) {}
43   bool lowerEarlyIntrinsics(Function &F);
44 };
45 }
46 
47 // Replace a direct call to coro.resume or coro.destroy with an indirect call to
48 // an address returned by coro.subfn.addr intrinsic. This is done so that
49 // CGPassManager recognizes devirtualization when CoroElide pass replaces a call
50 // to coro.subfn.addr with an appropriate function address.
lowerResumeOrDestroy(CallSite CS,CoroSubFnInst::ResumeKind Index)51 void Lowerer::lowerResumeOrDestroy(CallSite CS,
52                                    CoroSubFnInst::ResumeKind Index) {
53   Value *ResumeAddr =
54       makeSubFnCall(CS.getArgOperand(0), Index, CS.getInstruction());
55   CS.setCalledFunction(ResumeAddr);
56   CS.setCallingConv(CallingConv::Fast);
57 }
58 
59 // Coroutine promise field is always at the fixed offset from the beginning of
60 // the coroutine frame. i8* coro.promise(i8*, i1 from) intrinsic adds an offset
61 // to a passed pointer to move from coroutine frame to coroutine promise and
62 // vice versa. Since we don't know exactly which coroutine frame it is, we build
63 // a coroutine frame mock up starting with two function pointers, followed by a
64 // properly aligned coroutine promise field.
65 // TODO: Handle the case when coroutine promise alloca has align override.
lowerCoroPromise(CoroPromiseInst * Intrin)66 void Lowerer::lowerCoroPromise(CoroPromiseInst *Intrin) {
67   Value *Operand = Intrin->getArgOperand(0);
68   unsigned Alignement = Intrin->getAlignment();
69   Type *Int8Ty = Builder.getInt8Ty();
70 
71   auto *SampleStruct =
72       StructType::get(Context, {AnyResumeFnPtrTy, AnyResumeFnPtrTy, Int8Ty});
73   const DataLayout &DL = TheModule.getDataLayout();
74   int64_t Offset = alignTo(
75       DL.getStructLayout(SampleStruct)->getElementOffset(2), Alignement);
76   if (Intrin->isFromPromise())
77     Offset = -Offset;
78 
79   Builder.SetInsertPoint(Intrin);
80   Value *Replacement =
81       Builder.CreateConstInBoundsGEP1_32(Int8Ty, Operand, Offset);
82 
83   Intrin->replaceAllUsesWith(Replacement);
84   Intrin->eraseFromParent();
85 }
86 
87 // When a coroutine reaches final suspend point, it zeros out ResumeFnAddr in
88 // the coroutine frame (it is UB to resume from a final suspend point).
89 // The llvm.coro.done intrinsic is used to check whether a coroutine is
90 // suspended at the final suspend point or not.
lowerCoroDone(IntrinsicInst * II)91 void Lowerer::lowerCoroDone(IntrinsicInst *II) {
92   Value *Operand = II->getArgOperand(0);
93 
94   // ResumeFnAddr is the first pointer sized element of the coroutine frame.
95   auto *FrameTy = Int8Ptr;
96   PointerType *FramePtrTy = FrameTy->getPointerTo();
97 
98   Builder.SetInsertPoint(II);
99   auto *BCI = Builder.CreateBitCast(Operand, FramePtrTy);
100   auto *Gep = Builder.CreateConstInBoundsGEP1_32(FrameTy, BCI, 0);
101   auto *Load = Builder.CreateLoad(Gep);
102   auto *Cond = Builder.CreateICmpEQ(Load, NullPtr);
103 
104   II->replaceAllUsesWith(Cond);
105   II->eraseFromParent();
106 }
107 
lowerCoroNoop(IntrinsicInst * II)108 void Lowerer::lowerCoroNoop(IntrinsicInst *II) {
109   if (!NoopCoro) {
110     LLVMContext &C = Builder.getContext();
111     Module &M = *II->getModule();
112 
113     // Create a noop.frame struct type.
114     StructType *FrameTy = StructType::create(C, "NoopCoro.Frame");
115     auto *FramePtrTy = FrameTy->getPointerTo();
116     auto *FnTy = FunctionType::get(Type::getVoidTy(C), FramePtrTy,
117                                    /*IsVarArgs=*/false);
118     auto *FnPtrTy = FnTy->getPointerTo();
119     FrameTy->setBody({FnPtrTy, FnPtrTy});
120 
121     // Create a Noop function that does nothing.
122     Function *NoopFn =
123         Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage,
124                          "NoopCoro.ResumeDestroy", &M);
125     NoopFn->setCallingConv(CallingConv::Fast);
126     auto *Entry = BasicBlock::Create(C, "entry", NoopFn);
127     ReturnInst::Create(C, Entry);
128 
129     // Create a constant struct for the frame.
130     Constant* Values[] = {NoopFn, NoopFn};
131     Constant* NoopCoroConst = ConstantStruct::get(FrameTy, Values);
132     NoopCoro = new GlobalVariable(M, NoopCoroConst->getType(), /*isConstant=*/true,
133                                 GlobalVariable::PrivateLinkage, NoopCoroConst,
134                                 "NoopCoro.Frame.Const");
135   }
136 
137   Builder.SetInsertPoint(II);
138   auto *NoopCoroVoidPtr = Builder.CreateBitCast(NoopCoro, Int8Ptr);
139   II->replaceAllUsesWith(NoopCoroVoidPtr);
140   II->eraseFromParent();
141 }
142 
143 // Prior to CoroSplit, calls to coro.begin needs to be marked as NoDuplicate,
144 // as CoroSplit assumes there is exactly one coro.begin. After CoroSplit,
145 // NoDuplicate attribute will be removed from coro.begin otherwise, it will
146 // interfere with inlining.
setCannotDuplicate(CoroIdInst * CoroId)147 static void setCannotDuplicate(CoroIdInst *CoroId) {
148   for (User *U : CoroId->users())
149     if (auto *CB = dyn_cast<CoroBeginInst>(U))
150       CB->setCannotDuplicate();
151 }
152 
lowerEarlyIntrinsics(Function & F)153 bool Lowerer::lowerEarlyIntrinsics(Function &F) {
154   bool Changed = false;
155   CoroIdInst *CoroId = nullptr;
156   SmallVector<CoroFreeInst *, 4> CoroFrees;
157   for (auto IB = inst_begin(F), IE = inst_end(F); IB != IE;) {
158     Instruction &I = *IB++;
159     if (auto CS = CallSite(&I)) {
160       switch (CS.getIntrinsicID()) {
161       default:
162         continue;
163       case Intrinsic::coro_free:
164         CoroFrees.push_back(cast<CoroFreeInst>(&I));
165         break;
166       case Intrinsic::coro_suspend:
167         // Make sure that final suspend point is not duplicated as CoroSplit
168         // pass expects that there is at most one final suspend point.
169         if (cast<CoroSuspendInst>(&I)->isFinal())
170           CS.setCannotDuplicate();
171         break;
172       case Intrinsic::coro_end:
173         // Make sure that fallthrough coro.end is not duplicated as CoroSplit
174         // pass expects that there is at most one fallthrough coro.end.
175         if (cast<CoroEndInst>(&I)->isFallthrough())
176           CS.setCannotDuplicate();
177         break;
178       case Intrinsic::coro_noop:
179         lowerCoroNoop(cast<IntrinsicInst>(&I));
180         break;
181       case Intrinsic::coro_id:
182         // Mark a function that comes out of the frontend that has a coro.id
183         // with a coroutine attribute.
184         if (auto *CII = cast<CoroIdInst>(&I)) {
185           if (CII->getInfo().isPreSplit()) {
186             F.addFnAttr(CORO_PRESPLIT_ATTR, UNPREPARED_FOR_SPLIT);
187             setCannotDuplicate(CII);
188             CII->setCoroutineSelf();
189             CoroId = cast<CoroIdInst>(&I);
190           }
191         }
192         break;
193       case Intrinsic::coro_resume:
194         lowerResumeOrDestroy(CS, CoroSubFnInst::ResumeIndex);
195         break;
196       case Intrinsic::coro_destroy:
197         lowerResumeOrDestroy(CS, CoroSubFnInst::DestroyIndex);
198         break;
199       case Intrinsic::coro_promise:
200         lowerCoroPromise(cast<CoroPromiseInst>(&I));
201         break;
202       case Intrinsic::coro_done:
203         lowerCoroDone(cast<IntrinsicInst>(&I));
204         break;
205       }
206       Changed = true;
207     }
208   }
209   // Make sure that all CoroFree reference the coro.id intrinsic.
210   // Token type is not exposed through coroutine C/C++ builtins to plain C, so
211   // we allow specifying none and fixing it up here.
212   if (CoroId)
213     for (CoroFreeInst *CF : CoroFrees)
214       CF->setArgOperand(0, CoroId);
215   return Changed;
216 }
217 
218 //===----------------------------------------------------------------------===//
219 //                              Top Level Driver
220 //===----------------------------------------------------------------------===//
221 
222 namespace {
223 
224 struct CoroEarly : public FunctionPass {
225   static char ID; // Pass identification, replacement for typeid.
CoroEarly__anon0737b4120211::CoroEarly226   CoroEarly() : FunctionPass(ID) {
227     initializeCoroEarlyPass(*PassRegistry::getPassRegistry());
228   }
229 
230   std::unique_ptr<Lowerer> L;
231 
232   // This pass has work to do only if we find intrinsics we are going to lower
233   // in the module.
doInitialization__anon0737b4120211::CoroEarly234   bool doInitialization(Module &M) override {
235     if (coro::declaresIntrinsics(
236             M, {"llvm.coro.id", "llvm.coro.destroy", "llvm.coro.done",
237                 "llvm.coro.end", "llvm.coro.noop", "llvm.coro.free",
238                 "llvm.coro.promise", "llvm.coro.resume", "llvm.coro.suspend"}))
239       L = llvm::make_unique<Lowerer>(M);
240     return false;
241   }
242 
runOnFunction__anon0737b4120211::CoroEarly243   bool runOnFunction(Function &F) override {
244     if (!L)
245       return false;
246 
247     return L->lowerEarlyIntrinsics(F);
248   }
249 
getAnalysisUsage__anon0737b4120211::CoroEarly250   void getAnalysisUsage(AnalysisUsage &AU) const override {
251     AU.setPreservesCFG();
252   }
getPassName__anon0737b4120211::CoroEarly253   StringRef getPassName() const override {
254     return "Lower early coroutine intrinsics";
255   }
256 };
257 }
258 
259 char CoroEarly::ID = 0;
260 INITIALIZE_PASS(CoroEarly, "coro-early", "Lower early coroutine intrinsics",
261                 false, false)
262 
createCoroEarlyPass()263 Pass *llvm::createCoroEarlyPass() { return new CoroEarly(); }
264