1 //===--- PartiallyInlineLibCalls.cpp - Partially inline libcalls ----------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass tries to partially inline the fast path of well-known library
11 // functions, such as using square-root instructions for cases where sqrt()
12 // does not need to set errno.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/Analysis/TargetTransformInfo.h"
17 #include "llvm/IR/IRBuilder.h"
18 #include "llvm/IR/Intrinsics.h"
19 #include "llvm/Pass.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Analysis/TargetLibraryInfo.h"
22 #include "llvm/Transforms/Scalar.h"
23 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
24 
25 using namespace llvm;
26 
27 #define DEBUG_TYPE "partially-inline-libcalls"
28 
29 namespace {
30   class PartiallyInlineLibCalls : public FunctionPass {
31   public:
32     static char ID;
33 
PartiallyInlineLibCalls()34     PartiallyInlineLibCalls() :
35       FunctionPass(ID) {
36       initializePartiallyInlineLibCallsPass(*PassRegistry::getPassRegistry());
37     }
38 
39     void getAnalysisUsage(AnalysisUsage &AU) const override;
40     bool runOnFunction(Function &F) override;
41 
42   private:
43     /// Optimize calls to sqrt.
44     bool optimizeSQRT(CallInst *Call, Function *CalledFunc,
45                       BasicBlock &CurrBB, Function::iterator &BB);
46   };
47 
48   char PartiallyInlineLibCalls::ID = 0;
49 }
50 
51 INITIALIZE_PASS(PartiallyInlineLibCalls, "partially-inline-libcalls",
52                 "Partially inline calls to library functions", false, false)
53 
getAnalysisUsage(AnalysisUsage & AU) const54 void PartiallyInlineLibCalls::getAnalysisUsage(AnalysisUsage &AU) const {
55   AU.addRequired<TargetLibraryInfoWrapperPass>();
56   AU.addRequired<TargetTransformInfoWrapperPass>();
57   FunctionPass::getAnalysisUsage(AU);
58 }
59 
runOnFunction(Function & F)60 bool PartiallyInlineLibCalls::runOnFunction(Function &F) {
61   bool Changed = false;
62   Function::iterator CurrBB;
63   TargetLibraryInfo *TLI =
64       &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
65   const TargetTransformInfo *TTI =
66       &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
67   for (Function::iterator BB = F.begin(), BE = F.end(); BB != BE;) {
68     CurrBB = BB++;
69 
70     for (BasicBlock::iterator II = CurrBB->begin(), IE = CurrBB->end();
71          II != IE; ++II) {
72       CallInst *Call = dyn_cast<CallInst>(&*II);
73       Function *CalledFunc;
74 
75       if (!Call || !(CalledFunc = Call->getCalledFunction()))
76         continue;
77 
78       // Skip if function either has local linkage or is not a known library
79       // function.
80       LibFunc::Func LibFunc;
81       if (CalledFunc->hasLocalLinkage() || !CalledFunc->hasName() ||
82           !TLI->getLibFunc(CalledFunc->getName(), LibFunc))
83         continue;
84 
85       switch (LibFunc) {
86       case LibFunc::sqrtf:
87       case LibFunc::sqrt:
88         if (TTI->haveFastSqrt(Call->getType()) &&
89             optimizeSQRT(Call, CalledFunc, *CurrBB, BB))
90           break;
91         continue;
92       default:
93         continue;
94       }
95 
96       Changed = true;
97       break;
98     }
99   }
100 
101   return Changed;
102 }
103 
optimizeSQRT(CallInst * Call,Function * CalledFunc,BasicBlock & CurrBB,Function::iterator & BB)104 bool PartiallyInlineLibCalls::optimizeSQRT(CallInst *Call,
105                                            Function *CalledFunc,
106                                            BasicBlock &CurrBB,
107                                            Function::iterator &BB) {
108   // There is no need to change the IR, since backend will emit sqrt
109   // instruction if the call has already been marked read-only.
110   if (Call->onlyReadsMemory())
111     return false;
112 
113   // The call must have the expected result type.
114   if (!Call->getType()->isFloatingPointTy())
115     return false;
116 
117   // Do the following transformation:
118   //
119   // (before)
120   // dst = sqrt(src)
121   //
122   // (after)
123   // v0 = sqrt_noreadmem(src) # native sqrt instruction.
124   // if (v0 is a NaN)
125   //   v1 = sqrt(src)         # library call.
126   // dst = phi(v0, v1)
127   //
128 
129   // Move all instructions following Call to newly created block JoinBB.
130   // Create phi and replace all uses.
131   BasicBlock *JoinBB = llvm::SplitBlock(&CurrBB, Call->getNextNode());
132   IRBuilder<> Builder(JoinBB, JoinBB->begin());
133   PHINode *Phi = Builder.CreatePHI(Call->getType(), 2);
134   Call->replaceAllUsesWith(Phi);
135 
136   // Create basic block LibCallBB and insert a call to library function sqrt.
137   BasicBlock *LibCallBB = BasicBlock::Create(CurrBB.getContext(), "call.sqrt",
138                                              CurrBB.getParent(), JoinBB);
139   Builder.SetInsertPoint(LibCallBB);
140   Instruction *LibCall = Call->clone();
141   Builder.Insert(LibCall);
142   Builder.CreateBr(JoinBB);
143 
144   // Add attribute "readnone" so that backend can use a native sqrt instruction
145   // for this call. Insert a FP compare instruction and a conditional branch
146   // at the end of CurrBB.
147   Call->addAttribute(AttributeSet::FunctionIndex, Attribute::ReadNone);
148   CurrBB.getTerminator()->eraseFromParent();
149   Builder.SetInsertPoint(&CurrBB);
150   Value *FCmp = Builder.CreateFCmpOEQ(Call, Call);
151   Builder.CreateCondBr(FCmp, JoinBB, LibCallBB);
152 
153   // Add phi operands.
154   Phi->addIncoming(Call, &CurrBB);
155   Phi->addIncoming(LibCall, LibCallBB);
156 
157   BB = JoinBB->getIterator();
158   return true;
159 }
160 
createPartiallyInlineLibCallsPass()161 FunctionPass *llvm::createPartiallyInlineLibCallsPass() {
162   return new PartiallyInlineLibCalls();
163 }
164