1 /*
2  * Copyright 2016-2017, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "GlobalMergePass.h"
18 
19 #include "llvm/IR/Constants.h"
20 #include "llvm/IR/DataLayout.h"
21 #include "llvm/IR/GlobalVariable.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/Pass.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/raw_ostream.h"
28 
29 #include "Context.h"
30 #include "RSAllocationUtils.h"
31 
32 #include <functional>
33 
34 #define DEBUG_TYPE "rs2spirv-global-merge"
35 
36 using namespace llvm;
37 
38 namespace rs2spirv {
39 
40 namespace {
41 
42 class GlobalMergePass : public ModulePass {
43 public:
44   static char ID;
GlobalMergePass(bool CPU=false)45   GlobalMergePass(bool CPU = false) : ModulePass(ID), mForCPU(CPU) {}
getPassName() const46   const char *getPassName() const override { return "GlobalMergePass"; }
47 
runOnModule(Module & M)48   bool runOnModule(Module &M) override {
49     DEBUG(dbgs() << "RS2SPIRVGlobalMergePass\n");
50 
51     SmallVector<GlobalVariable *, 8> Globals;
52     if (!collectGlobals(M, Globals)) {
53       return false; // Module not modified.
54     }
55 
56     SmallVector<Type *, 8> Tys;
57     Tys.reserve(Globals.size());
58 
59     Context &RS2SPIRVCtxt = Context::getInstance();
60 
61     uint32_t index = 0;
62     for (GlobalVariable *GV : Globals) {
63       Tys.push_back(GV->getValueType());
64       const char *name = GV->getName().data();
65       RS2SPIRVCtxt.addExportVarIndex(name, index);
66       index++;
67     }
68 
69     LLVMContext &LLVMCtxt = M.getContext();
70 
71     StructType *MergedTy = StructType::create(LLVMCtxt, "struct.__GPUBuffer");
72     MergedTy->setBody(Tys, false);
73 
74     // Size calculation has to consider data layout
75     const DataLayout &DL = M.getDataLayout();
76     const uint64_t BufferSize = DL.getTypeAllocSize(MergedTy);
77     RS2SPIRVCtxt.setGlobalSize(BufferSize);
78 
79     Type *BufferVarTy = mForCPU ? static_cast<Type *>(PointerType::getUnqual(
80                                       Type::getInt8Ty(M.getContext())))
81                                 : static_cast<Type *>(MergedTy);
82     GlobalVariable *MergedGV =
83         new GlobalVariable(M, BufferVarTy, false, GlobalValue::ExternalLinkage,
84                            nullptr, "__GPUBlock");
85 
86     // For CPU, create a constant struct for initial values, which has each of
87     // its fields initialized to the original value of the corresponding global
88     // variable.
89     // During the script initialization, the driver should copy these initial
90     // values to the global buffer.
91     if (mForCPU) {
92       CreateInitFunction(LLVMCtxt, M, MergedGV, MergedTy, BufferSize, Globals);
93     }
94 
95     const bool forCPU = mForCPU;
96     IntegerType *const Int32Ty = Type::getInt32Ty(LLVMCtxt);
97     ConstantInt *const Zero = ConstantInt::get(Int32Ty, 0);
98     Value *Idx[] = {Zero, nullptr};
99 
100     auto InstMaker = [forCPU, MergedGV, MergedTy,
101                       &Idx](Instruction *InsertBefore) {
102       Value *Base = MergedGV;
103       if (forCPU) {
104         LoadInst *Load = new LoadInst(MergedGV, "", InsertBefore);
105         DEBUG(Load->dump());
106         Base = new BitCastInst(Load, PointerType::getUnqual(MergedTy), "",
107                                InsertBefore);
108         DEBUG(Base->dump());
109       }
110       GetElementPtrInst *GEP = GetElementPtrInst::CreateInBounds(
111           MergedTy, Base, Idx, "", InsertBefore);
112       DEBUG(GEP->dump());
113       return GEP;
114     };
115 
116     for (size_t i = 0, e = Globals.size(); i != e; ++i) {
117       GlobalVariable *G = Globals[i];
118       Idx[1] = ConstantInt::get(Int32Ty, i);
119       ReplaceAllUsesWithNewInstructions(G, std::cref(InstMaker));
120       G->eraseFromParent();
121     }
122 
123     // Return true, as the pass modifies module.
124     return true;
125   }
126 
127 private:
128   // In the User of Value Old, replaces all references of Old with Value New
ReplaceUse(User * U,Value * Old,Value * New)129   static inline void ReplaceUse(User *U, Value *Old, Value *New) {
130     for (unsigned i = 0, n = U->getNumOperands(); i < n; ++i) {
131       if (U->getOperand(i) == Old) {
132         U->getOperandUse(i) = New;
133       }
134     }
135   }
136 
137   // Replaces each use of V with new instructions created by
138   // funcCreateAndInsert and inserted right before that use. In the cases where
139   // the use is not an instruction, but a constant expression, recursively
140   // replaces that constant expression with a newly constructed equivalent
141   // instruction, before replacing V in that new instruction.
ReplaceAllUsesWithNewInstructions(Value * V,std::function<Instruction * (Instruction *)> funcCreateAndInsert)142   static inline void ReplaceAllUsesWithNewInstructions(
143       Value *V,
144       std::function<Instruction *(Instruction *)> funcCreateAndInsert) {
145     SmallVector<User *, 8> Users(V->user_begin(), V->user_end());
146     for (User *U : Users) {
147       if (Instruction *Inst = dyn_cast<Instruction>(U)) {
148         DEBUG(dbgs() << "\nBefore replacement:\n");
149         DEBUG(Inst->dump());
150         DEBUG(dbgs() << "----\n");
151 
152         ReplaceUse(U, V, funcCreateAndInsert(Inst));
153 
154         DEBUG(Inst->dump());
155       } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(U)) {
156         auto InstMaker([CE, V, &funcCreateAndInsert](Instruction *UserOfU) {
157           Instruction *Inst = CE->getAsInstruction();
158           Inst->insertBefore(UserOfU);
159           ReplaceUse(Inst, V, funcCreateAndInsert(Inst));
160 
161           DEBUG(Inst->dump());
162           return Inst;
163         });
164         ReplaceAllUsesWithNewInstructions(U, InstMaker);
165       } else {
166         DEBUG(U->dump());
167         llvm_unreachable("Expecting only Instruction or ConstantExpr");
168       }
169     }
170   }
171 
172   static inline void
CreateInitFunction(LLVMContext & LLVMCtxt,Module & M,GlobalVariable * MergedGV,StructType * MergedTy,const uint64_t BufferSize,const SmallVectorImpl<GlobalVariable * > & Globals)173   CreateInitFunction(LLVMContext &LLVMCtxt, Module &M, GlobalVariable *MergedGV,
174                      StructType *MergedTy, const uint64_t BufferSize,
175                      const SmallVectorImpl<GlobalVariable *> &Globals) {
176     SmallVector<Constant *, 8> Initializers;
177     Initializers.reserve(Globals.size());
178     for (size_t i = 0, e = Globals.size(); i != e; ++i) {
179       GlobalVariable *G = Globals[i];
180       Initializers.push_back(G->getInitializer());
181     }
182     ArrayRef<Constant *> ArrInit(Initializers.begin(), Initializers.end());
183     Constant *MergedInitializer = ConstantStruct::get(MergedTy, ArrInit);
184     GlobalVariable *MergedInit =
185         new GlobalVariable(M, MergedTy, true, GlobalValue::InternalLinkage,
186                            MergedInitializer, "__GPUBlock0");
187 
188     Function *UserInit = M.getFunction("init");
189     // If there is no user-defined init() function, make the new global
190     // initialization function the init().
191     StringRef FName(UserInit ? ".rsov.global_init" : "init");
192     Function *Func;
193     FunctionType *FTy = FunctionType::get(Type::getVoidTy(LLVMCtxt), false);
194     Func = Function::Create(FTy, GlobalValue::ExternalLinkage, FName, &M);
195     BasicBlock *Blk = BasicBlock::Create(LLVMCtxt, "entry", Func);
196     IRBuilder<> LLVMIRBuilder(Blk);
197     LoadInst *Load = LLVMIRBuilder.CreateLoad(MergedGV);
198     LLVMIRBuilder.CreateMemCpy(Load, MergedInit, BufferSize, 0);
199     LLVMIRBuilder.CreateRetVoid();
200 
201     // If there is a user-defined init() function, add a call to the global
202     // initialization function in the beginning of that function.
203     if (UserInit) {
204       BasicBlock &EntryBlk = UserInit->getEntryBlock();
205       CallInst::Create(Func, {}, "", &EntryBlk.front());
206     }
207   }
208 
collectGlobals(Module & M,SmallVectorImpl<GlobalVariable * > & Globals)209   bool collectGlobals(Module &M, SmallVectorImpl<GlobalVariable *> &Globals) {
210     for (GlobalVariable &GV : M.globals()) {
211       assert(!GV.hasComdat() && "global variable has a comdat section");
212       assert(!GV.hasSection() && "global variable has a non-default section");
213       assert(!GV.isDeclaration() && "global variable is only a declaration");
214       assert(!GV.isThreadLocal() && "global variable is thread-local");
215       assert(GV.getType()->getAddressSpace() == 0 &&
216              "global variable has non-default address space");
217 
218       // TODO: Constants accessed by kernels should be handled differently
219       if (GV.isConstant()) {
220         continue;
221       }
222 
223       // Global Allocations are handled differently in separate passes
224       if (isRSAllocation(GV)) {
225         continue;
226       }
227 
228       Globals.push_back(&GV);
229     }
230 
231     return !Globals.empty();
232   }
233 
234   bool mForCPU;
235 };
236 
237 } // namespace
238 
239 char GlobalMergePass::ID = 0;
240 
createGlobalMergePass(bool CPU)241 ModulePass *createGlobalMergePass(bool CPU) { return new GlobalMergePass(CPU); }
242 
243 } // namespace rs2spirv
244