1 //===- LoadCombine.cpp - Combine Adjacent Loads ---------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 /// \file
10 /// This transformation combines adjacent loads.
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Transforms/Scalar.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/Statistic.h"
17 #include "llvm/Analysis/AliasAnalysis.h"
18 #include "llvm/Analysis/AliasSetTracker.h"
19 #include "llvm/Analysis/GlobalsModRef.h"
20 #include "llvm/Analysis/TargetFolder.h"
21 #include "llvm/IR/DataLayout.h"
22 #include "llvm/IR/Function.h"
23 #include "llvm/IR/IRBuilder.h"
24 #include "llvm/IR/Instructions.h"
25 #include "llvm/IR/Module.h"
26 #include "llvm/Pass.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/MathExtras.h"
29 #include "llvm/Support/raw_ostream.h"
30 
31 using namespace llvm;
32 
33 #define DEBUG_TYPE "load-combine"
34 
35 STATISTIC(NumLoadsAnalyzed, "Number of loads analyzed for combining");
36 STATISTIC(NumLoadsCombined, "Number of loads combined");
37 
38 #define LDCOMBINE_NAME "Combine Adjacent Loads"
39 
40 namespace {
41 struct PointerOffsetPair {
42   Value *Pointer;
43   APInt Offset;
44 };
45 
46 struct LoadPOPPair {
47   LoadPOPPair() = default;
LoadPOPPair__anone6fd0b320111::LoadPOPPair48   LoadPOPPair(LoadInst *L, PointerOffsetPair P, unsigned O)
49       : Load(L), POP(P), InsertOrder(O) {}
50   LoadInst *Load;
51   PointerOffsetPair POP;
52   /// \brief The new load needs to be created before the first load in IR order.
53   unsigned InsertOrder;
54 };
55 
56 class LoadCombine : public BasicBlockPass {
57   LLVMContext *C;
58   AliasAnalysis *AA;
59 
60 public:
LoadCombine()61   LoadCombine() : BasicBlockPass(ID), C(nullptr), AA(nullptr) {
62     initializeLoadCombinePass(*PassRegistry::getPassRegistry());
63   }
64 
65   using llvm::Pass::doInitialization;
66   bool doInitialization(Function &) override;
67   bool runOnBasicBlock(BasicBlock &BB) override;
getAnalysisUsage(AnalysisUsage & AU) const68   void getAnalysisUsage(AnalysisUsage &AU) const override {
69     AU.setPreservesCFG();
70     AU.addRequired<AAResultsWrapperPass>();
71     AU.addPreserved<GlobalsAAWrapperPass>();
72   }
73 
getPassName() const74   const char *getPassName() const override { return LDCOMBINE_NAME; }
75   static char ID;
76 
77   typedef IRBuilder<TargetFolder> BuilderTy;
78 
79 private:
80   BuilderTy *Builder;
81 
82   PointerOffsetPair getPointerOffsetPair(LoadInst &);
83   bool combineLoads(DenseMap<const Value *, SmallVector<LoadPOPPair, 8>> &);
84   bool aggregateLoads(SmallVectorImpl<LoadPOPPair> &);
85   bool combineLoads(SmallVectorImpl<LoadPOPPair> &);
86 };
87 }
88 
doInitialization(Function & F)89 bool LoadCombine::doInitialization(Function &F) {
90   DEBUG(dbgs() << "LoadCombine function: " << F.getName() << "\n");
91   C = &F.getContext();
92   return true;
93 }
94 
getPointerOffsetPair(LoadInst & LI)95 PointerOffsetPair LoadCombine::getPointerOffsetPair(LoadInst &LI) {
96   auto &DL = LI.getModule()->getDataLayout();
97 
98   PointerOffsetPair POP;
99   POP.Pointer = LI.getPointerOperand();
100   unsigned BitWidth = DL.getPointerSizeInBits(LI.getPointerAddressSpace());
101   POP.Offset = APInt(BitWidth, 0);
102 
103   while (isa<BitCastInst>(POP.Pointer) || isa<GetElementPtrInst>(POP.Pointer)) {
104     if (auto *GEP = dyn_cast<GetElementPtrInst>(POP.Pointer)) {
105       APInt LastOffset = POP.Offset;
106       if (!GEP->accumulateConstantOffset(DL, POP.Offset)) {
107         // Can't handle GEPs with variable indices.
108         POP.Offset = LastOffset;
109         return POP;
110       }
111       POP.Pointer = GEP->getPointerOperand();
112     } else if (auto *BC = dyn_cast<BitCastInst>(POP.Pointer)) {
113       POP.Pointer = BC->getOperand(0);
114     }
115   }
116   return POP;
117 }
118 
combineLoads(DenseMap<const Value *,SmallVector<LoadPOPPair,8>> & LoadMap)119 bool LoadCombine::combineLoads(
120     DenseMap<const Value *, SmallVector<LoadPOPPair, 8>> &LoadMap) {
121   bool Combined = false;
122   for (auto &Loads : LoadMap) {
123     if (Loads.second.size() < 2)
124       continue;
125     std::sort(Loads.second.begin(), Loads.second.end(),
126               [](const LoadPOPPair &A, const LoadPOPPair &B) {
127                 return A.POP.Offset.slt(B.POP.Offset);
128               });
129     if (aggregateLoads(Loads.second))
130       Combined = true;
131   }
132   return Combined;
133 }
134 
135 /// \brief Try to aggregate loads from a sorted list of loads to be combined.
136 ///
137 /// It is guaranteed that no writes occur between any of the loads. All loads
138 /// have the same base pointer. There are at least two loads.
aggregateLoads(SmallVectorImpl<LoadPOPPair> & Loads)139 bool LoadCombine::aggregateLoads(SmallVectorImpl<LoadPOPPair> &Loads) {
140   assert(Loads.size() >= 2 && "Insufficient loads!");
141   LoadInst *BaseLoad = nullptr;
142   SmallVector<LoadPOPPair, 8> AggregateLoads;
143   bool Combined = false;
144   bool ValidPrevOffset = false;
145   APInt PrevOffset;
146   uint64_t PrevSize = 0;
147   for (auto &L : Loads) {
148     if (ValidPrevOffset == false) {
149       BaseLoad = L.Load;
150       PrevOffset = L.POP.Offset;
151       PrevSize = L.Load->getModule()->getDataLayout().getTypeStoreSize(
152           L.Load->getType());
153       AggregateLoads.push_back(L);
154       ValidPrevOffset = true;
155       continue;
156     }
157     if (L.Load->getAlignment() > BaseLoad->getAlignment())
158       continue;
159     APInt PrevEnd = PrevOffset + PrevSize;
160     if (L.POP.Offset.sgt(PrevEnd)) {
161       // No other load will be combinable
162       if (combineLoads(AggregateLoads))
163         Combined = true;
164       AggregateLoads.clear();
165       ValidPrevOffset = false;
166       continue;
167     }
168     if (L.POP.Offset != PrevEnd)
169       // This load is offset less than the size of the last load.
170       // FIXME: We may want to handle this case.
171       continue;
172     PrevOffset = L.POP.Offset;
173     PrevSize = L.Load->getModule()->getDataLayout().getTypeStoreSize(
174         L.Load->getType());
175     AggregateLoads.push_back(L);
176   }
177   if (combineLoads(AggregateLoads))
178     Combined = true;
179   return Combined;
180 }
181 
182 /// \brief Given a list of combinable load. Combine the maximum number of them.
combineLoads(SmallVectorImpl<LoadPOPPair> & Loads)183 bool LoadCombine::combineLoads(SmallVectorImpl<LoadPOPPair> &Loads) {
184   // Remove loads from the end while the size is not a power of 2.
185   unsigned TotalSize = 0;
186   for (const auto &L : Loads)
187     TotalSize += L.Load->getType()->getPrimitiveSizeInBits();
188   while (TotalSize != 0 && !isPowerOf2_32(TotalSize))
189     TotalSize -= Loads.pop_back_val().Load->getType()->getPrimitiveSizeInBits();
190   if (Loads.size() < 2)
191     return false;
192 
193   DEBUG({
194     dbgs() << "***** Combining Loads ******\n";
195     for (const auto &L : Loads) {
196       dbgs() << L.POP.Offset << ": " << *L.Load << "\n";
197     }
198   });
199 
200   // Find first load. This is where we put the new load.
201   LoadPOPPair FirstLP;
202   FirstLP.InsertOrder = -1u;
203   for (const auto &L : Loads)
204     if (L.InsertOrder < FirstLP.InsertOrder)
205       FirstLP = L;
206 
207   unsigned AddressSpace =
208       FirstLP.POP.Pointer->getType()->getPointerAddressSpace();
209 
210   Builder->SetInsertPoint(FirstLP.Load);
211   Value *Ptr = Builder->CreateConstGEP1_64(
212       Builder->CreatePointerCast(Loads[0].POP.Pointer,
213                                  Builder->getInt8PtrTy(AddressSpace)),
214       Loads[0].POP.Offset.getSExtValue());
215   LoadInst *NewLoad = new LoadInst(
216       Builder->CreatePointerCast(
217           Ptr, PointerType::get(IntegerType::get(Ptr->getContext(), TotalSize),
218                                 Ptr->getType()->getPointerAddressSpace())),
219       Twine(Loads[0].Load->getName()) + ".combined", false,
220       Loads[0].Load->getAlignment(), FirstLP.Load);
221 
222   for (const auto &L : Loads) {
223     Builder->SetInsertPoint(L.Load);
224     Value *V = Builder->CreateExtractInteger(
225         L.Load->getModule()->getDataLayout(), NewLoad,
226         cast<IntegerType>(L.Load->getType()),
227         (L.POP.Offset - Loads[0].POP.Offset).getZExtValue(), "combine.extract");
228     L.Load->replaceAllUsesWith(V);
229   }
230 
231   NumLoadsCombined = NumLoadsCombined + Loads.size();
232   return true;
233 }
234 
runOnBasicBlock(BasicBlock & BB)235 bool LoadCombine::runOnBasicBlock(BasicBlock &BB) {
236   if (skipBasicBlock(BB))
237     return false;
238 
239   AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
240 
241   IRBuilder<TargetFolder> TheBuilder(
242       BB.getContext(), TargetFolder(BB.getModule()->getDataLayout()));
243   Builder = &TheBuilder;
244 
245   DenseMap<const Value *, SmallVector<LoadPOPPair, 8>> LoadMap;
246   AliasSetTracker AST(*AA);
247 
248   bool Combined = false;
249   unsigned Index = 0;
250   for (auto &I : BB) {
251     if (I.mayThrow() || (I.mayWriteToMemory() && AST.containsUnknown(&I))) {
252       if (combineLoads(LoadMap))
253         Combined = true;
254       LoadMap.clear();
255       AST.clear();
256       continue;
257     }
258     LoadInst *LI = dyn_cast<LoadInst>(&I);
259     if (!LI)
260       continue;
261     ++NumLoadsAnalyzed;
262     if (!LI->isSimple() || !LI->getType()->isIntegerTy())
263       continue;
264     auto POP = getPointerOffsetPair(*LI);
265     if (!POP.Pointer)
266       continue;
267     LoadMap[POP.Pointer].push_back(LoadPOPPair(LI, POP, Index++));
268     AST.add(LI);
269   }
270   if (combineLoads(LoadMap))
271     Combined = true;
272   return Combined;
273 }
274 
275 char LoadCombine::ID = 0;
276 
createLoadCombinePass()277 BasicBlockPass *llvm::createLoadCombinePass() {
278   return new LoadCombine();
279 }
280 
281 INITIALIZE_PASS_BEGIN(LoadCombine, "load-combine", LDCOMBINE_NAME, false, false)
282 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
283 INITIALIZE_PASS_END(LoadCombine, "load-combine", LDCOMBINE_NAME, false, false)
284