1 //===- Transforms.h - Linalg transformations as patterns --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_
10 #define DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_
11 
12 #include "mlir/Dialect/Linalg/Utils/Utils.h"
13 #include "mlir/Dialect/Vector/VectorOps.h"
14 #include "mlir/IR/Identifier.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Transforms/Bufferize.h"
17 #include "llvm/ADT/SmallBitVector.h"
18 #include "llvm/ADT/SmallSet.h"
19 
20 namespace mlir {
21 class BufferizeTypeConverter;
22 class FrozenRewritePatternList;
23 
24 namespace linalg {
25 
26 struct LinalgFusionOptions;
27 struct LinalgTilingOptions;
28 
29 //===----------------------------------------------------------------------===//
30 // Transformations exposed as function calls.
31 //===----------------------------------------------------------------------===//
32 using LinalgLoops = SmallVector<Operation *, 4>;
33 
34 struct TiledLinalgOp {
35   LinalgOp op;
36   SmallVector<Operation *, 8> loops;
37   SmallVector<Value, 4> tensorResults;
38 };
39 
40 /// Populates patterns for vectorization of all ConvN-D ops.
41 void populateConvVectorizationPatterns(
42     MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
43     ArrayRef<int64_t> tileSizes);
44 
45 /// Populates the given list with patterns to bufferize linalg ops.
46 void populateLinalgBufferizePatterns(MLIRContext *context,
47                                      BufferizeTypeConverter &converter,
48                                      OwningRewritePatternList &patterns);
49 
50 /// Performs standalone tiling of a single LinalgOp by `tileSizes`.
51 /// and permute the loop nest according to `interchangeVector`
52 /// The permutation is expressed as a list of integers that specify
53 /// the new ordering of the loop nest. The length of `interchangeVector`
54 /// must be equal to the length of `tileSizes`.
55 /// An empty vector is interpreted as the identity permutation and the
56 /// transformation returns early.
57 ///
58 /// Returns a struct containing the tiled loops in the specified order
59 /// and the cloned op if successful, llvm::None otherwise.
60 ///
61 /// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed by
62 /// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
63 /// integers, in the range 0..`tileSizes.size()` without duplications
64 /// (i.e. `[1,1,2]` is an invalid permutation).
65 Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
66                                      const LinalgTilingOptions &options);
67 
68 /// Fuse a sequence of linalg operations (`ops`) using tile-and-fuse. This
69 /// proceeds as follows:
70 /// - Find outer parallel loops in these ops that can be fused.
71 /// - Tile fusable outer parallel loops of the last operation in the sequence.
72 /// - Fuse the remaining operations with the tiled operation
73 ///
74 /// For example, consider the sequence of matmul below
75 ///
76 ///   linalg.matmul ins(%arg0, %arg1 : memref<256x32xf32>, memref<32x32xf32>)
77 ///                 outs(%arg2 : memref<256x32xf32>)
78 ///   linalg.matmul ins(%arg2, %arg3 : memref<256x32xf32>, memref<32x32xf32>)
79 ///                 outs(%arg4 : memref<256x32xf32>)
80 ///
81 /// It is legal to fuse the RAW dependence (through %arg2) by only fusing the
82 /// matmuls row-wise. For example, the fused computation for the above is shown
83 /// below. The outer `scf.parallel` loop is the "fused" loop obtained by tiling
84 /// along the rows of the matrix. The entire rows of the first matmul operation
85 /// need to be computed before they can be used for the second matmul. The
86 /// second matmul is further tiled (similar to normal tiling).
87 ///
88 /// #map0 = affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)>
89 /// #map1 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
90 /// scf.parallel (%arg5) = (%c0) to (%c256) step (%c16) {
91 ///   %0 = subview %arg2[%arg5, 0] [16, 32] [1, 1]
92 ///     : memref<256x32xf32> to memref<16x32xf32, #map0>
93 ///   %1 = subview %arg4[%arg5, 0] [16, 32] [1, 1]
94 ///     : memref<256x32xf32> to memref<16x32xf32, #map0>
95 ///   %2 = subview %arg0[%arg5, 0] [16, 32] [1, 1]
96 ///     : memref<256x32xf32> to memref<16x32xf32, #map0>
97 ///   %3 = subview %arg1[0, 0] [32, 32] [1, 1]
98 ///     : memref<32x32xf32> to memref<32x32xf32, #map1>
99 ///   %4 = subview %arg3[0, 0] [32, 32] [1, 1]
100 ///     : memref<32x32xf32> to memref<32x32xf32, #map1>
101 ///   linalg.matmul
102 ///     ins(%2, %3 : memref<16x32xf32, #map0>, memref<32x32xf32, #map1>)
103 ///     outs(%0 : memref<16x32xf32, #map0>)
104 ///   linalg.matmul
105 ///     ins(%0, %4 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>)
106 ///     outs(%1 : memref<16x8xf32, #map0>)
107 /// }
108 ///
109 /// `tilingOptions` are used to tile the corresponding operation in `ops` (the
110 /// size of the former should be same as size of the latter. Based on how
111 /// tile+fuse is implemented, the fused loops are generated based on the last
112 /// operation in the sequence. For example, the tile sizes for the fused loops
113 /// is obtained from `tilingOptions.back()`. The following tiling options are
114 /// handled differently in tile+fuse (compared to tile only)
115 /// - Interchange of the tiling loops is not supported right now.
116 /// - Only the fused loops are distributed.
117 struct TiledAndFusedLinalgOps {
118   /// Operation obtained by tiling the last operation in sequence of `ops`
119   /// passed to `tileAndFuseLinalgOps`.
120   LinalgOp op;
121   /// The dimension of the loops that are fused.
122   std::set<unsigned> fusedLoopDims;
123   /// The generated fused operations (created within the fused loops).
124   SmallVector<LinalgOp, 1> fusedProducers;
125   /// The fused loop generated.
126   SmallVector<Operation *, 4> fusedLoops;
127 };
128 Optional<TiledAndFusedLinalgOps>
129 tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
130                      const LinalgDependenceGraph &dependenceGraph,
131                      const LinalgTilingOptions &tilingOptions);
132 
133 /// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`.
134 /// This is an in-place transformation controlled by `interchangeVector`.
135 /// An empty vector is interpreted as the identity permutation and the
136 /// transformation returns early.
137 ///
138 /// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed with
139 /// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
140 /// integers, in the range 0..`op.rank` without duplications
141 /// (i.e. `[1,1,2]` is an invalid permutation).
142 LinalgOp interchange(LinalgOp op, ArrayRef<unsigned> interchangeVector);
143 
144 /// Callback function type used to perform the allocation for the promoted
145 /// `subView`. In `boundingSubViewsize` a best attempt is made to find the
146 /// smallest constant value for the size of the buffer needed for each
147 /// dimension. If that is not possible, contains the dynamic size of the
148 /// subview. The call back should return the buffer to use.
149 using AllocBufferCallbackFn = std::function<Optional<Value>(
150     OpBuilder &b, SubViewOp subView, ArrayRef<Value> boundingSubViewSize,
151     OperationFolder *folder)>;
152 
153 /// Callback function type used to deallocate the buffers used to hold the
154 /// promoted subview.
155 using DeallocBufferCallbackFn =
156     std::function<LogicalResult(OpBuilder &b, Value buffer)>;
157 
158 /// Callback function type used to insert copy from original subview to subview
159 /// of the promoted region for the read operands/subview of promoted region to
160 /// original subview for the results. The copy has to happen from `src` to
161 /// `dst`.
162 using CopyCallbackFn =
163     std::function<LogicalResult(OpBuilder &b, Value src, Value dst)>;
164 
165 struct LinalgPromotionOptions {
166   /// Indices of subViews to promote. If `None`, try to promote all operands.
167   Optional<DenseSet<unsigned>> operandsToPromote = None;
setOperandsToPromoteLinalgPromotionOptions168   LinalgPromotionOptions &setOperandsToPromote(ArrayRef<int64_t> operands) {
169     operandsToPromote = DenseSet<unsigned>();
170     operandsToPromote->insert(operands.begin(), operands.end());
171     return *this;
172   }
173   /// If ith element of `useFullTiles` is true the full view should be used for
174   /// the promoted buffer of the ith operand in `operandsToPromote`. Otherwise
175   /// the partial view will be used.
176   /// The decision is defaulted to `useFullTileBuffersDefault` when
177   /// `useFullTileBuffers` is None and for operands missing from
178   /// `useFullTileBuffers`.
179   Optional<llvm::SmallBitVector> useFullTileBuffers = None;
setUseFullTileBuffersLinalgPromotionOptions180   LinalgPromotionOptions &setUseFullTileBuffers(ArrayRef<bool> useFullTiles) {
181     unsigned size = useFullTiles.size();
182     llvm::SmallBitVector tmp(size, false);
183     for (unsigned i = 0; i < size; ++i)
184       tmp[i] = useFullTiles[i];
185     useFullTileBuffers = tmp;
186     return *this;
187   }
188   /// If true all operands unspecified by `useFullTileBuffers` will use the full
189   /// view, otherwise the partial view.
190   bool useFullTileBuffersDefault = false;
setUseFullTileBuffersByDefaultLinalgPromotionOptions191   LinalgPromotionOptions &setUseFullTileBuffersByDefault(bool use) {
192     useFullTileBuffersDefault = use;
193     return *this;
194   }
195   /// Allow the use of dynamically-sized buffers.
196   bool dynamicBuffers = false;
setDynamicBuffersLinalgPromotionOptions197   LinalgPromotionOptions &setDynamicBuffers(unsigned dynamic) {
198     dynamicBuffers = dynamic;
199     return *this;
200   }
201   /// Alignment of promoted buffer. If `None` do not specify alignment.
202   Optional<unsigned> alignment = None;
setAlignmentLinalgPromotionOptions203   LinalgPromotionOptions &setAlignment(unsigned align) {
204     alignment = align;
205     return *this;
206   }
207   /// Use alloca with the default allocation scheme.
208   bool useAlloca = false;
setUseAllocaLinalgPromotionOptions209   LinalgPromotionOptions &setUseAlloca(bool use) {
210     useAlloca = use;
211     return *this;
212   }
213   /// Callback function to do the allocation of the promoted buffer. If None,
214   /// then the default allocation scheme of allocating a memref<?xi8> buffer
215   /// followed by a view operation is used.
216   Optional<AllocBufferCallbackFn> allocationFn = None;
217   Optional<DeallocBufferCallbackFn> deallocationFn = None;
218   LinalgPromotionOptions &
setAllocationDeallocationFnsLinalgPromotionOptions219   setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn,
220                                DeallocBufferCallbackFn const &deallocFn) {
221     allocationFn = allocFn;
222     deallocationFn = deallocFn;
223     return *this;
224   }
225   /// Callback function to do the copy of data to and from the promoted
226   /// subview. If None then a linalg.copy is used.
227   Optional<CopyCallbackFn> copyInFn = None;
228   Optional<CopyCallbackFn> copyOutFn = None;
setCopyInOutFnsLinalgPromotionOptions229   LinalgPromotionOptions &setCopyInOutFns(CopyCallbackFn const &copyIn,
230                                           CopyCallbackFn const &copyOut) {
231     copyInFn = copyIn;
232     copyOutFn = copyOut;
233     return *this;
234   }
235 };
236 
237 /// Creates a new buffer using the `allocationFn` provided. The size of this
238 /// buffer is the smallest constant bounding size along each dimension that can
239 /// be computed for the size of the result of `subView`. Returns the allocated
240 /// buffer as `fullLocalView` and the view that matches the size of the result
241 /// of subview operation as `partialLocalView`.
242 struct PromotionInfo {
243   Value fullLocalView;
244   Value partialLocalView;
245 };
246 Optional<PromotionInfo>
247 promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, SubViewOp subView,
248                           AllocBufferCallbackFn allocationFn,
249                           OperationFolder *folder = nullptr);
250 
251 /// Promotes the `subViews` into a new buffer allocated at the insertion point
252 /// `b`. Promotion occurs in 3 steps:
253 ///   1. Create a new buffer for a full tile (i.e. not clipped at the boundary).
254 ///   2. Take a full view on the buffer.
255 ///   3. Take a partial slice of the full view in step 2. and copy into it.
256 /// Infers statically sized buffers from subViews unless `dynamicBuffers` is
257 /// true.
258 ///
259 /// Returns the modified linalg op (the modification happens in place) as well
260 /// as all the copy ops created.
261 Optional<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
262                                    LinalgPromotionOptions options,
263                                    OperationFolder *folder = nullptr);
264 
265 /// Emit a suitable vector form for a Linalg op with fully static shape.
266 void vectorizeLinalgOp(OpBuilder &builder, Operation *op);
267 
268 /// Emits a loop nest of `LoopTy` with the proper body for `op`.
269 template <typename LoopTy>
270 Optional<LinalgLoops> linalgLowerOpToLoops(OpBuilder &builder, Operation *op);
271 
272 /// Emits a loop nest of `scf.for` with the proper body for `op`.
273 LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op);
274 
275 /// Emits a loop nest of `scf.parallel` with the proper body for `op`.
276 LogicalResult linalgOpToParallelLoops(OpBuilder &builder, Operation *op);
277 
278 /// Emits a loop nest of `affine.for` with the proper body for `op`.
279 LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op);
280 
281 //===----------------------------------------------------------------------===//
282 // Preconditions that ensure the corresponding transformation succeeds and can
283 // be applied as a rewrite pattern.
284 //===----------------------------------------------------------------------===//
285 /// Emits a `generic` or `indexed_generic` operation with the `indexing_maps`
286 /// and `iterator_types` permutated according to `permutation`.
287 LogicalResult
288 interchangeGenericLinalgOpPrecondition(Operation *op,
289                                        ArrayRef<unsigned> interchangeVector);
290 
291 /// Promote std.subviews feeding linalg operations.
292 LogicalResult promoteSubviewsPrecondition(Operation *op,
293                                           LinalgPromotionOptions options);
294 
295 /// Rewrite a linalg.generic into a suitable vector.contraction op.
296 LogicalResult vectorizeLinalgOpPrecondition(Operation *op);
297 
298 //===----------------------------------------------------------------------===//
299 // Transformations exposed as rewrite patterns.
300 //===----------------------------------------------------------------------===//
301 // Marker used as attribute name in generated Linalg rewriting transformations.
302 struct LinalgTransforms {
303   static const StringLiteral kLinalgTransformMarker;
304 };
305 
306 /// Helper class to control common attribute matching and setting behavior.
307 struct LinalgMarker {
308   explicit LinalgMarker(ArrayRef<Identifier> matchDisjunction = {},
309                         Optional<Identifier> replacement = None);
310   LinalgMarker(LinalgMarker &&) = default;
311   LinalgMarker(const LinalgMarker &) = default;
312   LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const;
313   void replaceLinalgMarker(PatternRewriter &rewriter, Operation *op) const;
314 
315 private:
316   SmallVector<Identifier, 4> matchDisjunction;
317   Optional<Identifier> replacement;
318 };
319 
320 ///
321 /// Linalg tiling patterns.
322 ///
323 /// Apply the `tileLinalgOp` transformation as a pattern.
324 /// `marker` controls LinalgTransformMarker matching and update when specified.
325 /// See `tileLinalgOp` for more details.
326 enum class LinalgTilingLoopType {
327   Loops = 0,
328   AffineLoops = 1,
329   ParallelLoops = 2,
330 };
331 
332 using TileSizeComputationFunction =
333     std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>;
334 
335 struct LinalgTilingOptions {
336   /// Computation function that returns the tile sizes for each operation.
337   /// Delayed construction of constant tile sizes should occur to interoperate
338   /// with folding.
339   TileSizeComputationFunction tileSizeComputationFunction = nullptr;
340 
341   LinalgTilingOptions &
setTileSizeComputationFunctionLinalgTilingOptions342   setTileSizeComputationFunction(TileSizeComputationFunction fun) {
343     tileSizeComputationFunction = std::move(fun);
344     return *this;
345   }
346   /// Set the `tileSizeComputationFunction` to return the values `ts`. The
347   /// values must not fold away when tiling. Otherwise, use a more robust
348   /// `tileSizeComputationFunction`.
setTileSizesLinalgTilingOptions349   LinalgTilingOptions &setTileSizes(SmallVector<Value, 4> ts) {
350     tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; };
351     return *this;
352   }
353   /// Convenience function to set the `tileSizeComputationFunction` to a
354   /// function that computes tile sizes at the point they are needed. Allows
355   /// proper interaction with folding.
356   LinalgTilingOptions &setTileSizes(ArrayRef<int64_t> ts);
357 
358   /// The interchange vector to reorder the tiled loops.
359   SmallVector<unsigned, 4> interchangeVector = {};
360 
setInterchangeLinalgTilingOptions361   LinalgTilingOptions &setInterchange(ArrayRef<unsigned> interchange) {
362     interchangeVector.assign(interchange.begin(), interchange.end());
363     return *this;
364   }
365 
366   /// The type of tile loops to generate.
367   LinalgTilingLoopType loopType = LinalgTilingLoopType::Loops;
368 
setLoopTypeLinalgTilingOptions369   LinalgTilingOptions &setLoopType(LinalgTilingLoopType lt) {
370     loopType = lt;
371     return *this;
372   }
373 
374   /// When specified, specifies distribution of generated tile loops to
375   /// processors.
376   Optional<LinalgLoopDistributionOptions> distribution = None;
377 
378   LinalgTilingOptions &
setDistributionOptionsLinalgTilingOptions379   setDistributionOptions(LinalgLoopDistributionOptions distributionOptions) {
380     distribution = std::move(distributionOptions);
381     return *this;
382   }
383 };
384 
385 /// Canonicalization patterns relevant to apply after tiling patterns. These are
386 /// applied automatically by the tiling pass but need to be applied manually
387 /// when tiling is called programmatically.
388 OwningRewritePatternList
389 getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx);
390 void populateLinalgTilingCanonicalizationPatterns(
391     OwningRewritePatternList &patterns, MLIRContext *ctx);
392 
393 struct LinalgBaseTilingPattern : public RewritePattern {
394   // Entry point to match any LinalgOp OpInterface.
395   LinalgBaseTilingPattern(LinalgTilingOptions options,
396                           LinalgMarker marker = LinalgMarker(),
397                           PatternBenefit benefit = 1);
398   // Entry point to match a specific Linalg op.
399   LinalgBaseTilingPattern(StringRef opName, MLIRContext *context,
400                           LinalgTilingOptions options,
401                           LinalgMarker marker = LinalgMarker(),
402                           PatternBenefit benefit = 1);
403   LogicalResult
404   matchAndRewriteBase(Operation *op, PatternRewriter &rewriter,
405                       SmallVectorImpl<Value> &tensorResults) const;
406 
407 private:
408   /// LinalgTransformMarker handles special attribute manipulations.
409   LinalgMarker marker;
410   /// Options to control tiling;
411   LinalgTilingOptions options;
412 };
413 
414 template <typename OpTy>
415 struct LinalgTilingPattern : public LinalgBaseTilingPattern {
416   LinalgTilingPattern(MLIRContext *context, LinalgTilingOptions options,
417                       LinalgMarker marker = LinalgMarker(),
418                       PatternBenefit benefit = 1)
LinalgBaseTilingPatternLinalgTilingPattern419       : LinalgBaseTilingPattern(OpTy::getOperationName(), context, options,
420                                 marker, benefit) {}
matchAndRewriteLinalgTilingPattern421   LogicalResult matchAndRewrite(Operation *op,
422                                 PatternRewriter &rewriter) const override {
423     SmallVector<Value, 4> tensorResults;
424     if (failed(LinalgBaseTilingPattern::matchAndRewriteBase(op, rewriter,
425                                                             tensorResults)))
426       return failure();
427     if (tensorResults.empty())
428       rewriter.eraseOp(op);
429     else
430       rewriter.replaceOp(op, tensorResults);
431     return success();
432   }
433 };
434 
435 struct LinalgFusionOptions {
436   /// List of operands indices to use for fusion.
437   llvm::SmallSet<unsigned, 1> indicesToFuse = {};
setIndicesToFuseLinalgFusionOptions438   LinalgFusionOptions &setIndicesToFuse(ArrayRef<int64_t> operands) {
439     indicesToFuse.insert(operands.begin(), operands.end());
440     return *this;
441   }
442 };
443 
444 struct LinalgBaseTileAndFusePattern : public RewritePattern {
445   LinalgBaseTileAndFusePattern(StringRef opName, MLIRContext *context,
446                                const LinalgDependenceGraph &dependenceGraph,
447                                LinalgTilingOptions tilingOptions,
448                                LinalgFusionOptions fusionOptions,
449                                LinalgMarker marker = LinalgMarker(),
450                                LinalgMarker fusedOpMarker = LinalgMarker(),
451                                LinalgMarker originalOpMarker = LinalgMarker(),
452                                PatternBenefit benefit = 1);
453   LogicalResult matchAndRewrite(Operation *op,
454                                 PatternRewriter &rewriter) const override;
455 
456 private:
457   /// Dependence graph needed for fusion.
458   const LinalgDependenceGraph &dependenceGraph;
459   /// Options to control tiling.
460   LinalgTilingOptions tilingOptions;
461   /// Options to control fusion.
462   LinalgFusionOptions fusionOptions;
463   /// Marker to control application of the pattern.
464   LinalgMarker marker;
465   /// Marker set on the fused op after tile and fuse.
466   LinalgMarker fusedOpMarker;
467   /// The dependenceGraph is not modifiable, i.e. if the Linalg operations used
468   /// to build the dependence graph changes then the dependenceGraph needs to be
469   /// recomputed right now. To not invalidate the dependenceGraph as
470   /// transformation happens, the original producer can be tagged with a marker
471   /// that can be later used to delete the original operations.
472   LinalgMarker originalOpMarker;
473 };
474 
475 template <typename OpTy>
476 struct LinalgTileAndFusePattern : public LinalgBaseTileAndFusePattern {
477   LinalgTileAndFusePattern(MLIRContext *context,
478                            const LinalgDependenceGraph &dependenceGraph,
479                            LinalgTilingOptions tilingOptions,
480                            LinalgFusionOptions fusionOptions,
481                            LinalgMarker marker = LinalgMarker(),
482                            LinalgMarker fusedOpMarker = LinalgMarker(),
483                            LinalgMarker originalOpMarker = LinalgMarker(),
484                            PatternBenefit benefit = 1)
LinalgBaseTileAndFusePatternLinalgTileAndFusePattern485       : LinalgBaseTileAndFusePattern(
486             OpTy::getOperationName(), context, dependenceGraph, tilingOptions,
487             fusionOptions, marker, fusedOpMarker, originalOpMarker, benefit) {}
488 };
489 
490 ///
491 /// Linalg interchange patterns.
492 ///
493 /// Apply the `interchange` transformation as a pattern.
494 /// `marker` controls LinalgTransformMarker matching and update when specified.
495 /// See `interchange` for more details.
496 struct LinalgBaseInterchangePattern : public RewritePattern {
497   LinalgBaseInterchangePattern(StringRef opName, MLIRContext *context,
498                                ArrayRef<unsigned> interchangeVector,
499                                LinalgMarker marker = LinalgMarker(),
500                                PatternBenefit benefit = 1);
501   LogicalResult matchAndRewrite(Operation *op,
502                                 PatternRewriter &rewriter) const override;
503 
504 private:
505   /// LinalgTransformMarker handles special attribute manipulations.
506   LinalgMarker marker;
507   /// The interchange vector to reorder the iterators and indexing_maps dims.
508   SmallVector<unsigned, 8> interchangeVector;
509 };
510 
511 template <typename OpTy>
512 struct LinalgInterchangePattern : public LinalgBaseInterchangePattern {
513   LinalgInterchangePattern(MLIRContext *context,
514                            ArrayRef<unsigned> interchangeVector,
515                            LinalgMarker marker = LinalgMarker(),
516                            PatternBenefit benefit = 1)
LinalgBaseInterchangePatternLinalgInterchangePattern517       : LinalgBaseInterchangePattern(OpTy::getOperationName(), context,
518                                      interchangeVector, marker, benefit) {}
519 };
520 
521 ///
522 /// Linalg promotion patterns.
523 ///
524 /// Apply the `promoteSubViews` transformation as a pattern.
525 /// `marker` controls LinalgTransformMarker matching and update when specified.
526 /// See `promoteSubViews` for more details.
527 struct LinalgBasePromotionPattern : public RewritePattern {
528   LinalgBasePromotionPattern(StringRef opName, MLIRContext *context,
529                              LinalgPromotionOptions options,
530                              LinalgMarker marker = LinalgMarker(),
531                              PatternBenefit benefit = 1);
532   LogicalResult matchAndRewrite(Operation *op,
533                                 PatternRewriter &rewriter) const override;
534 
535 private:
536   /// LinalgTransformMarker handles special attribute manipulations.
537   LinalgMarker marker;
538   /// Promotion options.
539   LinalgPromotionOptions options;
540 };
541 
542 template <typename OpTy>
543 struct LinalgPromotionPattern : public LinalgBasePromotionPattern {
544   LinalgPromotionPattern(MLIRContext *context, LinalgPromotionOptions options,
545                          LinalgMarker marker = LinalgMarker(),
546                          PatternBenefit benefit = 1)
LinalgBasePromotionPatternLinalgPromotionPattern547       : LinalgBasePromotionPattern(OpTy::getOperationName(), context, options,
548                                    marker, benefit) {}
549 };
550 
551 ///
552 /// Linalg vectorization patterns.
553 ///
554 /// Apply the `vectorizeLinalgOp` transformation as a pattern.
555 /// `marker` controls LinalgTransformMarker matching and update when specified.
556 /// See `vectorizeLinalgOp` for more details.
557 struct LinalgBaseVectorizationPattern : public RewritePattern {
558   LinalgBaseVectorizationPattern(StringRef opName, MLIRContext *context,
559                                  LinalgMarker marker = LinalgMarker(),
560                                  PatternBenefit benefit = 1);
561   LogicalResult matchAndRewrite(Operation *op,
562                                 PatternRewriter &rewriter) const override;
563 
564 private:
565   /// LinalgTransformMarker handles special attribute manipulations.
566   LinalgMarker marker;
567 };
568 
569 template <typename OpTy>
570 struct LinalgVectorizationPattern : public LinalgBaseVectorizationPattern {
571   LinalgVectorizationPattern(MLIRContext *context,
572                              LinalgMarker marker = LinalgMarker(),
573                              PatternBenefit benefit = 1)
LinalgBaseVectorizationPatternLinalgVectorizationPattern574       : LinalgBaseVectorizationPattern(OpTy::getOperationName(), context,
575                                        marker, benefit) {}
576 };
577 
578 ///
579 /// Linalg lowering patterns.
580 ///
581 /// Apply the `linalgLowerOpToLoops` transformation as a pattern.
582 /// `marker` controls LinalgTransformMarker matching and update when specified.
583 /// See `linalgLowerOpToLoops` for more details.
584 enum class LinalgLoweringType {
585   LibraryCall = 0,
586   Loops = 1,
587   AffineLoops = 2,
588   ParallelLoops = 3
589 };
590 template <typename OpTy>
591 struct LinalgLoweringPattern : public RewritePattern {
592   LinalgLoweringPattern(MLIRContext *context, LinalgLoweringType loweringType,
593                         LinalgMarker marker = LinalgMarker(),
594                         PatternBenefit benefit = 1)
595       : RewritePattern(OpTy::getOperationName(), {}, benefit, context),
596         marker(marker), loweringType(loweringType) {}
597   // TODO: Move implementation to .cpp once named ops are auto-generated.
matchAndRewriteLinalgLoweringPattern598   LogicalResult matchAndRewrite(Operation *op,
599                                 PatternRewriter &rewriter) const override {
600     LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
601     if (!linalgOp)
602       return failure();
603     if (failed(marker.checkAndNotify(rewriter, linalgOp)))
604       return failure();
605 
606     if (loweringType == LinalgLoweringType::LibraryCall) {
607       // TODO: Move lowering to library calls here.
608       return failure();
609     } else if (loweringType == LinalgLoweringType::Loops) {
610       if (failed(linalgOpToLoops(rewriter, op)))
611         return failure();
612     } else if (loweringType == LinalgLoweringType::AffineLoops) {
613       if (failed(linalgOpToAffineLoops(rewriter, op)))
614         return failure();
615     } else if (failed(linalgOpToParallelLoops(rewriter, op))) {
616       return failure();
617     }
618     rewriter.eraseOp(op);
619     return success();
620   }
621 
622 private:
623   /// LinalgTransformMarker handles special attribute manipulations.
624   LinalgMarker marker;
625   /// Controls whether the pattern lowers to library calls, scf.for, affine.for
626   /// or scf.parallel.
627   LinalgLoweringType loweringType;
628 };
629 
630 /// Linalg generalization patterns
631 
632 /// Populates `patterns` with patterns to convert spec-generated named ops to
633 /// linalg.generic ops.
634 void populateLinalgNamedOpsGeneralizationPatterns(
635     MLIRContext *context, OwningRewritePatternList &patterns,
636     LinalgMarker marker = LinalgMarker());
637 
638 /// Populates `patterns` with patterns to convert linalg.conv ops to
639 /// linalg.generic ops.
640 void populateLinalgConvGeneralizationPatterns(
641     MLIRContext *context, OwningRewritePatternList &patterns,
642     LinalgMarker marker = LinalgMarker());
643 
644 //===----------------------------------------------------------------------===//
645 // Op-specific patterns.
646 //===----------------------------------------------------------------------===//
647 /// Match and rewrite for the pattern:
648 /// ```
649 ///    %alloc = ...
650 ///    [optional] %view = std.view %alloc ...
651 ///    %subView = subview %allocOrView ...
652 ///    [optional] linalg.fill(%allocOrView, %cst) ...
653 ///    ...
654 ///    linalg.copy(%in, %subView) ...
655 ///    vector.transfer_read %allocOrView[...], %cst ...
656 /// ```
657 /// into
658 /// ```
659 ///    [unchanged] %alloc = ...
660 ///    [unchanged] [optional] %view = std.view %alloc ...
661 ///    [unchanged] [unchanged] %subView = subview %allocOrView ...
662 ///    ...
663 ///    vector.transfer_read %in[...], %cst ...
664 /// ```
665 /// Where there is no interleaved use between linalg.copy and transfer_read as
666 /// well as no interleaved use between linalg.fill and linalg.copy (if
667 /// linalg.fill is specified).
668 /// This is a custom rewrite to forward partial reads (with optional fills) to
669 /// vector.transfer_read.
670 struct LinalgCopyVTRForwardingPattern
671     : public OpRewritePattern<vector::TransferReadOp> {
672   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
673 
674   LogicalResult matchAndRewrite(vector::TransferReadOp xferOp,
675                                 PatternRewriter &rewriter) const override;
676 };
677 
678 /// Match and rewrite for the pattern:
679 /// ```
680 ///    %alloc = ...
681 ///    [optional] %view = std.view %alloc ...
682 ///    %subView = subview %allocOrView...
683 ///    ...
684 ///    vector.transfer_write %..., %allocOrView[...]
685 ///    linalg.copy(%subView, %out)
686 /// ```
687 /// into
688 /// ```
689 ///    [unchanged] %alloc = ...
690 ///    [unchanged] [optional] %view = std.view %alloc ...
691 ///    [unchanged] %subView = subview %allocOrView...
692 ///    ...
693 ///    vector.transfer_write %..., %out[...]
694 /// ```
695 /// Where there is no interleaved use between transfer_write and linalg.copy.
696 /// This is a custom rewrite to forward partial writes to vector.transfer_write.
697 struct LinalgCopyVTWForwardingPattern
698     : public OpRewritePattern<vector::TransferWriteOp> {
699   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
700 
701   LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
702                                 PatternRewriter &rewriter) const override;
703 };
704 
705 /// Canonicalize AffineMinOp operations in the context of enclosing scf.for and
706 /// scf.parallel by:
707 ///   1. building an affine map where uses of the induction variable of a loop
708 ///   are replaced by either the min (i.e. `%lb`) of the max
709 ///   (i.e. `%lb + %step * floordiv(%ub -1 - %lb, %step)`) expression, depending
710 ///   on whether the induction variable is used with a positive or negative
711 ///   coefficient.
712 ///   2. checking whether any of the results of this affine map is known to be
713 ///   greater than all other results.
714 ///   3. replacing the AffineMinOp by the result of (2).
715 // TODO: move to a more appropriate place when it is determined. For now Linalg
716 // depends both on Affine and SCF but they do not depend on each other.
717 struct AffineMinSCFCanonicalizationPattern
718     : public OpRewritePattern<AffineMinOp> {
719   using OpRewritePattern<AffineMinOp>::OpRewritePattern;
720 
721   LogicalResult matchAndRewrite(AffineMinOp minOp,
722                                 PatternRewriter &rewriter) const override;
723 };
724 
725 /// Converts Convolution op into vector contraction.
726 ///
727 /// Conversion expects ConvOp to have dimensions marked in the *mask* as
728 /// false of size 1. This ensures that the ConvOp can be lowered to vector
729 /// contraction of dimensions marked in the *mask* as true.
730 ///
731 /// A good example for vectorization is ConvNHWCOp which is 2D Conv op
732 /// with channels as the last dimension. Let's vectorize last 3 dimensions.
733 /// The initial op definition looks like this:
734 /// ```
735 /// linalg.conv_2d_nhwc  %arg0, %arg1, %arg2 :
736 ///   (memref<1x3x3x3xf32>, memref<1x3x3x3xf32>, memref<?x?x?x?xf32>)
737 /// ```
738 /// This op can be expressed as a dot product between %arg0 (input) and
739 /// %arg1 (kernel) which is written into first entry of %arg2 (output). This is
740 /// the ConvOp this pass expects and converts into:
741 /// ```
742 /// #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
743 /// #map1 = affine_map<(d0, d1, d2) -> ()>
744 /// .....
745 /// %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %c0_f32
746 ///   : memref<1x3x3x3xf32>, vector<3x3x3xf32>
747 /// %1 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %c0_f32
748 ///   : memref<1x3x3x3xf32>, vector<3x3x3xf32>
749 /// %2 = vector.contract {indexing_maps = [#map0, #map0, #map1],
750 ///   iterator_types = ["reduction", "reduction", "reduction"]} %0, %1,
751 ///   %c0_f32 : vector<3x3x3xf32>, vector<3x3x3xf32> into f32
752 /// store %2, %arg2[%c0, %c0, %c0, %c0] : memref<?x?x?x?xf32>
753 /// ```
754 /// where first 2 operations read input and kernel memory buffers into vectors.
755 /// Subsequently, they are contracted together and the result is written to
756 /// the first entry of the output buffer.
757 template <typename ConvOp, int N>
758 class ConvOpVectorization : public OpRewritePattern<ConvOp> {
759   using OpRewritePattern<ConvOp>::OpRewritePattern;
760   SmallVector<bool, 4> mask;
761 
762 public:
ConvOpVectorization(MLIRContext * context,SmallVector<bool,4> msk)763   ConvOpVectorization(MLIRContext *context, SmallVector<bool, 4> msk)
764       : OpRewritePattern<ConvOp>(context) {
765     assert(msk.size() == N && "Mask size does not match rank");
766     this->mask = msk;
767   }
768 
769   LogicalResult matchAndRewrite(ConvOp minOp,
770                                 PatternRewriter &rewriter) const override;
771 };
772 
773 //===----------------------------------------------------------------------===//
774 // Support for staged pattern application.
775 //===----------------------------------------------------------------------===//
776 /// Helper function to allow applying rewrite patterns, interleaved with more
777 /// global transformations, in a staged fashion:
778 ///   1. the first stage consists of a list of FrozenRewritePatternList. Each
779 ///   FrozenRewritePatternList in this list is applied once, in order.
780 ///   2. the second stage consists of a single OwningRewritePattern that is
781 ///   applied greedily until convergence.
782 ///   3. the third stage consists of applying a lambda, generally used for
783 ///   non-local transformation effects. This allows creating custom fused
784 ///   transformations where patterns can be ordered and applied at a finer
785 ///   granularity than a sequence of traditional compiler passes.
786 LogicalResult applyStagedPatterns(
787     Operation *op, ArrayRef<FrozenRewritePatternList> stage1Patterns,
788     const FrozenRewritePatternList &stage2Patterns,
789     function_ref<LogicalResult(Operation *)> stage3Lambda = nullptr);
790 
791 //===----------------------------------------------------------------------===//
792 // Support for sparse tensor code generation.
793 //
794 // The sparse compiler part of MLIR lowers a tensor expression formulated as a
795 // Linalg operation into a sequence of loops depending on what dimensions of the
796 // tensors are marked dense or sparse. The generated code distinguishes between:
797 // (1) for-loops that iterate over a single dense dimension,
798 // (2) for-loops that iterate over a single sparse dimension,
799 // (3) while-loops that co-iterate over several sparse dimensions.
800 // The for-loops may be subsequently optimized for parallel or vector execution.
801 //
802 // For more details, the Dialect/Linalg/Transforms/Sparsification.cpp file.
803 //===----------------------------------------------------------------------===//
804 
805 /// Defines a parallelization strategy. Any implicit loop in the Linalg
806 /// operation that is marked "parallel" (thus not "reduction") is a candidate
807 /// for parallelization. The loop is made parallel if (1) allowed by the
808 /// strategy (e.g., AnyStorageOuterLoop considers either a dense or sparse
809 /// outermost loop only), and (2) the generated code is an actual for-loop
810 /// (and not a co-iterating while-loop).
811 enum class SparseParallelizationStrategy {
812   kNone,
813   kDenseOuterLoop,
814   kAnyStorageOuterLoop,
815   kDenseAnyLoop,
816   kAnyStorageAnyLoop
817   // TODO: support reduction parallelization too?
818 };
819 
820 /// Defines a vectorization strategy. Any implicit inner loop in the Linalg
821 /// operation is a candidate (full SIMD for "parallel" loops and horizontal
822 /// SIMD for "reduction" loops). A loop is actually vectorized if (1) allowed
823 /// by the strategy, and (2) the emitted code is an actual for-loop (and not
824 /// a co-iterating while-loop).
825 enum class SparseVectorizationStrategy {
826   kNone,
827   kDenseInnerLoop,
828   kAnyStorageInnerLoop
829 };
830 
831 /// Defines a type for "pointer" and "index" storage in the sparse storage
832 /// scheme, with a choice between the native platform-dependent index width,
833 /// 64-bit integers, or 32-bit integers. A narrow width obviously reduces
834 /// the memory footprint of the sparse storage scheme, but the width should
835 /// suffice to define the total required range (viz. the maximum number of
836 /// stored entries per indirection level for the "pointers" and the maximum
837 /// value of each tensor index over all dimensions for the "indices").
838 enum class SparseIntType { kNative, kI64, kI32 };
839 
840 /// Sparsification options.
841 struct SparsificationOptions {
SparsificationOptionsSparsificationOptions842   SparsificationOptions(SparseParallelizationStrategy p,
843                         SparseVectorizationStrategy v, unsigned vl,
844                         SparseIntType pt, SparseIntType it)
845       : parallelizationStrategy(p), vectorizationStrategy(v), vectorLength(vl),
846         ptrType(pt), indType(it) {}
SparsificationOptionsSparsificationOptions847   SparsificationOptions()
848       : SparsificationOptions(SparseParallelizationStrategy::kNone,
849                               SparseVectorizationStrategy::kNone, 1u,
850                               SparseIntType::kNative, SparseIntType::kNative) {}
851   SparseParallelizationStrategy parallelizationStrategy;
852   SparseVectorizationStrategy vectorizationStrategy;
853   unsigned vectorLength;
854   SparseIntType ptrType;
855   SparseIntType indType;
856 };
857 
858 /// Set up sparsification rewriting rules with the given options.
859 void populateSparsificationPatterns(
860     MLIRContext *context, OwningRewritePatternList &patterns,
861     const SparsificationOptions &options = SparsificationOptions());
862 
863 } // namespace linalg
864 } // namespace mlir
865 
866 #endif // DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_
867