1 /*
2 * Copyright 2017, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "Wrapper.h"
18
19 #include "llvm/IR/Module.h"
20
21 #include "Builtin.h"
22 #include "Context.h"
23 #include "GlobalAllocSPIRITPass.h"
24 #include "RSAllocationUtils.h"
25 #include "bcinfo/MetadataExtractor.h"
26 #include "builder.h"
27 #include "instructions.h"
28 #include "module.h"
29 #include "pass.h"
30
31 #include <sstream>
32 #include <vector>
33
34 using bcinfo::MetadataExtractor;
35
36 namespace android {
37 namespace spirit {
38
AddBuffer(Instruction * elementType,uint32_t binding,Builder & b,Module * m)39 VariableInst *AddBuffer(Instruction *elementType, uint32_t binding, Builder &b,
40 Module *m) {
41 auto ArrTy = m->getRuntimeArrayType(elementType);
42 const size_t stride = m->getSize(elementType);
43 ArrTy->decorate(Decoration::ArrayStride)->addExtraOperand(stride);
44 auto StructTy = m->getStructType(ArrTy);
45 StructTy->decorate(Decoration::BufferBlock);
46 StructTy->memberDecorate(0, Decoration::Offset)->addExtraOperand(0);
47
48 auto StructPtrTy = m->getPointerType(StorageClass::Uniform, StructTy);
49
50 VariableInst *bufferVar = b.MakeVariable(StructPtrTy, StorageClass::Uniform);
51 bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0);
52 bufferVar->decorate(Decoration::Binding)->addExtraOperand(binding);
53 m->addVariable(bufferVar);
54
55 return bufferVar;
56 }
57
AddWrapper(const char * name,const uint32_t signature,const uint32_t numInput,Builder & b,Module * m)58 bool AddWrapper(const char *name, const uint32_t signature,
59 const uint32_t numInput, Builder &b, Module *m) {
60 FunctionDefinition *kernel = m->lookupFunctionDefinitionByName(name);
61 if (kernel == nullptr) {
62 // In the metadata for RenderScript LLVM bitcode, the first foreach kernel
63 // is always reserved for the root kernel, even though in the most recent RS
64 // apps it does not exist. Simply bypass wrapper generation here, and return
65 // true for this case.
66 // Otherwise, if a non-root kernel function cannot be found, it is a
67 // fatal internal error which is really unexpected.
68 return (strncmp(name, "root", 4) == 0);
69 }
70
71 // The following three cases are not supported
72 if (!MetadataExtractor::hasForEachSignatureKernel(signature)) {
73 // Not handling old-style kernel
74 return false;
75 }
76
77 if (MetadataExtractor::hasForEachSignatureUsrData(signature)) {
78 // Not handling the user argument
79 return false;
80 }
81
82 if (MetadataExtractor::hasForEachSignatureCtxt(signature)) {
83 // Not handling the context argument
84 return false;
85 }
86
87 TypeVoidInst *VoidTy = m->getVoidType();
88 TypeFunctionInst *FuncTy = m->getFunctionType(VoidTy, nullptr, 0);
89 FunctionDefinition *Func =
90 b.MakeFunctionDefinition(VoidTy, FunctionControl::None, FuncTy);
91 m->addFunctionDefinition(Func);
92
93 Block *Blk = b.MakeBlock();
94 Func->addBlock(Blk);
95
96 Blk->addInstruction(b.MakeLabel());
97
98 TypeIntInst *UIntTy = m->getUnsignedIntType(32);
99
100 Instruction *XValue = nullptr;
101 Instruction *YValue = nullptr;
102 Instruction *ZValue = nullptr;
103 Instruction *Index = nullptr;
104 VariableInst *InvocationId = nullptr;
105 VariableInst *NumWorkgroups = nullptr;
106
107 if (MetadataExtractor::hasForEachSignatureIn(signature) ||
108 MetadataExtractor::hasForEachSignatureOut(signature) ||
109 MetadataExtractor::hasForEachSignatureX(signature) ||
110 MetadataExtractor::hasForEachSignatureY(signature) ||
111 MetadataExtractor::hasForEachSignatureZ(signature)) {
112 TypeVectorInst *V3UIntTy = m->getVectorType(UIntTy, 3);
113 InvocationId = m->getInvocationId();
114 auto IID = b.MakeLoad(V3UIntTy, InvocationId);
115 Blk->addInstruction(IID);
116
117 XValue = b.MakeCompositeExtract(UIntTy, IID, {0});
118 Blk->addInstruction(XValue);
119
120 YValue = b.MakeCompositeExtract(UIntTy, IID, {1});
121 Blk->addInstruction(YValue);
122
123 ZValue = b.MakeCompositeExtract(UIntTy, IID, {2});
124 Blk->addInstruction(ZValue);
125
126 // TODO: Use SpecConstant for workgroup size
127 auto ConstOne = m->getConstant(UIntTy, 1U);
128 auto GroupSize =
129 m->getConstantComposite(V3UIntTy, ConstOne, ConstOne, ConstOne);
130
131 auto GroupSizeX = b.MakeCompositeExtract(UIntTy, GroupSize, {0});
132 Blk->addInstruction(GroupSizeX);
133
134 auto GroupSizeY = b.MakeCompositeExtract(UIntTy, GroupSize, {1});
135 Blk->addInstruction(GroupSizeY);
136
137 NumWorkgroups = m->getNumWorkgroups();
138 auto NumGroup = b.MakeLoad(V3UIntTy, NumWorkgroups);
139 Blk->addInstruction(NumGroup);
140
141 auto NumGroupX = b.MakeCompositeExtract(UIntTy, NumGroup, {0});
142 Blk->addInstruction(NumGroupX);
143
144 auto NumGroupY = b.MakeCompositeExtract(UIntTy, NumGroup, {1});
145 Blk->addInstruction(NumGroupY);
146
147 auto GlobalSizeX = b.MakeIMul(UIntTy, GroupSizeX, NumGroupX);
148 Blk->addInstruction(GlobalSizeX);
149
150 auto GlobalSizeY = b.MakeIMul(UIntTy, GroupSizeY, NumGroupY);
151 Blk->addInstruction(GlobalSizeY);
152
153 auto RowsAlongZ = b.MakeIMul(UIntTy, GlobalSizeY, ZValue);
154 Blk->addInstruction(RowsAlongZ);
155
156 auto NumRows = b.MakeIAdd(UIntTy, YValue, RowsAlongZ);
157 Blk->addInstruction(NumRows);
158
159 auto NumCellsFromYZ = b.MakeIMul(UIntTy, GlobalSizeX, NumRows);
160 Blk->addInstruction(NumCellsFromYZ);
161
162 Index = b.MakeIAdd(UIntTy, NumCellsFromYZ, XValue);
163 Blk->addInstruction(Index);
164 }
165
166 std::vector<IdRef> inputs;
167
168 ConstantInst *ConstZero = m->getConstant(UIntTy, 0);
169
170 for (uint32_t i = 0; i < numInput; i++) {
171 FunctionParameterInst *param = kernel->getParameter(i);
172 Instruction *elementType = param->mResultType.mInstruction;
173 VariableInst *inputBuffer = AddBuffer(elementType, i + 2, b, m);
174
175 TypePointerInst *PtrTy =
176 m->getPointerType(StorageClass::Function, elementType);
177 AccessChainInst *Ptr =
178 b.MakeAccessChain(PtrTy, inputBuffer, {ConstZero, Index});
179 Blk->addInstruction(Ptr);
180
181 Instruction *input = b.MakeLoad(elementType, Ptr);
182 Blk->addInstruction(input);
183
184 inputs.push_back(IdRef(input));
185 }
186
187 // TODO: Convert from unsigned int to signed int if that is what the kernel
188 // function takes for the coordinate parameters
189 if (MetadataExtractor::hasForEachSignatureX(signature)) {
190 inputs.push_back(XValue);
191 if (MetadataExtractor::hasForEachSignatureY(signature)) {
192 inputs.push_back(YValue);
193 if (MetadataExtractor::hasForEachSignatureZ(signature)) {
194 inputs.push_back(ZValue);
195 }
196 }
197 }
198
199 auto resultType = kernel->getReturnType();
200 auto kernelCall =
201 b.MakeFunctionCall(resultType, kernel->getInstruction(), inputs);
202 Blk->addInstruction(kernelCall);
203
204 if (MetadataExtractor::hasForEachSignatureOut(signature)) {
205 VariableInst *OutputBuffer = AddBuffer(resultType, 1, b, m);
206 auto resultPtrType = m->getPointerType(StorageClass::Function, resultType);
207 AccessChainInst *OutPtr =
208 b.MakeAccessChain(resultPtrType, OutputBuffer, {ConstZero, Index});
209 Blk->addInstruction(OutPtr);
210 Blk->addInstruction(b.MakeStore(OutPtr, kernelCall));
211 }
212
213 Blk->addInstruction(b.MakeReturn());
214
215 std::string wrapperName("entry_");
216 wrapperName.append(name);
217
218 EntryPointDefinition *entry = b.MakeEntryPointDefinition(
219 ExecutionModel::GLCompute, Func, wrapperName.c_str());
220
221 entry->setLocalSize(1, 1, 1);
222
223 if (Index != nullptr) {
224 entry->addToInterface(InvocationId);
225 entry->addToInterface(NumWorkgroups);
226 }
227
228 m->addEntryPoint(entry);
229
230 return true;
231 }
232
DecorateGlobalBuffer(llvm::Module & LM,Builder & b,Module * m)233 bool DecorateGlobalBuffer(llvm::Module &LM, Builder &b, Module *m) {
234 Instruction *inst = m->lookupByName("__GPUBlock");
235 if (inst == nullptr) {
236 return true;
237 }
238
239 VariableInst *bufferVar = static_cast<VariableInst *>(inst);
240 bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0);
241 bufferVar->decorate(Decoration::Binding)->addExtraOperand(0);
242
243 TypePointerInst *StructPtrTy =
244 static_cast<TypePointerInst *>(bufferVar->mResultType.mInstruction);
245 TypeStructInst *StructTy =
246 static_cast<TypeStructInst *>(StructPtrTy->mOperand2.mInstruction);
247 StructTy->decorate(Decoration::BufferBlock);
248
249 // Decorate each member with proper offsets
250
251 const auto GlobalsB = LM.globals().begin();
252 const auto GlobalsE = LM.globals().end();
253 const auto Found =
254 std::find_if(GlobalsB, GlobalsE, [](const llvm::GlobalVariable &GV) {
255 return GV.getName() == "__GPUBlock";
256 });
257
258 if (Found == GlobalsE) {
259 return true; // GPUBlock not found - not an error by itself.
260 }
261
262 const llvm::GlobalVariable &G = *Found;
263
264 bool IsCorrectTy = false;
265 if (const auto *LPtrTy = llvm::dyn_cast<llvm::PointerType>(G.getType())) {
266 if (auto *LStructTy =
267 llvm::dyn_cast<llvm::StructType>(LPtrTy->getElementType())) {
268 IsCorrectTy = true;
269
270 const auto &DLayout = LM.getDataLayout();
271 const auto *SLayout = DLayout.getStructLayout(LStructTy);
272 assert(SLayout);
273 if (SLayout == nullptr) {
274 std::cerr << "struct layout is null" << std::endl;
275 return false;
276 }
277 for (uint32_t i = 0, e = LStructTy->getNumElements(); i != e; ++i) {
278 auto decor = StructTy->memberDecorate(i, Decoration::Offset);
279 if (!decor) {
280 std::cerr << "failed creating member decoration for field " << i
281 << std::endl;
282 return false;
283 }
284 const uint32_t offset = (uint32_t)SLayout->getElementOffset(i);
285 decor->addExtraOperand(offset);
286 }
287 }
288 }
289
290 if (!IsCorrectTy) {
291 return false;
292 }
293
294 llvm::SmallVector<rs2spirv::RSAllocationInfo, 2> RSAllocs;
295 if (!getRSAllocationInfo(LM, RSAllocs)) {
296 // llvm::errs() << "Extracting rs_allocation info failed\n";
297 return true;
298 }
299
300 // TODO: clean up the binding number assignment
301 size_t BindingNum = 3;
302 for (const auto &A : RSAllocs) {
303 Instruction *inst = m->lookupByName(A.VarName.c_str());
304 if (inst == nullptr) {
305 return false;
306 }
307 VariableInst *bufferVar = static_cast<VariableInst *>(inst);
308 bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0);
309 bufferVar->decorate(Decoration::Binding)->addExtraOperand(BindingNum++);
310 }
311
312 return true;
313 }
314
AddHeader(Module * m)315 void AddHeader(Module *m) {
316 m->addCapability(Capability::Shader);
317 // TODO: avoid duplicated capability
318 // m->addCapability(Capability::Addresses);
319 m->setMemoryModel(AddressingModel::Physical32, MemoryModel::GLSL450);
320
321 m->addSource(SourceLanguage::GLSL, 450);
322 m->addSourceExtension("GL_ARB_separate_shader_objects");
323 m->addSourceExtension("GL_ARB_shading_language_420pack");
324 m->addSourceExtension("GL_GOOGLE_cpp_style_line_directive");
325 m->addSourceExtension("GL_GOOGLE_include_directive");
326 }
327
328 namespace {
329
330 class StorageClassVisitor : public DoNothingVisitor {
331 public:
visit(TypePointerInst * inst)332 void visit(TypePointerInst *inst) override {
333 matchAndReplace(inst->mOperand1);
334 }
335
visit(TypeForwardPointerInst * inst)336 void visit(TypeForwardPointerInst *inst) override {
337 matchAndReplace(inst->mOperand2);
338 }
339
visit(VariableInst * inst)340 void visit(VariableInst *inst) override { matchAndReplace(inst->mOperand1); }
341
342 private:
matchAndReplace(StorageClass & storage)343 void matchAndReplace(StorageClass &storage) {
344 if (storage == StorageClass::Function) {
345 storage = StorageClass::Uniform;
346 }
347 }
348 };
349
FixGlobalStorageClass(Module * m)350 void FixGlobalStorageClass(Module *m) {
351 StorageClassVisitor v;
352 m->getGlobalSection()->accept(&v);
353 }
354
355 } // anonymous namespace
356
AddWrappers(llvm::Module & LM,android::spirit::Module * m)357 bool AddWrappers(llvm::Module &LM,
358 android::spirit::Module *m) {
359 rs2spirv::Context &Ctxt = rs2spirv::Context::getInstance();
360 const bcinfo::MetadataExtractor &metadata = Ctxt.getMetadata();
361 android::spirit::Builder b;
362
363 m->setBuilder(&b);
364
365 FixGlobalStorageClass(m);
366
367 AddHeader(m);
368
369 DecorateGlobalBuffer(LM, b, m);
370
371 const size_t numKernel = metadata.getExportForEachSignatureCount();
372 const char **kernelName = metadata.getExportForEachNameList();
373 const uint32_t *kernelSigature = metadata.getExportForEachSignatureList();
374 const uint32_t *inputCount = metadata.getExportForEachInputCountList();
375
376 for (size_t i = 0; i < numKernel; i++) {
377 bool success =
378 AddWrapper(kernelName[i], kernelSigature[i], inputCount[i], b, m);
379 if (!success) {
380 return false;
381 }
382 }
383
384 m->consolidateAnnotations();
385 return true;
386 }
387
388 class WrapperPass : public Pass {
389 public:
WrapperPass(const llvm::Module & LM)390 WrapperPass(const llvm::Module &LM) : mLLVMModule(const_cast<llvm::Module&>(LM)) {}
391
run(Module * m,int * error)392 Module *run(Module *m, int *error) override {
393 bool success = AddWrappers(mLLVMModule, m);
394 if (error) {
395 *error = success ? 0 : -1;
396 }
397 return m;
398 }
399
400 private:
401 llvm::Module &mLLVMModule;
402 };
403
404 } // namespace spirit
405 } // namespace android
406
407 namespace rs2spirv {
408
CreateWrapperPass(const llvm::Module & LLVMModule)409 android::spirit::Pass* CreateWrapperPass(const llvm::Module &LLVMModule) {
410 return new android::spirit::WrapperPass(LLVMModule);
411 }
412
413 } // namespace rs2spirv
414