1 //===- ShapeToShapeLowering.cpp - Prepare for lowering to Standard --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Shape/IR/Shape.h"
11 #include "mlir/Dialect/Shape/Transforms/Passes.h"
12 #include "mlir/Dialect/StandardOps/IR/Ops.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Transforms/DialectConversion.h"
17
18 using namespace mlir;
19 using namespace mlir::shape;
20
21 namespace {
22 /// Converts `shape.num_elements` to `shape.reduce`.
23 struct NumElementsOpConverter : public OpRewritePattern<NumElementsOp> {
24 public:
25 using OpRewritePattern::OpRewritePattern;
26
27 LogicalResult matchAndRewrite(NumElementsOp op,
28 PatternRewriter &rewriter) const final;
29 };
30 } // namespace
31
32 LogicalResult
matchAndRewrite(NumElementsOp op,PatternRewriter & rewriter) const33 NumElementsOpConverter::matchAndRewrite(NumElementsOp op,
34 PatternRewriter &rewriter) const {
35 auto loc = op.getLoc();
36 Type valueType = op.getResult().getType();
37 Value init = op->getDialect()
38 ->materializeConstant(rewriter, rewriter.getIndexAttr(1),
39 valueType, loc)
40 ->getResult(0);
41 ReduceOp reduce = rewriter.create<ReduceOp>(loc, op.shape(), init);
42
43 // Generate reduce operator.
44 Block *body = reduce.getBody();
45 OpBuilder b = OpBuilder::atBlockEnd(body);
46 Value product = b.create<MulOp>(loc, valueType, body->getArgument(1),
47 body->getArgument(2));
48 b.create<shape::YieldOp>(loc, product);
49
50 rewriter.replaceOp(op, reduce.result());
51 return success();
52 }
53
54 namespace {
55 struct ShapeToShapeLowering
56 : public ShapeToShapeLoweringBase<ShapeToShapeLowering> {
57 void runOnFunction() override;
58 };
59 } // namespace
60
runOnFunction()61 void ShapeToShapeLowering::runOnFunction() {
62 MLIRContext &ctx = getContext();
63
64 OwningRewritePatternList patterns;
65 populateShapeRewritePatterns(&ctx, patterns);
66
67 ConversionTarget target(getContext());
68 target.addLegalDialect<ShapeDialect, StandardOpsDialect>();
69 target.addIllegalOp<NumElementsOp>();
70 if (failed(mlir::applyPartialConversion(getFunction(), target,
71 std::move(patterns))))
72 signalPassFailure();
73 }
74
populateShapeRewritePatterns(MLIRContext * context,OwningRewritePatternList & patterns)75 void mlir::populateShapeRewritePatterns(MLIRContext *context,
76 OwningRewritePatternList &patterns) {
77 patterns.insert<NumElementsOpConverter>(context);
78 }
79
createShapeToShapeLowering()80 std::unique_ptr<Pass> mlir::createShapeToShapeLowering() {
81 return std::make_unique<ShapeToShapeLowering>();
82 }
83