1 //===- ShapeToShapeLowering.cpp - Prepare for lowering to Standard --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Shape/IR/Shape.h"
11 #include "mlir/Dialect/Shape/Transforms/Passes.h"
12 #include "mlir/Dialect/StandardOps/IR/Ops.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 
18 using namespace mlir;
19 using namespace mlir::shape;
20 
21 namespace {
22 /// Converts `shape.num_elements` to `shape.reduce`.
23 struct NumElementsOpConverter : public OpRewritePattern<NumElementsOp> {
24 public:
25   using OpRewritePattern::OpRewritePattern;
26 
27   LogicalResult matchAndRewrite(NumElementsOp op,
28                                 PatternRewriter &rewriter) const final;
29 };
30 } // namespace
31 
32 LogicalResult
matchAndRewrite(NumElementsOp op,PatternRewriter & rewriter) const33 NumElementsOpConverter::matchAndRewrite(NumElementsOp op,
34                                         PatternRewriter &rewriter) const {
35   auto loc = op.getLoc();
36   Type valueType = op.getResult().getType();
37   Value init = op->getDialect()
38                    ->materializeConstant(rewriter, rewriter.getIndexAttr(1),
39                                          valueType, loc)
40                    ->getResult(0);
41   ReduceOp reduce = rewriter.create<ReduceOp>(loc, op.shape(), init);
42 
43   // Generate reduce operator.
44   Block *body = reduce.getBody();
45   OpBuilder b = OpBuilder::atBlockEnd(body);
46   Value product = b.create<MulOp>(loc, valueType, body->getArgument(1),
47                                   body->getArgument(2));
48   b.create<shape::YieldOp>(loc, product);
49 
50   rewriter.replaceOp(op, reduce.result());
51   return success();
52 }
53 
54 namespace {
55 struct ShapeToShapeLowering
56     : public ShapeToShapeLoweringBase<ShapeToShapeLowering> {
57   void runOnFunction() override;
58 };
59 } // namespace
60 
runOnFunction()61 void ShapeToShapeLowering::runOnFunction() {
62   MLIRContext &ctx = getContext();
63 
64   OwningRewritePatternList patterns;
65   populateShapeRewritePatterns(&ctx, patterns);
66 
67   ConversionTarget target(getContext());
68   target.addLegalDialect<ShapeDialect, StandardOpsDialect>();
69   target.addIllegalOp<NumElementsOp>();
70   if (failed(mlir::applyPartialConversion(getFunction(), target,
71                                           std::move(patterns))))
72     signalPassFailure();
73 }
74 
populateShapeRewritePatterns(MLIRContext * context,OwningRewritePatternList & patterns)75 void mlir::populateShapeRewritePatterns(MLIRContext *context,
76                                         OwningRewritePatternList &patterns) {
77   patterns.insert<NumElementsOpConverter>(context);
78 }
79 
createShapeToShapeLowering()80 std::unique_ptr<Pass> mlir::createShapeToShapeLowering() {
81   return std::make_unique<ShapeToShapeLowering>();
82 }
83