1 /*
2 * Copyright 2016, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "GlobalMergePass.h"
18
19 #include "llvm/IR/Constants.h"
20 #include "llvm/IR/DataLayout.h"
21 #include "llvm/IR/GlobalVariable.h"
22 #include "llvm/IR/Instructions.h"
23 #include "llvm/IR/Module.h"
24 #include "llvm/Pass.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/raw_ostream.h"
27
28 #include "RSAllocationUtils.h"
29
30 #define DEBUG_TYPE "rs2spirv-global-merge"
31
32 using namespace llvm;
33
34 namespace rs2spirv {
35
36 namespace {
37
38 class GlobalMergePass : public ModulePass {
39 public:
40 static char ID;
GlobalMergePass()41 GlobalMergePass() : ModulePass(ID) {}
getPassName() const42 const char *getPassName() const override { return "GlobalMergePass"; }
43
runOnModule(Module & M)44 bool runOnModule(Module &M) override {
45 DEBUG(dbgs() << "RS2SPIRVGlobalMergePass\n");
46 DEBUG(M.dump());
47
48 const auto &DL = M.getDataLayout();
49 SmallVector<GlobalVariable *, 8> Globals;
50 const bool CollectRes = collectGlobals(M, Globals);
51 if (!CollectRes)
52 return false; // Module not modified.
53
54 IntegerType *Int32Ty = Type::getInt32Ty(M.getContext());
55 uint64_t MergedSize = 0;
56 SmallVector<Type *, 8> Tys;
57 Tys.reserve(Globals.size());
58
59 for (auto *GV : Globals) {
60 auto *Ty = GV->getValueType();
61 MergedSize += DL.getTypeAllocSize(Ty);
62 Tys.push_back(Ty);
63 }
64
65 auto *MergedTy = StructType::create(M.getContext(), "struct.__GPUBuffer");
66 MergedTy->setBody(Tys, false);
67 DEBUG(MergedTy->dump());
68 auto *MergedGV =
69 new GlobalVariable(M, MergedTy, false, GlobalValue::ExternalLinkage,
70 nullptr, "__GPUBlock");
71 MergedGV->setInitializer(nullptr); // TODO: Emit initializers for CPU code.
72
73 Value *Idx[2] = {ConstantInt::get(Int32Ty, 0), nullptr};
74
75 for (size_t i = 0, e = Globals.size(); i != e; ++i) {
76 auto *G = Globals[i];
77 Idx[1] = ConstantInt::get(Int32Ty, i);
78
79 // Keep users in a vector - they get implicitly removed
80 // in the loop below, which would invalidate users() iterators.
81 std::vector<User *> Users(G->user_begin(), G->user_end());
82 for (auto *User : Users) {
83 DEBUG(dbgs() << "User: ");
84 DEBUG(User->dump());
85 auto *Inst = dyn_cast<Instruction>(User);
86
87 // TODO: Consider what should actually happen. Global variables can
88 // appear in ConstantExprs, but this case requires fixing the LLVM-SPIRV
89 // converter, which currently emits ill-formed SPIR-V code.
90 if (!Inst) {
91 errs() << "Found a global variable user that is not an Instruction\n";
92 assert(false);
93 return true; // Module may have been modified.
94 }
95
96 auto *GEP = GetElementPtrInst::CreateInBounds(MergedTy, MergedGV, Idx,
97 "gpu_gep", Inst);
98 for (unsigned k = 0, k_e = User->getNumOperands(); k != k_e; ++k)
99 if (User->getOperand(k) == G)
100 User->setOperand(k, GEP);
101 }
102
103 // TODO: Investigate emitting a GlobalAlias for each global variable.
104 G->eraseFromParent();
105 }
106
107 // Return true, as the pass modifies module.
108 return true;
109 }
110
111 private:
collectGlobals(Module & M,SmallVectorImpl<GlobalVariable * > & Globals)112 bool collectGlobals(Module &M, SmallVectorImpl<GlobalVariable *> &Globals) {
113 for (auto &GV : M.globals()) {
114 // TODO: Rethink what should happen with global statics.
115 if (GV.isDeclaration() || GV.isThreadLocal() || GV.hasSection())
116 continue;
117
118 if (isRSAllocation(GV))
119 continue;
120
121 DEBUG(GV.dump());
122 auto *PT = cast<PointerType>(GV.getType());
123
124 const unsigned AddressSpace = PT->getAddressSpace();
125 if (AddressSpace != 0) {
126 errs() << "Unknown address space! (" << AddressSpace
127 << ")\nGlobalMergePass failed!\n";
128 return false;
129 }
130
131 Globals.push_back(&GV);
132 }
133
134 return !Globals.empty();
135 }
136 };
137 } // namespace
138
139 char GlobalMergePass::ID = 0;
140
createGlobalMergePass()141 ModulePass *createGlobalMergePass() { return new GlobalMergePass(); }
142
143 } // namespace rs2spirv
144