1 /*
2  * Copyright 2017, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "Wrapper.h"
18 
19 #include "llvm/IR/Module.h"
20 
21 #include "Builtin.h"
22 #include "Context.h"
23 #include "GlobalAllocSPIRITPass.h"
24 #include "RSAllocationUtils.h"
25 #include "bcinfo/MetadataExtractor.h"
26 #include "builder.h"
27 #include "instructions.h"
28 #include "module.h"
29 #include "pass.h"
30 
31 #include <sstream>
32 #include <vector>
33 
34 using bcinfo::MetadataExtractor;
35 
36 namespace android {
37 namespace spirit {
38 
AddBuffer(Instruction * elementType,uint32_t binding,Builder & b,Module * m)39 VariableInst *AddBuffer(Instruction *elementType, uint32_t binding, Builder &b,
40                         Module *m) {
41   auto ArrTy = m->getRuntimeArrayType(elementType);
42   const size_t stride = m->getSize(elementType);
43   ArrTy->decorate(Decoration::ArrayStride)->addExtraOperand(stride);
44   auto StructTy = m->getStructType(ArrTy);
45   StructTy->decorate(Decoration::BufferBlock);
46   StructTy->memberDecorate(0, Decoration::Offset)->addExtraOperand(0);
47 
48   auto StructPtrTy = m->getPointerType(StorageClass::Uniform, StructTy);
49 
50   VariableInst *bufferVar = b.MakeVariable(StructPtrTy, StorageClass::Uniform);
51   bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0);
52   bufferVar->decorate(Decoration::Binding)->addExtraOperand(binding);
53   m->addVariable(bufferVar);
54 
55   return bufferVar;
56 }
57 
AddWrapper(const char * name,const uint32_t signature,const uint32_t numInput,Builder & b,Module * m)58 bool AddWrapper(const char *name, const uint32_t signature,
59                 const uint32_t numInput, Builder &b, Module *m) {
60   FunctionDefinition *kernel = m->lookupFunctionDefinitionByName(name);
61   if (kernel == nullptr) {
62     // In the metadata for RenderScript LLVM bitcode, the first foreach kernel
63     // is always reserved for the root kernel, even though in the most recent RS
64     // apps it does not exist. Simply bypass wrapper generation here, and return
65     // true for this case.
66     // Otherwise, if a non-root kernel function cannot be found, it is a
67     // fatal internal error which is really unexpected.
68     return (strncmp(name, "root", 4) == 0);
69   }
70 
71   // The following three cases are not supported
72   if (!MetadataExtractor::hasForEachSignatureKernel(signature)) {
73     // Not handling old-style kernel
74     return false;
75   }
76 
77   if (MetadataExtractor::hasForEachSignatureUsrData(signature)) {
78     // Not handling the user argument
79     return false;
80   }
81 
82   if (MetadataExtractor::hasForEachSignatureCtxt(signature)) {
83     // Not handling the context argument
84     return false;
85   }
86 
87   TypeVoidInst *VoidTy = m->getVoidType();
88   TypeFunctionInst *FuncTy = m->getFunctionType(VoidTy, nullptr, 0);
89   FunctionDefinition *Func =
90       b.MakeFunctionDefinition(VoidTy, FunctionControl::None, FuncTy);
91   m->addFunctionDefinition(Func);
92 
93   Block *Blk = b.MakeBlock();
94   Func->addBlock(Blk);
95 
96   Blk->addInstruction(b.MakeLabel());
97 
98   TypeIntInst *UIntTy = m->getUnsignedIntType(32);
99 
100   Instruction *XValue = nullptr;
101   Instruction *YValue = nullptr;
102   Instruction *ZValue = nullptr;
103   Instruction *Index = nullptr;
104   VariableInst *InvocationId = nullptr;
105   VariableInst *NumWorkgroups = nullptr;
106 
107   if (MetadataExtractor::hasForEachSignatureIn(signature) ||
108       MetadataExtractor::hasForEachSignatureOut(signature) ||
109       MetadataExtractor::hasForEachSignatureX(signature) ||
110       MetadataExtractor::hasForEachSignatureY(signature) ||
111       MetadataExtractor::hasForEachSignatureZ(signature)) {
112     TypeVectorInst *V3UIntTy = m->getVectorType(UIntTy, 3);
113     InvocationId = m->getInvocationId();
114     auto IID = b.MakeLoad(V3UIntTy, InvocationId);
115     Blk->addInstruction(IID);
116 
117     XValue = b.MakeCompositeExtract(UIntTy, IID, {0});
118     Blk->addInstruction(XValue);
119 
120     YValue = b.MakeCompositeExtract(UIntTy, IID, {1});
121     Blk->addInstruction(YValue);
122 
123     ZValue = b.MakeCompositeExtract(UIntTy, IID, {2});
124     Blk->addInstruction(ZValue);
125 
126     // TODO: Use SpecConstant for workgroup size
127     auto ConstOne = m->getConstant(UIntTy, 1U);
128     auto GroupSize =
129         m->getConstantComposite(V3UIntTy, ConstOne, ConstOne, ConstOne);
130 
131     auto GroupSizeX = b.MakeCompositeExtract(UIntTy, GroupSize, {0});
132     Blk->addInstruction(GroupSizeX);
133 
134     auto GroupSizeY = b.MakeCompositeExtract(UIntTy, GroupSize, {1});
135     Blk->addInstruction(GroupSizeY);
136 
137     NumWorkgroups = m->getNumWorkgroups();
138     auto NumGroup = b.MakeLoad(V3UIntTy, NumWorkgroups);
139     Blk->addInstruction(NumGroup);
140 
141     auto NumGroupX = b.MakeCompositeExtract(UIntTy, NumGroup, {0});
142     Blk->addInstruction(NumGroupX);
143 
144     auto NumGroupY = b.MakeCompositeExtract(UIntTy, NumGroup, {1});
145     Blk->addInstruction(NumGroupY);
146 
147     auto GlobalSizeX = b.MakeIMul(UIntTy, GroupSizeX, NumGroupX);
148     Blk->addInstruction(GlobalSizeX);
149 
150     auto GlobalSizeY = b.MakeIMul(UIntTy, GroupSizeY, NumGroupY);
151     Blk->addInstruction(GlobalSizeY);
152 
153     auto RowsAlongZ = b.MakeIMul(UIntTy, GlobalSizeY, ZValue);
154     Blk->addInstruction(RowsAlongZ);
155 
156     auto NumRows = b.MakeIAdd(UIntTy, YValue, RowsAlongZ);
157     Blk->addInstruction(NumRows);
158 
159     auto NumCellsFromYZ = b.MakeIMul(UIntTy, GlobalSizeX, NumRows);
160     Blk->addInstruction(NumCellsFromYZ);
161 
162     Index = b.MakeIAdd(UIntTy, NumCellsFromYZ, XValue);
163     Blk->addInstruction(Index);
164   }
165 
166   std::vector<IdRef> inputs;
167 
168   ConstantInst *ConstZero = m->getConstant(UIntTy, 0);
169 
170   for (uint32_t i = 0; i < numInput; i++) {
171     FunctionParameterInst *param = kernel->getParameter(i);
172     Instruction *elementType = param->mResultType.mInstruction;
173     VariableInst *inputBuffer = AddBuffer(elementType, i + 2, b, m);
174 
175     TypePointerInst *PtrTy =
176         m->getPointerType(StorageClass::Function, elementType);
177     AccessChainInst *Ptr =
178         b.MakeAccessChain(PtrTy, inputBuffer, {ConstZero, Index});
179     Blk->addInstruction(Ptr);
180 
181     Instruction *input = b.MakeLoad(elementType, Ptr);
182     Blk->addInstruction(input);
183 
184     inputs.push_back(IdRef(input));
185   }
186 
187   // TODO: Convert from unsigned int to signed int if that is what the kernel
188   // function takes for the coordinate parameters
189   if (MetadataExtractor::hasForEachSignatureX(signature)) {
190     inputs.push_back(XValue);
191     if (MetadataExtractor::hasForEachSignatureY(signature)) {
192       inputs.push_back(YValue);
193       if (MetadataExtractor::hasForEachSignatureZ(signature)) {
194         inputs.push_back(ZValue);
195       }
196     }
197   }
198 
199   auto resultType = kernel->getReturnType();
200   auto kernelCall =
201       b.MakeFunctionCall(resultType, kernel->getInstruction(), inputs);
202   Blk->addInstruction(kernelCall);
203 
204   if (MetadataExtractor::hasForEachSignatureOut(signature)) {
205     VariableInst *OutputBuffer = AddBuffer(resultType, 1, b, m);
206     auto resultPtrType = m->getPointerType(StorageClass::Function, resultType);
207     AccessChainInst *OutPtr =
208         b.MakeAccessChain(resultPtrType, OutputBuffer, {ConstZero, Index});
209     Blk->addInstruction(OutPtr);
210     Blk->addInstruction(b.MakeStore(OutPtr, kernelCall));
211   }
212 
213   Blk->addInstruction(b.MakeReturn());
214 
215   std::string wrapperName("entry_");
216   wrapperName.append(name);
217 
218   EntryPointDefinition *entry = b.MakeEntryPointDefinition(
219       ExecutionModel::GLCompute, Func, wrapperName.c_str());
220 
221   entry->setLocalSize(1, 1, 1);
222 
223   if (Index != nullptr) {
224     entry->addToInterface(InvocationId);
225     entry->addToInterface(NumWorkgroups);
226   }
227 
228   m->addEntryPoint(entry);
229 
230   return true;
231 }
232 
DecorateGlobalBuffer(llvm::Module & LM,Builder & b,Module * m)233 bool DecorateGlobalBuffer(llvm::Module &LM, Builder &b, Module *m) {
234   Instruction *inst = m->lookupByName("__GPUBlock");
235   if (inst == nullptr) {
236     return true;
237   }
238 
239   VariableInst *bufferVar = static_cast<VariableInst *>(inst);
240   bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0);
241   bufferVar->decorate(Decoration::Binding)->addExtraOperand(0);
242 
243   TypePointerInst *StructPtrTy =
244       static_cast<TypePointerInst *>(bufferVar->mResultType.mInstruction);
245   TypeStructInst *StructTy =
246       static_cast<TypeStructInst *>(StructPtrTy->mOperand2.mInstruction);
247   StructTy->decorate(Decoration::BufferBlock);
248 
249   // Decorate each member with proper offsets
250 
251   const auto GlobalsB = LM.globals().begin();
252   const auto GlobalsE = LM.globals().end();
253   const auto Found =
254       std::find_if(GlobalsB, GlobalsE, [](const llvm::GlobalVariable &GV) {
255         return GV.getName() == "__GPUBlock";
256       });
257 
258   if (Found == GlobalsE) {
259     return true; // GPUBlock not found - not an error by itself.
260   }
261 
262   const llvm::GlobalVariable &G = *Found;
263 
264   bool IsCorrectTy = false;
265   if (const auto *LPtrTy = llvm::dyn_cast<llvm::PointerType>(G.getType())) {
266     if (auto *LStructTy =
267             llvm::dyn_cast<llvm::StructType>(LPtrTy->getElementType())) {
268       IsCorrectTy = true;
269 
270       const auto &DLayout = LM.getDataLayout();
271       const auto *SLayout = DLayout.getStructLayout(LStructTy);
272       assert(SLayout);
273       if (SLayout == nullptr) {
274         std::cerr << "struct layout is null" << std::endl;
275         return false;
276       }
277       for (uint32_t i = 0, e = LStructTy->getNumElements(); i != e; ++i) {
278         auto decor = StructTy->memberDecorate(i, Decoration::Offset);
279         if (!decor) {
280           std::cerr << "failed creating member decoration for field " << i
281                     << std::endl;
282           return false;
283         }
284         const uint32_t offset = (uint32_t)SLayout->getElementOffset(i);
285         decor->addExtraOperand(offset);
286       }
287     }
288   }
289 
290   if (!IsCorrectTy) {
291     return false;
292   }
293 
294   llvm::SmallVector<rs2spirv::RSAllocationInfo, 2> RSAllocs;
295   if (!getRSAllocationInfo(LM, RSAllocs)) {
296     // llvm::errs() << "Extracting rs_allocation info failed\n";
297     return true;
298   }
299 
300   // TODO: clean up the binding number assignment
301   size_t BindingNum = 3;
302   for (const auto &A : RSAllocs) {
303     Instruction *inst = m->lookupByName(A.VarName.c_str());
304     if (inst == nullptr) {
305       return false;
306     }
307     VariableInst *bufferVar = static_cast<VariableInst *>(inst);
308     bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0);
309     bufferVar->decorate(Decoration::Binding)->addExtraOperand(BindingNum++);
310   }
311 
312   return true;
313 }
314 
AddHeader(Module * m)315 void AddHeader(Module *m) {
316   m->addCapability(Capability::Shader);
317   // TODO: avoid duplicated capability
318   // m->addCapability(Capability::Addresses);
319   m->setMemoryModel(AddressingModel::Physical32, MemoryModel::GLSL450);
320 
321   m->addSource(SourceLanguage::GLSL, 450);
322   m->addSourceExtension("GL_ARB_separate_shader_objects");
323   m->addSourceExtension("GL_ARB_shading_language_420pack");
324   m->addSourceExtension("GL_GOOGLE_cpp_style_line_directive");
325   m->addSourceExtension("GL_GOOGLE_include_directive");
326 }
327 
328 namespace {
329 
330 class StorageClassVisitor : public DoNothingVisitor {
331 public:
visit(TypePointerInst * inst)332   void visit(TypePointerInst *inst) override {
333     matchAndReplace(inst->mOperand1);
334   }
335 
visit(TypeForwardPointerInst * inst)336   void visit(TypeForwardPointerInst *inst) override {
337     matchAndReplace(inst->mOperand2);
338   }
339 
visit(VariableInst * inst)340   void visit(VariableInst *inst) override { matchAndReplace(inst->mOperand1); }
341 
342 private:
matchAndReplace(StorageClass & storage)343   void matchAndReplace(StorageClass &storage) {
344     if (storage == StorageClass::Function) {
345       storage = StorageClass::Uniform;
346     }
347   }
348 };
349 
FixGlobalStorageClass(Module * m)350 void FixGlobalStorageClass(Module *m) {
351   StorageClassVisitor v;
352   m->getGlobalSection()->accept(&v);
353 }
354 
355 } // anonymous namespace
356 
AddWrappers(llvm::Module & LM,android::spirit::Module * m)357 bool AddWrappers(llvm::Module &LM,
358                  android::spirit::Module *m) {
359   rs2spirv::Context &Ctxt = rs2spirv::Context::getInstance();
360   const bcinfo::MetadataExtractor &metadata = Ctxt.getMetadata();
361   android::spirit::Builder b;
362 
363   m->setBuilder(&b);
364 
365   FixGlobalStorageClass(m);
366 
367   AddHeader(m);
368 
369   DecorateGlobalBuffer(LM, b, m);
370 
371   const size_t numKernel = metadata.getExportForEachSignatureCount();
372   const char **kernelName = metadata.getExportForEachNameList();
373   const uint32_t *kernelSigature = metadata.getExportForEachSignatureList();
374   const uint32_t *inputCount = metadata.getExportForEachInputCountList();
375 
376   for (size_t i = 0; i < numKernel; i++) {
377     bool success =
378         AddWrapper(kernelName[i], kernelSigature[i], inputCount[i], b, m);
379     if (!success) {
380       return false;
381     }
382   }
383 
384   m->consolidateAnnotations();
385   return true;
386 }
387 
388 class WrapperPass : public Pass {
389 public:
WrapperPass(const llvm::Module & LM)390   WrapperPass(const llvm::Module &LM) : mLLVMModule(const_cast<llvm::Module&>(LM)) {}
391 
run(Module * m,int * error)392   Module *run(Module *m, int *error) override {
393     bool success = AddWrappers(mLLVMModule, m);
394     if (error) {
395       *error = success ? 0 : -1;
396     }
397     return m;
398   }
399 
400 private:
401   llvm::Module &mLLVMModule;
402 };
403 
404 } // namespace spirit
405 } // namespace android
406 
407 namespace rs2spirv {
408 
CreateWrapperPass(const llvm::Module & LLVMModule)409 android::spirit::Pass* CreateWrapperPass(const llvm::Module &LLVMModule) {
410   return new android::spirit::WrapperPass(LLVMModule);
411 }
412 
413 } // namespace rs2spirv
414