1 //===- TestConvertCallOp.cpp - Test LLVM Conversion of Standard CallOp ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestTypes.h"
11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Pass/Pass.h"
16 
17 using namespace mlir;
18 
19 namespace {
20 
21 class TestTypeProducerOpConverter
22     : public ConvertOpToLLVMPattern<test::TestTypeProducerOp> {
23 public:
24   using ConvertOpToLLVMPattern<
25       test::TestTypeProducerOp>::ConvertOpToLLVMPattern;
26 
27   LogicalResult
matchAndRewrite(test::TestTypeProducerOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const28   matchAndRewrite(test::TestTypeProducerOp op, ArrayRef<Value> operands,
29                   ConversionPatternRewriter &rewriter) const override {
30     rewriter.replaceOpWithNewOp<LLVM::NullOp>(op, getVoidPtrType());
31     return success();
32   }
33 };
34 
35 class TestConvertCallOp
36     : public PassWrapper<TestConvertCallOp, OperationPass<ModuleOp>> {
37 public:
getDependentDialects(DialectRegistry & registry) const38   void getDependentDialects(DialectRegistry &registry) const final {
39     registry.insert<LLVM::LLVMDialect>();
40   }
41 
runOnOperation()42   void runOnOperation() override {
43     ModuleOp m = getOperation();
44 
45     // Populate type conversions.
46     LLVMTypeConverter type_converter(m.getContext());
47     type_converter.addConversion([&](test::TestType type) {
48       return LLVM::LLVMType::getInt8PtrTy(m.getContext());
49     });
50 
51     // Populate patterns.
52     OwningRewritePatternList patterns;
53     populateStdToLLVMConversionPatterns(type_converter, patterns);
54     patterns.insert<TestTypeProducerOpConverter>(type_converter);
55 
56     // Set target.
57     ConversionTarget target(getContext());
58     target.addLegalDialect<LLVM::LLVMDialect>();
59     target.addIllegalDialect<test::TestDialect>();
60     target.addIllegalDialect<StandardOpsDialect>();
61 
62     if (failed(applyPartialConversion(m, target, std::move(patterns))))
63       signalPassFailure();
64   }
65 };
66 
67 } // namespace
68 
69 namespace mlir {
70 namespace test {
registerConvertCallOpPass()71 void registerConvertCallOpPass() {
72   PassRegistration<TestConvertCallOp>(
73       "test-convert-call-op",
74       "Tests conversion of `std.call` to `llvm.call` in "
75       "presence of custom types");
76 }
77 } // namespace test
78 } // namespace mlir
79