1 //===- TestDecomposeCallGraphTypes.cpp - Test CG type decomposition -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "mlir/Dialect/StandardOps/IR/Ops.h"
11 #include "mlir/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Transforms/DialectConversion.h"
15 
16 using namespace mlir;
17 
18 namespace {
19 /// A pass for testing call graph type decomposition.
20 ///
21 /// This instantiates the patterns with a TypeConverter and ValueDecomposer
22 /// that splits tuple types into their respective element types.
23 /// For example, `tuple<T1, T2, T3> --> T1, T2, T3`.
24 struct TestDecomposeCallGraphTypes
25     : public PassWrapper<TestDecomposeCallGraphTypes, OperationPass<ModuleOp>> {
26 
getDependentDialects__anon8352cea80111::TestDecomposeCallGraphTypes27   void getDependentDialects(DialectRegistry &registry) const override {
28     registry.insert<test::TestDialect>();
29   }
runOnOperation__anon8352cea80111::TestDecomposeCallGraphTypes30   void runOnOperation() override {
31     ModuleOp module = getOperation();
32     auto *context = &getContext();
33     TypeConverter typeConverter;
34     ConversionTarget target(*context);
35     ValueDecomposer decomposer;
36     OwningRewritePatternList patterns;
37 
38     target.addLegalDialect<test::TestDialect>();
39 
40     target.addDynamicallyLegalOp<ReturnOp>([&](ReturnOp op) {
41       return typeConverter.isLegal(op.getOperandTypes());
42     });
43     target.addDynamicallyLegalOp<CallOp>(
44         [&](CallOp op) { return typeConverter.isLegal(op); });
45     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
46       return typeConverter.isSignatureLegal(op.getType());
47     });
48 
49     typeConverter.addConversion([](Type type) { return type; });
50     typeConverter.addConversion(
51         [](TupleType tupleType, SmallVectorImpl<Type> &types) {
52           tupleType.getFlattenedTypes(types);
53           return success();
54         });
55 
56     decomposer.addDecomposeValueConversion([](OpBuilder &builder, Location loc,
57                                               TupleType resultType, Value value,
58                                               SmallVectorImpl<Value> &values) {
59       for (unsigned i = 0, e = resultType.size(); i < e; ++i) {
60         Value res = builder.create<test::GetTupleElementOp>(
61             loc, resultType.getType(i), value, builder.getI32IntegerAttr(i));
62         values.push_back(res);
63       }
64       return success();
65     });
66 
67     typeConverter.addArgumentMaterialization(
68         [](OpBuilder &builder, TupleType resultType, ValueRange inputs,
69            Location loc) -> Optional<Value> {
70           if (inputs.size() == 1)
71             return llvm::None;
72           TypeRange TypeRange = inputs.getTypes();
73           SmallVector<Type, 2> types(TypeRange.begin(), TypeRange.end());
74           TupleType tuple = TupleType::get(types, builder.getContext());
75           Value value = builder.create<test::MakeTupleOp>(loc, tuple, inputs);
76           return value;
77         });
78 
79     populateDecomposeCallGraphTypesPatterns(context, typeConverter, decomposer,
80                                             patterns);
81 
82     if (failed(applyPartialConversion(module, target, std::move(patterns))))
83       return signalPassFailure();
84   }
85 };
86 
87 } // end anonymous namespace
88 
89 namespace mlir {
90 namespace test {
registerTestDecomposeCallGraphTypes()91 void registerTestDecomposeCallGraphTypes() {
92   PassRegistration<TestDecomposeCallGraphTypes> pass(
93       "test-decompose-call-graph-types",
94       "Decomposes types at call graph boundaries.");
95 }
96 } // namespace test
97 } // namespace mlir
98