1 /*
2  * Copyright 2015, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "bcc/Renderscript/RSScriptGroupFusion.h"
18 
19 #include "bcc/Assert.h"
20 #include "bcc/BCCContext.h"
21 #include "bcc/Source.h"
22 #include "bcc/Support/Log.h"
23 #include "bcinfo/MetadataExtractor.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/IR/DataLayout.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/Module.h"
28 #include "llvm/Support/raw_ostream.h"
29 
30 using llvm::Function;
31 using llvm::Module;
32 
33 using std::string;
34 
35 namespace bcc {
36 
37 namespace {
38 
getInvokeFunction(const Source & source,const int slot,Module * newModule)39 const Function* getInvokeFunction(const Source& source, const int slot,
40                                   Module* newModule) {
41   Module* module = const_cast<Module*>(&source.getModule());
42   bcinfo::MetadataExtractor metadata(module);
43   if (!metadata.extract()) {
44     ALOGE("Kernel fusion (module %s slot %d): failed to extract metadata",
45           source.getName().c_str(), slot);
46     return nullptr;
47   }
48   const char* functionName = metadata.getExportFuncNameList()[slot];
49   Function* func = newModule->getFunction(functionName);
50   // Materialize the function so that later the caller can inspect its argument
51   // and return types.
52   newModule->materialize(func);
53   return func;
54 }
55 
56 const Function*
getFunction(Module * mergedModule,const Source * source,const int slot,uint32_t * signature)57 getFunction(Module* mergedModule, const Source* source, const int slot,
58             uint32_t* signature) {
59   bcinfo::MetadataExtractor metadata(&source->getModule());
60   metadata.extract();
61 
62   const char* functionName = metadata.getExportForEachNameList()[slot];
63   if (functionName == nullptr || !functionName[0]) {
64     ALOGE("Kernel fusion (module %s slot %d): failed to find kernel function",
65           source->getName().c_str(), slot);
66     return nullptr;
67   }
68 
69   if (metadata.getExportForEachInputCountList()[slot] > 1) {
70     ALOGE("Kernel fusion (module %s function %s): cannot handle multiple inputs",
71           source->getName().c_str(), functionName);
72     return nullptr;
73   }
74 
75   if (signature != nullptr) {
76     *signature = metadata.getExportForEachSignatureList()[slot];
77   }
78 
79   const Function* function = mergedModule->getFunction(functionName);
80 
81   return function;
82 }
83 
84 // The whitelist of supported signature bits. Context or user data arguments are
85 // not currently supported in kernel fusion. To support them or any new kinds of
86 // arguments in the future, it requires not only listing the signature bits here,
87 // but also implementing additional necessary fusion logic in the getFusedFuncSig(),
88 // getFusedFuncType(), and fuseKernels() functions below.
89 constexpr uint32_t ExpectedSignatureBits =
90         bcinfo::MD_SIG_In |
91         bcinfo::MD_SIG_Out |
92         bcinfo::MD_SIG_X |
93         bcinfo::MD_SIG_Y |
94         bcinfo::MD_SIG_Z |
95         bcinfo::MD_SIG_Kernel;
96 
getFusedFuncSig(const std::vector<Source * > & sources,const std::vector<int> & slots,uint32_t * retSig)97 int getFusedFuncSig(const std::vector<Source*>& sources,
98                     const std::vector<int>& slots,
99                     uint32_t* retSig) {
100   *retSig = 0;
101   uint32_t firstSignature = 0;
102   uint32_t signature = 0;
103   auto slotIter = slots.begin();
104   for (const Source* source : sources) {
105     const int slot = *slotIter++;
106     bcinfo::MetadataExtractor metadata(&source->getModule());
107     metadata.extract();
108 
109     if (metadata.getExportForEachInputCountList()[slot] > 1) {
110       ALOGE("Kernel fusion (module %s slot %d): cannot handle multiple inputs",
111             source->getName().c_str(), slot);
112       return -1;
113     }
114 
115     signature = metadata.getExportForEachSignatureList()[slot];
116     if (signature & ~ExpectedSignatureBits) {
117       ALOGE("Kernel fusion (module %s slot %d): Unexpected signature %x",
118             source->getName().c_str(), slot, signature);
119       return -1;
120     }
121 
122     if (firstSignature == 0) {
123       firstSignature = signature;
124     }
125 
126     *retSig |= signature;
127   }
128 
129   if (!bcinfo::MetadataExtractor::hasForEachSignatureIn(firstSignature)) {
130     *retSig &= ~bcinfo::MD_SIG_In;
131   }
132 
133   if (!bcinfo::MetadataExtractor::hasForEachSignatureOut(signature)) {
134     *retSig &= ~bcinfo::MD_SIG_Out;
135   }
136 
137   return 0;
138 }
139 
getFusedFuncType(bcc::BCCContext & Context,const std::vector<Source * > & sources,const std::vector<int> & slots,Module * M,uint32_t * signature)140 llvm::FunctionType* getFusedFuncType(bcc::BCCContext& Context,
141                                      const std::vector<Source*>& sources,
142                                      const std::vector<int>& slots,
143                                      Module* M,
144                                      uint32_t* signature) {
145   int error = getFusedFuncSig(sources, slots, signature);
146 
147   if (error < 0) {
148     return nullptr;
149   }
150 
151   const Function* firstF = getFunction(M, sources.front(), slots.front(), nullptr);
152 
153   bccAssert (firstF != nullptr);
154 
155   llvm::SmallVector<llvm::Type*, 8> ArgTys;
156 
157   if (bcinfo::MetadataExtractor::hasForEachSignatureIn(*signature)) {
158     ArgTys.push_back(firstF->arg_begin()->getType());
159   }
160 
161   llvm::Type* I32Ty = llvm::IntegerType::get(Context.getLLVMContext(), 32);
162   if (bcinfo::MetadataExtractor::hasForEachSignatureX(*signature)) {
163     ArgTys.push_back(I32Ty);
164   }
165   if (bcinfo::MetadataExtractor::hasForEachSignatureY(*signature)) {
166     ArgTys.push_back(I32Ty);
167   }
168   if (bcinfo::MetadataExtractor::hasForEachSignatureZ(*signature)) {
169     ArgTys.push_back(I32Ty);
170   }
171 
172   const Function* lastF = getFunction(M, sources.back(), slots.back(), nullptr);
173 
174   bccAssert (lastF != nullptr);
175 
176   llvm::Type* retTy = lastF->getReturnType();
177 
178   return llvm::FunctionType::get(retTy, ArgTys, false);
179 }
180 
181 }  // anonymous namespace
182 
fuseKernels(bcc::BCCContext & Context,const std::vector<Source * > & sources,const std::vector<int> & slots,const std::string & fusedName,Module * mergedModule)183 bool fuseKernels(bcc::BCCContext& Context,
184                  const std::vector<Source *>& sources,
185                  const std::vector<int>& slots,
186                  const std::string& fusedName,
187                  Module* mergedModule) {
188   bccAssert(sources.size() == slots.size() && "sources and slots differ in size");
189 
190   uint32_t fusedFunctionSignature;
191 
192   llvm::FunctionType* fusedType =
193           getFusedFuncType(Context, sources, slots, mergedModule, &fusedFunctionSignature);
194 
195   if (fusedType == nullptr) {
196     return false;
197   }
198 
199   Function* fusedKernel =
200           (Function*)(mergedModule->getOrInsertFunction(fusedName, fusedType));
201 
202   llvm::LLVMContext& ctxt = Context.getLLVMContext();
203 
204   llvm::BasicBlock* block = llvm::BasicBlock::Create(ctxt, "entry", fusedKernel);
205   llvm::IRBuilder<> builder(block);
206 
207   Function::arg_iterator argIter = fusedKernel->arg_begin();
208 
209   llvm::Value* dataElement = nullptr;
210   if (bcinfo::MetadataExtractor::hasForEachSignatureIn(fusedFunctionSignature)) {
211     dataElement = argIter++;
212     dataElement->setName("DataIn");
213   }
214 
215   llvm::Value* X = nullptr;
216   if (bcinfo::MetadataExtractor::hasForEachSignatureX(fusedFunctionSignature)) {
217     X = argIter++;
218     X->setName("x");
219   }
220 
221   llvm::Value* Y = nullptr;
222   if (bcinfo::MetadataExtractor::hasForEachSignatureY(fusedFunctionSignature)) {
223     Y = argIter++;
224     Y->setName("y");
225   }
226 
227   llvm::Value* Z = nullptr;
228   if (bcinfo::MetadataExtractor::hasForEachSignatureZ(fusedFunctionSignature)) {
229     Z = argIter++;
230     Z->setName("z");
231   }
232 
233   auto slotIter = slots.begin();
234   for (const Source* source : sources) {
235     int slot = *slotIter;
236 
237     uint32_t inputFunctionSignature;
238     const Function* inputFunction =
239             getFunction(mergedModule, source, slot, &inputFunctionSignature);
240     if (inputFunction == nullptr) {
241       // Either failed to find the kernel function, or the function has multiple inputs.
242       return false;
243     }
244 
245     // Don't try to fuse a non-kernel
246     if (!bcinfo::MetadataExtractor::hasForEachSignatureKernel(inputFunctionSignature)) {
247       ALOGE("Kernel fusion (module %s function %s): not a kernel",
248             source->getName().c_str(), inputFunction->getName().str().c_str());
249       return false;
250     }
251 
252     std::vector<llvm::Value*> args;
253 
254     if (bcinfo::MetadataExtractor::hasForEachSignatureIn(inputFunctionSignature)) {
255       if (dataElement == nullptr) {
256         ALOGE("Kernel fusion (module %s function %s): expected input, but got null",
257               source->getName().c_str(), inputFunction->getName().str().c_str());
258         return false;
259       }
260 
261       const llvm::FunctionType* funcTy = inputFunction->getFunctionType();
262       llvm::Type* firstArgType = funcTy->getParamType(0);
263 
264       if (dataElement->getType() != firstArgType) {
265         std::string msg;
266         llvm::raw_string_ostream rso(msg);
267         rso << "Mismatching argument type, expected ";
268         firstArgType->print(rso);
269         rso << ", received ";
270         dataElement->getType()->print(rso);
271         ALOGE("Kernel fusion (module %s function %s): %s", source->getName().c_str(),
272               inputFunction->getName().str().c_str(), rso.str().c_str());
273         return false;
274       }
275 
276       args.push_back(dataElement);
277     } else {
278       // Only the first kernel in a batch is allowed to have no input
279       if (slotIter != slots.begin()) {
280         ALOGE("Kernel fusion (module %s function %s): function not first in batch takes no input",
281               source->getName().c_str(), inputFunction->getName().str().c_str());
282         return false;
283       }
284     }
285 
286     if (bcinfo::MetadataExtractor::hasForEachSignatureX(inputFunctionSignature)) {
287       args.push_back(X);
288     }
289 
290     if (bcinfo::MetadataExtractor::hasForEachSignatureY(inputFunctionSignature)) {
291       args.push_back(Y);
292     }
293 
294     if (bcinfo::MetadataExtractor::hasForEachSignatureZ(inputFunctionSignature)) {
295       args.push_back(Z);
296     }
297 
298     dataElement = builder.CreateCall((llvm::Value*)inputFunction, args);
299 
300     slotIter++;
301   }
302 
303   if (fusedKernel->getReturnType()->isVoidTy()) {
304     builder.CreateRetVoid();
305   } else {
306     builder.CreateRet(dataElement);
307   }
308 
309   llvm::NamedMDNode* ExportForEachNameMD =
310     mergedModule->getOrInsertNamedMetadata("#rs_export_foreach_name");
311 
312   llvm::MDString* nameMDStr = llvm::MDString::get(ctxt, fusedName);
313   llvm::MDNode* nameMDNode = llvm::MDNode::get(ctxt, nameMDStr);
314   ExportForEachNameMD->addOperand(nameMDNode);
315 
316   llvm::NamedMDNode* ExportForEachMD =
317     mergedModule->getOrInsertNamedMetadata("#rs_export_foreach");
318   llvm::MDString* sigMDStr = llvm::MDString::get(ctxt,
319                                                  llvm::utostr_32(fusedFunctionSignature));
320   llvm::MDNode* sigMDNode = llvm::MDNode::get(ctxt, sigMDStr);
321   ExportForEachMD->addOperand(sigMDNode);
322 
323   return true;
324 }
325 
renameInvoke(BCCContext & Context,const Source * source,const int slot,const std::string & newName,Module * module)326 bool renameInvoke(BCCContext& Context, const Source* source, const int slot,
327                   const std::string& newName, Module* module) {
328   const llvm::Function* F = getInvokeFunction(*source, slot, module);
329   std::vector<llvm::Type*> params;
330   for (auto I = F->arg_begin(), E = F->arg_end(); I != E; ++I) {
331     params.push_back(I->getType());
332   }
333   llvm::Type* returnTy = F->getReturnType();
334 
335   llvm::FunctionType* batchFuncTy =
336           llvm::FunctionType::get(returnTy, params, false);
337 
338   llvm::Function* newF =
339           llvm::Function::Create(batchFuncTy,
340                                  llvm::GlobalValue::ExternalLinkage, newName,
341                                  module);
342 
343   llvm::BasicBlock* block = llvm::BasicBlock::Create(Context.getLLVMContext(),
344                                                      "entry", newF);
345   llvm::IRBuilder<> builder(block);
346 
347   llvm::Function::arg_iterator argIter = newF->arg_begin();
348   llvm::Value* arg1 = argIter++;
349   builder.CreateCall((llvm::Value*)F, arg1);
350 
351   builder.CreateRetVoid();
352 
353   llvm::NamedMDNode* ExportFuncNameMD =
354           module->getOrInsertNamedMetadata("#rs_export_func");
355   llvm::MDString* strMD = llvm::MDString::get(module->getContext(), newName);
356   llvm::MDNode* nodeMD = llvm::MDNode::get(module->getContext(), strMD);
357   ExportFuncNameMD->addOperand(nodeMD);
358 
359   return true;
360 }
361 
362 }  // namespace bcc
363