1 /*
2  * Copyright 2017, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "Wrapper.h"
18 
19 #include "llvm/IR/Module.h"
20 
21 #include "Builtin.h"
22 #include "Context.h"
23 #include "GlobalAllocSPIRITPass.h"
24 #include "RSAllocationUtils.h"
25 #include "bcinfo/MetadataExtractor.h"
26 #include "builder.h"
27 #include "instructions.h"
28 #include "module.h"
29 #include "pass.h"
30 
31 #include <sstream>
32 #include <vector>
33 
34 using bcinfo::MetadataExtractor;
35 
36 namespace android {
37 namespace spirit {
38 
AddBuffer(Instruction * elementType,uint32_t binding,Builder & b,Module * m)39 VariableInst *AddBuffer(Instruction *elementType, uint32_t binding, Builder &b,
40                         Module *m) {
41   auto ArrTy = m->getRuntimeArrayType(elementType);
42   const size_t stride = m->getSize(elementType);
43   ArrTy->decorate(Decoration::ArrayStride)->addExtraOperand(stride);
44   auto StructTy = m->getStructType(ArrTy);
45   StructTy->decorate(Decoration::BufferBlock);
46   StructTy->memberDecorate(0, Decoration::Offset)->addExtraOperand(0);
47 
48   auto StructPtrTy = m->getPointerType(StorageClass::Uniform, StructTy);
49 
50   VariableInst *bufferVar = b.MakeVariable(StructPtrTy, StorageClass::Uniform);
51   bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0);
52   bufferVar->decorate(Decoration::Binding)->addExtraOperand(binding);
53   m->addVariable(bufferVar);
54 
55   return bufferVar;
56 }
57 
AddWrapper(const char * name,const uint32_t signature,const uint32_t numInput,Builder & b,Module * m)58 bool AddWrapper(const char *name, const uint32_t signature,
59                 const uint32_t numInput, Builder &b, Module *m) {
60   FunctionDefinition *kernel = m->lookupFunctionDefinitionByName(name);
61   if (kernel == nullptr) {
62     // In the metadata for RenderScript LLVM bitcode, the first foreach kernel
63     // is always reserved for the root kernel, even though in the most recent RS
64     // apps it does not exist. Simply bypass wrapper generation here, and return
65     // true for this case.
66     // Otherwise, if a non-root kernel function cannot be found, it is a
67     // fatal internal error which is really unexpected.
68     return (strncmp(name, "root", 4) == 0);
69   }
70 
71   // The following three cases are not supported
72   if (!MetadataExtractor::hasForEachSignatureKernel(signature)) {
73     // Not handling old-style kernel
74     return false;
75   }
76 
77   if (MetadataExtractor::hasForEachSignatureUsrData(signature)) {
78     // Not handling the user argument
79     return false;
80   }
81 
82   if (MetadataExtractor::hasForEachSignatureCtxt(signature)) {
83     // Not handling the context argument
84     return false;
85   }
86 
87   TypeVoidInst *VoidTy = m->getVoidType();
88   TypeFunctionInst *FuncTy = m->getFunctionType(VoidTy, nullptr, 0);
89   FunctionDefinition *Func =
90       b.MakeFunctionDefinition(VoidTy, FunctionControl::None, FuncTy);
91   m->addFunctionDefinition(Func);
92 
93   Block *Blk = b.MakeBlock();
94   Func->addBlock(Blk);
95 
96   Blk->addInstruction(b.MakeLabel());
97 
98   TypeIntInst *UIntTy = m->getUnsignedIntType(32);
99 
100   Instruction *XValue = nullptr;
101   Instruction *YValue = nullptr;
102   Instruction *ZValue = nullptr;
103   Instruction *Index = nullptr;
104   VariableInst *InvocationId = nullptr;
105   VariableInst *NumWorkgroups = nullptr;
106 
107   if (MetadataExtractor::hasForEachSignatureIn(signature) ||
108       MetadataExtractor::hasForEachSignatureOut(signature) ||
109       MetadataExtractor::hasForEachSignatureX(signature) ||
110       MetadataExtractor::hasForEachSignatureY(signature) ||
111       MetadataExtractor::hasForEachSignatureZ(signature)) {
112     TypeVectorInst *V3UIntTy = m->getVectorType(UIntTy, 3);
113     InvocationId = m->getInvocationId();
114     auto IID = b.MakeLoad(V3UIntTy, InvocationId);
115     Blk->addInstruction(IID);
116 
117     XValue = b.MakeCompositeExtract(UIntTy, IID, {0});
118     Blk->addInstruction(XValue);
119 
120     YValue = b.MakeCompositeExtract(UIntTy, IID, {1});
121     Blk->addInstruction(YValue);
122 
123     ZValue = b.MakeCompositeExtract(UIntTy, IID, {2});
124     Blk->addInstruction(ZValue);
125 
126     // TODO: Use SpecConstant for workgroup size
127     auto ConstOne = m->getConstant(UIntTy, 1U);
128     auto GroupSize =
129         m->getConstantComposite(V3UIntTy, ConstOne, ConstOne, ConstOne);
130 
131     auto GroupSizeX = b.MakeCompositeExtract(UIntTy, GroupSize, {0});
132     Blk->addInstruction(GroupSizeX);
133 
134     auto GroupSizeY = b.MakeCompositeExtract(UIntTy, GroupSize, {1});
135     Blk->addInstruction(GroupSizeY);
136 
137     NumWorkgroups = m->getNumWorkgroups();
138     auto NumGroup = b.MakeLoad(V3UIntTy, NumWorkgroups);
139     Blk->addInstruction(NumGroup);
140 
141     auto NumGroupX = b.MakeCompositeExtract(UIntTy, NumGroup, {0});
142     Blk->addInstruction(NumGroupX);
143 
144     auto NumGroupY = b.MakeCompositeExtract(UIntTy, NumGroup, {1});
145     Blk->addInstruction(NumGroupY);
146 
147     auto GlobalSizeX = b.MakeIMul(UIntTy, GroupSizeX, NumGroupX);
148     Blk->addInstruction(GlobalSizeX);
149 
150     auto GlobalSizeY = b.MakeIMul(UIntTy, GroupSizeY, NumGroupY);
151     Blk->addInstruction(GlobalSizeY);
152 
153     auto RowsAlongZ = b.MakeIMul(UIntTy, GlobalSizeY, ZValue);
154     Blk->addInstruction(RowsAlongZ);
155 
156     auto NumRows = b.MakeIAdd(UIntTy, YValue, RowsAlongZ);
157     Blk->addInstruction(NumRows);
158 
159     auto NumCellsFromYZ = b.MakeIMul(UIntTy, GlobalSizeX, NumRows);
160     Blk->addInstruction(NumCellsFromYZ);
161 
162     Index = b.MakeIAdd(UIntTy, NumCellsFromYZ, XValue);
163     Blk->addInstruction(Index);
164   }
165 
166   std::vector<IdRef> inputs;
167 
168   ConstantInst *ConstZero = m->getConstant(UIntTy, 0);
169 
170   for (uint32_t i = 0; i < numInput; i++) {
171     FunctionParameterInst *param = kernel->getParameter(i);
172     Instruction *elementType = param->mResultType.mInstruction;
173     VariableInst *inputBuffer = AddBuffer(elementType, i + 3, b, m);
174 
175     TypePointerInst *PtrTy =
176         m->getPointerType(StorageClass::Function, elementType);
177     AccessChainInst *Ptr =
178         b.MakeAccessChain(PtrTy, inputBuffer, {ConstZero, Index});
179     Blk->addInstruction(Ptr);
180 
181     Instruction *input = b.MakeLoad(elementType, Ptr);
182     Blk->addInstruction(input);
183 
184     inputs.push_back(IdRef(input));
185   }
186 
187   // TODO: Convert from unsigned int to signed int if that is what the kernel
188   // function takes for the coordinate parameters
189   if (MetadataExtractor::hasForEachSignatureX(signature)) {
190     inputs.push_back(XValue);
191     if (MetadataExtractor::hasForEachSignatureY(signature)) {
192       inputs.push_back(YValue);
193       if (MetadataExtractor::hasForEachSignatureZ(signature)) {
194         inputs.push_back(ZValue);
195       }
196     }
197   }
198 
199   auto resultType = kernel->getReturnType();
200   auto kernelCall =
201       b.MakeFunctionCall(resultType, kernel->getInstruction(), inputs);
202   Blk->addInstruction(kernelCall);
203 
204   if (MetadataExtractor::hasForEachSignatureOut(signature)) {
205     VariableInst *OutputBuffer = AddBuffer(resultType, 2, b, m);
206     auto resultPtrType = m->getPointerType(StorageClass::Function, resultType);
207     AccessChainInst *OutPtr =
208         b.MakeAccessChain(resultPtrType, OutputBuffer, {ConstZero, Index});
209     Blk->addInstruction(OutPtr);
210     Blk->addInstruction(b.MakeStore(OutPtr, kernelCall));
211   }
212 
213   Blk->addInstruction(b.MakeReturn());
214 
215   std::string wrapperName("entry_");
216   wrapperName.append(name);
217 
218   EntryPointDefinition *entry = b.MakeEntryPointDefinition(
219       ExecutionModel::GLCompute, Func, wrapperName.c_str());
220 
221   entry->setLocalSize(1, 1, 1);
222 
223   if (Index != nullptr) {
224     entry->addToInterface(InvocationId);
225     entry->addToInterface(NumWorkgroups);
226   }
227 
228   m->addEntryPoint(entry);
229 
230   return true;
231 }
232 
DecorateGlobalBuffer(llvm::Module & LM,Builder & b,Module * m)233 bool DecorateGlobalBuffer(llvm::Module &LM, Builder &b, Module *m) {
234   Instruction *inst = m->lookupByName("__GPUBlock");
235   if (inst == nullptr) {
236     return true;
237   }
238 
239   VariableInst *bufferVar = static_cast<VariableInst *>(inst);
240   bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0);
241   bufferVar->decorate(Decoration::Binding)->addExtraOperand(0);
242 
243   TypePointerInst *StructPtrTy =
244       static_cast<TypePointerInst *>(bufferVar->mResultType.mInstruction);
245   TypeStructInst *StructTy =
246       static_cast<TypeStructInst *>(StructPtrTy->mOperand2.mInstruction);
247   StructTy->decorate(Decoration::BufferBlock);
248 
249   // Decorate each member with proper offsets
250 
251   const auto GlobalsB = LM.globals().begin();
252   const auto GlobalsE = LM.globals().end();
253   const auto Found =
254       std::find_if(GlobalsB, GlobalsE, [](const llvm::GlobalVariable &GV) {
255         return GV.getName() == "__GPUBlock";
256       });
257 
258   if (Found == GlobalsE) {
259     return true; // GPUBlock not found - not an error by itself.
260   }
261 
262   const llvm::GlobalVariable &G = *Found;
263 
264   rs2spirv::Context &Ctxt = rs2spirv::Context::getInstance();
265   bool IsCorrectTy = false;
266   if (const auto *LPtrTy = llvm::dyn_cast<llvm::PointerType>(G.getType())) {
267     if (auto *LStructTy =
268             llvm::dyn_cast<llvm::StructType>(LPtrTy->getElementType())) {
269       IsCorrectTy = true;
270 
271       const auto &DLayout = LM.getDataLayout();
272       const auto *SLayout = DLayout.getStructLayout(LStructTy);
273       assert(SLayout);
274       if (SLayout == nullptr) {
275         std::cerr << "struct layout is null" << std::endl;
276         return false;
277       }
278       std::vector<uint32_t> offsets;
279       for (uint32_t i = 0, e = LStructTy->getNumElements(); i != e; ++i) {
280         auto decor = StructTy->memberDecorate(i, Decoration::Offset);
281         if (!decor) {
282           std::cerr << "failed creating member decoration for field " << i
283                     << std::endl;
284           return false;
285         }
286         const uint32_t offset = (uint32_t)SLayout->getElementOffset(i);
287         decor->addExtraOperand(offset);
288         offsets.push_back(offset);
289       }
290       std::stringstream ssOffsets;
291       // TODO: define this string in a central place
292       ssOffsets << ".rsov.ExportedVars:";
293       for(uint32_t slot = 0; slot < Ctxt.getNumExportVar(); slot++) {
294         const uint32_t index = Ctxt.getExportVarIndex(slot);
295         const uint32_t offset = offsets[index];
296         ssOffsets << offset << ';';
297       }
298       m->addString(ssOffsets.str().c_str());
299 
300       std::stringstream ssGlobalSize;
301       ssGlobalSize << ".rsov.GlobalSize:" << Ctxt.getGlobalSize();
302       m->addString(ssGlobalSize.str().c_str());
303     }
304   }
305 
306   if (!IsCorrectTy) {
307     return false;
308   }
309 
310   llvm::SmallVector<rs2spirv::RSAllocationInfo, 2> RSAllocs;
311   if (!getRSAllocationInfo(LM, RSAllocs)) {
312     // llvm::errs() << "Extracting rs_allocation info failed\n";
313     return true;
314   }
315 
316   // TODO: clean up the binding number assignment
317   size_t BindingNum = 3;
318   for (const auto &A : RSAllocs) {
319     Instruction *inst = m->lookupByName(A.VarName.c_str());
320     if (inst == nullptr) {
321       return false;
322     }
323     VariableInst *bufferVar = static_cast<VariableInst *>(inst);
324     bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0);
325     bufferVar->decorate(Decoration::Binding)->addExtraOperand(BindingNum++);
326   }
327 
328   return true;
329 }
330 
AddHeader(Module * m)331 void AddHeader(Module *m) {
332   m->addCapability(Capability::Shader);
333   m->setMemoryModel(AddressingModel::Logical, MemoryModel::GLSL450);
334 
335   m->addSource(SourceLanguage::GLSL, 450);
336   m->addSourceExtension("GL_ARB_separate_shader_objects");
337   m->addSourceExtension("GL_ARB_shading_language_420pack");
338   m->addSourceExtension("GL_GOOGLE_cpp_style_line_directive");
339   m->addSourceExtension("GL_GOOGLE_include_directive");
340 }
341 
342 namespace {
343 
344 class StorageClassVisitor : public DoNothingVisitor {
345 public:
visit(TypePointerInst * inst)346   void visit(TypePointerInst *inst) override {
347     matchAndReplace(inst->mOperand1);
348   }
349 
visit(TypeForwardPointerInst * inst)350   void visit(TypeForwardPointerInst *inst) override {
351     matchAndReplace(inst->mOperand2);
352   }
353 
visit(VariableInst * inst)354   void visit(VariableInst *inst) override { matchAndReplace(inst->mOperand1); }
355 
356 private:
matchAndReplace(StorageClass & storage)357   void matchAndReplace(StorageClass &storage) {
358     if (storage == StorageClass::Function) {
359       storage = StorageClass::Uniform;
360     }
361   }
362 };
363 
FixGlobalStorageClass(Module * m)364 void FixGlobalStorageClass(Module *m) {
365   StorageClassVisitor v;
366   m->getGlobalSection()->accept(&v);
367 }
368 
369 } // anonymous namespace
370 
AddWrappers(llvm::Module & LM,android::spirit::Module * m)371 bool AddWrappers(llvm::Module &LM,
372                  android::spirit::Module *m) {
373   rs2spirv::Context &Ctxt = rs2spirv::Context::getInstance();
374   const bcinfo::MetadataExtractor &metadata = Ctxt.getMetadata();
375   android::spirit::Builder b;
376 
377   m->setBuilder(&b);
378 
379   FixGlobalStorageClass(m);
380 
381   AddHeader(m);
382 
383   DecorateGlobalBuffer(LM, b, m);
384 
385   const size_t numKernel = metadata.getExportForEachSignatureCount();
386   const char **kernelName = metadata.getExportForEachNameList();
387   const uint32_t *kernelSigature = metadata.getExportForEachSignatureList();
388   const uint32_t *inputCount = metadata.getExportForEachInputCountList();
389 
390   for (size_t i = 0; i < numKernel; i++) {
391     bool success =
392         AddWrapper(kernelName[i], kernelSigature[i], inputCount[i], b, m);
393     if (!success) {
394       return false;
395     }
396   }
397 
398   m->consolidateAnnotations();
399   return true;
400 }
401 
402 class WrapperPass : public Pass {
403 public:
WrapperPass(const llvm::Module & LM)404   WrapperPass(const llvm::Module &LM) : mLLVMModule(const_cast<llvm::Module&>(LM)) {}
405 
run(Module * m,int * error)406   Module *run(Module *m, int *error) override {
407     bool success = AddWrappers(mLLVMModule, m);
408     if (error) {
409       *error = success ? 0 : -1;
410     }
411     return m;
412   }
413 
414 private:
415   llvm::Module &mLLVMModule;
416 };
417 
418 } // namespace spirit
419 } // namespace android
420 
421 namespace rs2spirv {
422 
CreateWrapperPass(const llvm::Module & LLVMModule)423 android::spirit::Pass* CreateWrapperPass(const llvm::Module &LLVMModule) {
424   return new android::spirit::WrapperPass(LLVMModule);
425 }
426 
427 } // namespace rs2spirv
428