1 //===- VectorOps.h - MLIR Vector Dialect Operations -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Vector dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_VECTOR_VECTOROPS_H
14 #define MLIR_DIALECT_VECTOR_VECTOROPS_H
15 
16 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Dialect.h"
20 #include "mlir/IR/OpDefinition.h"
21 #include "mlir/Interfaces/SideEffectInterfaces.h"
22 #include "mlir/Interfaces/VectorInterfaces.h"
23 #include "mlir/Interfaces/ViewLikeInterface.h"
24 
25 namespace mlir {
26 class MLIRContext;
27 class OwningRewritePatternList;
28 namespace vector {
29 
30 /// Collect a set of vector-to-vector canonicalization patterns.
31 void populateVectorToVectorCanonicalizationPatterns(
32     OwningRewritePatternList &patterns, MLIRContext *context);
33 
34 /// Collect a set of vector-to-vector transformation patterns.
35 void populateVectorToVectorTransformationPatterns(
36     OwningRewritePatternList &patterns, MLIRContext *context);
37 
38 /// Collect a set of vector slices transformation patterns:
39 ///    ExtractSlicesOpLowering, InsertSlicesOpLowering
40 /// Useful for clients that want to express all vector "slices"
41 /// ops in terms of more elementary vector "slice" ops. If all
42 /// "produced" tuple values are "consumed" (the most common
43 /// use for "slices" ops), this lowering removes all tuple related
44 /// operations as well (through DCE and folding). If tuple values
45 /// "leak" coming in, however, some tuple related ops will remain.
46 void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns,
47                                           MLIRContext *context);
48 
49 /// Enum to control the lowering of `vector.contract` operations.
50 enum class VectorContractLowering {
51   /// Progressively lower to finer grained `vector.contract` and dot-products.
52   Dot = 0,
53   /// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics.
54   Matmul = 1,
55   /// Lower to `vector.outerproduct`.
56   OuterProduct = 2,
57 };
58 /// Enum to control the lowering of `vector.transpose` operations.
59 enum class VectorTransposeLowering {
60   /// Lower transpose into element-wise extract and inserts.
61   EltWise = 0,
62   /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
63   /// intrinsics.
64   Flat = 1,
65 };
66 /// Enum to control the splitting of `vector.transfer` operations into masked
67 /// and unmasked variants.
68 enum class VectorTransferSplit {
69   /// Do not split vector transfer operations.
70   None = 0,
71   /// Split using masked + unmasked vector.transfer operations.
72   VectorTransfer = 1,
73   /// Split using a unmasked vector.transfer + linalg.fill + linalg.copy
74   /// operations.
75   LinalgCopy = 2,
76   /// Do not split vector transfer operation but instead mark it as "unmasked".
77   ForceUnmasked = 3
78 };
79 /// Structure to control the behavior of vector transform patterns.
80 struct VectorTransformsOptions {
81   /// Option to control the lowering of vector.contract.
82   VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
83   VectorTransformsOptions &
setVectorTransformsOptionsVectorTransformsOptions84   setVectorTransformsOptions(VectorContractLowering opt) {
85     vectorContractLowering = opt;
86     return *this;
87   }
88   /// Option to control the lowering of vector.transpose.
89   VectorTransposeLowering vectorTransposeLowering =
90       VectorTransposeLowering::EltWise;
91   VectorTransformsOptions &
setVectorTransposeLoweringVectorTransformsOptions92   setVectorTransposeLowering(VectorTransposeLowering opt) {
93     vectorTransposeLowering = opt;
94     return *this;
95   }
96   /// Option to control the splitting of vector transfers.
97   VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
setVectorTransferSplitVectorTransformsOptions98   VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
99     vectorTransferSplit = opt;
100     return *this;
101   }
102 };
103 
104 /// Collect a set of transformation patterns that are related to contracting
105 /// or expanding vector operations:
106 ///   ContractionOpLowering,
107 ///   ShapeCastOp2DDownCastRewritePattern,
108 ///   ShapeCastOp2DUpCastRewritePattern
109 ///   BroadcastOpLowering,
110 ///   TransposeOpLowering
111 ///   OuterproductOpLowering
112 /// These transformation express higher level vector ops in terms of more
113 /// elementary extraction, insertion, reduction, product, and broadcast ops.
114 void populateVectorContractLoweringPatterns(
115     OwningRewritePatternList &patterns, MLIRContext *context,
116     VectorTransformsOptions vectorTransformOptions = VectorTransformsOptions());
117 
118 /// Returns the integer type required for subscripts in the vector dialect.
119 IntegerType getVectorSubscriptType(Builder &builder);
120 
121 /// Returns an integer array attribute containing the given values using
122 /// the integer type required for subscripts in the vector dialect.
123 ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef<int64_t> values);
124 
125 namespace impl {
126 /// Build the default minor identity map suitable for a vector transfer. This
127 /// also handles the case memref<... x vector<...>> -> vector<...> in which the
128 /// rank of the identity map must take the vector element type into account.
129 AffineMap getTransferMinorIdentityMap(MemRefType memRefType,
130                                       VectorType vectorType);
131 } // namespace impl
132 } // end namespace vector
133 } // end namespace mlir
134 
135 #define GET_OP_CLASSES
136 #include "mlir/Dialect/Vector/VectorOps.h.inc"
137 #include "mlir/Dialect/Vector/VectorOpsDialect.h.inc"
138 
139 #endif // MLIR_DIALECT_VECTOR_VECTOROPS_H
140