1 /*
2  * Copyright 2015, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "RSScriptGroupFusion.h"
18 
19 #include "Assert.h"
20 #include "Log.h"
21 #include "bcc/BCCContext.h"
22 #include "bcc/Source.h"
23 #include "bcinfo/MetadataExtractor.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/IR/DataLayout.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/Module.h"
28 #include "llvm/Support/raw_ostream.h"
29 
30 using llvm::Function;
31 using llvm::Module;
32 
33 using std::string;
34 
35 namespace bcc {
36 
37 namespace {
38 
getInvokeFunction(const Source & source,const int slot,Module * newModule)39 const Function* getInvokeFunction(const Source& source, const int slot,
40                                   Module* newModule) {
41 
42   bcinfo::MetadataExtractor &metadata = *source.getMetadata();
43   const char* functionName = metadata.getExportFuncNameList()[slot];
44   Function* func = newModule->getFunction(functionName);
45   // Materialize the function so that later the caller can inspect its argument
46   // and return types.
47   newModule->materialize(func);
48   return func;
49 }
50 
51 const Function*
getFunction(Module * mergedModule,const Source * source,const int slot,uint32_t * signature)52 getFunction(Module* mergedModule, const Source* source, const int slot,
53             uint32_t* signature) {
54 
55   bcinfo::MetadataExtractor &metadata = *source->getMetadata();
56   const char* functionName = metadata.getExportForEachNameList()[slot];
57   if (functionName == nullptr || !functionName[0]) {
58     ALOGE("Kernel fusion (module %s slot %d): failed to find kernel function",
59           source->getName().c_str(), slot);
60     return nullptr;
61   }
62 
63   if (metadata.getExportForEachInputCountList()[slot] > 1) {
64     ALOGE("Kernel fusion (module %s function %s): cannot handle multiple inputs",
65           source->getName().c_str(), functionName);
66     return nullptr;
67   }
68 
69   if (signature != nullptr) {
70     *signature = metadata.getExportForEachSignatureList()[slot];
71   }
72 
73   const Function* function = mergedModule->getFunction(functionName);
74 
75   return function;
76 }
77 
78 // The whitelist of supported signature bits. Context or user data arguments are
79 // not currently supported in kernel fusion. To support them or any new kinds of
80 // arguments in the future, it requires not only listing the signature bits here,
81 // but also implementing additional necessary fusion logic in the getFusedFuncSig(),
82 // getFusedFuncType(), and fuseKernels() functions below.
83 constexpr uint32_t ExpectedSignatureBits =
84         bcinfo::MD_SIG_In |
85         bcinfo::MD_SIG_Out |
86         bcinfo::MD_SIG_X |
87         bcinfo::MD_SIG_Y |
88         bcinfo::MD_SIG_Z |
89         bcinfo::MD_SIG_Kernel;
90 
getFusedFuncSig(const std::vector<Source * > & sources,const std::vector<int> & slots,uint32_t * retSig)91 int getFusedFuncSig(const std::vector<Source*>& sources,
92                     const std::vector<int>& slots,
93                     uint32_t* retSig) {
94   *retSig = 0;
95   uint32_t firstSignature = 0;
96   uint32_t signature = 0;
97   auto slotIter = slots.begin();
98   for (const Source* source : sources) {
99     const int slot = *slotIter++;
100     bcinfo::MetadataExtractor &metadata = *source->getMetadata();
101 
102     if (metadata.getExportForEachInputCountList()[slot] > 1) {
103       ALOGE("Kernel fusion (module %s slot %d): cannot handle multiple inputs",
104             source->getName().c_str(), slot);
105       return -1;
106     }
107 
108     signature = metadata.getExportForEachSignatureList()[slot];
109     if (signature & ~ExpectedSignatureBits) {
110       ALOGE("Kernel fusion (module %s slot %d): Unexpected signature %x",
111             source->getName().c_str(), slot, signature);
112       return -1;
113     }
114 
115     if (firstSignature == 0) {
116       firstSignature = signature;
117     }
118 
119     *retSig |= signature;
120   }
121 
122   if (!bcinfo::MetadataExtractor::hasForEachSignatureIn(firstSignature)) {
123     *retSig &= ~bcinfo::MD_SIG_In;
124   }
125 
126   if (!bcinfo::MetadataExtractor::hasForEachSignatureOut(signature)) {
127     *retSig &= ~bcinfo::MD_SIG_Out;
128   }
129 
130   return 0;
131 }
132 
getFusedFuncType(bcc::BCCContext & Context,const std::vector<Source * > & sources,const std::vector<int> & slots,Module * M,uint32_t * signature)133 llvm::FunctionType* getFusedFuncType(bcc::BCCContext& Context,
134                                      const std::vector<Source*>& sources,
135                                      const std::vector<int>& slots,
136                                      Module* M,
137                                      uint32_t* signature) {
138   int error = getFusedFuncSig(sources, slots, signature);
139 
140   if (error < 0) {
141     return nullptr;
142   }
143 
144   const Function* firstF = getFunction(M, sources.front(), slots.front(), nullptr);
145 
146   bccAssert (firstF != nullptr);
147 
148   llvm::SmallVector<llvm::Type*, 8> ArgTys;
149 
150   if (bcinfo::MetadataExtractor::hasForEachSignatureIn(*signature)) {
151     ArgTys.push_back(firstF->arg_begin()->getType());
152   }
153 
154   llvm::Type* I32Ty = llvm::IntegerType::get(Context.getLLVMContext(), 32);
155   if (bcinfo::MetadataExtractor::hasForEachSignatureX(*signature)) {
156     ArgTys.push_back(I32Ty);
157   }
158   if (bcinfo::MetadataExtractor::hasForEachSignatureY(*signature)) {
159     ArgTys.push_back(I32Ty);
160   }
161   if (bcinfo::MetadataExtractor::hasForEachSignatureZ(*signature)) {
162     ArgTys.push_back(I32Ty);
163   }
164 
165   const Function* lastF = getFunction(M, sources.back(), slots.back(), nullptr);
166 
167   bccAssert (lastF != nullptr);
168 
169   llvm::Type* retTy = lastF->getReturnType();
170 
171   return llvm::FunctionType::get(retTy, ArgTys, false);
172 }
173 
174 }  // anonymous namespace
175 
fuseKernels(bcc::BCCContext & Context,const std::vector<Source * > & sources,const std::vector<int> & slots,const std::string & fusedName,Module * mergedModule)176 bool fuseKernels(bcc::BCCContext& Context,
177                  const std::vector<Source *>& sources,
178                  const std::vector<int>& slots,
179                  const std::string& fusedName,
180                  Module* mergedModule) {
181   bccAssert(sources.size() == slots.size() && "sources and slots differ in size");
182 
183   uint32_t fusedFunctionSignature;
184 
185   llvm::FunctionType* fusedType =
186           getFusedFuncType(Context, sources, slots, mergedModule, &fusedFunctionSignature);
187 
188   if (fusedType == nullptr) {
189     return false;
190   }
191 
192   Function* fusedKernel =
193           (Function*)(mergedModule->getOrInsertFunction(fusedName, fusedType));
194 
195   llvm::LLVMContext& ctxt = Context.getLLVMContext();
196 
197   llvm::BasicBlock* block = llvm::BasicBlock::Create(ctxt, "entry", fusedKernel);
198   llvm::IRBuilder<> builder(block);
199 
200   Function::arg_iterator argIter = fusedKernel->arg_begin();
201 
202   llvm::Value* dataElement = nullptr;
203   if (bcinfo::MetadataExtractor::hasForEachSignatureIn(fusedFunctionSignature)) {
204     dataElement = &*(argIter++);
205     dataElement->setName("DataIn");
206   }
207 
208   llvm::Value* X = nullptr;
209   if (bcinfo::MetadataExtractor::hasForEachSignatureX(fusedFunctionSignature)) {
210     X = &*(argIter++);
211     X->setName("x");
212   }
213 
214   llvm::Value* Y = nullptr;
215   if (bcinfo::MetadataExtractor::hasForEachSignatureY(fusedFunctionSignature)) {
216     Y = &*(argIter++);
217     Y->setName("y");
218   }
219 
220   llvm::Value* Z = nullptr;
221   if (bcinfo::MetadataExtractor::hasForEachSignatureZ(fusedFunctionSignature)) {
222     Z = &*(argIter++);
223     Z->setName("z");
224   }
225 
226   auto slotIter = slots.begin();
227   for (const Source* source : sources) {
228     int slot = *slotIter;
229 
230     uint32_t inputFunctionSignature;
231     const Function* inputFunction =
232             getFunction(mergedModule, source, slot, &inputFunctionSignature);
233     if (inputFunction == nullptr) {
234       // Either failed to find the kernel function, or the function has multiple inputs.
235       return false;
236     }
237 
238     // Don't try to fuse a non-kernel
239     if (!bcinfo::MetadataExtractor::hasForEachSignatureKernel(inputFunctionSignature)) {
240       ALOGE("Kernel fusion (module %s function %s): not a kernel",
241             source->getName().c_str(), inputFunction->getName().str().c_str());
242       return false;
243     }
244 
245     std::vector<llvm::Value*> args;
246 
247     if (bcinfo::MetadataExtractor::hasForEachSignatureIn(inputFunctionSignature)) {
248       if (dataElement == nullptr) {
249         ALOGE("Kernel fusion (module %s function %s): expected input, but got null",
250               source->getName().c_str(), inputFunction->getName().str().c_str());
251         return false;
252       }
253 
254       const llvm::FunctionType* funcTy = inputFunction->getFunctionType();
255       llvm::Type* firstArgType = funcTy->getParamType(0);
256 
257       if (dataElement->getType() != firstArgType) {
258         std::string msg;
259         llvm::raw_string_ostream rso(msg);
260         rso << "Mismatching argument type, expected ";
261         firstArgType->print(rso);
262         rso << ", received ";
263         dataElement->getType()->print(rso);
264         ALOGE("Kernel fusion (module %s function %s): %s", source->getName().c_str(),
265               inputFunction->getName().str().c_str(), rso.str().c_str());
266         return false;
267       }
268 
269       args.push_back(dataElement);
270     } else {
271       // Only the first kernel in a batch is allowed to have no input
272       if (slotIter != slots.begin()) {
273         ALOGE("Kernel fusion (module %s function %s): function not first in batch takes no input",
274               source->getName().c_str(), inputFunction->getName().str().c_str());
275         return false;
276       }
277     }
278 
279     if (bcinfo::MetadataExtractor::hasForEachSignatureX(inputFunctionSignature)) {
280       args.push_back(X);
281     }
282 
283     if (bcinfo::MetadataExtractor::hasForEachSignatureY(inputFunctionSignature)) {
284       args.push_back(Y);
285     }
286 
287     if (bcinfo::MetadataExtractor::hasForEachSignatureZ(inputFunctionSignature)) {
288       args.push_back(Z);
289     }
290 
291     dataElement = builder.CreateCall((llvm::Value*)inputFunction, args);
292 
293     slotIter++;
294   }
295 
296   if (fusedKernel->getReturnType()->isVoidTy()) {
297     builder.CreateRetVoid();
298   } else {
299     builder.CreateRet(dataElement);
300   }
301 
302   llvm::NamedMDNode* ExportForEachNameMD =
303     mergedModule->getOrInsertNamedMetadata("#rs_export_foreach_name");
304 
305   llvm::MDString* nameMDStr = llvm::MDString::get(ctxt, fusedName);
306   llvm::MDNode* nameMDNode = llvm::MDNode::get(ctxt, nameMDStr);
307   ExportForEachNameMD->addOperand(nameMDNode);
308 
309   llvm::NamedMDNode* ExportForEachMD =
310     mergedModule->getOrInsertNamedMetadata("#rs_export_foreach");
311   llvm::MDString* sigMDStr = llvm::MDString::get(ctxt,
312                                                  llvm::utostr(fusedFunctionSignature));
313   llvm::MDNode* sigMDNode = llvm::MDNode::get(ctxt, sigMDStr);
314   ExportForEachMD->addOperand(sigMDNode);
315 
316   return true;
317 }
318 
renameInvoke(BCCContext & Context,const Source * source,const int slot,const std::string & newName,Module * module)319 bool renameInvoke(BCCContext& Context, const Source* source, const int slot,
320                   const std::string& newName, Module* module) {
321   const llvm::Function* F = getInvokeFunction(*source, slot, module);
322   std::vector<llvm::Type*> params;
323   for (auto I = F->arg_begin(), E = F->arg_end(); I != E; ++I) {
324     params.push_back(I->getType());
325   }
326   llvm::Type* returnTy = F->getReturnType();
327 
328   llvm::FunctionType* batchFuncTy =
329           llvm::FunctionType::get(returnTy, params, false);
330 
331   llvm::Function* newF =
332           llvm::Function::Create(batchFuncTy,
333                                  llvm::GlobalValue::ExternalLinkage, newName,
334                                  module);
335 
336   llvm::BasicBlock* block = llvm::BasicBlock::Create(Context.getLLVMContext(),
337                                                      "entry", newF);
338   llvm::IRBuilder<> builder(block);
339 
340   llvm::Function::arg_iterator argIter = newF->arg_begin();
341   llvm::Value* arg1 = &*(argIter++);
342   builder.CreateCall((llvm::Value*)F, arg1);
343 
344   builder.CreateRetVoid();
345 
346   llvm::NamedMDNode* ExportFuncNameMD =
347           module->getOrInsertNamedMetadata("#rs_export_func");
348   llvm::MDString* strMD = llvm::MDString::get(module->getContext(), newName);
349   llvm::MDNode* nodeMD = llvm::MDNode::get(module->getContext(), strMD);
350   ExportFuncNameMD->addOperand(nodeMD);
351 
352   return true;
353 }
354 
355 }  // namespace bcc
356