//===- SPIRVEntry.cpp - Base Class for SPIR-V Entities -----------*- C++ -*-===// // // The LLVM/SPIRV Translator // // This file is distributed under the University of Illinois Open Source // License. See LICENSE.TXT for details. // // Copyright (c) 2014 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a // copy of this software and associated documentation files (the "Software"), // to deal with the Software without restriction, including without limitation // the rights to use, copy, modify, merge, publish, distribute, sublicense, // and/or sell copies of the Software, and to permit persons to whom the // Software is furnished to do so, subject to the following conditions: // // Redistributions of source code must retain the above copyright notice, // this list of conditions and the following disclaimers. // Redistributions in binary form must reproduce the above copyright notice, // this list of conditions and the following disclaimers in the documentation // and/or other materials provided with the distribution. // Neither the names of Advanced Micro Devices, Inc., nor the names of its // contributors may be used to endorse or promote products derived from this // Software without specific prior written permission. // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH // THE SOFTWARE. // //===----------------------------------------------------------------------===// /// \file /// /// This file implements base class for SPIR-V entities. /// //===----------------------------------------------------------------------===// #include "SPIRVEntry.h" #include "SPIRVDebug.h" #include "SPIRVType.h" #include "SPIRVFunction.h" #include "SPIRVBasicBlock.h" #include "SPIRVInstruction.h" #include "SPIRVDecorate.h" #include "SPIRVStream.h" #include #include #include #include #include #include using namespace SPIRV; namespace SPIRV{ template SPIRVEntry* create() { return new T(); } SPIRVEntry * SPIRVEntry::create(Op OpCode) { typedef SPIRVEntry *(*SPIRVFactoryTy)(); struct TableEntry { Op Opn; SPIRVFactoryTy Factory; operator std::pair() { return std::make_pair(Opn, Factory); } }; static TableEntry Table[] = { #define _SPIRV_OP(x,...) {Op##x, &SPIRV::create}, #include "SPIRVOpCodeEnum.h" #undef _SPIRV_OP }; typedef std::map OpToFactoryMapTy; static const OpToFactoryMapTy OpToFactoryMap(std::begin(Table), std::end(Table)); OpToFactoryMapTy::const_iterator Loc = OpToFactoryMap.find(OpCode); if (Loc != OpToFactoryMap.end()) return Loc->second(); SPIRVDBG(spvdbgs() << "No factory for OpCode " << (unsigned)OpCode << '\n';) assert (0 && "Not implemented"); return 0; } std::unique_ptr SPIRVEntry::create_unique(Op OC) { return std::unique_ptr(create(OC)); } std::unique_ptr SPIRVEntry::create_unique(SPIRVExtInstSetKind Set, unsigned ExtOp) { return std::unique_ptr(new SPIRVExtInst(Set, ExtOp)); } SPIRVErrorLog & SPIRVEntry::getErrorLog()const { return Module->getErrorLog(); } bool SPIRVEntry::exist(SPIRVId TheId)const { return Module->exist(TheId); } SPIRVEntry * SPIRVEntry::getOrCreate(SPIRVId TheId)const { SPIRVEntry *Entry = nullptr; bool Found = Module->exist(TheId, &Entry); if (!Found) return Module->addForward(TheId, nullptr); return Entry; } SPIRVValue * SPIRVEntry::getValue(SPIRVId TheId)const { return get(TheId); } SPIRVType * SPIRVEntry::getValueType(SPIRVId TheId)const { return get(TheId)->getType(); } SPIRVEncoder SPIRVEntry::getEncoder(spv_ostream &O)const{ return SPIRVEncoder(O); } SPIRVDecoder SPIRVEntry::getDecoder(std::istream& I){ return SPIRVDecoder(I, *Module); } void SPIRVEntry::setWordCount(SPIRVWord TheWordCount){ WordCount = TheWordCount; } void SPIRVEntry::setName(const std::string& TheName) { Name = TheName; SPIRVDBG(spvdbgs() << "Set name for obj " << Id << " " << Name << '\n'); } void SPIRVEntry::setModule(SPIRVModule *TheModule) { assert(TheModule && "Invalid module"); if (TheModule == Module) return; assert(Module == NULL && "Cannot change owner of entry"); Module = TheModule; } void SPIRVEntry::encode(spv_ostream &O) const { assert (0 && "Not implemented"); } void SPIRVEntry::encodeName(spv_ostream &O) const { if (!Name.empty()) O << SPIRVName(this, Name); } void SPIRVEntry::encodeAll(spv_ostream &O) const { encodeWordCountOpCode(O); encode(O); encodeChildren(O); } void SPIRVEntry::encodeChildren(spv_ostream &O)const { } void SPIRVEntry::encodeWordCountOpCode(spv_ostream &O) const { #ifdef _SPIRV_SUPPORT_TEXT_FMT if (SPIRVUseTextFormat) { getEncoder(O) << WordCount << OpCode; return; } #endif getEncoder(O) << mkWord(WordCount, OpCode); } // Read words from SPIRV binary and create members for SPIRVEntry. // The word count and op code has already been read before calling this // function for creating the SPIRVEntry. Therefore the input stream only // contains the remaining part of the words for the SPIRVEntry. void SPIRVEntry::decode(std::istream &I) { assert (0 && "Not implemented"); } std::vector SPIRVEntry::getValues(const std::vector& IdVec)const { std::vector ValueVec; for (auto i:IdVec) ValueVec.push_back(getValue(i)); return ValueVec; } std::vector SPIRVEntry::getValueTypes(const std::vector& IdVec)const { std::vector TypeVec; for (auto i:IdVec) TypeVec.push_back(getValue(i)->getType()); return TypeVec; } std::vector SPIRVEntry::getIds(const std::vector ValueVec)const { std::vector IdVec; for (auto i:ValueVec) IdVec.push_back(i->getId()); return IdVec; } SPIRVEntry * SPIRVEntry::getEntry(SPIRVId TheId) const { return Module->getEntry(TheId); } void SPIRVEntry::validateFunctionControlMask(SPIRVWord TheFCtlMask) const { SPIRVCK(isValidFunctionControlMask(TheFCtlMask), InvalidFunctionControlMask, ""); } void SPIRVEntry::validateValues(const std::vector &Ids)const { for (auto I:Ids) getValue(I)->validate(); } void SPIRVEntry::validateBuiltin(SPIRVWord TheSet, SPIRVWord Index)const { (void) TheSet; (void) Index; assert(TheSet != SPIRVWORD_MAX && Index != SPIRVWORD_MAX && "Invalid builtin"); } void SPIRVEntry::addDecorate(const SPIRVDecorate *Dec){ Decorates.insert(std::make_pair(Dec->getDecorateKind(), Dec)); Module->addDecorate(Dec); SPIRVDBG(spvdbgs() << "[addDecorate] " << *Dec << '\n';) } void SPIRVEntry::addDecorate(Decoration Kind) { addDecorate(new SPIRVDecorate(Kind, this)); } void SPIRVEntry::addDecorate(Decoration Kind, SPIRVWord Literal) { addDecorate(new SPIRVDecorate(Kind, this, Literal)); } void SPIRVEntry::eraseDecorate(Decoration Dec){ Decorates.erase(Dec); } void SPIRVEntry::takeDecorates(SPIRVEntry *E){ Decorates = std::move(E->Decorates); SPIRVDBG(spvdbgs() << "[takeDecorates] " << Id << '\n';) } void SPIRVEntry::setLine(SPIRVLine *L){ Line = L; L->setTargetId(Id); SPIRVDBG(spvdbgs() << "[setLine] " << *L << '\n';) } void SPIRVEntry::takeLine(SPIRVEntry *E){ Line = E->Line; if (Line == nullptr) return; Line->setTargetId(Id); E->Line = nullptr; } void SPIRVEntry::addMemberDecorate(const SPIRVMemberDecorate *Dec){ assert(canHaveMemberDecorates() && MemberDecorates.find(Dec->getPair()) == MemberDecorates.end()); MemberDecorates[Dec->getPair()] = Dec; Module->addDecorate(Dec); SPIRVDBG(spvdbgs() << "[addMemberDecorate] " << *Dec << '\n';) } void SPIRVEntry::addMemberDecorate(SPIRVWord MemberNumber, Decoration Kind) { addMemberDecorate(new SPIRVMemberDecorate(Kind, MemberNumber, this)); } void SPIRVEntry::addMemberDecorate(SPIRVWord MemberNumber, Decoration Kind, SPIRVWord Literal) { addMemberDecorate(new SPIRVMemberDecorate(Kind, MemberNumber, this, Literal)); } void SPIRVEntry::eraseMemberDecorate(SPIRVWord MemberNumber, Decoration Dec){ MemberDecorates.erase(std::make_pair(MemberNumber, Dec)); } void SPIRVEntry::takeMemberDecorates(SPIRVEntry *E){ MemberDecorates = std::move(E->MemberDecorates); SPIRVDBG(spvdbgs() << "[takeMemberDecorates] " << Id << '\n';) } void SPIRVEntry::takeAnnotations(SPIRVForward *E){ Module->setName(this, E->getName()); takeDecorates(E); takeMemberDecorates(E); takeLine(E); if (OpCode == OpFunction) static_cast(this)->takeExecutionModes(E); } // Check if an entry has Kind of decoration and get the literal of the // first decoration of such kind at Index. bool SPIRVEntry::hasDecorate(Decoration Kind, size_t Index, SPIRVWord *Result)const { DecorateMapType::const_iterator Loc = Decorates.find(Kind); if (Loc == Decorates.end()) return false; if (Result) *Result = Loc->second->getLiteral(Index); return true; } // Get literals of all decorations of Kind at Index. std::set SPIRVEntry::getDecorate(Decoration Kind, size_t Index) const { auto Range = Decorates.equal_range(Kind); std::set Value; for (auto I = Range.first, E = Range.second; I != E; ++I) { assert(Index < I->second->getLiteralCount() && "Invalid index"); Value.insert(I->second->getLiteral(Index)); } return Value; } bool SPIRVEntry::hasLinkageType() const { return OpCode == OpFunction || OpCode == OpVariable; } void SPIRVEntry::encodeDecorate(spv_ostream &O) const { for (auto& i:Decorates) O << *i.second; } SPIRVLinkageTypeKind SPIRVEntry::getLinkageType() const { assert(hasLinkageType()); DecorateMapType::const_iterator Loc = Decorates.find(DecorationLinkageAttributes); if (Loc == Decorates.end()) return LinkageTypeInternal; return static_cast(Loc->second)->getLinkageType(); } void SPIRVEntry::setLinkageType(SPIRVLinkageTypeKind LT) { assert(isValid(LT)); assert(hasLinkageType()); addDecorate(new SPIRVDecorateLinkageAttr(this, Name, LT)); } void SPIRVEntry::updateModuleVersion() const { if (!Module) return; Module->setMinSPIRVVersion(getRequiredSPIRVVersion()); } spv_ostream & operator<<(spv_ostream &O, const SPIRVEntry &E) { E.validate(); E.encodeAll(O); O << SPIRVNL(); return O; } std::istream & operator>>(std::istream &I, SPIRVEntry &E) { E.decode(I); return I; } SPIRVEntryPoint::SPIRVEntryPoint(SPIRVModule *TheModule, SPIRVExecutionModelKind TheExecModel, SPIRVId TheId, const std::string &TheName) :SPIRVAnnotation(TheModule->get(TheId), getSizeInWords(TheName) + 3), ExecModel(TheExecModel), Name(TheName){ } void SPIRVEntryPoint::encode(spv_ostream &O) const { getEncoder(O) << ExecModel << Target << Name; } void SPIRVEntryPoint::decode(std::istream &I) { getDecoder(I) >> ExecModel >> Target >> Name; Module->setName(getOrCreateTarget(), Name); Module->addEntryPoint(ExecModel, Target); } void SPIRVExecutionMode::encode(spv_ostream &O) const { getEncoder(O) << Target << ExecMode << WordLiterals; } void SPIRVExecutionMode::decode(std::istream &I) { getDecoder(I) >> Target >> ExecMode; switch(ExecMode) { case ExecutionModeLocalSize: case ExecutionModeLocalSizeHint: WordLiterals.resize(3); break; case ExecutionModeInvocations: case ExecutionModeOutputVertices: case ExecutionModeVecTypeHint: WordLiterals.resize(1); break; default: // Do nothing. Keep this to avoid VS2013 warning. break; } getDecoder(I) >> WordLiterals; getOrCreateTarget()->addExecutionMode(this); } SPIRVForward * SPIRVAnnotationGeneric::getOrCreateTarget()const { SPIRVEntry *Entry = nullptr; bool Found = Module->exist(Target, &Entry); assert((!Found || Entry->getOpCode() == OpForward) && "Annotations only allowed on forward"); if (!Found) Entry = Module->addForward(Target, nullptr); return static_cast(Entry); } SPIRVName::SPIRVName(const SPIRVEntry *TheTarget, const std::string& TheStr) :SPIRVAnnotation(TheTarget, getSizeInWords(TheStr) + 2), Str(TheStr){ } void SPIRVName::encode(spv_ostream &O) const { getEncoder(O) << Target << Str; } void SPIRVName::decode(std::istream &I) { getDecoder(I) >> Target >> Str; Module->setName(getOrCreateTarget(), Str); } void SPIRVName::validate() const { assert(WordCount == getSizeInWords(Str) + 2 && "Incorrect word count"); } _SPIRV_IMP_ENCDEC2(SPIRVString, Id, Str) _SPIRV_IMP_ENCDEC3(SPIRVMemberName, Target, MemberNumber, Str) void SPIRVLine::encode(spv_ostream &O) const { getEncoder(O) << Target << FileName << Line << Column; } void SPIRVLine::decode(std::istream &I) { getDecoder(I) >> Target >> FileName >> Line >> Column; Module->addLine(getOrCreateTarget(), get(FileName), Line, Column); } void SPIRVLine::validate() const { assert(OpCode == OpLine); assert(WordCount == 5); assert(get(Target)); assert(get(FileName)->getOpCode() == OpString); assert(Line != SPIRVWORD_MAX); assert(Column != SPIRVWORD_MAX); } void SPIRVMemberName::validate() const { assert(OpCode == OpMemberName); assert(WordCount == getSizeInWords(Str) + FixedWC); assert(get(Target)->getOpCode() == OpTypeStruct); assert(MemberNumber < get(Target)->getStructMemberCount()); } SPIRVExtInstImport::SPIRVExtInstImport(SPIRVModule *TheModule, SPIRVId TheId, const std::string &TheStr): SPIRVEntry(TheModule, 2 + getSizeInWords(TheStr), OC, TheId), Str(TheStr){ validate(); } void SPIRVExtInstImport::encode(spv_ostream &O) const { getEncoder(O) << Id << Str; } void SPIRVExtInstImport::decode(std::istream &I) { getDecoder(I) >> Id >> Str; Module->importBuiltinSetWithId(Str, Id); } void SPIRVExtInstImport::validate() const { SPIRVEntry::validate(); assert(!Str.empty() && "Invalid builtin set"); } void SPIRVMemoryModel::encode(spv_ostream &O) const { getEncoder(O) << Module->getAddressingModel() << Module->getMemoryModel(); } void SPIRVMemoryModel::decode(std::istream &I) { SPIRVAddressingModelKind AddrModel; SPIRVMemoryModelKind MemModel; getDecoder(I) >> AddrModel >> MemModel; Module->setAddressingModel(AddrModel); Module->setMemoryModel(MemModel); } void SPIRVMemoryModel::validate() const { auto AM = Module->getAddressingModel(); auto MM = Module->getMemoryModel(); SPIRVCK(isValid(AM), InvalidAddressingModel, "Actual is "+AM ); SPIRVCK(isValid(MM), InvalidMemoryModel, "Actual is "+MM); } void SPIRVSource::encode(spv_ostream &O) const { SPIRVWord Ver = SPIRVWORD_MAX; auto Language = Module->getSourceLanguage(&Ver); getEncoder(O) << Language << Ver; } void SPIRVSource::decode(std::istream &I) { SourceLanguage Lang = SourceLanguageUnknown; SPIRVWord Ver = SPIRVWORD_MAX; getDecoder(I) >> Lang >> Ver; Module->setSourceLanguage(Lang, Ver); } SPIRVSourceExtension::SPIRVSourceExtension(SPIRVModule *M, const std::string &SS) :SPIRVEntryNoId(M, 1 + getSizeInWords(SS)), S(SS){} void SPIRVSourceExtension::encode(spv_ostream &O) const { getEncoder(O) << S; } void SPIRVSourceExtension::decode(std::istream &I) { getDecoder(I) >> S; Module->getSourceExtension().insert(S); } SPIRVExtension::SPIRVExtension(SPIRVModule *M, const std::string &SS) :SPIRVEntryNoId(M, 1 + getSizeInWords(SS)), S(SS){} void SPIRVExtension::encode(spv_ostream &O) const { getEncoder(O) << S; } void SPIRVExtension::decode(std::istream &I) { getDecoder(I) >> S; Module->getExtension().insert(S); } SPIRVCapability::SPIRVCapability(SPIRVModule *M, SPIRVCapabilityKind K) :SPIRVEntryNoId(M, 2), Kind(K){ updateModuleVersion(); } void SPIRVCapability::encode(spv_ostream &O) const { getEncoder(O) << Kind; } void SPIRVCapability::decode(std::istream &I) { getDecoder(I) >> Kind; Module->addCapability(Kind); } } // namespace SPIRV