1 //===- Generalization.cpp - linalg named ops to generic ops  --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg generalization pass. It converts named
10 // Linalg ops to linalg.generic ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/Linalg/Passes.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/EDSC/Builders.h"
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/Support/Debug.h"
26 
27 #define DEBUG_TYPE "linalg-generalization"
28 
29 using namespace mlir;
30 
31 // Creates a linalg.generic op from the given `namedOp`. Returns a null op if
32 // the given `namedOp` does not have a region builder.
createGenericOpFromNamedOp(linalg::LinalgOp namedOp,OpBuilder & builder)33 static linalg::GenericOp createGenericOpFromNamedOp(linalg::LinalgOp namedOp,
34                                                     OpBuilder &builder) {
35   auto regionBuilder = namedOp.getRegionBuilder();
36   if (!regionBuilder) {
37     LLVM_DEBUG(llvm::dbgs() << "no region builder for op: " << namedOp << "\n");
38     return nullptr;
39   }
40 
41   SmallVector<AffineMap, 4> indexingMaps = namedOp.getIndexingMaps();
42   auto iterators = llvm::to_vector<4>(
43       namedOp.iterator_types().getAsValueRange<StringAttr>());
44   auto resultTypes = namedOp.getOutputTensorTypes();
45   SmallVector<Type, 4> types(resultTypes.begin(), resultTypes.end());
46 
47   return builder.create<linalg::GenericOp>(
48       namedOp.getLoc(), types, namedOp.getInputs(), namedOp.getOutputBuffers(),
49       namedOp.getInitTensors(), indexingMaps, iterators,
50       [&regionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) {
51         edsc::ScopedContext scope(bodyBuilder, loc);
52         regionBuilder(*bodyBuilder.getBlock());
53       });
54 }
55 
56 namespace {
57 
58 /// Base class for all linalg generalization patterns. A subclass must provide
59 /// the following method:
60 ///   linalg::GenericOp createGenericOp(RootOp, PatternRewriter &)
61 /// for creating the generic op.
62 // TODO: remove this pattern after migrating all manually-written named ops
63 // into auto-generated ones.
64 template <typename ConcretePattern, typename RootOp>
65 struct LinalgGeneralizationPattern : OpRewritePattern<RootOp> {
LinalgGeneralizationPattern__anondb85c27e0211::LinalgGeneralizationPattern66   LinalgGeneralizationPattern(MLIRContext *context, linalg::LinalgMarker marker,
67                               PatternBenefit benefit = 1)
68       : OpRewritePattern<RootOp>(context, benefit), marker(std::move(marker)) {}
69 
matchAndRewrite__anondb85c27e0211::LinalgGeneralizationPattern70   LogicalResult matchAndRewrite(RootOp rootOp,
71                                 PatternRewriter &rewriter) const override {
72     auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp.getOperation());
73     if (!linalgOp)
74       return failure();
75     if (failed(marker.checkAndNotify(rewriter, linalgOp)))
76       return failure();
77 
78     auto *pattern = static_cast<const ConcretePattern *>(this);
79     linalg::GenericOp genericOp = pattern->createGenericOp(rootOp, rewriter);
80     if (!genericOp)
81       return failure();
82 
83     rewriter.replaceOp(rootOp, genericOp.getResults());
84     marker.replaceLinalgMarker(rewriter, genericOp.getOperation());
85     return success();
86   }
87 
88 private:
89   linalg::LinalgMarker marker;
90 };
91 
92 struct GeneralizeConvOp
93     : public LinalgGeneralizationPattern<GeneralizeConvOp, linalg::ConvOp> {
94   using LinalgGeneralizationPattern::LinalgGeneralizationPattern;
95 
96   linalg::GenericOp createGenericOp(linalg::ConvOp, OpBuilder &rewriter) const;
97 };
98 
99 /// Catch-all pattern for converting all named ops with a region builder into
100 /// linalg.generic.
101 struct LinalgNamedOpGeneralizationPattern : RewritePattern {
LinalgNamedOpGeneralizationPattern__anondb85c27e0211::LinalgNamedOpGeneralizationPattern102   LinalgNamedOpGeneralizationPattern(MLIRContext *context,
103                                      linalg::LinalgMarker marker,
104                                      PatternBenefit benefit = 1)
105       : RewritePattern(benefit, MatchAnyOpTypeTag()),
106         marker(std::move(marker)) {}
107 
matchAndRewrite__anondb85c27e0211::LinalgNamedOpGeneralizationPattern108   LogicalResult matchAndRewrite(Operation *rootOp,
109                                 PatternRewriter &rewriter) const override {
110     auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp);
111     if (!linalgOp)
112       return failure();
113     if (failed(marker.checkAndNotify(rewriter, linalgOp)))
114       return failure();
115 
116     // No nothing to do for linalg.generic and linalg.indexed_generic.
117     if (isa<linalg::GenericOp, linalg::IndexedGenericOp>(rootOp))
118       return failure();
119 
120     linalg::GenericOp genericOp =
121         createGenericOpFromNamedOp(linalgOp, rewriter);
122     if (!genericOp)
123       return failure();
124 
125     rewriter.replaceOp(rootOp, genericOp.getResults());
126     marker.replaceLinalgMarker(rewriter, genericOp.getOperation());
127     return success();
128   }
129 
130 private:
131   linalg::LinalgMarker marker;
132 };
133 
134 struct LinalgGeneralizationPass
135     : public LinalgGeneralizationBase<LinalgGeneralizationPass> {
136   void runOnFunction() override;
137 };
138 
139 } // namespace
140 
runOnFunction()141 void LinalgGeneralizationPass::runOnFunction() {
142   FuncOp func = getFunction();
143   OwningRewritePatternList patterns;
144   linalg::populateLinalgConvGeneralizationPatterns(&getContext(), patterns);
145   linalg::populateLinalgNamedOpsGeneralizationPatterns(&getContext(), patterns);
146   applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns));
147 }
148 
createGenericOp(linalg::ConvOp convOp,OpBuilder & builder) const149 linalg::GenericOp GeneralizeConvOp::createGenericOp(linalg::ConvOp convOp,
150                                                     OpBuilder &builder) const {
151   SmallVector<AffineMap, 4> indexingMaps = convOp.getIndexingMaps();
152   auto iterators =
153       llvm::to_vector<4>(convOp.iterator_types().getAsValueRange<StringAttr>());
154   return builder.create<linalg::GenericOp>(
155       convOp.getLoc(), /*resultTensorTypes=*/ArrayRef<Type>(),
156       convOp.getInputBuffers(), convOp.getOutputBuffers(),
157       /*initTensors=*/ValueRange(), indexingMaps, iterators,
158       [](OpBuilder &bodyBuilder, Location bodyLoc, ValueRange bodyArgs) {
159         Value mul =
160             bodyBuilder.create<MulFOp>(bodyLoc, bodyArgs[0], bodyArgs[1]);
161         Value add = bodyBuilder.create<AddFOp>(bodyLoc, mul, bodyArgs[2]);
162         bodyBuilder.create<linalg::YieldOp>(bodyLoc, add);
163       });
164 }
165 
populateLinalgConvGeneralizationPatterns(MLIRContext * context,OwningRewritePatternList & patterns,linalg::LinalgMarker marker)166 void mlir::linalg::populateLinalgConvGeneralizationPatterns(
167     MLIRContext *context, OwningRewritePatternList &patterns,
168     linalg::LinalgMarker marker) {
169   patterns.insert<GeneralizeConvOp>(context, marker);
170 }
171 
populateLinalgNamedOpsGeneralizationPatterns(MLIRContext * context,OwningRewritePatternList & patterns,linalg::LinalgMarker marker)172 void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
173     MLIRContext *context, OwningRewritePatternList &patterns,
174     linalg::LinalgMarker marker) {
175   patterns.insert<LinalgNamedOpGeneralizationPattern>(context, marker);
176 }
177 
createLinalgGeneralizationPass()178 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() {
179   return std::make_unique<LinalgGeneralizationPass>();
180 }
181