1 /*
2  * Copyright 2015, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "bcc/Assert.h"
18 #include "bcc/Renderscript/RSUtils.h"
19 #include "bcc/Support/Log.h"
20 
21 #include <algorithm>
22 #include <vector>
23 
24 #include <llvm/IR/CallSite.h>
25 #include <llvm/IR/Type.h>
26 #include <llvm/IR/Instructions.h>
27 #include <llvm/IR/Module.h>
28 #include <llvm/IR/Function.h>
29 #include <llvm/Pass.h>
30 
31 namespace { // anonymous namespace
32 
33 static const bool kDebug = false;
34 
35 /* RSX86_64CallConvPass: This pass scans for calls to Renderscript functions in
36  * the CPU reference driver.  For such calls, it  identifies the
37  * pass-by-reference large-object pointer arguments introduced by the frontend
38  * to conform to the AArch64 calling convention (AAPCS).  These pointer
39  * arguments are converted to pass-by-value to match the calling convention of
40  * the CPU reference driver.
41  */
42 class RSX86_64CallConvPass: public llvm::ModulePass {
43 private:
IsRSFunctionOfInterest(llvm::Function & F)44   bool IsRSFunctionOfInterest(llvm::Function &F) {
45   // Only Renderscript functions that are not defined locally be considered
46     if (!F.empty()) // defined locally
47       return false;
48 
49     // llvm intrinsic or internal function
50     llvm::StringRef FName = F.getName();
51     if (FName.startswith("llvm."))
52       return false;
53 
54     // All other functions need to be checked for large-object parameters.
55     // Disallowed (non-Renderscript) functions are detected by a different pass.
56     return true;
57   }
58 
59   // Test if this argument needs to be converted to pass-by-value.
IsDerefNeeded(llvm::Function * F,llvm::Argument & Arg)60   bool IsDerefNeeded(llvm::Function *F, llvm::Argument &Arg) {
61     unsigned ArgNo = Arg.getArgNo();
62     llvm::Type *ArgTy = Arg.getType();
63 
64     // Do not consider arguments with 'sret' attribute.  Parameters with this
65     // attribute are actually pointers to structure return values.
66     if (Arg.hasStructRetAttr())
67       return false;
68 
69     // Dereference needed only if type is a pointer to a struct
70     if (!ArgTy->isPointerTy() || !ArgTy->getPointerElementType()->isStructTy())
71       return false;
72 
73     // Dereference needed only for certain RS struct objects.
74     llvm::Type *StructTy = ArgTy->getPointerElementType();
75     if (!isRsObjectType(StructTy))
76       return false;
77 
78     // TODO Find a better way to encode exceptions
79     llvm::StringRef FName = F->getName();
80     // rsSetObject's first parameter is a pointer
81     if (FName.find("rsSetObject") != std::string::npos && ArgNo == 0)
82       return false;
83     // rsClearObject's first parameter is a pointer
84     if (FName.find("rsClearObject") != std::string::npos && ArgNo == 0)
85       return false;
86 
87     return true;
88   }
89 
90   // Compute which arguments to this function need be converted to pass-by-value
FillArgsToDeref(llvm::Function * F,std::vector<unsigned> & ArgNums)91   bool FillArgsToDeref(llvm::Function *F, std::vector<unsigned> &ArgNums) {
92     bccAssert(ArgNums.size() == 0);
93 
94     for (auto &Arg: F->getArgumentList()) {
95       if (IsDerefNeeded(F, Arg)) {
96         ArgNums.push_back(Arg.getArgNo());
97 
98         if (kDebug) {
99           ALOGV("Lowering argument %u for function %s\n", Arg.getArgNo(),
100                 F->getName().str().c_str());
101         }
102       }
103     }
104     return ArgNums.size() > 0;
105   }
106 
RedefineFn(llvm::Function * OrigFn,std::vector<unsigned> & ArgsToDeref)107   llvm::Function *RedefineFn(llvm::Function *OrigFn,
108                              std::vector<unsigned> &ArgsToDeref) {
109 
110     llvm::FunctionType *FTy = OrigFn->getFunctionType();
111     std::vector<llvm::Type *> Params(FTy->param_begin(), FTy->param_end());
112 
113     llvm::FunctionType *NewTy = llvm::FunctionType::get(FTy->getReturnType(),
114                                                         Params,
115                                                         FTy->isVarArg());
116     llvm::Function *NewFn = llvm::Function::Create(NewTy,
117                                                    OrigFn->getLinkage(),
118                                                    OrigFn->getName(),
119                                                    OrigFn->getParent());
120 
121     // Add the ByVal attribute to the attribute list corresponding to this
122     // argument.  The list at index (i+1) corresponds to the i-th argument.  The
123     // list at index 0 corresponds to the return value's attribute.
124     for (auto i: ArgsToDeref) {
125       NewFn->addAttribute(i+1, llvm::Attribute::ByVal);
126     }
127 
128     NewFn->copyAttributesFrom(OrigFn);
129     NewFn->takeName(OrigFn);
130 
131     for (auto AI=OrigFn->arg_begin(), AE=OrigFn->arg_end(),
132               NAI=NewFn->arg_begin();
133          AI != AE; ++ AI, ++NAI) {
134       NAI->takeName(AI);
135     }
136 
137     return NewFn;
138   }
139 
ReplaceCallInsn(llvm::CallSite & CS,llvm::Function * NewFn,std::vector<unsigned> & ArgsToDeref)140   void ReplaceCallInsn(llvm::CallSite &CS,
141                        llvm::Function *NewFn,
142                        std::vector<unsigned> &ArgsToDeref) {
143 
144     llvm::CallInst *CI = llvm::cast<llvm::CallInst>(CS.getInstruction());
145     std::vector<llvm::Value *> Args(CS.arg_begin(), CS.arg_end());
146     auto NewCI = llvm::CallInst::Create(NewFn, Args, "", CI);
147 
148     // Add the ByVal attribute to the attribute list corresponding to this
149     // argument.  The list at index (i+1) corresponds to the i-th argument.  The
150     // list at index 0 corresponds to the return value's attribute.
151     for (auto i: ArgsToDeref) {
152       NewCI->addAttribute(i+1, llvm::Attribute::ByVal);
153     }
154     if (CI->isTailCall())
155       NewCI->setTailCall();
156 
157     if (!CI->getType()->isVoidTy())
158       CI->replaceAllUsesWith(NewCI);
159 
160     CI->eraseFromParent();
161   }
162 
163 public:
164   static char ID;
165 
RSX86_64CallConvPass()166   RSX86_64CallConvPass()
167     : ModulePass (ID) {
168   }
169 
getAnalysisUsage(llvm::AnalysisUsage & AU) const170   virtual void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
171     // This pass does not use any other analysis passes, but it does
172     // modify the existing functions in the module (thus altering the CFG).
173   }
174 
runOnModule(llvm::Module & M)175   bool runOnModule(llvm::Module &M) override {
176     // Avoid adding Functions and altering FunctionList while iterating over it
177     // by collecting functions and processing them later.
178     std::vector<llvm::Function *> FunctionsToHandle;
179 
180     auto &FunctionList = M.getFunctionList();
181     for (auto &OrigFn: FunctionList) {
182       if (!IsRSFunctionOfInterest(OrigFn))
183         continue;
184       FunctionsToHandle.push_back(&OrigFn);
185     }
186 
187     for (auto OrigFn: FunctionsToHandle) {
188       std::vector<unsigned> ArgsToDeref;
189       if (!FillArgsToDeref(OrigFn, ArgsToDeref))
190         continue;
191 
192       // Replace all calls to OrigFn and erase it from parent.
193       llvm::Function *NewFn = RedefineFn(OrigFn, ArgsToDeref);
194       while (!OrigFn->use_empty()) {
195         llvm::CallSite CS(OrigFn->user_back());
196         ReplaceCallInsn(CS, NewFn, ArgsToDeref);
197       }
198       OrigFn->eraseFromParent();
199     }
200 
201     return FunctionsToHandle.size() > 0;
202   }
203 
204 };
205 
206 }
207 
208 char RSX86_64CallConvPass::ID = 0;
209 
210 static llvm::RegisterPass<RSX86_64CallConvPass> X("X86-64-calling-conv",
211   "remove AArch64 assumptions from calls in X86-64");
212 
213 namespace bcc {
214 
215 llvm::ModulePass *
createRSX86_64CallConvPass()216 createRSX86_64CallConvPass() {
217   return new RSX86_64CallConvPass();
218 }
219 
220 }
221