1 //===- Utils.h - Utilities to support the Linalg dialect --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_LINALG_UTILS_H_
10 #define MLIR_DIALECT_LINALG_UTILS_H_
11 
12 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
14 #include "mlir/Dialect/Linalg/EDSC/Builders.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/SCF/SCF.h"
17 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 
20 #include "llvm/ADT/MapVector.h"
21 #include "llvm/ADT/SetVector.h"
22 
23 using mlir::edsc::intrinsics::AffineIndexedValue;
24 using mlir::edsc::intrinsics::StdIndexedValue;
25 
26 namespace mlir {
27 class AffineExpr;
28 class AffineForOp;
29 class AffineMap;
30 class OperationFolder;
31 class PatternRewriter;
32 
33 namespace linalg {
34 class LinalgDependenceGraph;
35 
36 /// A struct containing the Linalg producer before and after fusion.
37 /// When operating on tensors, `fusedProducer` may feed into a `tensor_cast` op
38 /// before the consumer Linalg op, until enough canonicalizations have applied.
39 struct FusionInfo {
40   LinalgOp originalProducer;
41   LinalgOp fusedProducer;
42 };
43 
44 /// A struct containing common matchers over linalg op's region.
45 struct RegionMatcher {
46   enum class BinaryOpKind {
47     IAdd,
48   };
49 
50   /// Matches the given linalg op if its body is performing binary operation on
51   /// int or float scalar values and returns the binary op kind.
52   ///
53   /// The linalg op's region is expected to be
54   /// ```
55   /// {
56   ///   ^bb(%a: <scalar-type>, %b: <scalar-type>):
57   ///     %0 = <binary-op> %a, %b: <scalar-type>
58   ///     linalg.yield %0: <scalar-type>
59   /// }
60   /// ```
61   static Optional<BinaryOpKind> matchAsScalarBinaryOp(GenericOp op);
62 };
63 
64 /// Checks if an iterator_type attribute is parallel.
65 bool isParallelIteratorType(Attribute attr);
66 
67 /// Checks if an iterator_type attribute is parallel.
68 bool isReductionIteratorType(Attribute attr);
69 
70 /// Checks if an iterator_type attribute is parallel.
71 bool isWindowIteratorType(Attribute attr);
72 
73 /// Checks whether the specific `producer` is the last write to exactly the
74 /// whole `consumedView`. This checks structural dominance, that the dependence
75 /// is a RAW without any interleaved write to any piece of `consumedView`.
76 bool isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
77                                LinalgOp consumer, Value consumedView,
78                                LinalgOp producer);
79 
80 /// Checks whether fusing the specific `producer` of the `consumedView` is
81 /// feasible. This checks `producer` is the last write of `consumedView` and
82 /// that no interleaved dependence would be violated (RAW, WAR or WAW).
83 bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer,
84                    Value consumedView, LinalgOp producer);
85 
86 using FusableOpDependencesTy = llvm::MapVector<
87     Operation *,
88     SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>;
89 FusableOpDependencesTy
90 findAllFusableDependences(ArrayRef<LinalgOp> ops,
91                           const LinalgDependenceGraph &dependenceGraph);
92 
93 /// Fuses producer into consumer if the producer is structurally feasible and
94 /// the fusion would not violate dependencies.
95 /// Implements the fusion part of the "tileAndFuse on buffers"
96 /// transformation and thus requires the `consumerdIdx`^th operand of `consumer`
97 /// to be a `subview` op (generally obtained by applying the tiling
98 /// transformation).
99 Optional<FusionInfo> fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
100                                           unsigned consumerIdx,
101                                           const LinalgDependenceGraph &graph);
102 /// Tensor counterpart of `fuseProducerOfBuffer`.
103 /// This implements the fusion part of the "tileAndFuse on tensors"
104 /// transformation and thus requires the `consumerdIdx`^th operand of `consumer`
105 /// to be the result of a `subtensor` op (generally obtained by applying the
106 /// tiling transformation).
107 Optional<FusionInfo> fuseProducerOfTensor(OpBuilder &b, LinalgOp consumer,
108                                           unsigned consumerIdx);
109 
110 /// Fuse linalg operation on tensors, with the producer of the operand at
111 /// position `consumerIdx` of the consumer.
112 Optional<SmallVector<Value, 1>> fuseTensorOps(PatternRewriter &rewriter,
113                                               Operation *consumer,
114                                               unsigned consumerIdx);
115 
116 /// Like `getShape`, but only returns statically-known information, without
117 /// generating any new IR. For each shape dimension, returns >=0 if that
118 /// dimension is statically known, or -1 otherwise.
119 SmallVector<int64_t, 8> getStaticShape(LinalgOp linalgOp);
120 
121 /// Returns the statically-known loop ranges of the `linalgOp`. Composes
122 /// `linalgOp.getShapesToLoopsMap()` with the result of `getStaticShape`.
123 /// Returns None if `linalgOp.getShapesToLoopsMap()` fails. Returns -1
124 /// for non-statically-known loop ranges.
125 Optional<SmallVector<int64_t, 4>> getStaticLoopRanges(LinalgOp linalgOp);
126 
127 /// Apply the permutation defined by `permutation` to `inVec`.
128 /// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
129 /// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector
130 /// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`.
131 template <typename T, unsigned N>
applyPermutationToVector(SmallVector<T,N> & inVec,ArrayRef<unsigned> permutation)132 void applyPermutationToVector(SmallVector<T, N> &inVec,
133                               ArrayRef<unsigned> permutation) {
134   SmallVector<T, N> auxVec(inVec.size());
135   for (unsigned i = 0; i < permutation.size(); ++i)
136     auxVec[i] = inVec[permutation[i]];
137   inVec = auxVec;
138 }
139 
140 /// Scheme used to distribute loops to processors.
141 enum class DistributionMethod {
142   /// Cyclic distribution where no assumption is made about the dynamic
143   /// relationship between number of processors and number of iterations of the
144   /// distributed loop. Distributes the following loop
145   ///
146   /// scf.parallel (%iv) = (%lb) to (%ub) step (%step)
147   ///
148   /// to
149   ///
150   /// scf.parallel(%iv)= (%lb + %procId * %step) to (%ub) step (%step * %nprocs)
151   Cyclic = 0,
152 
153   /// Cyclic distribution where the number of processors can be assumed to be
154   /// more than or equal to the number of iterations of the distributed loop. In
155   /// such cases, a simple in-bounds check is enough (instead of materializing a
156   /// loop). Distributes the following loop
157   ///
158   /// scf.parallel (%iv) = (%lb) to (%ub) step (%step)
159   ///
160   /// to
161   ///
162   /// %iv = %lb + %procId * %step
163   /// %cond = cmpi "slt", %iv, %ub
164   /// scf.if %cond {
165   ///   ...
166   /// }
167   CyclicNumProcsGeNumIters = 1,
168 
169   /// Cyclic distribution where the number of processors can be assumed to be
170   ///  equal to the number of iterations of the distributed loop. In such cases,
171   ///  no bounds check is needed. Distributes the following loop
172   ///
173   /// scf.parallel (%iv) = (%lb) to (%ub) step (%step)
174   ///
175   /// to
176   ///
177   /// %iv = %lb + %procId * %step
178   CyclicNumProcsEqNumIters = 2
179 };
180 
181 /// Callback function type used to get processor ID, and number of processors
182 /// used for distribution for all parallel loops generated.
183 struct ProcInfo {
184   Value procId;
185   Value nprocs;
186 };
187 using ProcInfoCallBackFn = std::function<SmallVector<ProcInfo, 2>(
188     OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges)>;
189 
190 /// Options that allow distribution of loops generated in Linalg transforms to
191 /// processors while generating the loops.
192 struct LinalgLoopDistributionOptions {
193   /// Callback function that returns the Values for processor ID (`procId`), and
194   /// number of processors (`nprocs`) used to execute the parallel loops. The
195   /// number of `{procId, nprocs}` pairs returned must be equal to the number of
196   /// `parallelLoopRanges` passed into the callback, which in-turn is same as
197   /// the number of parallel loops for which the `distributionMethod` is
198   /// specified below.
199   ProcInfoCallBackFn procInfo;
200   /// Specification of how to distribute the `scf.parallel` loops that are
201   /// generated. As the `scf.parallel` loop is generated, the elements of this
202   /// vector is used (from left to right) and the specified distribution is
203   /// applied. If the vector is less than the number of `scf.parallel` loops
204   /// generated, then no distribution is applied.
205   SmallVector<DistributionMethod, 0> distributionMethod = {};
206 };
207 
208 /// Utility class used to generate nested loops with ranges described by
209 /// `loopRanges` and loop type described by the `iteratorTypes`. `bodyBuilderFn`
210 /// is used to generate the body of the innermost loop. It is passed a range
211 /// of loop induction variables.
212 template <typename LoopTy>
213 struct GenerateLoopNest {
214   using IndexedValueTy =
215       typename std::conditional<std::is_same<LoopTy, AffineForOp>::value,
216                                 AffineIndexedValue, StdIndexedValue>::type;
217 
218   static void
219   doit(ArrayRef<Range> loopRanges, ValueRange iterArgInitValues,
220        ArrayRef<Attribute> iteratorTypes,
221        function_ref<scf::ValueVector(ValueRange, ValueRange)> bodyBuilderFn,
222        Optional<LinalgLoopDistributionOptions> = None);
223 };
224 
225 } // namespace linalg
226 } // namespace mlir
227 
228 #endif // MLIR_DIALECT_LINALG_UTILS_H_
229