1 //===---- ManagedMemoryRewrite.cpp - Rewrite global & malloc'd memory -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Take a module and rewrite:
10 // 1. `malloc` -> `polly_mallocManaged`
11 // 2. `free` -> `polly_freeManaged`
12 // 3. global arrays with initializers -> global arrays that are initialized
13 //                                       with a constructor call to
14 //                                       `polly_mallocManaged`.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "polly/CodeGen/IRBuilder.h"
19 #include "polly/CodeGen/PPCGCodeGeneration.h"
20 #include "polly/DependenceInfo.h"
21 #include "polly/LinkAllPasses.h"
22 #include "polly/Options.h"
23 #include "polly/ScopDetection.h"
24 #include "llvm/ADT/SmallSet.h"
25 #include "llvm/Analysis/CaptureTracking.h"
26 #include "llvm/InitializePasses.h"
27 #include "llvm/Transforms/Utils/ModuleUtils.h"
28 
29 using namespace polly;
30 
31 static cl::opt<bool> RewriteAllocas(
32     "polly-acc-rewrite-allocas",
33     cl::desc(
34         "Ask the managed memory rewriter to also rewrite alloca instructions"),
35     cl::Hidden, cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory));
36 
37 static cl::opt<bool> IgnoreLinkageForGlobals(
38     "polly-acc-rewrite-ignore-linkage-for-globals",
39     cl::desc(
40         "By default, we only rewrite globals with internal linkage. This flag "
41         "enables rewriting of globals regardless of linkage"),
42     cl::Hidden, cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory));
43 
44 #define DEBUG_TYPE "polly-acc-rewrite-managed-memory"
45 namespace {
46 
getOrCreatePollyMallocManaged(Module & M)47 static llvm::Function *getOrCreatePollyMallocManaged(Module &M) {
48   const char *Name = "polly_mallocManaged";
49   Function *F = M.getFunction(Name);
50 
51   // If F is not available, declare it.
52   if (!F) {
53     GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
54     PollyIRBuilder Builder(M.getContext());
55     // TODO: How do I get `size_t`? I assume from DataLayout?
56     FunctionType *Ty = FunctionType::get(Builder.getInt8PtrTy(),
57                                          {Builder.getInt64Ty()}, false);
58     F = Function::Create(Ty, Linkage, Name, &M);
59   }
60 
61   return F;
62 }
63 
getOrCreatePollyFreeManaged(Module & M)64 static llvm::Function *getOrCreatePollyFreeManaged(Module &M) {
65   const char *Name = "polly_freeManaged";
66   Function *F = M.getFunction(Name);
67 
68   // If F is not available, declare it.
69   if (!F) {
70     GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
71     PollyIRBuilder Builder(M.getContext());
72     // TODO: How do I get `size_t`? I assume from DataLayout?
73     FunctionType *Ty =
74         FunctionType::get(Builder.getVoidTy(), {Builder.getInt8PtrTy()}, false);
75     F = Function::Create(Ty, Linkage, Name, &M);
76   }
77 
78   return F;
79 }
80 
81 // Expand a constant expression `Cur`, which is used at instruction `Parent`
82 // at index `index`.
83 // Since a constant expression can expand to multiple instructions, store all
84 // the expands into a set called `Expands`.
85 // Note that this goes inorder on the constant expression tree.
86 // A * ((B * D) + C)
87 // will be processed with first A, then B * D, then B, then D, and then C.
88 // Though ConstantExprs are not treated as "trees" but as DAGs, since you can
89 // have something like this:
90 //    *
91 //   /  \
92 //   \  /
93 //    (D)
94 //
95 // For the purposes of this expansion, we expand the two occurences of D
96 // separately. Therefore, we expand the DAG into the tree:
97 //  *
98 // / \
99 // D  D
100 // TODO: We don't _have_to do this, but this is the simplest solution.
101 // We can write a solution that keeps track of which constants have been
102 // already expanded.
expandConstantExpr(ConstantExpr * Cur,PollyIRBuilder & Builder,Instruction * Parent,int index,SmallPtrSet<Instruction *,4> & Expands)103 static void expandConstantExpr(ConstantExpr *Cur, PollyIRBuilder &Builder,
104                                Instruction *Parent, int index,
105                                SmallPtrSet<Instruction *, 4> &Expands) {
106   assert(Cur && "invalid constant expression passed");
107   Instruction *I = Cur->getAsInstruction();
108   assert(I && "unable to convert ConstantExpr to Instruction");
109 
110   LLVM_DEBUG(dbgs() << "Expanding ConstantExpression: (" << *Cur
111                     << ") in Instruction: (" << *I << ")\n";);
112 
113   // Invalidate `Cur` so that no one after this point uses `Cur`. Rather,
114   // they should mutate `I`.
115   Cur = nullptr;
116 
117   Expands.insert(I);
118   Parent->setOperand(index, I);
119 
120   // The things that `Parent` uses (its operands) should be created
121   // before `Parent`.
122   Builder.SetInsertPoint(Parent);
123   Builder.Insert(I);
124 
125   for (unsigned i = 0; i < I->getNumOperands(); i++) {
126     Value *Op = I->getOperand(i);
127     assert(isa<Constant>(Op) && "constant must have a constant operand");
128 
129     if (ConstantExpr *CExprOp = dyn_cast<ConstantExpr>(Op))
130       expandConstantExpr(CExprOp, Builder, I, i, Expands);
131   }
132 }
133 
134 // Edit all uses of `OldVal` to NewVal` in `Inst`. This will rewrite
135 // `ConstantExpr`s that are used in the `Inst`.
136 // Note that `replaceAllUsesWith` is insufficient for this purpose because it
137 // does not rewrite values in `ConstantExpr`s.
rewriteOldValToNew(Instruction * Inst,Value * OldVal,Value * NewVal,PollyIRBuilder & Builder)138 static void rewriteOldValToNew(Instruction *Inst, Value *OldVal, Value *NewVal,
139                                PollyIRBuilder &Builder) {
140 
141   // This contains a set of instructions in which OldVal must be replaced.
142   // We start with `Inst`, and we fill it up with the expanded `ConstantExpr`s
143   // from `Inst`s arguments.
144   // We need to go through this process because `replaceAllUsesWith` does not
145   // actually edit `ConstantExpr`s.
146   SmallPtrSet<Instruction *, 4> InstsToVisit = {Inst};
147 
148   // Expand all `ConstantExpr`s and place it in `InstsToVisit`.
149   for (unsigned i = 0; i < Inst->getNumOperands(); i++) {
150     Value *Operand = Inst->getOperand(i);
151     if (ConstantExpr *ValueConstExpr = dyn_cast<ConstantExpr>(Operand))
152       expandConstantExpr(ValueConstExpr, Builder, Inst, i, InstsToVisit);
153   }
154 
155   // Now visit each instruction and use `replaceUsesOfWith`. We know that
156   // will work because `I` cannot have any `ConstantExpr` within it.
157   for (Instruction *I : InstsToVisit)
158     I->replaceUsesOfWith(OldVal, NewVal);
159 }
160 
161 // Given a value `Current`, return all Instructions that may contain `Current`
162 // in an expression.
163 // We need this auxiliary function, because if we have a
164 // `Constant` that is a user of `V`, we need to recurse into the
165 // `Constant`s uses to gather the root instruciton.
getInstructionUsersOfValue(Value * V,SmallVector<Instruction *,4> & Owners)166 static void getInstructionUsersOfValue(Value *V,
167                                        SmallVector<Instruction *, 4> &Owners) {
168   if (auto *I = dyn_cast<Instruction>(V)) {
169     Owners.push_back(I);
170   } else {
171     // Anything that is a `User` must be a constant or an instruction.
172     auto *C = cast<Constant>(V);
173     for (Use &CUse : C->uses())
174       getInstructionUsersOfValue(CUse.getUser(), Owners);
175   }
176 }
177 
178 static void
replaceGlobalArray(Module & M,const DataLayout & DL,GlobalVariable & Array,SmallPtrSet<GlobalVariable *,4> & ReplacedGlobals)179 replaceGlobalArray(Module &M, const DataLayout &DL, GlobalVariable &Array,
180                    SmallPtrSet<GlobalVariable *, 4> &ReplacedGlobals) {
181   // We only want arrays.
182   ArrayType *ArrayTy = dyn_cast<ArrayType>(Array.getType()->getElementType());
183   if (!ArrayTy)
184     return;
185   Type *ElemTy = ArrayTy->getElementType();
186   PointerType *ElemPtrTy = ElemTy->getPointerTo();
187 
188   // We only wish to replace arrays that are visible in the module they
189   // inhabit. Otherwise, our type edit from [T] to T* would be illegal across
190   // modules.
191   const bool OnlyVisibleInsideModule = Array.hasPrivateLinkage() ||
192                                        Array.hasInternalLinkage() ||
193                                        IgnoreLinkageForGlobals;
194   if (!OnlyVisibleInsideModule) {
195     LLVM_DEBUG(
196         dbgs() << "Not rewriting (" << Array
197                << ") to managed memory "
198                   "because it could be visible externally. To force rewrite, "
199                   "use -polly-acc-rewrite-ignore-linkage-for-globals.\n");
200     return;
201   }
202 
203   if (!Array.hasInitializer() ||
204       !isa<ConstantAggregateZero>(Array.getInitializer())) {
205     LLVM_DEBUG(dbgs() << "Not rewriting (" << Array
206                       << ") to managed memory "
207                          "because it has an initializer which is "
208                          "not a zeroinitializer.\n");
209     return;
210   }
211 
212   // At this point, we have committed to replacing this array.
213   ReplacedGlobals.insert(&Array);
214 
215   std::string NewName = Array.getName().str();
216   NewName += ".toptr";
217   GlobalVariable *ReplacementToArr =
218       cast<GlobalVariable>(M.getOrInsertGlobal(NewName, ElemPtrTy));
219   ReplacementToArr->setInitializer(ConstantPointerNull::get(ElemPtrTy));
220 
221   Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M);
222   std::string FnName = Array.getName().str();
223   FnName += ".constructor";
224   PollyIRBuilder Builder(M.getContext());
225   FunctionType *Ty = FunctionType::get(Builder.getVoidTy(), false);
226   const GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
227   Function *F = Function::Create(Ty, Linkage, FnName, &M);
228   BasicBlock *Start = BasicBlock::Create(M.getContext(), "entry", F);
229   Builder.SetInsertPoint(Start);
230 
231   const uint64_t ArraySizeInt = DL.getTypeAllocSize(ArrayTy);
232   Value *ArraySize = Builder.getInt64(ArraySizeInt);
233   ArraySize->setName("array.size");
234 
235   Value *AllocatedMemRaw =
236       Builder.CreateCall(PollyMallocManaged, {ArraySize}, "mem.raw");
237   Value *AllocatedMemTyped =
238       Builder.CreatePointerCast(AllocatedMemRaw, ElemPtrTy, "mem.typed");
239   Builder.CreateStore(AllocatedMemTyped, ReplacementToArr);
240   Builder.CreateRetVoid();
241 
242   const int Priority = 0;
243   appendToGlobalCtors(M, F, Priority, ReplacementToArr);
244 
245   SmallVector<Instruction *, 4> ArrayUserInstructions;
246   // Get all instructions that use array. We need to do this weird thing
247   // because `Constant`s that contain this array neeed to be expanded into
248   // instructions so that we can replace their parameters. `Constant`s cannot
249   // be edited easily, so we choose to convert all `Constant`s to
250   // `Instruction`s and handle all of the uses of `Array` uniformly.
251   for (Use &ArrayUse : Array.uses())
252     getInstructionUsersOfValue(ArrayUse.getUser(), ArrayUserInstructions);
253 
254   for (Instruction *UserOfArrayInst : ArrayUserInstructions) {
255 
256     Builder.SetInsertPoint(UserOfArrayInst);
257     // <ty>** -> <ty>*
258     Value *ArrPtrLoaded = Builder.CreateLoad(ReplacementToArr, "arrptr.load");
259     // <ty>* -> [ty]*
260     Value *ArrPtrLoadedBitcasted = Builder.CreateBitCast(
261         ArrPtrLoaded, ArrayTy->getPointerTo(), "arrptr.bitcast");
262     rewriteOldValToNew(UserOfArrayInst, &Array, ArrPtrLoadedBitcasted, Builder);
263   }
264 }
265 
266 // We return all `allocas` that may need to be converted to a call to
267 // cudaMallocManaged.
getAllocasToBeManaged(Function & F,SmallSet<AllocaInst *,4> & Allocas)268 static void getAllocasToBeManaged(Function &F,
269                                   SmallSet<AllocaInst *, 4> &Allocas) {
270   for (BasicBlock &BB : F) {
271     for (Instruction &I : BB) {
272       auto *Alloca = dyn_cast<AllocaInst>(&I);
273       if (!Alloca)
274         continue;
275       LLVM_DEBUG(dbgs() << "Checking if (" << *Alloca << ") may be captured: ");
276 
277       if (PointerMayBeCaptured(Alloca, /* ReturnCaptures */ false,
278                                /* StoreCaptures */ true)) {
279         Allocas.insert(Alloca);
280         LLVM_DEBUG(dbgs() << "YES (captured).\n");
281       } else {
282         LLVM_DEBUG(dbgs() << "NO (not captured).\n");
283       }
284     }
285   }
286 }
287 
rewriteAllocaAsManagedMemory(AllocaInst * Alloca,const DataLayout & DL)288 static void rewriteAllocaAsManagedMemory(AllocaInst *Alloca,
289                                          const DataLayout &DL) {
290   LLVM_DEBUG(dbgs() << "rewriting: (" << *Alloca << ") to managed mem.\n");
291   Module *M = Alloca->getModule();
292   assert(M && "Alloca does not have a module");
293 
294   PollyIRBuilder Builder(M->getContext());
295   Builder.SetInsertPoint(Alloca);
296 
297   Function *MallocManagedFn =
298       getOrCreatePollyMallocManaged(*Alloca->getModule());
299   const uint64_t Size =
300       DL.getTypeAllocSize(Alloca->getType()->getElementType());
301   Value *SizeVal = Builder.getInt64(Size);
302   Value *RawManagedMem = Builder.CreateCall(MallocManagedFn, {SizeVal});
303   Value *Bitcasted = Builder.CreateBitCast(RawManagedMem, Alloca->getType());
304 
305   Function *F = Alloca->getFunction();
306   assert(F && "Alloca has invalid function");
307 
308   Bitcasted->takeName(Alloca);
309   Alloca->replaceAllUsesWith(Bitcasted);
310   Alloca->eraseFromParent();
311 
312   for (BasicBlock &BB : *F) {
313     ReturnInst *Return = dyn_cast<ReturnInst>(BB.getTerminator());
314     if (!Return)
315       continue;
316     Builder.SetInsertPoint(Return);
317 
318     Function *FreeManagedFn = getOrCreatePollyFreeManaged(*M);
319     Builder.CreateCall(FreeManagedFn, {RawManagedMem});
320   }
321 }
322 
323 // Replace all uses of `Old` with `New`, even inside `ConstantExpr`.
324 //
325 // `replaceAllUsesWith` does replace values in `ConstantExpr`. This function
326 // actually does replace it in `ConstantExpr`. The caveat is that if there is
327 // a use that is *outside* a function (say, at global declarations), we fail.
328 // So, this is meant to be used on values which we know will only be used
329 // within functions.
330 //
331 // This process works by looking through the uses of `Old`. If it finds a
332 // `ConstantExpr`, it recursively looks for the owning instruction.
333 // Then, it expands all the `ConstantExpr` to instructions and replaces
334 // `Old` with `New` in the expanded instructions.
replaceAllUsesAndConstantUses(Value * Old,Value * New,PollyIRBuilder & Builder)335 static void replaceAllUsesAndConstantUses(Value *Old, Value *New,
336                                           PollyIRBuilder &Builder) {
337   SmallVector<Instruction *, 4> UserInstructions;
338   // Get all instructions that use array. We need to do this weird thing
339   // because `Constant`s that contain this array neeed to be expanded into
340   // instructions so that we can replace their parameters. `Constant`s cannot
341   // be edited easily, so we choose to convert all `Constant`s to
342   // `Instruction`s and handle all of the uses of `Array` uniformly.
343   for (Use &ArrayUse : Old->uses())
344     getInstructionUsersOfValue(ArrayUse.getUser(), UserInstructions);
345 
346   for (Instruction *I : UserInstructions)
347     rewriteOldValToNew(I, Old, New, Builder);
348 }
349 
350 class ManagedMemoryRewritePass : public ModulePass {
351 public:
352   static char ID;
353   GPUArch Architecture;
354   GPURuntime Runtime;
355 
ManagedMemoryRewritePass()356   ManagedMemoryRewritePass() : ModulePass(ID) {}
runOnModule(Module & M)357   bool runOnModule(Module &M) override {
358     const DataLayout &DL = M.getDataLayout();
359 
360     Function *Malloc = M.getFunction("malloc");
361 
362     if (Malloc) {
363       PollyIRBuilder Builder(M.getContext());
364       Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M);
365       assert(PollyMallocManaged && "unable to create polly_mallocManaged");
366 
367       replaceAllUsesAndConstantUses(Malloc, PollyMallocManaged, Builder);
368       Malloc->eraseFromParent();
369     }
370 
371     Function *Free = M.getFunction("free");
372 
373     if (Free) {
374       PollyIRBuilder Builder(M.getContext());
375       Function *PollyFreeManaged = getOrCreatePollyFreeManaged(M);
376       assert(PollyFreeManaged && "unable to create polly_freeManaged");
377 
378       replaceAllUsesAndConstantUses(Free, PollyFreeManaged, Builder);
379       Free->eraseFromParent();
380     }
381 
382     SmallPtrSet<GlobalVariable *, 4> GlobalsToErase;
383     for (GlobalVariable &Global : M.globals())
384       replaceGlobalArray(M, DL, Global, GlobalsToErase);
385     for (GlobalVariable *G : GlobalsToErase)
386       G->eraseFromParent();
387 
388     // Rewrite allocas to cudaMallocs if we are asked to do so.
389     if (RewriteAllocas) {
390       SmallSet<AllocaInst *, 4> AllocasToBeManaged;
391       for (Function &F : M.functions())
392         getAllocasToBeManaged(F, AllocasToBeManaged);
393 
394       for (AllocaInst *Alloca : AllocasToBeManaged)
395         rewriteAllocaAsManagedMemory(Alloca, DL);
396     }
397 
398     return true;
399   }
400 };
401 } // namespace
402 char ManagedMemoryRewritePass::ID = 42;
403 
createManagedMemoryRewritePassPass(GPUArch Arch,GPURuntime Runtime)404 Pass *polly::createManagedMemoryRewritePassPass(GPUArch Arch,
405                                                 GPURuntime Runtime) {
406   ManagedMemoryRewritePass *pass = new ManagedMemoryRewritePass();
407   pass->Runtime = Runtime;
408   pass->Architecture = Arch;
409   return pass;
410 }
411 
412 INITIALIZE_PASS_BEGIN(
413     ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory",
414     "Polly - Rewrite all allocations in heap & data section to managed memory",
415     false, false)
416 INITIALIZE_PASS_DEPENDENCY(PPCGCodeGeneration);
417 INITIALIZE_PASS_DEPENDENCY(DependenceInfo);
418 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass);
419 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass);
420 INITIALIZE_PASS_DEPENDENCY(RegionInfoPass);
421 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass);
422 INITIALIZE_PASS_DEPENDENCY(ScopDetectionWrapperPass);
423 INITIALIZE_PASS_END(
424     ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory",
425     "Polly - Rewrite all allocations in heap & data section to managed memory",
426     false, false)
427