1 //===- IndexIntrinsicsOpLowering.h - GPU IndexOps Lowering class *- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #ifndef MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_
9 #define MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_
10 
11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
12 #include "mlir/Dialect/GPU/GPUDialect.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "llvm/ADT/StringSwitch.h"
15 
16 namespace mlir {
17 
18 // Rewriting that replaces Op with XOp, YOp, or ZOp depending on the dimension
19 // that Op operates on.  Op is assumed to return an `std.index` value and
20 // XOp, YOp and ZOp are assumed to return an `llvm.i32` value.  Depending on
21 // `indexBitwidth`, sign-extend or truncate the resulting value to match the
22 // bitwidth expected by the consumers of the value.
23 template <typename Op, typename XOp, typename YOp, typename ZOp>
24 struct GPUIndexIntrinsicOpLowering : public ConvertToLLVMPattern {
25 private:
26   enum dimension { X = 0, Y = 1, Z = 2, invalid };
27   unsigned indexBitwidth;
28 
dimensionToIndexGPUIndexIntrinsicOpLowering29   static dimension dimensionToIndex(Op op) {
30     return StringSwitch<dimension>(op.dimension())
31         .Case("x", X)
32         .Case("y", Y)
33         .Case("z", Z)
34         .Default(invalid);
35   }
36 
37 public:
GPUIndexIntrinsicOpLoweringGPUIndexIntrinsicOpLowering38   explicit GPUIndexIntrinsicOpLowering(LLVMTypeConverter &typeConverter)
39       : ConvertToLLVMPattern(Op::getOperationName(),
40                              typeConverter.getDialect()->getContext(),
41                              typeConverter),
42         indexBitwidth(typeConverter.getIndexTypeBitwidth()) {}
43 
44   // Convert the kernel arguments to an LLVM type, preserve the rest.
45   LogicalResult
matchAndRewriteGPUIndexIntrinsicOpLowering46   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
47                   ConversionPatternRewriter &rewriter) const override {
48     auto loc = op->getLoc();
49     MLIRContext *context = rewriter.getContext();
50     Value newOp;
51     switch (dimensionToIndex(cast<Op>(op))) {
52     case X:
53       newOp = rewriter.create<XOp>(loc, LLVM::LLVMType::getInt32Ty(context));
54       break;
55     case Y:
56       newOp = rewriter.create<YOp>(loc, LLVM::LLVMType::getInt32Ty(context));
57       break;
58     case Z:
59       newOp = rewriter.create<ZOp>(loc, LLVM::LLVMType::getInt32Ty(context));
60       break;
61     default:
62       return failure();
63     }
64 
65     if (indexBitwidth > 32) {
66       newOp = rewriter.create<LLVM::SExtOp>(
67           loc, LLVM::LLVMType::getIntNTy(context, indexBitwidth), newOp);
68     } else if (indexBitwidth < 32) {
69       newOp = rewriter.create<LLVM::TruncOp>(
70           loc, LLVM::LLVMType::getIntNTy(context, indexBitwidth), newOp);
71     }
72 
73     rewriter.replaceOp(op, {newOp});
74     return success();
75   }
76 };
77 
78 } // namespace mlir
79 
80 #endif // MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_
81