1 /*
2  * Copyright 2017, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "module.h"
18 
19 #include <set>
20 
21 #include "builder.h"
22 #include "core_defs.h"
23 #include "instructions.h"
24 #include "types_generated.h"
25 #include "word_stream.h"
26 
27 namespace android {
28 namespace spirit {
29 
30 Module *Module::mInstance = nullptr;
31 
getCurrentModule()32 Module *Module::getCurrentModule() {
33   if (mInstance == nullptr) {
34     return mInstance = new Module();
35   }
36   return mInstance;
37 }
38 
Module()39 Module::Module()
40     : mNextId(1), mCapabilitiesDeleter(mCapabilities),
41       mExtensionsDeleter(mExtensions), mExtInstImportsDeleter(mExtInstImports),
42       mEntryPointInstsDeleter(mEntryPointInsts),
43       mExecutionModesDeleter(mExecutionModes),
44       mEntryPointsDeleter(mEntryPoints),
45       mFunctionDefinitionsDeleter(mFunctionDefinitions) {
46   mInstance = this;
47 }
48 
Module(Builder * b)49 Module::Module(Builder *b)
50     : Entity(b), mNextId(1), mCapabilitiesDeleter(mCapabilities),
51       mExtensionsDeleter(mExtensions), mExtInstImportsDeleter(mExtInstImports),
52       mEntryPointInstsDeleter(mEntryPointInsts),
53       mExecutionModesDeleter(mExecutionModes),
54       mEntryPointsDeleter(mEntryPoints),
55       mFunctionDefinitionsDeleter(mFunctionDefinitions) {
56   mInstance = this;
57 }
58 
resolveIds()59 bool Module::resolveIds() {
60   auto &table = mIdTable;
61 
62   std::unique_ptr<IVisitor> v0(
63       CreateInstructionVisitor([&table](Instruction *inst) {
64         if (inst->hasResult()) {
65           table.insert(std::make_pair(inst->getId(), inst));
66         }
67       }));
68   v0->visit(this);
69 
70   mNextId = mIdTable.rbegin()->first + 1;
71 
72   int err = 0;
73   std::unique_ptr<IVisitor> v(
74       CreateInstructionVisitor([&table, &err](Instruction *inst) {
75         for (auto ref : inst->getAllIdRefs()) {
76           if (ref) {
77             auto it = table.find(ref->mId);
78             if (it != table.end()) {
79               ref->mInstruction = it->second;
80             } else {
81               std::cout << "Found no instruction for id " << ref->mId
82                         << std::endl;
83               err++;
84             }
85           }
86         }
87       }));
88   v->visit(this);
89   return err == 0;
90 }
91 
DeserializeInternal(InputWordStream & IS)92 bool Module::DeserializeInternal(InputWordStream &IS) {
93   if (IS.empty()) {
94     return false;
95   }
96 
97   IS >> &mMagicNumber;
98   if (mMagicNumber != 0x07230203) {
99     errs() << "Wrong Magic Number: " << mMagicNumber;
100     return false;
101   }
102 
103   if (IS.empty()) {
104     return false;
105   }
106 
107   IS >> &mVersion.mWord;
108   if (mVersion.mBytes[0] != 0 || mVersion.mBytes[3] != 0) {
109     return false;
110   }
111 
112   if (IS.empty()) {
113     return false;
114   }
115 
116   IS >> &mGeneratorMagicNumber >> &mBound >> &mReserved;
117 
118   DeserializeZeroOrMore<CapabilityInst>(IS, mCapabilities);
119   DeserializeZeroOrMore<ExtensionInst>(IS, mExtensions);
120   DeserializeZeroOrMore<ExtInstImportInst>(IS, mExtInstImports);
121 
122   mMemoryModel.reset(Deserialize<MemoryModelInst>(IS));
123   if (!mMemoryModel) {
124     errs() << "Missing memory model specification.\n";
125     return false;
126   }
127 
128   DeserializeZeroOrMore<EntryPointDefinition>(IS, mEntryPoints);
129   DeserializeZeroOrMore<ExecutionModeInst>(IS, mExecutionModes);
130   for (auto entry : mEntryPoints) {
131     mEntryPointInsts.push_back(entry->getInstruction());
132     for (auto mode : mExecutionModes) {
133       entry->applyExecutionMode(mode);
134     }
135   }
136 
137   mDebugInfo.reset(Deserialize<DebugInfoSection>(IS));
138   mAnnotations.reset(Deserialize<AnnotationSection>(IS));
139   mGlobals.reset(Deserialize<GlobalSection>(IS));
140 
141   DeserializeZeroOrMore<FunctionDefinition>(IS, mFunctionDefinitions);
142 
143   if (mFunctionDefinitions.empty()) {
144     errs() << "Missing function definitions.\n";
145     for (int i = 0; i < 4; i++) {
146       uint32_t w;
147       IS >> &w;
148       std::cout << std::hex << w << " ";
149     }
150     std::cout << std::endl;
151     return false;
152   }
153 
154   return true;
155 }
156 
initialize()157 void Module::initialize() {
158   mMagicNumber = 0x07230203;
159   mVersion.mMajorMinor = {.mMinorNumber = 1, .mMajorNumber = 1};
160   mGeneratorMagicNumber = 0x00070000;
161   mBound = 0;
162   mReserved = 0;
163   mAnnotations.reset(new AnnotationSection());
164 }
165 
SerializeHeader(OutputWordStream & OS) const166 void Module::SerializeHeader(OutputWordStream &OS) const {
167   OS << mMagicNumber;
168   OS << mVersion.mWord << mGeneratorMagicNumber;
169   if (mBound == 0) {
170     OS << mIdTable.end()->first + 1;
171   } else {
172     OS << std::max(mBound, mNextId);
173   }
174   OS << mReserved;
175 }
176 
Serialize(OutputWordStream & OS) const177 void Module::Serialize(OutputWordStream &OS) const {
178   SerializeHeader(OS);
179   Entity::Serialize(OS);
180 }
181 
addCapability(Capability cap)182 Module *Module::addCapability(Capability cap) {
183   mCapabilities.push_back(mBuilder->MakeCapability(cap));
184   return this;
185 }
186 
setMemoryModel(AddressingModel am,MemoryModel mm)187 Module *Module::setMemoryModel(AddressingModel am, MemoryModel mm) {
188   mMemoryModel.reset(mBuilder->MakeMemoryModel(am, mm));
189   return this;
190 }
191 
addExtInstImport(const char * extName)192 Module *Module::addExtInstImport(const char *extName) {
193   ExtInstImportInst *extInst = mBuilder->MakeExtInstImport(extName);
194   mExtInstImports.push_back(extInst);
195   if (strcmp(extName, "GLSL.std.450") == 0) {
196     mGLExt = extInst;
197   }
198   return this;
199 }
200 
addSource(SourceLanguage lang,int version)201 Module *Module::addSource(SourceLanguage lang, int version) {
202   if (!mDebugInfo) {
203     mDebugInfo.reset(mBuilder->MakeDebugInfoSection());
204   }
205   mDebugInfo->addSource(lang, version);
206   return this;
207 }
208 
addSourceExtension(const char * ext)209 Module *Module::addSourceExtension(const char *ext) {
210   if (!mDebugInfo) {
211     mDebugInfo.reset(mBuilder->MakeDebugInfoSection());
212   }
213   mDebugInfo->addSourceExtension(ext);
214   return this;
215 }
216 
addString(const char * str)217 Module *Module::addString(const char *str) {
218   if (!mDebugInfo) {
219     mDebugInfo.reset(mBuilder->MakeDebugInfoSection());
220   }
221   mDebugInfo->addString(str);
222   return this;
223 }
224 
addEntryPoint(EntryPointDefinition * entry)225 Module *Module::addEntryPoint(EntryPointDefinition *entry) {
226   mEntryPoints.push_back(entry);
227   auto newModes = entry->getExecutionModes();
228   mExecutionModes.insert(mExecutionModes.end(), newModes.begin(),
229                          newModes.end());
230   return this;
231 }
232 
getGlobalSection()233 GlobalSection *Module::getGlobalSection() {
234   if (!mGlobals) {
235     mGlobals.reset(new GlobalSection());
236   }
237   return mGlobals.get();
238 }
239 
getConstant(TypeIntInst * type,int32_t value)240 ConstantInst *Module::getConstant(TypeIntInst *type, int32_t value) {
241   return getGlobalSection()->getConstant(type, value);
242 }
243 
getConstant(TypeIntInst * type,uint32_t value)244 ConstantInst *Module::getConstant(TypeIntInst *type, uint32_t value) {
245   return getGlobalSection()->getConstant(type, value);
246 }
247 
getConstant(TypeFloatInst * type,float value)248 ConstantInst *Module::getConstant(TypeFloatInst *type, float value) {
249   return getGlobalSection()->getConstant(type, value);
250 }
251 
getConstantComposite(TypeVectorInst * type,ConstantInst * components[],size_t width)252 ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type,
253                                                     ConstantInst *components[],
254                                                     size_t width) {
255   return getGlobalSection()->getConstantComposite(type, components, width);
256 }
257 
getConstantComposite(TypeVectorInst * type,ConstantInst * comp0,ConstantInst * comp1,ConstantInst * comp2)258 ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type,
259                                                     ConstantInst *comp0,
260                                                     ConstantInst *comp1,
261                                                     ConstantInst *comp2) {
262   // TODO: verify that component types are the same and consistent with the
263   // resulting vector type
264   ConstantInst *comps[] = {comp0, comp1, comp2};
265   return getConstantComposite(type, comps, 3);
266 }
267 
getConstantComposite(TypeVectorInst * type,ConstantInst * comp0,ConstantInst * comp1,ConstantInst * comp2,ConstantInst * comp3)268 ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type,
269                                                     ConstantInst *comp0,
270                                                     ConstantInst *comp1,
271                                                     ConstantInst *comp2,
272                                                     ConstantInst *comp3) {
273   // TODO: verify that component types are the same and consistent with the
274   // resulting vector type
275   ConstantInst *comps[] = {comp0, comp1, comp2, comp3};
276   return getConstantComposite(type, comps, 4);
277 }
278 
getVoidType()279 TypeVoidInst *Module::getVoidType() {
280   return getGlobalSection()->getVoidType();
281 }
282 
getIntType(int bits,bool isSigned)283 TypeIntInst *Module::getIntType(int bits, bool isSigned) {
284   return getGlobalSection()->getIntType(bits, isSigned);
285 }
286 
getUnsignedIntType(int bits)287 TypeIntInst *Module::getUnsignedIntType(int bits) {
288   return getIntType(bits, false);
289 }
290 
getFloatType(int bits)291 TypeFloatInst *Module::getFloatType(int bits) {
292   return getGlobalSection()->getFloatType(bits);
293 }
294 
getVectorType(Instruction * componentType,int width)295 TypeVectorInst *Module::getVectorType(Instruction *componentType, int width) {
296   return getGlobalSection()->getVectorType(componentType, width);
297 }
298 
getPointerType(StorageClass storage,Instruction * pointeeType)299 TypePointerInst *Module::getPointerType(StorageClass storage,
300                                         Instruction *pointeeType) {
301   return getGlobalSection()->getPointerType(storage, pointeeType);
302 }
303 
getRuntimeArrayType(Instruction * elementType)304 TypeRuntimeArrayInst *Module::getRuntimeArrayType(Instruction *elementType) {
305   return getGlobalSection()->getRuntimeArrayType(elementType);
306 }
307 
getStructType(Instruction * fieldType[],int numField)308 TypeStructInst *Module::getStructType(Instruction *fieldType[], int numField) {
309   return getGlobalSection()->getStructType(fieldType, numField);
310 }
311 
getStructType(Instruction * fieldType)312 TypeStructInst *Module::getStructType(Instruction *fieldType) {
313   return getStructType(&fieldType, 1);
314 }
315 
getFunctionType(Instruction * retType,Instruction * const argType[],size_t numArg)316 TypeFunctionInst *Module::getFunctionType(Instruction *retType,
317                                           Instruction *const argType[],
318                                           size_t numArg) {
319   return getGlobalSection()->getFunctionType(retType, argType, numArg);
320 }
321 
322 TypeFunctionInst *
getFunctionType(Instruction * retType,const std::vector<Instruction * > & argTypes)323 Module::getFunctionType(Instruction *retType,
324                         const std::vector<Instruction *> &argTypes) {
325   return getGlobalSection()->getFunctionType(retType, argTypes.data(),
326                                              argTypes.size());
327 }
328 
getSize(TypeVoidInst *)329 size_t Module::getSize(TypeVoidInst *) { return 0; }
330 
getSize(TypeIntInst * intTy)331 size_t Module::getSize(TypeIntInst *intTy) { return intTy->mOperand1 / 8; }
332 
getSize(TypeFloatInst * fpTy)333 size_t Module::getSize(TypeFloatInst *fpTy) { return fpTy->mOperand1 / 8; }
334 
getSize(TypeVectorInst * vTy)335 size_t Module::getSize(TypeVectorInst *vTy) {
336   return getSize(vTy->mOperand1.mInstruction) * vTy->mOperand2;
337 }
338 
getSize(TypePointerInst *)339 size_t Module::getSize(TypePointerInst *) {
340   return 4; // TODO: or 8?
341 }
342 
getSize(TypeStructInst * structTy)343 size_t Module::getSize(TypeStructInst *structTy) {
344   size_t sz = 0;
345   for (auto ty : structTy->mOperand1) {
346     sz += getSize(ty.mInstruction);
347   }
348   return sz;
349 }
350 
getSize(TypeFunctionInst *)351 size_t Module::getSize(TypeFunctionInst *) {
352   return 4; // TODO: or 8? Is this just the size of a pointer?
353 }
354 
getSize(Instruction * inst)355 size_t Module::getSize(Instruction *inst) {
356   switch (inst->getOpCode()) {
357   case OpTypeVoid:
358     return getSize(static_cast<TypeVoidInst *>(inst));
359   case OpTypeInt:
360     return getSize(static_cast<TypeIntInst *>(inst));
361   case OpTypeFloat:
362     return getSize(static_cast<TypeFloatInst *>(inst));
363   case OpTypeVector:
364     return getSize(static_cast<TypeVectorInst *>(inst));
365   case OpTypeStruct:
366     return getSize(static_cast<TypeStructInst *>(inst));
367   case OpTypeFunction:
368     return getSize(static_cast<TypeFunctionInst *>(inst));
369   default:
370     return 0;
371   }
372 }
373 
addFunctionDefinition(FunctionDefinition * func)374 Module *Module::addFunctionDefinition(FunctionDefinition *func) {
375   mFunctionDefinitions.push_back(func);
376   return this;
377 }
378 
lookupByName(const char * name) const379 Instruction *Module::lookupByName(const char *name) const {
380   return mDebugInfo->lookupByName(name);
381 }
382 
383 FunctionDefinition *
getFunctionDefinitionFromInstruction(FunctionInst * inst) const384 Module::getFunctionDefinitionFromInstruction(FunctionInst *inst) const {
385   for (auto fdef : mFunctionDefinitions) {
386     if (fdef->getInstruction() == inst) {
387       return fdef;
388     }
389   }
390   return nullptr;
391 }
392 
393 FunctionDefinition *
lookupFunctionDefinitionByName(const char * name) const394 Module::lookupFunctionDefinitionByName(const char *name) const {
395   FunctionInst *inst = static_cast<FunctionInst *>(lookupByName(name));
396   return getFunctionDefinitionFromInstruction(inst);
397 }
398 
lookupNameByInstruction(const Instruction * inst) const399 const char *Module::lookupNameByInstruction(const Instruction *inst) const {
400   return mDebugInfo->lookupNameByInstruction(inst);
401 }
402 
getInvocationId()403 VariableInst *Module::getInvocationId() {
404   return getGlobalSection()->getInvocationId();
405 }
406 
getNumWorkgroups()407 VariableInst *Module::getNumWorkgroups() {
408   return getGlobalSection()->getNumWorkgroups();
409 }
410 
addStructType(TypeStructInst * structType)411 Module *Module::addStructType(TypeStructInst *structType) {
412   getGlobalSection()->addStructType(structType);
413   return this;
414 }
415 
addVariable(VariableInst * var)416 Module *Module::addVariable(VariableInst *var) {
417   getGlobalSection()->addVariable(var);
418   return this;
419 }
420 
consolidateAnnotations()421 void Module::consolidateAnnotations() {
422   std::vector<Instruction *> annotations(mAnnotations->begin(),
423                                       mAnnotations->end());
424   std::unique_ptr<IVisitor> v(
425       CreateInstructionVisitor([&annotations](Instruction *inst) -> void {
426         const auto &ann = inst->getAnnotations();
427         annotations.insert(annotations.end(), ann.begin(), ann.end());
428       }));
429   v->visit(this);
430   mAnnotations->clear();
431   mAnnotations->addAnnotations(annotations.begin(), annotations.end());
432 }
433 
EntryPointDefinition(Builder * builder,ExecutionModel execModel,FunctionDefinition * func,const char * name)434 EntryPointDefinition::EntryPointDefinition(Builder *builder,
435                                            ExecutionModel execModel,
436                                            FunctionDefinition *func,
437                                            const char *name)
438     : Entity(builder), mFunction(func->getInstruction()),
439       mExecutionModel(execModel) {
440   mName = strndup(name, strlen(name));
441   mEntryPointInst = mBuilder->MakeEntryPoint(execModel, mFunction, mName);
442 }
443 
DeserializeInternal(InputWordStream & IS)444 bool EntryPointDefinition::DeserializeInternal(InputWordStream &IS) {
445   if (IS.empty()) {
446     return false;
447   }
448 
449   if ((mEntryPointInst = Deserialize<EntryPointInst>(IS))) {
450     return true;
451   }
452 
453   return false;
454 }
455 
456 EntryPointDefinition *
applyExecutionMode(ExecutionModeInst * mode)457 EntryPointDefinition::applyExecutionMode(ExecutionModeInst *mode) {
458   if (mode->mOperand1.mInstruction == mFunction) {
459     addExecutionMode(mode);
460   }
461   return this;
462 }
463 
addToInterface(VariableInst * var)464 EntryPointDefinition *EntryPointDefinition::addToInterface(VariableInst *var) {
465   mInterface.push_back(var);
466   mEntryPointInst->mOperand4.push_back(var);
467   return this;
468 }
469 
setLocalSize(uint32_t width,uint32_t height,uint32_t depth)470 EntryPointDefinition *EntryPointDefinition::setLocalSize(uint32_t width,
471                                                          uint32_t height,
472                                                          uint32_t depth) {
473   mLocalSize.mWidth = width;
474   mLocalSize.mHeight = height;
475   mLocalSize.mDepth = depth;
476 
477   auto mode = mBuilder->MakeExecutionMode(mFunction, ExecutionMode::LocalSize);
478   mode->addExtraOperand(width)->addExtraOperand(height)->addExtraOperand(depth);
479 
480   addExecutionMode(mode);
481 
482   return this;
483 }
484 
DeserializeInternal(InputWordStream & IS)485 bool DebugInfoSection::DeserializeInternal(InputWordStream &IS) {
486   while (true) {
487     if (auto str = Deserialize<StringInst>(IS)) {
488       mSources.push_back(str);
489     } else if (auto src = Deserialize<SourceInst>(IS)) {
490       mSources.push_back(src);
491     } else if (auto srcExt = Deserialize<SourceExtensionInst>(IS)) {
492       mSources.push_back(srcExt);
493     } else if (auto srcCont = Deserialize<SourceContinuedInst>(IS)) {
494       mSources.push_back(srcCont);
495     } else {
496       break;
497     }
498   }
499 
500   while (true) {
501     if (auto name = Deserialize<NameInst>(IS)) {
502       mNames.push_back(name);
503     } else if (auto memName = Deserialize<MemberNameInst>(IS)) {
504       mNames.push_back(memName);
505     } else {
506       break;
507     }
508   }
509 
510   return true;
511 }
512 
addSource(SourceLanguage lang,int version)513 DebugInfoSection *DebugInfoSection::addSource(SourceLanguage lang,
514                                               int version) {
515   SourceInst *source = mBuilder->MakeSource(lang, version);
516   mSources.push_back(source);
517   return this;
518 }
519 
addSourceExtension(const char * ext)520 DebugInfoSection *DebugInfoSection::addSourceExtension(const char *ext) {
521   SourceExtensionInst *inst = mBuilder->MakeSourceExtension(ext);
522   mSources.push_back(inst);
523   return this;
524 }
525 
addString(const char * str)526 DebugInfoSection *DebugInfoSection::addString(const char *str) {
527   StringInst *source = mBuilder->MakeString(str);
528   mSources.push_back(source);
529   return this;
530 }
531 
lookupByName(const char * name) const532 Instruction *DebugInfoSection::lookupByName(const char *name) const {
533   for (auto inst : mNames) {
534     if (inst->getOpCode() == OpName) {
535       NameInst *nameInst = static_cast<NameInst *>(inst);
536       if (nameInst->mOperand2.compare(name) == 0) {
537         return nameInst->mOperand1.mInstruction;
538       }
539     }
540     // Ignore member names
541   }
542   return nullptr;
543 }
544 
545 const char *
lookupNameByInstruction(const Instruction * target) const546 DebugInfoSection::lookupNameByInstruction(const Instruction *target) const {
547   for (auto inst : mNames) {
548     if (inst->getOpCode() == OpName) {
549       NameInst *nameInst = static_cast<NameInst *>(inst);
550       if (nameInst->mOperand1.mInstruction == target) {
551         return nameInst->mOperand2.c_str();
552       }
553     }
554     // Ignore member names
555   }
556   return nullptr;
557 }
558 
AnnotationSection()559 AnnotationSection::AnnotationSection() : mAnnotationsDeleter(mAnnotations) {}
560 
AnnotationSection(Builder * b)561 AnnotationSection::AnnotationSection(Builder *b)
562     : Entity(b), mAnnotationsDeleter(mAnnotations) {}
563 
DeserializeInternal(InputWordStream & IS)564 bool AnnotationSection::DeserializeInternal(InputWordStream &IS) {
565   while (true) {
566     if (auto decor = Deserialize<DecorateInst>(IS)) {
567       mAnnotations.push_back(decor);
568     } else if (auto decor = Deserialize<MemberDecorateInst>(IS)) {
569       mAnnotations.push_back(decor);
570     } else if (auto decor = Deserialize<GroupDecorateInst>(IS)) {
571       mAnnotations.push_back(decor);
572     } else if (auto decor = Deserialize<GroupMemberDecorateInst>(IS)) {
573       mAnnotations.push_back(decor);
574     } else if (auto decor = Deserialize<DecorationGroupInst>(IS)) {
575       mAnnotations.push_back(decor);
576     } else {
577       break;
578     }
579   }
580   return true;
581 }
582 
GlobalSection()583 GlobalSection::GlobalSection() : mGlobalDefsDeleter(mGlobalDefs) {}
584 
GlobalSection(Builder * builder)585 GlobalSection::GlobalSection(Builder *builder)
586     : Entity(builder), mGlobalDefsDeleter(mGlobalDefs) {}
587 
588 namespace {
589 
590 template <typename T>
findOrCreate(std::function<bool (T *)> criteria,std::function<T * ()> factory,std::vector<Instruction * > * globals)591 T *findOrCreate(std::function<bool(T *)> criteria, std::function<T *()> factory,
592                 std::vector<Instruction *> *globals) {
593   T *derived;
594   for (auto inst : *globals) {
595     if (inst->getOpCode() == T::mOpCode) {
596       T *derived = static_cast<T *>(inst);
597       if (criteria(derived)) {
598         return derived;
599       }
600     }
601   }
602   derived = factory();
603   globals->push_back(derived);
604   return derived;
605 }
606 
607 } // anonymous namespace
608 
DeserializeInternal(InputWordStream & IS)609 bool GlobalSection::DeserializeInternal(InputWordStream &IS) {
610   while (true) {
611 #define HANDLE_INSTRUCTION(OPCODE, INST_CLASS)                                 \
612   if (auto typeInst = Deserialize<INST_CLASS>(IS)) {                           \
613     mGlobalDefs.push_back(typeInst);                                           \
614     continue;                                                                  \
615   }
616 #include "const_inst_dispatches_generated.h"
617 #include "type_inst_dispatches_generated.h"
618 #undef HANDLE_INSTRUCTION
619 
620     if (auto globalInst = Deserialize<VariableInst>(IS)) {
621       // Check if this is function scoped
622       if (globalInst->mOperand1 == StorageClass::Function) {
623         Module::errs() << "warning: Variable (id = " << globalInst->mResult;
624         Module::errs() << ") has function scope in global section.\n";
625         // Khronos LLVM-SPIRV convertor emits "Function" storage-class globals.
626         // As a workaround, accept such SPIR-V code here, and fix it up later
627         // in the rs2spirv compiler by correcting the storage class.
628         // In a stricter deserializer, such code should be rejected, and we
629         // should return false here.
630       }
631       mGlobalDefs.push_back(globalInst);
632       continue;
633     }
634 
635     if (auto globalInst = Deserialize<UndefInst>(IS)) {
636       mGlobalDefs.push_back(globalInst);
637       continue;
638     }
639     break;
640   }
641   return true;
642 }
643 
getConstant(TypeIntInst * type,int32_t value)644 ConstantInst *GlobalSection::getConstant(TypeIntInst *type, int32_t value) {
645   return findOrCreate<ConstantInst>(
646       [=](ConstantInst *c) { return c->mOperand1.intValue == value; },
647       [=]() -> ConstantInst * {
648         LiteralContextDependentNumber cdn = {.intValue = value};
649         return mBuilder->MakeConstant(type, cdn);
650       },
651       &mGlobalDefs);
652 }
653 
getConstant(TypeIntInst * type,uint32_t value)654 ConstantInst *GlobalSection::getConstant(TypeIntInst *type, uint32_t value) {
655   return findOrCreate<ConstantInst>(
656       [=](ConstantInst *c) { return c->mOperand1.intValue == (int)value; },
657       [=]() -> ConstantInst * {
658         LiteralContextDependentNumber cdn = {.intValue = (int)value};
659         return mBuilder->MakeConstant(type, cdn);
660       },
661       &mGlobalDefs);
662 }
663 
getConstant(TypeFloatInst * type,float value)664 ConstantInst *GlobalSection::getConstant(TypeFloatInst *type, float value) {
665   return findOrCreate<ConstantInst>(
666       [=](ConstantInst *c) { return c->mOperand1.floatValue == value; },
667       [=]() -> ConstantInst * {
668         LiteralContextDependentNumber cdn = {.floatValue = value};
669         return mBuilder->MakeConstant(type, cdn);
670       },
671       &mGlobalDefs);
672 }
673 
674 ConstantCompositeInst *
getConstantComposite(TypeVectorInst * type,ConstantInst * components[],size_t width)675 GlobalSection::getConstantComposite(TypeVectorInst *type,
676                                     ConstantInst *components[], size_t width) {
677   return findOrCreate<ConstantCompositeInst>(
678       [=](ConstantCompositeInst *c) {
679         if (c->mOperand1.size() != width) {
680           return false;
681         }
682         for (size_t i = 0; i < width; i++) {
683           if (c->mOperand1[i].mInstruction != components[i]) {
684             return false;
685           }
686         }
687         return true;
688       },
689       [=]() -> ConstantCompositeInst * {
690         ConstantCompositeInst *c = mBuilder->MakeConstantComposite(type);
691         for (size_t i = 0; i < width; i++) {
692           c->mOperand1.push_back(components[i]);
693         }
694         return c;
695       },
696       &mGlobalDefs);
697 }
698 
getVoidType()699 TypeVoidInst *GlobalSection::getVoidType() {
700   return findOrCreate<TypeVoidInst>(
701       [=](TypeVoidInst *) -> bool { return true; },
702       [=]() -> TypeVoidInst * { return mBuilder->MakeTypeVoid(); },
703       &mGlobalDefs);
704 }
705 
getIntType(int bits,bool isSigned)706 TypeIntInst *GlobalSection::getIntType(int bits, bool isSigned) {
707   if (isSigned) {
708     switch (bits) {
709 #define HANDLE_INT_SIZE(INT_TYPE, BITS, SIGNED)                                \
710   case BITS: {                                                                 \
711     return findOrCreate<TypeIntInst>(                                          \
712         [=](TypeIntInst *intTy) -> bool {                                      \
713           return intTy->mOperand1 == BITS && intTy->mOperand2 == SIGNED;       \
714         },                                                                     \
715         [=]() -> TypeIntInst * {                                               \
716           return mBuilder->MakeTypeInt(BITS, SIGNED);                          \
717         },                                                                     \
718         &mGlobalDefs);                                                         \
719   }
720       HANDLE_INT_SIZE(Int, 8, 1);
721       HANDLE_INT_SIZE(Int, 16, 1);
722       HANDLE_INT_SIZE(Int, 32, 1);
723       HANDLE_INT_SIZE(Int, 64, 1);
724     default:
725       Module::errs() << "unexpected int type";
726     }
727   } else {
728     switch (bits) {
729       HANDLE_INT_SIZE(UInt, 8, 0);
730       HANDLE_INT_SIZE(UInt, 16, 0);
731       HANDLE_INT_SIZE(UInt, 32, 0);
732       HANDLE_INT_SIZE(UInt, 64, 0);
733     default:
734       Module::errs() << "unexpected int type";
735     }
736   }
737 #undef HANDLE_INT_SIZE
738   return nullptr;
739 }
740 
getFloatType(int bits)741 TypeFloatInst *GlobalSection::getFloatType(int bits) {
742   switch (bits) {
743 #define HANDLE_FLOAT_SIZE(BITS)                                                \
744   case BITS: {                                                                 \
745     return findOrCreate<TypeFloatInst>(                                        \
746         [=](TypeFloatInst *floatTy) -> bool {                                  \
747           return floatTy->mOperand1 == BITS;                                   \
748         },                                                                     \
749         [=]() -> TypeFloatInst * { return mBuilder->MakeTypeFloat(BITS); },    \
750         &mGlobalDefs);                                                         \
751   }
752     HANDLE_FLOAT_SIZE(16);
753     HANDLE_FLOAT_SIZE(32);
754     HANDLE_FLOAT_SIZE(64);
755   default:
756     Module::errs() << "unexpeced floating point type";
757   }
758 #undef HANDLE_FLOAT_SIZE
759   return nullptr;
760 }
761 
getVectorType(Instruction * componentType,int width)762 TypeVectorInst *GlobalSection::getVectorType(Instruction *componentType,
763                                              int width) {
764   // TODO: verify that componentType is basic numeric types
765 
766   return findOrCreate<TypeVectorInst>(
767       [=](TypeVectorInst *vecTy) -> bool {
768         return vecTy->mOperand1.mInstruction == componentType &&
769                vecTy->mOperand2 == width;
770       },
771       [=]() -> TypeVectorInst * {
772         return mBuilder->MakeTypeVector(componentType, width);
773       },
774       &mGlobalDefs);
775 }
776 
getPointerType(StorageClass storage,Instruction * pointeeType)777 TypePointerInst *GlobalSection::getPointerType(StorageClass storage,
778                                                Instruction *pointeeType) {
779   return findOrCreate<TypePointerInst>(
780       [=](TypePointerInst *type) -> bool {
781         return type->mOperand1 == storage &&
782                type->mOperand2.mInstruction == pointeeType;
783       },
784       [=]() -> TypePointerInst * {
785         return mBuilder->MakeTypePointer(storage, pointeeType);
786       },
787       &mGlobalDefs);
788 }
789 
790 TypeRuntimeArrayInst *
getRuntimeArrayType(Instruction * elemType)791 GlobalSection::getRuntimeArrayType(Instruction *elemType) {
792   return findOrCreate<TypeRuntimeArrayInst>(
793       [=](TypeRuntimeArrayInst * /*type*/) -> bool {
794         // return type->mOperand1.mInstruction == elemType;
795         return false;
796       },
797       [=]() -> TypeRuntimeArrayInst * {
798         return mBuilder->MakeTypeRuntimeArray(elemType);
799       },
800       &mGlobalDefs);
801 }
802 
getStructType(Instruction * fieldType[],int numField)803 TypeStructInst *GlobalSection::getStructType(Instruction *fieldType[],
804                                              int numField) {
805   TypeStructInst *structTy = mBuilder->MakeTypeStruct();
806   for (int i = 0; i < numField; i++) {
807     structTy->mOperand1.push_back(fieldType[i]);
808   }
809   mGlobalDefs.push_back(structTy);
810   return structTy;
811 }
812 
getFunctionType(Instruction * retType,Instruction * const argType[],size_t numArg)813 TypeFunctionInst *GlobalSection::getFunctionType(Instruction *retType,
814                                                  Instruction *const argType[],
815                                                  size_t numArg) {
816   return findOrCreate<TypeFunctionInst>(
817       [=](TypeFunctionInst *type) -> bool {
818         if (type->mOperand1.mInstruction != retType ||
819             type->mOperand2.size() != numArg) {
820           return false;
821         }
822         for (size_t i = 0; i < numArg; i++) {
823           if (type->mOperand2[i].mInstruction != argType[i]) {
824             return false;
825           }
826         }
827         return true;
828       },
829       [=]() -> TypeFunctionInst * {
830         TypeFunctionInst *funcTy = mBuilder->MakeTypeFunction(retType);
831         for (size_t i = 0; i < numArg; i++) {
832           funcTy->mOperand2.push_back(argType[i]);
833         }
834         return funcTy;
835       },
836       &mGlobalDefs);
837 }
838 
addStructType(TypeStructInst * structType)839 GlobalSection *GlobalSection::addStructType(TypeStructInst *structType) {
840   mGlobalDefs.push_back(structType);
841   return this;
842 }
843 
addVariable(VariableInst * var)844 GlobalSection *GlobalSection::addVariable(VariableInst *var) {
845   mGlobalDefs.push_back(var);
846   return this;
847 }
848 
getInvocationId()849 VariableInst *GlobalSection::getInvocationId() {
850   if (mInvocationId) {
851     return mInvocationId.get();
852   }
853 
854   TypeIntInst *UIntTy = getIntType(32, false);
855   TypeVectorInst *V3UIntTy = getVectorType(UIntTy, 3);
856   TypePointerInst *V3UIntPtrTy = getPointerType(StorageClass::Input, V3UIntTy);
857 
858   VariableInst *InvocationId =
859       mBuilder->MakeVariable(V3UIntPtrTy, StorageClass::Input);
860   InvocationId->decorate(Decoration::BuiltIn)
861       ->addExtraOperand(static_cast<uint32_t>(BuiltIn::GlobalInvocationId));
862 
863   mInvocationId.reset(InvocationId);
864 
865   return InvocationId;
866 }
867 
getNumWorkgroups()868 VariableInst *GlobalSection::getNumWorkgroups() {
869   if (mNumWorkgroups) {
870     return mNumWorkgroups.get();
871   }
872 
873   TypeIntInst *UIntTy = getIntType(32, false);
874   TypeVectorInst *V3UIntTy = getVectorType(UIntTy, 3);
875   TypePointerInst *V3UIntPtrTy = getPointerType(StorageClass::Input, V3UIntTy);
876 
877   VariableInst *GNum = mBuilder->MakeVariable(V3UIntPtrTy, StorageClass::Input);
878   GNum->decorate(Decoration::BuiltIn)
879       ->addExtraOperand(static_cast<uint32_t>(BuiltIn::NumWorkgroups));
880 
881   mNumWorkgroups.reset(GNum);
882 
883   return GNum;
884 }
885 
DeserializeInternal(InputWordStream & IS)886 bool FunctionDeclaration::DeserializeInternal(InputWordStream &IS) {
887   if (!Deserialize<FunctionInst>(IS)) {
888     return false;
889   }
890 
891   DeserializeZeroOrMore<FunctionParameterInst>(IS, mParams);
892 
893   if (!Deserialize<FunctionEndInst>(IS)) {
894     return false;
895   }
896 
897   return true;
898 }
899 
Deserialize(InputWordStream & IS)900 template <> Instruction *Deserialize(InputWordStream &IS) {
901   Instruction *inst;
902 
903   switch ((*IS) & 0xFFFF) {
904 #define HANDLE_INSTRUCTION(OPCODE, INST_CLASS)                                 \
905   case OPCODE:                                                                 \
906     inst = Deserialize<INST_CLASS>(IS);                                        \
907     break;
908 #include "instruction_dispatches_generated.h"
909 #undef HANDLE_INSTRUCTION
910   default:
911     Module::errs() << "unrecognized instruction";
912     inst = nullptr;
913   }
914 
915   return inst;
916 }
917 
DeserializeInternal(InputWordStream & IS)918 bool Block::DeserializeInternal(InputWordStream &IS) {
919   Instruction *inst;
920   while (((*IS) & 0xFFFF) != OpFunctionEnd &&
921          (inst = Deserialize<Instruction>(IS))) {
922     mInsts.push_back(inst);
923     if (inst->getOpCode() == OpBranch ||
924         inst->getOpCode() == OpBranchConditional ||
925         inst->getOpCode() == OpSwitch || inst->getOpCode() == OpKill ||
926         inst->getOpCode() == OpReturn || inst->getOpCode() == OpReturnValue ||
927         inst->getOpCode() == OpUnreachable) {
928       break;
929     }
930   }
931   return !mInsts.empty();
932 }
933 
FunctionDefinition()934 FunctionDefinition::FunctionDefinition()
935     : mParamsDeleter(mParams), mBlocksDeleter(mBlocks) {}
936 
FunctionDefinition(Builder * builder,FunctionInst * func,FunctionEndInst * end)937 FunctionDefinition::FunctionDefinition(Builder *builder, FunctionInst *func,
938                                        FunctionEndInst *end)
939     : Entity(builder), mFunc(func), mFuncEnd(end), mParamsDeleter(mParams),
940       mBlocksDeleter(mBlocks) {}
941 
DeserializeInternal(InputWordStream & IS)942 bool FunctionDefinition::DeserializeInternal(InputWordStream &IS) {
943   mFunc.reset(Deserialize<FunctionInst>(IS));
944   if (!mFunc) {
945     return false;
946   }
947 
948   DeserializeZeroOrMore<FunctionParameterInst>(IS, mParams);
949   DeserializeZeroOrMore<Block>(IS, mBlocks);
950 
951   mFuncEnd.reset(Deserialize<FunctionEndInst>(IS));
952   if (!mFuncEnd) {
953     return false;
954   }
955 
956   return true;
957 }
958 
getReturnType() const959 Instruction *FunctionDefinition::getReturnType() const {
960   return mFunc->mResultType.mInstruction;
961 }
962 
963 } // namespace spirit
964 } // namespace android
965