1 /*
2 * Copyright 2017, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "module.h"
18
19 #include <set>
20
21 #include "builder.h"
22 #include "core_defs.h"
23 #include "instructions.h"
24 #include "types_generated.h"
25 #include "word_stream.h"
26
27 namespace android {
28 namespace spirit {
29
30 Module *Module::mInstance = nullptr;
31
getCurrentModule()32 Module *Module::getCurrentModule() {
33 if (mInstance == nullptr) {
34 return mInstance = new Module();
35 }
36 return mInstance;
37 }
38
Module()39 Module::Module()
40 : mNextId(1), mCapabilitiesDeleter(mCapabilities),
41 mExtensionsDeleter(mExtensions), mExtInstImportsDeleter(mExtInstImports),
42 mEntryPointInstsDeleter(mEntryPointInsts),
43 mExecutionModesDeleter(mExecutionModes),
44 mEntryPointsDeleter(mEntryPoints),
45 mFunctionDefinitionsDeleter(mFunctionDefinitions) {
46 mInstance = this;
47 }
48
Module(Builder * b)49 Module::Module(Builder *b)
50 : Entity(b), mNextId(1), mCapabilitiesDeleter(mCapabilities),
51 mExtensionsDeleter(mExtensions), mExtInstImportsDeleter(mExtInstImports),
52 mEntryPointInstsDeleter(mEntryPointInsts),
53 mExecutionModesDeleter(mExecutionModes),
54 mEntryPointsDeleter(mEntryPoints),
55 mFunctionDefinitionsDeleter(mFunctionDefinitions) {
56 mInstance = this;
57 }
58
resolveIds()59 bool Module::resolveIds() {
60 auto &table = mIdTable;
61
62 std::unique_ptr<IVisitor> v0(
63 CreateInstructionVisitor([&table](Instruction *inst) {
64 if (inst->hasResult()) {
65 table.insert(std::make_pair(inst->getId(), inst));
66 }
67 }));
68 v0->visit(this);
69
70 mNextId = mIdTable.rbegin()->first + 1;
71
72 int err = 0;
73 std::unique_ptr<IVisitor> v(
74 CreateInstructionVisitor([&table, &err](Instruction *inst) {
75 for (auto ref : inst->getAllIdRefs()) {
76 if (ref) {
77 auto it = table.find(ref->mId);
78 if (it != table.end()) {
79 ref->mInstruction = it->second;
80 } else {
81 std::cout << "Found no instruction for id " << ref->mId
82 << std::endl;
83 err++;
84 }
85 }
86 }
87 }));
88 v->visit(this);
89 return err == 0;
90 }
91
DeserializeInternal(InputWordStream & IS)92 bool Module::DeserializeInternal(InputWordStream &IS) {
93 if (IS.empty()) {
94 return false;
95 }
96
97 IS >> &mMagicNumber;
98 if (mMagicNumber != 0x07230203) {
99 errs() << "Wrong Magic Number: " << mMagicNumber;
100 return false;
101 }
102
103 if (IS.empty()) {
104 return false;
105 }
106
107 IS >> &mVersion.mWord;
108 if (mVersion.mBytes[0] != 0 || mVersion.mBytes[3] != 0) {
109 return false;
110 }
111
112 if (IS.empty()) {
113 return false;
114 }
115
116 IS >> &mGeneratorMagicNumber >> &mBound >> &mReserved;
117
118 DeserializeZeroOrMore<CapabilityInst>(IS, mCapabilities);
119 DeserializeZeroOrMore<ExtensionInst>(IS, mExtensions);
120 DeserializeZeroOrMore<ExtInstImportInst>(IS, mExtInstImports);
121
122 mMemoryModel.reset(Deserialize<MemoryModelInst>(IS));
123 if (!mMemoryModel) {
124 errs() << "Missing memory model specification.\n";
125 return false;
126 }
127
128 DeserializeZeroOrMore<EntryPointDefinition>(IS, mEntryPoints);
129 DeserializeZeroOrMore<ExecutionModeInst>(IS, mExecutionModes);
130 for (auto entry : mEntryPoints) {
131 mEntryPointInsts.push_back(entry->getInstruction());
132 for (auto mode : mExecutionModes) {
133 entry->applyExecutionMode(mode);
134 }
135 }
136
137 mDebugInfo.reset(Deserialize<DebugInfoSection>(IS));
138 mAnnotations.reset(Deserialize<AnnotationSection>(IS));
139 mGlobals.reset(Deserialize<GlobalSection>(IS));
140
141 DeserializeZeroOrMore<FunctionDefinition>(IS, mFunctionDefinitions);
142
143 if (mFunctionDefinitions.empty()) {
144 errs() << "Missing function definitions.\n";
145 for (int i = 0; i < 4; i++) {
146 uint32_t w;
147 IS >> &w;
148 std::cout << std::hex << w << " ";
149 }
150 std::cout << std::endl;
151 return false;
152 }
153
154 return true;
155 }
156
initialize()157 void Module::initialize() {
158 mMagicNumber = 0x07230203;
159 mVersion.mMajorMinor = {.mMinorNumber = 1, .mMajorNumber = 1};
160 mGeneratorMagicNumber = 0x00070000;
161 mBound = 0;
162 mReserved = 0;
163 mAnnotations.reset(new AnnotationSection());
164 }
165
SerializeHeader(OutputWordStream & OS) const166 void Module::SerializeHeader(OutputWordStream &OS) const {
167 OS << mMagicNumber;
168 OS << mVersion.mWord << mGeneratorMagicNumber;
169 if (mBound == 0) {
170 OS << mIdTable.end()->first + 1;
171 } else {
172 OS << std::max(mBound, mNextId);
173 }
174 OS << mReserved;
175 }
176
Serialize(OutputWordStream & OS) const177 void Module::Serialize(OutputWordStream &OS) const {
178 SerializeHeader(OS);
179 Entity::Serialize(OS);
180 }
181
addCapability(Capability cap)182 Module *Module::addCapability(Capability cap) {
183 mCapabilities.push_back(mBuilder->MakeCapability(cap));
184 return this;
185 }
186
setMemoryModel(AddressingModel am,MemoryModel mm)187 Module *Module::setMemoryModel(AddressingModel am, MemoryModel mm) {
188 mMemoryModel.reset(mBuilder->MakeMemoryModel(am, mm));
189 return this;
190 }
191
addExtInstImport(const char * extName)192 Module *Module::addExtInstImport(const char *extName) {
193 ExtInstImportInst *extInst = mBuilder->MakeExtInstImport(extName);
194 mExtInstImports.push_back(extInst);
195 if (strcmp(extName, "GLSL.std.450") == 0) {
196 mGLExt = extInst;
197 }
198 return this;
199 }
200
addSource(SourceLanguage lang,int version)201 Module *Module::addSource(SourceLanguage lang, int version) {
202 if (!mDebugInfo) {
203 mDebugInfo.reset(mBuilder->MakeDebugInfoSection());
204 }
205 mDebugInfo->addSource(lang, version);
206 return this;
207 }
208
addSourceExtension(const char * ext)209 Module *Module::addSourceExtension(const char *ext) {
210 if (!mDebugInfo) {
211 mDebugInfo.reset(mBuilder->MakeDebugInfoSection());
212 }
213 mDebugInfo->addSourceExtension(ext);
214 return this;
215 }
216
addString(const char * str)217 Module *Module::addString(const char *str) {
218 if (!mDebugInfo) {
219 mDebugInfo.reset(mBuilder->MakeDebugInfoSection());
220 }
221 mDebugInfo->addString(str);
222 return this;
223 }
224
addEntryPoint(EntryPointDefinition * entry)225 Module *Module::addEntryPoint(EntryPointDefinition *entry) {
226 mEntryPoints.push_back(entry);
227 auto newModes = entry->getExecutionModes();
228 mExecutionModes.insert(mExecutionModes.end(), newModes.begin(),
229 newModes.end());
230 return this;
231 }
232
getGlobalSection()233 GlobalSection *Module::getGlobalSection() {
234 if (!mGlobals) {
235 mGlobals.reset(new GlobalSection());
236 }
237 return mGlobals.get();
238 }
239
getConstant(TypeIntInst * type,int32_t value)240 ConstantInst *Module::getConstant(TypeIntInst *type, int32_t value) {
241 return getGlobalSection()->getConstant(type, value);
242 }
243
getConstant(TypeIntInst * type,uint32_t value)244 ConstantInst *Module::getConstant(TypeIntInst *type, uint32_t value) {
245 return getGlobalSection()->getConstant(type, value);
246 }
247
getConstant(TypeFloatInst * type,float value)248 ConstantInst *Module::getConstant(TypeFloatInst *type, float value) {
249 return getGlobalSection()->getConstant(type, value);
250 }
251
getConstantComposite(TypeVectorInst * type,ConstantInst * components[],size_t width)252 ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type,
253 ConstantInst *components[],
254 size_t width) {
255 return getGlobalSection()->getConstantComposite(type, components, width);
256 }
257
getConstantComposite(TypeVectorInst * type,ConstantInst * comp0,ConstantInst * comp1,ConstantInst * comp2)258 ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type,
259 ConstantInst *comp0,
260 ConstantInst *comp1,
261 ConstantInst *comp2) {
262 // TODO: verify that component types are the same and consistent with the
263 // resulting vector type
264 ConstantInst *comps[] = {comp0, comp1, comp2};
265 return getConstantComposite(type, comps, 3);
266 }
267
getConstantComposite(TypeVectorInst * type,ConstantInst * comp0,ConstantInst * comp1,ConstantInst * comp2,ConstantInst * comp3)268 ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type,
269 ConstantInst *comp0,
270 ConstantInst *comp1,
271 ConstantInst *comp2,
272 ConstantInst *comp3) {
273 // TODO: verify that component types are the same and consistent with the
274 // resulting vector type
275 ConstantInst *comps[] = {comp0, comp1, comp2, comp3};
276 return getConstantComposite(type, comps, 4);
277 }
278
getVoidType()279 TypeVoidInst *Module::getVoidType() {
280 return getGlobalSection()->getVoidType();
281 }
282
getIntType(int bits,bool isSigned)283 TypeIntInst *Module::getIntType(int bits, bool isSigned) {
284 return getGlobalSection()->getIntType(bits, isSigned);
285 }
286
getUnsignedIntType(int bits)287 TypeIntInst *Module::getUnsignedIntType(int bits) {
288 return getIntType(bits, false);
289 }
290
getFloatType(int bits)291 TypeFloatInst *Module::getFloatType(int bits) {
292 return getGlobalSection()->getFloatType(bits);
293 }
294
getVectorType(Instruction * componentType,int width)295 TypeVectorInst *Module::getVectorType(Instruction *componentType, int width) {
296 return getGlobalSection()->getVectorType(componentType, width);
297 }
298
getPointerType(StorageClass storage,Instruction * pointeeType)299 TypePointerInst *Module::getPointerType(StorageClass storage,
300 Instruction *pointeeType) {
301 return getGlobalSection()->getPointerType(storage, pointeeType);
302 }
303
getRuntimeArrayType(Instruction * elementType)304 TypeRuntimeArrayInst *Module::getRuntimeArrayType(Instruction *elementType) {
305 return getGlobalSection()->getRuntimeArrayType(elementType);
306 }
307
getStructType(Instruction * fieldType[],int numField)308 TypeStructInst *Module::getStructType(Instruction *fieldType[], int numField) {
309 return getGlobalSection()->getStructType(fieldType, numField);
310 }
311
getStructType(Instruction * fieldType)312 TypeStructInst *Module::getStructType(Instruction *fieldType) {
313 return getStructType(&fieldType, 1);
314 }
315
getFunctionType(Instruction * retType,Instruction * const argType[],size_t numArg)316 TypeFunctionInst *Module::getFunctionType(Instruction *retType,
317 Instruction *const argType[],
318 size_t numArg) {
319 return getGlobalSection()->getFunctionType(retType, argType, numArg);
320 }
321
322 TypeFunctionInst *
getFunctionType(Instruction * retType,const std::vector<Instruction * > & argTypes)323 Module::getFunctionType(Instruction *retType,
324 const std::vector<Instruction *> &argTypes) {
325 return getGlobalSection()->getFunctionType(retType, argTypes.data(),
326 argTypes.size());
327 }
328
getSize(TypeVoidInst *)329 size_t Module::getSize(TypeVoidInst *) { return 0; }
330
getSize(TypeIntInst * intTy)331 size_t Module::getSize(TypeIntInst *intTy) { return intTy->mOperand1 / 8; }
332
getSize(TypeFloatInst * fpTy)333 size_t Module::getSize(TypeFloatInst *fpTy) { return fpTy->mOperand1 / 8; }
334
getSize(TypeVectorInst * vTy)335 size_t Module::getSize(TypeVectorInst *vTy) {
336 return getSize(vTy->mOperand1.mInstruction) * vTy->mOperand2;
337 }
338
getSize(TypePointerInst *)339 size_t Module::getSize(TypePointerInst *) {
340 return 4; // TODO: or 8?
341 }
342
getSize(TypeStructInst * structTy)343 size_t Module::getSize(TypeStructInst *structTy) {
344 size_t sz = 0;
345 for (auto ty : structTy->mOperand1) {
346 sz += getSize(ty.mInstruction);
347 }
348 return sz;
349 }
350
getSize(TypeFunctionInst *)351 size_t Module::getSize(TypeFunctionInst *) {
352 return 4; // TODO: or 8? Is this just the size of a pointer?
353 }
354
getSize(Instruction * inst)355 size_t Module::getSize(Instruction *inst) {
356 switch (inst->getOpCode()) {
357 case OpTypeVoid:
358 return getSize(static_cast<TypeVoidInst *>(inst));
359 case OpTypeInt:
360 return getSize(static_cast<TypeIntInst *>(inst));
361 case OpTypeFloat:
362 return getSize(static_cast<TypeFloatInst *>(inst));
363 case OpTypeVector:
364 return getSize(static_cast<TypeVectorInst *>(inst));
365 case OpTypeStruct:
366 return getSize(static_cast<TypeStructInst *>(inst));
367 case OpTypeFunction:
368 return getSize(static_cast<TypeFunctionInst *>(inst));
369 default:
370 return 0;
371 }
372 }
373
addFunctionDefinition(FunctionDefinition * func)374 Module *Module::addFunctionDefinition(FunctionDefinition *func) {
375 mFunctionDefinitions.push_back(func);
376 return this;
377 }
378
lookupByName(const char * name) const379 Instruction *Module::lookupByName(const char *name) const {
380 return mDebugInfo->lookupByName(name);
381 }
382
383 FunctionDefinition *
getFunctionDefinitionFromInstruction(FunctionInst * inst) const384 Module::getFunctionDefinitionFromInstruction(FunctionInst *inst) const {
385 for (auto fdef : mFunctionDefinitions) {
386 if (fdef->getInstruction() == inst) {
387 return fdef;
388 }
389 }
390 return nullptr;
391 }
392
393 FunctionDefinition *
lookupFunctionDefinitionByName(const char * name) const394 Module::lookupFunctionDefinitionByName(const char *name) const {
395 FunctionInst *inst = static_cast<FunctionInst *>(lookupByName(name));
396 return getFunctionDefinitionFromInstruction(inst);
397 }
398
lookupNameByInstruction(const Instruction * inst) const399 const char *Module::lookupNameByInstruction(const Instruction *inst) const {
400 return mDebugInfo->lookupNameByInstruction(inst);
401 }
402
getInvocationId()403 VariableInst *Module::getInvocationId() {
404 return getGlobalSection()->getInvocationId();
405 }
406
getNumWorkgroups()407 VariableInst *Module::getNumWorkgroups() {
408 return getGlobalSection()->getNumWorkgroups();
409 }
410
addStructType(TypeStructInst * structType)411 Module *Module::addStructType(TypeStructInst *structType) {
412 getGlobalSection()->addStructType(structType);
413 return this;
414 }
415
addVariable(VariableInst * var)416 Module *Module::addVariable(VariableInst *var) {
417 getGlobalSection()->addVariable(var);
418 return this;
419 }
420
consolidateAnnotations()421 void Module::consolidateAnnotations() {
422 std::vector<Instruction *> annotations(mAnnotations->begin(),
423 mAnnotations->end());
424 std::unique_ptr<IVisitor> v(
425 CreateInstructionVisitor([&annotations](Instruction *inst) -> void {
426 const auto &ann = inst->getAnnotations();
427 annotations.insert(annotations.end(), ann.begin(), ann.end());
428 }));
429 v->visit(this);
430 mAnnotations->clear();
431 mAnnotations->addAnnotations(annotations.begin(), annotations.end());
432 }
433
EntryPointDefinition(Builder * builder,ExecutionModel execModel,FunctionDefinition * func,const char * name)434 EntryPointDefinition::EntryPointDefinition(Builder *builder,
435 ExecutionModel execModel,
436 FunctionDefinition *func,
437 const char *name)
438 : Entity(builder), mFunction(func->getInstruction()),
439 mExecutionModel(execModel) {
440 mName = strndup(name, strlen(name));
441 mEntryPointInst = mBuilder->MakeEntryPoint(execModel, mFunction, mName);
442 }
443
DeserializeInternal(InputWordStream & IS)444 bool EntryPointDefinition::DeserializeInternal(InputWordStream &IS) {
445 if (IS.empty()) {
446 return false;
447 }
448
449 if ((mEntryPointInst = Deserialize<EntryPointInst>(IS))) {
450 return true;
451 }
452
453 return false;
454 }
455
456 EntryPointDefinition *
applyExecutionMode(ExecutionModeInst * mode)457 EntryPointDefinition::applyExecutionMode(ExecutionModeInst *mode) {
458 if (mode->mOperand1.mInstruction == mFunction) {
459 addExecutionMode(mode);
460 }
461 return this;
462 }
463
addToInterface(VariableInst * var)464 EntryPointDefinition *EntryPointDefinition::addToInterface(VariableInst *var) {
465 mInterface.push_back(var);
466 mEntryPointInst->mOperand4.push_back(var);
467 return this;
468 }
469
setLocalSize(uint32_t width,uint32_t height,uint32_t depth)470 EntryPointDefinition *EntryPointDefinition::setLocalSize(uint32_t width,
471 uint32_t height,
472 uint32_t depth) {
473 mLocalSize.mWidth = width;
474 mLocalSize.mHeight = height;
475 mLocalSize.mDepth = depth;
476
477 auto mode = mBuilder->MakeExecutionMode(mFunction, ExecutionMode::LocalSize);
478 mode->addExtraOperand(width)->addExtraOperand(height)->addExtraOperand(depth);
479
480 addExecutionMode(mode);
481
482 return this;
483 }
484
DeserializeInternal(InputWordStream & IS)485 bool DebugInfoSection::DeserializeInternal(InputWordStream &IS) {
486 while (true) {
487 if (auto str = Deserialize<StringInst>(IS)) {
488 mSources.push_back(str);
489 } else if (auto src = Deserialize<SourceInst>(IS)) {
490 mSources.push_back(src);
491 } else if (auto srcExt = Deserialize<SourceExtensionInst>(IS)) {
492 mSources.push_back(srcExt);
493 } else if (auto srcCont = Deserialize<SourceContinuedInst>(IS)) {
494 mSources.push_back(srcCont);
495 } else {
496 break;
497 }
498 }
499
500 while (true) {
501 if (auto name = Deserialize<NameInst>(IS)) {
502 mNames.push_back(name);
503 } else if (auto memName = Deserialize<MemberNameInst>(IS)) {
504 mNames.push_back(memName);
505 } else {
506 break;
507 }
508 }
509
510 return true;
511 }
512
addSource(SourceLanguage lang,int version)513 DebugInfoSection *DebugInfoSection::addSource(SourceLanguage lang,
514 int version) {
515 SourceInst *source = mBuilder->MakeSource(lang, version);
516 mSources.push_back(source);
517 return this;
518 }
519
addSourceExtension(const char * ext)520 DebugInfoSection *DebugInfoSection::addSourceExtension(const char *ext) {
521 SourceExtensionInst *inst = mBuilder->MakeSourceExtension(ext);
522 mSources.push_back(inst);
523 return this;
524 }
525
addString(const char * str)526 DebugInfoSection *DebugInfoSection::addString(const char *str) {
527 StringInst *source = mBuilder->MakeString(str);
528 mSources.push_back(source);
529 return this;
530 }
531
lookupByName(const char * name) const532 Instruction *DebugInfoSection::lookupByName(const char *name) const {
533 for (auto inst : mNames) {
534 if (inst->getOpCode() == OpName) {
535 NameInst *nameInst = static_cast<NameInst *>(inst);
536 if (nameInst->mOperand2.compare(name) == 0) {
537 return nameInst->mOperand1.mInstruction;
538 }
539 }
540 // Ignore member names
541 }
542 return nullptr;
543 }
544
545 const char *
lookupNameByInstruction(const Instruction * target) const546 DebugInfoSection::lookupNameByInstruction(const Instruction *target) const {
547 for (auto inst : mNames) {
548 if (inst->getOpCode() == OpName) {
549 NameInst *nameInst = static_cast<NameInst *>(inst);
550 if (nameInst->mOperand1.mInstruction == target) {
551 return nameInst->mOperand2.c_str();
552 }
553 }
554 // Ignore member names
555 }
556 return nullptr;
557 }
558
AnnotationSection()559 AnnotationSection::AnnotationSection() : mAnnotationsDeleter(mAnnotations) {}
560
AnnotationSection(Builder * b)561 AnnotationSection::AnnotationSection(Builder *b)
562 : Entity(b), mAnnotationsDeleter(mAnnotations) {}
563
DeserializeInternal(InputWordStream & IS)564 bool AnnotationSection::DeserializeInternal(InputWordStream &IS) {
565 while (true) {
566 if (auto decor = Deserialize<DecorateInst>(IS)) {
567 mAnnotations.push_back(decor);
568 } else if (auto decor = Deserialize<MemberDecorateInst>(IS)) {
569 mAnnotations.push_back(decor);
570 } else if (auto decor = Deserialize<GroupDecorateInst>(IS)) {
571 mAnnotations.push_back(decor);
572 } else if (auto decor = Deserialize<GroupMemberDecorateInst>(IS)) {
573 mAnnotations.push_back(decor);
574 } else if (auto decor = Deserialize<DecorationGroupInst>(IS)) {
575 mAnnotations.push_back(decor);
576 } else {
577 break;
578 }
579 }
580 return true;
581 }
582
GlobalSection()583 GlobalSection::GlobalSection() : mGlobalDefsDeleter(mGlobalDefs) {}
584
GlobalSection(Builder * builder)585 GlobalSection::GlobalSection(Builder *builder)
586 : Entity(builder), mGlobalDefsDeleter(mGlobalDefs) {}
587
588 namespace {
589
590 template <typename T>
findOrCreate(std::function<bool (T *)> criteria,std::function<T * ()> factory,std::vector<Instruction * > * globals)591 T *findOrCreate(std::function<bool(T *)> criteria, std::function<T *()> factory,
592 std::vector<Instruction *> *globals) {
593 T *derived;
594 for (auto inst : *globals) {
595 if (inst->getOpCode() == T::mOpCode) {
596 T *derived = static_cast<T *>(inst);
597 if (criteria(derived)) {
598 return derived;
599 }
600 }
601 }
602 derived = factory();
603 globals->push_back(derived);
604 return derived;
605 }
606
607 } // anonymous namespace
608
DeserializeInternal(InputWordStream & IS)609 bool GlobalSection::DeserializeInternal(InputWordStream &IS) {
610 while (true) {
611 #define HANDLE_INSTRUCTION(OPCODE, INST_CLASS) \
612 if (auto typeInst = Deserialize<INST_CLASS>(IS)) { \
613 mGlobalDefs.push_back(typeInst); \
614 continue; \
615 }
616 #include "const_inst_dispatches_generated.h"
617 #include "type_inst_dispatches_generated.h"
618 #undef HANDLE_INSTRUCTION
619
620 if (auto globalInst = Deserialize<VariableInst>(IS)) {
621 // Check if this is function scoped
622 if (globalInst->mOperand1 == StorageClass::Function) {
623 Module::errs() << "warning: Variable (id = " << globalInst->mResult;
624 Module::errs() << ") has function scope in global section.\n";
625 // Khronos LLVM-SPIRV convertor emits "Function" storage-class globals.
626 // As a workaround, accept such SPIR-V code here, and fix it up later
627 // in the rs2spirv compiler by correcting the storage class.
628 // In a stricter deserializer, such code should be rejected, and we
629 // should return false here.
630 }
631 mGlobalDefs.push_back(globalInst);
632 continue;
633 }
634
635 if (auto globalInst = Deserialize<UndefInst>(IS)) {
636 mGlobalDefs.push_back(globalInst);
637 continue;
638 }
639 break;
640 }
641 return true;
642 }
643
getConstant(TypeIntInst * type,int32_t value)644 ConstantInst *GlobalSection::getConstant(TypeIntInst *type, int32_t value) {
645 return findOrCreate<ConstantInst>(
646 [=](ConstantInst *c) { return c->mOperand1.intValue == value; },
647 [=]() -> ConstantInst * {
648 LiteralContextDependentNumber cdn = {.intValue = value};
649 return mBuilder->MakeConstant(type, cdn);
650 },
651 &mGlobalDefs);
652 }
653
getConstant(TypeIntInst * type,uint32_t value)654 ConstantInst *GlobalSection::getConstant(TypeIntInst *type, uint32_t value) {
655 return findOrCreate<ConstantInst>(
656 [=](ConstantInst *c) { return c->mOperand1.intValue == (int)value; },
657 [=]() -> ConstantInst * {
658 LiteralContextDependentNumber cdn = {.intValue = (int)value};
659 return mBuilder->MakeConstant(type, cdn);
660 },
661 &mGlobalDefs);
662 }
663
getConstant(TypeFloatInst * type,float value)664 ConstantInst *GlobalSection::getConstant(TypeFloatInst *type, float value) {
665 return findOrCreate<ConstantInst>(
666 [=](ConstantInst *c) { return c->mOperand1.floatValue == value; },
667 [=]() -> ConstantInst * {
668 LiteralContextDependentNumber cdn = {.floatValue = value};
669 return mBuilder->MakeConstant(type, cdn);
670 },
671 &mGlobalDefs);
672 }
673
674 ConstantCompositeInst *
getConstantComposite(TypeVectorInst * type,ConstantInst * components[],size_t width)675 GlobalSection::getConstantComposite(TypeVectorInst *type,
676 ConstantInst *components[], size_t width) {
677 return findOrCreate<ConstantCompositeInst>(
678 [=](ConstantCompositeInst *c) {
679 if (c->mOperand1.size() != width) {
680 return false;
681 }
682 for (size_t i = 0; i < width; i++) {
683 if (c->mOperand1[i].mInstruction != components[i]) {
684 return false;
685 }
686 }
687 return true;
688 },
689 [=]() -> ConstantCompositeInst * {
690 ConstantCompositeInst *c = mBuilder->MakeConstantComposite(type);
691 for (size_t i = 0; i < width; i++) {
692 c->mOperand1.push_back(components[i]);
693 }
694 return c;
695 },
696 &mGlobalDefs);
697 }
698
getVoidType()699 TypeVoidInst *GlobalSection::getVoidType() {
700 return findOrCreate<TypeVoidInst>(
701 [=](TypeVoidInst *) -> bool { return true; },
702 [=]() -> TypeVoidInst * { return mBuilder->MakeTypeVoid(); },
703 &mGlobalDefs);
704 }
705
getIntType(int bits,bool isSigned)706 TypeIntInst *GlobalSection::getIntType(int bits, bool isSigned) {
707 if (isSigned) {
708 switch (bits) {
709 #define HANDLE_INT_SIZE(INT_TYPE, BITS, SIGNED) \
710 case BITS: { \
711 return findOrCreate<TypeIntInst>( \
712 [=](TypeIntInst *intTy) -> bool { \
713 return intTy->mOperand1 == BITS && intTy->mOperand2 == SIGNED; \
714 }, \
715 [=]() -> TypeIntInst * { \
716 return mBuilder->MakeTypeInt(BITS, SIGNED); \
717 }, \
718 &mGlobalDefs); \
719 }
720 HANDLE_INT_SIZE(Int, 8, 1);
721 HANDLE_INT_SIZE(Int, 16, 1);
722 HANDLE_INT_SIZE(Int, 32, 1);
723 HANDLE_INT_SIZE(Int, 64, 1);
724 default:
725 Module::errs() << "unexpected int type";
726 }
727 } else {
728 switch (bits) {
729 HANDLE_INT_SIZE(UInt, 8, 0);
730 HANDLE_INT_SIZE(UInt, 16, 0);
731 HANDLE_INT_SIZE(UInt, 32, 0);
732 HANDLE_INT_SIZE(UInt, 64, 0);
733 default:
734 Module::errs() << "unexpected int type";
735 }
736 }
737 #undef HANDLE_INT_SIZE
738 return nullptr;
739 }
740
getFloatType(int bits)741 TypeFloatInst *GlobalSection::getFloatType(int bits) {
742 switch (bits) {
743 #define HANDLE_FLOAT_SIZE(BITS) \
744 case BITS: { \
745 return findOrCreate<TypeFloatInst>( \
746 [=](TypeFloatInst *floatTy) -> bool { \
747 return floatTy->mOperand1 == BITS; \
748 }, \
749 [=]() -> TypeFloatInst * { return mBuilder->MakeTypeFloat(BITS); }, \
750 &mGlobalDefs); \
751 }
752 HANDLE_FLOAT_SIZE(16);
753 HANDLE_FLOAT_SIZE(32);
754 HANDLE_FLOAT_SIZE(64);
755 default:
756 Module::errs() << "unexpeced floating point type";
757 }
758 #undef HANDLE_FLOAT_SIZE
759 return nullptr;
760 }
761
getVectorType(Instruction * componentType,int width)762 TypeVectorInst *GlobalSection::getVectorType(Instruction *componentType,
763 int width) {
764 // TODO: verify that componentType is basic numeric types
765
766 return findOrCreate<TypeVectorInst>(
767 [=](TypeVectorInst *vecTy) -> bool {
768 return vecTy->mOperand1.mInstruction == componentType &&
769 vecTy->mOperand2 == width;
770 },
771 [=]() -> TypeVectorInst * {
772 return mBuilder->MakeTypeVector(componentType, width);
773 },
774 &mGlobalDefs);
775 }
776
getPointerType(StorageClass storage,Instruction * pointeeType)777 TypePointerInst *GlobalSection::getPointerType(StorageClass storage,
778 Instruction *pointeeType) {
779 return findOrCreate<TypePointerInst>(
780 [=](TypePointerInst *type) -> bool {
781 return type->mOperand1 == storage &&
782 type->mOperand2.mInstruction == pointeeType;
783 },
784 [=]() -> TypePointerInst * {
785 return mBuilder->MakeTypePointer(storage, pointeeType);
786 },
787 &mGlobalDefs);
788 }
789
790 TypeRuntimeArrayInst *
getRuntimeArrayType(Instruction * elemType)791 GlobalSection::getRuntimeArrayType(Instruction *elemType) {
792 return findOrCreate<TypeRuntimeArrayInst>(
793 [=](TypeRuntimeArrayInst * /*type*/) -> bool {
794 // return type->mOperand1.mInstruction == elemType;
795 return false;
796 },
797 [=]() -> TypeRuntimeArrayInst * {
798 return mBuilder->MakeTypeRuntimeArray(elemType);
799 },
800 &mGlobalDefs);
801 }
802
getStructType(Instruction * fieldType[],int numField)803 TypeStructInst *GlobalSection::getStructType(Instruction *fieldType[],
804 int numField) {
805 TypeStructInst *structTy = mBuilder->MakeTypeStruct();
806 for (int i = 0; i < numField; i++) {
807 structTy->mOperand1.push_back(fieldType[i]);
808 }
809 mGlobalDefs.push_back(structTy);
810 return structTy;
811 }
812
getFunctionType(Instruction * retType,Instruction * const argType[],size_t numArg)813 TypeFunctionInst *GlobalSection::getFunctionType(Instruction *retType,
814 Instruction *const argType[],
815 size_t numArg) {
816 return findOrCreate<TypeFunctionInst>(
817 [=](TypeFunctionInst *type) -> bool {
818 if (type->mOperand1.mInstruction != retType ||
819 type->mOperand2.size() != numArg) {
820 return false;
821 }
822 for (size_t i = 0; i < numArg; i++) {
823 if (type->mOperand2[i].mInstruction != argType[i]) {
824 return false;
825 }
826 }
827 return true;
828 },
829 [=]() -> TypeFunctionInst * {
830 TypeFunctionInst *funcTy = mBuilder->MakeTypeFunction(retType);
831 for (size_t i = 0; i < numArg; i++) {
832 funcTy->mOperand2.push_back(argType[i]);
833 }
834 return funcTy;
835 },
836 &mGlobalDefs);
837 }
838
addStructType(TypeStructInst * structType)839 GlobalSection *GlobalSection::addStructType(TypeStructInst *structType) {
840 mGlobalDefs.push_back(structType);
841 return this;
842 }
843
addVariable(VariableInst * var)844 GlobalSection *GlobalSection::addVariable(VariableInst *var) {
845 mGlobalDefs.push_back(var);
846 return this;
847 }
848
getInvocationId()849 VariableInst *GlobalSection::getInvocationId() {
850 if (mInvocationId) {
851 return mInvocationId.get();
852 }
853
854 TypeIntInst *UIntTy = getIntType(32, false);
855 TypeVectorInst *V3UIntTy = getVectorType(UIntTy, 3);
856 TypePointerInst *V3UIntPtrTy = getPointerType(StorageClass::Input, V3UIntTy);
857
858 VariableInst *InvocationId =
859 mBuilder->MakeVariable(V3UIntPtrTy, StorageClass::Input);
860 InvocationId->decorate(Decoration::BuiltIn)
861 ->addExtraOperand(static_cast<uint32_t>(BuiltIn::GlobalInvocationId));
862
863 mInvocationId.reset(InvocationId);
864
865 return InvocationId;
866 }
867
getNumWorkgroups()868 VariableInst *GlobalSection::getNumWorkgroups() {
869 if (mNumWorkgroups) {
870 return mNumWorkgroups.get();
871 }
872
873 TypeIntInst *UIntTy = getIntType(32, false);
874 TypeVectorInst *V3UIntTy = getVectorType(UIntTy, 3);
875 TypePointerInst *V3UIntPtrTy = getPointerType(StorageClass::Input, V3UIntTy);
876
877 VariableInst *GNum = mBuilder->MakeVariable(V3UIntPtrTy, StorageClass::Input);
878 GNum->decorate(Decoration::BuiltIn)
879 ->addExtraOperand(static_cast<uint32_t>(BuiltIn::NumWorkgroups));
880
881 mNumWorkgroups.reset(GNum);
882
883 return GNum;
884 }
885
DeserializeInternal(InputWordStream & IS)886 bool FunctionDeclaration::DeserializeInternal(InputWordStream &IS) {
887 if (!Deserialize<FunctionInst>(IS)) {
888 return false;
889 }
890
891 DeserializeZeroOrMore<FunctionParameterInst>(IS, mParams);
892
893 if (!Deserialize<FunctionEndInst>(IS)) {
894 return false;
895 }
896
897 return true;
898 }
899
Deserialize(InputWordStream & IS)900 template <> Instruction *Deserialize(InputWordStream &IS) {
901 Instruction *inst;
902
903 switch ((*IS) & 0xFFFF) {
904 #define HANDLE_INSTRUCTION(OPCODE, INST_CLASS) \
905 case OPCODE: \
906 inst = Deserialize<INST_CLASS>(IS); \
907 break;
908 #include "instruction_dispatches_generated.h"
909 #undef HANDLE_INSTRUCTION
910 default:
911 Module::errs() << "unrecognized instruction";
912 inst = nullptr;
913 }
914
915 return inst;
916 }
917
DeserializeInternal(InputWordStream & IS)918 bool Block::DeserializeInternal(InputWordStream &IS) {
919 Instruction *inst;
920 while (((*IS) & 0xFFFF) != OpFunctionEnd &&
921 (inst = Deserialize<Instruction>(IS))) {
922 mInsts.push_back(inst);
923 if (inst->getOpCode() == OpBranch ||
924 inst->getOpCode() == OpBranchConditional ||
925 inst->getOpCode() == OpSwitch || inst->getOpCode() == OpKill ||
926 inst->getOpCode() == OpReturn || inst->getOpCode() == OpReturnValue ||
927 inst->getOpCode() == OpUnreachable) {
928 break;
929 }
930 }
931 return !mInsts.empty();
932 }
933
FunctionDefinition()934 FunctionDefinition::FunctionDefinition()
935 : mParamsDeleter(mParams), mBlocksDeleter(mBlocks) {}
936
FunctionDefinition(Builder * builder,FunctionInst * func,FunctionEndInst * end)937 FunctionDefinition::FunctionDefinition(Builder *builder, FunctionInst *func,
938 FunctionEndInst *end)
939 : Entity(builder), mFunc(func), mFuncEnd(end), mParamsDeleter(mParams),
940 mBlocksDeleter(mBlocks) {}
941
DeserializeInternal(InputWordStream & IS)942 bool FunctionDefinition::DeserializeInternal(InputWordStream &IS) {
943 mFunc.reset(Deserialize<FunctionInst>(IS));
944 if (!mFunc) {
945 return false;
946 }
947
948 DeserializeZeroOrMore<FunctionParameterInst>(IS, mParams);
949 DeserializeZeroOrMore<Block>(IS, mBlocks);
950
951 mFuncEnd.reset(Deserialize<FunctionEndInst>(IS));
952 if (!mFuncEnd) {
953 return false;
954 }
955
956 return true;
957 }
958
getReturnType() const959 Instruction *FunctionDefinition::getReturnType() const {
960 return mFunc->mResultType.mInstruction;
961 }
962
963 } // namespace spirit
964 } // namespace android
965