1 //===- CodegenStrategy.h - Linalg programmable codegen strategy -*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_LINALG_TRANSFORMS_CODEGENSTRATEGY_H_ 10 #define MLIR_DIALECT_LINALG_TRANSFORMS_CODEGENSTRATEGY_H_ 11 12 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 15 namespace mlir { 16 17 class FuncOp; 18 19 namespace linalg { 20 21 /// Abstract Transformation class applied in a sequence that also handles state 22 /// through markers. 23 struct Transformation { 24 virtual ~Transformation() = default; 25 virtual OwningRewritePatternList 26 buildRewritePatterns(MLIRContext *context, linalg::LinalgMarker m) = 0; 27 linalg::LinalgMarker marker; 28 }; 29 30 /// Promotion transformation enqueues a particular stage-1 pattern for 31 /// `Tile<LinalgOpType>`with the appropriate `options`. 32 template <typename LinalgOpType> 33 struct Tile : public Transformation { TileTile34 explicit Tile(linalg::LinalgTilingOptions options) : options(options) {} 35 36 OwningRewritePatternList buildRewritePatternsTile37 buildRewritePatterns(MLIRContext *context, linalg::LinalgMarker m) override { 38 OwningRewritePatternList tilingPatterns; 39 tilingPatterns.insert<linalg::LinalgTilingPattern<LinalgOpType>>( 40 context, options, m); 41 return tilingPatterns; 42 } 43 44 private: 45 linalg::LinalgTilingOptions options; 46 }; 47 48 /// Promotion transformation enqueues a particular stage-1 pattern for 49 /// `Promote<LinalgOpType>`with the appropriate `options`. 50 template <typename LinalgOpType> 51 struct Promote : public Transformation { PromotePromote52 explicit Promote(linalg::LinalgPromotionOptions options) : options(options) {} 53 54 OwningRewritePatternList buildRewritePatternsPromote55 buildRewritePatterns(MLIRContext *context, linalg::LinalgMarker m) override { 56 OwningRewritePatternList promotionPatterns; 57 promotionPatterns.insert<linalg::LinalgPromotionPattern<LinalgOpType>>( 58 context, options, m); 59 return promotionPatterns; 60 } 61 62 private: 63 linalg::LinalgPromotionOptions options; 64 }; 65 66 /// Vectorization transformation enqueues a particular stage-1 pattern for 67 /// `LinalgVectorizationPattern<LinalgOpType>` as well as copy to vector 68 /// transfer rewrite forwarding patterns. 69 template <typename LinalgOpType> 70 struct Vectorize : public Transformation { 71 OwningRewritePatternList buildRewritePatternsVectorize72 buildRewritePatterns(MLIRContext *context, linalg::LinalgMarker m) override { 73 OwningRewritePatternList vectorizationPatterns; 74 // FillOp may interfere with forwarding patterns atm, so we bump up the 75 // priority of LinalgCopyVTRForwardingPattern / 76 // LinalgCopyVTWForwardingPattern. 77 vectorizationPatterns 78 .insert<linalg::LinalgVectorizationPattern<LinalgOpType>>(context, m); 79 vectorizationPatterns.insert<linalg::LinalgCopyVTRForwardingPattern, 80 linalg::LinalgCopyVTWForwardingPattern>( 81 context, /*benefit=*/2); 82 return vectorizationPatterns; 83 } 84 }; 85 86 /// Codegen strategy controls how a Linalg op is progressively lowered. 87 /// The application uses a 3-level staged patterns strategy which allows 88 /// ordering transformations by using the Linalg `applyStagedPatterns` function, 89 /// where: 90 /// 1. The first stage consists of the successive `tile`, `promote` and 91 /// `vectorize` patterns, applied sequentially. 92 /// 2. The second stage consists of common local canonicalization patterns 93 /// that are applied eagerly after each stage-1 pattern. 94 /// 3. the third stage consists of more global transformation, also applied 95 /// eagerly, after all stage-2 patterns. Such more global transformations 96 struct CodegenStrategy { 97 /// Append a pattern to add a level of tiling for `LinalgOpType` with tiling 98 /// `options`. 99 template <typename LinalgOpType> tileCodegenStrategy100 CodegenStrategy &tile(linalg::LinalgTilingOptions options) { 101 transformationSequence.emplace_back(new Tile<LinalgOpType>(options)); 102 return *this; 103 } 104 /// Conditionally append a pattern to add a level of tiling for `LinalgOpType` 105 /// with tiling `options`. 106 template <typename LinalgOpType> tileIfCodegenStrategy107 CodegenStrategy &tileIf(bool b, linalg::LinalgTilingOptions options) { 108 return b ? tile<LinalgOpType>(options) : *this; 109 } 110 /// Append a pattern to add a level of promotion for `LinalgOpType` with 111 /// promotion `options`. 112 template <typename LinalgOpType> promoteCodegenStrategy113 CodegenStrategy &promote(linalg::LinalgPromotionOptions options) { 114 transformationSequence.emplace_back(new Promote<LinalgOpType>(options)); 115 return *this; 116 } 117 /// Conditionally append a pattern to add a level of promotion for 118 /// `LinalgOpType` with promotion `options`. 119 template <typename LinalgOpType> promoteIfCodegenStrategy120 CodegenStrategy &promoteIf(bool b, linalg::LinalgPromotionOptions options) { 121 return b ? promote<LinalgOpType>(options) : *this; 122 return *this; 123 } 124 /// Append a pattern to rewrite `LinalgOpType` as a vector operation. 125 template <typename LinalgOpType> vectorizeCodegenStrategy126 CodegenStrategy &vectorize() { 127 transformationSequence.emplace_back(new Vectorize<LinalgOpType>()); 128 return *this; 129 } 130 /// Conditionally append a pattern to rewrite `LinalgOpType` as a vector 131 /// operation. 132 template <typename LinalgOpType> vectorizeIfCodegenStrategy133 CodegenStrategy &vectorizeIf(bool b) { 134 return b ? vectorize<LinalgOpType>() : *this; 135 return *this; 136 } 137 /// Configure the post staged-patterns late vector transformations. 138 CodegenStrategy & setVectorTransformsOptionsCodegenStrategy139 setVectorTransformsOptions(vector::VectorTransformsOptions options) { 140 vectorTransformsOptions = options; 141 return *this; 142 } 143 /// Configure the post staged-patterns late vector.transfer to scf conversion. 144 CodegenStrategy & setVectorTransferToSCFOptionsCodegenStrategy145 setVectorTransferToSCFOptions(VectorTransferToSCFOptions options) { 146 vectorToSCFOptions = options; 147 return *this; 148 } 149 150 /// Apply the transformation patterns in sequence with cleanup transformations 151 /// interleaved. 152 void transform(FuncOp func) const; 153 154 private: 155 LogicalResult postPatternTransforms(Operation *func) const; 156 157 vector::VectorTransformsOptions vectorTransformsOptions; 158 VectorTransferToSCFOptions vectorToSCFOptions; 159 SmallVector<std::unique_ptr<Transformation>, 4> transformationSequence; 160 }; 161 162 } // namespace linalg 163 } // namespace mlir 164 165 #endif // MLIR_DIALECT_LINALG_TRANSFORMS_CODEGENSTRATEGY_H_ 166