1 //===- DecomposeCallGraphTypes.cpp - CG type decomposition ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.h"
10 #include "mlir/Dialect/StandardOps/IR/Ops.h"
11 #include "mlir/IR/BuiltinOps.h"
12
13 using namespace mlir;
14
15 //===----------------------------------------------------------------------===//
16 // ValueDecomposer
17 //===----------------------------------------------------------------------===//
18
decomposeValue(OpBuilder & builder,Location loc,Type type,Value value,SmallVectorImpl<Value> & results)19 void ValueDecomposer::decomposeValue(OpBuilder &builder, Location loc,
20 Type type, Value value,
21 SmallVectorImpl<Value> &results) {
22 for (auto &conversion : decomposeValueConversions)
23 if (conversion(builder, loc, type, value, results))
24 return;
25 results.push_back(value);
26 }
27
28 //===----------------------------------------------------------------------===//
29 // DecomposeCallGraphTypesOpConversionPattern
30 //===----------------------------------------------------------------------===//
31
32 namespace {
33 /// Base OpConversionPattern class to make a ValueDecomposer available to
34 /// inherited patterns.
35 template <typename SourceOp>
36 class DecomposeCallGraphTypesOpConversionPattern
37 : public OpConversionPattern<SourceOp> {
38 public:
DecomposeCallGraphTypesOpConversionPattern(TypeConverter & typeConverter,MLIRContext * context,ValueDecomposer & decomposer,PatternBenefit benefit=1)39 DecomposeCallGraphTypesOpConversionPattern(TypeConverter &typeConverter,
40 MLIRContext *context,
41 ValueDecomposer &decomposer,
42 PatternBenefit benefit = 1)
43 : OpConversionPattern<SourceOp>(typeConverter, context, benefit),
44 decomposer(decomposer) {}
45
46 protected:
47 ValueDecomposer &decomposer;
48 };
49 } // namespace
50
51 //===----------------------------------------------------------------------===//
52 // DecomposeCallGraphTypesForFuncArgs
53 //===----------------------------------------------------------------------===//
54
55 namespace {
56 /// Expand function arguments according to the provided TypeConverter and
57 /// ValueDecomposer.
58 struct DecomposeCallGraphTypesForFuncArgs
59 : public DecomposeCallGraphTypesOpConversionPattern<FuncOp> {
60 using DecomposeCallGraphTypesOpConversionPattern::
61 DecomposeCallGraphTypesOpConversionPattern;
62
63 LogicalResult
matchAndRewrite__anond23ea8500211::DecomposeCallGraphTypesForFuncArgs64 matchAndRewrite(FuncOp op, ArrayRef<Value> operands,
65 ConversionPatternRewriter &rewriter) const final {
66 auto functionType = op.getType();
67
68 // Convert function arguments using the provided TypeConverter.
69 TypeConverter::SignatureConversion conversion(functionType.getNumInputs());
70 for (auto argType : llvm::enumerate(functionType.getInputs())) {
71 SmallVector<Type, 2> decomposedTypes;
72 getTypeConverter()->convertType(argType.value(), decomposedTypes);
73 if (!decomposedTypes.empty())
74 conversion.addInputs(argType.index(), decomposedTypes);
75 }
76
77 // If the SignatureConversion doesn't apply, bail out.
78 if (failed(rewriter.convertRegionTypes(&op.getBody(), *getTypeConverter(),
79 &conversion)))
80 return failure();
81
82 // Update the signature of the function.
83 SmallVector<Type, 2> newResultTypes;
84 getTypeConverter()->convertTypes(functionType.getResults(), newResultTypes);
85 rewriter.updateRootInPlace(op, [&] {
86 op.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
87 newResultTypes));
88 });
89 return success();
90 }
91 };
92 } // namespace
93
94 //===----------------------------------------------------------------------===//
95 // DecomposeCallGraphTypesForReturnOp
96 //===----------------------------------------------------------------------===//
97
98 namespace {
99 /// Expand return operands according to the provided TypeConverter and
100 /// ValueDecomposer.
101 struct DecomposeCallGraphTypesForReturnOp
102 : public DecomposeCallGraphTypesOpConversionPattern<ReturnOp> {
103 using DecomposeCallGraphTypesOpConversionPattern::
104 DecomposeCallGraphTypesOpConversionPattern;
105 LogicalResult
matchAndRewrite__anond23ea8500411::DecomposeCallGraphTypesForReturnOp106 matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
107 ConversionPatternRewriter &rewriter) const final {
108 SmallVector<Value, 2> newOperands;
109 for (Value operand : operands)
110 decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
111 operand, newOperands);
112 rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
113 return success();
114 }
115 };
116 } // namespace
117
118 //===----------------------------------------------------------------------===//
119 // DecomposeCallGraphTypesForCallOp
120 //===----------------------------------------------------------------------===//
121
122 namespace {
123 /// Expand call op operands and results according to the provided TypeConverter
124 /// and ValueDecomposer.
125 struct DecomposeCallGraphTypesForCallOp
126 : public DecomposeCallGraphTypesOpConversionPattern<CallOp> {
127 using DecomposeCallGraphTypesOpConversionPattern::
128 DecomposeCallGraphTypesOpConversionPattern;
129
130 LogicalResult
matchAndRewrite__anond23ea8500511::DecomposeCallGraphTypesForCallOp131 matchAndRewrite(CallOp op, ArrayRef<Value> operands,
132 ConversionPatternRewriter &rewriter) const final {
133
134 // Create the operands list of the new `CallOp`.
135 SmallVector<Value, 2> newOperands;
136 for (Value operand : operands)
137 decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
138 operand, newOperands);
139
140 // Create the new result types for the new `CallOp` and track the indices in
141 // the new call op's results that correspond to the old call op's results.
142 //
143 // expandedResultIndices[i] = "list of new result indices that old result i
144 // expanded to".
145 SmallVector<Type, 2> newResultTypes;
146 SmallVector<SmallVector<unsigned, 2>, 4> expandedResultIndices;
147 for (Type resultType : op.getResultTypes()) {
148 unsigned oldSize = newResultTypes.size();
149 getTypeConverter()->convertType(resultType, newResultTypes);
150 auto &resultMapping = expandedResultIndices.emplace_back();
151 for (unsigned i = oldSize, e = newResultTypes.size(); i < e; i++)
152 resultMapping.push_back(i);
153 }
154
155 CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCallee(),
156 newResultTypes, newOperands);
157
158 // Build a replacement value for each result to replace its uses. If a
159 // result has multiple mapping values, it needs to be materialized as a
160 // single value.
161 SmallVector<Value, 2> replacedValues;
162 replacedValues.reserve(op.getNumResults());
163 for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
164 auto decomposedValues = llvm::to_vector<6>(
165 llvm::map_range(expandedResultIndices[i],
166 [&](unsigned i) { return newCallOp.getResult(i); }));
167 if (decomposedValues.empty()) {
168 // No replacement is required.
169 replacedValues.push_back(nullptr);
170 } else if (decomposedValues.size() == 1) {
171 replacedValues.push_back(decomposedValues.front());
172 } else {
173 // Materialize a single Value to replace the original Value.
174 Value materialized = getTypeConverter()->materializeArgumentConversion(
175 rewriter, op.getLoc(), op.getType(i), decomposedValues);
176 replacedValues.push_back(materialized);
177 }
178 }
179 rewriter.replaceOp(op, replacedValues);
180 return success();
181 }
182 };
183 } // namespace
184
populateDecomposeCallGraphTypesPatterns(MLIRContext * context,TypeConverter & typeConverter,ValueDecomposer & decomposer,OwningRewritePatternList & patterns)185 void mlir::populateDecomposeCallGraphTypesPatterns(
186 MLIRContext *context, TypeConverter &typeConverter,
187 ValueDecomposer &decomposer, OwningRewritePatternList &patterns) {
188 patterns.insert<DecomposeCallGraphTypesForCallOp,
189 DecomposeCallGraphTypesForFuncArgs,
190 DecomposeCallGraphTypesForReturnOp>(typeConverter, context,
191 decomposer);
192 }
193