//===- ConvertSPIRVToLLVM.cpp - SPIR-V dialect to LLVM dialect conversion -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns to convert SPIR-V dialect to LLVM dialect. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/SPIRV/LayoutUtils.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" #define DEBUG_TYPE "spirv-to-llvm-pattern" using namespace mlir; //===----------------------------------------------------------------------===// // Utility functions //===----------------------------------------------------------------------===// /// Returns true if the given type is a signed integer or vector type. static bool isSignedIntegerOrVector(Type type) { if (type.isSignedInteger()) return true; if (auto vecType = type.dyn_cast()) return vecType.getElementType().isSignedInteger(); return false; } /// Returns true if the given type is an unsigned integer or vector type static bool isUnsignedIntegerOrVector(Type type) { if (type.isUnsignedInteger()) return true; if (auto vecType = type.dyn_cast()) return vecType.getElementType().isUnsignedInteger(); return false; } /// Returns the bit width of integer, float or vector of float or integer values static unsigned getBitWidth(Type type) { assert((type.isIntOrFloat() || type.isa()) && "bitwidth is not supported for this type"); if (type.isIntOrFloat()) return type.getIntOrFloatBitWidth(); auto vecType = type.dyn_cast(); auto elementType = vecType.getElementType(); assert(elementType.isIntOrFloat() && "only integers and floats have a bitwidth"); return elementType.getIntOrFloatBitWidth(); } /// Returns the bit width of LLVMType integer or vector. static unsigned getLLVMTypeBitWidth(LLVM::LLVMType type) { return type.isVectorTy() ? type.getVectorElementType().getIntegerBitWidth() : type.getIntegerBitWidth(); } /// Creates `IntegerAttribute` with all bits set for given type static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) { if (auto vecType = type.dyn_cast()) { auto integerType = vecType.getElementType().cast(); return builder.getIntegerAttr(integerType, -1); } auto integerType = type.cast(); return builder.getIntegerAttr(integerType, -1); } /// Creates `llvm.mlir.constant` with all bits set for the given type. static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter) { if (srcType.isa()) { return rewriter.create( loc, dstType, SplatElementsAttr::get(srcType.cast(), minusOneIntegerAttribute(srcType, rewriter))); } return rewriter.create( loc, dstType, minusOneIntegerAttribute(srcType, rewriter)); } /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value. static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value) { if (auto vecType = srcType.dyn_cast()) { auto floatType = vecType.getElementType().cast(); return rewriter.create( loc, dstType, SplatElementsAttr::get(vecType, rewriter.getFloatAttr(floatType, value))); } auto floatType = srcType.cast(); return rewriter.create( loc, dstType, rewriter.getFloatAttr(floatType, value)); } /// Utility function for bitfield ops: /// - `BitFieldInsert` /// - `BitFieldSExtract` /// - `BitFieldUExtract` /// Truncates or extends the value. If the bitwidth of the value is the same as /// `dstType` bitwidth, the value remains unchanged. static Value optionallyTruncateOrExtend(Location loc, Value value, Type dstType, PatternRewriter &rewriter) { auto srcType = value.getType(); auto llvmType = dstType.cast(); unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType); unsigned valueBitWidth = srcType.isa() ? getLLVMTypeBitWidth(srcType.cast()) : getBitWidth(srcType); if (valueBitWidth < targetBitWidth) return rewriter.create(loc, llvmType, value); // If the bit widths of `Count` and `Offset` are greater than the bit width // of the target type, they are truncated. Truncation is safe since `Count` // and `Offset` must be no more than 64 for op behaviour to be defined. Hence, // both values can be expressed in 8 bits. if (valueBitWidth > targetBitWidth) return rewriter.create(loc, llvmType, value); return value; } /// Broadcasts the value to vector with `numElements` number of elements. static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { auto vectorType = VectorType::get(numElements, toBroadcast.getType()); auto llvmVectorType = typeConverter.convertType(vectorType); auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32)); Value broadcasted = rewriter.create(loc, llvmVectorType); for (unsigned i = 0; i < numElements; ++i) { auto index = rewriter.create( loc, llvmI32Type, rewriter.getI32IntegerAttr(i)); broadcasted = rewriter.create( loc, llvmVectorType, broadcasted, toBroadcast, index); } return broadcasted; } /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged. static Value optionallyBroadcast(Location loc, Value value, Type srcType, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { if (auto vectorType = srcType.dyn_cast()) { unsigned numElements = vectorType.getNumElements(); return broadcast(loc, value, numElements, typeConverter, rewriter); } return value; } /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and /// `BitFieldUExtract`. /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of /// a vector type, construct a vector that has: /// - same number of elements as `Base` /// - each element has the type that is the same as the type of `Offset` or /// `Count` /// - each element has the same value as `Offset` or `Count` /// Then cast `Offset` and `Count` if their bit width is different /// from `Base` bit width. static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, LLVMTypeConverter &converter, ConversionPatternRewriter &rewriter) { Value broadcasted = optionallyBroadcast(loc, value, srcType, converter, rewriter); return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter); } /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`) /// offset to LLVM struct. Otherwise, the conversion is not supported. static Optional convertStructTypeWithOffset(spirv::StructType type, LLVMTypeConverter &converter) { if (type != VulkanLayoutUtils::decorateType(type)) return llvm::None; auto elementsVector = llvm::to_vector<8>( llvm::map_range(type.getElementTypes(), [&](Type elementType) { return converter.convertType(elementType).cast(); })); return LLVM::LLVMType::getStructTy(type.getContext(), elementsVector, /*isPacked=*/false); } /// Converts SPIR-V struct with no offset to packed LLVM struct. static Type convertStructTypePacked(spirv::StructType type, LLVMTypeConverter &converter) { auto elementsVector = llvm::to_vector<8>( llvm::map_range(type.getElementTypes(), [&](Type elementType) { return converter.convertType(elementType).cast(); })); return LLVM::LLVMType::getStructTy(type.getContext(), elementsVector, /*isPacked=*/true); } /// Creates LLVM dialect constant with the given value. static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value) { return rewriter.create( loc, LLVM::LLVMType::getInt32Ty(rewriter.getContext()), rewriter.getIntegerAttr(rewriter.getI32Type(), value)); } /// Utility for `spv.Load` and `spv.Store` conversion. static LogicalResult replaceWithLoadOrStore(Operation *op, ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal) { if (auto loadOp = dyn_cast(op)) { auto dstType = typeConverter.convertType(loadOp.getType()); if (!dstType) return failure(); rewriter.replaceOpWithNewOp( loadOp, dstType, loadOp.ptr(), alignment, isVolatile, isNonTemporal); return success(); } auto storeOp = cast(op); rewriter.replaceOpWithNewOp(storeOp, storeOp.value(), storeOp.ptr(), alignment, isVolatile, isNonTemporal); return success(); } //===----------------------------------------------------------------------===// // Type conversion //===----------------------------------------------------------------------===// /// Converts SPIR-V array type to LLVM array. Natural stride (according to /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected /// when converting ops that manipulate array types. static Optional convertArrayType(spirv::ArrayType type, TypeConverter &converter) { unsigned stride = type.getArrayStride(); Type elementType = type.getElementType(); auto sizeInBytes = elementType.cast().getSizeInBytes(); if (stride != 0 && !(sizeInBytes.hasValue() && sizeInBytes.getValue() == stride)) return llvm::None; auto llvmElementType = converter.convertType(elementType).cast(); unsigned numElements = type.getNumElements(); return LLVM::LLVMType::getArrayTy(llvmElementType, numElements); } /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not /// modelled at the moment. static Type convertPointerType(spirv::PointerType type, TypeConverter &converter) { auto pointeeType = converter.convertType(type.getPointeeType()).cast(); return pointeeType.getPointerTo(); } /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is /// no modelling of array stride at the moment. static Optional convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter) { if (type.getArrayStride() != 0) return llvm::None; auto elementType = converter.convertType(type.getElementType()).cast(); return LLVM::LLVMType::getArrayTy(elementType, 0); } /// Converts SPIR-V struct to LLVM struct. There is no support of structs with /// member decorations. Also, only natural offset is supported. static Optional convertStructType(spirv::StructType type, LLVMTypeConverter &converter) { SmallVector memberDecorations; type.getMemberDecorations(memberDecorations); if (!memberDecorations.empty()) return llvm::None; if (type.hasOffset()) return convertStructTypeWithOffset(type, converter); return convertStructTypePacked(type, converter); } //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// namespace { class AccessChainPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::AccessChainOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dstType = typeConverter.convertType(op.component_ptr().getType()); if (!dstType) return failure(); // To use GEP we need to add a first 0 index to go through the pointer. auto indices = llvm::to_vector<4>(op.indices()); Type indexType = op.indices().front().getType(); auto llvmIndexType = typeConverter.convertType(indexType); if (!llvmIndexType) return failure(); Value zero = rewriter.create( op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0)); indices.insert(indices.begin(), zero); rewriter.replaceOpWithNewOp(op, dstType, op.base_ptr(), indices); return success(); } }; class AddressOfPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::AddressOfOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dstType = typeConverter.convertType(op.pointer().getType()); if (!dstType) return failure(); rewriter.replaceOpWithNewOp( op, dstType.cast(), op.variable()); return success(); } }; class BitFieldInsertPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::BitFieldInsertOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto srcType = op.getType(); auto dstType = typeConverter.convertType(srcType); if (!dstType) return failure(); Location loc = op.getLoc(); // Process `Offset` and `Count`: broadcast and extend/truncate if needed. Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType, typeConverter, rewriter); Value count = processCountOrOffset(loc, op.count(), srcType, dstType, typeConverter, rewriter); // Create a mask with bits set outside [Offset, Offset + Count - 1]. Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); Value maskShiftedByCount = rewriter.create(loc, dstType, minusOne, count); Value negated = rewriter.create(loc, dstType, maskShiftedByCount, minusOne); Value maskShiftedByCountAndOffset = rewriter.create(loc, dstType, negated, offset); Value mask = rewriter.create( loc, dstType, maskShiftedByCountAndOffset, minusOne); // Extract unchanged bits from the `Base` that are outside of // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`. Value baseAndMask = rewriter.create(loc, dstType, op.base(), mask); Value insertShiftedByOffset = rewriter.create(loc, dstType, op.insert(), offset); rewriter.replaceOpWithNewOp(op, dstType, baseAndMask, insertShiftedByOffset); return success(); } }; /// Converts SPIR-V ConstantOp with scalar or vector type. class ConstantScalarAndVectorPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::ConstantOp constOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto srcType = constOp.getType(); if (!srcType.isa() && !srcType.isIntOrFloat()) return failure(); auto dstType = typeConverter.convertType(srcType); if (!dstType) return failure(); // SPIR-V constant can be a signed/unsigned integer, which has to be // casted to signless integer when converting to LLVM dialect. Removing the // sign bit may have unexpected behaviour. However, it is better to handle // it case-by-case, given that the purpose of the conversion is not to // cover all possible corner cases. if (isSignedIntegerOrVector(srcType) || isUnsignedIntegerOrVector(srcType)) { auto *context = rewriter.getContext(); auto signlessType = IntegerType::get(getBitWidth(srcType), context); if (srcType.isa()) { auto dstElementsAttr = constOp.value().cast(); rewriter.replaceOpWithNewOp( constOp, dstType, dstElementsAttr.mapValues( signlessType, [&](const APInt &value) { return value; })); return success(); } auto srcAttr = constOp.value().cast(); auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue()); rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); return success(); } rewriter.replaceOpWithNewOp(constOp, dstType, operands, constOp.getAttrs()); return success(); } }; class BitFieldSExtractPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::BitFieldSExtractOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto srcType = op.getType(); auto dstType = typeConverter.convertType(srcType); if (!dstType) return failure(); Location loc = op.getLoc(); // Process `Offset` and `Count`: broadcast and extend/truncate if needed. Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType, typeConverter, rewriter); Value count = processCountOrOffset(loc, op.count(), srcType, dstType, typeConverter, rewriter); // Create a constant that holds the size of the `Base`. IntegerType integerType; if (auto vecType = srcType.dyn_cast()) integerType = vecType.getElementType().cast(); else integerType = srcType.cast(); auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType)); Value size = srcType.isa() ? rewriter.create( loc, dstType, SplatElementsAttr::get(srcType.cast(), baseSize)) : rewriter.create(loc, dstType, baseSize); // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit // at Offset + Count - 1 is the most significant bit now. Value countPlusOffset = rewriter.create(loc, dstType, count, offset); Value amountToShiftLeft = rewriter.create(loc, dstType, size, countPlusOffset); Value baseShiftedLeft = rewriter.create( loc, dstType, op.base(), amountToShiftLeft); // Shift the result right, filling the bits with the sign bit. Value amountToShiftRight = rewriter.create(loc, dstType, offset, amountToShiftLeft); rewriter.replaceOpWithNewOp(op, dstType, baseShiftedLeft, amountToShiftRight); return success(); } }; class BitFieldUExtractPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::BitFieldUExtractOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto srcType = op.getType(); auto dstType = typeConverter.convertType(srcType); if (!dstType) return failure(); Location loc = op.getLoc(); // Process `Offset` and `Count`: broadcast and extend/truncate if needed. Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType, typeConverter, rewriter); Value count = processCountOrOffset(loc, op.count(), srcType, dstType, typeConverter, rewriter); // Create a mask with bits set at [0, Count - 1]. Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); Value maskShiftedByCount = rewriter.create(loc, dstType, minusOne, count); Value mask = rewriter.create(loc, dstType, maskShiftedByCount, minusOne); // Shift `Base` by `Offset` and apply the mask on it. Value shiftedBase = rewriter.create(loc, dstType, op.base(), offset); rewriter.replaceOpWithNewOp(op, dstType, shiftedBase, mask); return success(); } }; class BranchConversionPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::BranchOp branchOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(branchOp, operands, branchOp.getTarget()); return success(); } }; class BranchConditionalConversionPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion< spirv::BranchConditionalOp>::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::BranchConditionalOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // If branch weights exist, map them to 32-bit integer vector. ElementsAttr branchWeights = nullptr; if (auto weights = op.branch_weights()) { VectorType weightType = VectorType::get(2, rewriter.getI32Type()); branchWeights = DenseElementsAttr::get(weightType, weights.getValue().getValue()); } rewriter.replaceOpWithNewOp( op, op.condition(), op.getTrueBlockArguments(), op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(), op.getFalseBlock()); return success(); } }; /// Converts `spv.CompositeExtract` to `llvm.extractvalue` if the container type /// is an aggregate type (struct or array). Otherwise, converts to /// `llvm.extractelement` that operates on vectors. class CompositeExtractPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::CompositeExtractOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dstType = this->typeConverter.convertType(op.getType()); if (!dstType) return failure(); Type containerType = op.composite().getType(); if (containerType.isa()) { Location loc = op.getLoc(); IntegerAttr value = op.indices()[0].cast(); Value index = createI32ConstantOf(loc, rewriter, value.getInt()); rewriter.replaceOpWithNewOp( op, dstType, op.composite(), index); return success(); } rewriter.replaceOpWithNewOp( op, dstType, op.composite(), op.indices()); return success(); } }; /// Converts `spv.CompositeInsert` to `llvm.insertvalue` if the container type /// is an aggregate type (struct or array). Otherwise, converts to /// `llvm.insertelement` that operates on vectors. class CompositeInsertPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::CompositeInsertOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dstType = this->typeConverter.convertType(op.getType()); if (!dstType) return failure(); Type containerType = op.composite().getType(); if (containerType.isa()) { Location loc = op.getLoc(); IntegerAttr value = op.indices()[0].cast(); Value index = createI32ConstantOf(loc, rewriter, value.getInt()); rewriter.replaceOpWithNewOp( op, dstType, op.composite(), op.object(), index); return success(); } rewriter.replaceOpWithNewOp( op, dstType, op.composite(), op.object(), op.indices()); return success(); } }; /// Converts SPIR-V operations that have straightforward LLVM equivalent /// into LLVM dialect operations. template class DirectConversionPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(SPIRVOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dstType = this->typeConverter.convertType(operation.getType()); if (!dstType) return failure(); rewriter.template replaceOpWithNewOp(operation, dstType, operands, operation.getAttrs()); return success(); } }; /// Converts `spv.ExecutionMode` into a global struct constant that holds /// execution mode information. class ExecutionModePattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::ExecutionModeOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // First, create the global struct's name that would be associated with // this entry point's execution mode. We set it to be: // __spv__{SPIR-V module name}_{function name}_execution_mode_info ModuleOp module = op->getParentOfType(); std::string moduleName; if (module.getName().hasValue()) moduleName = "_" + module.getName().getValue().str(); else moduleName = ""; std::string executionModeInfoName = llvm::formatv( "__spv_{0}_{1}_execution_mode_info", moduleName, op.fn().str()); MLIRContext *context = rewriter.getContext(); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); // Create a struct type, corresponding to the C struct below. // struct { // int32_t executionMode; // int32_t values[]; // optional values // }; auto llvmI32Type = LLVM::LLVMType::getInt32Ty(context); SmallVector fields; fields.push_back(llvmI32Type); ArrayAttr values = op.values(); if (!values.empty()) { auto arrayType = LLVM::LLVMType::getArrayTy(llvmI32Type, values.size()); fields.push_back(arrayType); } auto structType = LLVM::LLVMType::getStructTy(context, fields); // Create `llvm.mlir.global` with initializer region containing one block. auto global = rewriter.create( UnknownLoc::get(context), structType, /*isConstant=*/true, LLVM::Linkage::External, executionModeInfoName, Attribute()); Location loc = global.getLoc(); Region ®ion = global.getInitializerRegion(); Block *block = rewriter.createBlock(®ion); // Initialize the struct and set the execution mode value. rewriter.setInsertionPoint(block, block->begin()); Value structValue = rewriter.create(loc, structType); IntegerAttr executionModeAttr = op.execution_modeAttr(); Value executionMode = rewriter.create(loc, llvmI32Type, executionModeAttr); structValue = rewriter.create( loc, structType, structValue, executionMode, ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 0)}, context)); // Insert extra operands if they exist into execution mode info struct. for (unsigned i = 0, e = values.size(); i < e; ++i) { auto attr = values.getValue()[i]; Value entry = rewriter.create(loc, llvmI32Type, attr); structValue = rewriter.create( loc, structType, structValue, entry, ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 1), rewriter.getIntegerAttr(rewriter.getI32Type(), i)}, context)); } rewriter.create(loc, ArrayRef({structValue})); rewriter.eraseOp(op); return success(); } }; /// Converts `spv.globalVariable` to `llvm.mlir.global`. Note that SPIR-V global /// returns a pointer, whereas in LLVM dialect the global holds an actual value. /// This difference is handled by `spv.mlir.addressof` and /// `llvm.mlir.addressof`ops that both return a pointer. class GlobalVariablePattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::GlobalVariableOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // Currently, there is no support of initialization with a constant value in // SPIR-V dialect. Specialization constants are not considered as well. if (op.initializer()) return failure(); auto srcType = op.type().cast(); auto dstType = typeConverter.convertType(srcType.getPointeeType()); if (!dstType) return failure(); // Limit conversion to the current invocation only or `StorageBuffer` // required by SPIR-V runner. // This is okay because multiple invocations are not supported yet. auto storageClass = srcType.getStorageClass(); if (storageClass != spirv::StorageClass::Input && storageClass != spirv::StorageClass::Private && storageClass != spirv::StorageClass::Output && storageClass != spirv::StorageClass::StorageBuffer) { return failure(); } // LLVM dialect spec: "If the global value is a constant, storing into it is // not allowed.". This corresponds to SPIR-V 'Input' storage class that is // read-only. bool isConstant = storageClass == spirv::StorageClass::Input; // SPIR-V spec: "By default, functions and global variables are private to a // module and cannot be accessed by other modules. However, a module may be // written to export or import functions and global (module scope) // variables.". Therefore, map 'Private' storage class to private linkage, // 'Input' and 'Output' to external linkage. auto linkage = storageClass == spirv::StorageClass::Private ? LLVM::Linkage::Private : LLVM::Linkage::External; rewriter.replaceOpWithNewOp( op, dstType.cast(), isConstant, linkage, op.sym_name(), Attribute()); return success(); } }; /// Converts SPIR-V cast ops that do not have straightforward LLVM /// equivalent in LLVM dialect. template class IndirectCastPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(SPIRVOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Type fromType = operation.operand().getType(); Type toType = operation.getType(); auto dstType = this->typeConverter.convertType(toType); if (!dstType) return failure(); if (getBitWidth(fromType) < getBitWidth(toType)) { rewriter.template replaceOpWithNewOp(operation, dstType, operands); return success(); } if (getBitWidth(fromType) > getBitWidth(toType)) { rewriter.template replaceOpWithNewOp(operation, dstType, operands); return success(); } return failure(); } }; class FunctionCallPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::FunctionCallOp callOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (callOp.getNumResults() == 0) { rewriter.replaceOpWithNewOp(callOp, llvm::None, operands, callOp.getAttrs()); return success(); } // Function returns a single result. auto dstType = typeConverter.convertType(callOp.getType(0)); rewriter.replaceOpWithNewOp(callOp, dstType, operands, callOp.getAttrs()); return success(); } }; /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate" template class FComparePattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(SPIRVOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dstType = this->typeConverter.convertType(operation.getType()); if (!dstType) return failure(); rewriter.template replaceOpWithNewOp( operation, dstType, rewriter.getI64IntegerAttr(static_cast(predicate)), operation.operand1(), operation.operand2()); return success(); } }; /// Converts SPIR-V integer comparisons to llvm.icmp "predicate" template class IComparePattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(SPIRVOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dstType = this->typeConverter.convertType(operation.getType()); if (!dstType) return failure(); rewriter.template replaceOpWithNewOp( operation, dstType, rewriter.getI64IntegerAttr(static_cast(predicate)), operation.operand1(), operation.operand2()); return success(); } }; class InverseSqrtPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::GLSLInverseSqrtOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto srcType = op.getType(); auto dstType = typeConverter.convertType(srcType); if (!dstType) return failure(); Location loc = op.getLoc(); Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); Value sqrt = rewriter.create(loc, dstType, op.operand()); rewriter.replaceOpWithNewOp(op, dstType, one, sqrt); return success(); } }; /// Converts `spv.Load` and `spv.Store` to LLVM dialect. template class LoadStorePattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(SPIRVop op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (!op.memory_access().hasValue()) { replaceWithLoadOrStore(op, rewriter, this->typeConverter, /*alignment=*/0, /*isVolatile=*/false, /*isNonTemporal=*/ false); return success(); } auto memoryAccess = op.memory_access().getValue(); switch (memoryAccess) { case spirv::MemoryAccess::Aligned: case spirv::MemoryAccess::None: case spirv::MemoryAccess::Nontemporal: case spirv::MemoryAccess::Volatile: { unsigned alignment = memoryAccess == spirv::MemoryAccess::Aligned ? *op.alignment() : 0; bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal; bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile; replaceWithLoadOrStore(op, rewriter, this->typeConverter, alignment, isVolatile, isNonTemporal); return success(); } default: // There is no support of other memory access attributes. return failure(); } } }; /// Converts `spv.Not` and `spv.LogicalNot` into LLVM dialect. template class NotPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(SPIRVOp notOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto srcType = notOp.getType(); auto dstType = this->typeConverter.convertType(srcType); if (!dstType) return failure(); Location loc = notOp.getLoc(); IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter); auto mask = srcType.template isa() ? rewriter.create( loc, dstType, SplatElementsAttr::get( srcType.template cast(), minusOne)) : rewriter.create(loc, dstType, minusOne); rewriter.template replaceOpWithNewOp(notOp, dstType, notOp.operand(), mask); return success(); } }; /// A template pattern that erases the given `SPIRVOp`. template class ErasePattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(SPIRVOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.eraseOp(op); return success(); } }; class ReturnPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::ReturnOp returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(returnOp, ArrayRef(), ArrayRef()); return success(); } }; class ReturnValuePattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::ReturnValueOp returnValueOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(returnValueOp, ArrayRef(), operands); return success(); } }; /// Converts `spv.loop` to LLVM dialect. All blocks within selection should be /// reachable for conversion to succeed. /// The structure of the loop in LLVM dialect will be the following: /// /// +------------------------------------+ /// | | /// | llvm.br ^header | /// +------------------------------------+ /// | /// +----------------+ | /// | | | /// | V V /// | +------------------------------------+ /// | | ^header: | /// | |
| /// | | llvm.cond_br %cond, ^body, ^exit | /// | +------------------------------------+ /// | | /// | |----------------------+ /// | | | /// | V | /// | +------------------------------------+ | /// | | ^body: | | /// | | | | /// | | llvm.br ^continue | | /// | +------------------------------------+ | /// | | | /// | V | /// | +------------------------------------+ | /// | | ^continue: | | /// | | | | /// | | llvm.br ^header | | /// | +------------------------------------+ | /// | | | /// +---------------+ +----------------------+ /// | /// V /// +------------------------------------+ /// | ^exit: | /// | llvm.br ^remaining | /// +------------------------------------+ /// | /// V /// +------------------------------------+ /// | ^remaining: | /// | | /// +------------------------------------+ /// class LoopPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::LoopOp loopOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // There is no support of loop control at the moment. if (loopOp.loop_control() != spirv::LoopControl::None) return failure(); Location loc = loopOp.getLoc(); // Split the current block after `spv.loop`. The remaining ops will be used // in `endBlock`. Block *currentBlock = rewriter.getBlock(); auto position = Block::iterator(loopOp); Block *endBlock = rewriter.splitBlock(currentBlock, position); // Remove entry block and create a branch in the current block going to the // header block. Block *entryBlock = loopOp.getEntryBlock(); assert(entryBlock->getOperations().size() == 1); auto brOp = dyn_cast(entryBlock->getOperations().front()); if (!brOp) return failure(); Block *headerBlock = loopOp.getHeaderBlock(); rewriter.setInsertionPointToEnd(currentBlock); rewriter.create(loc, brOp.getBlockArguments(), headerBlock); rewriter.eraseBlock(entryBlock); // Branch from merge block to end block. Block *mergeBlock = loopOp.getMergeBlock(); Operation *terminator = mergeBlock->getTerminator(); ValueRange terminatorOperands = terminator->getOperands(); rewriter.setInsertionPointToEnd(mergeBlock); rewriter.create(loc, terminatorOperands, endBlock); rewriter.inlineRegionBefore(loopOp.body(), endBlock); rewriter.replaceOp(loopOp, endBlock->getArguments()); return success(); } }; /// Converts `spv.selection` with `spv.BranchConditional` in its header block. /// All blocks within selection should be reachable for conversion to succeed. class SelectionPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::SelectionOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // There is no support for `Flatten` or `DontFlatten` selection control at // the moment. This are just compiler hints and can be performed during the // optimization passes. if (op.selection_control() != spirv::SelectionControl::None) return failure(); // `spv.selection` should have at least two blocks: one selection header // block and one merge block. If no blocks are present, or control flow // branches straight to merge block (two blocks are present), the op is // redundant and it is erased. if (op.body().getBlocks().size() <= 2) { rewriter.eraseOp(op); return success(); } Location loc = op.getLoc(); // Split the current block after `spv.selection`. The remaining ops will be // used in `continueBlock`. auto *currentBlock = rewriter.getInsertionBlock(); rewriter.setInsertionPointAfter(op); auto position = rewriter.getInsertionPoint(); auto *continueBlock = rewriter.splitBlock(currentBlock, position); // Extract conditional branch information from the header block. By SPIR-V // dialect spec, it should contain `spv.BranchConditional` or `spv.Switch` // op. Note that `spv.Switch op` is not supported at the moment in the // SPIR-V dialect. Remove this block when finished. auto *headerBlock = op.getHeaderBlock(); assert(headerBlock->getOperations().size() == 1); auto condBrOp = dyn_cast( headerBlock->getOperations().front()); if (!condBrOp) return failure(); rewriter.eraseBlock(headerBlock); // Branch from merge block to continue block. auto *mergeBlock = op.getMergeBlock(); Operation *terminator = mergeBlock->getTerminator(); ValueRange terminatorOperands = terminator->getOperands(); rewriter.setInsertionPointToEnd(mergeBlock); rewriter.create(loc, terminatorOperands, continueBlock); // Link current block to `true` and `false` blocks within the selection. Block *trueBlock = condBrOp.getTrueBlock(); Block *falseBlock = condBrOp.getFalseBlock(); rewriter.setInsertionPointToEnd(currentBlock); rewriter.create(loc, condBrOp.condition(), trueBlock, condBrOp.trueTargetOperands(), falseBlock, condBrOp.falseTargetOperands()); rewriter.inlineRegionBefore(op.body(), continueBlock); rewriter.replaceOp(op, continueBlock->getArguments()); return success(); } }; /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect /// puts a restriction on `Shift` and `Base` to have the same bit width, /// `Shift` is zero or sign extended to match this specification. Cases when /// `Shift` bit width > `Base` bit width are considered to be illegal. template class ShiftPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(SPIRVOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dstType = this->typeConverter.convertType(operation.getType()); if (!dstType) return failure(); Type op1Type = operation.operand1().getType(); Type op2Type = operation.operand2().getType(); if (op1Type == op2Type) { rewriter.template replaceOpWithNewOp(operation, dstType, operands); return success(); } Location loc = operation.getLoc(); Value extended; if (isUnsignedIntegerOrVector(op2Type)) { extended = rewriter.template create(loc, dstType, operation.operand2()); } else { extended = rewriter.template create(loc, dstType, operation.operand2()); } Value result = rewriter.template create( loc, dstType, operation.operand1(), extended); rewriter.replaceOp(operation, result); return success(); } }; class TanPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::GLSLTanOp tanOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dstType = typeConverter.convertType(tanOp.getType()); if (!dstType) return failure(); Location loc = tanOp.getLoc(); Value sin = rewriter.create(loc, dstType, tanOp.operand()); Value cos = rewriter.create(loc, dstType, tanOp.operand()); rewriter.replaceOpWithNewOp(tanOp, dstType, sin, cos); return success(); } }; /// Convert `spv.Tanh` to /// /// exp(2x) - 1 /// ----------- /// exp(2x) + 1 /// class TanhPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::GLSLTanhOp tanhOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto srcType = tanhOp.getType(); auto dstType = typeConverter.convertType(srcType); if (!dstType) return failure(); Location loc = tanhOp.getLoc(); Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0); Value multiplied = rewriter.create(loc, dstType, two, tanhOp.operand()); Value exponential = rewriter.create(loc, dstType, multiplied); Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); Value numerator = rewriter.create(loc, dstType, exponential, one); Value denominator = rewriter.create(loc, dstType, exponential, one); rewriter.replaceOpWithNewOp(tanhOp, dstType, numerator, denominator); return success(); } }; class VariablePattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::VariableOp varOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto srcType = varOp.getType(); // Initialization is supported for scalars and vectors only. auto pointerTo = srcType.cast().getPointeeType(); auto init = varOp.initializer(); if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa()) return failure(); auto dstType = typeConverter.convertType(srcType); if (!dstType) return failure(); Location loc = varOp.getLoc(); Value size = createI32ConstantOf(loc, rewriter, 1); if (!init) { rewriter.replaceOpWithNewOp(varOp, dstType, size); return success(); } Value allocated = rewriter.create(loc, dstType, size); rewriter.create(loc, init, allocated); rewriter.replaceOp(varOp, allocated); return success(); } }; //===----------------------------------------------------------------------===// // FuncOp conversion //===----------------------------------------------------------------------===// class FuncConversionPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // Convert function signature. At the moment LLVMType converter is enough // for currently supported types. auto funcType = funcOp.getType(); TypeConverter::SignatureConversion signatureConverter( funcType.getNumInputs()); auto llvmType = typeConverter.convertFunctionSignature( funcOp.getType(), /*isVariadic=*/false, signatureConverter); if (!llvmType) return failure(); // Create a new `LLVMFuncOp` Location loc = funcOp.getLoc(); StringRef name = funcOp.getName(); auto newFuncOp = rewriter.create(loc, name, llvmType); // Convert SPIR-V Function Control to equivalent LLVM function attribute MLIRContext *context = funcOp.getContext(); switch (funcOp.function_control()) { #define DISPATCH(functionControl, llvmAttr) \ case functionControl: \ newFuncOp.setAttr("passthrough", ArrayAttr::get({llvmAttr}, context)); \ break; DISPATCH(spirv::FunctionControl::Inline, StringAttr::get("alwaysinline", context)); DISPATCH(spirv::FunctionControl::DontInline, StringAttr::get("noinline", context)); DISPATCH(spirv::FunctionControl::Pure, StringAttr::get("readonly", context)); DISPATCH(spirv::FunctionControl::Const, StringAttr::get("readnone", context)); #undef DISPATCH // Default: if `spirv::FunctionControl::None`, then no attributes are // needed. default: break; } rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, &signatureConverter))) { return failure(); } rewriter.eraseOp(funcOp); return success(); } }; //===----------------------------------------------------------------------===// // ModuleOp conversion //===----------------------------------------------------------------------===// class ModuleConversionPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::ModuleOp spvModuleOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto newModuleOp = rewriter.create(spvModuleOp.getLoc(), spvModuleOp.getName()); rewriter.inlineRegionBefore(spvModuleOp.body(), newModuleOp.getBody()); // Remove the terminator block that was automatically added by builder rewriter.eraseBlock(&newModuleOp.getBodyRegion().back()); rewriter.eraseOp(spvModuleOp); return success(); } }; class ModuleEndConversionPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::ModuleEndOp moduleEndOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(moduleEndOp); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter) { typeConverter.addConversion([&](spirv::ArrayType type) { return convertArrayType(type, typeConverter); }); typeConverter.addConversion([&](spirv::PointerType type) { return convertPointerType(type, typeConverter); }); typeConverter.addConversion([&](spirv::RuntimeArrayType type) { return convertRuntimeArrayType(type, typeConverter); }); typeConverter.addConversion([&](spirv::StructType type) { return convertStructType(type, typeConverter); }); } void mlir::populateSPIRVToLLVMConversionPatterns( MLIRContext *context, LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns) { patterns.insert< // Arithmetic ops DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, // Bitwise ops BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, NotPattern, // Cast ops DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, IndirectCastPattern, IndirectCastPattern, IndirectCastPattern, // Comparison ops IComparePattern, IComparePattern, FComparePattern, FComparePattern, FComparePattern, FComparePattern, FComparePattern, FComparePattern, FComparePattern, FComparePattern, FComparePattern, FComparePattern, FComparePattern, FComparePattern, IComparePattern, IComparePattern, IComparePattern, IComparePattern, IComparePattern, IComparePattern, IComparePattern, IComparePattern, // Constant op ConstantScalarAndVectorPattern, // Control Flow ops BranchConversionPattern, BranchConditionalConversionPattern, FunctionCallPattern, LoopPattern, SelectionPattern, ErasePattern, // Entry points and execution mode are handled separately. ErasePattern, ExecutionModePattern, // GLSL extended instruction set ops DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, InverseSqrtPattern, TanPattern, TanhPattern, // Logical ops DirectConversionPattern, DirectConversionPattern, IComparePattern, IComparePattern, NotPattern, // Memory ops AccessChainPattern, AddressOfPattern, GlobalVariablePattern, LoadStorePattern, LoadStorePattern, VariablePattern, // Miscellaneous ops CompositeExtractPattern, CompositeInsertPattern, DirectConversionPattern, DirectConversionPattern, // Shift ops ShiftPattern, ShiftPattern, ShiftPattern, // Return ops ReturnPattern, ReturnValuePattern>(context, typeConverter); } void mlir::populateSPIRVToLLVMFunctionConversionPatterns( MLIRContext *context, LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns) { patterns.insert(context, typeConverter); } void mlir::populateSPIRVToLLVMModuleConversionPatterns( MLIRContext *context, LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns) { patterns.insert( context, typeConverter); } //===----------------------------------------------------------------------===// // Pre-conversion hooks //===----------------------------------------------------------------------===// /// Hook for descriptor set and binding number encoding. static constexpr StringRef kBinding = "binding"; static constexpr StringRef kDescriptorSet = "descriptor_set"; void mlir::encodeBindAttribute(ModuleOp module) { auto spvModules = module.getOps(); for (auto spvModule : spvModules) { spvModule.walk([&](spirv::GlobalVariableOp op) { IntegerAttr descriptorSet = op->getAttrOfType(kDescriptorSet); IntegerAttr binding = op->getAttrOfType(kBinding); // For every global variable in the module, get the ones with descriptor // set and binding numbers. if (descriptorSet && binding) { // Encode these numbers into the variable's symbolic name. If the // SPIR-V module has a name, add it at the beginning. auto moduleAndName = spvModule.getName().hasValue() ? spvModule.getName().getValue().str() + "_" + op.sym_name().str() : op.sym_name().str(); std::string name = llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName, std::to_string(descriptorSet.getInt()), std::to_string(binding.getInt())); // Replace all symbol uses and set the new symbol name. Finally, remove // descriptor set and binding attributes. if (failed(SymbolTable::replaceAllSymbolUses(op, name, spvModule))) op.emitError("unable to replace all symbol uses for ") << name; SymbolTable::setSymbolName(op, name); op.removeAttr(kDescriptorSet); op.removeAttr(kBinding); } }); } }