1 //===- DecomposeCallGraphTypes.cpp - CG type decomposition ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.h"
10 #include "mlir/Dialect/StandardOps/IR/Ops.h"
11 #include "mlir/IR/BuiltinOps.h"
12 
13 using namespace mlir;
14 
15 //===----------------------------------------------------------------------===//
16 // ValueDecomposer
17 //===----------------------------------------------------------------------===//
18 
decomposeValue(OpBuilder & builder,Location loc,Type type,Value value,SmallVectorImpl<Value> & results)19 void ValueDecomposer::decomposeValue(OpBuilder &builder, Location loc,
20                                      Type type, Value value,
21                                      SmallVectorImpl<Value> &results) {
22   for (auto &conversion : decomposeValueConversions)
23     if (conversion(builder, loc, type, value, results))
24       return;
25   results.push_back(value);
26 }
27 
28 //===----------------------------------------------------------------------===//
29 // DecomposeCallGraphTypesOpConversionPattern
30 //===----------------------------------------------------------------------===//
31 
32 namespace {
33 /// Base OpConversionPattern class to make a ValueDecomposer available to
34 /// inherited patterns.
35 template <typename SourceOp>
36 class DecomposeCallGraphTypesOpConversionPattern
37     : public OpConversionPattern<SourceOp> {
38 public:
DecomposeCallGraphTypesOpConversionPattern(TypeConverter & typeConverter,MLIRContext * context,ValueDecomposer & decomposer,PatternBenefit benefit=1)39   DecomposeCallGraphTypesOpConversionPattern(TypeConverter &typeConverter,
40                                              MLIRContext *context,
41                                              ValueDecomposer &decomposer,
42                                              PatternBenefit benefit = 1)
43       : OpConversionPattern<SourceOp>(typeConverter, context, benefit),
44         decomposer(decomposer) {}
45 
46 protected:
47   ValueDecomposer &decomposer;
48 };
49 } // namespace
50 
51 //===----------------------------------------------------------------------===//
52 // DecomposeCallGraphTypesForFuncArgs
53 //===----------------------------------------------------------------------===//
54 
55 namespace {
56 /// Expand function arguments according to the provided TypeConverter and
57 /// ValueDecomposer.
58 struct DecomposeCallGraphTypesForFuncArgs
59     : public DecomposeCallGraphTypesOpConversionPattern<FuncOp> {
60   using DecomposeCallGraphTypesOpConversionPattern::
61       DecomposeCallGraphTypesOpConversionPattern;
62 
63   LogicalResult
matchAndRewrite__anond23ea8500211::DecomposeCallGraphTypesForFuncArgs64   matchAndRewrite(FuncOp op, ArrayRef<Value> operands,
65                   ConversionPatternRewriter &rewriter) const final {
66     auto functionType = op.getType();
67 
68     // Convert function arguments using the provided TypeConverter.
69     TypeConverter::SignatureConversion conversion(functionType.getNumInputs());
70     for (auto argType : llvm::enumerate(functionType.getInputs())) {
71       SmallVector<Type, 2> decomposedTypes;
72       getTypeConverter()->convertType(argType.value(), decomposedTypes);
73       if (!decomposedTypes.empty())
74         conversion.addInputs(argType.index(), decomposedTypes);
75     }
76 
77     // If the SignatureConversion doesn't apply, bail out.
78     if (failed(rewriter.convertRegionTypes(&op.getBody(), *getTypeConverter(),
79                                            &conversion)))
80       return failure();
81 
82     // Update the signature of the function.
83     SmallVector<Type, 2> newResultTypes;
84     getTypeConverter()->convertTypes(functionType.getResults(), newResultTypes);
85     rewriter.updateRootInPlace(op, [&] {
86       op.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
87                                           newResultTypes));
88     });
89     return success();
90   }
91 };
92 } // namespace
93 
94 //===----------------------------------------------------------------------===//
95 // DecomposeCallGraphTypesForReturnOp
96 //===----------------------------------------------------------------------===//
97 
98 namespace {
99 /// Expand return operands according to the provided TypeConverter and
100 /// ValueDecomposer.
101 struct DecomposeCallGraphTypesForReturnOp
102     : public DecomposeCallGraphTypesOpConversionPattern<ReturnOp> {
103   using DecomposeCallGraphTypesOpConversionPattern::
104       DecomposeCallGraphTypesOpConversionPattern;
105   LogicalResult
matchAndRewrite__anond23ea8500411::DecomposeCallGraphTypesForReturnOp106   matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
107                   ConversionPatternRewriter &rewriter) const final {
108     SmallVector<Value, 2> newOperands;
109     for (Value operand : operands)
110       decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
111                                 operand, newOperands);
112     rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
113     return success();
114   }
115 };
116 } // namespace
117 
118 //===----------------------------------------------------------------------===//
119 // DecomposeCallGraphTypesForCallOp
120 //===----------------------------------------------------------------------===//
121 
122 namespace {
123 /// Expand call op operands and results according to the provided TypeConverter
124 /// and ValueDecomposer.
125 struct DecomposeCallGraphTypesForCallOp
126     : public DecomposeCallGraphTypesOpConversionPattern<CallOp> {
127   using DecomposeCallGraphTypesOpConversionPattern::
128       DecomposeCallGraphTypesOpConversionPattern;
129 
130   LogicalResult
matchAndRewrite__anond23ea8500511::DecomposeCallGraphTypesForCallOp131   matchAndRewrite(CallOp op, ArrayRef<Value> operands,
132                   ConversionPatternRewriter &rewriter) const final {
133 
134     // Create the operands list of the new `CallOp`.
135     SmallVector<Value, 2> newOperands;
136     for (Value operand : operands)
137       decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
138                                 operand, newOperands);
139 
140     // Create the new result types for the new `CallOp` and track the indices in
141     // the new call op's results that correspond to the old call op's results.
142     //
143     // expandedResultIndices[i] = "list of new result indices that old result i
144     // expanded to".
145     SmallVector<Type, 2> newResultTypes;
146     SmallVector<SmallVector<unsigned, 2>, 4> expandedResultIndices;
147     for (Type resultType : op.getResultTypes()) {
148       unsigned oldSize = newResultTypes.size();
149       getTypeConverter()->convertType(resultType, newResultTypes);
150       auto &resultMapping = expandedResultIndices.emplace_back();
151       for (unsigned i = oldSize, e = newResultTypes.size(); i < e; i++)
152         resultMapping.push_back(i);
153     }
154 
155     CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCallee(),
156                                                newResultTypes, newOperands);
157 
158     // Build a replacement value for each result to replace its uses. If a
159     // result has multiple mapping values, it needs to be materialized as a
160     // single value.
161     SmallVector<Value, 2> replacedValues;
162     replacedValues.reserve(op.getNumResults());
163     for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
164       auto decomposedValues = llvm::to_vector<6>(
165           llvm::map_range(expandedResultIndices[i],
166                           [&](unsigned i) { return newCallOp.getResult(i); }));
167       if (decomposedValues.empty()) {
168         // No replacement is required.
169         replacedValues.push_back(nullptr);
170       } else if (decomposedValues.size() == 1) {
171         replacedValues.push_back(decomposedValues.front());
172       } else {
173         // Materialize a single Value to replace the original Value.
174         Value materialized = getTypeConverter()->materializeArgumentConversion(
175             rewriter, op.getLoc(), op.getType(i), decomposedValues);
176         replacedValues.push_back(materialized);
177       }
178     }
179     rewriter.replaceOp(op, replacedValues);
180     return success();
181   }
182 };
183 } // namespace
184 
populateDecomposeCallGraphTypesPatterns(MLIRContext * context,TypeConverter & typeConverter,ValueDecomposer & decomposer,OwningRewritePatternList & patterns)185 void mlir::populateDecomposeCallGraphTypesPatterns(
186     MLIRContext *context, TypeConverter &typeConverter,
187     ValueDecomposer &decomposer, OwningRewritePatternList &patterns) {
188   patterns.insert<DecomposeCallGraphTypesForCallOp,
189                   DecomposeCallGraphTypesForFuncArgs,
190                   DecomposeCallGraphTypesForReturnOp>(typeConverter, context,
191                                                       decomposer);
192 }
193