1 //===- TestDecomposeCallGraphTypes.cpp - Test CG type decomposition -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "TestDialect.h"
10 #include "mlir/Dialect/StandardOps/IR/Ops.h"
11 #include "mlir/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Transforms/DialectConversion.h"
15
16 using namespace mlir;
17
18 namespace {
19 /// A pass for testing call graph type decomposition.
20 ///
21 /// This instantiates the patterns with a TypeConverter and ValueDecomposer
22 /// that splits tuple types into their respective element types.
23 /// For example, `tuple<T1, T2, T3> --> T1, T2, T3`.
24 struct TestDecomposeCallGraphTypes
25 : public PassWrapper<TestDecomposeCallGraphTypes, OperationPass<ModuleOp>> {
26
getDependentDialects__anon8352cea80111::TestDecomposeCallGraphTypes27 void getDependentDialects(DialectRegistry ®istry) const override {
28 registry.insert<test::TestDialect>();
29 }
runOnOperation__anon8352cea80111::TestDecomposeCallGraphTypes30 void runOnOperation() override {
31 ModuleOp module = getOperation();
32 auto *context = &getContext();
33 TypeConverter typeConverter;
34 ConversionTarget target(*context);
35 ValueDecomposer decomposer;
36 OwningRewritePatternList patterns;
37
38 target.addLegalDialect<test::TestDialect>();
39
40 target.addDynamicallyLegalOp<ReturnOp>([&](ReturnOp op) {
41 return typeConverter.isLegal(op.getOperandTypes());
42 });
43 target.addDynamicallyLegalOp<CallOp>(
44 [&](CallOp op) { return typeConverter.isLegal(op); });
45 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
46 return typeConverter.isSignatureLegal(op.getType());
47 });
48
49 typeConverter.addConversion([](Type type) { return type; });
50 typeConverter.addConversion(
51 [](TupleType tupleType, SmallVectorImpl<Type> &types) {
52 tupleType.getFlattenedTypes(types);
53 return success();
54 });
55
56 decomposer.addDecomposeValueConversion([](OpBuilder &builder, Location loc,
57 TupleType resultType, Value value,
58 SmallVectorImpl<Value> &values) {
59 for (unsigned i = 0, e = resultType.size(); i < e; ++i) {
60 Value res = builder.create<test::GetTupleElementOp>(
61 loc, resultType.getType(i), value, builder.getI32IntegerAttr(i));
62 values.push_back(res);
63 }
64 return success();
65 });
66
67 typeConverter.addArgumentMaterialization(
68 [](OpBuilder &builder, TupleType resultType, ValueRange inputs,
69 Location loc) -> Optional<Value> {
70 if (inputs.size() == 1)
71 return llvm::None;
72 TypeRange TypeRange = inputs.getTypes();
73 SmallVector<Type, 2> types(TypeRange.begin(), TypeRange.end());
74 TupleType tuple = TupleType::get(types, builder.getContext());
75 Value value = builder.create<test::MakeTupleOp>(loc, tuple, inputs);
76 return value;
77 });
78
79 populateDecomposeCallGraphTypesPatterns(context, typeConverter, decomposer,
80 patterns);
81
82 if (failed(applyPartialConversion(module, target, std::move(patterns))))
83 return signalPassFailure();
84 }
85 };
86
87 } // end anonymous namespace
88
89 namespace mlir {
90 namespace test {
registerTestDecomposeCallGraphTypes()91 void registerTestDecomposeCallGraphTypes() {
92 PassRegistration<TestDecomposeCallGraphTypes> pass(
93 "test-decompose-call-graph-types",
94 "Decomposes types at call graph boundaries.");
95 }
96 } // namespace test
97 } // namespace mlir
98