1 /*
2  * Copyright 2016, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "GlobalMergePass.h"
18 
19 #include "llvm/IR/Constants.h"
20 #include "llvm/IR/DataLayout.h"
21 #include "llvm/IR/GlobalVariable.h"
22 #include "llvm/IR/Instructions.h"
23 #include "llvm/IR/Module.h"
24 #include "llvm/Pass.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/raw_ostream.h"
27 
28 #include "RSAllocationUtils.h"
29 
30 #define DEBUG_TYPE "rs2spirv-global-merge"
31 
32 using namespace llvm;
33 
34 namespace rs2spirv {
35 
36 namespace {
37 
38 class GlobalMergePass : public ModulePass {
39 public:
40   static char ID;
GlobalMergePass()41   GlobalMergePass() : ModulePass(ID) {}
getPassName() const42   const char *getPassName() const override { return "GlobalMergePass"; }
43 
runOnModule(Module & M)44   bool runOnModule(Module &M) override {
45     DEBUG(dbgs() << "RS2SPIRVGlobalMergePass\n");
46     DEBUG(M.dump());
47 
48     const auto &DL = M.getDataLayout();
49     SmallVector<GlobalVariable *, 8> Globals;
50     const bool CollectRes = collectGlobals(M, Globals);
51     if (!CollectRes)
52       return false; // Module not modified.
53 
54     IntegerType *Int32Ty = Type::getInt32Ty(M.getContext());
55     uint64_t MergedSize = 0;
56     SmallVector<Type *, 8> Tys;
57     Tys.reserve(Globals.size());
58 
59     for (auto *GV : Globals) {
60       auto *Ty = GV->getValueType();
61       MergedSize += DL.getTypeAllocSize(Ty);
62       Tys.push_back(Ty);
63     }
64 
65     auto *MergedTy = StructType::create(M.getContext(), "struct.__GPUBuffer");
66     MergedTy->setBody(Tys, false);
67     DEBUG(MergedTy->dump());
68     auto *MergedGV =
69         new GlobalVariable(M, MergedTy, false, GlobalValue::ExternalLinkage,
70                            nullptr, "__GPUBlock");
71     MergedGV->setInitializer(nullptr); // TODO: Emit initializers for CPU code.
72 
73     Value *Idx[2] = {ConstantInt::get(Int32Ty, 0), nullptr};
74 
75     for (size_t i = 0, e = Globals.size(); i != e; ++i) {
76       auto *G = Globals[i];
77       Idx[1] = ConstantInt::get(Int32Ty, i);
78 
79       // Keep users in a vector - they get implicitly removed
80       // in the loop below, which would invalidate users() iterators.
81       std::vector<User *> Users(G->user_begin(), G->user_end());
82       for (auto *User : Users) {
83         DEBUG(dbgs() << "User: ");
84         DEBUG(User->dump());
85         auto *Inst = dyn_cast<Instruction>(User);
86 
87         // TODO: Consider what should actually happen. Global variables can
88         // appear in ConstantExprs, but this case requires fixing the LLVM-SPIRV
89         // converter, which currently emits ill-formed SPIR-V code.
90         if (!Inst) {
91           errs() << "Found a global variable user that is not an Instruction\n";
92           assert(false);
93           return true; // Module may have been modified.
94         }
95 
96         auto *GEP = GetElementPtrInst::CreateInBounds(MergedTy, MergedGV, Idx,
97                                                       "gpu_gep", Inst);
98         for (unsigned k = 0, k_e = User->getNumOperands(); k != k_e; ++k)
99           if (User->getOperand(k) == G)
100             User->setOperand(k, GEP);
101       }
102 
103       // TODO: Investigate emitting a GlobalAlias for each global variable.
104       G->eraseFromParent();
105     }
106 
107     // Return true, as the pass modifies module.
108     return true;
109   }
110 
111 private:
collectGlobals(Module & M,SmallVectorImpl<GlobalVariable * > & Globals)112   bool collectGlobals(Module &M, SmallVectorImpl<GlobalVariable *> &Globals) {
113     for (auto &GV : M.globals()) {
114       // TODO: Rethink what should happen with global statics.
115       if (GV.isDeclaration() || GV.isThreadLocal() || GV.hasSection())
116         continue;
117 
118       if (isRSAllocation(GV))
119         continue;
120 
121       DEBUG(GV.dump());
122       auto *PT = cast<PointerType>(GV.getType());
123 
124       const unsigned AddressSpace = PT->getAddressSpace();
125       if (AddressSpace != 0) {
126         errs() << "Unknown address space! (" << AddressSpace
127                << ")\nGlobalMergePass failed!\n";
128         return false;
129       }
130 
131       Globals.push_back(&GV);
132     }
133 
134     return !Globals.empty();
135   }
136 };
137 } // namespace
138 
139 char GlobalMergePass::ID = 0;
140 
createGlobalMergePass()141 ModulePass *createGlobalMergePass() { return new GlobalMergePass(); }
142 
143 } // namespace rs2spirv
144