1 //===- FuncConversions.cpp - Standard Function conversions ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
10 #include "mlir/Dialect/StandardOps/IR/Ops.h"
11 #include "mlir/Transforms/DialectConversion.h"
12 
13 using namespace mlir;
14 
15 namespace {
16 /// Converts the operand and result types of the Standard's CallOp, used
17 /// together with the FuncOpSignatureConversion.
18 struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
19   using OpConversionPattern<CallOp>::OpConversionPattern;
20 
21   /// Hook for derived classes to implement combined matching and rewriting.
22   LogicalResult
matchAndRewrite__anon63855c530111::CallOpSignatureConversion23   matchAndRewrite(CallOp callOp, ArrayRef<Value> operands,
24                   ConversionPatternRewriter &rewriter) const override {
25     // Convert the original function results.
26     SmallVector<Type, 1> convertedResults;
27     if (failed(typeConverter->convertTypes(callOp.getResultTypes(),
28                                            convertedResults)))
29       return failure();
30 
31     // Substitute with the new result types from the corresponding FuncType
32     // conversion.
33     rewriter.replaceOpWithNewOp<CallOp>(callOp, callOp.callee(),
34                                         convertedResults, operands);
35     return success();
36   }
37 };
38 } // end anonymous namespace
39 
populateCallOpTypeConversionPattern(OwningRewritePatternList & patterns,MLIRContext * ctx,TypeConverter & converter)40 void mlir::populateCallOpTypeConversionPattern(
41     OwningRewritePatternList &patterns, MLIRContext *ctx,
42     TypeConverter &converter) {
43   patterns.insert<CallOpSignatureConversion>(converter, ctx);
44 }
45 
46 namespace {
47 /// Only needed to support partial conversion of functions where this pattern
48 /// ensures that the branch operation arguments matches up with the succesor
49 /// block arguments.
50 class BranchOpInterfaceTypeConversion : public ConversionPattern {
51 public:
BranchOpInterfaceTypeConversion(TypeConverter & typeConverter,MLIRContext * ctx)52   BranchOpInterfaceTypeConversion(TypeConverter &typeConverter,
53                                   MLIRContext *ctx)
54       : ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {}
55 
56   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const57   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
58                   ConversionPatternRewriter &rewriter) const final {
59     auto branchOp = dyn_cast<BranchOpInterface>(op);
60     if (!branchOp)
61       return failure();
62 
63     // For a branch operation, only some operands go to the target blocks, so
64     // only rewrite those.
65     SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end());
66     for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors();
67          succIdx < succEnd; ++succIdx) {
68       auto successorOperands = branchOp.getSuccessorOperands(succIdx);
69       if (!successorOperands)
70         continue;
71       for (int idx = successorOperands->getBeginOperandIndex(),
72                eidx = idx + successorOperands->size();
73            idx < eidx; ++idx) {
74         newOperands[idx] = operands[idx];
75       }
76     }
77     rewriter.updateRootInPlace(
78         op, [newOperands, op]() { op->setOperands(newOperands); });
79     return success();
80   }
81 };
82 } // end anonymous namespace
83 
84 namespace {
85 /// Only needed to support partial conversion of functions where this pattern
86 /// ensures that the branch operation arguments matches up with the succesor
87 /// block arguments.
88 class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
89 public:
90   using OpConversionPattern<ReturnOp>::OpConversionPattern;
91 
92   LogicalResult
matchAndRewrite(ReturnOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const93   matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
94                   ConversionPatternRewriter &rewriter) const final {
95     // For a return, all operands go to the results of the parent, so
96     // rewrite them all.
97     Operation *operation = op.getOperation();
98     rewriter.updateRootInPlace(
99         op, [operands, operation]() { operation->setOperands(operands); });
100     return success();
101   }
102 };
103 } // end anonymous namespace
104 
populateBranchOpInterfaceAndReturnOpTypeConversionPattern(OwningRewritePatternList & patterns,MLIRContext * ctx,TypeConverter & typeConverter)105 void mlir::populateBranchOpInterfaceAndReturnOpTypeConversionPattern(
106     OwningRewritePatternList &patterns, MLIRContext *ctx,
107     TypeConverter &typeConverter) {
108   patterns.insert<BranchOpInterfaceTypeConversion, ReturnOpTypeConversion>(
109       typeConverter, ctx);
110 }
111