1 /*
2  * Copyright 2015, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "Assert.h"
18 #include "Log.h"
19 #include "RSUtils.h"
20 
21 #include <algorithm>
22 #include <vector>
23 
24 #include <llvm/IR/CallSite.h>
25 #include <llvm/IR/Type.h>
26 #include <llvm/IR/Instructions.h>
27 #include <llvm/IR/Module.h>
28 #include <llvm/IR/Function.h>
29 #include <llvm/Pass.h>
30 
31 namespace { // anonymous namespace
32 
33 static const bool kDebug = false;
34 
35 /* RSX86_64CallConvPass: This pass scans for calls to Renderscript functions in
36  * the CPU reference driver.  For such calls, it  identifies the
37  * pass-by-reference large-object pointer arguments introduced by the frontend
38  * to conform to the AArch64 calling convention (AAPCS).  These pointer
39  * arguments are converted to pass-by-value to match the calling convention of
40  * the CPU reference driver.
41  */
42 class RSX86_64CallConvPass: public llvm::ModulePass {
43 private:
IsRSFunctionOfInterest(llvm::Function & F)44   bool IsRSFunctionOfInterest(llvm::Function &F) {
45   // Only Renderscript functions that are not defined locally be considered
46     if (!F.empty()) // defined locally
47       return false;
48 
49     // llvm intrinsic or internal function
50     llvm::StringRef FName = F.getName();
51     if (FName.startswith("llvm."))
52       return false;
53 
54     // All other functions need to be checked for large-object parameters.
55     // Disallowed (non-Renderscript) functions are detected by a different pass.
56     return true;
57   }
58 
59   // Test if this argument needs to be converted to pass-by-value.
IsDerefNeeded(llvm::Function * F,llvm::Argument & Arg)60   bool IsDerefNeeded(llvm::Function *F, llvm::Argument &Arg) {
61     unsigned ArgNo = Arg.getArgNo();
62     llvm::Type *ArgTy = Arg.getType();
63 
64     // Do not consider arguments with 'sret' attribute.  Parameters with this
65     // attribute are actually pointers to structure return values.
66     if (Arg.hasStructRetAttr())
67       return false;
68 
69     // Dereference needed only if type is a pointer to a struct
70     if (!ArgTy->isPointerTy() || !ArgTy->getPointerElementType()->isStructTy())
71       return false;
72 
73     // Dereference needed only for certain RS struct objects.
74     llvm::Type *StructTy = ArgTy->getPointerElementType();
75     if (!isRsObjectType(StructTy))
76       return false;
77 
78     // TODO Find a better way to encode exceptions
79     llvm::StringRef FName = F->getName();
80     // rsSetObject's first parameter is a pointer
81     if (FName.find("rsSetObject") != std::string::npos && ArgNo == 0)
82       return false;
83     // rsClearObject's first parameter is a pointer
84     if (FName.find("rsClearObject") != std::string::npos && ArgNo == 0)
85       return false;
86     // rsForEachInternal's fifth parameter is a pointer
87     if (FName.find("rsForEachInternal") != std::string::npos && ArgNo == 4)
88       return false;
89 
90     return true;
91   }
92 
93   // Compute which arguments to this function need be converted to pass-by-value
FillArgsToDeref(llvm::Function * F,std::vector<unsigned> & ArgNums)94   bool FillArgsToDeref(llvm::Function *F, std::vector<unsigned> &ArgNums) {
95     bccAssert(ArgNums.size() == 0);
96 
97     for (auto &Arg: F->getArgumentList()) {
98       if (IsDerefNeeded(F, Arg)) {
99         ArgNums.push_back(Arg.getArgNo());
100 
101         if (kDebug) {
102           ALOGV("Lowering argument %u for function %s\n", Arg.getArgNo(),
103                 F->getName().str().c_str());
104         }
105       }
106     }
107     return ArgNums.size() > 0;
108   }
109 
RedefineFn(llvm::Function * OrigFn,std::vector<unsigned> & ArgsToDeref)110   llvm::Function *RedefineFn(llvm::Function *OrigFn,
111                              std::vector<unsigned> &ArgsToDeref) {
112 
113     llvm::FunctionType *FTy = OrigFn->getFunctionType();
114     std::vector<llvm::Type *> Params(FTy->param_begin(), FTy->param_end());
115 
116     llvm::FunctionType *NewTy = llvm::FunctionType::get(FTy->getReturnType(),
117                                                         Params,
118                                                         FTy->isVarArg());
119     llvm::Function *NewFn = llvm::Function::Create(NewTy,
120                                                    OrigFn->getLinkage(),
121                                                    OrigFn->getName(),
122                                                    OrigFn->getParent());
123 
124     // Add the ByVal attribute to the attribute list corresponding to this
125     // argument.  The list at index (i+1) corresponds to the i-th argument.  The
126     // list at index 0 corresponds to the return value's attribute.
127     for (auto i: ArgsToDeref) {
128       NewFn->addAttribute(i+1, llvm::Attribute::ByVal);
129     }
130 
131     NewFn->copyAttributesFrom(OrigFn);
132     NewFn->takeName(OrigFn);
133 
134     for (auto AI=OrigFn->arg_begin(), AE=OrigFn->arg_end(),
135               NAI=NewFn->arg_begin();
136          AI != AE; ++ AI, ++NAI) {
137       NAI->takeName(&*AI);
138     }
139 
140     return NewFn;
141   }
142 
ReplaceCallInsn(llvm::CallSite & CS,llvm::Function * NewFn,std::vector<unsigned> & ArgsToDeref)143   void ReplaceCallInsn(llvm::CallSite &CS,
144                        llvm::Function *NewFn,
145                        std::vector<unsigned> &ArgsToDeref) {
146 
147     llvm::CallInst *CI = llvm::cast<llvm::CallInst>(CS.getInstruction());
148     std::vector<llvm::Value *> Args(CS.arg_begin(), CS.arg_end());
149     auto NewCI = llvm::CallInst::Create(NewFn, Args, "", CI);
150 
151     // Add the ByVal attribute to the attribute list corresponding to this
152     // argument.  The list at index (i+1) corresponds to the i-th argument.  The
153     // list at index 0 corresponds to the return value's attribute.
154     for (auto i: ArgsToDeref) {
155       NewCI->addAttribute(i+1, llvm::Attribute::ByVal);
156     }
157     if (CI->isTailCall())
158       NewCI->setTailCall();
159 
160     if (!CI->getType()->isVoidTy())
161       CI->replaceAllUsesWith(NewCI);
162 
163     CI->eraseFromParent();
164   }
165 
166 public:
167   static char ID;
168 
RSX86_64CallConvPass()169   RSX86_64CallConvPass()
170     : ModulePass (ID) {
171   }
172 
getAnalysisUsage(llvm::AnalysisUsage & AU) const173   virtual void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
174     // This pass does not use any other analysis passes, but it does
175     // modify the existing functions in the module (thus altering the CFG).
176   }
177 
runOnModule(llvm::Module & M)178   bool runOnModule(llvm::Module &M) override {
179     // Avoid adding Functions and altering FunctionList while iterating over it
180     // by collecting functions and processing them later.
181     std::vector<llvm::Function *> FunctionsToHandle;
182 
183     auto &FunctionList = M.getFunctionList();
184     for (auto &OrigFn: FunctionList) {
185       if (!IsRSFunctionOfInterest(OrigFn))
186         continue;
187       FunctionsToHandle.push_back(&OrigFn);
188     }
189 
190     for (auto OrigFn: FunctionsToHandle) {
191       std::vector<unsigned> ArgsToDeref;
192       if (!FillArgsToDeref(OrigFn, ArgsToDeref))
193         continue;
194 
195       // Replace all calls to OrigFn and erase it from parent.
196       llvm::Function *NewFn = RedefineFn(OrigFn, ArgsToDeref);
197       while (!OrigFn->use_empty()) {
198         llvm::CallSite CS(OrigFn->user_back());
199         ReplaceCallInsn(CS, NewFn, ArgsToDeref);
200       }
201       OrigFn->eraseFromParent();
202     }
203 
204     return FunctionsToHandle.size() > 0;
205   }
206 
207 };
208 
209 }
210 
211 char RSX86_64CallConvPass::ID = 0;
212 
213 static llvm::RegisterPass<RSX86_64CallConvPass> X("X86-64-calling-conv",
214   "remove AArch64 assumptions from calls in X86-64");
215 
216 namespace bcc {
217 
218 llvm::ModulePass *
createRSX86_64CallConvPass()219 createRSX86_64CallConvPass() {
220   return new RSX86_64CallConvPass();
221 }
222 
223 }
224