//===- Serializer.cpp - MLIR SPIR-V Serialization -------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file defines the MLIR SPIR-V module to SPIR-V binary serialization. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/Serialization.h" #include "mlir/Dialect/SPIRV/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/RegionGraphTraits.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/bit.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "spirv-serialization" using namespace mlir; /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into /// the given `binary` vector. static LogicalResult encodeInstructionInto(SmallVectorImpl &binary, spirv::Opcode op, ArrayRef operands) { uint32_t wordCount = 1 + operands.size(); binary.push_back(spirv::getPrefixedOpcode(wordCount, op)); binary.append(operands.begin(), operands.end()); return success(); } /// A pre-order depth-first visitor function for processing basic blocks. /// /// Visits the basic blocks starting from the given `headerBlock` in pre-order /// depth-first manner and calls `blockHandler` on each block. Skips handling /// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler` /// will not be invoked in `headerBlock` but still handles all `headerBlock`'s /// successors. /// /// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order /// of blocks in a function must satisfy the rule that blocks appear before /// all blocks they dominate." This can be achieved by a pre-order CFG /// traversal algorithm. To make the serialization output more logical and /// readable to human, we perform depth-first CFG traversal and delay the /// serialization of the merge block and the continue block, if exists, until /// after all other blocks have been processed. static LogicalResult visitInPrettyBlockOrder(Block *headerBlock, function_ref blockHandler, bool skipHeader = false, BlockRange skipBlocks = {}) { llvm::df_iterator_default_set doneBlocks; doneBlocks.insert(skipBlocks.begin(), skipBlocks.end()); for (Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) { if (skipHeader && block == headerBlock) continue; if (failed(blockHandler(block))) return failure(); } return success(); } /// Returns the merge block if the given `op` is a structured control flow op. /// Otherwise returns nullptr. static Block *getStructuredControlFlowOpMergeBlock(Operation *op) { if (auto selectionOp = dyn_cast(op)) return selectionOp.getMergeBlock(); if (auto loopOp = dyn_cast(op)) return loopOp.getMergeBlock(); return nullptr; } /// Given a predecessor `block` for a block with arguments, returns the block /// that should be used as the parent block for SPIR-V OpPhi instructions /// corresponding to the block arguments. static Block *getPhiIncomingBlock(Block *block) { // If the predecessor block in question is the entry block for a spv.loop, // we jump to this spv.loop from its enclosing block. if (block->isEntryBlock()) { if (auto loopOp = dyn_cast(block->getParentOp())) { // Then the incoming parent block for OpPhi should be the merge block of // the structured control flow op before this loop. Operation *op = loopOp.getOperation(); while ((op = op->getPrevNode()) != nullptr) if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op)) return incomingBlock; // Or the enclosing block itself if no structured control flow ops // exists before this loop. return loopOp->getBlock(); } } // Otherwise, we jump from the given predecessor block. Try to see if there is // a structured control flow op inside it. for (Operation &op : llvm::reverse(block->getOperations())) { if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op)) return incomingBlock; } return block; } namespace { /// A SPIR-V module serializer. /// /// A SPIR-V binary module is a single linear stream of instructions; each /// instruction is composed of 32-bit words with the layout: /// /// | | | | | ... | /// | <------ word -------> | <-- word --> | <-- word --> | ... | /// /// For the first word, the 16 high-order bits are the word count of the /// instruction, the 16 low-order bits are the opcode enumerant. The /// instructions then belong to different sections, which must be laid out in /// the particular order as specified in "2.4 Logical Layout of a Module" of /// the SPIR-V spec. class Serializer { public: /// Creates a serializer for the given SPIR-V `module`. explicit Serializer(spirv::ModuleOp module, bool emitDebugInfo = false); /// Serializes the remembered SPIR-V module. LogicalResult serialize(); /// Collects the final SPIR-V `binary`. void collect(SmallVectorImpl &binary); #ifndef NDEBUG /// (For debugging) prints each value and its corresponding result . void printValueIDMap(raw_ostream &os); #endif private: // Note that there are two main categories of methods in this class: // * process*() methods are meant to fully serialize a SPIR-V module entity // (header, type, op, etc.). They update internal vectors containing // different binary sections. They are not meant to be called except the // top-level serialization loop. // * prepare*() methods are meant to be helpers that prepare for serializing // certain entity. They may or may not update internal vectors containing // different binary sections. They are meant to be called among themselves // or by other process*() methods for subtasks. //===--------------------------------------------------------------------===// // //===--------------------------------------------------------------------===// // Note that it is illegal to use id <0> in SPIR-V binary module. Various // methods in this class, if using SPIR-V word (uint32_t) as interface, // check or return id <0> to indicate error in processing. /// Consumes the next unused . This method will never return 0. uint32_t getNextID() { return nextID++; } //===--------------------------------------------------------------------===// // Module structure //===--------------------------------------------------------------------===// uint32_t getSpecConstID(StringRef constName) const { return specConstIDMap.lookup(constName); } uint32_t getVariableID(StringRef varName) const { return globalVarIDMap.lookup(varName); } uint32_t getFunctionID(StringRef fnName) const { return funcIDMap.lookup(fnName); } /// Gets the for the function with the given name. Assigns the next /// available if the function haven't been deserialized. uint32_t getOrCreateFunctionID(StringRef fnName); void processCapability(); void processDebugInfo(); void processExtension(); void processMemoryModel(); LogicalResult processConstantOp(spirv::ConstantOp op); LogicalResult processSpecConstantOp(spirv::SpecConstantOp op); LogicalResult processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op); /// SPIR-V dialect supports OpUndef using spv.UndefOp that produces a SSA /// value to use with other operations. The SPIR-V spec recommends that /// OpUndef be generated at module level. The serialization generates an /// OpUndef for each type needed at module level. LogicalResult processUndefOp(spirv::UndefOp op); /// Emit OpName for the given `resultID`. LogicalResult processName(uint32_t resultID, StringRef name); /// Processes a SPIR-V function op. LogicalResult processFuncOp(spirv::FuncOp op); LogicalResult processVariableOp(spirv::VariableOp op); /// Process a SPIR-V GlobalVariableOp LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp); /// Process attributes that translate to decorations on the result LogicalResult processDecoration(Location loc, uint32_t resultID, NamedAttribute attr); template LogicalResult processTypeDecoration(Location loc, DType type, uint32_t resultId) { return emitError(loc, "unhandled decoration for type:") << type; } /// Process member decoration LogicalResult processMemberDecoration( uint32_t structID, const spirv::StructType::MemberDecorationInfo &memberDecorationInfo); //===--------------------------------------------------------------------===// // Types //===--------------------------------------------------------------------===// uint32_t getTypeID(Type type) const { return typeIDMap.lookup(type); } Type getVoidType() { return mlirBuilder.getNoneType(); } bool isVoidType(Type type) const { return type.isa(); } /// Returns true if the given type is a pointer type to a struct in some /// interface storage class. bool isInterfaceStructPtrType(Type type) const; /// Main dispatch method for serializing a type. The result of the /// serialized type will be returned as `typeID`. LogicalResult processType(Location loc, Type type, uint32_t &typeID); LogicalResult processTypeImpl(Location loc, Type type, uint32_t &typeID, llvm::SetVector &serializationCtx); /// Method for preparing basic SPIR-V type serialization. Returns the type's /// opcode and operands for the instruction via `typeEnum` and `operands`. LogicalResult prepareBasicType(Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum, SmallVectorImpl &operands, bool &deferSerialization, llvm::SetVector &serializationCtx); LogicalResult prepareFunctionType(Location loc, FunctionType type, spirv::Opcode &typeEnum, SmallVectorImpl &operands); //===--------------------------------------------------------------------===// // Constant //===--------------------------------------------------------------------===// uint32_t getConstantID(Attribute value) const { return constIDMap.lookup(value); } /// Main dispatch method for processing a constant with the given `constType` /// and `valueAttr`. `constType` is needed here because we can interpret the /// `valueAttr` as a different type than the type of `valueAttr` itself; for /// example, ArrayAttr, whose type is NoneType, is used for spirv::ArrayType /// constants. uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr); /// Prepares array attribute serialization. This method emits corresponding /// OpConstant* and returns the result associated with it. Returns 0 if /// failed. uint32_t prepareArrayConstant(Location loc, Type constType, ArrayAttr attr); /// Prepares bool/int/float DenseElementsAttr serialization. This method /// iterates the DenseElementsAttr to construct the constant array, and /// returns the result associated with it. Returns 0 if failed. Note /// that the size of `index` must match the rank. /// TODO: Consider to enhance splat elements cases. For splat cases, /// we don't need to loop over all elements, especially when the splat value /// is zero. We can use OpConstantNull when the value is zero. uint32_t prepareDenseElementsConstant(Location loc, Type constType, DenseElementsAttr valueAttr, int dim, MutableArrayRef index); /// Prepares scalar attribute serialization. This method emits corresponding /// OpConstant* and returns the result associated with it. Returns 0 if /// the attribute is not for a scalar bool/integer/float value. If `isSpec` is /// true, then the constant will be serialized as a specialization constant. uint32_t prepareConstantScalar(Location loc, Attribute valueAttr, bool isSpec = false); uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr, bool isSpec = false); uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr, bool isSpec = false); uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr, bool isSpec = false); //===--------------------------------------------------------------------===// // Control flow //===--------------------------------------------------------------------===// /// Returns the result for the given block. uint32_t getBlockID(Block *block) const { return blockIDMap.lookup(block); } /// Returns the result for the given block. If no has been assigned, /// assigns the next available uint32_t getOrCreateBlockID(Block *block); /// Processes the given `block` and emits SPIR-V instructions for all ops /// inside. Does not emit OpLabel for this block if `omitLabel` is true. /// `actionBeforeTerminator` is a callback that will be invoked before /// handling the terminator op. It can be used to inject the Op*Merge /// instruction if this is a SPIR-V selection/loop header block. LogicalResult processBlock(Block *block, bool omitLabel = false, function_ref actionBeforeTerminator = nullptr); /// Emits OpPhi instructions for the given block if it has block arguments. LogicalResult emitPhiForBlockArguments(Block *block); LogicalResult processSelectionOp(spirv::SelectionOp selectionOp); LogicalResult processLoopOp(spirv::LoopOp loopOp); LogicalResult processBranchConditionalOp(spirv::BranchConditionalOp); LogicalResult processBranchOp(spirv::BranchOp branchOp); //===--------------------------------------------------------------------===// // Operations //===--------------------------------------------------------------------===// LogicalResult encodeExtensionInstruction(Operation *op, StringRef extensionSetName, uint32_t opcode, ArrayRef operands); uint32_t getValueID(Value val) const { return valueIDMap.lookup(val); } LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp); LogicalResult processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp); /// Main dispatch method for serializing an operation. LogicalResult processOperation(Operation *op); /// Method to dispatch to the serialization function for an operation in /// SPIR-V dialect that is a mirror of an instruction in the SPIR-V spec. /// This is auto-generated from ODS. Dispatch is handled for all operations /// in SPIR-V dialect that have hasOpcode == 1. LogicalResult dispatchToAutogenSerialization(Operation *op); /// Method to serialize an operation in the SPIR-V dialect that is a mirror of /// an instruction in the SPIR-V spec. This is auto generated if hasOpcode == /// 1 and autogenSerialization == 1 in ODS. template LogicalResult processOp(OpTy op) { return op.emitError("unsupported op serialization"); } //===--------------------------------------------------------------------===// // Utilities //===--------------------------------------------------------------------===// /// Emits an OpDecorate instruction to decorate the given `target` with the /// given `decoration`. LogicalResult emitDecoration(uint32_t target, spirv::Decoration decoration, ArrayRef params = {}); /// Emits an OpLine instruction with the given `loc` location information into /// the given `binary` vector. LogicalResult emitDebugLine(SmallVectorImpl &binary, Location loc); private: /// The SPIR-V module to be serialized. spirv::ModuleOp module; /// An MLIR builder for getting MLIR constructs. mlir::Builder mlirBuilder; /// A flag which indicates if the debuginfo should be emitted. bool emitDebugInfo = false; /// A flag which indicates if the last processed instruction was a merge /// instruction. /// According to SPIR-V spec: "If a branch merge instruction is used, the last /// OpLine in the block must be before its merge instruction". bool lastProcessedWasMergeInst = false; /// The of the OpString instruction, which specifies a file name, for /// use by other debug instructions. uint32_t fileID = 0; /// The next available result . uint32_t nextID = 1; // The following are for different SPIR-V instruction sections. They follow // the logical layout of a SPIR-V module. SmallVector capabilities; SmallVector extensions; SmallVector extendedSets; SmallVector memoryModel; SmallVector entryPoints; SmallVector executionModes; SmallVector debug; SmallVector names; SmallVector decorations; SmallVector typesGlobalValues; SmallVector functions; /// Recursive struct references are serialized as OpTypePointer instructions /// to the recursive struct type. However, the OpTypePointer instruction /// cannot be emitted before the recursive struct's OpTypeStruct. /// RecursiveStructPointerInfo stores the data needed to emit such /// OpTypePointer instructions after forward references to such types. struct RecursiveStructPointerInfo { uint32_t pointerTypeID; spirv::StorageClass storageClass; }; // Maps spirv::StructType to its recursive reference member info. DenseMap> recursiveStructInfos; /// `functionHeader` contains all the instructions that must be in the first /// block in the function, and `functionBody` contains the rest. After /// processing FuncOp, the encoded instructions of a function are appended to /// `functions`. An example of instructions in `functionHeader` in order: /// OpFunction ... /// OpFunctionParameter ... /// OpFunctionParameter ... /// OpLabel ... /// OpVariable ... /// OpVariable ... SmallVector functionHeader; SmallVector functionBody; /// Map from type used in SPIR-V module to their s. DenseMap typeIDMap; /// Map from constant values to their s. DenseMap constIDMap; /// Map from specialization constant names to their s. llvm::StringMap specConstIDMap; /// Map from GlobalVariableOps name to s. llvm::StringMap globalVarIDMap; /// Map from FuncOps name to s. llvm::StringMap funcIDMap; /// Map from blocks to their s. DenseMap blockIDMap; /// Map from the Type to the that represents undef value of that type. DenseMap undefValIDMap; /// Map from results of normal operations to their s. DenseMap valueIDMap; /// Map from extended instruction set name to s. llvm::StringMap extendedInstSetIDMap; /// Map from values used in OpPhi instructions to their offset in the /// `functions` section. /// /// When processing a block with arguments, we need to emit OpPhi /// instructions to record the predecessor block s and the values they /// send to the block in question. But it's not guaranteed all values are /// visited and thus assigned result s. So we need this list to capture /// the offsets into `functions` where a value is used so that we can fix it /// up later after processing all the blocks in a function. /// /// More concretely, say if we are visiting the following blocks: /// /// ```mlir /// ^phi(%arg0: i32): /// ... /// ^parent1: /// ... /// spv.Branch ^phi(%val0: i32) /// ^parent2: /// ... /// spv.Branch ^phi(%val1: i32) /// ``` /// /// When we are serializing the `^phi` block, we need to emit at the beginning /// of the block OpPhi instructions which has the following parameters: /// /// OpPhi id-for-i32 id-for-%arg0 id-for-%val0 id-for-^parent1 /// id-for-%val1 id-for-^parent2 /// /// But we don't know the for %val0 and %val1 yet. One way is to visit /// all the blocks twice and use the first visit to assign an to each /// value. But it's paying the overheads just for OpPhi emission. Instead, /// we still visit the blocks once for emission. When we emit the OpPhi /// instructions, we use 0 as a placeholder for the s for %val0 and %val1. /// At the same time, we record their offsets in the emitted binary (which is /// placed inside `functions`) here. And then after emitting all blocks, we /// replace the dummy 0 with the real result by overwriting /// `functions[offset]`. DenseMap> deferredPhiValues; }; } // namespace Serializer::Serializer(spirv::ModuleOp module, bool emitDebugInfo) : module(module), mlirBuilder(module.getContext()), emitDebugInfo(emitDebugInfo) {} LogicalResult Serializer::serialize() { LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n"); if (failed(module.verify())) return failure(); // TODO: handle the other sections processCapability(); processExtension(); processMemoryModel(); processDebugInfo(); // Iterate over the module body to serialize it. Assumptions are that there is // only one basic block in the moduleOp for (auto &op : module.getBlock()) { if (failed(processOperation(&op))) { return failure(); } } LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n"); return success(); } void Serializer::collect(SmallVectorImpl &binary) { auto moduleSize = spirv::kHeaderWordCount + capabilities.size() + extensions.size() + extendedSets.size() + memoryModel.size() + entryPoints.size() + executionModes.size() + decorations.size() + typesGlobalValues.size() + functions.size(); binary.clear(); binary.reserve(moduleSize); spirv::appendModuleHeader(binary, module.vce_triple()->getVersion(), nextID); binary.append(capabilities.begin(), capabilities.end()); binary.append(extensions.begin(), extensions.end()); binary.append(extendedSets.begin(), extendedSets.end()); binary.append(memoryModel.begin(), memoryModel.end()); binary.append(entryPoints.begin(), entryPoints.end()); binary.append(executionModes.begin(), executionModes.end()); binary.append(debug.begin(), debug.end()); binary.append(names.begin(), names.end()); binary.append(decorations.begin(), decorations.end()); binary.append(typesGlobalValues.begin(), typesGlobalValues.end()); binary.append(functions.begin(), functions.end()); } #ifndef NDEBUG void Serializer::printValueIDMap(raw_ostream &os) { os << "\n= Value Map =\n\n"; for (auto valueIDPair : valueIDMap) { Value val = valueIDPair.first; os << " " << val << " " << "id = " << valueIDPair.second << ' '; if (auto *op = val.getDefiningOp()) { os << "from op '" << op->getName() << "'"; } else if (auto arg = val.dyn_cast()) { Block *block = arg.getOwner(); os << "from argument of block " << block << ' '; os << " in op '" << block->getParentOp()->getName() << "'"; } os << '\n'; } } #endif //===----------------------------------------------------------------------===// // Module structure //===----------------------------------------------------------------------===// uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) { auto funcID = funcIDMap.lookup(fnName); if (!funcID) { funcID = getNextID(); funcIDMap[fnName] = funcID; } return funcID; } void Serializer::processCapability() { for (auto cap : module.vce_triple()->getCapabilities()) encodeInstructionInto(capabilities, spirv::Opcode::OpCapability, {static_cast(cap)}); } void Serializer::processDebugInfo() { if (!emitDebugInfo) return; auto fileLoc = module.getLoc().dyn_cast(); auto fileName = fileLoc ? fileLoc.getFilename() : ""; fileID = getNextID(); SmallVector operands; operands.push_back(fileID); spirv::encodeStringLiteralInto(operands, fileName); encodeInstructionInto(debug, spirv::Opcode::OpString, operands); // TODO: Encode more debug instructions. } void Serializer::processExtension() { llvm::SmallVector extName; for (spirv::Extension ext : module.vce_triple()->getExtensions()) { extName.clear(); spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext)); encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName); } } void Serializer::processMemoryModel() { uint32_t mm = module->getAttrOfType("memory_model").getInt(); uint32_t am = module->getAttrOfType("addressing_model").getInt(); encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm}); } LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) { if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value())) { valueIDMap[op.getResult()] = resultID; return success(); } return failure(); } LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) { if (auto resultID = prepareConstantScalar(op.getLoc(), op.default_value(), /*isSpec=*/true)) { // Emit the OpDecorate instruction for SpecId. if (auto specID = op->getAttrOfType("spec_id")) { auto val = static_cast(specID.getInt()); emitDecoration(resultID, spirv::Decoration::SpecId, {val}); } specConstIDMap[op.sym_name()] = resultID; return processName(resultID, op.sym_name()); } return failure(); } LogicalResult Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) { uint32_t typeID = 0; if (failed(processType(op.getLoc(), op.type(), typeID))) { return failure(); } auto resultID = getNextID(); SmallVector operands; operands.push_back(typeID); operands.push_back(resultID); auto constituents = op.constituents(); for (auto index : llvm::seq(0, constituents.size())) { auto constituent = constituents[index].dyn_cast(); auto constituentName = constituent.getValue(); auto constituentID = getSpecConstID(constituentName); if (!constituentID) { return op.emitError("unknown result for specialization constant ") << constituentName; } operands.push_back(constituentID); } encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpSpecConstantComposite, operands); specConstIDMap[op.sym_name()] = resultID; return processName(resultID, op.sym_name()); } LogicalResult Serializer::processUndefOp(spirv::UndefOp op) { auto undefType = op.getType(); auto &id = undefValIDMap[undefType]; if (!id) { id = getNextID(); uint32_t typeID = 0; if (failed(processType(op.getLoc(), undefType, typeID)) || failed(encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef, {typeID, id}))) { return failure(); } } valueIDMap[op.getResult()] = id; return success(); } LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, NamedAttribute attr) { auto attrName = attr.first.strref(); auto decorationName = llvm::convertToCamelFromSnakeCase(attrName, true); auto decoration = spirv::symbolizeDecoration(decorationName); if (!decoration) { return emitError( loc, "non-argument attributes expected to have snake-case-ified " "decoration name, unhandled attribute with name : ") << attrName; } SmallVector args; switch (decoration.getValue()) { case spirv::Decoration::Binding: case spirv::Decoration::DescriptorSet: case spirv::Decoration::Location: if (auto intAttr = attr.second.dyn_cast()) { args.push_back(intAttr.getValue().getZExtValue()); break; } return emitError(loc, "expected integer attribute for ") << attrName; case spirv::Decoration::BuiltIn: if (auto strAttr = attr.second.dyn_cast()) { auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue()); if (enumVal) { args.push_back(static_cast(enumVal.getValue())); break; } return emitError(loc, "invalid ") << attrName << " attribute " << strAttr.getValue(); } return emitError(loc, "expected string attribute for ") << attrName; case spirv::Decoration::Aliased: case spirv::Decoration::Flat: case spirv::Decoration::NonReadable: case spirv::Decoration::NonWritable: case spirv::Decoration::NoPerspective: case spirv::Decoration::Restrict: // For unit attributes, the args list has no values so we do nothing if (auto unitAttr = attr.second.dyn_cast()) break; return emitError(loc, "expected unit attribute for ") << attrName; default: return emitError(loc, "unhandled decoration ") << decorationName; } return emitDecoration(resultID, decoration.getValue(), args); } LogicalResult Serializer::processName(uint32_t resultID, StringRef name) { assert(!name.empty() && "unexpected empty string for OpName"); SmallVector nameOperands; nameOperands.push_back(resultID); if (failed(spirv::encodeStringLiteralInto(nameOperands, name))) { return failure(); } return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands); } namespace { template <> LogicalResult Serializer::processTypeDecoration( Location loc, spirv::ArrayType type, uint32_t resultID) { if (unsigned stride = type.getArrayStride()) { // OpDecorate %arrayTypeSSA ArrayStride strideLiteral return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); } return success(); } template <> LogicalResult Serializer::processTypeDecoration( Location Loc, spirv::RuntimeArrayType type, uint32_t resultID) { if (unsigned stride = type.getArrayStride()) { // OpDecorate %arrayTypeSSA ArrayStride strideLiteral return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); } return success(); } LogicalResult Serializer::processMemberDecoration( uint32_t structID, const spirv::StructType::MemberDecorationInfo &memberDecoration) { SmallVector args( {structID, memberDecoration.memberIndex, static_cast(memberDecoration.decoration)}); if (memberDecoration.hasValue) { args.push_back(memberDecoration.decorationValue); } return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args); } } // namespace LogicalResult Serializer::processFuncOp(spirv::FuncOp op) { LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n"); assert(functionHeader.empty() && functionBody.empty()); uint32_t fnTypeID = 0; // Generate type of the function. processType(op.getLoc(), op.getType(), fnTypeID); // Add the function definition. SmallVector operands; uint32_t resTypeID = 0; auto resultTypes = op.getType().getResults(); if (resultTypes.size() > 1) { return op.emitError("cannot serialize function with multiple return types"); } if (failed(processType(op.getLoc(), (resultTypes.empty() ? getVoidType() : resultTypes[0]), resTypeID))) { return failure(); } operands.push_back(resTypeID); auto funcID = getOrCreateFunctionID(op.getName()); operands.push_back(funcID); operands.push_back(static_cast(op.function_control())); operands.push_back(fnTypeID); encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands); // Add function name. if (failed(processName(funcID, op.getName()))) { return failure(); } // Declare the parameters. for (auto arg : op.getArguments()) { uint32_t argTypeID = 0; if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) { return failure(); } auto argValueID = getNextID(); valueIDMap[arg] = argValueID; encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter, {argTypeID, argValueID}); } // Process the body. if (op.isExternal()) { return op.emitError("external function is unhandled"); } // Some instructions (e.g., OpVariable) in a function must be in the first // block in the function. These instructions will be put in functionHeader. // Thus, we put the label in functionHeader first, and omit it from the first // block. encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel, {getOrCreateBlockID(&op.front())}); processBlock(&op.front(), /*omitLabel=*/true); if (failed(visitInPrettyBlockOrder( &op.front(), [&](Block *block) { return processBlock(block); }, /*skipHeader=*/true))) { return failure(); } // There might be OpPhi instructions who have value references needing to fix. for (auto deferredValue : deferredPhiValues) { Value value = deferredValue.first; uint32_t id = getValueID(value); LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value << " to id = " << id << '\n'); assert(id && "OpPhi references undefined value!"); for (size_t offset : deferredValue.second) functionBody[offset] = id; } deferredPhiValues.clear(); LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName() << "' --\n"); // Insert OpFunctionEnd. if (failed(encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, {}))) { return failure(); } functions.append(functionHeader.begin(), functionHeader.end()); functions.append(functionBody.begin(), functionBody.end()); functionHeader.clear(); functionBody.clear(); return success(); } LogicalResult Serializer::processVariableOp(spirv::VariableOp op) { SmallVector operands; SmallVector elidedAttrs; uint32_t resultID = 0; uint32_t resultTypeID = 0; if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) { return failure(); } operands.push_back(resultTypeID); resultID = getNextID(); valueIDMap[op.getResult()] = resultID; operands.push_back(resultID); auto attr = op.getAttr(spirv::attributeName()); if (attr) { operands.push_back(static_cast( attr.cast().getValue().getZExtValue())); } elidedAttrs.push_back(spirv::attributeName()); for (auto arg : op.getODSOperands(0)) { auto argID = getValueID(arg); if (!argID) { return emitError(op.getLoc(), "operand 0 has a use before def"); } operands.push_back(argID); } emitDebugLine(functionHeader, op.getLoc()); encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands); for (auto attr : op.getAttrs()) { if (llvm::any_of(elidedAttrs, [&](StringRef elided) { return attr.first == elided; })) { continue; } if (failed(processDecoration(op.getLoc(), resultID, attr))) { return failure(); } } return success(); } LogicalResult Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) { // Get TypeID. uint32_t resultTypeID = 0; SmallVector elidedAttrs; if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) { return failure(); } if (isInterfaceStructPtrType(varOp.type())) { auto structType = varOp.type() .cast() .getPointeeType() .cast(); if (failed( emitDecoration(getTypeID(structType), spirv::Decoration::Block))) { return varOp.emitError("cannot decorate ") << structType << " with Block decoration"; } } elidedAttrs.push_back("type"); SmallVector operands; operands.push_back(resultTypeID); auto resultID = getNextID(); // Encode the name. auto varName = varOp.sym_name(); elidedAttrs.push_back(SymbolTable::getSymbolAttrName()); if (failed(processName(resultID, varName))) { return failure(); } globalVarIDMap[varName] = resultID; operands.push_back(resultID); // Encode StorageClass. operands.push_back(static_cast(varOp.storageClass())); // Encode initialization. if (auto initializer = varOp.initializer()) { auto initializerID = getVariableID(initializer.getValue()); if (!initializerID) { return emitError(varOp.getLoc(), "invalid usage of undefined variable as initializer"); } operands.push_back(initializerID); elidedAttrs.push_back("initializer"); } emitDebugLine(typesGlobalValues, varOp.getLoc()); if (failed(encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, operands))) { elidedAttrs.push_back("initializer"); return failure(); } // Encode decorations. for (auto attr : varOp.getAttrs()) { if (llvm::any_of(elidedAttrs, [&](StringRef elided) { return attr.first == elided; })) { continue; } if (failed(processDecoration(varOp.getLoc(), resultID, attr))) { return failure(); } } return success(); } //===----------------------------------------------------------------------===// // Type //===----------------------------------------------------------------------===// // According to the SPIR-V spec "Validation Rules for Shader Capabilities": // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and // PushConstant Storage Classes must be explicitly laid out." bool Serializer::isInterfaceStructPtrType(Type type) const { if (auto ptrType = type.dyn_cast()) { switch (ptrType.getStorageClass()) { case spirv::StorageClass::PhysicalStorageBuffer: case spirv::StorageClass::PushConstant: case spirv::StorageClass::StorageBuffer: case spirv::StorageClass::Uniform: return ptrType.getPointeeType().isa(); default: break; } } return false; } LogicalResult Serializer::processType(Location loc, Type type, uint32_t &typeID) { // Maintains a set of names for nested identified struct types. This is used // to properly serialize resursive references. llvm::SetVector serializationCtx; return processTypeImpl(loc, type, typeID, serializationCtx); } LogicalResult Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID, llvm::SetVector &serializationCtx) { typeID = getTypeID(type); if (typeID) { return success(); } typeID = getNextID(); SmallVector operands; operands.push_back(typeID); auto typeEnum = spirv::Opcode::OpTypeVoid; bool deferSerialization = false; if ((type.isa() && succeeded(prepareFunctionType(loc, type.cast(), typeEnum, operands))) || succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands, deferSerialization, serializationCtx))) { if (deferSerialization) return success(); typeIDMap[type] = typeID; if (failed(encodeInstructionInto(typesGlobalValues, typeEnum, operands))) return failure(); if (recursiveStructInfos.count(type) != 0) { // This recursive struct type is emitted already, now the OpTypePointer // instructions referring to recursive references are emitted as well. for (auto &ptrInfo : recursiveStructInfos[type]) { // TODO: This might not work if more than 1 recursive reference is // present in the struct. SmallVector ptrOperands; ptrOperands.push_back(ptrInfo.pointerTypeID); ptrOperands.push_back(static_cast(ptrInfo.storageClass)); ptrOperands.push_back(typeIDMap[type]); if (failed(encodeInstructionInto( typesGlobalValues, spirv::Opcode::OpTypePointer, ptrOperands))) return failure(); } recursiveStructInfos[type].clear(); } return success(); } return failure(); } LogicalResult Serializer::prepareBasicType( Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum, SmallVectorImpl &operands, bool &deferSerialization, llvm::SetVector &serializationCtx) { deferSerialization = false; if (isVoidType(type)) { typeEnum = spirv::Opcode::OpTypeVoid; return success(); } if (auto intType = type.dyn_cast()) { if (intType.getWidth() == 1) { typeEnum = spirv::Opcode::OpTypeBool; return success(); } typeEnum = spirv::Opcode::OpTypeInt; operands.push_back(intType.getWidth()); // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics // to preserve or validate. // 0 indicates unsigned, or no signedness semantics // 1 indicates signed semantics." operands.push_back(intType.isSigned() ? 1 : 0); return success(); } if (auto floatType = type.dyn_cast()) { typeEnum = spirv::Opcode::OpTypeFloat; operands.push_back(floatType.getWidth()); return success(); } if (auto vectorType = type.dyn_cast()) { uint32_t elementTypeID = 0; if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID, serializationCtx))) { return failure(); } typeEnum = spirv::Opcode::OpTypeVector; operands.push_back(elementTypeID); operands.push_back(vectorType.getNumElements()); return success(); } if (auto arrayType = type.dyn_cast()) { typeEnum = spirv::Opcode::OpTypeArray; uint32_t elementTypeID = 0; if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID, serializationCtx))) { return failure(); } operands.push_back(elementTypeID); if (auto elementCountID = prepareConstantInt( loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) { operands.push_back(elementCountID); } return processTypeDecoration(loc, arrayType, resultID); } if (auto ptrType = type.dyn_cast()) { uint32_t pointeeTypeID = 0; spirv::StructType pointeeStruct = ptrType.getPointeeType().dyn_cast(); if (pointeeStruct && pointeeStruct.isIdentified() && serializationCtx.count(pointeeStruct.getIdentifier()) != 0) { // A recursive reference to an enclosing struct is found. // // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage // class as operands. SmallVector forwardPtrOperands; forwardPtrOperands.push_back(resultID); forwardPtrOperands.push_back( static_cast(ptrType.getStorageClass())); encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpTypeForwardPointer, forwardPtrOperands); // 2. Find the pointee (enclosing) struct. auto structType = spirv::StructType::getIdentified( module.getContext(), pointeeStruct.getIdentifier()); if (!structType) return failure(); // 3. Mark the OpTypePointer that is supposed to be emitted by this call // as deferred. deferSerialization = true; // 4. Record the info needed to emit the deferred OpTypePointer // instruction when the enclosing struct is completely serialized. recursiveStructInfos[structType].push_back( {resultID, ptrType.getStorageClass()}); } else { if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID, serializationCtx))) return failure(); } typeEnum = spirv::Opcode::OpTypePointer; operands.push_back(static_cast(ptrType.getStorageClass())); operands.push_back(pointeeTypeID); return success(); } if (auto runtimeArrayType = type.dyn_cast()) { uint32_t elementTypeID = 0; if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(), elementTypeID, serializationCtx))) { return failure(); } typeEnum = spirv::Opcode::OpTypeRuntimeArray; operands.push_back(elementTypeID); return processTypeDecoration(loc, runtimeArrayType, resultID); } if (auto structType = type.dyn_cast()) { if (structType.isIdentified()) { processName(resultID, structType.getIdentifier()); serializationCtx.insert(structType.getIdentifier()); } bool hasOffset = structType.hasOffset(); for (auto elementIndex : llvm::seq(0, structType.getNumElements())) { uint32_t elementTypeID = 0; if (failed(processTypeImpl(loc, structType.getElementType(elementIndex), elementTypeID, serializationCtx))) { return failure(); } operands.push_back(elementTypeID); if (hasOffset) { // Decorate each struct member with an offset spirv::StructType::MemberDecorationInfo offsetDecoration{ elementIndex, /*hasValue=*/1, spirv::Decoration::Offset, static_cast(structType.getMemberOffset(elementIndex))}; if (failed(processMemberDecoration(resultID, offsetDecoration))) { return emitError(loc, "cannot decorate ") << elementIndex << "-th member of " << structType << " with its offset"; } } } SmallVector memberDecorations; structType.getMemberDecorations(memberDecorations); for (auto &memberDecoration : memberDecorations) { if (failed(processMemberDecoration(resultID, memberDecoration))) { return emitError(loc, "cannot decorate ") << static_cast(memberDecoration.memberIndex) << "-th member of " << structType << " with " << stringifyDecoration(memberDecoration.decoration); } } typeEnum = spirv::Opcode::OpTypeStruct; if (structType.isIdentified()) serializationCtx.remove(structType.getIdentifier()); return success(); } if (auto cooperativeMatrixType = type.dyn_cast()) { uint32_t elementTypeID = 0; if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(), elementTypeID, serializationCtx))) { return failure(); } typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV; auto getConstantOp = [&](uint32_t id) { auto attr = IntegerAttr::get(IntegerType::get(32, type.getContext()), id); return prepareConstantInt(loc, attr); }; operands.push_back(elementTypeID); operands.push_back( getConstantOp(static_cast(cooperativeMatrixType.getScope()))); operands.push_back(getConstantOp(cooperativeMatrixType.getRows())); operands.push_back(getConstantOp(cooperativeMatrixType.getColumns())); return success(); } if (auto matrixType = type.dyn_cast()) { uint32_t elementTypeID = 0; if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID, serializationCtx))) { return failure(); } typeEnum = spirv::Opcode::OpTypeMatrix; operands.push_back(elementTypeID); operands.push_back(matrixType.getNumColumns()); return success(); } // TODO: Handle other types. return emitError(loc, "unhandled type in serialization: ") << type; } LogicalResult Serializer::prepareFunctionType(Location loc, FunctionType type, spirv::Opcode &typeEnum, SmallVectorImpl &operands) { typeEnum = spirv::Opcode::OpTypeFunction; assert(type.getNumResults() <= 1 && "serialization supports only a single return value"); uint32_t resultID = 0; if (failed(processType( loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(), resultID))) { return failure(); } operands.push_back(resultID); for (auto &res : type.getInputs()) { uint32_t argTypeID = 0; if (failed(processType(loc, res, argTypeID))) { return failure(); } operands.push_back(argTypeID); } return success(); } //===----------------------------------------------------------------------===// // Constant //===----------------------------------------------------------------------===// uint32_t Serializer::prepareConstant(Location loc, Type constType, Attribute valueAttr) { if (auto id = prepareConstantScalar(loc, valueAttr)) { return id; } // This is a composite literal. We need to handle each component separately // and then emit an OpConstantComposite for the whole. if (auto id = getConstantID(valueAttr)) { return id; } uint32_t typeID = 0; if (failed(processType(loc, constType, typeID))) { return 0; } uint32_t resultID = 0; if (auto attr = valueAttr.dyn_cast()) { int rank = attr.getType().dyn_cast().getRank(); SmallVector index(rank); resultID = prepareDenseElementsConstant(loc, constType, attr, /*dim=*/0, index); } else if (auto arrayAttr = valueAttr.dyn_cast()) { resultID = prepareArrayConstant(loc, constType, arrayAttr); } if (resultID == 0) { emitError(loc, "cannot serialize attribute: ") << valueAttr; return 0; } constIDMap[valueAttr] = resultID; return resultID; } uint32_t Serializer::prepareArrayConstant(Location loc, Type constType, ArrayAttr attr) { uint32_t typeID = 0; if (failed(processType(loc, constType, typeID))) { return 0; } uint32_t resultID = getNextID(); SmallVector operands = {typeID, resultID}; operands.reserve(attr.size() + 2); auto elementType = constType.cast().getElementType(); for (Attribute elementAttr : attr) { if (auto elementID = prepareConstant(loc, elementType, elementAttr)) { operands.push_back(elementID); } else { return 0; } } spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; encodeInstructionInto(typesGlobalValues, opcode, operands); return resultID; } // TODO: Turn the below function into iterative function, instead of // recursive function. uint32_t Serializer::prepareDenseElementsConstant(Location loc, Type constType, DenseElementsAttr valueAttr, int dim, MutableArrayRef index) { auto shapedType = valueAttr.getType().dyn_cast(); assert(dim <= shapedType.getRank()); if (shapedType.getRank() == dim) { if (auto attr = valueAttr.dyn_cast()) { return attr.getType().getElementType().isInteger(1) ? prepareConstantBool(loc, attr.getValue(index)) : prepareConstantInt(loc, attr.getValue(index)); } if (auto attr = valueAttr.dyn_cast()) { return prepareConstantFp(loc, attr.getValue(index)); } return 0; } uint32_t typeID = 0; if (failed(processType(loc, constType, typeID))) { return 0; } uint32_t resultID = getNextID(); SmallVector operands = {typeID, resultID}; operands.reserve(shapedType.getDimSize(dim) + 2); auto elementType = constType.cast().getElementType(0); for (int i = 0; i < shapedType.getDimSize(dim); ++i) { index[dim] = i; if (auto elementID = prepareDenseElementsConstant( loc, elementType, valueAttr, dim + 1, index)) { operands.push_back(elementID); } else { return 0; } } spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; encodeInstructionInto(typesGlobalValues, opcode, operands); return resultID; } uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr, bool isSpec) { if (auto floatAttr = valueAttr.dyn_cast()) { return prepareConstantFp(loc, floatAttr, isSpec); } if (auto boolAttr = valueAttr.dyn_cast()) { return prepareConstantBool(loc, boolAttr, isSpec); } if (auto intAttr = valueAttr.dyn_cast()) { return prepareConstantInt(loc, intAttr, isSpec); } return 0; } uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr, bool isSpec) { if (!isSpec) { // We can de-duplicate normal constants, but not specialization constants. if (auto id = getConstantID(boolAttr)) { return id; } } // Process the type for this bool literal uint32_t typeID = 0; if (failed(processType(loc, boolAttr.getType(), typeID))) { return 0; } auto resultID = getNextID(); auto opcode = boolAttr.getValue() ? (isSpec ? spirv::Opcode::OpSpecConstantTrue : spirv::Opcode::OpConstantTrue) : (isSpec ? spirv::Opcode::OpSpecConstantFalse : spirv::Opcode::OpConstantFalse); encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID}); if (!isSpec) { constIDMap[boolAttr] = resultID; } return resultID; } uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr, bool isSpec) { if (!isSpec) { // We can de-duplicate normal constants, but not specialization constants. if (auto id = getConstantID(intAttr)) { return id; } } // Process the type for this integer literal uint32_t typeID = 0; if (failed(processType(loc, intAttr.getType(), typeID))) { return 0; } auto resultID = getNextID(); APInt value = intAttr.getValue(); unsigned bitwidth = value.getBitWidth(); bool isSigned = value.isSignedIntN(bitwidth); auto opcode = isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; // According to SPIR-V spec, "When the type's bit width is less than 32-bits, // the literal's value appears in the low-order bits of the word, and the // high-order bits must be 0 for a floating-point type, or 0 for an integer // type with Signedness of 0, or sign extended when Signedness is 1." if (bitwidth == 32 || bitwidth == 16) { uint32_t word = 0; if (isSigned) { word = static_cast(value.getSExtValue()); } else { word = static_cast(value.getZExtValue()); } encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); } // According to SPIR-V spec: "When the type's bit width is larger than one // word, the literal’s low-order words appear first." else if (bitwidth == 64) { struct DoubleWord { uint32_t word1; uint32_t word2; } words; if (isSigned) { words = llvm::bit_cast(value.getSExtValue()); } else { words = llvm::bit_cast(value.getZExtValue()); } encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, words.word1, words.word2}); } else { std::string valueStr; llvm::raw_string_ostream rss(valueStr); value.print(rss, /*isSigned=*/false); emitError(loc, "cannot serialize ") << bitwidth << "-bit integer literal: " << rss.str(); return 0; } if (!isSpec) { constIDMap[intAttr] = resultID; } return resultID; } uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, bool isSpec) { if (!isSpec) { // We can de-duplicate normal constants, but not specialization constants. if (auto id = getConstantID(floatAttr)) { return id; } } // Process the type for this float literal uint32_t typeID = 0; if (failed(processType(loc, floatAttr.getType(), typeID))) { return 0; } auto resultID = getNextID(); APFloat value = floatAttr.getValue(); APInt intValue = value.bitcastToAPInt(); auto opcode = isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; if (&value.getSemantics() == &APFloat::IEEEsingle()) { uint32_t word = llvm::bit_cast(value.convertToFloat()); encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); } else if (&value.getSemantics() == &APFloat::IEEEdouble()) { struct DoubleWord { uint32_t word1; uint32_t word2; } words = llvm::bit_cast(value.convertToDouble()); encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, words.word1, words.word2}); } else if (&value.getSemantics() == &APFloat::IEEEhalf()) { uint32_t word = static_cast(value.bitcastToAPInt().getZExtValue()); encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); } else { std::string valueStr; llvm::raw_string_ostream rss(valueStr); value.print(rss); emitError(loc, "cannot serialize ") << floatAttr.getType() << "-typed float literal: " << rss.str(); return 0; } if (!isSpec) { constIDMap[floatAttr] = resultID; } return resultID; } //===----------------------------------------------------------------------===// // Control flow //===----------------------------------------------------------------------===// uint32_t Serializer::getOrCreateBlockID(Block *block) { if (uint32_t id = getBlockID(block)) return id; return blockIDMap[block] = getNextID(); } LogicalResult Serializer::processBlock(Block *block, bool omitLabel, function_ref actionBeforeTerminator) { LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n"); LLVM_DEBUG(block->print(llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << '\n'); if (!omitLabel) { uint32_t blockID = getOrCreateBlockID(block); LLVM_DEBUG(llvm::dbgs() << "[block] " << block << " (id = " << blockID << ")\n"); // Emit OpLabel for this block. encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID}); } // Emit OpPhi instructions for block arguments, if any. if (failed(emitPhiForBlockArguments(block))) return failure(); // Process each op in this block except the terminator. for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) { if (failed(processOperation(&op))) return failure(); } // Process the terminator. if (actionBeforeTerminator) actionBeforeTerminator(); if (failed(processOperation(&block->back()))) return failure(); return success(); } LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { // Nothing to do if this block has no arguments or it's the entry block, which // always has the same arguments as the function signature. if (block->args_empty() || block->isEntryBlock()) return success(); // If the block has arguments, we need to create SPIR-V OpPhi instructions. // A SPIR-V OpPhi instruction is of the syntax: // OpPhi | result type | result | (value , parent block ) pair // So we need to collect all predecessor blocks and the arguments they send // to this block. SmallVector, 4> predecessors; for (Block *predecessor : block->getPredecessors()) { auto *terminator = predecessor->getTerminator(); // The predecessor here is the immediate one according to MLIR's IR // structure. It does not directly map to the incoming parent block for the // OpPhi instructions at SPIR-V binary level. This is because structured // control flow ops are serialized to multiple SPIR-V blocks. If there is a // spv.selection/spv.loop op in the MLIR predecessor block, the branch op // jumping to the OpPhi's block then resides in the previous structured // control flow op's merge block. predecessor = getPhiIncomingBlock(predecessor); if (auto branchOp = dyn_cast(terminator)) { predecessors.emplace_back(predecessor, branchOp.operand_begin()); } else { return terminator->emitError("unimplemented terminator for Phi creation"); } } // Then create OpPhi instruction for each of the block argument. for (auto argIndex : llvm::seq(0, block->getNumArguments())) { BlockArgument arg = block->getArgument(argIndex); // Get the type and result for this OpPhi instruction. uint32_t phiTypeID = 0; if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID))) return failure(); uint32_t phiID = getNextID(); LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' ' << arg << " (id = " << phiID << ")\n"); // Prepare the (value , parent block ) pairs. SmallVector phiArgs; phiArgs.push_back(phiTypeID); phiArgs.push_back(phiID); for (auto predIndex : llvm::seq(0, predecessors.size())) { Value value = *(predecessors[predIndex].second + argIndex); uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first); LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId << ") value " << value << ' '); // Each pair is a value ... uint32_t valueId = getValueID(value); if (valueId == 0) { // The op generating this value hasn't been visited yet so we don't have // an assigned yet. Record this to fix up later. LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n"); deferredPhiValues[value].push_back(functionBody.size() + 1 + phiArgs.size()); } else { LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n"); } phiArgs.push_back(valueId); // ... and a parent block . phiArgs.push_back(predBlockId); } encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs); valueIDMap[arg] = phiID; } return success(); } LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) { // Assign s to all blocks so that branches inside the SelectionOp can // resolve properly. auto &body = selectionOp.body(); for (Block &block : body) getOrCreateBlockID(&block); auto *headerBlock = selectionOp.getHeaderBlock(); auto *mergeBlock = selectionOp.getMergeBlock(); auto mergeID = getBlockID(mergeBlock); auto loc = selectionOp.getLoc(); // Emit the selection header block, which dominates all other blocks, first. // We need to emit an OpSelectionMerge instruction before the selection header // block's terminator. auto emitSelectionMerge = [&]() { emitDebugLine(functionBody, loc); lastProcessedWasMergeInst = true; encodeInstructionInto( functionBody, spirv::Opcode::OpSelectionMerge, {mergeID, static_cast(selectionOp.selection_control())}); }; // For structured selection, we cannot have blocks in the selection construct // branching to the selection header block. Entering the selection (and // reaching the selection header) must be from the block containing the // spv.selection op. If there are ops ahead of the spv.selection op in the // block, we can "merge" them into the selection header. So here we don't need // to emit a separate block; just continue with the existing block. if (failed(processBlock(headerBlock, /*omitLabel=*/true, emitSelectionMerge))) return failure(); // Process all blocks with a depth-first visitor starting from the header // block. The selection header block and merge block are skipped by this // visitor. if (failed(visitInPrettyBlockOrder( headerBlock, [&](Block *block) { return processBlock(block); }, /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock}))) return failure(); // There is nothing to do for the merge block in the selection, which just // contains a spv.mlir.merge op, itself. But we need to have an OpLabel // instruction to start a new SPIR-V block for ops following this SelectionOp. // The block should use the for the merge block. return encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); } LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { // Assign s to all blocks so that branches inside the LoopOp can resolve // properly. We don't need to assign for the entry block, which is just for // satisfying MLIR region's structural requirement. auto &body = loopOp.body(); for (Block &block : llvm::make_range(std::next(body.begin(), 1), body.end())) { getOrCreateBlockID(&block); } auto *headerBlock = loopOp.getHeaderBlock(); auto *continueBlock = loopOp.getContinueBlock(); auto *mergeBlock = loopOp.getMergeBlock(); auto headerID = getBlockID(headerBlock); auto continueID = getBlockID(continueBlock); auto mergeID = getBlockID(mergeBlock); auto loc = loopOp.getLoc(); // This LoopOp is in some MLIR block with preceding and following ops. In the // binary format, it should reside in separate SPIR-V blocks from its // preceding and following ops. So we need to emit unconditional branches to // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow // afterwards. encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID}); // LoopOp's entry block is just there for satisfying MLIR's structural // requirements so we omit it and start serialization from the loop header // block. // Emit the loop header block, which dominates all other blocks, first. We // need to emit an OpLoopMerge instruction before the loop header block's // terminator. auto emitLoopMerge = [&]() { emitDebugLine(functionBody, loc); lastProcessedWasMergeInst = true; encodeInstructionInto( functionBody, spirv::Opcode::OpLoopMerge, {mergeID, continueID, static_cast(loopOp.loop_control())}); }; if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge))) return failure(); // Process all blocks with a depth-first visitor starting from the header // block. The loop header block, loop continue block, and loop merge block are // skipped by this visitor and handled later in this function. if (failed(visitInPrettyBlockOrder( headerBlock, [&](Block *block) { return processBlock(block); }, /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock}))) return failure(); // We have handled all other blocks. Now get to the loop continue block. if (failed(processBlock(continueBlock))) return failure(); // There is nothing to do for the merge block in the loop, which just contains // a spv.mlir.merge op, itself. But we need to have an OpLabel instruction to // start a new SPIR-V block for ops following this LoopOp. The block should // use the for the merge block. return encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); } LogicalResult Serializer::processBranchConditionalOp( spirv::BranchConditionalOp condBranchOp) { auto conditionID = getValueID(condBranchOp.condition()); auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock()); auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock()); SmallVector arguments{conditionID, trueLabelID, falseLabelID}; if (auto weights = condBranchOp.branch_weights()) { for (auto val : weights->getValue()) arguments.push_back(val.cast().getInt()); } emitDebugLine(functionBody, condBranchOp.getLoc()); return encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional, arguments); } LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) { emitDebugLine(functionBody, branchOp.getLoc()); return encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {getOrCreateBlockID(branchOp.getTarget())}); } //===----------------------------------------------------------------------===// // Operation //===----------------------------------------------------------------------===// LogicalResult Serializer::encodeExtensionInstruction( Operation *op, StringRef extensionSetName, uint32_t extensionOpcode, ArrayRef operands) { // Check if the extension has been imported. auto &setID = extendedInstSetIDMap[extensionSetName]; if (!setID) { setID = getNextID(); SmallVector importOperands; importOperands.push_back(setID); if (failed( spirv::encodeStringLiteralInto(importOperands, extensionSetName)) || failed(encodeInstructionInto( extendedSets, spirv::Opcode::OpExtInstImport, importOperands))) { return failure(); } } // The first two operands are the result type and result . The set // and the opcode need to be insert after this. if (operands.size() < 2) { return op->emitError("extended instructions must have a result encoding"); } SmallVector extInstOperands; extInstOperands.reserve(operands.size() + 2); extInstOperands.append(operands.begin(), std::next(operands.begin(), 2)); extInstOperands.push_back(setID); extInstOperands.push_back(extensionOpcode); extInstOperands.append(std::next(operands.begin(), 2), operands.end()); return encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst, extInstOperands); } LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) { auto varName = addressOfOp.variable(); auto variableID = getVariableID(varName); if (!variableID) { return addressOfOp.emitError("unknown result for variable ") << varName; } valueIDMap[addressOfOp.pointer()] = variableID; return success(); } LogicalResult Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) { auto constName = referenceOfOp.spec_const(); auto constID = getSpecConstID(constName); if (!constID) { return referenceOfOp.emitError( "unknown result for specialization constant ") << constName; } valueIDMap[referenceOfOp.reference()] = constID; return success(); } LogicalResult Serializer::processOperation(Operation *opInst) { LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n"); // First dispatch the ops that do not directly mirror an instruction from // the SPIR-V spec. return TypeSwitch(opInst) .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); }) .Case([&](spirv::BranchOp op) { return processBranchOp(op); }) .Case([&](spirv::BranchConditionalOp op) { return processBranchConditionalOp(op); }) .Case([&](spirv::ConstantOp op) { return processConstantOp(op); }) .Case([&](spirv::FuncOp op) { return processFuncOp(op); }) .Case([&](spirv::GlobalVariableOp op) { return processGlobalVariableOp(op); }) .Case([&](spirv::LoopOp op) { return processLoopOp(op); }) .Case([&](spirv::ModuleEndOp) { return success(); }) .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); }) .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); }) .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); }) .Case([&](spirv::SpecConstantCompositeOp op) { return processSpecConstantCompositeOp(op); }) .Case([&](spirv::UndefOp op) { return processUndefOp(op); }) .Case([&](spirv::VariableOp op) { return processVariableOp(op); }) // Then handle all the ops that directly mirror SPIR-V instructions with // auto-generated methods. .Default( [&](Operation *op) { return dispatchToAutogenSerialization(op); }); } namespace { template <> LogicalResult Serializer::processOp(spirv::EntryPointOp op) { SmallVector operands; // Add the ExecutionModel. operands.push_back(static_cast(op.execution_model())); // Add the function . auto funcID = getFunctionID(op.fn()); if (!funcID) { return op.emitError("missing for function ") << op.fn() << "; function needs to be defined before spv.EntryPoint is " "serialized"; } operands.push_back(funcID); // Add the name of the function. spirv::encodeStringLiteralInto(operands, op.fn()); // Add the interface values. if (auto interface = op.interface()) { for (auto var : interface.getValue()) { auto id = getVariableID(var.cast().getValue()); if (!id) { return op.emitError("referencing undefined global variable." "spv.EntryPoint is at the end of spv.module. All " "referenced variables should already be defined"); } operands.push_back(id); } } return encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, operands); } template <> LogicalResult Serializer::processOp(spirv::ControlBarrierOp op) { StringRef argNames[] = {"execution_scope", "memory_scope", "memory_semantics"}; SmallVector operands; for (auto argName : argNames) { auto argIntAttr = op->getAttrOfType(argName); auto operand = prepareConstantInt(op.getLoc(), argIntAttr); if (!operand) { return failure(); } operands.push_back(operand); } return encodeInstructionInto(functionBody, spirv::Opcode::OpControlBarrier, operands); } template <> LogicalResult Serializer::processOp(spirv::ExecutionModeOp op) { SmallVector operands; // Add the function . auto funcID = getFunctionID(op.fn()); if (!funcID) { return op.emitError("missing for function ") << op.fn() << "; function needs to be serialized before ExecutionModeOp is " "serialized"; } operands.push_back(funcID); // Add the ExecutionMode. operands.push_back(static_cast(op.execution_mode())); // Serialize values if any. auto values = op.values(); if (values) { for (auto &intVal : values.getValue()) { operands.push_back(static_cast( intVal.cast().getValue().getZExtValue())); } } return encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode, operands); } template <> LogicalResult Serializer::processOp(spirv::MemoryBarrierOp op) { StringRef argNames[] = {"memory_scope", "memory_semantics"}; SmallVector operands; for (auto argName : argNames) { auto argIntAttr = op->getAttrOfType(argName); auto operand = prepareConstantInt(op.getLoc(), argIntAttr); if (!operand) { return failure(); } operands.push_back(operand); } return encodeInstructionInto(functionBody, spirv::Opcode::OpMemoryBarrier, operands); } template <> LogicalResult Serializer::processOp(spirv::FunctionCallOp op) { auto funcName = op.callee(); uint32_t resTypeID = 0; Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType(); if (failed(processType(op.getLoc(), resultTy, resTypeID))) return failure(); auto funcID = getOrCreateFunctionID(funcName); auto funcCallID = getNextID(); SmallVector operands{resTypeID, funcCallID, funcID}; for (auto value : op.arguments()) { auto valueID = getValueID(value); assert(valueID && "cannot find a value for spv.FunctionCall"); operands.push_back(valueID); } if (!resultTy.isa()) valueIDMap[op.getResult(0)] = funcCallID; return encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, operands); } template <> LogicalResult Serializer::processOp(spirv::CopyMemoryOp op) { SmallVector operands; SmallVector elidedAttrs; for (Value operand : op->getOperands()) { auto id = getValueID(operand); assert(id && "use before def!"); operands.push_back(id); } if (auto attr = op.getAttr("memory_access")) { operands.push_back(static_cast( attr.cast().getValue().getZExtValue())); } elidedAttrs.push_back("memory_access"); if (auto attr = op.getAttr("alignment")) { operands.push_back(static_cast( attr.cast().getValue().getZExtValue())); } elidedAttrs.push_back("alignment"); if (auto attr = op.getAttr("source_memory_access")) { operands.push_back(static_cast( attr.cast().getValue().getZExtValue())); } elidedAttrs.push_back("source_memory_access"); if (auto attr = op.getAttr("source_alignment")) { operands.push_back(static_cast( attr.cast().getValue().getZExtValue())); } elidedAttrs.push_back("source_alignment"); emitDebugLine(functionBody, op.getLoc()); encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands); return success(); } // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and // various Serializer::processOp<...>() specializations. #define GET_SERIALIZATION_FNS #include "mlir/Dialect/SPIRV/SPIRVSerialization.inc" } // namespace LogicalResult Serializer::emitDecoration(uint32_t target, spirv::Decoration decoration, ArrayRef params) { uint32_t wordCount = 3 + params.size(); decorations.push_back( spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate)); decorations.push_back(target); decorations.push_back(static_cast(decoration)); decorations.append(params.begin(), params.end()); return success(); } LogicalResult Serializer::emitDebugLine(SmallVectorImpl &binary, Location loc) { if (!emitDebugInfo) return success(); if (lastProcessedWasMergeInst) { lastProcessedWasMergeInst = false; return success(); } auto fileLoc = loc.dyn_cast(); if (fileLoc) encodeInstructionInto(binary, spirv::Opcode::OpLine, {fileID, fileLoc.getLine(), fileLoc.getColumn()}); return success(); } LogicalResult spirv::serialize(spirv::ModuleOp module, SmallVectorImpl &binary, bool emitDebugInfo) { if (!module.vce_triple().hasValue()) return module.emitError( "module must have 'vce_triple' attribute to be serializeable"); Serializer serializer(module, emitDebugInfo); if (failed(serializer.serialize())) return failure(); LLVM_DEBUG(serializer.printValueIDMap(llvm::dbgs())); serializer.collect(binary); return success(); }