1 /*
2 * Copyright 2015, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "bcc/Renderscript/RSScriptGroupFusion.h"
18
19 #include "bcc/Assert.h"
20 #include "bcc/BCCContext.h"
21 #include "bcc/Source.h"
22 #include "bcc/Support/Log.h"
23 #include "bcinfo/MetadataExtractor.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/IR/DataLayout.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/Module.h"
28 #include "llvm/Support/raw_ostream.h"
29
30 using llvm::Function;
31 using llvm::Module;
32
33 using std::string;
34
35 namespace bcc {
36
37 namespace {
38
getInvokeFunction(const Source & source,const int slot,Module * newModule)39 const Function* getInvokeFunction(const Source& source, const int slot,
40 Module* newModule) {
41 Module* module = const_cast<Module*>(&source.getModule());
42 bcinfo::MetadataExtractor metadata(module);
43 if (!metadata.extract()) {
44 ALOGE("Kernel fusion (module %s slot %d): failed to extract metadata",
45 source.getName().c_str(), slot);
46 return nullptr;
47 }
48 const char* functionName = metadata.getExportFuncNameList()[slot];
49 Function* func = newModule->getFunction(functionName);
50 // Materialize the function so that later the caller can inspect its argument
51 // and return types.
52 newModule->materialize(func);
53 return func;
54 }
55
56 const Function*
getFunction(Module * mergedModule,const Source * source,const int slot,uint32_t * signature)57 getFunction(Module* mergedModule, const Source* source, const int slot,
58 uint32_t* signature) {
59 bcinfo::MetadataExtractor metadata(&source->getModule());
60 metadata.extract();
61
62 const char* functionName = metadata.getExportForEachNameList()[slot];
63 if (functionName == nullptr || !functionName[0]) {
64 ALOGE("Kernel fusion (module %s slot %d): failed to find kernel function",
65 source->getName().c_str(), slot);
66 return nullptr;
67 }
68
69 if (metadata.getExportForEachInputCountList()[slot] > 1) {
70 ALOGE("Kernel fusion (module %s function %s): cannot handle multiple inputs",
71 source->getName().c_str(), functionName);
72 return nullptr;
73 }
74
75 if (signature != nullptr) {
76 *signature = metadata.getExportForEachSignatureList()[slot];
77 }
78
79 const Function* function = mergedModule->getFunction(functionName);
80
81 return function;
82 }
83
84 // The whitelist of supported signature bits. Context or user data arguments are
85 // not currently supported in kernel fusion. To support them or any new kinds of
86 // arguments in the future, it requires not only listing the signature bits here,
87 // but also implementing additional necessary fusion logic in the getFusedFuncSig(),
88 // getFusedFuncType(), and fuseKernels() functions below.
89 constexpr uint32_t ExpectedSignatureBits =
90 bcinfo::MD_SIG_In |
91 bcinfo::MD_SIG_Out |
92 bcinfo::MD_SIG_X |
93 bcinfo::MD_SIG_Y |
94 bcinfo::MD_SIG_Z |
95 bcinfo::MD_SIG_Kernel;
96
getFusedFuncSig(const std::vector<Source * > & sources,const std::vector<int> & slots,uint32_t * retSig)97 int getFusedFuncSig(const std::vector<Source*>& sources,
98 const std::vector<int>& slots,
99 uint32_t* retSig) {
100 *retSig = 0;
101 uint32_t firstSignature = 0;
102 uint32_t signature = 0;
103 auto slotIter = slots.begin();
104 for (const Source* source : sources) {
105 const int slot = *slotIter++;
106 bcinfo::MetadataExtractor metadata(&source->getModule());
107 metadata.extract();
108
109 if (metadata.getExportForEachInputCountList()[slot] > 1) {
110 ALOGE("Kernel fusion (module %s slot %d): cannot handle multiple inputs",
111 source->getName().c_str(), slot);
112 return -1;
113 }
114
115 signature = metadata.getExportForEachSignatureList()[slot];
116 if (signature & ~ExpectedSignatureBits) {
117 ALOGE("Kernel fusion (module %s slot %d): Unexpected signature %x",
118 source->getName().c_str(), slot, signature);
119 return -1;
120 }
121
122 if (firstSignature == 0) {
123 firstSignature = signature;
124 }
125
126 *retSig |= signature;
127 }
128
129 if (!bcinfo::MetadataExtractor::hasForEachSignatureIn(firstSignature)) {
130 *retSig &= ~bcinfo::MD_SIG_In;
131 }
132
133 if (!bcinfo::MetadataExtractor::hasForEachSignatureOut(signature)) {
134 *retSig &= ~bcinfo::MD_SIG_Out;
135 }
136
137 return 0;
138 }
139
getFusedFuncType(bcc::BCCContext & Context,const std::vector<Source * > & sources,const std::vector<int> & slots,Module * M,uint32_t * signature)140 llvm::FunctionType* getFusedFuncType(bcc::BCCContext& Context,
141 const std::vector<Source*>& sources,
142 const std::vector<int>& slots,
143 Module* M,
144 uint32_t* signature) {
145 int error = getFusedFuncSig(sources, slots, signature);
146
147 if (error < 0) {
148 return nullptr;
149 }
150
151 const Function* firstF = getFunction(M, sources.front(), slots.front(), nullptr);
152
153 bccAssert (firstF != nullptr);
154
155 llvm::SmallVector<llvm::Type*, 8> ArgTys;
156
157 if (bcinfo::MetadataExtractor::hasForEachSignatureIn(*signature)) {
158 ArgTys.push_back(firstF->arg_begin()->getType());
159 }
160
161 llvm::Type* I32Ty = llvm::IntegerType::get(Context.getLLVMContext(), 32);
162 if (bcinfo::MetadataExtractor::hasForEachSignatureX(*signature)) {
163 ArgTys.push_back(I32Ty);
164 }
165 if (bcinfo::MetadataExtractor::hasForEachSignatureY(*signature)) {
166 ArgTys.push_back(I32Ty);
167 }
168 if (bcinfo::MetadataExtractor::hasForEachSignatureZ(*signature)) {
169 ArgTys.push_back(I32Ty);
170 }
171
172 const Function* lastF = getFunction(M, sources.back(), slots.back(), nullptr);
173
174 bccAssert (lastF != nullptr);
175
176 llvm::Type* retTy = lastF->getReturnType();
177
178 return llvm::FunctionType::get(retTy, ArgTys, false);
179 }
180
181 } // anonymous namespace
182
fuseKernels(bcc::BCCContext & Context,const std::vector<Source * > & sources,const std::vector<int> & slots,const std::string & fusedName,Module * mergedModule)183 bool fuseKernels(bcc::BCCContext& Context,
184 const std::vector<Source *>& sources,
185 const std::vector<int>& slots,
186 const std::string& fusedName,
187 Module* mergedModule) {
188 bccAssert(sources.size() == slots.size() && "sources and slots differ in size");
189
190 uint32_t fusedFunctionSignature;
191
192 llvm::FunctionType* fusedType =
193 getFusedFuncType(Context, sources, slots, mergedModule, &fusedFunctionSignature);
194
195 if (fusedType == nullptr) {
196 return false;
197 }
198
199 Function* fusedKernel =
200 (Function*)(mergedModule->getOrInsertFunction(fusedName, fusedType));
201
202 llvm::LLVMContext& ctxt = Context.getLLVMContext();
203
204 llvm::BasicBlock* block = llvm::BasicBlock::Create(ctxt, "entry", fusedKernel);
205 llvm::IRBuilder<> builder(block);
206
207 Function::arg_iterator argIter = fusedKernel->arg_begin();
208
209 llvm::Value* dataElement = nullptr;
210 if (bcinfo::MetadataExtractor::hasForEachSignatureIn(fusedFunctionSignature)) {
211 dataElement = argIter++;
212 dataElement->setName("DataIn");
213 }
214
215 llvm::Value* X = nullptr;
216 if (bcinfo::MetadataExtractor::hasForEachSignatureX(fusedFunctionSignature)) {
217 X = argIter++;
218 X->setName("x");
219 }
220
221 llvm::Value* Y = nullptr;
222 if (bcinfo::MetadataExtractor::hasForEachSignatureY(fusedFunctionSignature)) {
223 Y = argIter++;
224 Y->setName("y");
225 }
226
227 llvm::Value* Z = nullptr;
228 if (bcinfo::MetadataExtractor::hasForEachSignatureZ(fusedFunctionSignature)) {
229 Z = argIter++;
230 Z->setName("z");
231 }
232
233 auto slotIter = slots.begin();
234 for (const Source* source : sources) {
235 int slot = *slotIter;
236
237 uint32_t inputFunctionSignature;
238 const Function* inputFunction =
239 getFunction(mergedModule, source, slot, &inputFunctionSignature);
240 if (inputFunction == nullptr) {
241 // Either failed to find the kernel function, or the function has multiple inputs.
242 return false;
243 }
244
245 // Don't try to fuse a non-kernel
246 if (!bcinfo::MetadataExtractor::hasForEachSignatureKernel(inputFunctionSignature)) {
247 ALOGE("Kernel fusion (module %s function %s): not a kernel",
248 source->getName().c_str(), inputFunction->getName().str().c_str());
249 return false;
250 }
251
252 std::vector<llvm::Value*> args;
253
254 if (bcinfo::MetadataExtractor::hasForEachSignatureIn(inputFunctionSignature)) {
255 if (dataElement == nullptr) {
256 ALOGE("Kernel fusion (module %s function %s): expected input, but got null",
257 source->getName().c_str(), inputFunction->getName().str().c_str());
258 return false;
259 }
260
261 const llvm::FunctionType* funcTy = inputFunction->getFunctionType();
262 llvm::Type* firstArgType = funcTy->getParamType(0);
263
264 if (dataElement->getType() != firstArgType) {
265 std::string msg;
266 llvm::raw_string_ostream rso(msg);
267 rso << "Mismatching argument type, expected ";
268 firstArgType->print(rso);
269 rso << ", received ";
270 dataElement->getType()->print(rso);
271 ALOGE("Kernel fusion (module %s function %s): %s", source->getName().c_str(),
272 inputFunction->getName().str().c_str(), rso.str().c_str());
273 return false;
274 }
275
276 args.push_back(dataElement);
277 } else {
278 // Only the first kernel in a batch is allowed to have no input
279 if (slotIter != slots.begin()) {
280 ALOGE("Kernel fusion (module %s function %s): function not first in batch takes no input",
281 source->getName().c_str(), inputFunction->getName().str().c_str());
282 return false;
283 }
284 }
285
286 if (bcinfo::MetadataExtractor::hasForEachSignatureX(inputFunctionSignature)) {
287 args.push_back(X);
288 }
289
290 if (bcinfo::MetadataExtractor::hasForEachSignatureY(inputFunctionSignature)) {
291 args.push_back(Y);
292 }
293
294 if (bcinfo::MetadataExtractor::hasForEachSignatureZ(inputFunctionSignature)) {
295 args.push_back(Z);
296 }
297
298 dataElement = builder.CreateCall((llvm::Value*)inputFunction, args);
299
300 slotIter++;
301 }
302
303 if (fusedKernel->getReturnType()->isVoidTy()) {
304 builder.CreateRetVoid();
305 } else {
306 builder.CreateRet(dataElement);
307 }
308
309 llvm::NamedMDNode* ExportForEachNameMD =
310 mergedModule->getOrInsertNamedMetadata("#rs_export_foreach_name");
311
312 llvm::MDString* nameMDStr = llvm::MDString::get(ctxt, fusedName);
313 llvm::MDNode* nameMDNode = llvm::MDNode::get(ctxt, nameMDStr);
314 ExportForEachNameMD->addOperand(nameMDNode);
315
316 llvm::NamedMDNode* ExportForEachMD =
317 mergedModule->getOrInsertNamedMetadata("#rs_export_foreach");
318 llvm::MDString* sigMDStr = llvm::MDString::get(ctxt,
319 llvm::utostr_32(fusedFunctionSignature));
320 llvm::MDNode* sigMDNode = llvm::MDNode::get(ctxt, sigMDStr);
321 ExportForEachMD->addOperand(sigMDNode);
322
323 return true;
324 }
325
renameInvoke(BCCContext & Context,const Source * source,const int slot,const std::string & newName,Module * module)326 bool renameInvoke(BCCContext& Context, const Source* source, const int slot,
327 const std::string& newName, Module* module) {
328 const llvm::Function* F = getInvokeFunction(*source, slot, module);
329 std::vector<llvm::Type*> params;
330 for (auto I = F->arg_begin(), E = F->arg_end(); I != E; ++I) {
331 params.push_back(I->getType());
332 }
333 llvm::Type* returnTy = F->getReturnType();
334
335 llvm::FunctionType* batchFuncTy =
336 llvm::FunctionType::get(returnTy, params, false);
337
338 llvm::Function* newF =
339 llvm::Function::Create(batchFuncTy,
340 llvm::GlobalValue::ExternalLinkage, newName,
341 module);
342
343 llvm::BasicBlock* block = llvm::BasicBlock::Create(Context.getLLVMContext(),
344 "entry", newF);
345 llvm::IRBuilder<> builder(block);
346
347 llvm::Function::arg_iterator argIter = newF->arg_begin();
348 llvm::Value* arg1 = argIter++;
349 builder.CreateCall((llvm::Value*)F, arg1);
350
351 builder.CreateRetVoid();
352
353 llvm::NamedMDNode* ExportFuncNameMD =
354 module->getOrInsertNamedMetadata("#rs_export_func");
355 llvm::MDString* strMD = llvm::MDString::get(module->getContext(), newName);
356 llvm::MDNode* nodeMD = llvm::MDNode::get(module->getContext(), strMD);
357 ExportFuncNameMD->addOperand(nodeMD);
358
359 return true;
360 }
361
362 } // namespace bcc
363