//===- ConvertStandardToSPIRV.cpp - Standard to SPIR-V dialect conversion--===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns to convert standard ops to SPIR-V ops. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/LayoutUtils.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVLowering.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineMap.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "std-to-spirv-pattern" using namespace mlir; //===----------------------------------------------------------------------===// // Utility functions //===----------------------------------------------------------------------===// /// Returns true if the given `type` is a boolean scalar or vector type. static bool isBoolScalarOrVector(Type type) { if (type.isInteger(1)) return true; if (auto vecType = type.dyn_cast()) return vecType.getElementType().isInteger(1); return false; } /// Converts the given `srcAttr` into a boolean attribute if it holds an /// integral value. Returns null attribute if conversion fails. static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) { if (auto boolAttr = srcAttr.dyn_cast()) return boolAttr; if (auto intAttr = srcAttr.dyn_cast()) return builder.getBoolAttr(intAttr.getValue().getBoolValue()); return BoolAttr(); } /// Converts the given `srcAttr` to a new attribute of the given `dstType`. /// Returns null attribute if conversion fails. static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder) { // If the source number uses less active bits than the target bitwidth, then // it should be safe to convert. if (srcAttr.getValue().isIntN(dstType.getWidth())) return builder.getIntegerAttr(dstType, srcAttr.getInt()); // XXX: Try again by interpreting the source number as a signed value. // Although integers in the standard dialect are signless, they can represent // a signed number. It's the operation decides how to interpret. This is // dangerous, but it seems there is no good way of handling this if we still // want to change the bitwidth. Emit a message at least. if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) { auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt()); LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '" << dstAttr << "' for type '" << dstType << "'\n"); return dstAttr; } LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' illegal: cannot fit into target type '" << dstType << "'\n"); return IntegerAttr(); } /// Converts the given `srcAttr` to a new attribute of the given `dstType`. /// Returns null attribute if `dstType` is not 32-bit or conversion fails. static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder) { // Only support converting to float for now. if (!dstType.isF32()) return FloatAttr(); // Try to convert the source floating-point number to single precision. APFloat dstVal = srcAttr.getValue(); bool losesInfo = false; APFloat::opStatus status = dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo); if (status != APFloat::opOK || losesInfo) { LLVM_DEBUG(llvm::dbgs() << srcAttr << " illegal: cannot fit into converted type '" << dstType << "'\n"); return FloatAttr(); } return builder.getF32FloatAttr(dstVal.convertToFloat()); } /// Returns signed remainder for `lhs` and `rhs` and lets the result follow /// the sign of `signOperand`. /// /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative /// the result is undefined." So we cannot directly use spv.SRem/spv.SMod /// if either operand can be negative. Emulate it via spv.UMod. static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, Value signOperand, OpBuilder &builder) { assert(lhs.getType() == rhs.getType()); assert(lhs == signOperand || rhs == signOperand); Type type = lhs.getType(); // Calculate the remainder with spv.UMod. Value lhsAbs = builder.create(loc, type, lhs); Value rhsAbs = builder.create(loc, type, rhs); Value abs = builder.create(loc, lhsAbs, rhsAbs); // Fix the sign. Value isPositive; if (lhs == signOperand) isPositive = builder.create(loc, lhs, lhsAbs); else isPositive = builder.create(loc, rhs, rhsAbs); Value absNegate = builder.create(loc, type, abs); return builder.create(loc, type, isPositive, abs, absNegate); } /// Returns the offset of the value in `targetBits` representation. /// /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`. /// It's assumed to be non-negative. /// /// When accessing an element in the array treating as having elements of /// `targetBits`, multiple values are loaded in the same time. The method /// returns the offset where the `srcIdx` locates in the value. For example, if /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is /// located at (x % 4) * 8. Because there are four elements in one i32, and one /// element has 8 bits. static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder) { assert(targetBits % sourceBits == 0); IntegerType targetType = builder.getIntegerType(targetBits); IntegerAttr idxAttr = builder.getIntegerAttr(targetType, targetBits / sourceBits); auto idx = builder.create(loc, targetType, idxAttr); IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits); auto srcBitsValue = builder.create(loc, targetType, srcBitsAttr); auto m = builder.create(loc, srcIdx, idx); return builder.create(loc, targetType, m, srcBitsValue); } /// Returns an adjusted spirv::AccessChainOp. Based on the /// extension/capabilities, certain integer bitwidths `sourceBits` might not be /// supported. During conversion if a memref of an unsupported type is used, /// load/stores to this memref need to be modified to use a supported higher /// bitwidth `targetBits` and extracting the required bits. For an accessing a /// 1D array (spv.array or spv.rt_array), the last index is modified to load the /// bits needed. The extraction of the actual bits needed are handled /// separately. Note that this only works for a 1-D tensor. static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder) { assert(targetBits % sourceBits == 0); const auto loc = op.getLoc(); IntegerType targetType = builder.getIntegerType(targetBits); IntegerAttr attr = builder.getIntegerAttr(targetType, targetBits / sourceBits); auto idx = builder.create(loc, targetType, attr); auto lastDim = op->getOperand(op.getNumOperands() - 1); auto indices = llvm::to_vector<4>(op.indices()); // There are two elements if this is a 1-D tensor. assert(indices.size() == 2); indices.back() = builder.create(loc, lastDim, idx); Type t = typeConverter.convertType(op.component_ptr().getType()); return builder.create(loc, t, op.base_ptr(), indices); } /// Returns the shifted `targetBits`-bit value with the given offset. static Value shiftValue(Location loc, Value value, Value offset, Value mask, int targetBits, OpBuilder &builder) { Type targetType = builder.getIntegerType(targetBits); Value result = builder.create(loc, value, mask); return builder.create(loc, targetType, result, offset); } /// Returns true if the operator is operating on unsigned integers. /// TODO: Have a TreatOperandsAsUnsignedInteger trait and bake the information /// to the ops themselves. template bool isUnsignedOp() { return false; } #define CHECK_UNSIGNED_OP(SPIRVOp) \ template <> \ bool isUnsignedOp() { \ return true; \ } CHECK_UNSIGNED_OP(spirv::AtomicUMaxOp) CHECK_UNSIGNED_OP(spirv::AtomicUMinOp) CHECK_UNSIGNED_OP(spirv::BitFieldUExtractOp) CHECK_UNSIGNED_OP(spirv::ConvertUToFOp) CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMaxOp) CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMinOp) CHECK_UNSIGNED_OP(spirv::UConvertOp) CHECK_UNSIGNED_OP(spirv::UDivOp) CHECK_UNSIGNED_OP(spirv::UGreaterThanEqualOp) CHECK_UNSIGNED_OP(spirv::UGreaterThanOp) CHECK_UNSIGNED_OP(spirv::ULessThanEqualOp) CHECK_UNSIGNED_OP(spirv::ULessThanOp) CHECK_UNSIGNED_OP(spirv::UModOp) #undef CHECK_UNSIGNED_OP /// Returns true if the allocations of type `t` can be lowered to SPIR-V. static bool isAllocationSupported(MemRefType t) { // Currently only support workgroup local memory allocations with static // shape and int or float or vector of int or float element type. if (!(t.hasStaticShape() && SPIRVTypeConverter::getMemorySpaceForStorageClass( spirv::StorageClass::Workgroup) == t.getMemorySpace())) return false; Type elementType = t.getElementType(); if (auto vecType = elementType.dyn_cast()) elementType = vecType.getElementType(); return elementType.isIntOrFloat(); } /// Returns the scope to use for atomic operations use for emulating store /// operations of unsupported integer bitwidths, based on the memref /// type. Returns None on failure. static Optional getAtomicOpScope(MemRefType t) { Optional storageClass = SPIRVTypeConverter::getStorageClassForMemorySpace(t.getMemorySpace()); if (!storageClass) return {}; switch (*storageClass) { case spirv::StorageClass::StorageBuffer: return spirv::Scope::Device; case spirv::StorageClass::Workgroup: return spirv::Scope::Workgroup; default: { } } return {}; } //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// // Note that DRR cannot be used for the patterns in this file: we may need to // convert type along the way, which requires ConversionPattern. DRR generates // normal RewritePattern. namespace { /// Converts an allocation operation to SPIR-V. Currently only supports lowering /// to Workgroup memory when the size is constant. Note that this pattern needs /// to be applied in a pass that runs at least at spv.module scope since it wil /// ladd global variables into the spv.module. class AllocOpPattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; LogicalResult matchAndRewrite(AllocOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { MemRefType allocType = operation.getType(); if (!isAllocationSupported(allocType)) return operation.emitError("unhandled allocation type"); // Get the SPIR-V type for the allocation. Type spirvType = typeConverter.convertType(allocType); // Insert spv.globalVariable for this allocation. Operation *parent = SymbolTable::getNearestSymbolTable(operation->getParentOp()); if (!parent) return failure(); Location loc = operation.getLoc(); spirv::GlobalVariableOp varOp; { OpBuilder::InsertionGuard guard(rewriter); Block &entryBlock = *parent->getRegion(0).begin(); rewriter.setInsertionPointToStart(&entryBlock); auto varOps = entryBlock.getOps(); std::string varName = std::string("__workgroup_mem__") + std::to_string(std::distance(varOps.begin(), varOps.end())); varOp = rewriter.create( loc, TypeAttr::get(spirvType), varName, /*initializer = */ nullptr); } // Get pointer to global variable at the current scope. rewriter.replaceOpWithNewOp(operation, varOp); return success(); } }; /// Removed a deallocation if it is a supported allocation. Currently only /// removes deallocation if the memory space is workgroup memory. class DeallocOpPattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; LogicalResult matchAndRewrite(DeallocOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { MemRefType deallocType = operation.memref().getType().cast(); if (!isAllocationSupported(deallocType)) return operation.emitError("unhandled deallocation type"); rewriter.eraseOp(operation); return success(); } }; /// Converts unary and binary standard operations to SPIR-V operations. template class UnaryAndBinaryOpPattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; LogicalResult matchAndRewrite(StdOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { assert(operands.size() <= 2); auto dstType = this->typeConverter.convertType(operation.getType()); if (!dstType) return failure(); if (isUnsignedOp() && dstType != operation.getType()) { return operation.emitError( "bitwidth emulation is not implemented yet on unsigned op"); } rewriter.template replaceOpWithNewOp(operation, dstType, operands); return success(); } }; /// Converts std.remi_signed to SPIR-V ops. /// /// This cannot be merged into the template unary/binary pattern due to /// Vulkan restrictions over spv.SRem and spv.SMod. class SignedRemIOpPattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; LogicalResult matchAndRewrite(SignedRemIOp remOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts bitwise standard operations to SPIR-V operations. This is a special /// pattern other than the BinaryOpPatternPattern because if the operands are /// boolean values, SPIR-V uses different operations (`SPIRVLogicalOp`). For /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. template class BitwiseOpPattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; LogicalResult matchAndRewrite(StdOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { assert(operands.size() == 2); auto dstType = this->typeConverter.convertType(operation.getResult().getType()); if (!dstType) return failure(); if (isBoolScalarOrVector(operands.front().getType())) { rewriter.template replaceOpWithNewOp(operation, dstType, operands); } else { rewriter.template replaceOpWithNewOp(operation, dstType, operands); } return success(); } }; /// Converts composite std.constant operation to spv.constant. class ConstantCompositeOpPattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; LogicalResult matchAndRewrite(ConstantOp constOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts scalar std.constant operation to spv.constant. class ConstantScalarOpPattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; LogicalResult matchAndRewrite(ConstantOp constOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts floating-point comparison operations to SPIR-V ops. class CmpFOpPattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; LogicalResult matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts integer compare operation on i1 type operands to SPIR-V ops. class BoolCmpIOpPattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; LogicalResult matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts integer compare operation to SPIR-V ops. class CmpIOpPattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; LogicalResult matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts std.load to spv.Load. class IntLoadOpPattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; LogicalResult matchAndRewrite(LoadOp loadOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts std.load to spv.Load. class LoadOpPattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; LogicalResult matchAndRewrite(LoadOp loadOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts std.return to spv.Return. class ReturnOpPattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; LogicalResult matchAndRewrite(ReturnOp returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts std.select to spv.Select. class SelectOpPattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; LogicalResult matchAndRewrite(SelectOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts std.store to spv.Store on integers. class IntStoreOpPattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; LogicalResult matchAndRewrite(StoreOp storeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts std.store to spv.Store. class StoreOpPattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; LogicalResult matchAndRewrite(StoreOp storeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts std.zexti to spv.Select if the type of source is i1 or vector of /// i1. class ZeroExtendI1Pattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; LogicalResult matchAndRewrite(ZeroExtendIOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto srcType = operands.front().getType(); if (!isBoolScalarOrVector(srcType)) return failure(); auto dstType = this->typeConverter.convertType(op.getResult().getType()); Location loc = op.getLoc(); Attribute zeroAttr, oneAttr; if (auto vectorType = dstType.dyn_cast()) { zeroAttr = DenseElementsAttr::get(vectorType, 0); oneAttr = DenseElementsAttr::get(vectorType, 1); } else { zeroAttr = IntegerAttr::get(dstType, 0); oneAttr = IntegerAttr::get(dstType, 1); } Value zero = rewriter.create(loc, zeroAttr); Value one = rewriter.create(loc, oneAttr); rewriter.template replaceOpWithNewOp( op, dstType, operands.front(), one, zero); return success(); } }; /// Converts type-casting standard operations to SPIR-V operations. template class TypeCastingOpPattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; LogicalResult matchAndRewrite(StdOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { assert(operands.size() == 1); auto srcType = operands.front().getType(); if (isBoolScalarOrVector(srcType)) return failure(); auto dstType = this->typeConverter.convertType(operation.getResult().getType()); if (dstType == srcType) { // Due to type conversion, we are seeing the same source and target type. // Then we can just erase this operation by forwarding its operand. rewriter.replaceOp(operation, operands.front()); } else { rewriter.template replaceOpWithNewOp(operation, dstType, operands); } return success(); } }; /// Converts std.xor to SPIR-V operations. class XOrOpPattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; LogicalResult matchAndRewrite(XOrOp xorOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; } // namespace //===----------------------------------------------------------------------===// // SignedRemIOpPattern //===----------------------------------------------------------------------===// LogicalResult SignedRemIOpPattern::matchAndRewrite( SignedRemIOp remOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { Value result = emulateSignedRemainder(remOp.getLoc(), operands[0], operands[1], operands[0], rewriter); rewriter.replaceOp(remOp, result); return success(); } //===----------------------------------------------------------------------===// // ConstantOp with composite type. //===----------------------------------------------------------------------===// LogicalResult ConstantCompositeOpPattern::matchAndRewrite( ConstantOp constOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { auto srcType = constOp.getType().dyn_cast(); if (!srcType) return failure(); // std.constant should only have vector or tenor types. assert((srcType.isa())); auto dstType = typeConverter.convertType(srcType); if (!dstType) return failure(); auto dstElementsAttr = constOp.value().dyn_cast(); ShapedType dstAttrType = dstElementsAttr.getType(); if (!dstElementsAttr) return failure(); // If the composite type has more than one dimensions, perform linearization. if (srcType.getRank() > 1) { if (srcType.isa()) { dstAttrType = RankedTensorType::get(srcType.getNumElements(), srcType.getElementType()); dstElementsAttr = dstElementsAttr.reshape(dstAttrType); } else { // TODO: add support for large vectors. return failure(); } } Type srcElemType = srcType.getElementType(); Type dstElemType; // Tensor types are converted to SPIR-V array types; vector types are // converted to SPIR-V vector/array types. if (auto arrayType = dstType.dyn_cast()) dstElemType = arrayType.getElementType(); else dstElemType = dstType.cast().getElementType(); // If the source and destination element types are different, perform // attribute conversion. if (srcElemType != dstElemType) { SmallVector elements; if (srcElemType.isa()) { for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) { FloatAttr dstAttr = convertFloatAttr( srcAttr.cast(), dstElemType.cast(), rewriter); if (!dstAttr) return failure(); elements.push_back(dstAttr); } } else if (srcElemType.isInteger(1)) { return failure(); } else { for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) { IntegerAttr dstAttr = convertIntegerAttr(srcAttr.cast(), dstElemType.cast(), rewriter); if (!dstAttr) return failure(); elements.push_back(dstAttr); } } // Unfortunately, we cannot use dialect-specific types for element // attributes; element attributes only works with builtin types. So we need // to prepare another converted builtin types for the destination elements // attribute. if (dstAttrType.isa()) dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType); else dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType); dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements); } rewriter.replaceOpWithNewOp(constOp, dstType, dstElementsAttr); return success(); } //===----------------------------------------------------------------------===// // ConstantOp with scalar type. //===----------------------------------------------------------------------===// LogicalResult ConstantScalarOpPattern::matchAndRewrite( ConstantOp constOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { Type srcType = constOp.getType(); if (!srcType.isIntOrIndexOrFloat()) return failure(); Type dstType = typeConverter.convertType(srcType); if (!dstType) return failure(); // Floating-point types. if (srcType.isa()) { auto srcAttr = constOp.value().cast(); auto dstAttr = srcAttr; // Floating-point types not supported in the target environment are all // converted to float type. if (srcType != dstType) { dstAttr = convertFloatAttr(srcAttr, dstType.cast(), rewriter); if (!dstAttr) return failure(); } rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); return success(); } // Bool type. if (srcType.isInteger(1)) { // std.constant can use 0/1 instead of true/false for i1 values. We need to // handle that here. auto dstAttr = convertBoolAttr(constOp.value(), rewriter); if (!dstAttr) return failure(); rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); return success(); } // IndexType or IntegerType. Index values are converted to 32-bit integer // values when converting to SPIR-V. auto srcAttr = constOp.value().cast(); auto dstAttr = convertIntegerAttr(srcAttr, dstType.cast(), rewriter); if (!dstAttr) return failure(); rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); return success(); } //===----------------------------------------------------------------------===// // CmpFOp //===----------------------------------------------------------------------===// LogicalResult CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { CmpFOpAdaptor cmpFOpOperands(operands); switch (cmpFOp.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ rewriter.replaceOpWithNewOp(cmpFOp, cmpFOp.getResult().getType(), \ cmpFOpOperands.lhs(), \ cmpFOpOperands.rhs()); \ return success(); // Ordered. DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp); DISPATCH(CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); DISPATCH(CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp); DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp); // Unordered. DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp); DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); DISPATCH(CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp); DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp); #undef DISPATCH default: break; } return failure(); } //===----------------------------------------------------------------------===// // CmpIOp //===----------------------------------------------------------------------===// LogicalResult BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { CmpIOpAdaptor cmpIOpOperands(operands); Type operandType = cmpIOp.lhs().getType(); if (!isBoolScalarOrVector(operandType)) return failure(); switch (cmpIOp.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ rewriter.replaceOpWithNewOp(cmpIOp, cmpIOp.getResult().getType(), \ cmpIOpOperands.lhs(), \ cmpIOpOperands.rhs()); \ return success(); DISPATCH(CmpIPredicate::eq, spirv::LogicalEqualOp); DISPATCH(CmpIPredicate::ne, spirv::LogicalNotEqualOp); #undef DISPATCH default:; } return failure(); } LogicalResult CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { CmpIOpAdaptor cmpIOpOperands(operands); Type operandType = cmpIOp.lhs().getType(); if (isBoolScalarOrVector(operandType)) return failure(); switch (cmpIOp.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ if (isUnsignedOp() && \ operandType != this->typeConverter.convertType(operandType)) { \ return cmpIOp.emitError( \ "bitwidth emulation is not implemented yet on unsigned op"); \ } \ rewriter.replaceOpWithNewOp(cmpIOp, cmpIOp.getResult().getType(), \ cmpIOpOperands.lhs(), \ cmpIOpOperands.rhs()); \ return success(); DISPATCH(CmpIPredicate::eq, spirv::IEqualOp); DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp); DISPATCH(CmpIPredicate::slt, spirv::SLessThanOp); DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp); DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp); DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp); DISPATCH(CmpIPredicate::ult, spirv::ULessThanOp); DISPATCH(CmpIPredicate::ule, spirv::ULessThanEqualOp); DISPATCH(CmpIPredicate::ugt, spirv::UGreaterThanOp); DISPATCH(CmpIPredicate::uge, spirv::UGreaterThanEqualOp); #undef DISPATCH } return failure(); } //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// LogicalResult IntLoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { LoadOpAdaptor loadOperands(operands); auto loc = loadOp.getLoc(); auto memrefType = loadOp.memref().getType().cast(); if (!memrefType.getElementType().isSignlessInteger()) return failure(); spirv::AccessChainOp accessChainOp = spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(), loadOperands.indices(), loc, rewriter); int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); auto dstType = typeConverter.convertType(memrefType) .cast() .getPointeeType() .cast() .getElementType(0) .cast() .getElementType(); int dstBits = dstType.getIntOrFloatBitWidth(); assert(dstBits % srcBits == 0); // If the rewrited load op has the same bit width, use the loading value // directly. if (srcBits == dstBits) { rewriter.replaceOpWithNewOp(loadOp, accessChainOp.getResult()); return success(); } // Assume that getElementPtr() works linearizely. If it's a scalar, the method // still returns a linearized accessing. If the accessing is not linearized, // there will be offset issues. assert(accessChainOp.indices().size() == 2); Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, srcBits, dstBits, rewriter); Value spvLoadOp = rewriter.create( loc, dstType, adjustedPtr, loadOp->getAttrOfType( spirv::attributeName()), loadOp->getAttrOfType("alignment")); // Shift the bits to the rightmost. // ____XXXX________ -> ____________XXXX Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); Value result = rewriter.create( loc, spvLoadOp.getType(), spvLoadOp, offset); // Apply the mask to extract corresponding bits. Value mask = rewriter.create( loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); result = rewriter.create(loc, dstType, result, mask); // Apply sign extension on the loading value unconditionally. The signedness // semantic is carried in the operator itself, we relies other pattern to // handle the casting. IntegerAttr shiftValueAttr = rewriter.getIntegerAttr(dstType, dstBits - srcBits); Value shiftValue = rewriter.create(loc, dstType, shiftValueAttr); result = rewriter.create(loc, dstType, result, shiftValue); result = rewriter.create(loc, dstType, result, shiftValue); rewriter.replaceOp(loadOp, result); assert(accessChainOp.use_empty()); rewriter.eraseOp(accessChainOp); return success(); } LogicalResult LoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { LoadOpAdaptor loadOperands(operands); auto memrefType = loadOp.memref().getType().cast(); if (memrefType.getElementType().isSignlessInteger()) return failure(); auto loadPtr = spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter); rewriter.replaceOpWithNewOp(loadOp, loadPtr); return success(); } //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// LogicalResult ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (returnOp.getNumOperands()) { return failure(); } rewriter.replaceOpWithNewOp(returnOp); return success(); } //===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// LogicalResult SelectOpPattern::matchAndRewrite(SelectOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { SelectOpAdaptor selectOperands(operands); rewriter.replaceOpWithNewOp(op, selectOperands.condition(), selectOperands.true_value(), selectOperands.false_value()); return success(); } //===----------------------------------------------------------------------===// // StoreOp //===----------------------------------------------------------------------===// LogicalResult IntStoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { StoreOpAdaptor storeOperands(operands); auto memrefType = storeOp.memref().getType().cast(); if (!memrefType.getElementType().isSignlessInteger()) return failure(); auto loc = storeOp.getLoc(); spirv::AccessChainOp accessChainOp = spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(), storeOperands.indices(), loc, rewriter); int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); auto dstType = typeConverter.convertType(memrefType) .cast() .getPointeeType() .cast() .getElementType(0) .cast() .getElementType(); int dstBits = dstType.getIntOrFloatBitWidth(); assert(dstBits % srcBits == 0); if (srcBits == dstBits) { rewriter.replaceOpWithNewOp( storeOp, accessChainOp.getResult(), storeOperands.value()); return success(); } // Since there are multi threads in the processing, the emulation will be done // with atomic operations. E.g., if the storing value is i8, rewrite the // StoreOp to // 1) load a 32-bit integer // 2) clear 8 bits in the loading value // 3) store 32-bit value back // 4) load a 32-bit integer // 5) modify 8 bits in the loading value // 6) store 32-bit value back // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step // 4 to step 6 are done by AtomicOr as another atomic step. assert(accessChainOp.indices().size() == 2); Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); // Create a mask to clear the destination. E.g., if it is the second i8 in // i32, 0xFFFF00FF is created. Value mask = rewriter.create( loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); Value clearBitsMask = rewriter.create(loc, dstType, mask, offset); clearBitsMask = rewriter.create(loc, dstType, clearBitsMask); Value storeVal = shiftValue(loc, storeOperands.value(), offset, mask, dstBits, rewriter); Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, srcBits, dstBits, rewriter); Optional scope = getAtomicOpScope(memrefType); if (!scope) return failure(); Value result = rewriter.create( loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, clearBitsMask); result = rewriter.create( loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, storeVal); // The AtomicOrOp has no side effect. Since it is already inserted, we can // just remove the original StoreOp. Note that rewriter.replaceOp() // doesn't work because it only accepts that the numbers of result are the // same. rewriter.eraseOp(storeOp); assert(accessChainOp.use_empty()); rewriter.eraseOp(accessChainOp); return success(); } LogicalResult StoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { StoreOpAdaptor storeOperands(operands); auto memrefType = storeOp.memref().getType().cast(); if (memrefType.getElementType().isSignlessInteger()) return failure(); auto storePtr = spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(), storeOperands.indices(), storeOp.getLoc(), rewriter); rewriter.replaceOpWithNewOp(storeOp, storePtr, storeOperands.value()); return success(); } //===----------------------------------------------------------------------===// // XorOp //===----------------------------------------------------------------------===// LogicalResult XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { assert(operands.size() == 2); if (isBoolScalarOrVector(operands.front().getType())) return failure(); auto dstType = typeConverter.convertType(xorOp.getType()); if (!dstType) return failure(); rewriter.replaceOpWithNewOp(xorOp, dstType, operands); return success(); } //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// namespace mlir { void populateStandardToSPIRVPatterns(MLIRContext *context, SPIRVTypeConverter &typeConverter, OwningRewritePatternList &patterns) { patterns.insert< // Unary and binary patterns BitwiseOpPattern, BitwiseOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, SignedRemIOpPattern, XOrOpPattern, // Comparison patterns BoolCmpIOpPattern, CmpFOpPattern, CmpIOpPattern, // Constant patterns ConstantCompositeOpPattern, ConstantScalarOpPattern, // Memory patterns AllocOpPattern, DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern, StoreOpPattern, ReturnOpPattern, SelectOpPattern, // Type cast patterns ZeroExtendI1Pattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern>(context, typeConverter); } } // namespace mlir