1 //===- StructuralTypeConversions.cpp - Shape structural type conversions --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Shape/IR/Shape.h"
11 #include "mlir/Dialect/Shape/Transforms/Passes.h"
12 #include "mlir/Transforms/DialectConversion.h"
13 
14 using namespace mlir;
15 using namespace mlir::shape;
16 
17 namespace {
18 class ConvertAssumingOpTypes : public OpConversionPattern<AssumingOp> {
19 public:
20   using OpConversionPattern::OpConversionPattern;
21 
22   LogicalResult
matchAndRewrite(AssumingOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const23   matchAndRewrite(AssumingOp op, ArrayRef<Value> operands,
24                   ConversionPatternRewriter &rewriter) const final {
25     SmallVector<Type, 2> newResultTypes;
26     newResultTypes.reserve(op.getNumResults());
27     for (auto result : op.getResults()) {
28       auto originalType = result.getType();
29       Type convertedType = getTypeConverter()->convertType(originalType);
30       newResultTypes.push_back(convertedType);
31     }
32 
33     auto newAssumingOp =
34         rewriter.create<AssumingOp>(op.getLoc(), newResultTypes, op.witness());
35     rewriter.inlineRegionBefore(op.doRegion(), newAssumingOp.doRegion(),
36                                 newAssumingOp.doRegion().end());
37     rewriter.replaceOp(op, newAssumingOp.getResults());
38 
39     return success();
40   }
41 };
42 } // namespace
43 
44 namespace {
45 class ConvertAssumingYieldOpTypes
46     : public OpConversionPattern<AssumingYieldOp> {
47 public:
48   using OpConversionPattern::OpConversionPattern;
49 
50   LogicalResult
matchAndRewrite(AssumingYieldOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const51   matchAndRewrite(AssumingYieldOp op, ArrayRef<Value> operands,
52                   ConversionPatternRewriter &rewriter) const final {
53     rewriter.replaceOpWithNewOp<AssumingYieldOp>(op, operands);
54     return success();
55   }
56 };
57 } // namespace
58 
populateShapeStructuralTypeConversionsAndLegality(MLIRContext * context,TypeConverter & typeConverter,OwningRewritePatternList & patterns,ConversionTarget & target)59 void mlir::populateShapeStructuralTypeConversionsAndLegality(
60     MLIRContext *context, TypeConverter &typeConverter,
61     OwningRewritePatternList &patterns, ConversionTarget &target) {
62   patterns.insert<ConvertAssumingOpTypes, ConvertAssumingYieldOpTypes>(
63       typeConverter, context);
64   target.addDynamicallyLegalOp<AssumingOp>([&](AssumingOp op) {
65     return typeConverter.isLegal(op.getResultTypes());
66   });
67   target.addDynamicallyLegalOp<AssumingYieldOp>([&](AssumingYieldOp op) {
68     return typeConverter.isLegal(op.getOperandTypes());
69   });
70 }
71