1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites as 1->N patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <type_traits>
14 
15 #include "mlir/Dialect/Affine/EDSC/Builders.h"
16 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
19 #include "mlir/Dialect/SCF/EDSC/Intrinsics.h"
20 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"
22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
24 #include "mlir/Dialect/Vector/VectorOps.h"
25 #include "mlir/Dialect/Vector/VectorTransforms.h"
26 #include "mlir/Dialect/Vector/VectorUtils.h"
27 #include "mlir/IR/AffineExpr.h"
28 #include "mlir/IR/AffineMap.h"
29 #include "mlir/IR/Attributes.h"
30 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinOps.h"
32 #include "mlir/IR/Location.h"
33 #include "mlir/IR/Matchers.h"
34 #include "mlir/IR/OperationSupport.h"
35 #include "mlir/IR/PatternMatch.h"
36 #include "mlir/IR/TypeUtilities.h"
37 #include "mlir/IR/Types.h"
38 #include "mlir/Interfaces/VectorInterfaces.h"
39 
40 #include "llvm/Support/CommandLine.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/raw_ostream.h"
43 
44 #define DEBUG_TYPE "vector-to-vector"
45 
46 using namespace mlir;
47 using llvm::dbgs;
48 
49 // Helper to find an index in an affine map.
getResultIndex(AffineMap map,int64_t index)50 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
51   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
52     int64_t idx = map.getDimPosition(i);
53     if (idx == index)
54       return i;
55   }
56   return None;
57 }
58 
59 // Helper to construct iterator types with one index removed.
adjustIter(ArrayAttr iteratorTypes,int64_t index)60 static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
61                                             int64_t index) {
62   SmallVector<Attribute, 4> results;
63   for (auto it : llvm::enumerate(iteratorTypes)) {
64     int64_t idx = it.index();
65     if (idx == index)
66       continue;
67     results.push_back(it.value());
68   }
69   return results;
70 }
71 
72 // Helper to construct an affine map with one index removed.
adjustMap(AffineMap map,int64_t index,PatternRewriter & rewriter)73 static AffineMap adjustMap(AffineMap map, int64_t index,
74                            PatternRewriter &rewriter) {
75   auto *ctx = rewriter.getContext();
76   SmallVector<AffineExpr, 4> results;
77   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
78     int64_t idx = map.getDimPosition(i);
79     if (idx == index)
80       continue;
81     // Re-insert remaining indices, but renamed when occurring
82     // after the removed index.
83     auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
84     results.push_back(targetExpr);
85   }
86   return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
87 }
88 
89 // Helper to drop dimension from vector type.
adjustType(VectorType tp,int64_t index)90 static Type adjustType(VectorType tp, int64_t index) {
91   int64_t rank = tp.getRank();
92   Type eltType = tp.getElementType();
93   if (rank == 1) {
94     assert(index == 0 && "index for scalar result out of bounds");
95     return eltType;
96   }
97   SmallVector<int64_t, 4> adjustedShape;
98   for (int64_t i = 0; i < rank; ++i) {
99     // Omit dimension at the given index.
100     if (i == index)
101       continue;
102     // Otherwise, add dimension back.
103     adjustedShape.push_back(tp.getDimSize(i));
104   }
105   return VectorType::get(adjustedShape, eltType);
106 }
107 
108 // Helper method to possibly drop a dimension in a load.
109 // TODO
reshapeLoad(Location loc,Value val,VectorType type,int64_t index,int64_t pos,PatternRewriter & rewriter)110 static Value reshapeLoad(Location loc, Value val, VectorType type,
111                          int64_t index, int64_t pos,
112                          PatternRewriter &rewriter) {
113   if (index == -1)
114     return val;
115   Type lowType = adjustType(type, 0);
116   // At extraction dimension?
117   if (index == 0) {
118     auto posAttr = rewriter.getI64ArrayAttr(pos);
119     return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
120   }
121   // Unroll leading dimensions.
122   VectorType vType = lowType.cast<VectorType>();
123   VectorType resType = adjustType(type, index).cast<VectorType>();
124   Value result =
125       rewriter.create<ConstantOp>(loc, resType, rewriter.getZeroAttr(resType));
126   for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
127     auto posAttr = rewriter.getI64ArrayAttr(d);
128     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
129     Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
130     result =
131         rewriter.create<vector::InsertOp>(loc, resType, load, result, posAttr);
132   }
133   return result;
134 }
135 
136 // Helper method to possibly drop a dimension in a store.
137 // TODO
reshapeStore(Location loc,Value val,Value result,VectorType type,int64_t index,int64_t pos,PatternRewriter & rewriter)138 static Value reshapeStore(Location loc, Value val, Value result,
139                           VectorType type, int64_t index, int64_t pos,
140                           PatternRewriter &rewriter) {
141   // Unmodified?
142   if (index == -1)
143     return val;
144   // At insertion dimension?
145   if (index == 0) {
146     auto posAttr = rewriter.getI64ArrayAttr(pos);
147     return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
148   }
149   // Unroll leading dimensions.
150   Type lowType = adjustType(type, 0);
151   VectorType vType = lowType.cast<VectorType>();
152   Type insType = adjustType(vType, 0);
153   for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
154     auto posAttr = rewriter.getI64ArrayAttr(d);
155     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
156     Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
157     Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
158     result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
159   }
160   return result;
161 }
162 
163 // Clones `op` into a new operations that takes `operands` and returns
164 // `resultTypes`.
cloneOpWithOperandsAndTypes(OpBuilder & builder,Location loc,Operation * op,ArrayRef<Value> operands,ArrayRef<Type> resultTypes)165 static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
166                                               Operation *op,
167                                               ArrayRef<Value> operands,
168                                               ArrayRef<Type> resultTypes) {
169   OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
170                      op->getAttrs());
171   return builder.createOperation(res);
172 }
173 
174 // Populates 'resultElements[indexMap[i]]' with elements from 'inputElements[i]'
175 // for each index 'i' in inputElements with a valid mapping in 'indexMap'.
getMappedElements(const DenseMap<int64_t,int64_t> & indexMap,ArrayRef<int64_t> inputElements,SmallVectorImpl<int64_t> & resultElements)176 static void getMappedElements(const DenseMap<int64_t, int64_t> &indexMap,
177                               ArrayRef<int64_t> inputElements,
178                               SmallVectorImpl<int64_t> &resultElements) {
179   assert(indexMap.size() == resultElements.size());
180   assert(inputElements.size() >= resultElements.size());
181   for (unsigned i = 0, e = inputElements.size(); i < e; ++i) {
182     auto it = indexMap.find(i);
183     if (it != indexMap.end())
184       resultElements[it->second] = inputElements[i];
185   }
186 }
187 
188 // Returns a tuple type with vector element types for each resulting slice
189 // of 'vectorType' unrolled by 'sizes' and 'strides'.
190 // TODO: Move this to a utility function and share it with
191 // Extract/InsertSlicesOp verification.
generateExtractSlicesOpResultType(VectorType vectorType,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides,OpBuilder & builder)192 static TupleType generateExtractSlicesOpResultType(VectorType vectorType,
193                                                    ArrayRef<int64_t> sizes,
194                                                    ArrayRef<int64_t> strides,
195                                                    OpBuilder &builder) {
196   assert(llvm::all_of(strides, [](int64_t s) { return s == 1; }));
197   assert(static_cast<int64_t>(sizes.size()) == vectorType.getRank());
198   assert(static_cast<int64_t>(strides.size()) == vectorType.getRank());
199 
200   // Compute shape ratio of 'shape' and 'sizes'.
201   auto shape = vectorType.getShape();
202   auto maybeDimSliceCounts = shapeRatio(shape, sizes);
203   assert(maybeDimSliceCounts.hasValue());
204   auto sliceDimCounts = *maybeDimSliceCounts;
205 
206   // Compute strides w.r.t number of slices in each dimension.
207   auto sliceStrides = computeStrides(sliceDimCounts);
208   int64_t sliceCount = computeMaxLinearIndex(sliceDimCounts);
209   SmallVector<Type, 4> vectorTypes(sliceCount);
210   for (unsigned i = 0; i < sliceCount; ++i) {
211     auto vectorOffsets = delinearize(sliceStrides, i);
212     auto elementOffsets =
213         computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
214     auto sliceSizes = computeSliceSizes(shape, sizes, elementOffsets);
215     // Create Vector type and add to 'vectorTypes[i]'.
216     vectorTypes[i] = VectorType::get(sliceSizes, vectorType.getElementType());
217   }
218   return TupleType::get(vectorTypes, builder.getContext());
219 }
220 
221 // UnrolledVectorState aggregates per-operand/result vector state required for
222 // unrolling.
223 struct UnrolledVectorState {
224   SmallVector<int64_t, 4> unrolledShape;
225   SmallVector<int64_t, 4> unrollFactors;
226   SmallVector<int64_t, 8> basis;
227   int64_t numInstances;
228   Value slicesTuple;
229 };
230 
231 // Populates 'state' with unrolled shape, unroll factors, basis and
232 // num unrolled instances for 'vectorType'.
initUnrolledVectorState(VectorType vectorType,Value initValue,const DenseMap<int64_t,int64_t> & indexMap,ArrayRef<int64_t> targetShape,UnrolledVectorState & state,OpBuilder & builder)233 static void initUnrolledVectorState(VectorType vectorType, Value initValue,
234                                     const DenseMap<int64_t, int64_t> &indexMap,
235                                     ArrayRef<int64_t> targetShape,
236                                     UnrolledVectorState &state,
237                                     OpBuilder &builder) {
238   // Compute unrolled shape of 'vectorType'.
239   state.unrolledShape.resize(vectorType.getRank());
240   getMappedElements(indexMap, targetShape, state.unrolledShape);
241   // Compute unroll factors for unrolled shape.
242   auto maybeUnrollFactors =
243       shapeRatio(vectorType.getShape(), state.unrolledShape);
244   assert(maybeUnrollFactors.hasValue());
245   state.unrollFactors = *maybeUnrollFactors;
246   // Compute 'basis' and 'numInstances' based on 'state.unrollFactors'.
247   state.basis = computeStrides(state.unrollFactors);
248   state.numInstances = computeMaxLinearIndex(state.unrollFactors);
249   state.slicesTuple = nullptr;
250   if (initValue != nullptr) {
251     // Create ExtractSlicesOp.
252     SmallVector<int64_t, 4> sizes(state.unrolledShape);
253     SmallVector<int64_t, 4> strides(state.unrollFactors.size(), 1);
254     auto tupleType =
255         generateExtractSlicesOpResultType(vectorType, sizes, strides, builder);
256     state.slicesTuple = builder.create<vector::ExtractSlicesOp>(
257         initValue.getLoc(), tupleType, initValue, sizes, strides);
258   }
259 }
260 
261 // Computes and returns the linear index of the unrolled vector at
262 // 'vectorOffsets' within the vector represented by 'state'.
263 static int64_t
getUnrolledVectorLinearIndex(UnrolledVectorState & state,ArrayRef<int64_t> vectorOffsets,DenseMap<int64_t,int64_t> & indexMap)264 getUnrolledVectorLinearIndex(UnrolledVectorState &state,
265                              ArrayRef<int64_t> vectorOffsets,
266                              DenseMap<int64_t, int64_t> &indexMap) {
267   // Compute vector offsets.
268   SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size());
269   getMappedElements(indexMap, vectorOffsets, sliceOffsets);
270   // Compute and return linear index of 'sliceOffsets' w.r.t 'state.basis'.
271   return linearize(sliceOffsets, state.basis);
272 }
273 
274 // Returns an unrolled vector at 'vectorOffsets' within the vector
275 // represented by 'state'. The vector is created from a slice of 'initValue'
276 // if not present in 'cache'.
getOrCreateUnrolledVectorSlice(Location loc,UnrolledVectorState & state,ArrayRef<int64_t> vectorOffsets,ArrayRef<int64_t> offsets,DenseMap<int64_t,int64_t> & indexMap,Value initValue,SmallVectorImpl<Value> & cache,OpBuilder & builder)277 static Value getOrCreateUnrolledVectorSlice(
278     Location loc, UnrolledVectorState &state, ArrayRef<int64_t> vectorOffsets,
279     ArrayRef<int64_t> offsets, DenseMap<int64_t, int64_t> &indexMap,
280     Value initValue, SmallVectorImpl<Value> &cache, OpBuilder &builder) {
281   // Compute slice offsets.
282   SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size());
283   getMappedElements(indexMap, offsets, sliceOffsets);
284   // TODO: Support non-1 strides.
285   SmallVector<int64_t, 4> sliceStrides(state.unrolledShape.size(), 1);
286   // Compute linear index of 'sliceOffsets' w.r.t 'state.basis'.
287   int64_t sliceLinearIndex =
288       getUnrolledVectorLinearIndex(state, vectorOffsets, indexMap);
289   assert(sliceLinearIndex < static_cast<int64_t>(cache.size()));
290   auto valueSlice = cache[sliceLinearIndex];
291   if (valueSlice == nullptr) {
292     // Return tuple element at 'sliceLinearIndex'.
293     auto tupleIndex = builder.getI64IntegerAttr(sliceLinearIndex);
294     auto initValueType = initValue.getType().cast<VectorType>();
295     auto vectorType =
296         VectorType::get(state.unrolledShape, initValueType.getElementType());
297     // Initialize 'cache' with slice from 'initValue'.
298     valueSlice = builder.create<vector::TupleGetOp>(
299         loc, vectorType, state.slicesTuple, tupleIndex);
300     // Store value back to 'cache'.
301     cache[sliceLinearIndex] = valueSlice;
302   }
303   return valueSlice;
304 }
305 
306 // VectorState aggregates per-operand/result vector state required for
307 // creating slices of vector operands, and clones of the operation being
308 // unrolled.
309 struct VectorState {
310   // The type of this vector.
311   VectorType type;
312   // Map from iteration space index to vector dimension index.
313   DenseMap<int64_t, int64_t> indexMap;
314   // Index of this value in operation's operand list (-1 if not an operand).
315   int64_t operandIndex = -1;
316   // Accumulator iterator flag.
317   bool isAcc = false;
318 };
319 
320 //
321 // unrollSingleResultStructuredOp
322 //
323 // Returns a value representing the result of structured operation 'op'
324 // with iteration bounds 'iterationBounds' unrolled to 'targetShape'.
325 // A list of VectorState objects must be specified in 'vectors', where
326 // each VectorState in the list represents a vector operand or vector result
327 // (if the operation does not have an accumulator operand).
328 // The VectorState at index 'resultIndex' in the list must be the state
329 // associated with the operations single result (i.e. either its accumulator
330 // operand or vector result value).
331 //
332 // Example:
333 //
334 //  // Before unrolling
335 //
336 //   operand0                operand1                operand2
337 //       \                      |                      /
338 //        -------------------- opA --------------------
339 //
340 //  // After unrolling by 2
341 //
342 //   operand0                operand1                operand2
343 //   /      \                /      \                /      \
344 // slice00  slice01       slice10  slice11        slice20  slice21
345 //   \         |            |          |            /          |
346 //    -------------------- opA0 --------------------           |
347 //             |            |          |                       |
348 //              \           |          |                      /
349 //               -------------------- opA1 -------------------
350 //                          |          |
351 //                           \        /
352 //                           insertslice
353 //                                |
354 
355 // TODO: Add the following canonicalization/simplification patterns:
356 // *) Add pattern which matches InsertStridedSlice -> StridedSlice and forwards
357 //    InsertStridedSlice operand to StridedSlice.
358 // *) Add pattern which matches SourceOp -> StridedSlice -> UserOp which checks
359 //    if there are duplicate identical StridedSlice ops from SourceOp, and
360 //    rewrites itself to use the first duplicate. This transformation should
361 //    cause users of identifical StridedSlice ops to reuse the same StridedSlice
362 //    operation, and leave the duplicate StridedSlice ops with no users
363 //    (removable with DCE).
364 
365 // TODO: Generalize this to support structured ops beyond
366 // vector ContractionOp, and merge it with 'unrollSingleResultVectorOp'
unrollSingleResultStructuredOp(Operation * op,ArrayRef<int64_t> iterationBounds,std::vector<VectorState> & vectors,unsigned resultIndex,ArrayRef<int64_t> targetShape,OpBuilder & builder)367 static Value unrollSingleResultStructuredOp(Operation *op,
368                                             ArrayRef<int64_t> iterationBounds,
369                                             std::vector<VectorState> &vectors,
370                                             unsigned resultIndex,
371                                             ArrayRef<int64_t> targetShape,
372                                             OpBuilder &builder) {
373   auto shapedType = op->getResult(0).getType().dyn_cast_or_null<ShapedType>();
374   if (!shapedType || !shapedType.hasStaticShape())
375     assert(false && "Expected a statically shaped result type");
376 
377   // Compute unroll factors for 'iterationBounds' based on 'targetShape'
378   auto maybeUnrollFactors = shapeRatio(iterationBounds, targetShape);
379   if (!maybeUnrollFactors.hasValue())
380     assert(false && "Failed to compute unroll factors for target shape");
381   auto unrollFactors = *maybeUnrollFactors;
382 
383   // Compute unrolled vector state for each vector in 'vectors'.
384   unsigned numVectors = vectors.size();
385   SmallVector<UnrolledVectorState, 3> unrolledVectorState(numVectors);
386   for (unsigned i = 0; i < numVectors; ++i) {
387     int64_t operandIndex = vectors[i].operandIndex;
388     auto operand = operandIndex >= 0 ? op->getOperand(operandIndex) : nullptr;
389     initUnrolledVectorState(vectors[i].type, operand, vectors[i].indexMap,
390                             targetShape, unrolledVectorState[i], builder);
391   }
392   // Compute number of total unrolled instances.
393   auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors);
394   auto sliceStrides = computeStrides(unrollFactors);
395 
396   auto &resultValueState = unrolledVectorState[resultIndex];
397   auto unrolledResultType = VectorType::get(resultValueState.unrolledShape,
398                                             shapedType.getElementType());
399 
400   // Initialize caches for intermediate vector results.
401   std::vector<SmallVector<Value, 4>> caches(numVectors);
402   for (unsigned i = 0; i < numVectors; ++i)
403     caches[i].resize(unrolledVectorState[i].numInstances);
404 
405   // Unroll 'numUnrolledInstances' of 'op', storing results in 'caches'.
406   for (unsigned i = 0; i < numUnrolledInstances; ++i) {
407     auto vectorOffsets = delinearize(sliceStrides, i);
408     auto elementOffsets =
409         computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets);
410     // Get cached slice (or create slice) for each operand at 'offsets'.
411     SmallVector<Value, 3> operands;
412     operands.resize(op->getNumOperands());
413     for (unsigned i = 0; i < numVectors; ++i) {
414       int64_t operandIndex = vectors[i].operandIndex;
415       if (operandIndex < 0)
416         continue; // Output
417       auto operand = op->getOperand(operandIndex);
418       operands[operandIndex] = getOrCreateUnrolledVectorSlice(
419           op->getLoc(), unrolledVectorState[i], vectorOffsets, elementOffsets,
420           vectors[i].indexMap, operand, caches[i], builder);
421     }
422     // Create op on sliced vector arguments.
423     auto resultVector =
424         cloneOpWithOperandsAndTypes(builder, op->getLoc(), op, operands,
425                                     unrolledResultType)
426             ->getResult(0);
427 
428     // Compute linear result index.
429     int64_t linearIndex = getUnrolledVectorLinearIndex(
430         resultValueState, vectorOffsets, vectors[resultIndex].indexMap);
431     // Update result cache at 'linearIndex'.
432     caches[resultIndex][linearIndex] = resultVector;
433   }
434 
435   // Create TupleOp of unrolled result vectors.
436   SmallVector<Type, 4> vectorTupleTypes(resultValueState.numInstances);
437   SmallVector<Value, 4> vectorTupleValues(resultValueState.numInstances);
438   for (unsigned i = 0; i < resultValueState.numInstances; ++i) {
439     vectorTupleTypes[i] = caches[resultIndex][i].getType().cast<VectorType>();
440     vectorTupleValues[i] = caches[resultIndex][i];
441   }
442   TupleType tupleType = builder.getTupleType(vectorTupleTypes);
443   Value tupleOp = builder.create<vector::TupleOp>(op->getLoc(), tupleType,
444                                                   vectorTupleValues);
445 
446   // Create InsertSlicesOp(Tuple(result_vectors)).
447   auto resultVectorType = op->getResult(0).getType().cast<VectorType>();
448   SmallVector<int64_t, 4> sizes(resultValueState.unrolledShape);
449   SmallVector<int64_t, 4> strides(resultValueState.unrollFactors.size(), 1);
450 
451   Value insertSlicesOp = builder.create<vector::InsertSlicesOp>(
452       op->getLoc(), resultVectorType, tupleOp, builder.getI64ArrayAttr(sizes),
453       builder.getI64ArrayAttr(strides));
454   return insertSlicesOp;
455 }
456 
getVectorContractionOpUnrollState(vector::ContractionOp contractionOp,ArrayRef<int64_t> targetShape,std::vector<VectorState> & vectors,unsigned & resultIndex)457 static void getVectorContractionOpUnrollState(
458     vector::ContractionOp contractionOp, ArrayRef<int64_t> targetShape,
459     std::vector<VectorState> &vectors, unsigned &resultIndex) {
460   // Get map from iteration space index to lhs/rhs/result shape index.
461   std::vector<DenseMap<int64_t, int64_t>> iterationIndexMapList;
462   contractionOp.getIterationIndexMap(iterationIndexMapList);
463   unsigned numIterators = iterationIndexMapList.size();
464   vectors.resize(numIterators);
465   unsigned accOperandIndex = vector::ContractionOp::getAccOperandIndex();
466   for (unsigned i = 0; i < numIterators; ++i) {
467     vectors[i].type = contractionOp.getOperand(i).getType().cast<VectorType>();
468     vectors[i].indexMap = iterationIndexMapList[i];
469     vectors[i].operandIndex = i;
470     vectors[i].isAcc = i == accOperandIndex ? true : false;
471   }
472 
473   if (llvm::size(contractionOp.masks()) == 2) {
474     // Add vectors for lhs/rhs vector mask arguments. Masks have the
475     // same vector shape lhs/rhs args, so copy their index maps.
476     vectors.push_back({contractionOp.getLHSVectorMaskType(),
477                        vectors[0].indexMap, accOperandIndex + 1, false});
478     vectors.push_back({contractionOp.getRHSVectorMaskType(),
479                        vectors[1].indexMap, accOperandIndex + 2, false});
480   }
481   // TODO: Use linalg style 'args_in'/'args_out' to partition
482   // 'vectors' instead of 'resultIndex'.
483   resultIndex = accOperandIndex;
484 }
485 
getVectorElementwiseOpUnrollState(Operation * op,ArrayRef<int64_t> targetShape,std::vector<VectorState> & vectors,unsigned & resultIndex)486 static void getVectorElementwiseOpUnrollState(Operation *op,
487                                               ArrayRef<int64_t> targetShape,
488                                               std::vector<VectorState> &vectors,
489                                               unsigned &resultIndex) {
490   // Verify that operation and operands all have the same vector shape.
491   auto resultType = op->getResult(0).getType().dyn_cast_or_null<VectorType>();
492   assert(resultType && "Expected op with vector result type");
493   auto resultShape = resultType.getShape();
494   // Verify that all operands have the same vector type as result.
495   assert(llvm::all_of(op->getOperandTypes(),
496                       [=](Type type) { return type == resultType; }));
497 
498   // Create trivial elementwise identity index map based on 'resultShape'.
499   DenseMap<int64_t, int64_t> indexMap;
500   indexMap.reserve(resultShape.size());
501   for (unsigned i = 0; i < resultShape.size(); ++i)
502     indexMap[i] = i;
503 
504   // Create VectorState each operand and single result.
505   unsigned numVectors = op->getNumOperands() + op->getNumResults();
506   vectors.resize(numVectors);
507   for (unsigned i = 0; i < op->getNumOperands(); ++i)
508     vectors[i] = {resultType, indexMap, i, false};
509   vectors[numVectors - 1] = {resultType, indexMap, -1, false};
510   resultIndex = numVectors - 1;
511 }
512 
513 /// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
514 /// calls 'fn' with linear index and indices for each slice.
generateTransferOpSlices(Type memrefElementType,VectorType vectorType,TupleType tupleType,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides,ArrayRef<Value> indices,OpBuilder & builder,function_ref<void (unsigned,ArrayRef<Value>)> fn)515 static void generateTransferOpSlices(
516     Type memrefElementType, VectorType vectorType, TupleType tupleType,
517     ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides, ArrayRef<Value> indices,
518     OpBuilder &builder, function_ref<void(unsigned, ArrayRef<Value>)> fn) {
519   // Compute strides w.r.t. to slice counts in each dimension.
520   auto maybeDimSliceCounts = shapeRatio(vectorType.getShape(), sizes);
521   assert(maybeDimSliceCounts.hasValue());
522   auto sliceDimCounts = *maybeDimSliceCounts;
523   auto sliceStrides = computeStrides(sliceDimCounts);
524 
525   int64_t numSlices = tupleType.size();
526   unsigned numSliceIndices = indices.size();
527   // Compute 'indexOffset' at which to update 'indices', which is equal
528   // to the memref rank (indices.size) minus the effective 'vectorRank'.
529   // The effective 'vectorRank', is equal to the rank of the vector type
530   // minus the rank of the memref vector element type (if it has one).
531   //
532   // For example:
533   //
534   //   Given memref type 'memref<6x2x1xvector<2x4xf32>>' and vector
535   //   transfer_read/write ops which read/write vectors of type
536   //   'vector<2x1x2x4xf32>'. The memref rank is 3, and the effective
537   //   vector rank is 4 - 2 = 2, and so 'indexOffset' = 3 - 2 = 1.
538   //
539   unsigned vectorRank = vectorType.getRank();
540   if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
541     assert(vectorRank >= memrefVectorElementType.getRank());
542     vectorRank -= memrefVectorElementType.getRank();
543   }
544   unsigned indexOffset = numSliceIndices - vectorRank;
545 
546   auto *ctx = builder.getContext();
547   for (unsigned i = 0; i < numSlices; ++i) {
548     auto vectorOffsets = delinearize(sliceStrides, i);
549     auto elementOffsets =
550         computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
551     // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
552     SmallVector<Value, 4> sliceIndices(numSliceIndices);
553     for (unsigned j = 0; j < numSliceIndices; ++j) {
554       if (j < indexOffset) {
555         sliceIndices[j] = indices[j];
556       } else {
557         auto expr = getAffineDimExpr(0, ctx) +
558                     getAffineConstantExpr(elementOffsets[j - indexOffset], ctx);
559         auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
560         sliceIndices[j] = builder.create<AffineApplyOp>(
561             indices[j].getLoc(), map, ArrayRef<Value>(indices[j]));
562       }
563     }
564     // Call 'fn' to generate slice 'i' at 'sliceIndices'.
565     fn(i, sliceIndices);
566   }
567 }
568 
569 /// Returns true if 'map' is a suffix of an identity affine map, false
570 /// otherwise. Example: affine_map<(d0, d1, d2, d3) -> (d2, d3)>
isIdentitySuffix(AffineMap map)571 static bool isIdentitySuffix(AffineMap map) {
572   if (map.getNumDims() < map.getNumResults())
573     return false;
574   ArrayRef<AffineExpr> results = map.getResults();
575   Optional<int> lastPos;
576   for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
577     auto expr = results[i].dyn_cast<AffineDimExpr>();
578     if (!expr)
579       return false;
580     int currPos = static_cast<int>(expr.getPosition());
581     if (lastPos.hasValue() && currPos != lastPos.getValue() + 1)
582       return false;
583     lastPos = currPos;
584   }
585   return true;
586 }
587 
588 /// Unroll transfer_read ops to the given shape and create an aggregate with all
589 /// the chunks.
unrollTransferReadOp(vector::TransferReadOp readOp,ArrayRef<int64_t> targetShape,OpBuilder & builder)590 static Value unrollTransferReadOp(vector::TransferReadOp readOp,
591                                   ArrayRef<int64_t> targetShape,
592                                   OpBuilder &builder) {
593   if (!isIdentitySuffix(readOp.permutation_map()))
594     return nullptr;
595   auto sourceVectorType = readOp.getVectorType();
596   SmallVector<int64_t, 4> strides(targetShape.size(), 1);
597 
598   Location loc = readOp.getLoc();
599   auto memrefElementType =
600       readOp.memref().getType().cast<MemRefType>().getElementType();
601   auto tupleType = generateExtractSlicesOpResultType(
602       sourceVectorType, targetShape, strides, builder);
603   int64_t numSlices = tupleType.size();
604 
605   SmallVector<Value, 4> vectorTupleValues(numSlices);
606   SmallVector<Value, 4> indices(readOp.indices().begin(),
607                                 readOp.indices().end());
608   auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
609     // Get VectorType for slice 'i'.
610     auto sliceVectorType = tupleType.getType(index);
611     // Create split TransferReadOp for 'sliceUser'.
612     // `masked` attribute propagates conservatively: if the coarse op didn't
613     // need masking, the fine op doesn't either.
614     vectorTupleValues[index] = builder.create<vector::TransferReadOp>(
615         loc, sliceVectorType, readOp.memref(), sliceIndices,
616         readOp.permutation_map(), readOp.padding(),
617         readOp.masked() ? *readOp.masked() : ArrayAttr());
618   };
619   generateTransferOpSlices(memrefElementType, sourceVectorType, tupleType,
620                            targetShape, strides, indices, builder, createSlice);
621 
622   // Create tuple of splice transfer read operations.
623   Value tupleOp =
624       builder.create<vector::TupleOp>(loc, tupleType, vectorTupleValues);
625   // Replace 'readOp' with result 'insertSlicesResult'.
626   Value newVec = builder.create<vector::InsertSlicesOp>(
627       loc, sourceVectorType, tupleOp, builder.getI64ArrayAttr(targetShape),
628       builder.getI64ArrayAttr(strides));
629   return newVec;
630 }
631 
632 // Entry point for unrolling declarative pattern rewrite for transfer_write op.
633 LogicalResult
unrollTransferWriteOp(OpBuilder & builder,Operation * op,ArrayRef<int64_t> targetShape)634 mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op,
635                                     ArrayRef<int64_t> targetShape) {
636   auto writeOp = cast<vector::TransferWriteOp>(op);
637   if (!isIdentitySuffix(writeOp.permutation_map()))
638     return failure();
639   VectorType sourceVectorType = writeOp.getVectorType();
640   SmallVector<int64_t, 4> strides(targetShape.size(), 1);
641   TupleType tupleType = generateExtractSlicesOpResultType(
642       sourceVectorType, targetShape, strides, builder);
643   Location loc = writeOp.getLoc();
644   Value tuple = builder.create<vector::ExtractSlicesOp>(
645       loc, tupleType, writeOp.vector(), targetShape, strides);
646   auto memrefElementType =
647       writeOp.memref().getType().cast<MemRefType>().getElementType();
648   SmallVector<Value, 4> indices(writeOp.indices().begin(),
649                                 writeOp.indices().end());
650   auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
651     auto element = builder.create<vector::TupleGetOp>(
652         loc, tupleType.getType(index), tuple, builder.getI64IntegerAttr(index));
653     builder.create<vector::TransferWriteOp>(
654         loc, element.getResult(), writeOp.memref(), sliceIndices,
655         writeOp.permutation_map(),
656         writeOp.masked() ? *writeOp.masked() : ArrayAttr());
657   };
658   generateTransferOpSlices(memrefElementType, sourceVectorType, tupleType,
659                            targetShape, strides, indices, builder, createSlice);
660   return success();
661 }
662 
663 // Entry point for unrolling declarative pattern rewrites.
664 SmallVector<Value, 1>
unrollSingleResultVectorOp(OpBuilder & builder,Operation * op,ArrayRef<int64_t> targetShape)665 mlir::vector::unrollSingleResultVectorOp(OpBuilder &builder, Operation *op,
666                                          ArrayRef<int64_t> targetShape) {
667   assert(op->getNumResults() == 1 && "Expected single result operation");
668 
669   // Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'.
670   SmallVector<int64_t, 6> iterationBounds;
671   auto unrollableVectorOp = cast<VectorUnrollOpInterface>(op);
672   auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
673   assert(maybeUnrollShape && "Trying to unroll an incorrect vector op");
674 
675   std::vector<VectorState> vectors;
676   unsigned resultIndex;
677 
678   if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
679     return SmallVector<Value, 1>{
680         unrollTransferReadOp(readOp, targetShape, builder)};
681 
682   if (auto contractionOp = dyn_cast<vector::ContractionOp>(op)) {
683     // Populate state for vector ContractionOp.
684     getVectorContractionOpUnrollState(contractionOp, targetShape, vectors,
685                                       resultIndex);
686   } else {
687     // Populate state for vector elementwise op.
688     getVectorElementwiseOpUnrollState(op, targetShape, vectors, resultIndex);
689   }
690 
691   // Unroll 'op' with 'iterationBounds' to 'targetShape'.
692   return SmallVector<Value, 1>{unrollSingleResultStructuredOp(
693       op, *maybeUnrollShape, vectors, resultIndex, targetShape, builder)};
694 }
695 
696 namespace {
697 
698 // Splits vector TransferReadOp into smaller TransferReadOps based on slicing
699 // scheme of its unique ExtractSlicesOp user.
700 struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
701   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
702 
matchAndRewrite__anona2fda2e20511::SplitTransferReadOp703   LogicalResult matchAndRewrite(vector::TransferReadOp xferReadOp,
704                                 PatternRewriter &rewriter) const override {
705     // TODO: Support splitting TransferReadOp with non-identity
706     // permutation maps. Repurpose code from MaterializeVectors transformation.
707     if (!isIdentitySuffix(xferReadOp.permutation_map()))
708       return failure();
709     // Return unless the unique 'xferReadOp' user is an ExtractSlicesOp.
710     Value xferReadResult = xferReadOp.getResult();
711     auto extractSlicesOp =
712         dyn_cast<vector::ExtractSlicesOp>(*xferReadResult.getUsers().begin());
713     if (!xferReadResult.hasOneUse() || !extractSlicesOp)
714       return failure();
715 
716     // Get 'sizes' and 'strides' parameters from ExtractSlicesOp user.
717     SmallVector<int64_t, 4> sizes;
718     extractSlicesOp.getSizes(sizes);
719     SmallVector<int64_t, 4> strides;
720     extractSlicesOp.getStrides(strides);
721     assert(llvm::all_of(strides, [](int64_t s) { return s == 1; }));
722 
723     Value newVec = unrollTransferReadOp(xferReadOp, sizes, rewriter);
724     if (!newVec)
725       return failure();
726     rewriter.replaceOp(xferReadOp, newVec);
727     return success();
728   }
729 };
730 
731 // Splits vector TransferWriteOp into smaller TransferWriteOps for each source.
732 struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
733   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
734 
matchAndRewrite__anona2fda2e20511::SplitTransferWriteOp735   LogicalResult matchAndRewrite(vector::TransferWriteOp xferWriteOp,
736                                 PatternRewriter &rewriter) const override {
737     // TODO: Support splitting TransferWriteOp with non-identity
738     // permutation maps. Repurpose code from MaterializeVectors transformation.
739     if (!isIdentitySuffix(xferWriteOp.permutation_map()))
740       return failure();
741     // Return unless the 'xferWriteOp' 'vector' operand is an 'InsertSlicesOp'.
742     auto *vectorDefOp = xferWriteOp.vector().getDefiningOp();
743     auto insertSlicesOp = dyn_cast_or_null<vector::InsertSlicesOp>(vectorDefOp);
744     if (!insertSlicesOp)
745       return failure();
746 
747     // Get TupleOp operand of 'insertSlicesOp'.
748     auto tupleOp = dyn_cast_or_null<vector::TupleOp>(
749         insertSlicesOp.vectors().getDefiningOp());
750     if (!tupleOp)
751       return failure();
752 
753     // Get 'sizes' and 'strides' parameters from InsertSlicesOp user.
754     auto sourceTupleType = insertSlicesOp.getSourceTupleType();
755     auto resultVectorType = insertSlicesOp.getResultVectorType();
756     SmallVector<int64_t, 4> sizes;
757     insertSlicesOp.getSizes(sizes);
758     SmallVector<int64_t, 4> strides;
759     insertSlicesOp.getStrides(strides);
760 
761     Location loc = xferWriteOp.getLoc();
762     auto memrefElementType =
763         xferWriteOp.memref().getType().cast<MemRefType>().getElementType();
764     SmallVector<Value, 4> indices(xferWriteOp.indices().begin(),
765                                   xferWriteOp.indices().end());
766     auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
767       // Create split TransferWriteOp for source vector 'tupleOp.operand[i]'.
768       // `masked` attribute propagates conservatively: if the coarse op didn't
769       // need masking, the fine op doesn't either.
770       rewriter.create<vector::TransferWriteOp>(
771           loc, tupleOp.getOperand(index), xferWriteOp.memref(), sliceIndices,
772           xferWriteOp.permutation_map(),
773           xferWriteOp.masked() ? *xferWriteOp.masked() : ArrayAttr());
774     };
775     generateTransferOpSlices(memrefElementType, resultVectorType,
776                              sourceTupleType, sizes, strides, indices, rewriter,
777                              createSlice);
778 
779     // Erase old 'xferWriteOp'.
780     rewriter.eraseOp(xferWriteOp);
781     return success();
782   }
783 };
784 
785 /// Decomposes ShapeCastOp on tuple-of-vectors to multiple ShapeCastOps, each
786 /// on vector types.
787 struct ShapeCastOpDecomposer : public OpRewritePattern<vector::ShapeCastOp> {
788   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
789 
matchAndRewrite__anona2fda2e20511::ShapeCastOpDecomposer790   LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
791                                 PatternRewriter &rewriter) const override {
792     // Check if 'shapeCastOp' has tuple source/result type.
793     auto sourceTupleType =
794         shapeCastOp.source().getType().dyn_cast_or_null<TupleType>();
795     auto resultTupleType =
796         shapeCastOp.result().getType().dyn_cast_or_null<TupleType>();
797     if (!sourceTupleType || !resultTupleType)
798       return failure();
799     assert(sourceTupleType.size() == resultTupleType.size());
800 
801     // Create single-vector ShapeCastOp for each source tuple element.
802     Location loc = shapeCastOp.getLoc();
803     SmallVector<Value, 8> resultElements;
804     resultElements.reserve(resultTupleType.size());
805     for (unsigned i = 0, e = sourceTupleType.size(); i < e; ++i) {
806       auto sourceElement = rewriter.create<vector::TupleGetOp>(
807           loc, sourceTupleType.getType(i), shapeCastOp.source(),
808           rewriter.getI64IntegerAttr(i));
809       resultElements.push_back(rewriter.create<vector::ShapeCastOp>(
810           loc, resultTupleType.getType(i), sourceElement));
811     }
812 
813     // Replace 'shapeCastOp' with tuple of 'resultElements'.
814     rewriter.replaceOpWithNewOp<vector::TupleOp>(shapeCastOp, resultTupleType,
815                                                  resultElements);
816     return success();
817   }
818 };
819 
820 /// Returns the producer Value of the same type as 'consumerValue', by tracking
821 /// the tuple index and offsets of the consumer vector value through the
822 /// chain of operations (TupleGetOp, InsertSlicesOp, ExtractSlicesOp, TupleOp,
823 /// and ShapeCastOp) from consumer to producer. Each operation in the chain is
824 /// structured, and so the tuple index and offsets can be mapped from result to
825 /// input, while visiting each operation in the chain.
826 /// Returns nullptr on failure.
getProducerValue(Value consumerValue)827 static Value getProducerValue(Value consumerValue) {
828   auto consumerVectorType = consumerValue.getType().cast<VectorType>();
829   // A tupleIndex == -1 indicates that 'offsets' are w.r.t a vector type.
830   int64_t tupleIndex = -1;
831   SmallVector<int64_t, 4> offsets(consumerVectorType.getRank(), 0);
832   auto *op = consumerValue.getDefiningOp();
833   while (op != nullptr) {
834     if (auto tupleGetOp = dyn_cast<vector::TupleGetOp>(op)) {
835       assert(tupleIndex == -1 && "TupleGetOp must have vector result type");
836 
837       // Update 'tupleIndex' and next defining 'op' to visit.
838       tupleIndex = tupleGetOp.getIndex();
839       op = tupleGetOp.vectors().getDefiningOp();
840     } else if (auto extractSlicesOp = dyn_cast<vector::ExtractSlicesOp>(op)) {
841       assert(tupleIndex >= 0);
842 
843       // Compute slice strides for 'extractSlicesOp'.
844       SmallVector<int64_t, 4> sizes;
845       extractSlicesOp.getSizes(sizes);
846       auto sliceStrides = computeStrides(
847           extractSlicesOp.getSourceVectorType().getShape(), sizes);
848 
849       // Compute 'elementOffsets' into 'extractSlicesOp' input vector type,
850       // of 'extractSlicesOp' result vector tuple element at 'tupleIndex'.
851       auto vectorOffsets = delinearize(sliceStrides, tupleIndex);
852       auto elementOffsets =
853           computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
854 
855       // Add 'elementOffsets' to 'offsets' so that 'offsets' are now relative
856       // to the 'extractSlicesOp' input vector type.
857       assert(offsets.size() == elementOffsets.size());
858       for (unsigned i = 0, e = offsets.size(); i < e; ++i)
859         offsets[i] += elementOffsets[i];
860 
861       // Clear 'tupleIndex' and update next defining 'op' to visit.
862       tupleIndex = -1;
863       op = extractSlicesOp.vector().getDefiningOp();
864     } else if (auto insertSlicesOp = dyn_cast<vector::InsertSlicesOp>(op)) {
865       assert(tupleIndex == -1);
866 
867       // Compute slice strides for 'insertSlicesOp'.
868       SmallVector<int64_t, 4> sizes;
869       insertSlicesOp.getSizes(sizes);
870       auto sliceStrides = computeStrides(
871           insertSlicesOp.getResultVectorType().getShape(), sizes);
872 
873       // Compute 'vectorOffsets' of 'insertSlicesOp' input vector slice,
874       // of 'insertSlicesOp' result vector type at 'offsets'.
875       SmallVector<int64_t, 4> vectorOffsets(offsets.size());
876       assert(offsets.size() == sizes.size());
877       for (unsigned i = 0, e = offsets.size(); i < e; ++i)
878         vectorOffsets[i] = offsets[i] / sizes[i];
879 
880       // Compute the source tuple element index.
881       tupleIndex = linearize(vectorOffsets, sliceStrides);
882 
883       // Subtract 'elementOffsets' from 'offsets' so that 'offsets' are now
884       // relative to input tuple element vector type at 'tupleIndex'.
885       auto elementOffsets =
886           computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
887       assert(offsets.size() == elementOffsets.size());
888       for (unsigned i = 0, e = offsets.size(); i < e; ++i) {
889         offsets[i] -= elementOffsets[i];
890         assert(offsets[i] >= 0);
891       }
892 
893       // Update next defining 'op' to visit.
894       op = insertSlicesOp.vectors().getDefiningOp();
895     } else if (auto tupleOp = dyn_cast<vector::TupleOp>(op)) {
896       assert(tupleIndex >= 0);
897 
898       // Return tuple element 'value' at 'tupleIndex' if it matches type.
899       auto value = tupleOp.getOperand(tupleIndex);
900       if (value.getType() == consumerVectorType)
901         return value;
902 
903       // Update 'tupleIndex' and next defining 'op' to visit.
904       tupleIndex = -1;
905       op = value.getDefiningOp();
906     } else if (auto shapeCastOp = dyn_cast<vector::ShapeCastOp>(op)) {
907       if (shapeCastOp.source().getType().isa<TupleType>())
908         return nullptr;
909       assert(tupleIndex == -1);
910       auto sourceVectorType = shapeCastOp.getSourceVectorType();
911       auto sourceVectorShape = sourceVectorType.getShape();
912       unsigned sourceVectorRank = sourceVectorType.getRank();
913       auto resultVectorType = shapeCastOp.getResultVectorType();
914       auto resultVectorShape = resultVectorType.getShape();
915       unsigned resultVectorRank = resultVectorType.getRank();
916 
917       int i = sourceVectorRank - 1;
918       int j = resultVectorRank - 1;
919 
920       // Check that source/result vector shape prefixes match while updating
921       // 'newOffsets'.
922       SmallVector<int64_t, 4> newOffsets(sourceVectorRank, 0);
923       for (auto it : llvm::zip(llvm::reverse(sourceVectorShape),
924                                llvm::reverse(resultVectorShape))) {
925         if (std::get<0>(it) != std::get<1>(it))
926           return nullptr;
927         newOffsets[i--] = offsets[j--];
928       }
929 
930       // Check that remaining prefix of source/result vector shapes are all 1s.
931       // Currently we only support producer/consumer tracking through trivial
932       // shape cast ops. Examples:
933       //   %1 = vector.shape_cast %0 : vector<1x1x2x4xf32> to vector<2x4xf32>
934       //   %3 = vector.shape_cast %2 : vector<16x8xf32> to vector<1x16x8xf32>
935       assert(i == -1 || j == -1);
936       if (i >= 0 &&
937           !std::all_of(sourceVectorShape.begin(), sourceVectorShape.begin() + i,
938                        [](int64_t v) { return v == 1; }))
939         return nullptr;
940       if (j >= 0 &&
941           !std::all_of(resultVectorShape.begin(), resultVectorShape.begin() + j,
942                        [](int64_t v) { return v == 1; }))
943         return nullptr;
944 
945       offsets.swap(newOffsets);
946       op = shapeCastOp.source().getDefiningOp();
947     } else {
948       // Check if 'op' produces a Value with the same type as 'consumerValue'.
949       if (op->getNumResults() == 1 &&
950           op->getResult(0).getType() == consumerVectorType)
951         return op->getResult(0);
952       return nullptr;
953     }
954   }
955   return nullptr;
956 }
957 
958 /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
959 //
960 // Example:
961 //
962 //  The following MLIR with cancelling ShapeCastOps:
963 //
964 //   %0 = source : vector<5x4x2xf32>
965 //   %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
966 //   %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
967 //   %3 = user %2 : vector<5x4x2xf32>
968 //
969 //  Should canonicalize to the following:
970 //
971 //   %0 = source : vector<5x4x2xf32>
972 //   %1 = user %0 : vector<5x4x2xf32>
973 //
974 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
975   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
976 
matchAndRewrite__anona2fda2e20511::ShapeCastOpFolder977   LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
978                                 PatternRewriter &rewriter) const override {
979     // Check if we can replace 'shapeCastOp' result with its producer.
980     if (auto producer = getProducerValue(shapeCastOp.getResult())) {
981       rewriter.replaceOp(shapeCastOp, producer);
982       return success();
983     }
984 
985     // Check if 'shapeCastOp' has vector source/result type.
986     auto sourceVectorType =
987         shapeCastOp.source().getType().dyn_cast_or_null<VectorType>();
988     auto resultVectorType =
989         shapeCastOp.result().getType().dyn_cast_or_null<VectorType>();
990     if (!sourceVectorType || !resultVectorType)
991       return failure();
992 
993     // Check if shape cast op source operand is also a shape cast op.
994     auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
995         shapeCastOp.source().getDefiningOp());
996     if (!sourceShapeCastOp)
997       return failure();
998     auto operandSourceVectorType =
999         sourceShapeCastOp.source().getType().cast<VectorType>();
1000     auto operandResultVectorType =
1001         sourceShapeCastOp.result().getType().cast<VectorType>();
1002 
1003     // Check if shape cast operations invert each other.
1004     if (operandSourceVectorType != resultVectorType ||
1005         operandResultVectorType != sourceVectorType)
1006       return failure();
1007 
1008     rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source());
1009     return success();
1010   }
1011 };
1012 
1013 // Patter rewrite which forward tuple elements to their users.
1014 // User(TupleGetOp(ExtractSlicesOp(InsertSlicesOp(TupleOp(Producer)))))
1015 //   -> User(Producer)
1016 struct TupleGetFolderOp : public OpRewritePattern<vector::TupleGetOp> {
1017   using OpRewritePattern<vector::TupleGetOp>::OpRewritePattern;
1018 
matchAndRewrite__anona2fda2e20511::TupleGetFolderOp1019   LogicalResult matchAndRewrite(vector::TupleGetOp tupleGetOp,
1020                                 PatternRewriter &rewriter) const override {
1021     if (auto producer = getProducerValue(tupleGetOp.getResult())) {
1022       rewriter.replaceOp(tupleGetOp, producer);
1023       return success();
1024     }
1025     return failure();
1026   }
1027 };
1028 
1029 /// Progressive lowering of ExtractSlicesOp to tuple of ExtractStridedSliceOp.
1030 /// One:
1031 ///   %x = vector.extract_slices %0
1032 /// is replaced by:
1033 ///   %a = vector.strided_slice %0
1034 ///   %b = vector.strided_slice %0
1035 ///   ..
1036 ///   %x = vector.tuple %a, %b, ..
1037 class ExtractSlicesOpLowering
1038     : public OpRewritePattern<vector::ExtractSlicesOp> {
1039 public:
1040   using OpRewritePattern<vector::ExtractSlicesOp>::OpRewritePattern;
1041 
matchAndRewrite(vector::ExtractSlicesOp op,PatternRewriter & rewriter) const1042   LogicalResult matchAndRewrite(vector::ExtractSlicesOp op,
1043                                 PatternRewriter &rewriter) const override {
1044     auto loc = op.getLoc();
1045 
1046     VectorType vectorType = op.getSourceVectorType();
1047     auto shape = vectorType.getShape();
1048 
1049     SmallVector<int64_t, 4> sizes;
1050     op.getSizes(sizes);
1051     SmallVector<int64_t, 4> strides;
1052     op.getStrides(strides); // all-ones at the moment
1053 
1054     // For each element in the tuple, generate the proper strided slice.
1055     TupleType tupleType = op.getResultTupleType();
1056     int64_t tupleSize = tupleType.size();
1057     SmallVector<Value, 4> tupleValues(tupleSize);
1058     auto sliceStrides = computeStrides(shape, sizes);
1059     for (int64_t i = 0; i < tupleSize; ++i) {
1060       auto vectorOffsets = delinearize(sliceStrides, i);
1061       auto elementOffsets =
1062           computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
1063       auto sliceSizes = computeSliceSizes(shape, sizes, elementOffsets);
1064       // Insert in tuple.
1065       tupleValues[i] = rewriter.create<vector::ExtractStridedSliceOp>(
1066           loc, op.vector(), elementOffsets, sliceSizes, strides);
1067     }
1068 
1069     rewriter.replaceOpWithNewOp<vector::TupleOp>(op, tupleType, tupleValues);
1070     return success();
1071   }
1072 };
1073 
1074 /// Progressive lowering of InsertSlicesOp to series of InsertStridedSliceOp.
1075 /// One:
1076 ///   %x = vector.insert_slices %0
1077 /// is replaced by:
1078 ///   %r0 = zero-result
1079 ///   %t1 = vector.tuple_get %0, 0
1080 ///   %r1 = vector.insert_strided_slice %r0, %t1
1081 ///   %t2 = vector.tuple_get %0, 1
1082 ///   %r2 = vector.insert_strided_slice %r1, %t2
1083 ///   ..
1084 ///   %x  = ..
1085 class InsertSlicesOpLowering : public OpRewritePattern<vector::InsertSlicesOp> {
1086 public:
1087   using OpRewritePattern<vector::InsertSlicesOp>::OpRewritePattern;
1088 
matchAndRewrite(vector::InsertSlicesOp op,PatternRewriter & rewriter) const1089   LogicalResult matchAndRewrite(vector::InsertSlicesOp op,
1090                                 PatternRewriter &rewriter) const override {
1091     auto loc = op.getLoc();
1092 
1093     VectorType vectorType = op.getResultVectorType();
1094     auto shape = vectorType.getShape();
1095 
1096     SmallVector<int64_t, 4> sizes;
1097     op.getSizes(sizes);
1098     SmallVector<int64_t, 4> strides;
1099     op.getStrides(strides); // all-ones at the moment
1100 
1101     // Prepare result.
1102     Value result = rewriter.create<ConstantOp>(
1103         loc, vectorType, rewriter.getZeroAttr(vectorType));
1104 
1105     // For each element in the tuple, extract the proper strided slice.
1106     TupleType tupleType = op.getSourceTupleType();
1107     int64_t tupleSize = tupleType.size();
1108     auto sliceStrides = computeStrides(shape, sizes);
1109     for (int64_t i = 0; i < tupleSize; ++i) {
1110       auto vectorOffsets = delinearize(sliceStrides, i);
1111       auto elementOffsets =
1112           computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
1113       // Extract from tuple into the result.
1114       auto index = rewriter.getI64IntegerAttr(i);
1115       auto tupleGet = rewriter.create<vector::TupleGetOp>(
1116           loc, tupleType.getType(i), op.getOperand(), index);
1117       result = rewriter.create<vector::InsertStridedSliceOp>(
1118           loc, tupleGet, result, elementOffsets, strides);
1119     }
1120 
1121     rewriter.replaceOp(op, result);
1122     return success();
1123   }
1124 };
1125 
1126 /// Progressive lowering of BroadcastOp.
1127 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
1128 public:
1129   using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
1130 
matchAndRewrite(vector::BroadcastOp op,PatternRewriter & rewriter) const1131   LogicalResult matchAndRewrite(vector::BroadcastOp op,
1132                                 PatternRewriter &rewriter) const override {
1133     auto loc = op.getLoc();
1134     VectorType dstType = op.getVectorType();
1135     VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
1136     Type eltType = dstType.getElementType();
1137 
1138     // Determine rank of source and destination.
1139     int64_t srcRank = srcType ? srcType.getRank() : 0;
1140     int64_t dstRank = dstType.getRank();
1141 
1142     // Duplicate this rank.
1143     // For example:
1144     //   %x = broadcast %y  : k-D to n-D, k < n
1145     // becomes:
1146     //   %b = broadcast %y  : k-D to (n-1)-D
1147     //   %x = [%b,%b,%b,%b] : n-D
1148     // becomes:
1149     //   %b = [%y,%y]       : (n-1)-D
1150     //   %x = [%b,%b,%b,%b] : n-D
1151     if (srcRank < dstRank) {
1152       // Scalar to any vector can use splat.
1153       if (srcRank == 0) {
1154         rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, op.source());
1155         return success();
1156       }
1157       // Duplication.
1158       VectorType resType =
1159           VectorType::get(dstType.getShape().drop_front(), eltType);
1160       Value bcst =
1161           rewriter.create<vector::BroadcastOp>(loc, resType, op.source());
1162       Value result = rewriter.create<ConstantOp>(loc, dstType,
1163                                                  rewriter.getZeroAttr(dstType));
1164       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
1165         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
1166       rewriter.replaceOp(op, result);
1167       return success();
1168     }
1169 
1170     // Find non-matching dimension, if any.
1171     assert(srcRank == dstRank);
1172     int64_t m = -1;
1173     for (int64_t r = 0; r < dstRank; r++)
1174       if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
1175         m = r;
1176         break;
1177       }
1178 
1179     // All trailing dimensions are the same. Simply pass through.
1180     if (m == -1) {
1181       rewriter.replaceOp(op, op.source());
1182       return success();
1183     }
1184 
1185     // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
1186     if (srcRank == 1) {
1187       assert(m == 0);
1188       Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
1189       rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, ext);
1190       return success();
1191     }
1192 
1193     // Any non-matching dimension forces a stretch along this rank.
1194     // For example:
1195     //   %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
1196     // becomes:
1197     //   %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
1198     //   %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
1199     //   %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
1200     //   %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
1201     //   %x = [%a,%b,%c,%d]
1202     // becomes:
1203     //   %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
1204     //   %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
1205     //   %a = [%u, %v]
1206     //   ..
1207     //   %x = [%a,%b,%c,%d]
1208     VectorType resType =
1209         VectorType::get(dstType.getShape().drop_front(), eltType);
1210     Value result = rewriter.create<ConstantOp>(loc, dstType,
1211                                                rewriter.getZeroAttr(dstType));
1212     if (m == 0) {
1213       // Stetch at start.
1214       Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
1215       Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
1216       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
1217         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
1218     } else {
1219       // Stetch not at start.
1220       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
1221         Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), d);
1222         Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
1223         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
1224       }
1225     }
1226     rewriter.replaceOp(op, result);
1227     return success();
1228   }
1229 };
1230 
1231 /// Progressive lowering of TransposeOp.
1232 /// One:
1233 ///   %x = vector.transpose %y, [1, 0]
1234 /// is replaced by:
1235 ///   %z = constant dense<0.000000e+00>
1236 ///   %0 = vector.extract %y[0, 0]
1237 ///   %1 = vector.insert %0, %z [0, 0]
1238 ///   ..
1239 ///   %x = vector.insert .., .. [.., ..]
1240 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
1241 public:
1242   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
1243 
TransposeOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,MLIRContext * context)1244   TransposeOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,
1245                       MLIRContext *context)
1246       : OpRewritePattern<vector::TransposeOp>(context),
1247         vectorTransformsOptions(vectorTransformsOptions) {}
1248 
matchAndRewrite(vector::TransposeOp op,PatternRewriter & rewriter) const1249   LogicalResult matchAndRewrite(vector::TransposeOp op,
1250                                 PatternRewriter &rewriter) const override {
1251     auto loc = op.getLoc();
1252 
1253     VectorType resType = op.getResultType();
1254 
1255     // Set up convenience transposition table.
1256     SmallVector<int64_t, 4> transp;
1257     for (auto attr : op.transp())
1258       transp.push_back(attr.cast<IntegerAttr>().getInt());
1259 
1260     // Handle a true 2-D matrix transpose differently when requested.
1261     if (vectorTransformsOptions.vectorTransposeLowering ==
1262             vector::VectorTransposeLowering::Flat &&
1263         resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
1264       Type flattenedType =
1265           VectorType::get(resType.getNumElements(), resType.getElementType());
1266       auto matrix =
1267           rewriter.create<vector::ShapeCastOp>(loc, flattenedType, op.vector());
1268       auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
1269       auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
1270       Value trans = rewriter.create<vector::FlatTransposeOp>(
1271           loc, flattenedType, matrix, rows, columns);
1272       rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
1273       return success();
1274     }
1275 
1276     // Generate fully unrolled extract/insert ops.
1277     Value result = rewriter.create<ConstantOp>(loc, resType,
1278                                                rewriter.getZeroAttr(resType));
1279     SmallVector<int64_t, 4> lhs(transp.size(), 0);
1280     SmallVector<int64_t, 4> rhs(transp.size(), 0);
1281     rewriter.replaceOp(op, expandIndices(loc, resType, 0, transp, lhs, rhs,
1282                                          op.vector(), result, rewriter));
1283     return success();
1284   }
1285 
1286 private:
1287   // Builds the indices arrays for the lhs and rhs. Generates the extract/insert
1288   // operation when al ranks are exhausted.
expandIndices(Location loc,VectorType resType,int64_t pos,SmallVector<int64_t,4> & transp,SmallVector<int64_t,4> & lhs,SmallVector<int64_t,4> & rhs,Value input,Value result,PatternRewriter & rewriter) const1289   Value expandIndices(Location loc, VectorType resType, int64_t pos,
1290                       SmallVector<int64_t, 4> &transp,
1291                       SmallVector<int64_t, 4> &lhs,
1292                       SmallVector<int64_t, 4> &rhs, Value input, Value result,
1293                       PatternRewriter &rewriter) const {
1294     if (pos >= resType.getRank()) {
1295       auto ridx = rewriter.getI64ArrayAttr(rhs);
1296       auto lidx = rewriter.getI64ArrayAttr(lhs);
1297       Type eltType = resType.getElementType();
1298       Value e = rewriter.create<vector::ExtractOp>(loc, eltType, input, ridx);
1299       return rewriter.create<vector::InsertOp>(loc, resType, e, result, lidx);
1300     }
1301     for (int64_t d = 0, e = resType.getDimSize(pos); d < e; ++d) {
1302       lhs[pos] = d;
1303       rhs[transp[pos]] = d;
1304       result = expandIndices(loc, resType, pos + 1, transp, lhs, rhs, input,
1305                              result, rewriter);
1306     }
1307     return result;
1308   }
1309 
1310   /// Options to control the vector patterns.
1311   vector::VectorTransformsOptions vectorTransformsOptions;
1312 };
1313 
1314 /// Progressive lowering of OuterProductOp.
1315 /// One:
1316 ///   %x = vector.outerproduct %lhs, %rhs, %acc
1317 /// is replaced by:
1318 ///   %z = zero-result
1319 ///   %0 = vector.extract %lhs[0]
1320 ///   %1 = vector.broadcast %0
1321 ///   %2 = vector.extract %acc[0]
1322 ///   %3 = vector.fma %1, %rhs, %2
1323 ///   %4 = vector.insert %3, %z[0]
1324 ///   ..
1325 ///   %x = vector.insert %.., %..[N-1]
1326 ///
1327 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
1328 public:
1329   using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
1330 
matchAndRewrite(vector::OuterProductOp op,PatternRewriter & rewriter) const1331   LogicalResult matchAndRewrite(vector::OuterProductOp op,
1332                                 PatternRewriter &rewriter) const override {
1333     auto loc = op.getLoc();
1334 
1335     VectorType lhsType = op.getOperandVectorTypeLHS();
1336     VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
1337     VectorType resType = op.getVectorType();
1338     Type eltType = resType.getElementType();
1339     bool isInt = eltType.isa<IntegerType>();
1340     Value acc = (op.acc().empty()) ? nullptr : op.acc()[0];
1341 
1342     if (!rhsType) {
1343       // Special case: AXPY operation.
1344       Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.rhs());
1345       rewriter.replaceOp(op, genMult(loc, op.lhs(), b, acc, isInt, rewriter));
1346       return success();
1347     }
1348 
1349     Value result = rewriter.create<ConstantOp>(loc, resType,
1350                                                rewriter.getZeroAttr(resType));
1351     for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
1352       auto pos = rewriter.getI64ArrayAttr(d);
1353       Value x = rewriter.create<vector::ExtractOp>(loc, eltType, op.lhs(), pos);
1354       Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
1355       Value r = nullptr;
1356       if (acc)
1357         r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
1358       Value m = genMult(loc, a, op.rhs(), r, isInt, rewriter);
1359       result = rewriter.create<vector::InsertOp>(loc, resType, m, result, pos);
1360     }
1361     rewriter.replaceOp(op, result);
1362     return success();
1363   }
1364 
1365 private:
genMult(Location loc,Value x,Value y,Value acc,bool isInt,PatternRewriter & rewriter)1366   static Value genMult(Location loc, Value x, Value y, Value acc, bool isInt,
1367                        PatternRewriter &rewriter) {
1368     if (acc) {
1369       if (isInt)
1370         return rewriter.create<AddIOp>(loc, rewriter.create<MulIOp>(loc, x, y),
1371                                        acc);
1372       return rewriter.create<vector::FMAOp>(loc, x, y, acc);
1373     }
1374     if (isInt)
1375       return rewriter.create<MulIOp>(loc, x, y);
1376     return rewriter.create<MulFOp>(loc, x, y);
1377   }
1378 };
1379 
1380 /// Progressive lowering of ConstantMaskOp.
1381 /// One:
1382 ///   %x = vector.constant_mask [a,b]
1383 /// is replaced by:
1384 ///   %z = zero-result
1385 ///   %l = vector.constant_mask [b]
1386 ///   %4 = vector.insert %l, %z[0]
1387 ///   ..
1388 ///   %x = vector.insert %l, %..[a-1]
1389 /// until a one-dimensional vector is reached. All these operations
1390 /// will be folded at LLVM IR level.
1391 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
1392 public:
1393   using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern;
1394 
matchAndRewrite(vector::ConstantMaskOp op,PatternRewriter & rewriter) const1395   LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
1396                                 PatternRewriter &rewriter) const override {
1397     auto loc = op.getLoc();
1398     auto dstType = op.getResult().getType().cast<VectorType>();
1399     auto eltType = dstType.getElementType();
1400     auto dimSizes = op.mask_dim_sizes();
1401     int64_t rank = dimSizes.size();
1402     int64_t trueDim = std::min(dstType.getDimSize(0),
1403                                dimSizes[0].cast<IntegerAttr>().getInt());
1404 
1405     if (rank == 1) {
1406       // Express constant 1-D case in explicit vector form:
1407       //   [T,..,T,F,..,F].
1408       SmallVector<bool, 4> values(dstType.getDimSize(0));
1409       for (int64_t d = 0; d < trueDim; d++)
1410         values[d] = true;
1411       rewriter.replaceOpWithNewOp<ConstantOp>(
1412           op, dstType, rewriter.getBoolVectorAttr(values));
1413       return success();
1414     }
1415 
1416     VectorType lowType =
1417         VectorType::get(dstType.getShape().drop_front(), eltType);
1418     SmallVector<int64_t, 4> newDimSizes;
1419     for (int64_t r = 1; r < rank; r++)
1420       newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
1421     Value trueVal = rewriter.create<vector::ConstantMaskOp>(
1422         loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
1423     Value result = rewriter.create<ConstantOp>(loc, dstType,
1424                                                rewriter.getZeroAttr(dstType));
1425     for (int64_t d = 0; d < trueDim; d++) {
1426       auto pos = rewriter.getI64ArrayAttr(d);
1427       result =
1428           rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
1429     }
1430     rewriter.replaceOp(op, result);
1431     return success();
1432   }
1433 };
1434 
1435 /// Progressive lowering of CreateMaskOp.
1436 /// One:
1437 ///   %x = vector.create_mask %a, ... : vector<dx...>
1438 /// is replaced by:
1439 ///   %l = vector.create_mask ... : vector<...>  ; one lower rank
1440 ///   %0 = cmpi "slt", %ci, %a       |
1441 ///   %1 = select %0, %l, %zeroes    |
1442 ///   %r = vector.insert %1, %pr [i] | d-times
1443 ///   %x = ....
1444 /// until a one-dimensional vector is reached.
1445 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
1446 public:
1447   using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern;
1448 
matchAndRewrite(vector::CreateMaskOp op,PatternRewriter & rewriter) const1449   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1450                                 PatternRewriter &rewriter) const override {
1451     auto loc = op.getLoc();
1452     auto dstType = op.getResult().getType().cast<VectorType>();
1453     auto eltType = dstType.getElementType();
1454     int64_t dim = dstType.getDimSize(0);
1455     int64_t rank = dstType.getRank();
1456     Value idx = op.getOperand(0);
1457 
1458     if (rank == 1)
1459       return failure(); // leave for lowering
1460 
1461     VectorType lowType =
1462         VectorType::get(dstType.getShape().drop_front(), eltType);
1463     Value trueVal = rewriter.create<vector::CreateMaskOp>(
1464         loc, lowType, op.getOperands().drop_front());
1465     Value falseVal = rewriter.create<ConstantOp>(loc, lowType,
1466                                                  rewriter.getZeroAttr(lowType));
1467     Value result = rewriter.create<ConstantOp>(loc, dstType,
1468                                                rewriter.getZeroAttr(dstType));
1469     for (int64_t d = 0; d < dim; d++) {
1470       Value bnd = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(d));
1471       Value val = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, bnd, idx);
1472       Value sel = rewriter.create<SelectOp>(loc, val, trueVal, falseVal);
1473       auto pos = rewriter.getI64ArrayAttr(d);
1474       result =
1475           rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
1476     }
1477     rewriter.replaceOp(op, result);
1478     return success();
1479   }
1480 };
1481 
1482 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
1483 /// vectors progressively on the way to target llvm.matrix intrinsics.
1484 /// This iterates over the most major dimension of the 2-D vector and performs
1485 /// rewrites into:
1486 ///   vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
1487 class ShapeCastOp2DDownCastRewritePattern
1488     : public OpRewritePattern<vector::ShapeCastOp> {
1489 public:
1490   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
1491 
matchAndRewrite(vector::ShapeCastOp op,PatternRewriter & rewriter) const1492   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
1493                                 PatternRewriter &rewriter) const override {
1494     auto sourceVectorType = op.getSourceVectorType();
1495     auto resultVectorType = op.getResultVectorType();
1496     if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
1497       return failure();
1498 
1499     auto loc = op.getLoc();
1500     Value desc = rewriter.create<ConstantOp>(
1501         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
1502     unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
1503     for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
1504       Value vec = rewriter.create<vector::ExtractOp>(loc, op.source(), i);
1505       desc = rewriter.create<vector::InsertStridedSliceOp>(
1506           loc, vec, desc,
1507           /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
1508     }
1509     rewriter.replaceOp(op, desc);
1510     return success();
1511   }
1512 };
1513 
1514 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
1515 /// vectors progressively on the way from targeting llvm.matrix intrinsics.
1516 /// This iterates over the most major dimension of the 2-D vector and performs
1517 /// rewrites into:
1518 ///   vector.strided_slice from 1-D + vector.insert into 2-D
1519 class ShapeCastOp2DUpCastRewritePattern
1520     : public OpRewritePattern<vector::ShapeCastOp> {
1521 public:
1522   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
1523 
matchAndRewrite(vector::ShapeCastOp op,PatternRewriter & rewriter) const1524   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
1525                                 PatternRewriter &rewriter) const override {
1526     auto sourceVectorType = op.getSourceVectorType();
1527     auto resultVectorType = op.getResultVectorType();
1528     if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
1529       return failure();
1530 
1531     auto loc = op.getLoc();
1532     Value desc = rewriter.create<ConstantOp>(
1533         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
1534     unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
1535     for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
1536       Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
1537           loc, op.source(), /*offsets=*/i * mostMinorVectorSize,
1538           /*sizes=*/mostMinorVectorSize,
1539           /*strides=*/1);
1540       desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
1541     }
1542     rewriter.replaceOp(op, desc);
1543     return success();
1544   }
1545 };
1546 
1547 // We typically should not lower general shape cast operations into data
1548 // movement instructions, since the assumption is that these casts are
1549 // optimized away during progressive lowering. For completeness, however,
1550 // we fall back to a reference implementation that moves all elements
1551 // into the right place if we get here.
1552 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
1553 public:
1554   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
1555 
matchAndRewrite(vector::ShapeCastOp op,PatternRewriter & rewriter) const1556   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
1557                                 PatternRewriter &rewriter) const override {
1558     Location loc = op.getLoc();
1559     auto sourceVectorType = op.getSourceVectorType();
1560     auto resultVectorType = op.getResultVectorType();
1561     // Intended 2D/1D lowerings with better implementations.
1562     int64_t srcRank = sourceVectorType.getRank();
1563     int64_t resRank = resultVectorType.getRank();
1564     if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
1565       return failure();
1566     // Compute number of elements involved in the reshape.
1567     int64_t numElts = 1;
1568     for (int64_t r = 0; r < srcRank; r++)
1569       numElts *= sourceVectorType.getDimSize(r);
1570     // Replace with data movement operations:
1571     //    x[0,0,0] = y[0,0]
1572     //    x[0,0,1] = y[0,1]
1573     //    x[0,1,0] = y[0,2]
1574     // etc., incrementing the two index vectors "row-major"
1575     // within the source and result shape.
1576     SmallVector<int64_t, 4> srcIdx(srcRank);
1577     SmallVector<int64_t, 4> resIdx(resRank);
1578     Value result = rewriter.create<ConstantOp>(
1579         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
1580     for (int64_t i = 0; i < numElts; i++) {
1581       if (i != 0) {
1582         incIdx(srcIdx, sourceVectorType, srcRank - 1);
1583         incIdx(resIdx, resultVectorType, resRank - 1);
1584       }
1585       Value e = rewriter.create<vector::ExtractOp>(loc, op.source(), srcIdx);
1586       result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
1587     }
1588     rewriter.replaceOp(op, result);
1589     return success();
1590   }
1591 
1592 private:
incIdx(SmallVector<int64_t,4> & idx,VectorType tp,int64_t r)1593   static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) {
1594     assert(0 <= r && r < tp.getRank());
1595     if (++idx[r] == tp.getDimSize(r)) {
1596       idx[r] = 0;
1597       incIdx(idx, tp, r - 1);
1598     }
1599   }
1600 };
1601 
1602 } // namespace
1603 
1604 namespace mlir {
1605 
1606 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1607 /// semantics to:
1608 /// ```
1609 ///    %flattened_a = vector.shape_cast %a
1610 ///    %flattened_b = vector.shape_cast %b
1611 ///    %flattened_d = vector.matmul %flattened_a, %flattened_b
1612 ///    %d = vector.shape_cast %%flattened_d
1613 ///    %e = add %c, %d
1614 /// ```
1615 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
1616 //
1617 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
1618 /// the vector.contract op is a row-major matrix multiply.
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rewriter) const1619 LogicalResult ContractionOpToMatmulOpLowering::matchAndRewrite(
1620     vector::ContractionOp op, PatternRewriter &rewriter) const {
1621   // TODO: implement masks
1622   if (llvm::size(op.masks()) != 0)
1623     return failure();
1624   if (vectorTransformsOptions.vectorContractLowering !=
1625       vector::VectorContractLowering::Matmul)
1626     return failure();
1627   if (failed(filter(op)))
1628     return failure();
1629 
1630   auto iteratorTypes = op.iterator_types().getValue();
1631   if (!isParallelIterator(iteratorTypes[0]) ||
1632       !isParallelIterator(iteratorTypes[1]) ||
1633       !isReductionIterator(iteratorTypes[2]))
1634     return failure();
1635 
1636   if (!isRowMajorMatmul(op.indexing_maps()))
1637     return failure();
1638 
1639   Type elementType = op.getLhsType().getElementType();
1640   if (!elementType.isIntOrFloat())
1641     return failure();
1642 
1643   VectorType lhsType = op.getLhsType();
1644   VectorType rhsType = op.getRhsType();
1645   int64_t lhsRows = lhsType.getDimSize(0);
1646   int64_t lhsColumns = lhsType.getDimSize(1);
1647   int64_t rhsColumns = rhsType.getDimSize(1);
1648 
1649   Type flattenedLHSType =
1650       VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1651   Type flattenedRHSType =
1652       VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1653   auto lhs = rewriter.create<vector::ShapeCastOp>(op.getLoc(), flattenedLHSType,
1654                                                   op.lhs());
1655   auto rhs = rewriter.create<vector::ShapeCastOp>(op.getLoc(), flattenedRHSType,
1656                                                   op.rhs());
1657 
1658   Value mul = rewriter.create<vector::MatmulOp>(op.getLoc(), lhs, rhs, lhsRows,
1659                                                 lhsColumns, rhsColumns);
1660   mul = rewriter.create<vector::ShapeCastOp>(op.getLoc(), op.acc().getType(),
1661                                              mul);
1662   if (elementType.isa<IntegerType>())
1663     rewriter.replaceOpWithNewOp<AddIOp>(op, op.acc(), mul);
1664   else
1665     rewriter.replaceOpWithNewOp<AddFOp>(op, op.acc(), mul);
1666 
1667   return success();
1668 }
1669 
1670 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1671 /// semantics to a reduction_size-unrolled sequence:
1672 /// ```
1673 ///    %at = vector.transpose %a, [1, 0]
1674 ///    %bRow0 = vector.extract %b[0]
1675 ///    %atRow0 = vector.extract %at[0]
1676 ///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
1677 ///    ...
1678 ///    %bRowK = vector.extract %b[K]
1679 ///    %atRowK = vector.extract %at[K]
1680 ///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
1681 /// ```
1682 ///
1683 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
1684 /// otherwise supports any layout permutation of the matrix-multiply.
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rewriter) const1685 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
1686     vector::ContractionOp op, PatternRewriter &rewriter) const {
1687   // TODO: implement masks
1688   if (llvm::size(op.masks()) != 0)
1689     return failure();
1690 
1691   if (vectorTransformsOptions.vectorContractLowering !=
1692       vector::VectorContractLowering::OuterProduct)
1693     return failure();
1694 
1695   if (failed(filter(op)))
1696     return failure();
1697 
1698   Location loc = op.getLoc();
1699   int64_t reductionSize = 0;
1700   VectorType lhsType = op.getLhsType();
1701   Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
1702 
1703   // Set up the parallel/reduction structure in right form.
1704   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1705   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1706   AffineExpr m, n, k;
1707   bindDims(rewriter.getContext(), m, n, k);
1708   static constexpr std::array<int64_t, 2> perm = {1, 0};
1709   auto iteratorTypes = op.iterator_types().getValue();
1710   SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
1711   if (isParallelIterator(iteratorTypes[0]) &&
1712       isParallelIterator(iteratorTypes[1]) &&
1713       isReductionIterator(iteratorTypes[2])) {
1714     //
1715     // Two outer parallel, one inner reduction (matmat flavor).
1716     //
1717     if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1718       // This is the classical row-major matmul. Just permute the lhs.
1719       reductionSize = lhsType.getDimSize(1);
1720       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1721     } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
1722       // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1723       reductionSize = lhsType.getDimSize(1);
1724       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1725       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1726     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1727       // No need to permute anything.
1728       reductionSize = lhsType.getDimSize(0);
1729     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1730       // Just permute the rhs.
1731       reductionSize = lhsType.getDimSize(0);
1732       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1733     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1734       // This is the classical row-major matmul. Just permute the lhs.
1735       reductionSize = lhsType.getDimSize(1);
1736       Value tmp = rhs;
1737       rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1738       lhs = tmp;
1739     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1740       // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1741       reductionSize = lhsType.getDimSize(1);
1742       Value tmp = rhs;
1743       rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1744       lhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
1745     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1746       // No need to permute anything, but still swap lhs and rhs.
1747       reductionSize = lhsType.getDimSize(0);
1748       std::swap(lhs, rhs);
1749     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1750       // Just permute the rhs.
1751       reductionSize = lhsType.getDimSize(0);
1752       Value tmp = lhs;
1753       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1754       rhs = tmp;
1755     } else {
1756       return failure();
1757     }
1758   } else if (isParallelIterator(iteratorTypes[0]) &&
1759              isReductionIterator(iteratorTypes[1])) {
1760     //
1761     // One outer parallel, one inner reduction (matvec flavor)
1762     //
1763     if (maps == infer({{m, n}, {n}, {m}})) {
1764       // Case mat-vec: transpose.
1765       reductionSize = lhsType.getDimSize(1);
1766       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1767     } else if (maps == infer({{n, m}, {n}, {m}})) {
1768       // Case mat-trans-vec: ready to go.
1769       reductionSize = lhsType.getDimSize(0);
1770     } else if (maps == infer({{n}, {m, n}, {m}})) {
1771       // Case vec-mat: swap and transpose.
1772       reductionSize = lhsType.getDimSize(0);
1773       std::swap(lhs, rhs);
1774       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1775     } else if (maps == infer({{n}, {n, m}, {m}})) {
1776       // Case vec-mat-trans: swap and ready to go.
1777       reductionSize = lhsType.getDimSize(0);
1778       std::swap(lhs, rhs);
1779     } else {
1780       return failure();
1781     }
1782   } else {
1783     return failure();
1784   }
1785   assert(reductionSize > 0);
1786 
1787   // Unroll outer-products along reduction.
1788   for (int64_t k = 0; k < reductionSize; ++k) {
1789     Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, k);
1790     Value b = rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, k);
1791     res = rewriter.create<vector::OuterProductOp>(op.getLoc(), a, b, res);
1792   }
1793   rewriter.replaceOp(op, res);
1794   return success();
1795 }
1796 
1797 LogicalResult
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rewriter) const1798 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
1799                                             PatternRewriter &rewriter) const {
1800   // TODO: implement masks
1801   if (llvm::size(op.masks()) != 0)
1802     return failure();
1803 
1804   if (failed(filter(op)))
1805     return failure();
1806 
1807   if (vectorTransformsOptions.vectorContractLowering !=
1808       vector::VectorContractLowering::Dot)
1809     return failure();
1810 
1811   auto iteratorTypes = op.iterator_types().getValue();
1812   static constexpr std::array<int64_t, 2> perm = {1, 0};
1813   Location loc = op.getLoc();
1814   Value lhs = op.lhs(), rhs = op.rhs();
1815 
1816   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1817   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1818   AffineExpr m, n, k;
1819   bindDims(rewriter.getContext(), m, n, k);
1820   SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
1821   //
1822   // In the following we wish to make the reduction dimension innermost so we
1823   // can load vectors and just fmul + reduce into a scalar.
1824   //
1825   if (isParallelIterator(iteratorTypes[0]) &&
1826       isParallelIterator(iteratorTypes[1]) &&
1827       isReductionIterator(iteratorTypes[2])) {
1828     //
1829     // Two outer parallel, one inner reduction (matmat flavor).
1830     //
1831     if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1832       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1833     } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
1834       // No need to permute anything.
1835     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1836       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1837       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1838     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1839       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1840     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1841       // This is the classical row-major matmul. Just permute the lhs.
1842       Value tmp = lhs;
1843       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1844       rhs = tmp;
1845     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1846       std::swap(lhs, rhs);
1847     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1848       Value tmp = lhs;
1849       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1850       rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
1851     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1852       Value tmp = rhs;
1853       rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1854       lhs = tmp;
1855     } else {
1856       return failure();
1857     }
1858   } else if (isParallelIterator(iteratorTypes[0]) &&
1859              isReductionIterator(iteratorTypes[1])) {
1860     //
1861     // One outer parallel, one inner reduction (matvec flavor)
1862     //
1863     if (maps == infer({{m, n}, {n}, {m}})) {
1864       // No need to permute anything.
1865     } else if (maps == infer({{n, m}, {n}, {m}})) {
1866       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1867     } else if (maps == infer({{n}, {m, n}, {m}})) {
1868       std::swap(lhs, rhs);
1869     } else if (maps == infer({{n}, {n, m}, {m}})) {
1870       std::swap(lhs, rhs);
1871       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1872     } else {
1873       return failure();
1874     }
1875   } else {
1876     return failure();
1877   }
1878 
1879   VectorType dstType = op.getResultType().cast<VectorType>();
1880   assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
1881          "Expected dst type of rank 1 or 2");
1882 
1883   unsigned rank = dstType.getRank();
1884   unsigned dstRows = dstType.getShape()[0];
1885   unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
1886 
1887   // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
1888   Value res =
1889       rewriter.create<ConstantOp>(loc, dstType, rewriter.getZeroAttr(dstType));
1890   for (unsigned r = 0; r < dstRows; ++r) {
1891     Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
1892     for (unsigned c = 0; c < dstColumns; ++c) {
1893       Value b = rank == 1
1894                     ? rhs
1895                     : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
1896       Value m = rewriter.create<MulFOp>(op.getLoc(), a, b);
1897       Value reduced = rewriter.create<vector::ReductionOp>(
1898           op.getLoc(), dstType.getElementType(), rewriter.getStringAttr("add"),
1899           m, ValueRange{});
1900 
1901       SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
1902                                               : SmallVector<int64_t, 2>{r, c};
1903       res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
1904     }
1905   }
1906   if (auto acc = op.acc())
1907     res = rewriter.create<AddFOp>(op.getLoc(), res, acc);
1908   rewriter.replaceOp(op, res);
1909   return success();
1910 }
1911 
1912 /// Progressive lowering of ContractionOp.
1913 /// One:
1914 ///   %x = vector.contract with at least one free/batch dimension
1915 /// is replaced by:
1916 ///   %a = vector.contract with one less free/batch dimension
1917 ///   %b = vector.contract with one less free/batch dimension
1918 ///   ..
1919 ///   %x = combine %a %b ..
1920 /// until a pure contraction is reached (no free/batch dimensions),
1921 /// which is replaced by a dot-product.
1922 ///
1923 /// This only kicks in when either VectorTransformsOptions is set
1924 /// to DOT or when other contraction patterns fail.
1925 //
1926 // TODO: break down into transpose/reshape/cast ops
1927 //               when they become available to avoid code dup
1928 // TODO: investigate lowering order impact on performance
1929 LogicalResult
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rewriter) const1930 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
1931                                        PatternRewriter &rewriter) const {
1932   // TODO: implement masks.
1933   if (llvm::size(op.masks()) != 0)
1934     return failure();
1935 
1936   if (failed(filter(op)))
1937     return failure();
1938 
1939   // TODO: support mixed mode contract lowering.
1940   if (op.getLhsType().getElementType() !=
1941           getElementTypeOrSelf(op.getAccType()) ||
1942       op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
1943     return failure();
1944 
1945   // TODO: implement benefits, cost models.
1946   MLIRContext *ctx = op.getContext();
1947   ContractionOpToMatmulOpLowering pat1(vectorTransformsOptions, ctx);
1948   if (succeeded(pat1.matchAndRewrite(op, rewriter)))
1949     return success();
1950   ContractionOpToOuterProductOpLowering pat2(vectorTransformsOptions, ctx);
1951   if (succeeded(pat2.matchAndRewrite(op, rewriter)))
1952     return success();
1953   ContractionOpToDotLowering pat3(vectorTransformsOptions, ctx);
1954   if (succeeded(pat3.matchAndRewrite(op, rewriter)))
1955     return success();
1956 
1957   // Find first batch dimension in LHS/RHS, and lower when found.
1958   std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
1959   if (!batchDimMap.empty()) {
1960     int64_t lhsIndex = batchDimMap[0].first;
1961     int64_t rhsIndex = batchDimMap[0].second;
1962     rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
1963     return success();
1964   }
1965 
1966   // Collect contracting dimensions.
1967   std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
1968       op.getContractingDimMap();
1969   DenseSet<int64_t> lhsContractingDimSet;
1970   DenseSet<int64_t> rhsContractingDimSet;
1971   for (auto &dimPair : contractingDimMap) {
1972     lhsContractingDimSet.insert(dimPair.first);
1973     rhsContractingDimSet.insert(dimPair.second);
1974   }
1975 
1976   // Find first free dimension in LHS, and lower when found.
1977   VectorType lhsType = op.getLhsType();
1978   for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
1979     if (lhsContractingDimSet.count(lhsIndex) == 0) {
1980       rewriter.replaceOp(
1981           op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
1982       return success();
1983     }
1984   }
1985 
1986   // Find first free dimension in RHS, and lower when found.
1987   VectorType rhsType = op.getRhsType();
1988   for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
1989     if (rhsContractingDimSet.count(rhsIndex) == 0) {
1990       rewriter.replaceOp(
1991           op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
1992       return success();
1993     }
1994   }
1995 
1996   // Lower the first remaining reduction dimension.
1997   if (!contractingDimMap.empty()) {
1998     rewriter.replaceOp(op, lowerReduction(op, rewriter));
1999     return success();
2000   }
2001 
2002   return failure();
2003 }
2004 
2005 // Lower one parallel dimension.
2006 // TODO: consider reusing existing contract unrolling
lowerParallel(vector::ContractionOp op,int64_t lhsIndex,int64_t rhsIndex,PatternRewriter & rewriter) const2007 Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
2008                                            int64_t lhsIndex, int64_t rhsIndex,
2009                                            PatternRewriter &rewriter) const {
2010   VectorType lhsType = op.getLhsType();
2011   VectorType rhsType = op.getRhsType();
2012   VectorType resType = op.getResultType().cast<VectorType>();
2013   // Find the iterator type index and result index.
2014   SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
2015   int64_t iterIndex = -1;
2016   int64_t dimSize = -1;
2017   if (lhsIndex >= 0) {
2018     iterIndex = iMap[0].getDimPosition(lhsIndex);
2019     assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) &&
2020            "parallel index should be free in LHS or batch in LHS/RHS");
2021     dimSize = lhsType.getDimSize(lhsIndex);
2022   } else {
2023     assert(rhsIndex >= 0 && "missing parallel index");
2024     iterIndex = iMap[1].getDimPosition(rhsIndex);
2025     dimSize = rhsType.getDimSize(rhsIndex);
2026   }
2027   assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
2028   Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
2029   assert(lookup.hasValue() && "parallel index not listed in reduction");
2030   int64_t resIndex = lookup.getValue();
2031   // Construct new iterator types and affine map array attribute.
2032   std::array<AffineMap, 3> lowIndexingMaps = {
2033       adjustMap(iMap[0], iterIndex, rewriter),
2034       adjustMap(iMap[1], iterIndex, rewriter),
2035       adjustMap(iMap[2], iterIndex, rewriter)};
2036   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
2037   auto lowIter =
2038       rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
2039   // Unroll into a series of lower dimensional vector.contract ops.
2040   Location loc = op.getLoc();
2041   Value result =
2042       rewriter.create<ConstantOp>(loc, resType, rewriter.getZeroAttr(resType));
2043   for (int64_t d = 0; d < dimSize; ++d) {
2044     auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
2045     auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
2046     auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter);
2047     Value lowContract = rewriter.create<vector::ContractionOp>(
2048         loc, lhs, rhs, acc, lowAffine, lowIter);
2049     result =
2050         reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter);
2051   }
2052   return result;
2053 }
2054 
2055 // Lower one reduction dimension.
lowerReduction(vector::ContractionOp op,PatternRewriter & rewriter) const2056 Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
2057                                             PatternRewriter &rewriter) const {
2058   auto loc = op.getLoc();
2059   VectorType lhsType = op.getLhsType();
2060   VectorType rhsType = op.getRhsType();
2061   Type resType = op.getResultType();
2062   assert(!resType.isa<VectorType>());
2063   // Use iterator index 0.
2064   int64_t iterIndex = 0;
2065   SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
2066   Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
2067   Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
2068   assert(lookupLhs.hasValue() && "missing LHS parallel index");
2069   assert(lookupRhs.hasValue() && "missing RHS parallel index");
2070   int64_t lhsIndex = lookupLhs.getValue();
2071   int64_t rhsIndex = lookupRhs.getValue();
2072   int64_t dimSize = lhsType.getDimSize(lhsIndex);
2073   assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape");
2074   // Base case.
2075   if (lhsType.getRank() == 1) {
2076     assert(rhsType.getRank() == 1 && "corrupt contraction");
2077     Value m = rewriter.create<MulFOp>(loc, op.lhs(), op.rhs());
2078     StringAttr kind = rewriter.getStringAttr("add");
2079     return rewriter.create<vector::ReductionOp>(loc, resType, kind, m,
2080                                                 op.acc());
2081   }
2082   // Construct new iterator types and affine map array attribute.
2083   std::array<AffineMap, 3> lowIndexingMaps = {
2084       adjustMap(iMap[0], iterIndex, rewriter),
2085       adjustMap(iMap[1], iterIndex, rewriter),
2086       adjustMap(iMap[2], iterIndex, rewriter)};
2087   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
2088   auto lowIter =
2089       rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
2090   // Unroll into a series of lower dimensional vector.contract ops.
2091   // By feeding the initial accumulator into the first contraction,
2092   // and the result of each contraction into the next, eventually
2093   // the sum of all reductions is computed.
2094   Value result = op.acc();
2095   for (int64_t d = 0; d < dimSize; ++d) {
2096     auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
2097     auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
2098     result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
2099                                                     lowAffine, lowIter);
2100   }
2101   return result;
2102 }
2103 
2104 } // namespace mlir
2105 
extractConstantIndex(Value v)2106 static Optional<int64_t> extractConstantIndex(Value v) {
2107   if (auto cstOp = v.getDefiningOp<ConstantIndexOp>())
2108     return cstOp.getValue();
2109   if (auto affineApplyOp = v.getDefiningOp<AffineApplyOp>())
2110     if (affineApplyOp.getAffineMap().isSingleConstant())
2111       return affineApplyOp.getAffineMap().getSingleConstantResult();
2112   return None;
2113 }
2114 
2115 // Missing foldings of scf.if make it necessary to perform poor man's folding
2116 // eagerly, especially in the case of unrolling. In the future, this should go
2117 // away once scf.if folds properly.
createScopedFoldedSLE(Value v,Value ub)2118 static Value createScopedFoldedSLE(Value v, Value ub) {
2119   using namespace edsc::op;
2120   auto maybeCstV = extractConstantIndex(v);
2121   auto maybeCstUb = extractConstantIndex(ub);
2122   if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb)
2123     return Value();
2124   return sle(v, ub);
2125 }
2126 
2127 // Operates under a scoped context to build the condition to ensure that a
2128 // particular VectorTransferOpInterface is unmasked.
createScopedInBoundsCond(VectorTransferOpInterface xferOp)2129 static Value createScopedInBoundsCond(VectorTransferOpInterface xferOp) {
2130   assert(xferOp.permutation_map().isMinorIdentity() &&
2131          "Expected minor identity map");
2132   Value inBoundsCond;
2133   xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
2134     // Zip over the resulting vector shape and memref indices.
2135     // If the dimension is known to be unmasked, it does not participate in the
2136     // construction of `inBoundsCond`.
2137     if (!xferOp.isMaskedDim(resultIdx))
2138       return;
2139     int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
2140     using namespace edsc::op;
2141     using namespace edsc::intrinsics;
2142     // Fold or create the check that `index + vector_size` <= `memref_size`.
2143     Value sum = xferOp.indices()[indicesIdx] + std_constant_index(vectorSize);
2144     Value cond =
2145         createScopedFoldedSLE(sum, std_dim(xferOp.memref(), indicesIdx));
2146     if (!cond)
2147       return;
2148     // Conjunction over all dims for which we are in-bounds.
2149     inBoundsCond = inBoundsCond ? inBoundsCond && cond : cond;
2150   });
2151   return inBoundsCond;
2152 }
2153 
splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp)2154 LogicalResult mlir::vector::splitFullAndPartialTransferPrecondition(
2155     VectorTransferOpInterface xferOp) {
2156   // TODO: expand support to these 2 cases.
2157   if (!xferOp.permutation_map().isMinorIdentity())
2158     return failure();
2159   // Must have some masked dimension to be a candidate for splitting.
2160   if (!xferOp.hasMaskedDim())
2161     return failure();
2162   // Don't split transfer operations directly under IfOp, this avoids applying
2163   // the pattern recursively.
2164   // TODO: improve the filtering condition to make it more applicable.
2165   if (isa<scf::IfOp>(xferOp->getParentOp()))
2166     return failure();
2167   return success();
2168 }
2169 
2170 /// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
2171 /// be cast. If the MemRefTypes don't have the same rank or are not strided,
2172 /// return null; otherwise:
2173 ///   1. if `aT` and `bT` are cast-compatible, return `aT`.
2174 ///   2. else return a new MemRefType obtained by iterating over the shape and
2175 ///   strides and:
2176 ///     a. keeping the ones that are static and equal across `aT` and `bT`.
2177 ///     b. using a dynamic shape and/or stride for the dimensions that don't
2178 ///        agree.
getCastCompatibleMemRefType(MemRefType aT,MemRefType bT)2179 static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
2180   if (MemRefCastOp::areCastCompatible(aT, bT))
2181     return aT;
2182   if (aT.getRank() != bT.getRank())
2183     return MemRefType();
2184   int64_t aOffset, bOffset;
2185   SmallVector<int64_t, 4> aStrides, bStrides;
2186   if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
2187       failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
2188       aStrides.size() != bStrides.size())
2189     return MemRefType();
2190 
2191   ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
2192   int64_t resOffset;
2193   SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
2194       resStrides(bT.getRank(), 0);
2195   for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
2196     resShape[idx] =
2197         (aShape[idx] == bShape[idx]) ? aShape[idx] : MemRefType::kDynamicSize;
2198     resStrides[idx] = (aStrides[idx] == bStrides[idx])
2199                           ? aStrides[idx]
2200                           : MemRefType::kDynamicStrideOrOffset;
2201   }
2202   resOffset =
2203       (aOffset == bOffset) ? aOffset : MemRefType::kDynamicStrideOrOffset;
2204   return MemRefType::get(
2205       resShape, aT.getElementType(),
2206       makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext()));
2207 }
2208 
2209 /// Operates under a scoped context to build the intersection between the
2210 /// view `xferOp.memref()` @ `xferOp.indices()` and the view `alloc`.
2211 // TODO: view intersection/union/differences should be a proper std op.
createScopedSubViewIntersection(VectorTransferOpInterface xferOp,Value alloc)2212 static Value createScopedSubViewIntersection(VectorTransferOpInterface xferOp,
2213                                              Value alloc) {
2214   using namespace edsc::intrinsics;
2215   int64_t memrefRank = xferOp.getMemRefType().getRank();
2216   // TODO: relax this precondition, will require rank-reducing subviews.
2217   assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() &&
2218          "Expected memref rank to match the alloc rank");
2219   Value one = std_constant_index(1);
2220   ValueRange leadingIndices =
2221       xferOp.indices().take_front(xferOp.getLeadingMemRefRank());
2222   SmallVector<Value, 4> sizes;
2223   sizes.append(leadingIndices.begin(), leadingIndices.end());
2224   xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
2225     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
2226     Value dimMemRef = std_dim(xferOp.memref(), indicesIdx);
2227     Value dimAlloc = std_dim(alloc, resultIdx);
2228     Value index = xferOp.indices()[indicesIdx];
2229     AffineExpr i, j, k;
2230     bindDims(xferOp.getContext(), i, j, k);
2231     SmallVector<AffineMap, 4> maps =
2232         AffineMap::inferFromExprList(MapList{{i - j, k}});
2233     // affine_min(%dimMemRef - %index, %dimAlloc)
2234     Value affineMin = affine_min(index.getType(), maps[0],
2235                                  ValueRange{dimMemRef, index, dimAlloc});
2236     sizes.push_back(affineMin);
2237   });
2238   return std_sub_view(xferOp.memref(), xferOp.indices(), sizes,
2239                       SmallVector<Value, 4>(memrefRank, one));
2240 }
2241 
2242 /// Given an `xferOp` for which:
2243 ///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
2244 ///   2. a memref of single vector `alloc` has been allocated.
2245 /// Produce IR resembling:
2246 /// ```
2247 ///    %1:3 = scf.if (%inBounds) {
2248 ///      memref_cast %A: memref<A...> to compatibleMemRefType
2249 ///      scf.yield %view, ... : compatibleMemRefType, index, index
2250 ///    } else {
2251 ///      %2 = linalg.fill(%alloc, %pad)
2252 ///      %3 = subview %view [...][...][...]
2253 ///      linalg.copy(%3, %alloc)
2254 ///      memref_cast %alloc: memref<B...> to compatibleMemRefType
2255 ///      scf.yield %4, ... : compatibleMemRefType, index, index
2256 ///   }
2257 /// ```
2258 /// Return the produced scf::IfOp.
createScopedFullPartialLinalgCopy(vector::TransferReadOp xferOp,TypeRange returnTypes,Value inBoundsCond,MemRefType compatibleMemRefType,Value alloc)2259 static scf::IfOp createScopedFullPartialLinalgCopy(
2260     vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond,
2261     MemRefType compatibleMemRefType, Value alloc) {
2262   using namespace edsc;
2263   using namespace edsc::intrinsics;
2264   scf::IfOp fullPartialIfOp;
2265   Value zero = std_constant_index(0);
2266   Value memref = xferOp.memref();
2267   conditionBuilder(
2268       returnTypes, inBoundsCond,
2269       [&]() -> scf::ValueVector {
2270         Value res = memref;
2271         if (compatibleMemRefType != xferOp.getMemRefType())
2272           res = std_memref_cast(memref, compatibleMemRefType);
2273         scf::ValueVector viewAndIndices{res};
2274         viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
2275                               xferOp.indices().end());
2276         return viewAndIndices;
2277       },
2278       [&]() -> scf::ValueVector {
2279         linalg_fill(alloc, xferOp.padding());
2280         // Take partial subview of memref which guarantees no dimension
2281         // overflows.
2282         Value memRefSubView = createScopedSubViewIntersection(
2283             cast<VectorTransferOpInterface>(xferOp.getOperation()), alloc);
2284         linalg_copy(memRefSubView, alloc);
2285         Value casted = std_memref_cast(alloc, compatibleMemRefType);
2286         scf::ValueVector viewAndIndices{casted};
2287         viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
2288                               zero);
2289         return viewAndIndices;
2290       },
2291       &fullPartialIfOp);
2292   return fullPartialIfOp;
2293 }
2294 
2295 /// Given an `xferOp` for which:
2296 ///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
2297 ///   2. a memref of single vector `alloc` has been allocated.
2298 /// Produce IR resembling:
2299 /// ```
2300 ///    %1:3 = scf.if (%inBounds) {
2301 ///      memref_cast %A: memref<A...> to compatibleMemRefType
2302 ///      scf.yield %view, ... : compatibleMemRefType, index, index
2303 ///    } else {
2304 ///      %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
2305 ///      %3 = vector.type_cast %extra_alloc :
2306 ///        memref<...> to memref<vector<...>>
2307 ///      store %2, %3[] : memref<vector<...>>
2308 ///      %4 = memref_cast %alloc: memref<B...> to compatibleMemRefType
2309 ///      scf.yield %4, ... : compatibleMemRefType, index, index
2310 ///   }
2311 /// ```
2312 /// Return the produced scf::IfOp.
createScopedFullPartialVectorTransferRead(vector::TransferReadOp xferOp,TypeRange returnTypes,Value inBoundsCond,MemRefType compatibleMemRefType,Value alloc)2313 static scf::IfOp createScopedFullPartialVectorTransferRead(
2314     vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond,
2315     MemRefType compatibleMemRefType, Value alloc) {
2316   using namespace edsc;
2317   using namespace edsc::intrinsics;
2318   scf::IfOp fullPartialIfOp;
2319   Value zero = std_constant_index(0);
2320   Value memref = xferOp.memref();
2321   conditionBuilder(
2322       returnTypes, inBoundsCond,
2323       [&]() -> scf::ValueVector {
2324         Value res = memref;
2325         if (compatibleMemRefType != xferOp.getMemRefType())
2326           res = std_memref_cast(memref, compatibleMemRefType);
2327         scf::ValueVector viewAndIndices{res};
2328         viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
2329                               xferOp.indices().end());
2330         return viewAndIndices;
2331       },
2332       [&]() -> scf::ValueVector {
2333         Operation *newXfer =
2334             ScopedContext::getBuilderRef().clone(*xferOp.getOperation());
2335         Value vector = cast<VectorTransferOpInterface>(newXfer).vector();
2336         std_store(vector, vector_type_cast(
2337                               MemRefType::get({}, vector.getType()), alloc));
2338 
2339         Value casted = std_memref_cast(alloc, compatibleMemRefType);
2340         scf::ValueVector viewAndIndices{casted};
2341         viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
2342                               zero);
2343 
2344         return viewAndIndices;
2345       },
2346       &fullPartialIfOp);
2347   return fullPartialIfOp;
2348 }
2349 
2350 /// Split a vector.transfer operation into an unmasked fastpath and a slowpath.
2351 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
2352 /// newly created conditional upon function return.
2353 /// To accomodate for the fact that the original vector.transfer indexing may be
2354 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
2355 /// scf.if op returns a view and values of type index.
2356 /// At this time, only vector.transfer_read case is implemented.
2357 ///
2358 /// Example (a 2-D vector.transfer_read):
2359 /// ```
2360 ///    %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
2361 /// ```
2362 /// is transformed into:
2363 /// ```
2364 ///    %1:3 = scf.if (%inBounds) {
2365 ///      // fastpath, direct cast
2366 ///      memref_cast %A: memref<A...> to compatibleMemRefType
2367 ///      scf.yield %view : compatibleMemRefType, index, index
2368 ///    } else {
2369 ///      // slowpath, masked vector.transfer or linalg.copy.
2370 ///      memref_cast %alloc: memref<B...> to compatibleMemRefType
2371 ///      scf.yield %4 : compatibleMemRefType, index, index
2372 //     }
2373 ///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]}
2374 /// ```
2375 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
2376 ///
2377 /// Preconditions:
2378 ///  1. `xferOp.permutation_map()` must be a minor identity map
2379 ///  2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
2380 ///  must be equal. This will be relaxed in the future but requires
2381 ///  rank-reducing subviews.
splitFullAndPartialTransfer(OpBuilder & b,VectorTransferOpInterface xferOp,VectorTransformsOptions options,scf::IfOp * ifOp)2382 LogicalResult mlir::vector::splitFullAndPartialTransfer(
2383     OpBuilder &b, VectorTransferOpInterface xferOp,
2384     VectorTransformsOptions options, scf::IfOp *ifOp) {
2385   using namespace edsc;
2386   using namespace edsc::intrinsics;
2387 
2388   if (options.vectorTransferSplit == VectorTransferSplit::None)
2389     return failure();
2390 
2391   SmallVector<bool, 4> bools(xferOp.getTransferRank(), false);
2392   auto unmaskedAttr = b.getBoolArrayAttr(bools);
2393   if (options.vectorTransferSplit == VectorTransferSplit::ForceUnmasked) {
2394     xferOp.setAttr(vector::TransferReadOp::getMaskedAttrName(), unmaskedAttr);
2395     return success();
2396   }
2397 
2398   assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
2399          "Expected splitFullAndPartialTransferPrecondition to hold");
2400   auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
2401 
2402   // TODO: add support for write case.
2403   if (!xferReadOp)
2404     return failure();
2405 
2406   OpBuilder::InsertionGuard guard(b);
2407   if (xferOp.memref().getDefiningOp())
2408     b.setInsertionPointAfter(xferOp.memref().getDefiningOp());
2409   else
2410     b.setInsertionPoint(xferOp);
2411   ScopedContext scope(b, xferOp.getLoc());
2412   Value inBoundsCond = createScopedInBoundsCond(
2413       cast<VectorTransferOpInterface>(xferOp.getOperation()));
2414   if (!inBoundsCond)
2415     return failure();
2416 
2417   // Top of the function `alloc` for transient storage.
2418   Value alloc;
2419   {
2420     FuncOp funcOp = xferOp->getParentOfType<FuncOp>();
2421     OpBuilder::InsertionGuard guard(b);
2422     b.setInsertionPointToStart(&funcOp.getRegion().front());
2423     auto shape = xferOp.getVectorType().getShape();
2424     Type elementType = xferOp.getVectorType().getElementType();
2425     alloc = std_alloca(MemRefType::get(shape, elementType), ValueRange{},
2426                        b.getI64IntegerAttr(32));
2427   }
2428 
2429   MemRefType compatibleMemRefType = getCastCompatibleMemRefType(
2430       xferOp.getMemRefType(), alloc.getType().cast<MemRefType>());
2431 
2432   // Read case: full fill + partial copy -> unmasked vector.xfer_read.
2433   SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
2434                                    b.getIndexType());
2435   returnTypes[0] = compatibleMemRefType;
2436   scf::IfOp fullPartialIfOp =
2437       options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
2438           ? createScopedFullPartialVectorTransferRead(
2439                 xferReadOp, returnTypes, inBoundsCond, compatibleMemRefType,
2440                 alloc)
2441           : createScopedFullPartialLinalgCopy(xferReadOp, returnTypes,
2442                                               inBoundsCond,
2443                                               compatibleMemRefType, alloc);
2444   if (ifOp)
2445     *ifOp = fullPartialIfOp;
2446 
2447   // Unmask the existing read op, it always reads from a full buffer.
2448   for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
2449     xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
2450   xferOp.setAttr(vector::TransferReadOp::getMaskedAttrName(), unmaskedAttr);
2451 
2452   return success();
2453 }
2454 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2455 LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
2456     Operation *op, PatternRewriter &rewriter) const {
2457   auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
2458   if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
2459       failed(filter(xferOp)))
2460     return failure();
2461   rewriter.startRootUpdate(xferOp);
2462   if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) {
2463     rewriter.finalizeRootUpdate(xferOp);
2464     return success();
2465   }
2466   rewriter.cancelRootUpdate(xferOp);
2467   return failure();
2468 }
2469 
matchAndRewrite(ExtractMapOp extract,PatternRewriter & rewriter) const2470 LogicalResult mlir::vector::PointwiseExtractPattern::matchAndRewrite(
2471     ExtractMapOp extract, PatternRewriter &rewriter) const {
2472   Operation *definedOp = extract.vector().getDefiningOp();
2473   if (!definedOp || definedOp->getNumResults() != 1)
2474     return failure();
2475   // TODO: Create an interfaceOp for elementwise operations.
2476   if (!isa<AddFOp>(definedOp))
2477     return failure();
2478   Location loc = extract.getLoc();
2479   SmallVector<Value, 4> extractOperands;
2480   for (OpOperand &operand : definedOp->getOpOperands())
2481     extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
2482         loc, extract.getResultType(), operand.get(), extract.ids()));
2483   Operation *newOp = cloneOpWithOperandsAndTypes(
2484       rewriter, loc, definedOp, extractOperands, extract.getResult().getType());
2485   rewriter.replaceOp(extract, newOp->getResult(0));
2486   return success();
2487 }
2488 
distributPointwiseVectorOp(OpBuilder & builder,Operation * op,ArrayRef<Value> ids,ArrayRef<int64_t> multiplicity,const AffineMap & map)2489 Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
2490     OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
2491     ArrayRef<int64_t> multiplicity, const AffineMap &map) {
2492   OpBuilder::InsertionGuard guard(builder);
2493   builder.setInsertionPointAfter(op);
2494   Location loc = op->getLoc();
2495   if (op->getNumResults() != 1)
2496     return {};
2497   Value result = op->getResult(0);
2498   VectorType type = op->getResult(0).getType().dyn_cast<VectorType>();
2499   if (!type || map.getNumResults() != multiplicity.size())
2500     return {};
2501   // For each dimension being distributed check that the size is a multiple of
2502   // the multiplicity. To handle more sizes we would need to support masking.
2503   unsigned multiplictyCount = 0;
2504   for (auto exp : map.getResults()) {
2505     auto affinExp = exp.dyn_cast<AffineDimExpr>();
2506     if (!affinExp || affinExp.getPosition() >= type.getRank() ||
2507         type.getDimSize(affinExp.getPosition()) %
2508                 multiplicity[multiplictyCount++] !=
2509             0)
2510       return {};
2511   }
2512   DistributeOps ops;
2513   ops.extract =
2514       builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map);
2515   ops.insert =
2516       builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids);
2517   return ops;
2518 }
2519 
2520 struct TransferReadExtractPattern
2521     : public OpRewritePattern<vector::TransferReadOp> {
TransferReadExtractPatternTransferReadExtractPattern2522   TransferReadExtractPattern(MLIRContext *context)
2523       : OpRewritePattern<vector::TransferReadOp>(context) {}
matchAndRewriteTransferReadExtractPattern2524   LogicalResult matchAndRewrite(vector::TransferReadOp read,
2525                                 PatternRewriter &rewriter) const override {
2526     if (!read.getResult().hasOneUse())
2527       return failure();
2528     auto extract =
2529         dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin());
2530     if (!extract)
2531       return failure();
2532     edsc::ScopedContext scope(rewriter, read.getLoc());
2533     using mlir::edsc::op::operator+;
2534     using mlir::edsc::op::operator*;
2535     using namespace mlir::edsc::intrinsics;
2536     SmallVector<Value, 4> indices(read.indices().begin(), read.indices().end());
2537     AffineMap map = extract.map();
2538     unsigned idCount = 0;
2539     for (auto expr : map.getResults()) {
2540       unsigned pos = expr.cast<AffineDimExpr>().getPosition();
2541       indices[pos] =
2542           indices[pos] +
2543           extract.ids()[idCount++] *
2544               std_constant_index(extract.getResultType().getDimSize(pos));
2545     }
2546     Value newRead = vector_transfer_read(extract.getType(), read.memref(),
2547                                          indices, read.permutation_map(),
2548                                          read.padding(), read.maskedAttr());
2549     Value dest = rewriter.create<ConstantOp>(
2550         read.getLoc(), read.getType(), rewriter.getZeroAttr(read.getType()));
2551     newRead = rewriter.create<vector::InsertMapOp>(read.getLoc(), newRead, dest,
2552                                                    extract.ids());
2553     rewriter.replaceOp(read, newRead);
2554     return success();
2555   }
2556 };
2557 
2558 struct TransferWriteInsertPattern
2559     : public OpRewritePattern<vector::TransferWriteOp> {
TransferWriteInsertPatternTransferWriteInsertPattern2560   TransferWriteInsertPattern(MLIRContext *context)
2561       : OpRewritePattern<vector::TransferWriteOp>(context) {}
matchAndRewriteTransferWriteInsertPattern2562   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
2563                                 PatternRewriter &rewriter) const override {
2564     auto insert = write.vector().getDefiningOp<vector::InsertMapOp>();
2565     if (!insert)
2566       return failure();
2567     edsc::ScopedContext scope(rewriter, write.getLoc());
2568     using mlir::edsc::op::operator+;
2569     using mlir::edsc::op::operator*;
2570     using namespace mlir::edsc::intrinsics;
2571     SmallVector<Value, 4> indices(write.indices().begin(),
2572                                   write.indices().end());
2573     AffineMap map = insert.map();
2574     unsigned idCount = 0;
2575     for (auto expr : map.getResults()) {
2576       unsigned pos = expr.cast<AffineDimExpr>().getPosition();
2577       indices[pos] =
2578           indices[pos] +
2579           insert.ids()[idCount++] *
2580               std_constant_index(insert.getSourceVectorType().getDimSize(pos));
2581     }
2582     vector_transfer_write(insert.vector(), write.memref(), indices,
2583                           write.permutation_map(), write.maskedAttr());
2584     rewriter.eraseOp(write);
2585     return success();
2586   }
2587 };
2588 
2589 // TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
2590 // TODO: Add this as DRR pattern.
populateVectorToVectorTransformationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)2591 void mlir::vector::populateVectorToVectorTransformationPatterns(
2592     OwningRewritePatternList &patterns, MLIRContext *context) {
2593   // clang-format off
2594   patterns.insert<ShapeCastOpDecomposer,
2595                   ShapeCastOpFolder,
2596                   SplitTransferReadOp,
2597                   SplitTransferWriteOp,
2598                   TupleGetFolderOp,
2599                   TransferReadExtractPattern,
2600                   TransferWriteInsertPattern>(context);
2601   // clang-format on
2602 }
2603 
populateVectorSlicesLoweringPatterns(OwningRewritePatternList & patterns,MLIRContext * context)2604 void mlir::vector::populateVectorSlicesLoweringPatterns(
2605     OwningRewritePatternList &patterns, MLIRContext *context) {
2606   patterns.insert<ExtractSlicesOpLowering, InsertSlicesOpLowering>(context);
2607 }
2608 
populateVectorContractLoweringPatterns(OwningRewritePatternList & patterns,MLIRContext * context,VectorTransformsOptions parameters)2609 void mlir::vector::populateVectorContractLoweringPatterns(
2610     OwningRewritePatternList &patterns, MLIRContext *context,
2611     VectorTransformsOptions parameters) {
2612   // clang-format off
2613   patterns.insert<BroadcastOpLowering,
2614                   CreateMaskOpLowering,
2615                   ConstantMaskOpLowering,
2616                   OuterProductOpLowering,
2617                   ShapeCastOp2DDownCastRewritePattern,
2618                   ShapeCastOp2DUpCastRewritePattern,
2619                   ShapeCastOpRewritePattern>(context);
2620   patterns.insert<TransposeOpLowering,
2621                   ContractionOpLowering,
2622                   ContractionOpToMatmulOpLowering,
2623                   ContractionOpToOuterProductOpLowering>(parameters, context);
2624   // clang-format on
2625 }
2626