1 //===- IndexIntrinsicsOpLowering.h - GPU IndexOps Lowering class *- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 #ifndef MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_ 9 #define MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_ 10 11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 12 #include "mlir/Dialect/GPU/GPUDialect.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "llvm/ADT/StringSwitch.h" 15 16 namespace mlir { 17 18 // Rewriting that replaces Op with XOp, YOp, or ZOp depending on the dimension 19 // that Op operates on. Op is assumed to return an `std.index` value and 20 // XOp, YOp and ZOp are assumed to return an `llvm.i32` value. Depending on 21 // `indexBitwidth`, sign-extend or truncate the resulting value to match the 22 // bitwidth expected by the consumers of the value. 23 template <typename Op, typename XOp, typename YOp, typename ZOp> 24 struct GPUIndexIntrinsicOpLowering : public ConvertToLLVMPattern { 25 private: 26 enum dimension { X = 0, Y = 1, Z = 2, invalid }; 27 unsigned indexBitwidth; 28 dimensionToIndexGPUIndexIntrinsicOpLowering29 static dimension dimensionToIndex(Op op) { 30 return StringSwitch<dimension>(op.dimension()) 31 .Case("x", X) 32 .Case("y", Y) 33 .Case("z", Z) 34 .Default(invalid); 35 } 36 37 public: GPUIndexIntrinsicOpLoweringGPUIndexIntrinsicOpLowering38 explicit GPUIndexIntrinsicOpLowering(LLVMTypeConverter &typeConverter) 39 : ConvertToLLVMPattern(Op::getOperationName(), 40 typeConverter.getDialect()->getContext(), 41 typeConverter), 42 indexBitwidth(typeConverter.getIndexTypeBitwidth()) {} 43 44 // Convert the kernel arguments to an LLVM type, preserve the rest. 45 LogicalResult matchAndRewriteGPUIndexIntrinsicOpLowering46 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 47 ConversionPatternRewriter &rewriter) const override { 48 auto loc = op->getLoc(); 49 MLIRContext *context = rewriter.getContext(); 50 Value newOp; 51 switch (dimensionToIndex(cast<Op>(op))) { 52 case X: 53 newOp = rewriter.create<XOp>(loc, LLVM::LLVMType::getInt32Ty(context)); 54 break; 55 case Y: 56 newOp = rewriter.create<YOp>(loc, LLVM::LLVMType::getInt32Ty(context)); 57 break; 58 case Z: 59 newOp = rewriter.create<ZOp>(loc, LLVM::LLVMType::getInt32Ty(context)); 60 break; 61 default: 62 return failure(); 63 } 64 65 if (indexBitwidth > 32) { 66 newOp = rewriter.create<LLVM::SExtOp>( 67 loc, LLVM::LLVMType::getIntNTy(context, indexBitwidth), newOp); 68 } else if (indexBitwidth < 32) { 69 newOp = rewriter.create<LLVM::TruncOp>( 70 loc, LLVM::LLVMType::getIntNTy(context, indexBitwidth), newOp); 71 } 72 73 rewriter.replaceOp(op, {newOp}); 74 return success(); 75 } 76 }; 77 78 } // namespace mlir 79 80 #endif // MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_ 81