1 //===- VectorTransforms.h - Vector transformations as patterns --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef DIALECT_VECTOR_VECTORTRANSFORMS_H_ 10 #define DIALECT_VECTOR_VECTORTRANSFORMS_H_ 11 12 #include "mlir/Dialect/Vector/VectorOps.h" 13 #include "mlir/Dialect/Vector/VectorUtils.h" 14 #include "mlir/IR/BuiltinOps.h" 15 #include "mlir/IR/PatternMatch.h" 16 17 namespace mlir { 18 class MLIRContext; 19 class OwningRewritePatternList; 20 class VectorTransferOpInterface; 21 22 namespace scf { 23 class IfOp; 24 } // namespace scf 25 26 /// Collect a set of patterns to convert from the Vector dialect to itself. 27 /// Should be merged with populateVectorToSCFLoweringPattern. 28 void populateVectorToVectorConversionPatterns( 29 MLIRContext *context, OwningRewritePatternList &patterns, 30 ArrayRef<int64_t> coarseVectorShape = {}, 31 ArrayRef<int64_t> fineVectorShape = {}); 32 33 namespace vector { 34 35 /// Entry point for unrolling declarative pattern rewrites. 36 /// `op` is unrolled to the `targetShape` as follows, for each of its operands: 37 /// 1. the unrolled type `unrolledVectorType` and number of unrolled instances 38 /// `numUnrolledInstances` are computed from the `targetShape`. For now it is 39 /// assumed the unrolling factors divide the vector sizes. 40 /// 2. a fakeFork cast op is inserted that takes the operand and returns 41 /// `numUnrolledInstances` results of type `unrolledVectorType`. 42 /// 3. the original op is cloned `numUnrolledInstances` times, once for each 43 /// result of the fakeFork cast op. 44 /// 4. a fakeJoin cast op takes all these results and merges them into a 45 /// single aggregate vector result whose size matches the original 46 /// non-unrolled op operand types. 47 /// 48 /// Example: 49 /// 50 /// opA(operand0, operand1) // numUnrolledInstances = 3 51 /// 52 /// operand0 operand1 53 /// | | 54 /// fork fork 55 /// <----------gather all fork ops ---------> 56 /// /|\ /|\ 57 /// f00 f01 f02 f10 f11 f12 58 /// <---------- clone op 3 times ---------> 59 /// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12) 60 /// \ | / 61 /// <-------------------- join -------------------------> 62 /// 63 /// Other local patterns then kick in iteratively (including DCE) and compose 64 /// until all the fakeFork and fakeJoin ops are removed. 65 /// 66 /// This will be extended in the future to support more advanced use cases than 67 /// simple pointwise ops. 68 SmallVector<Value, 1> unrollSingleResultVectorOp(OpBuilder &builder, 69 Operation *op, 70 ArrayRef<int64_t> targetShape); 71 72 /// Unroll a transfer_write op. Break up the vector source into a tuple of 73 /// vectors matching the given shape. Then store each element with its own 74 /// transfer_write. 75 /// 76 /// Example: 77 /// vector.transfer_write %A, %M[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32> 78 /// -> 79 /// %0 = vector.extract_slices %A, [2, 4], [1, 1] : 80 /// vector<4x4xf32> into tuple<vector<2x4xf32>, vector<2x4xf32>> 81 /// %1 = vector.tuple_get %0, 0 : tuple<vector<2x4xf32>, vector<2x4xf32>> 82 /// vector.transfer_write %1, %M[%c0, %c0] : vector<2x4xf32>, memref<4x4xf32> 83 /// %2 = vector.tuple_get %0, 1 : tuple<vector<2x4xf32>, vector<2x4xf32>> 84 /// vector.transfer_write %2, %M[%c2, %c0] : vector<2x4xf32>, memref<4x4xf32> 85 LogicalResult unrollTransferWriteOp(OpBuilder &builder, Operation *op, 86 ArrayRef<int64_t> targetShape); 87 88 /// Options that control the vector unrolling. 89 struct UnrollVectorOptions { 90 using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>; 91 /// Callback function that indicates whether vector unrolling should be 92 /// attempted on the operation. 93 FilterConstraintFnType filterConstraint = nullptr; setFilterConstraintUnrollVectorOptions94 UnrollVectorOptions &setFilterConstraint(FilterConstraintFnType constraint) { 95 filterConstraint = constraint; 96 return *this; 97 } 98 99 using NativeShapeFnType = 100 std::function<Optional<SmallVector<int64_t, 4>>(Operation *op)>; 101 /// Function that returns the shape of the vector to unroll to for a given 102 /// operation. The unrolling is aborted if the function returns `llvm::None`. 103 NativeShapeFnType nativeShape = nullptr; setNativeShapeFnUnrollVectorOptions104 UnrollVectorOptions &setNativeShapeFn(NativeShapeFnType fn) { 105 nativeShape = fn; 106 return *this; 107 } 108 109 /// Set the native shape to use for unrolling. setNativeShapeUnrollVectorOptions110 UnrollVectorOptions &setNativeShape(ArrayRef<int64_t> shape) { 111 SmallVector<int64_t, 4> tsShape(shape.begin(), shape.end()); 112 nativeShape = [=](Operation *) -> Optional<SmallVector<int64_t, 4>> { 113 return tsShape; 114 }; 115 return *this; 116 } 117 }; 118 /// Pattern to apply `unrollSingleResultVectorOp` to a `targetShape` 119 /// declaratively. 120 struct UnrollVectorPattern : public RewritePattern { 121 using FilterConstraintType = std::function<LogicalResult(Operation *op)>; UnrollVectorPatternUnrollVectorPattern122 UnrollVectorPattern(MLIRContext *context, UnrollVectorOptions options) 123 : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()), options(options) {} matchAndRewriteUnrollVectorPattern124 LogicalResult matchAndRewrite(Operation *op, 125 PatternRewriter &rewriter) const override { 126 if (options.filterConstraint && failed(options.filterConstraint(op))) 127 return failure(); 128 if (!options.nativeShape) { 129 return op->emitError("vector unrolling expects the native shape or native" 130 "shape call back function to be set"); 131 } 132 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op); 133 if (!unrollableVectorOp) 134 return failure(); 135 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); 136 if (!maybeUnrollShape) 137 return failure(); 138 Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op); 139 if (!targetShape) 140 return op->emitError("failed to get target shape for vector unroll"); 141 auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape); 142 if (!maybeShapeRatio || 143 llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) 144 return failure(); 145 if (isa<TransferWriteOp>(op)) { 146 if (failed(unrollTransferWriteOp(rewriter, op, *targetShape))) 147 return failure(); 148 rewriter.eraseOp(op); 149 return success(); 150 } 151 if (op->getNumResults() != 1) 152 return failure(); 153 auto resultVector = unrollSingleResultVectorOp(rewriter, op, *targetShape); 154 if (resultVector.size() != 1) 155 return failure(); 156 rewriter.replaceOp(op, resultVector.front()); 157 return success(); 158 } 159 160 private: 161 UnrollVectorOptions options; 162 }; 163 164 /// Split a vector.transfer operation into an unmasked fastpath and a slowpath. 165 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the 166 /// newly created conditional upon function return. 167 /// To accomodate for the fact that the original vector.transfer indexing may be 168 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the 169 /// scf.if op returns a view and values of type index. 170 /// At this time, only vector.transfer_read case is implemented. 171 /// 172 /// Example (a 2-D vector.transfer_read): 173 /// ``` 174 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...> 175 /// ``` 176 /// is transformed into: 177 /// ``` 178 /// %1:3 = scf.if (%inBounds) { 179 /// // fastpath, direct cast 180 /// memref_cast %A: memref<A...> to compatibleMemRefType 181 /// scf.yield %view : compatibleMemRefType, index, index 182 /// } else { 183 /// // slowpath, masked vector.transfer or linalg.copy. 184 /// memref_cast %alloc: memref<B...> to compatibleMemRefType 185 /// scf.yield %4 : compatibleMemRefType, index, index 186 // } 187 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]} 188 /// ``` 189 /// where `alloc` is a top of the function alloca'ed buffer of one vector. 190 /// 191 /// Preconditions: 192 /// 1. `xferOp.permutation_map()` must be a minor identity map 193 /// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()` 194 /// must be equal. This will be relaxed in the future but requires 195 /// rank-reducing subviews. 196 LogicalResult 197 splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp); 198 LogicalResult splitFullAndPartialTransfer( 199 OpBuilder &b, VectorTransferOpInterface xferOp, 200 VectorTransformsOptions options = VectorTransformsOptions(), 201 scf::IfOp *ifOp = nullptr); 202 203 /// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern 204 /// may take an extra filter to perform selection at a finer granularity. 205 struct VectorTransferFullPartialRewriter : public RewritePattern { 206 using FilterConstraintType = 207 std::function<LogicalResult(VectorTransferOpInterface op)>; 208 209 explicit VectorTransferFullPartialRewriter( 210 MLIRContext *context, 211 VectorTransformsOptions options = VectorTransformsOptions(), 212 FilterConstraintType filter = 213 [](VectorTransferOpInterface op) { return success(); }, 214 PatternBenefit benefit = 1) RewritePatternVectorTransferFullPartialRewriter215 : RewritePattern(benefit, MatchAnyOpTypeTag()), options(options), 216 filter(filter) {} 217 218 /// Performs the rewrite. 219 LogicalResult matchAndRewrite(Operation *op, 220 PatternRewriter &rewriter) const override; 221 222 private: 223 VectorTransformsOptions options; 224 FilterConstraintType filter; 225 }; 226 227 struct DistributeOps { 228 ExtractMapOp extract; 229 InsertMapOp insert; 230 }; 231 232 /// Distribute a N-D vector pointwise operation over a range of given ids taking 233 /// *all* values in [0 .. multiplicity - 1] (e.g. loop induction variable or 234 /// SPMD id). This transformation only inserts 235 /// vector.extract_map/vector.insert_map. It is meant to be used with 236 /// canonicalizations pattern to propagate and fold the vector 237 /// insert_map/extract_map operations. 238 /// Transforms: 239 // %v = addf %a, %b : vector<32xf32> 240 /// to: 241 /// %v = addf %a, %b : vector<32xf32> 242 /// %ev = vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32> 243 /// %nv = vector.insert_map %ev, %id, 32 : vector<1xf32> into vector<32xf32> 244 Optional<DistributeOps> 245 distributPointwiseVectorOp(OpBuilder &builder, Operation *op, 246 ArrayRef<Value> id, ArrayRef<int64_t> multiplicity, 247 const AffineMap &map); 248 /// Canonicalize an extra element using the result of a pointwise operation. 249 /// Transforms: 250 /// %v = addf %a, %b : vector32xf32> 251 /// %dv = vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32> 252 /// to: 253 /// %da = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32> 254 /// %db = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32> 255 /// %dv = addf %da, %db : vector<1xf32> 256 struct PointwiseExtractPattern : public OpRewritePattern<ExtractMapOp> { 257 using FilterConstraintType = std::function<LogicalResult(ExtractMapOp op)>; 258 PointwiseExtractPattern( 259 MLIRContext *context, FilterConstraintType constraint = 260 [](ExtractMapOp op) { return success(); }) 261 : OpRewritePattern<ExtractMapOp>(context), filter(constraint) {} 262 LogicalResult matchAndRewrite(ExtractMapOp extract, 263 PatternRewriter &rewriter) const override; 264 265 private: 266 FilterConstraintType filter; 267 }; 268 269 /// Implements transfer op write to read forwarding and dead transfer write 270 /// optimizations. 271 void transferOpflowOpt(FuncOp func); 272 273 } // namespace vector 274 275 //===----------------------------------------------------------------------===// 276 // Finer-grained patterns exposed for more control over individual lowerings. 277 //===----------------------------------------------------------------------===// 278 279 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul 280 /// semantics to: 281 /// ``` 282 /// %flattened_a = vector.shape_cast %a 283 /// %flattened_b = vector.shape_cast %b 284 /// %flattened_d = vector.matmul %flattened_a, %flattened_b 285 /// %d = vector.shape_cast %%flattened_d 286 /// %e = add %c, %d 287 /// ``` 288 /// `vector.matmul` later lowers to `llvm.matrix.multiply`. 289 // 290 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and 291 /// the vector.contract op is a row-major matrix multiply. 292 class ContractionOpToMatmulOpLowering 293 : public OpRewritePattern<vector::ContractionOp> { 294 public: 295 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 296 using FilterConstraintType = 297 std::function<LogicalResult(vector::ContractionOp op)>; 298 defaultFilter(vector::ContractionOp op)299 static LogicalResult defaultFilter(vector::ContractionOp op) { 300 return success(); 301 } 302 303 ContractionOpToMatmulOpLowering( 304 vector::VectorTransformsOptions vectorTransformsOptions, 305 MLIRContext *context, FilterConstraintType constraint = defaultFilter) 306 : OpRewritePattern<vector::ContractionOp>(context), 307 vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {} 308 309 LogicalResult matchAndRewrite(vector::ContractionOp op, 310 PatternRewriter &rewriter) const override; 311 312 private: 313 /// Options to control the vector patterns. 314 vector::VectorTransformsOptions vectorTransformsOptions; 315 FilterConstraintType filter; 316 }; 317 318 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul 319 /// semantics to a reduction_size-unrolled sequence: 320 /// ``` 321 /// %at = vector.transpose %a, [1, 0] 322 /// %bRow0 = vector.extract %b[0] 323 /// %atRow0 = vector.extract %at[0] 324 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c 325 /// ... 326 /// %bRowK = vector.extract %b[K] 327 /// %atRowK = vector.extract %at[K] 328 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 329 /// ``` 330 /// 331 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and 332 /// the vector.contract op is a row-major matrix multiply. 333 class ContractionOpToOuterProductOpLowering 334 : public OpRewritePattern<vector::ContractionOp> { 335 public: 336 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 337 using FilterConstraintType = 338 std::function<LogicalResult(vector::ContractionOp op)>; 339 defaultFilter(vector::ContractionOp op)340 static LogicalResult defaultFilter(vector::ContractionOp op) { 341 return success(); 342 } 343 344 ContractionOpToOuterProductOpLowering( 345 vector::VectorTransformsOptions vectorTransformsOptions, 346 MLIRContext *context, FilterConstraintType constraint = defaultFilter) 347 : OpRewritePattern<vector::ContractionOp>(context), 348 vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {} 349 350 LogicalResult matchAndRewrite(vector::ContractionOp op, 351 PatternRewriter &rewriter) const override; 352 353 private: 354 /// Options to control the vector patterns. 355 vector::VectorTransformsOptions vectorTransformsOptions; 356 FilterConstraintType filter; 357 }; 358 359 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul 360 /// semantics to an output-size-unrolled sequence: 361 /// ``` 362 /// %out = constant ... : vector<MxNxelt_type> 363 /// %bt = vector.transpose %b, [1, 0] 364 /// %aRow0 = vector.extract %a[0] 365 /// %btRow0 = vector.extract %bt[0] 366 /// %c00 = vector.reduce %atRow0, %bRow0 367 /// %out00 = vector.insert %c00, %out[0, 0] 368 /// ... 369 /// %aRowLast = vector.extract %at[M-1] 370 /// %btRowLast = vector.extract %b[N-1] 371 /// %cLastLast = vector.reduce %atRowLast, %bRowLast 372 /// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1] 373 /// ``` 374 /// 375 /// This only kicks in when VectorTransformsOptions is set to Dot and 376 /// the vector.contract op is a row-major matmul or matvec. 377 class ContractionOpToDotLowering 378 : public OpRewritePattern<vector::ContractionOp> { 379 public: 380 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 381 using FilterConstraintType = 382 std::function<LogicalResult(vector::ContractionOp op)>; 383 defaultFilter(vector::ContractionOp op)384 static LogicalResult defaultFilter(vector::ContractionOp op) { 385 return success(); 386 } 387 388 ContractionOpToDotLowering( 389 vector::VectorTransformsOptions vectorTransformsOptions, 390 MLIRContext *context, FilterConstraintType constraint = defaultFilter) 391 : OpRewritePattern<vector::ContractionOp>(context), 392 vectorTransformsOptions(vectorTransformsOptions), 393 filter(defaultFilter) {} 394 395 LogicalResult matchAndRewrite(vector::ContractionOp op, 396 PatternRewriter &rewriter) const override; 397 398 private: 399 /// Options to control the vector patterns. 400 vector::VectorTransformsOptions vectorTransformsOptions; 401 FilterConstraintType filter; 402 }; 403 404 /// Progressive lowering of ContractionOp. 405 /// 406 /// One: 407 /// %x = vector.contract with at least one free/batch dimension 408 /// is replaced by: 409 /// %a = vector.contract with one less free/batch dimension 410 /// %b = vector.contract with one less free/batch dimension 411 /// .. 412 /// %x = combine %a %b .. 413 /// until a pure contraction is reached (no free/batch dimensions), 414 /// which is replaced by a dot-product. 415 /// 416 /// This only kicks in when either VectorTransformsOptions is set 417 /// to Dot or when other contraction patterns fail. 418 class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> { 419 public: 420 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 421 using FilterConstraintType = 422 std::function<LogicalResult(vector::ContractionOp op)>; 423 defaultFilter(vector::ContractionOp op)424 static LogicalResult defaultFilter(vector::ContractionOp op) { 425 return success(); 426 } 427 428 ContractionOpLowering(vector::VectorTransformsOptions vectorTransformsOptions, 429 MLIRContext *context, 430 FilterConstraintType constraint = defaultFilter) 431 : OpRewritePattern<vector::ContractionOp>(context), 432 vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {} 433 434 LogicalResult matchAndRewrite(vector::ContractionOp op, 435 PatternRewriter &rewriter) const override; 436 437 private: 438 /// Options to control the vector patterns. 439 vector::VectorTransformsOptions vectorTransformsOptions; 440 FilterConstraintType filter; 441 // Lower one parallel dimension. 442 Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex, 443 int64_t rhsIndex, PatternRewriter &rewriter) const; 444 // Lower one reduction dimension. 445 Value lowerReduction(vector::ContractionOp op, 446 PatternRewriter &rewriter) const; 447 }; 448 449 } // namespace mlir 450 451 #endif // DIALECT_VECTOR_VECTORTRANSFORMS_H_ 452