1 //===- VectorTransforms.h - Vector transformations as patterns --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef DIALECT_VECTOR_VECTORTRANSFORMS_H_
10 #define DIALECT_VECTOR_VECTORTRANSFORMS_H_
11 
12 #include "mlir/Dialect/Vector/VectorOps.h"
13 #include "mlir/Dialect/Vector/VectorUtils.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/PatternMatch.h"
16 
17 namespace mlir {
18 class MLIRContext;
19 class OwningRewritePatternList;
20 class VectorTransferOpInterface;
21 
22 namespace scf {
23 class IfOp;
24 } // namespace scf
25 
26 /// Collect a set of patterns to convert from the Vector dialect to itself.
27 /// Should be merged with populateVectorToSCFLoweringPattern.
28 void populateVectorToVectorConversionPatterns(
29     MLIRContext *context, OwningRewritePatternList &patterns,
30     ArrayRef<int64_t> coarseVectorShape = {},
31     ArrayRef<int64_t> fineVectorShape = {});
32 
33 namespace vector {
34 
35 /// Entry point for unrolling declarative pattern rewrites.
36 /// `op` is unrolled to the `targetShape` as follows, for each of its operands:
37 ///   1. the unrolled type `unrolledVectorType` and number of unrolled instances
38 ///   `numUnrolledInstances` are computed from the `targetShape`. For now it is
39 ///   assumed the unrolling factors divide the vector sizes.
40 ///   2. a fakeFork cast op is inserted that takes the operand and returns
41 ///   `numUnrolledInstances` results of type `unrolledVectorType`.
42 ///   3. the original op is cloned `numUnrolledInstances` times, once for each
43 ///   result of the fakeFork cast op.
44 ///   4. a fakeJoin cast op takes all these results and merges them into a
45 ///   single aggregate vector result whose size matches the original
46 ///   non-unrolled op operand types.
47 ///
48 /// Example:
49 ///
50 ///    opA(operand0, operand1)  // numUnrolledInstances = 3
51 ///
52 ///            operand0                   operand1
53 ///               |                          |
54 ///             fork                       fork
55 ///        <----------gather all fork ops --------->
56 ///              /|\                        /|\
57 ///          f00 f01 f02                f10 f11 f12
58 ///        <---------- clone op 3 times --------->
59 ///          opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
60 ///                 \            |            /
61 ///      <-------------------- join ------------------------->
62 ///
63 /// Other local patterns then kick in iteratively (including DCE) and compose
64 /// until all the fakeFork and fakeJoin ops are removed.
65 ///
66 /// This will be extended in the future to support more advanced use cases than
67 /// simple pointwise ops.
68 SmallVector<Value, 1> unrollSingleResultVectorOp(OpBuilder &builder,
69                                                  Operation *op,
70                                                  ArrayRef<int64_t> targetShape);
71 
72 /// Unroll a transfer_write op. Break up the vector source into a tuple of
73 /// vectors matching the given shape. Then store each element with its own
74 /// transfer_write.
75 ///
76 /// Example:
77 /// vector.transfer_write %A, %M[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
78 /// ->
79 /// %0 = vector.extract_slices %A, [2, 4], [1, 1] :
80 ///                vector<4x4xf32> into tuple<vector<2x4xf32>, vector<2x4xf32>>
81 /// %1 = vector.tuple_get %0, 0 : tuple<vector<2x4xf32>, vector<2x4xf32>>
82 /// vector.transfer_write %1, %M[%c0, %c0] : vector<2x4xf32>, memref<4x4xf32>
83 /// %2 = vector.tuple_get %0, 1 : tuple<vector<2x4xf32>, vector<2x4xf32>>
84 /// vector.transfer_write %2, %M[%c2, %c0] : vector<2x4xf32>, memref<4x4xf32>
85 LogicalResult unrollTransferWriteOp(OpBuilder &builder, Operation *op,
86                                     ArrayRef<int64_t> targetShape);
87 
88 /// Options that control the vector unrolling.
89 struct UnrollVectorOptions {
90   using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
91   /// Callback function that indicates whether vector unrolling should be
92   /// attempted on the operation.
93   FilterConstraintFnType filterConstraint = nullptr;
setFilterConstraintUnrollVectorOptions94   UnrollVectorOptions &setFilterConstraint(FilterConstraintFnType constraint) {
95     filterConstraint = constraint;
96     return *this;
97   }
98 
99   using NativeShapeFnType =
100       std::function<Optional<SmallVector<int64_t, 4>>(Operation *op)>;
101   /// Function that returns the shape of the vector to unroll to for a given
102   /// operation. The unrolling is aborted if the function returns `llvm::None`.
103   NativeShapeFnType nativeShape = nullptr;
setNativeShapeFnUnrollVectorOptions104   UnrollVectorOptions &setNativeShapeFn(NativeShapeFnType fn) {
105     nativeShape = fn;
106     return *this;
107   }
108 
109   /// Set the native shape to use for unrolling.
setNativeShapeUnrollVectorOptions110   UnrollVectorOptions &setNativeShape(ArrayRef<int64_t> shape) {
111     SmallVector<int64_t, 4> tsShape(shape.begin(), shape.end());
112     nativeShape = [=](Operation *) -> Optional<SmallVector<int64_t, 4>> {
113       return tsShape;
114     };
115     return *this;
116   }
117 };
118 /// Pattern to apply `unrollSingleResultVectorOp` to a `targetShape`
119 /// declaratively.
120 struct UnrollVectorPattern : public RewritePattern {
121   using FilterConstraintType = std::function<LogicalResult(Operation *op)>;
UnrollVectorPatternUnrollVectorPattern122   UnrollVectorPattern(MLIRContext *context, UnrollVectorOptions options)
123       : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()), options(options) {}
matchAndRewriteUnrollVectorPattern124   LogicalResult matchAndRewrite(Operation *op,
125                                 PatternRewriter &rewriter) const override {
126     if (options.filterConstraint && failed(options.filterConstraint(op)))
127       return failure();
128     if (!options.nativeShape) {
129       return op->emitError("vector unrolling expects the native shape or native"
130                            "shape call back function to be set");
131     }
132     auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
133     if (!unrollableVectorOp)
134       return failure();
135     auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
136     if (!maybeUnrollShape)
137       return failure();
138     Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op);
139     if (!targetShape)
140       return op->emitError("failed to get target shape for vector unroll");
141     auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape);
142     if (!maybeShapeRatio ||
143         llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
144       return failure();
145     if (isa<TransferWriteOp>(op)) {
146       if (failed(unrollTransferWriteOp(rewriter, op, *targetShape)))
147         return failure();
148       rewriter.eraseOp(op);
149       return success();
150     }
151     if (op->getNumResults() != 1)
152       return failure();
153     auto resultVector = unrollSingleResultVectorOp(rewriter, op, *targetShape);
154     if (resultVector.size() != 1)
155       return failure();
156     rewriter.replaceOp(op, resultVector.front());
157     return success();
158   }
159 
160 private:
161   UnrollVectorOptions options;
162 };
163 
164 /// Split a vector.transfer operation into an unmasked fastpath and a slowpath.
165 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
166 /// newly created conditional upon function return.
167 /// To accomodate for the fact that the original vector.transfer indexing may be
168 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
169 /// scf.if op returns a view and values of type index.
170 /// At this time, only vector.transfer_read case is implemented.
171 ///
172 /// Example (a 2-D vector.transfer_read):
173 /// ```
174 ///    %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
175 /// ```
176 /// is transformed into:
177 /// ```
178 ///    %1:3 = scf.if (%inBounds) {
179 ///      // fastpath, direct cast
180 ///      memref_cast %A: memref<A...> to compatibleMemRefType
181 ///      scf.yield %view : compatibleMemRefType, index, index
182 ///    } else {
183 ///      // slowpath, masked vector.transfer or linalg.copy.
184 ///      memref_cast %alloc: memref<B...> to compatibleMemRefType
185 ///      scf.yield %4 : compatibleMemRefType, index, index
186 //     }
187 ///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]}
188 /// ```
189 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
190 ///
191 /// Preconditions:
192 ///  1. `xferOp.permutation_map()` must be a minor identity map
193 ///  2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
194 ///  must be equal. This will be relaxed in the future but requires
195 ///  rank-reducing subviews.
196 LogicalResult
197 splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp);
198 LogicalResult splitFullAndPartialTransfer(
199     OpBuilder &b, VectorTransferOpInterface xferOp,
200     VectorTransformsOptions options = VectorTransformsOptions(),
201     scf::IfOp *ifOp = nullptr);
202 
203 /// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
204 /// may take an extra filter to perform selection at a finer granularity.
205 struct VectorTransferFullPartialRewriter : public RewritePattern {
206   using FilterConstraintType =
207       std::function<LogicalResult(VectorTransferOpInterface op)>;
208 
209   explicit VectorTransferFullPartialRewriter(
210       MLIRContext *context,
211       VectorTransformsOptions options = VectorTransformsOptions(),
212       FilterConstraintType filter =
213           [](VectorTransferOpInterface op) { return success(); },
214       PatternBenefit benefit = 1)
RewritePatternVectorTransferFullPartialRewriter215       : RewritePattern(benefit, MatchAnyOpTypeTag()), options(options),
216         filter(filter) {}
217 
218   /// Performs the rewrite.
219   LogicalResult matchAndRewrite(Operation *op,
220                                 PatternRewriter &rewriter) const override;
221 
222 private:
223   VectorTransformsOptions options;
224   FilterConstraintType filter;
225 };
226 
227 struct DistributeOps {
228   ExtractMapOp extract;
229   InsertMapOp insert;
230 };
231 
232 /// Distribute a N-D vector pointwise operation over a range of given ids taking
233 /// *all* values in [0 .. multiplicity - 1] (e.g. loop induction variable or
234 /// SPMD id). This transformation only inserts
235 /// vector.extract_map/vector.insert_map. It is meant to be used with
236 /// canonicalizations pattern to propagate and fold the vector
237 /// insert_map/extract_map operations.
238 /// Transforms:
239 //  %v = addf %a, %b : vector<32xf32>
240 /// to:
241 /// %v = addf %a, %b : vector<32xf32>
242 /// %ev = vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32>
243 /// %nv = vector.insert_map %ev, %id, 32 : vector<1xf32> into vector<32xf32>
244 Optional<DistributeOps>
245 distributPointwiseVectorOp(OpBuilder &builder, Operation *op,
246                            ArrayRef<Value> id, ArrayRef<int64_t> multiplicity,
247                            const AffineMap &map);
248 /// Canonicalize an extra element using the result of a pointwise operation.
249 /// Transforms:
250 /// %v = addf %a, %b : vector32xf32>
251 /// %dv = vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32>
252 /// to:
253 /// %da = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32>
254 /// %db = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32>
255 /// %dv = addf %da, %db : vector<1xf32>
256 struct PointwiseExtractPattern : public OpRewritePattern<ExtractMapOp> {
257   using FilterConstraintType = std::function<LogicalResult(ExtractMapOp op)>;
258   PointwiseExtractPattern(
259       MLIRContext *context, FilterConstraintType constraint =
260                                 [](ExtractMapOp op) { return success(); })
261       : OpRewritePattern<ExtractMapOp>(context), filter(constraint) {}
262   LogicalResult matchAndRewrite(ExtractMapOp extract,
263                                 PatternRewriter &rewriter) const override;
264 
265 private:
266   FilterConstraintType filter;
267 };
268 
269 /// Implements transfer op write to read forwarding and dead transfer write
270 /// optimizations.
271 void transferOpflowOpt(FuncOp func);
272 
273 } // namespace vector
274 
275 //===----------------------------------------------------------------------===//
276 // Finer-grained patterns exposed for more control over individual lowerings.
277 //===----------------------------------------------------------------------===//
278 
279 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
280 /// semantics to:
281 /// ```
282 ///    %flattened_a = vector.shape_cast %a
283 ///    %flattened_b = vector.shape_cast %b
284 ///    %flattened_d = vector.matmul %flattened_a, %flattened_b
285 ///    %d = vector.shape_cast %%flattened_d
286 ///    %e = add %c, %d
287 /// ```
288 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
289 //
290 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
291 /// the vector.contract op is a row-major matrix multiply.
292 class ContractionOpToMatmulOpLowering
293     : public OpRewritePattern<vector::ContractionOp> {
294 public:
295   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
296   using FilterConstraintType =
297       std::function<LogicalResult(vector::ContractionOp op)>;
298 
defaultFilter(vector::ContractionOp op)299   static LogicalResult defaultFilter(vector::ContractionOp op) {
300     return success();
301   }
302 
303   ContractionOpToMatmulOpLowering(
304       vector::VectorTransformsOptions vectorTransformsOptions,
305       MLIRContext *context, FilterConstraintType constraint = defaultFilter)
306       : OpRewritePattern<vector::ContractionOp>(context),
307         vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {}
308 
309   LogicalResult matchAndRewrite(vector::ContractionOp op,
310                                 PatternRewriter &rewriter) const override;
311 
312 private:
313   /// Options to control the vector patterns.
314   vector::VectorTransformsOptions vectorTransformsOptions;
315   FilterConstraintType filter;
316 };
317 
318 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
319 /// semantics to a reduction_size-unrolled sequence:
320 /// ```
321 ///    %at = vector.transpose %a, [1, 0]
322 ///    %bRow0 = vector.extract %b[0]
323 ///    %atRow0 = vector.extract %at[0]
324 ///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
325 ///    ...
326 ///    %bRowK = vector.extract %b[K]
327 ///    %atRowK = vector.extract %at[K]
328 ///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
329 /// ```
330 ///
331 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
332 /// the vector.contract op is a row-major matrix multiply.
333 class ContractionOpToOuterProductOpLowering
334     : public OpRewritePattern<vector::ContractionOp> {
335 public:
336   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
337   using FilterConstraintType =
338       std::function<LogicalResult(vector::ContractionOp op)>;
339 
defaultFilter(vector::ContractionOp op)340   static LogicalResult defaultFilter(vector::ContractionOp op) {
341     return success();
342   }
343 
344   ContractionOpToOuterProductOpLowering(
345       vector::VectorTransformsOptions vectorTransformsOptions,
346       MLIRContext *context, FilterConstraintType constraint = defaultFilter)
347       : OpRewritePattern<vector::ContractionOp>(context),
348         vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {}
349 
350   LogicalResult matchAndRewrite(vector::ContractionOp op,
351                                 PatternRewriter &rewriter) const override;
352 
353 private:
354   /// Options to control the vector patterns.
355   vector::VectorTransformsOptions vectorTransformsOptions;
356   FilterConstraintType filter;
357 };
358 
359 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
360 /// semantics to an output-size-unrolled sequence:
361 /// ```
362 ///    %out = constant ... : vector<MxNxelt_type>
363 ///    %bt = vector.transpose %b, [1, 0]
364 ///    %aRow0 = vector.extract %a[0]
365 ///    %btRow0 = vector.extract %bt[0]
366 ///    %c00 = vector.reduce %atRow0, %bRow0
367 ///    %out00 = vector.insert %c00, %out[0, 0]
368 ///    ...
369 ///    %aRowLast = vector.extract %at[M-1]
370 ///    %btRowLast = vector.extract %b[N-1]
371 ///    %cLastLast = vector.reduce %atRowLast, %bRowLast
372 ///    %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
373 /// ```
374 ///
375 /// This only kicks in when VectorTransformsOptions is set to Dot and
376 /// the vector.contract op is a row-major matmul or matvec.
377 class ContractionOpToDotLowering
378     : public OpRewritePattern<vector::ContractionOp> {
379 public:
380   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
381   using FilterConstraintType =
382       std::function<LogicalResult(vector::ContractionOp op)>;
383 
defaultFilter(vector::ContractionOp op)384   static LogicalResult defaultFilter(vector::ContractionOp op) {
385     return success();
386   }
387 
388   ContractionOpToDotLowering(
389       vector::VectorTransformsOptions vectorTransformsOptions,
390       MLIRContext *context, FilterConstraintType constraint = defaultFilter)
391       : OpRewritePattern<vector::ContractionOp>(context),
392         vectorTransformsOptions(vectorTransformsOptions),
393         filter(defaultFilter) {}
394 
395   LogicalResult matchAndRewrite(vector::ContractionOp op,
396                                 PatternRewriter &rewriter) const override;
397 
398 private:
399   /// Options to control the vector patterns.
400   vector::VectorTransformsOptions vectorTransformsOptions;
401   FilterConstraintType filter;
402 };
403 
404 /// Progressive lowering of ContractionOp.
405 ///
406 /// One:
407 ///   %x = vector.contract with at least one free/batch dimension
408 /// is replaced by:
409 ///   %a = vector.contract with one less free/batch dimension
410 ///   %b = vector.contract with one less free/batch dimension
411 ///   ..
412 ///   %x = combine %a %b ..
413 /// until a pure contraction is reached (no free/batch dimensions),
414 /// which is replaced by a dot-product.
415 ///
416 /// This only kicks in when either VectorTransformsOptions is set
417 /// to Dot or when other contraction patterns fail.
418 class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
419 public:
420   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
421   using FilterConstraintType =
422       std::function<LogicalResult(vector::ContractionOp op)>;
423 
defaultFilter(vector::ContractionOp op)424   static LogicalResult defaultFilter(vector::ContractionOp op) {
425     return success();
426   }
427 
428   ContractionOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,
429                         MLIRContext *context,
430                         FilterConstraintType constraint = defaultFilter)
431       : OpRewritePattern<vector::ContractionOp>(context),
432         vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {}
433 
434   LogicalResult matchAndRewrite(vector::ContractionOp op,
435                                 PatternRewriter &rewriter) const override;
436 
437 private:
438   /// Options to control the vector patterns.
439   vector::VectorTransformsOptions vectorTransformsOptions;
440   FilterConstraintType filter;
441   // Lower one parallel dimension.
442   Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
443                       int64_t rhsIndex, PatternRewriter &rewriter) const;
444   // Lower one reduction dimension.
445   Value lowerReduction(vector::ContractionOp op,
446                        PatternRewriter &rewriter) const;
447 };
448 
449 } // namespace mlir
450 
451 #endif // DIALECT_VECTOR_VECTORTRANSFORMS_H_
452