1 //===- VectorToSCF.h - Utils to convert from the vector dialect -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_CONVERSION_VECTORTOSCF_VECTORTOSCF_H_
10 #define MLIR_CONVERSION_VECTORTOSCF_VECTORTOSCF_H_
11 
12 #include "mlir/IR/PatternMatch.h"
13 
14 namespace mlir {
15 class MLIRContext;
16 class OwningRewritePatternList;
17 class Pass;
18 
19 /// Control whether unrolling is used when lowering vector transfer ops to SCF.
20 ///
21 /// Case 1:
22 /// =======
23 /// When `unroll` is false, a temporary buffer is created through which
24 /// individual 1-D vector are staged. this is consistent with the lack of an
25 /// LLVM instruction to dynamically index into an aggregate (see the Vector
26 /// dialect lowering to LLVM deep dive).
27 /// An instruction such as:
28 /// ```
29 ///    vector.transfer_write %vec, %A[%base, %base] :
30 ///      vector<17x15xf32>, memref<?x?xf32>
31 /// ```
32 /// Lowers to pseudo-IR resembling:
33 /// ```
34 ///    %0 = alloc() : memref<17xvector<15xf32>>
35 ///    %1 = vector.type_cast %0 :
36 ///      memref<17xvector<15xf32>> to memref<vector<17x15xf32>>
37 ///    store %vec, %1[] : memref<vector<17x15xf32>>
38 ///    %dim = dim %A, 0 : memref<?x?xf32>
39 ///    affine.for %I = 0 to 17 {
40 ///      %add = affine.apply %I + %base
41 ///      %cmp = cmpi "slt", %add, %dim : index
42 ///      scf.if %cmp {
43 ///        %vec_1d = load %0[%I] : memref<17xvector<15xf32>>
44 ///        vector.transfer_write %vec_1d, %A[%add, %base] :
45 ///          vector<15xf32>, memref<?x?xf32>
46 /// ```
47 ///
48 /// Case 2:
49 /// =======
50 /// When `unroll` is true, the temporary buffer is skipped and static indices
51 /// into aggregates can be used (see the Vector dialect lowering to LLVM deep
52 /// dive).
53 /// An instruction such as:
54 /// ```
55 ///    vector.transfer_write %vec, %A[%base, %base] :
56 ///      vector<3x15xf32>, memref<?x?xf32>
57 /// ```
58 /// Lowers to pseudo-IR resembling:
59 /// ```
60 ///    %0 = vector.extract %arg2[0] : vector<3x15xf32>
61 ///    vector.transfer_write %0, %arg0[%arg1, %arg1] : vector<15xf32>,
62 ///    memref<?x?xf32> %1 = affine.apply #map1()[%arg1] %2 = vector.extract
63 ///    %arg2[1] : vector<3x15xf32> vector.transfer_write %2, %arg0[%1, %arg1] :
64 ///    vector<15xf32>, memref<?x?xf32> %3 = affine.apply #map2()[%arg1] %4 =
65 ///    vector.extract %arg2[2] : vector<3x15xf32> vector.transfer_write %4,
66 ///    %arg0[%3, %arg1] : vector<15xf32>, memref<?x?xf32>
67 /// ```
68 struct VectorTransferToSCFOptions {
69   bool unroll = false;
setUnrollVectorTransferToSCFOptions70   VectorTransferToSCFOptions &setUnroll(bool u) {
71     unroll = u;
72     return *this;
73   }
74 };
75 
76 /// Implements lowering of TransferReadOp and TransferWriteOp to a
77 /// proper abstraction for the hardware.
78 ///
79 /// There are multiple cases.
80 ///
81 /// Case A: Permutation Map does not permute or broadcast.
82 /// ======================================================
83 ///
84 /// Progressive lowering occurs to 1-D vector transfer ops according to the
85 /// description in `VectorTransferToSCFOptions`.
86 ///
87 /// Case B: Permutation Map permutes and/or broadcasts.
88 /// ======================================================
89 ///
90 /// This path will be progressively deprecated and folded into the case above by
91 /// using vector broadcast and transpose operations.
92 ///
93 /// This path only emits a simple loop nest that performs clipped pointwise
94 /// copies from a remote to a locally allocated memory.
95 ///
96 /// Consider the case:
97 ///
98 /// ```mlir
99 ///    // Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into
100 ///    // vector<32x256xf32> and pad with %f0 to handle the boundary case:
101 ///    %f0 = constant 0.0f : f32
102 ///    scf.for %i0 = 0 to %0 {
103 ///      scf.for %i1 = 0 to %1 step %c256 {
104 ///        scf.for %i2 = 0 to %2 step %c32 {
105 ///          %v = vector.transfer_read %A[%i0, %i1, %i2], %f0
106 ///               {permutation_map: (d0, d1, d2) -> (d2, d1)} :
107 ///               memref<?x?x?xf32>, vector<32x256xf32>
108 ///    }}}
109 /// ```
110 ///
111 /// The rewriters construct loop and indices that access MemRef A in a pattern
112 /// resembling the following (while guaranteeing an always full-tile
113 /// abstraction):
114 ///
115 /// ```mlir
116 ///    scf.for %d2 = 0 to %c256 {
117 ///      scf.for %d1 = 0 to %c32 {
118 ///        %s = %A[%i0, %i1 + %d1, %i2 + %d2] : f32
119 ///        %tmp[%d2, %d1] = %s
120 ///      }
121 ///    }
122 /// ```
123 ///
124 /// In the current state, only a clipping transfer is implemented by `clip`,
125 /// which creates individual indexing expressions of the form:
126 ///
127 /// ```mlir-dsc
128 ///    auto condMax = i + ii < N;
129 ///    auto max = std_select(condMax, i + ii, N - one)
130 ///    auto cond = i + ii < zero;
131 ///    std_select(cond, zero, max);
132 /// ```
133 ///
134 /// In the future, clipping should not be the only way and instead we should
135 /// load vectors + mask them. Similarly on the write side, load/mask/store for
136 /// implementing RMW behavior.
137 ///
138 /// Lowers TransferOp into a combination of:
139 ///   1. local memory allocation;
140 ///   2. perfect loop nest over:
141 ///      a. scalar load/stores from local buffers (viewed as a scalar memref);
142 ///      a. scalar store/load to original memref (with clipping).
143 ///   3. vector_load/store
144 ///   4. local memory deallocation.
145 /// Minor variations occur depending on whether a TransferReadOp or
146 /// a TransferWriteOp is rewritten.
147 template <typename TransferOpTy>
148 struct VectorTransferRewriter : public RewritePattern {
149   explicit VectorTransferRewriter(VectorTransferToSCFOptions options,
150                                   MLIRContext *context);
151 
152   /// Used for staging the transfer in a local buffer.
153   MemRefType tmpMemRefType(TransferOpTy transfer) const;
154 
155   /// Performs the rewrite.
156   LogicalResult matchAndRewrite(Operation *op,
157                                 PatternRewriter &rewriter) const override;
158 
159   /// See description of `VectorTransferToSCFOptions`.
160   VectorTransferToSCFOptions options;
161 };
162 
163 /// Collect a set of patterns to convert from the Vector dialect to SCF + std.
164 void populateVectorToSCFConversionPatterns(
165     OwningRewritePatternList &patterns, MLIRContext *context,
166     const VectorTransferToSCFOptions &options = VectorTransferToSCFOptions());
167 
168 /// Create a pass to convert a subset of vector ops to SCF.
169 std::unique_ptr<Pass> createConvertVectorToSCFPass(
170     const VectorTransferToSCFOptions &options = VectorTransferToSCFOptions());
171 
172 } // namespace mlir
173 
174 #endif // MLIR_CONVERSION_VECTORTOSCF_VECTORTOSCF_H_
175