1 //===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
9 #define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
10 
11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
12 #include "mlir/Dialect/GPU/GPUDialect.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/Builders.h"
16 
17 namespace mlir {
18 
19 /// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func`
20 /// depending on the element type that Op operates upon. The function
21 /// declaration is added in case it was not added before.
22 ///
23 /// If the input values are of f16 type, the value is first casted to f32, the
24 /// function called and then the result casted back.
25 ///
26 /// Example with NVVM:
27 ///   %exp_f32 = std.exp %arg_f32 : f32
28 ///
29 /// will be transformed into
30 ///   llvm.call @__nv_expf(%arg_f32) : (!llvm.float) -> !llvm.float
31 template <typename SourceOp>
32 struct OpToFuncCallLowering : public ConvertToLLVMPattern {
33 public:
OpToFuncCallLoweringOpToFuncCallLowering34   explicit OpToFuncCallLowering(LLVMTypeConverter &lowering_, StringRef f32Func,
35                                 StringRef f64Func)
36       : ConvertToLLVMPattern(SourceOp::getOperationName(),
37                              lowering_.getDialect()->getContext(), lowering_),
38         f32Func(f32Func), f64Func(f64Func) {}
39 
40   LogicalResult
matchAndRewriteOpToFuncCallLowering41   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
42                   ConversionPatternRewriter &rewriter) const override {
43     using LLVM::LLVMFuncOp;
44     using LLVM::LLVMType;
45 
46     static_assert(
47         std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
48         "expected single result op");
49 
50     static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
51                                   SourceOp>::value,
52                   "expected op with same operand and result types");
53 
54     SmallVector<Value, 1> castedOperands;
55     for (Value operand : operands)
56       castedOperands.push_back(maybeCast(operand, rewriter));
57 
58     LLVMType resultType =
59         castedOperands.front().getType().cast<LLVM::LLVMType>();
60     LLVMType funcType = getFunctionType(resultType, castedOperands);
61     StringRef funcName = getFunctionName(funcType.getFunctionResultType());
62     if (funcName.empty())
63       return failure();
64 
65     LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
66     auto callOp = rewriter.create<LLVM::CallOp>(
67         op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp),
68         castedOperands);
69 
70     if (resultType == operands.front().getType()) {
71       rewriter.replaceOp(op, {callOp.getResult(0)});
72       return success();
73     }
74 
75     Value truncated = rewriter.create<LLVM::FPTruncOp>(
76         op->getLoc(), operands.front().getType(), callOp.getResult(0));
77     rewriter.replaceOp(op, {truncated});
78     return success();
79   }
80 
81 private:
maybeCastOpToFuncCallLowering82   Value maybeCast(Value operand, PatternRewriter &rewriter) const {
83     LLVM::LLVMType type = operand.getType().cast<LLVM::LLVMType>();
84     if (!type.isHalfTy())
85       return operand;
86 
87     return rewriter.create<LLVM::FPExtOp>(
88         operand.getLoc(), LLVM::LLVMType::getFloatTy(rewriter.getContext()),
89         operand);
90   }
91 
getFunctionTypeOpToFuncCallLowering92   LLVM::LLVMType getFunctionType(LLVM::LLVMType resultType,
93                                  ArrayRef<Value> operands) const {
94     using LLVM::LLVMType;
95     SmallVector<LLVMType, 1> operandTypes;
96     for (Value operand : operands) {
97       operandTypes.push_back(operand.getType().cast<LLVMType>());
98     }
99     return LLVMType::getFunctionTy(resultType, operandTypes,
100                                    /*isVarArg=*/false);
101   }
102 
getFunctionNameOpToFuncCallLowering103   StringRef getFunctionName(LLVM::LLVMType type) const {
104     if (type.isFloatTy())
105       return f32Func;
106     if (type.isDoubleTy())
107       return f64Func;
108     return "";
109   }
110 
appendOrGetFuncOpOpToFuncCallLowering111   LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName,
112                                      LLVM::LLVMType funcType,
113                                      Operation *op) const {
114     using LLVM::LLVMFuncOp;
115 
116     Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcName);
117     if (funcOp)
118       return cast<LLVMFuncOp>(*funcOp);
119 
120     mlir::OpBuilder b(op->getParentOfType<LLVMFuncOp>());
121     return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
122   }
123 
124   const std::string f32Func;
125   const std::string f64Func;
126 };
127 
128 } // namespace mlir
129 
130 #endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
131