1 //===- CodegenStrategy.h - Linalg programmable codegen strategy -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_LINALG_TRANSFORMS_CODEGENSTRATEGY_H_
10 #define MLIR_DIALECT_LINALG_TRANSFORMS_CODEGENSTRATEGY_H_
11 
12 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14 
15 namespace mlir {
16 
17 class FuncOp;
18 
19 namespace linalg {
20 
21 /// Abstract Transformation class applied in a sequence that also handles state
22 /// through markers.
23 struct Transformation {
24   virtual ~Transformation() = default;
25   virtual OwningRewritePatternList
26   buildRewritePatterns(MLIRContext *context, linalg::LinalgMarker m) = 0;
27   linalg::LinalgMarker marker;
28 };
29 
30 /// Promotion transformation enqueues a particular stage-1 pattern for
31 /// `Tile<LinalgOpType>`with the appropriate `options`.
32 template <typename LinalgOpType>
33 struct Tile : public Transformation {
TileTile34   explicit Tile(linalg::LinalgTilingOptions options) : options(options) {}
35 
36   OwningRewritePatternList
buildRewritePatternsTile37   buildRewritePatterns(MLIRContext *context, linalg::LinalgMarker m) override {
38     OwningRewritePatternList tilingPatterns;
39     tilingPatterns.insert<linalg::LinalgTilingPattern<LinalgOpType>>(
40         context, options, m);
41     return tilingPatterns;
42   }
43 
44 private:
45   linalg::LinalgTilingOptions options;
46 };
47 
48 /// Promotion transformation enqueues a particular stage-1 pattern for
49 /// `Promote<LinalgOpType>`with the appropriate `options`.
50 template <typename LinalgOpType>
51 struct Promote : public Transformation {
PromotePromote52   explicit Promote(linalg::LinalgPromotionOptions options) : options(options) {}
53 
54   OwningRewritePatternList
buildRewritePatternsPromote55   buildRewritePatterns(MLIRContext *context, linalg::LinalgMarker m) override {
56     OwningRewritePatternList promotionPatterns;
57     promotionPatterns.insert<linalg::LinalgPromotionPattern<LinalgOpType>>(
58         context, options, m);
59     return promotionPatterns;
60   }
61 
62 private:
63   linalg::LinalgPromotionOptions options;
64 };
65 
66 /// Vectorization transformation enqueues a particular stage-1 pattern for
67 /// `LinalgVectorizationPattern<LinalgOpType>` as well as copy to vector
68 /// transfer rewrite forwarding patterns.
69 template <typename LinalgOpType>
70 struct Vectorize : public Transformation {
71   OwningRewritePatternList
buildRewritePatternsVectorize72   buildRewritePatterns(MLIRContext *context, linalg::LinalgMarker m) override {
73     OwningRewritePatternList vectorizationPatterns;
74     // FillOp may interfere with forwarding patterns atm, so we bump up the
75     // priority of LinalgCopyVTRForwardingPattern /
76     // LinalgCopyVTWForwardingPattern.
77     vectorizationPatterns
78         .insert<linalg::LinalgVectorizationPattern<LinalgOpType>>(context, m);
79     vectorizationPatterns.insert<linalg::LinalgCopyVTRForwardingPattern,
80                                  linalg::LinalgCopyVTWForwardingPattern>(
81         context, /*benefit=*/2);
82     return vectorizationPatterns;
83   }
84 };
85 
86 /// Codegen strategy controls how a Linalg op is progressively lowered.
87 /// The application uses a 3-level staged patterns strategy which allows
88 /// ordering transformations by using the Linalg `applyStagedPatterns` function,
89 /// where:
90 ///   1. The first stage consists of the successive `tile`, `promote` and
91 ///   `vectorize` patterns, applied sequentially.
92 ///   2. The second stage consists of common local canonicalization patterns
93 ///   that are applied eagerly after each stage-1 pattern.
94 ///   3. the third stage consists of more global transformation, also applied
95 ///   eagerly, after all stage-2 patterns. Such more global transformations
96 struct CodegenStrategy {
97   /// Append a pattern to add a level of tiling for `LinalgOpType` with tiling
98   /// `options`.
99   template <typename LinalgOpType>
tileCodegenStrategy100   CodegenStrategy &tile(linalg::LinalgTilingOptions options) {
101     transformationSequence.emplace_back(new Tile<LinalgOpType>(options));
102     return *this;
103   }
104   /// Conditionally append a pattern to add a level of tiling for `LinalgOpType`
105   /// with tiling `options`.
106   template <typename LinalgOpType>
tileIfCodegenStrategy107   CodegenStrategy &tileIf(bool b, linalg::LinalgTilingOptions options) {
108     return b ? tile<LinalgOpType>(options) : *this;
109   }
110   /// Append a pattern to add a level of promotion for `LinalgOpType` with
111   /// promotion `options`.
112   template <typename LinalgOpType>
promoteCodegenStrategy113   CodegenStrategy &promote(linalg::LinalgPromotionOptions options) {
114     transformationSequence.emplace_back(new Promote<LinalgOpType>(options));
115     return *this;
116   }
117   /// Conditionally append a pattern to add a level of promotion for
118   /// `LinalgOpType` with promotion `options`.
119   template <typename LinalgOpType>
promoteIfCodegenStrategy120   CodegenStrategy &promoteIf(bool b, linalg::LinalgPromotionOptions options) {
121     return b ? promote<LinalgOpType>(options) : *this;
122     return *this;
123   }
124   /// Append a pattern to rewrite `LinalgOpType` as a vector operation.
125   template <typename LinalgOpType>
vectorizeCodegenStrategy126   CodegenStrategy &vectorize() {
127     transformationSequence.emplace_back(new Vectorize<LinalgOpType>());
128     return *this;
129   }
130   /// Conditionally append a pattern to rewrite `LinalgOpType` as a vector
131   /// operation.
132   template <typename LinalgOpType>
vectorizeIfCodegenStrategy133   CodegenStrategy &vectorizeIf(bool b) {
134     return b ? vectorize<LinalgOpType>() : *this;
135     return *this;
136   }
137   /// Configure the post staged-patterns late vector transformations.
138   CodegenStrategy &
setVectorTransformsOptionsCodegenStrategy139   setVectorTransformsOptions(vector::VectorTransformsOptions options) {
140     vectorTransformsOptions = options;
141     return *this;
142   }
143   /// Configure the post staged-patterns late vector.transfer to scf conversion.
144   CodegenStrategy &
setVectorTransferToSCFOptionsCodegenStrategy145   setVectorTransferToSCFOptions(VectorTransferToSCFOptions options) {
146     vectorToSCFOptions = options;
147     return *this;
148   }
149 
150   /// Apply the transformation patterns in sequence with cleanup transformations
151   /// interleaved.
152   void transform(FuncOp func) const;
153 
154 private:
155   LogicalResult postPatternTransforms(Operation *func) const;
156 
157   vector::VectorTransformsOptions vectorTransformsOptions;
158   VectorTransferToSCFOptions vectorToSCFOptions;
159   SmallVector<std::unique_ptr<Transformation>, 4> transformationSequence;
160 };
161 
162 } // namespace linalg
163 } // namespace mlir
164 
165 #endif // MLIR_DIALECT_LINALG_TRANSFORMS_CODEGENSTRATEGY_H_
166