1 //===-- AArch64AddressTypePromotion.cpp --- Promote type for addr accesses -==//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass tries to promote the computations use to obtained a sign extended
11 // value used into memory accesses.
12 // E.g.
13 // a = add nsw i32 b, 3
14 // d = sext i32 a to i64
15 // e = getelementptr ..., i64 d
16 //
17 // =>
18 // f = sext i32 b to i64
19 // a = add nsw i64 f, 3
20 // e = getelementptr ..., i64 a
21 //
22 // This is legal to do if the computations are marked with either nsw or nuw
23 // markers.
24 // Moreover, the current heuristic is simple: it does not create new sext
25 // operations, i.e., it gives up when a sext would have forked (e.g., if
26 // a = add i32 b, c, two sexts are required to promote the computation).
27 //
28 // FIXME: This pass may be useful for other targets too.
29 // ===---------------------------------------------------------------------===//
30
31 #include "AArch64.h"
32 #include "llvm/ADT/DenseMap.h"
33 #include "llvm/ADT/SmallPtrSet.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/IR/Constants.h"
36 #include "llvm/IR/Dominators.h"
37 #include "llvm/IR/Function.h"
38 #include "llvm/IR/Instructions.h"
39 #include "llvm/IR/Module.h"
40 #include "llvm/IR/Operator.h"
41 #include "llvm/Pass.h"
42 #include "llvm/Support/CommandLine.h"
43 #include "llvm/Support/Debug.h"
44 #include "llvm/Support/raw_ostream.h"
45
46 using namespace llvm;
47
48 #define DEBUG_TYPE "aarch64-type-promotion"
49
50 static cl::opt<bool>
51 EnableAddressTypePromotion("aarch64-type-promotion", cl::Hidden,
52 cl::desc("Enable the type promotion pass"),
53 cl::init(true));
54 static cl::opt<bool>
55 EnableMerge("aarch64-type-promotion-merge", cl::Hidden,
56 cl::desc("Enable merging of redundant sexts when one is dominating"
57 " the other."),
58 cl::init(true));
59
60 //===----------------------------------------------------------------------===//
61 // AArch64AddressTypePromotion
62 //===----------------------------------------------------------------------===//
63
64 namespace llvm {
65 void initializeAArch64AddressTypePromotionPass(PassRegistry &);
66 }
67
68 namespace {
69 class AArch64AddressTypePromotion : public FunctionPass {
70
71 public:
72 static char ID;
AArch64AddressTypePromotion()73 AArch64AddressTypePromotion()
74 : FunctionPass(ID), Func(nullptr), ConsideredSExtType(nullptr) {
75 initializeAArch64AddressTypePromotionPass(*PassRegistry::getPassRegistry());
76 }
77
getPassName() const78 const char *getPassName() const override {
79 return "AArch64 Address Type Promotion";
80 }
81
82 /// Iterate over the functions and promote the computation of interesting
83 // sext instructions.
84 bool runOnFunction(Function &F) override;
85
86 private:
87 /// The current function.
88 Function *Func;
89 /// Filter out all sexts that does not have this type.
90 /// Currently initialized with Int64Ty.
91 Type *ConsideredSExtType;
92
93 // This transformation requires dominator info.
getAnalysisUsage(AnalysisUsage & AU) const94 void getAnalysisUsage(AnalysisUsage &AU) const override {
95 AU.setPreservesCFG();
96 AU.addRequired<DominatorTreeWrapperPass>();
97 AU.addPreserved<DominatorTreeWrapperPass>();
98 FunctionPass::getAnalysisUsage(AU);
99 }
100
101 typedef SmallPtrSet<Instruction *, 32> SetOfInstructions;
102 typedef SmallVector<Instruction *, 16> Instructions;
103 typedef DenseMap<Value *, Instructions> ValueToInsts;
104
105 /// Check if it is profitable to move a sext through this instruction.
106 /// Currently, we consider it is profitable if:
107 /// - Inst is used only once (no need to insert truncate).
108 /// - Inst has only one operand that will require a sext operation (we do
109 /// do not create new sext operation).
110 bool shouldGetThrough(const Instruction *Inst);
111
112 /// Check if it is possible and legal to move a sext through this
113 /// instruction.
114 /// Current heuristic considers that we can get through:
115 /// - Arithmetic operation marked with the nsw or nuw flag.
116 /// - Other sext operation.
117 /// - Truncate operation if it was just dropping sign extended bits.
118 bool canGetThrough(const Instruction *Inst);
119
120 /// Move sext operations through safe to sext instructions.
121 bool propagateSignExtension(Instructions &SExtInsts);
122
123 /// Is this sext should be considered for code motion.
124 /// We look for sext with ConsideredSExtType and uses in at least one
125 // GetElementPtrInst.
126 bool shouldConsiderSExt(const Instruction *SExt) const;
127
128 /// Collect all interesting sext operations, i.e., the ones with the right
129 /// type and used in memory accesses.
130 /// More precisely, a sext instruction is considered as interesting if it
131 /// is used in a "complex" getelementptr or it exits at least another
132 /// sext instruction that sign extended the same initial value.
133 /// A getelementptr is considered as "complex" if it has more than 2
134 // operands.
135 void analyzeSExtension(Instructions &SExtInsts);
136
137 /// Merge redundant sign extension operations in common dominator.
138 void mergeSExts(ValueToInsts &ValToSExtendedUses,
139 SetOfInstructions &ToRemove);
140 };
141 } // end anonymous namespace.
142
143 char AArch64AddressTypePromotion::ID = 0;
144
145 INITIALIZE_PASS_BEGIN(AArch64AddressTypePromotion, "aarch64-type-promotion",
146 "AArch64 Type Promotion Pass", false, false)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)147 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
148 INITIALIZE_PASS_END(AArch64AddressTypePromotion, "aarch64-type-promotion",
149 "AArch64 Type Promotion Pass", false, false)
150
151 FunctionPass *llvm::createAArch64AddressTypePromotionPass() {
152 return new AArch64AddressTypePromotion();
153 }
154
canGetThrough(const Instruction * Inst)155 bool AArch64AddressTypePromotion::canGetThrough(const Instruction *Inst) {
156 if (isa<SExtInst>(Inst))
157 return true;
158
159 const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Inst);
160 if (BinOp && isa<OverflowingBinaryOperator>(BinOp) &&
161 (BinOp->hasNoUnsignedWrap() || BinOp->hasNoSignedWrap()))
162 return true;
163
164 // sext(trunc(sext)) --> sext
165 if (isa<TruncInst>(Inst) && isa<SExtInst>(Inst->getOperand(0))) {
166 const Instruction *Opnd = cast<Instruction>(Inst->getOperand(0));
167 // Check that the truncate just drop sign extended bits.
168 if (Inst->getType()->getIntegerBitWidth() >=
169 Opnd->getOperand(0)->getType()->getIntegerBitWidth() &&
170 Inst->getOperand(0)->getType()->getIntegerBitWidth() <=
171 ConsideredSExtType->getIntegerBitWidth())
172 return true;
173 }
174
175 return false;
176 }
177
shouldGetThrough(const Instruction * Inst)178 bool AArch64AddressTypePromotion::shouldGetThrough(const Instruction *Inst) {
179 // If the type of the sext is the same as the considered one, this sext
180 // will become useless.
181 // Otherwise, we will have to do something to preserve the original value,
182 // unless it is used once.
183 if (isa<SExtInst>(Inst) &&
184 (Inst->getType() == ConsideredSExtType || Inst->hasOneUse()))
185 return true;
186
187 // If the Inst is used more that once, we may need to insert truncate
188 // operations and we don't do that at the moment.
189 if (!Inst->hasOneUse())
190 return false;
191
192 // This truncate is used only once, thus if we can get thourgh, it will become
193 // useless.
194 if (isa<TruncInst>(Inst))
195 return true;
196
197 // If both operands are not constant, a new sext will be created here.
198 // Current heuristic is: each step should be profitable.
199 // Therefore we don't allow to increase the number of sext even if it may
200 // be profitable later on.
201 if (isa<BinaryOperator>(Inst) && isa<ConstantInt>(Inst->getOperand(1)))
202 return true;
203
204 return false;
205 }
206
shouldSExtOperand(const Instruction * Inst,int OpIdx)207 static bool shouldSExtOperand(const Instruction *Inst, int OpIdx) {
208 if (isa<SelectInst>(Inst) && OpIdx == 0)
209 return false;
210 return true;
211 }
212
213 bool
shouldConsiderSExt(const Instruction * SExt) const214 AArch64AddressTypePromotion::shouldConsiderSExt(const Instruction *SExt) const {
215 if (SExt->getType() != ConsideredSExtType)
216 return false;
217
218 for (const User *U : SExt->users()) {
219 if (isa<GetElementPtrInst>(U))
220 return true;
221 }
222
223 return false;
224 }
225
226 // Input:
227 // - SExtInsts contains all the sext instructions that are used directly in
228 // GetElementPtrInst, i.e., access to memory.
229 // Algorithm:
230 // - For each sext operation in SExtInsts:
231 // Let var be the operand of sext.
232 // while it is profitable (see shouldGetThrough), legal, and safe
233 // (see canGetThrough) to move sext through var's definition:
234 // * promote the type of var's definition.
235 // * fold var into sext uses.
236 // * move sext above var's definition.
237 // * update sext operand to use the operand of var that should be sign
238 // extended (by construction there is only one).
239 //
240 // E.g.,
241 // a = ... i32 c, 3
242 // b = sext i32 a to i64 <- is it legal/safe/profitable to get through 'a'
243 // ...
244 // = b
245 // => Yes, update the code
246 // b = sext i32 c to i64
247 // a = ... i64 b, 3
248 // ...
249 // = a
250 // Iterate on 'c'.
251 bool
propagateSignExtension(Instructions & SExtInsts)252 AArch64AddressTypePromotion::propagateSignExtension(Instructions &SExtInsts) {
253 DEBUG(dbgs() << "*** Propagate Sign Extension ***\n");
254
255 bool LocalChange = false;
256 SetOfInstructions ToRemove;
257 ValueToInsts ValToSExtendedUses;
258 while (!SExtInsts.empty()) {
259 // Get through simple chain.
260 Instruction *SExt = SExtInsts.pop_back_val();
261
262 DEBUG(dbgs() << "Consider:\n" << *SExt << '\n');
263
264 // If this SExt has already been merged continue.
265 if (SExt->use_empty() && ToRemove.count(SExt)) {
266 DEBUG(dbgs() << "No uses => marked as delete\n");
267 continue;
268 }
269
270 // Now try to get through the chain of definitions.
271 while (auto *Inst = dyn_cast<Instruction>(SExt->getOperand(0))) {
272 DEBUG(dbgs() << "Try to get through:\n" << *Inst << '\n');
273 if (!canGetThrough(Inst) || !shouldGetThrough(Inst)) {
274 // We cannot get through something that is not an Instruction
275 // or not safe to SExt.
276 DEBUG(dbgs() << "Cannot get through\n");
277 break;
278 }
279
280 LocalChange = true;
281 // If this is a sign extend, it becomes useless.
282 if (isa<SExtInst>(Inst) || isa<TruncInst>(Inst)) {
283 DEBUG(dbgs() << "SExt or trunc, mark it as to remove\n");
284 // We cannot use replaceAllUsesWith here because we may trigger some
285 // assertion on the type as all involved sext operation may have not
286 // been moved yet.
287 while (!Inst->use_empty()) {
288 Use &U = *Inst->use_begin();
289 Instruction *User = dyn_cast<Instruction>(U.getUser());
290 assert(User && "User of sext is not an Instruction!");
291 User->setOperand(U.getOperandNo(), SExt);
292 }
293 ToRemove.insert(Inst);
294 SExt->setOperand(0, Inst->getOperand(0));
295 SExt->moveBefore(Inst);
296 continue;
297 }
298
299 // Get through the Instruction:
300 // 1. Update its type.
301 // 2. Replace the uses of SExt by Inst.
302 // 3. Sign extend each operand that needs to be sign extended.
303
304 // Step #1.
305 Inst->mutateType(SExt->getType());
306 // Step #2.
307 SExt->replaceAllUsesWith(Inst);
308 // Step #3.
309 Instruction *SExtForOpnd = SExt;
310
311 DEBUG(dbgs() << "Propagate SExt to operands\n");
312 for (int OpIdx = 0, EndOpIdx = Inst->getNumOperands(); OpIdx != EndOpIdx;
313 ++OpIdx) {
314 DEBUG(dbgs() << "Operand:\n" << *(Inst->getOperand(OpIdx)) << '\n');
315 if (Inst->getOperand(OpIdx)->getType() == SExt->getType() ||
316 !shouldSExtOperand(Inst, OpIdx)) {
317 DEBUG(dbgs() << "No need to propagate\n");
318 continue;
319 }
320 // Check if we can statically sign extend the operand.
321 Value *Opnd = Inst->getOperand(OpIdx);
322 if (const ConstantInt *Cst = dyn_cast<ConstantInt>(Opnd)) {
323 DEBUG(dbgs() << "Statically sign extend\n");
324 Inst->setOperand(OpIdx, ConstantInt::getSigned(SExt->getType(),
325 Cst->getSExtValue()));
326 continue;
327 }
328 // UndefValue are typed, so we have to statically sign extend them.
329 if (isa<UndefValue>(Opnd)) {
330 DEBUG(dbgs() << "Statically sign extend\n");
331 Inst->setOperand(OpIdx, UndefValue::get(SExt->getType()));
332 continue;
333 }
334
335 // Otherwise we have to explicity sign extend it.
336 assert(SExtForOpnd &&
337 "Only one operand should have been sign extended");
338
339 SExtForOpnd->setOperand(0, Opnd);
340
341 DEBUG(dbgs() << "Move before:\n" << *Inst << "\nSign extend\n");
342 // Move the sign extension before the insertion point.
343 SExtForOpnd->moveBefore(Inst);
344 Inst->setOperand(OpIdx, SExtForOpnd);
345 // If more sext are required, new instructions will have to be created.
346 SExtForOpnd = nullptr;
347 }
348 if (SExtForOpnd == SExt) {
349 DEBUG(dbgs() << "Sign extension is useless now\n");
350 ToRemove.insert(SExt);
351 break;
352 }
353 }
354
355 // If the use is already of the right type, connect its uses to its argument
356 // and delete it.
357 // This can happen for an Instruction all uses of which are sign extended.
358 if (!ToRemove.count(SExt) &&
359 SExt->getType() == SExt->getOperand(0)->getType()) {
360 DEBUG(dbgs() << "Sign extension is useless, attach its use to "
361 "its argument\n");
362 SExt->replaceAllUsesWith(SExt->getOperand(0));
363 ToRemove.insert(SExt);
364 } else
365 ValToSExtendedUses[SExt->getOperand(0)].push_back(SExt);
366 }
367
368 if (EnableMerge)
369 mergeSExts(ValToSExtendedUses, ToRemove);
370
371 // Remove all instructions marked as ToRemove.
372 for (Instruction *I: ToRemove)
373 I->eraseFromParent();
374 return LocalChange;
375 }
376
mergeSExts(ValueToInsts & ValToSExtendedUses,SetOfInstructions & ToRemove)377 void AArch64AddressTypePromotion::mergeSExts(ValueToInsts &ValToSExtendedUses,
378 SetOfInstructions &ToRemove) {
379 DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
380
381 for (auto &Entry : ValToSExtendedUses) {
382 Instructions &Insts = Entry.second;
383 Instructions CurPts;
384 for (Instruction *Inst : Insts) {
385 if (ToRemove.count(Inst))
386 continue;
387 bool inserted = false;
388 for (auto &Pt : CurPts) {
389 if (DT.dominates(Inst, Pt)) {
390 DEBUG(dbgs() << "Replace all uses of:\n" << *Pt << "\nwith:\n"
391 << *Inst << '\n');
392 Pt->replaceAllUsesWith(Inst);
393 ToRemove.insert(Pt);
394 Pt = Inst;
395 inserted = true;
396 break;
397 }
398 if (!DT.dominates(Pt, Inst))
399 // Give up if we need to merge in a common dominator as the
400 // expermients show it is not profitable.
401 continue;
402
403 DEBUG(dbgs() << "Replace all uses of:\n" << *Inst << "\nwith:\n"
404 << *Pt << '\n');
405 Inst->replaceAllUsesWith(Pt);
406 ToRemove.insert(Inst);
407 inserted = true;
408 break;
409 }
410 if (!inserted)
411 CurPts.push_back(Inst);
412 }
413 }
414 }
415
analyzeSExtension(Instructions & SExtInsts)416 void AArch64AddressTypePromotion::analyzeSExtension(Instructions &SExtInsts) {
417 DEBUG(dbgs() << "*** Analyze Sign Extensions ***\n");
418
419 DenseMap<Value *, Instruction *> SeenChains;
420
421 for (auto &BB : *Func) {
422 for (auto &II : BB) {
423 Instruction *SExt = &II;
424
425 // Collect all sext operation per type.
426 if (!isa<SExtInst>(SExt) || !shouldConsiderSExt(SExt))
427 continue;
428
429 DEBUG(dbgs() << "Found:\n" << (*SExt) << '\n');
430
431 // Cases where we actually perform the optimization:
432 // 1. SExt is used in a getelementptr with more than 2 operand =>
433 // likely we can merge some computation if they are done on 64 bits.
434 // 2. The beginning of the SExt chain is SExt several time. =>
435 // code sharing is possible.
436
437 bool insert = false;
438 // #1.
439 for (const User *U : SExt->users()) {
440 const Instruction *Inst = dyn_cast<GetElementPtrInst>(U);
441 if (Inst && Inst->getNumOperands() > 2) {
442 DEBUG(dbgs() << "Interesting use in GetElementPtrInst\n" << *Inst
443 << '\n');
444 insert = true;
445 break;
446 }
447 }
448
449 // #2.
450 // Check the head of the chain.
451 Instruction *Inst = SExt;
452 Value *Last;
453 do {
454 int OpdIdx = 0;
455 const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Inst);
456 if (BinOp && isa<ConstantInt>(BinOp->getOperand(0)))
457 OpdIdx = 1;
458 Last = Inst->getOperand(OpdIdx);
459 Inst = dyn_cast<Instruction>(Last);
460 } while (Inst && canGetThrough(Inst) && shouldGetThrough(Inst));
461
462 DEBUG(dbgs() << "Head of the chain:\n" << *Last << '\n');
463 DenseMap<Value *, Instruction *>::iterator AlreadySeen =
464 SeenChains.find(Last);
465 if (insert || AlreadySeen != SeenChains.end()) {
466 DEBUG(dbgs() << "Insert\n");
467 SExtInsts.push_back(SExt);
468 if (AlreadySeen != SeenChains.end() && AlreadySeen->second != nullptr) {
469 DEBUG(dbgs() << "Insert chain member\n");
470 SExtInsts.push_back(AlreadySeen->second);
471 SeenChains[Last] = nullptr;
472 }
473 } else {
474 DEBUG(dbgs() << "Record its chain membership\n");
475 SeenChains[Last] = SExt;
476 }
477 }
478 }
479 }
480
runOnFunction(Function & F)481 bool AArch64AddressTypePromotion::runOnFunction(Function &F) {
482 if (!EnableAddressTypePromotion || F.isDeclaration())
483 return false;
484 Func = &F;
485 ConsideredSExtType = Type::getInt64Ty(Func->getContext());
486
487 DEBUG(dbgs() << "*** " << getPassName() << ": " << Func->getName() << '\n');
488
489 Instructions SExtInsts;
490 analyzeSExtension(SExtInsts);
491 return propagateSignExtension(SExtInsts);
492 }
493