1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion on tensors operations pass.
10 //
11 //===----------------------------------------------------------------------===//
12 #include "PassDetail.h"
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
16 #include "mlir/Dialect/Linalg/Passes.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/IR/AffineExpr.h"
20 #include "mlir/IR/AffineMap.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Support/LLVM.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24
25 using namespace mlir;
26 using namespace mlir::linalg;
27
28 /// Implementation of fusion of generic ops and indexed_generic ops.
29 // struct FuseGenericOpsOnTensors {
areTensorOpsFusable(LinalgOp producer,LinalgOp consumer,unsigned consumerIdx)30 static bool areTensorOpsFusable(LinalgOp producer, LinalgOp consumer,
31 unsigned consumerIdx) {
32 // Producer and consumer must have tensor semantics.
33 if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
34 return false;
35
36 // Verify that
37 // - the producer has all "parallel" iterator type.
38 if (producer.getNumParallelLoops() != producer.getNumLoops())
39 return false;
40
41 // Get the consumer index map. The number of results of the consumer index
42 // map must match the number of loops of the producer.
43 AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx);
44 if (consumerIndexMap.getNumResults() != producer.getNumLoops())
45 return false;
46
47 // Finally the index_map for the result must be invertible. For now just
48 // verify it is a permutation.
49 AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
50 return producerResultIndexMap.isPermutation();
51 }
52
53 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
54 /// the `producer` to use in the fused operation given the indexing map of the
55 /// result of the producer in the consumer.
getIndexingMapOfProducerOperandsInFusedOp(LinalgOp producer,AffineMap fusedConsumerArgIndexMap,SmallVectorImpl<Attribute> & fusedOpIndexingMapAttrs)56 static void getIndexingMapOfProducerOperandsInFusedOp(
57 LinalgOp producer, AffineMap fusedConsumerArgIndexMap,
58 SmallVectorImpl<Attribute> &fusedOpIndexingMapAttrs) {
59 // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
60 // from consumer loop -> consumer arg tensor index/producer result tensor
61 // index. The fused loop is same as the consumer loop. For each producer arg
62 // the indexing map to be computed is a map from consumer loop -> producer
63 // arg tensor index.
64
65 AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
66 // producerResultIndexMap is a map from producer loop -> tensor index.
67 // Compute the inverse to get map from tensor index -> producer loop.
68 // The inverse is a map from producer result tensor index -> producer loop.
69 AffineMap invProducerResultIndexMap =
70 inversePermutation(producerResultIndexMap);
71 assert(invProducerResultIndexMap &&
72 "expected producer result indexig map to be invertible");
73 for (unsigned argNum : llvm::seq<unsigned>(0, producer.getNumInputs())) {
74 // argMap is a map from producer loop -> producer arg tensor index.
75 AffineMap argMap = producer.getInputIndexingMap(argNum);
76
77 // Compose argMap with invProducerResultIndexMap to get a map from
78 // producer result tensor index -> producer arg tensor index.
79 AffineMap t1 = argMap.compose(invProducerResultIndexMap);
80
81 // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
82 // consumer loop/ fused loop -> producer arg tensor index.
83 AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap);
84 fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap));
85 }
86 }
87
88 /// Generate the region of the fused tensor operation. The region of the fused
89 /// op must be empty.
generateFusedTensorOpRegion(PatternRewriter & rewriter,Operation * fusedOp,LinalgOp producer,LinalgOp consumer,AffineMap consumerToProducerLoopsMap,unsigned consumerIdx,unsigned nloops)90 static void generateFusedTensorOpRegion(PatternRewriter &rewriter,
91 Operation *fusedOp, LinalgOp producer,
92 LinalgOp consumer,
93 AffineMap consumerToProducerLoopsMap,
94 unsigned consumerIdx, unsigned nloops) {
95 // Build the region of the fused op.
96 Block &producerBlock = producer->getRegion(0).front();
97 Block &consumerBlock = consumer->getRegion(0).front();
98 Block *fusedBlock = new Block();
99 fusedOp->getRegion(0).push_back(fusedBlock);
100 BlockAndValueMapping mapper;
101 OpBuilder::InsertionGuard guard(rewriter);
102 rewriter.setInsertionPointToStart(fusedBlock);
103
104 // The block arguments are
105 // [index_0, index_1, ... ,
106 // consumer_operand_0, ... , consumer_operand_(`consumerIdx`-1),
107 // producer_operand_0, ... , producer_operand_(n-1)],
108 // consumer_operand_(`consumerIdx`), .. consumer_operand_(m-1)]
109 // , where n is the number of producer's operand and m is the number
110 // consumer's operand.
111 // If both `numProducerIndices` and `numConsumerIndices` are zero, this is a
112 // generic op. In this case, there are no indices in block arguments.
113 unsigned numProducerIndices = isa<IndexedGenericOp>(producer.getOperation())
114 ? producer.getNumLoops()
115 : 0;
116 unsigned numConsumerIndices = isa<IndexedGenericOp>(consumer.getOperation())
117 ? consumer.getNumLoops()
118 : 0;
119 unsigned numFusedOpIndices =
120 (isa<IndexedGenericOp>(producer.getOperation()) ||
121 isa<IndexedGenericOp>(consumer.getOperation()))
122 ? std::max(producer.getNumLoops(), consumer.getNumLoops())
123 : 0;
124 // Firstly, add all the indices to the block arguments.
125 for (unsigned i = 0, e = numFusedOpIndices; i < e; ++i)
126 fusedBlock->addArgument(rewriter.getIndexType());
127 // Map the arguments for the unmodified args from the consumer.
128 for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) {
129 if (consumerArg.index() == consumerIdx + numConsumerIndices) {
130 // Map the arguments for the args from the producer.
131 for (auto producerArg : llvm::enumerate(producerBlock.getArguments())) {
132 // If producer is an indexed_generic op, map the indices from consumer
133 // loop to producer loop (because the fusedOp is built based on
134 // consumer's perspective).
135 if (producerArg.index() < numProducerIndices) {
136 auto newIndex = rewriter.create<mlir::AffineApplyOp>(
137 producer.getLoc(),
138 consumerToProducerLoopsMap.getSubMap(producerArg.index()),
139 fusedBlock->getArguments().take_front(numFusedOpIndices));
140 mapper.map(producerArg.value(), newIndex);
141 } else {
142 mapper.map(producerArg.value(),
143 fusedBlock->addArgument(producerArg.value().getType()));
144 }
145 }
146 continue;
147 }
148
149 // If consumer is an indexed_generic op, map the indices to the block
150 // arguments directly. Otherwise, add the same type of argument and map to
151 // it.
152 if (consumerArg.index() < numConsumerIndices) {
153 mapper.map(consumerArg.value(),
154 fusedBlock->getArgument(consumerArg.index()));
155 } else {
156 mapper.map(consumerArg.value(),
157 fusedBlock->addArgument(consumerArg.value().getType()));
158 }
159 }
160
161 // Add operations from producer (except the yield operation) to the fused
162 // op.
163 for (auto &op : producerBlock.getOperations()) {
164 if (auto yieldOp = dyn_cast<linalg::YieldOp>(op)) {
165 // Lookup the value the yield operation is mapped to.
166 Value yieldVal = yieldOp.getOperand(0);
167 if (Value clonedVal = mapper.lookupOrNull(yieldVal))
168 mapper.map(consumerBlock.getArgument(consumerIdx + numConsumerIndices),
169 clonedVal);
170 continue;
171 }
172 rewriter.clone(op, mapper);
173 }
174 for (auto &op : consumerBlock.getOperations())
175 rewriter.clone(op, mapper);
176 }
177
178 static Optional<SmallVector<Value, 1>>
fuseTensorOpsImpl(LinalgOp producer,LinalgOp consumer,unsigned consumerIdx,PatternRewriter & rewriter)179 fuseTensorOpsImpl(LinalgOp producer, LinalgOp consumer, unsigned consumerIdx,
180 PatternRewriter &rewriter) {
181 if (!areTensorOpsFusable(producer, consumer, consumerIdx))
182 return llvm::None;
183
184 unsigned numFusedOperands =
185 producer.getNumInputs() + consumer.getNumInputs() - 1;
186
187 // Compute the fused operands list,
188 SmallVector<Value, 2> fusedOperands;
189 fusedOperands.reserve(numFusedOperands);
190 auto consumerOperands = consumer.getInputs();
191 auto producerOperands = producer.getInputs();
192 fusedOperands.assign(consumerOperands.begin(),
193 std::next(consumerOperands.begin(), consumerIdx));
194 fusedOperands.append(producerOperands.begin(), producerOperands.end());
195 fusedOperands.append(std::next(consumerOperands.begin(), consumerIdx + 1),
196 consumerOperands.end());
197
198 // Compute indexing_maps for the fused operation. The indexing_maps for the
199 // operands of the consumers that arent fused are the same. The
200 // indexing_maps for the producers need to be computed based on the
201 // indexing_map of the operand at consumerIdx in the consumer.
202 SmallVector<Attribute, 4> fusedIndexMaps;
203 auto consumerIndexMaps = consumer.indexing_maps();
204 fusedIndexMaps.reserve(fusedOperands.size() + consumer.getNumOutputs());
205 fusedIndexMaps.assign(consumerIndexMaps.begin(),
206 std::next(consumerIndexMaps.begin(), consumerIdx));
207 // Compute indexing maps for the producer args in the fused operation.
208 getIndexingMapOfProducerOperandsInFusedOp(
209 producer, consumer.getInputIndexingMap(consumerIdx), fusedIndexMaps);
210
211 // Append the indexing maps for the remaining consumer operands.
212 fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1),
213 consumerIndexMaps.end());
214
215 // Generate the fused op.
216 // Tensor-level fusion is only on ops without initTensors and outputBuffers.
217 LinalgOp fusedOp;
218 if (isa<GenericOp>(producer.getOperation()) &&
219 isa<GenericOp>(consumer.getOperation())) {
220 fusedOp =
221 rewriter
222 .create<GenericOp>(consumer.getLoc(), consumer->getResultTypes(),
223 /*inputs=*/fusedOperands,
224 /*outputBuffers=*/ValueRange{},
225 /*initTensors=*/ValueRange{},
226 rewriter.getArrayAttr(fusedIndexMaps),
227 consumer.iterator_types(),
228 /*doc=*/nullptr,
229 /*library_call=*/nullptr,
230 /*sparse=*/nullptr)
231 .getOperation();
232 } else {
233 fusedOp = rewriter
234 .create<IndexedGenericOp>(
235 consumer.getLoc(), consumer->getResultTypes(),
236 /*inputs=*/fusedOperands,
237 /*outputBuffers=*/ValueRange{},
238 /*initTensors=*/ValueRange{},
239 rewriter.getArrayAttr(fusedIndexMaps),
240 consumer.iterator_types(),
241 /*doc=*/nullptr,
242 /*library_call=*/nullptr,
243 /*sparse=*/nullptr)
244 .getOperation();
245 }
246
247 // Construct an AffineMap from consumer loops to producer loops.
248 // consumer loop -> tensor index
249 AffineMap consumerResultIndexMap = consumer.getInputIndexingMap(consumerIdx);
250 // producer loop -> tensor index
251 AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
252 // tensor index -> producer loop
253 AffineMap invProducerResultIndexMap =
254 inversePermutation(producerResultIndexMap);
255 assert(invProducerResultIndexMap &&
256 "expected producer result indexig map to be invertible");
257 // consumer loop -> producer loop
258 AffineMap consumerToProducerLoopsMap =
259 invProducerResultIndexMap.compose(consumerResultIndexMap);
260
261 generateFusedTensorOpRegion(rewriter, fusedOp.getOperation(), producer,
262 consumer, consumerToProducerLoopsMap, consumerIdx,
263 consumer.getNumLoops());
264 return SmallVector<Value, 1>(fusedOp->getResults());
265 }
266
267 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
268 /// provided, given the shape of the source tensor that corresponds to the
269 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
270 /// are "row-major" ordered logically.
271 ///
272 /// For example:
273 ///
274 /// %0 = op ... : tensor<?x?x4x5xf32>
275 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
276 ///
277 /// and reshape:
278 /// %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
279 /// affine_map<(i, j, k, l) -> (j, k, l)>] :
280 /// tensor<?x?x4x5xf32> into tensor<?x?xf32>
281 ///
282 /// would be rewritten into:
283 /// %0 = op ... : tensor<?x?x4x5xf32>
284 /// with output index_map
285 /// `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
linearizeCollapsedDims(AffineMap sourceMap,ArrayRef<int64_t> sourceShape,ArrayRef<AffineMap> reassociationMaps)286 static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
287 ArrayRef<int64_t> sourceShape,
288 ArrayRef<AffineMap> reassociationMaps) {
289 SmallVector<AffineExpr, 4> resultExprs;
290 resultExprs.reserve(reassociationMaps.size());
291 ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
292 MLIRContext *context = sourceMap.getContext();
293
294 // Compute the result exprs based on the reassociation maps.
295 for (AffineMap map : reassociationMaps) {
296 ArrayRef<AffineExpr> collapsedDims = map.getResults();
297 // Assume that they are in-order and contiguous (already checked in
298 // verifier).
299 assert(!collapsedDims.empty());
300 unsigned startDim =
301 collapsedDims.front().cast<AffineDimExpr>().getPosition();
302 AffineExpr linearizedExpr = makeCanonicalStridedLayoutExpr(
303 sourceShape.slice(startDim, collapsedDims.size()),
304 sourceExprs.slice(startDim, collapsedDims.size()), context);
305 resultExprs.push_back(linearizedExpr);
306 }
307 return AffineMap::get(sourceMap.getNumDims(), sourceMap.getNumSymbols(),
308 resultExprs, context);
309 }
310
311 /// Checks if the `reshapeOp` can be fused with it consumer (if `asProducer` is
312 /// true) or its producer (if `asProducer` is false) given the indexing map at
313 /// its use.
isTensorReshapeOpFoldableByLinearization(TensorReshapeOp reshapeOp,AffineMap useIndexMap,bool asProducer)314 static bool isTensorReshapeOpFoldableByLinearization(TensorReshapeOp reshapeOp,
315 AffineMap useIndexMap,
316 bool asProducer) {
317 RankedTensorType returnType = reshapeOp.getResultType();
318 RankedTensorType operandType = reshapeOp.getSrcType();
319 // Reshape is fusable with its consumer (i.e. reshape as a producer) when its
320 // operand is of lesser rank than the result. Fusing when operand has higher
321 // rank will require use of mods and divs in the indexing maps of the fused op
322 // which would make it non-invertible. Similarly reshape is fused with its
323 // producer (i.e. reshape as consumer) only if the return type has lesser
324 // rank.
325 if ((asProducer && reshapeOp.getSrcType().hasStaticShape() &&
326 returnType.getRank() < operandType.getRank()) ||
327 (!asProducer && reshapeOp.getResultType().hasStaticShape() &&
328 operandType.getRank() < returnType.getRank()))
329 return false;
330 return useIndexMap.isPermutation();
331 }
332
333 /// Based on the type of `op` create a linalg op of the same type, i.e. if `op`
334 /// is a linalg.generic operation, the create a `linalg.generic` operation with
335 /// the given `args`. Expects `op` to be `linalg.generic` or
336 /// `linalg.indexed_generic`.
337 template <typename... Args>
createLinalgOpOfSameType(LinalgOp op,PatternRewriter & rewriter,Args...args)338 static LinalgOp createLinalgOpOfSameType(LinalgOp op, PatternRewriter &rewriter,
339 Args... args) {
340 if (isa<GenericOp>(op.getOperation()))
341 return rewriter.create<GenericOp>(args...);
342 if (isa<IndexedGenericOp>(op.getOperation()))
343 return rewriter.create<IndexedGenericOp>(args...);
344 llvm_unreachable(
345 "expected only linalg.generic or linalg.indexed_generic ops");
346 return nullptr;
347 }
348
349 /// Conditions for folding a generic/indexed-generic operation with a reshape op
350 /// by expanding the iteration space dimensionality for tensor operations. These
351 /// are preconditions assumed by `foldReshapeByDimExpansion` which implements
352 /// the following fusion pattern.
353 ///
354 /// Consider
355 ///
356 /// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
357 /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
358 /// affine_map<(d0, d1, d2) -> (d1, d2)>,
359 /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
360 /// %d = linalg.tensor_reshape %c
361 /// [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>,
362 /// affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>,
363 /// affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>]
364 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
365 ///
366 /// The reshape can be folded into the `linalgOp` if the
367 /// generic/indexed-generic op loop dimensionality is increased to match the
368 /// result (operand) of the tensor_reshape when the reshape is expanding
369 /// (folding). The indexing_map of the fused tensor in the `linalgOp` and the
370 /// reassociation map helps compute the indexing maps of the modified op. For
371 /// the above example, based on the reassociation map it can be concluded that
372 ///
373 /// - The loop used to access the first dimension of the fused tensor is split
374 /// into two.
375 /// - The loop used to access the second dimension of the fused tensor is kept
376 /// as is.
377 /// - The loop used to access the third dimension of the fused tensor is split
378 /// into three.
379 ///
380 /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
381 /// op, then
382 ///
383 /// d0 -> e0, e1
384 /// d1 -> e2, e3, e4
385 /// d2 -> e5
386 ///
387 /// substituting this, the generic op can be rewritten as
388 ///
389 /// %d = linalg.generic ins(%0, %1 : )
390 /// indexing_maps =
391 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
392 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
393 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
394 ///
395 /// Since operands to the linalg generic are now 5D, reshapes can be introduced
396 /// to make it consistent
397 ///
398 /// %0 = linalg.tensor_reshape %a
399 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e2),
400 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e3, e4),
401 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e5)]
402 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
403 /// %1 = linalg.tensor_reshape %b
404 /// [affine_map<(e0, e1, e2, e3) -> (e0, e1, e2),
405 /// affine_map<(e0, e1, e2, e3) -> (e3)]
406 /// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
407 ///
408 /// The added reshapes are again expanding patterns, so they will get fused
409 /// with its producers if possible.
isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,unsigned fusedTensorIndex)410 static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
411 unsigned fusedTensorIndex) {
412 // Is fusable only if:
413 // - The linalgOp is a generic op, or an indexed_generic.
414 // - All the indexing maps for operands and results in linalgOp are projected
415 // permutations.
416 // - The fused tensor is not a scalar.
417 // - All the loops in linalgOp are parallel loops.
418 return isa<GenericOp, IndexedGenericOp>(linalgOp.getOperation()) &&
419 linalgOp.hasTensorSemantics() &&
420 llvm::all_of(linalgOp.indexing_maps().getValue(),
421 [](Attribute attr) {
422 return attr.cast<AffineMapAttr>()
423 .getValue()
424 .isProjectedPermutation();
425 }) &&
426 linalgOp.getIndexingMap(fusedTensorIndex).getNumResults() > 0 &&
427 llvm::all_of(linalgOp.iterator_types(), [](Attribute attr) {
428 return attr.cast<StringAttr>().getValue() ==
429 getParallelIteratorTypeName();
430 });
431 }
432
433 /// Implements the fusion of a tensor_reshape op and a generic/indexed_generic
434 /// op as explained in `isFusableWithReshapeByExpansion`. Assumes that those
435 /// conditions have been satisfied.
436 static Optional<SmallVector<Value, 1>>
fuseWithReshapeByExpansion(LinalgOp linalgOp,TensorReshapeOp reshapeOp,unsigned fusedTensorIndex,PatternRewriter & rewriter)437 fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
438 unsigned fusedTensorIndex,
439 PatternRewriter &rewriter) {
440 assert(isFusableWithReshapeByDimExpansion(linalgOp, fusedTensorIndex) &&
441 "preconditions for fuse operation failed");
442 // Check if reshape is expanding or collapsing.
443 bool isExpanding =
444 reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank();
445 RankedTensorType expandedType =
446 isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType();
447 AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex);
448
449 // The reshape is folding/expanding consecutive dimensions. Given the indexing
450 // map of the fused tensor find the number of dimensions each of the loops of
451 // the original op is expanded into. Also record the shape of the expanded
452 // dimensions.
453 ArrayRef<int64_t> expandedShape = expandedType.getShape();
454 Optional<SmallVector<int64_t, 4>> origOpLoopRange =
455 getStaticLoopRanges(linalgOp);
456 if (!origOpLoopRange) {
457 linalgOp.emitError("unable to find loop range for operation");
458 return llvm::None;
459 }
460 SmallVector<unsigned, 4> numFoldedDims(fusedIndexMap.getNumDims(), 1);
461 SmallVector<SmallVector<int64_t, 4>, 4> expandedDimsShape(
462 fusedIndexMap.getNumDims());
463 auto reassociationMaps = reshapeOp.getReassociationMaps();
464 for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
465 unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
466 AffineMap foldedDims = reassociationMaps[resultExpr.index()];
467 numFoldedDims[pos] = foldedDims.getNumResults();
468 ArrayRef<int64_t> shape =
469 expandedShape.slice(foldedDims.getDimPosition(0), numFoldedDims[pos]);
470 expandedDimsShape[pos].assign(shape.begin(), shape.end());
471 }
472 // The remaining dimensions remain the same.
473 for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
474 if (expandedDimsShape[i].empty())
475 expandedDimsShape[i] = {(*origOpLoopRange)[i]};
476
477 if (isa<IndexedGenericOp>(linalgOp.getOperation())) {
478 // For indexed generic op, the region contains arguments that represent the
479 // induction variable value of the loops. In the fused op these values are
480 // obtained by linearizing the expanded dimensions. For now just check that
481 // the extents used in the linearization (all the expanded dims except the
482 // front) are statically know. For dynamic case, we would need shape
483 // information on these dimensions to get these.
484 for (auto &expandedShape : expandedDimsShape) {
485 if (expandedShape.size() == 1)
486 continue;
487 for (int64_t expandedDimShape : llvm::make_range(
488 std::next(expandedShape.begin()), expandedShape.end())) {
489 if (ShapedType::isDynamic(expandedDimShape)) {
490 linalgOp.emitError(
491 "unable to fuse indexed generic op where the expanded dim is "
492 "dynamic");
493 return llvm::None;
494 }
495 }
496 }
497 }
498
499 // The remapping of the indices is then the prefix sum (inclusive) of the
500 // numFoldedDims.
501 SmallVector<unsigned, 4> remapping(numFoldedDims.size() + 1, 0);
502 unsigned sum = 0;
503 for (auto numFoldedDim : llvm::enumerate(numFoldedDims)) {
504 sum += numFoldedDim.value();
505 remapping[numFoldedDim.index() + 1] = sum;
506 }
507
508 SmallVector<AffineMap, 4> expandedOpIndexingMaps;
509 // Compute the modified indexing maps by replacing every loop (AffineDimExpr)
510 // in the original indexing map with the sequence of loops that it is expanded
511 // to.
512 for (AffineMap indexingMap : linalgOp.getIndexingMaps()) {
513 SmallVector<AffineExpr, 4> newExprs;
514 for (AffineExpr expr : indexingMap.getResults()) {
515 unsigned pos = expr.cast<AffineDimExpr>().getPosition();
516 for (unsigned newPos :
517 llvm::seq<unsigned>(remapping[pos], remapping[pos + 1])) {
518 newExprs.push_back(rewriter.getAffineDimExpr(newPos));
519 }
520 }
521 expandedOpIndexingMaps.push_back(
522 AffineMap::get(remapping.back(), indexingMap.getNumSymbols(), newExprs,
523 rewriter.getContext()));
524 }
525
526 // The operands of the expanded op are computed by reshaping the original
527 // operands. The reshape depends on the ordering of the loop used to access
528 // the tensor in the original operation, and are expanded into as many
529 // dimensions as the loop is expanded into (as computed by `remapping`).
530 auto getReshapeInfo =
531 [&](AffineMap operandIndexingMap,
532 SmallVectorImpl<ReassociationIndices> &reassociation,
533 SmallVectorImpl<int64_t> &expandedOpOperandShape) {
534 unsigned reshapeDims = 0;
535 for (AffineExpr expr : operandIndexingMap.getResults()) {
536 unsigned origDim = expr.cast<AffineDimExpr>().getPosition();
537 auto foldedDims = llvm::seq<int64_t>(
538 reshapeDims, reshapeDims + numFoldedDims[origDim]);
539 reassociation.emplace_back(foldedDims.begin(), foldedDims.end());
540 expandedOpOperandShape.append(expandedDimsShape[origDim].begin(),
541 expandedDimsShape[origDim].end());
542 reshapeDims += numFoldedDims[origDim];
543 }
544 };
545 SmallVector<Value, 4> expandedOpOperands;
546 for (auto operand : llvm::enumerate(linalgOp.getInputs())) {
547 if (operand.index() == fusedTensorIndex) {
548 expandedOpOperands.push_back(reshapeOp.src());
549 continue;
550 }
551 AffineMap indexingMap = linalgOp.getIndexingMap(operand.index());
552 SmallVector<ReassociationIndices, 4> reassociation;
553 SmallVector<int64_t, 4> expandedOperandShape;
554 getReshapeInfo(indexingMap, reassociation, expandedOperandShape);
555 Type expandedOperandType = RankedTensorType::get(
556 expandedOperandShape,
557 operand.value().getType().cast<ShapedType>().getElementType());
558 if (expandedOperandType != operand.value().getType()) {
559 expandedOpOperands.push_back(rewriter.create<TensorReshapeOp>(
560 linalgOp.getLoc(), expandedOperandType, operand.value(),
561 reassociation));
562 } else {
563 expandedOpOperands.push_back(operand.value());
564 }
565 }
566 SmallVector<Type, 1> resultTypes;
567 SmallVector<SmallVector<ReassociationIndices, 4>, 1> resultReassociation;
568 for (auto result : llvm::enumerate(linalgOp->getResults())) {
569 AffineMap indexingMap =
570 linalgOp.getIndexingMap(linalgOp.getNumInputs() + result.index());
571 SmallVector<ReassociationIndices, 4> reassociation;
572 SmallVector<int64_t, 4> expandedResultShape;
573 getReshapeInfo(indexingMap, reassociation, expandedResultShape);
574 resultTypes.push_back(RankedTensorType::get(
575 expandedResultShape,
576 result.value().getType().cast<ShapedType>().getElementType()));
577 resultReassociation.emplace_back(std::move(reassociation));
578 }
579
580 // The iterator types of the expanded op are all parallel.
581 SmallVector<StringRef, 4> iteratorTypes(remapping.back(),
582 getParallelIteratorTypeName());
583
584 LinalgOp fusedOp = createLinalgOpOfSameType(
585 linalgOp, rewriter, linalgOp.getLoc(), resultTypes,
586 /*inputs=*/expandedOpOperands,
587 /*outputBuffers=*/ValueRange{},
588 /*initTensors=*/ValueRange{}, expandedOpIndexingMaps, iteratorTypes);
589 Region &fusedRegion = fusedOp->getRegion(0);
590 Region &originalRegion = linalgOp->getRegion(0);
591
592 if (isa<GenericOp>(linalgOp.getOperation())) {
593 rewriter.cloneRegionBefore(originalRegion, fusedRegion,
594 fusedRegion.begin());
595 } else {
596 assert(isa<IndexedGenericOp>(linalgOp.getOperation()));
597 // Create an entry block in the fused Region with same number of arguments
598 // as the fused op
599 Block *fusedEntryBlock = new Block;
600 fusedRegion.push_back(fusedEntryBlock);
601 rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.end());
602
603 // Merge the entry block of the fused op with the cloned blocks. For this
604 // compute the value for arguments of the region in the original operation
605 // in terms of the arguments of the fused op. Since the original operation
606 // is expanded, the expanded dimensions need to be folded back to get the
607 // replacement value for the arguments corresponding to interation index.
608 // For now this expects that all the loop ranges are constants, which is
609 // true if the shapes are all static. This has already been checked in the
610 // precondition.
611 using namespace edsc::op;
612 using namespace edsc::intrinsics;
613 OpBuilder::InsertionGuard guard(rewriter);
614 SmallVector<Value, 4> argReplacements(originalRegion.getNumArguments());
615 rewriter.setInsertionPointToStart(fusedEntryBlock);
616 edsc::ScopedContext scopedContext(rewriter, fusedOp.getLoc());
617 IndexType indexType = rewriter.getIndexType();
618 for (unsigned i : llvm::seq<unsigned>(0, numFoldedDims.size())) {
619 Value linearizedIndex = fusedEntryBlock->addArgument(indexType);
620 for (unsigned foldedDim = remapping[i] + 1; foldedDim != remapping[i + 1];
621 foldedDim++) {
622 int64_t expandedDimExtent =
623 expandedDimsShape[i][foldedDim - remapping[i]];
624 assert(!ShapedType::isDynamic(expandedDimExtent));
625 linearizedIndex =
626 linearizedIndex * std_constant_index(expandedDimExtent);
627 linearizedIndex =
628 linearizedIndex + fusedEntryBlock->addArgument(indexType);
629 }
630 argReplacements[i] = linearizedIndex;
631 }
632 for (unsigned i :
633 llvm::seq<unsigned>(numFoldedDims.size(), argReplacements.size())) {
634 argReplacements[i] =
635 fusedEntryBlock->addArgument(originalRegion.getArgument(i).getType());
636 }
637 rewriter.mergeBlocks(fusedEntryBlock->getNextNode(), fusedEntryBlock,
638 argReplacements);
639 }
640
641 // Reshape the result values to their original shape if this is a collapsing
642 // reshape folded into its consumer.
643 SmallVector<Value, 1> resultVals;
644 for (auto result : llvm::enumerate(linalgOp->getResults())) {
645 if (!isExpanding &&
646 resultTypes[result.index()] != result.value().getType()) {
647 resultVals.push_back(rewriter.create<TensorReshapeOp>(
648 linalgOp.getLoc(), result.value().getType(),
649 fusedOp->getResult(result.index()),
650 resultReassociation[result.index()]));
651 } else {
652 resultVals.push_back(fusedOp->getResult(result.index()));
653 }
654 }
655 // Assuming a single result.
656 return resultVals;
657 }
658
659 namespace {
660
661 /// Pattern to fold tensor_reshape op with its consumer by using the source of
662 /// the reshape op as the operand in the consumer (instead of the result of the
663 /// tensor_reshapeop) when the tensor_reshape op is collapsing. The
664 /// corresponding index map in the consumer needs to be modified to linearize
665 /// the folded dimension.
666 ///
667 /// For example,
668 ///
669 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
670 /// %0 = linalg.tensor_reshape %arg0
671 /// [affine_map<(i, j, k, l) -> (i)>, affine_map<(i, j, k, l) -> (j, k)>,
672 /// affine_map<(i, j, k, l) -> (l)>]
673 /// tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
674 /// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
675 /// ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
676 /// -> tensor<?x?x4x?xf32>
677 ///
678 /// can be folded into
679 ///
680 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
681 /// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
682 /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
683 /// ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
684 /// -> tensor<?x?x4x?xf32>
685 template <typename LinalgOpTy>
686 struct FoldProducerReshapeOpByLinearization
687 : public OpRewritePattern<LinalgOpTy> {
688 using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
689
matchAndRewrite__anon6d6c30c10411::FoldProducerReshapeOpByLinearization690 LogicalResult matchAndRewrite(LinalgOpTy op,
691 PatternRewriter &rewriter) const override {
692 if (!op.hasTensorSemantics())
693 return failure();
694 LinalgOp linalgOp = cast<LinalgOp>(op.getOperation());
695 for (auto operand : llvm::enumerate(linalgOp.getInputs())) {
696 TensorReshapeOp reshapeOp =
697 operand.value().getDefiningOp<TensorReshapeOp>();
698 if (!reshapeOp ||
699 !isTensorReshapeOpFoldableByLinearization(
700 reshapeOp, linalgOp.getInputIndexingMap(operand.index()),
701 /*asProducer =*/true))
702 continue;
703
704 // Compute the fused operands list,
705 SmallVector<Value, 2> fusedOperands(linalgOp.getInputs());
706 fusedOperands[operand.index()] = reshapeOp.src();
707
708 // Compute indexing_maps for the fused operation. The indexing_maps for
709 // the operands of the consumers that arent fused are the same.
710 SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>(
711 op.indexing_maps().template getAsValueRange<AffineMapAttr>());
712
713 // Accepted consumer maps are either identity or permutation.
714 auto invMap = inversePermutation(fusedIndexMaps[operand.index()]);
715
716 // Compute the indexing map to use for the result of the producer.
717 AffineMap modifiedMap =
718 linearizeCollapsedDims(invMap, reshapeOp.getResultType().getShape(),
719 reshapeOp.getReassociationMaps());
720 for (AffineExpr expr : modifiedMap.getResults()) {
721 if (!expr.isPureAffine())
722 return failure();
723 }
724 fusedIndexMaps[operand.index()] = modifiedMap;
725
726 // Further check that the resulting index maps can be fused and
727 // inverted. Without this the resultant op is not legal.
728 if (!inversePermutation(concatAffineMaps(fusedIndexMaps)))
729 return op.emitRemark("fused op loop bound computation failed");
730
731 rewriter.startRootUpdate(op);
732 op->setOperands(fusedOperands);
733 op.indexing_mapsAttr(rewriter.getAffineMapArrayAttr(fusedIndexMaps));
734 rewriter.finalizeRootUpdate(op);
735 if (reshapeOp.use_empty())
736 rewriter.eraseOp(reshapeOp);
737 return success();
738 }
739 return op.emitRemark("no fusion candidates found");
740 }
741 };
742
743 /// Pattern to fuse a tensor_reshape op with its consumer
744 /// generic/indexed_generic op, when the reshape op is collapsing
745 /// dimensions. The dimensionality of the loop in the consumer is expanded.
746 template <typename GenericOpTy>
747 struct FoldWithProducerReshapeOpByExpansion
748 : public OpRewritePattern<GenericOpTy> {
749 using OpRewritePattern<GenericOpTy>::OpRewritePattern;
750
matchAndRewrite__anon6d6c30c10411::FoldWithProducerReshapeOpByExpansion751 LogicalResult matchAndRewrite(GenericOpTy genericOp,
752 PatternRewriter &rewriter) const override {
753 LinalgOp linalgOp = cast<LinalgOp>(genericOp.getOperation());
754 for (auto operand : llvm::enumerate(linalgOp.getInputs())) {
755 TensorReshapeOp reshapeOp =
756 operand.value().getDefiningOp<TensorReshapeOp>();
757 if (!reshapeOp)
758 continue;
759
760 // Fold only if
761 // - The tensor reshape op is folding.
762 // - All constraints of fusing with reshape by expansion are met.
763 if (reshapeOp.getSrcType().getRank() <
764 reshapeOp.getResultType().getRank() ||
765 !isFusableWithReshapeByDimExpansion(linalgOp, operand.index()))
766 continue;
767
768 Optional<SmallVector<Value, 1>> replacementValues =
769 fuseWithReshapeByExpansion(linalgOp, reshapeOp, operand.index(),
770 rewriter);
771 if (!replacementValues)
772 return failure();
773 rewriter.replaceOp(genericOp, replacementValues.getValue());
774 if (reshapeOp.use_empty())
775 rewriter.eraseOp(reshapeOp);
776 return success();
777 }
778 return failure();
779 }
780 };
781
782 /// Pattern to fold tensor_reshape op with its producer. The corresponding index
783 /// map in the consumer needs to be modified to linearize the folded dimension.
784 struct FoldConsumerReshapeOpByLinearization
785 : public OpRewritePattern<TensorReshapeOp> {
786 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
787
matchAndRewrite__anon6d6c30c10411::FoldConsumerReshapeOpByLinearization788 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
789 PatternRewriter &rewriter) const override {
790 LinalgOp producer = reshapeOp.src().getDefiningOp<LinalgOp>();
791 if (!producer ||
792 !isa<GenericOp, IndexedGenericOp>(producer.getOperation()) ||
793 !producer.hasTensorSemantics() || producer.getNumOutputs() != 1 ||
794 !isTensorReshapeOpFoldableByLinearization(
795 reshapeOp, producer.getOutputIndexingMap(0), /*asProducer =*/false))
796 return failure();
797 // The indexing_maps for the operands of the fused operation are same as
798 // those for the operands of the producer.
799 SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>(
800 producer.indexing_maps().getAsValueRange<AffineMapAttr>());
801
802 auto invMap = inversePermutation(producer.getOutputIndexingMap(0));
803
804 // Compute the indexing map to use for the operand of the producer.
805 AffineMap modifiedMap =
806 linearizeCollapsedDims(invMap, reshapeOp.getSrcType().getShape(),
807 reshapeOp.getReassociationMaps());
808 for (AffineExpr expr : modifiedMap.getResults()) {
809 if (!expr.isPureAffine())
810 return reshapeOp.emitRemark("fused op indexing map is not affine");
811 }
812 fusedIndexMaps.back() = modifiedMap;
813
814 // Further check that the resulting index maps can be fused and
815 // inverted. Without this the resultant op is not legal.
816 if (!inversePermutation(concatAffineMaps(fusedIndexMaps)))
817 return reshapeOp.emitRemark("fused op loop bound computation failed");
818
819 LinalgOp fusedOp = createLinalgOpOfSameType(
820 producer, rewriter, rewriter.getUnknownLoc(), reshapeOp.getResultType(),
821 /*inputs=*/producer.getInputs(),
822 /*outputBuffers=*/ValueRange{},
823 /*initTensors=*/ValueRange{}, // no init tensors for now.
824 rewriter.getAffineMapArrayAttr(fusedIndexMaps),
825 producer.iterator_types(),
826 /*doc=*/nullptr,
827 /*library_call=*/nullptr,
828 /*sparse=*/nullptr);
829 auto &fusedRegion = fusedOp->getRegion(0);
830 rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion,
831 fusedRegion.begin());
832 rewriter.replaceOp(reshapeOp, fusedOp->getResults());
833 if (producer.use_empty())
834 rewriter.eraseOp(producer);
835 return success();
836 }
837 };
838
839 /// Pattern to fold a tensor_reshape op with its producer generic op if the
840 /// tensor_reshape op is expanding, by expanding the dimensionality of the loop
841 /// in the producer op.
842 struct FoldReshapeWithGenericOpByExpansion
843 : public OpRewritePattern<TensorReshapeOp> {
844 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
matchAndRewrite__anon6d6c30c10411::FoldReshapeWithGenericOpByExpansion845 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
846 PatternRewriter &rewriter) const override {
847 // Fold only if
848 // - The tensor reshape op is a expanding case.
849 // - All constraints of fusing with reshape by expansion are met.
850 if (reshapeOp.getSrcType().getRank() > reshapeOp.getResultType().getRank())
851 return failure();
852 LinalgOp producer = reshapeOp.src().getDefiningOp<LinalgOp>();
853 if (!producer || producer.getNumOutputs() != 1 ||
854 !isFusableWithReshapeByDimExpansion(producer, producer.getNumInputs()))
855 return failure();
856 Optional<SmallVector<Value, 1>> replacementValues =
857 fuseWithReshapeByExpansion(producer, reshapeOp, producer.getNumInputs(),
858 rewriter);
859 if (!replacementValues)
860 return failure();
861 rewriter.replaceOp(reshapeOp, replacementValues.getValue());
862 if (producer.use_empty())
863 rewriter.eraseOp(producer);
864 return success();
865 }
866 };
867
868 /// Pattern to fold a GenericOp/IndexedGenericOp with a splat constant.
869 template <typename LinalgOpTy>
870 struct FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
871 using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
872
matchAndRewrite__anon6d6c30c10411::FoldSplatConstants873 LogicalResult matchAndRewrite(LinalgOpTy op,
874 PatternRewriter &rewriter) const override {
875 if (!op.hasTensorSemantics())
876 return failure();
877 LinalgOp linalgOp = cast<LinalgOp>(op.getOperation());
878 for (auto operand : llvm::enumerate(linalgOp.getInputs())) {
879 ConstantOp constantOp = operand.value().getDefiningOp<ConstantOp>();
880 if (!constantOp ||
881 !constantOp.value().cast<DenseElementsAttr>().isSplat())
882 continue;
883
884 // The indexing_maps for the operands of the fused operation are same as
885 // those for the operands of the linalgOp without the indexing map at
886 // operand.index()
887 SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>(
888 linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>());
889 fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), operand.index()));
890
891 // The operands list is same as the linalgOp with the argument for
892 // constant index dropped.
893 SmallVector<Value, 4> fusedOperands(linalgOp.getInputs());
894 fusedOperands.erase(std::next(fusedOperands.begin(), operand.index()));
895
896 // Create a constant scalar value from the splat constant.
897 Value scalarConstant = rewriter.create<ConstantOp>(
898 constantOp.getLoc(),
899 constantOp.value().cast<DenseElementsAttr>().getSplatValue());
900
901 LinalgOp fusedOp = createLinalgOpOfSameType(
902 linalgOp, rewriter, rewriter.getUnknownLoc(),
903 linalgOp->getResultTypes(),
904 /*inputs=*/fusedOperands,
905 /*outputBuffers=*/ValueRange{},
906 /*initTensors=*/ValueRange{}, // no init tensors for now.
907 rewriter.getAffineMapArrayAttr(fusedIndexMaps),
908 linalgOp.iterator_types(),
909 /*doc=*/nullptr,
910 /*library_call=*/nullptr,
911 /*sparse=*/nullptr);
912
913 // Map the block argument corresponding to the replaced argument with the
914 // scalar constant.
915 Region &linalgOpRegion = linalgOp->getRegion(0);
916 Block &entryBlock = *linalgOpRegion.begin();
917 unsigned argIndex = entryBlock.getNumArguments() -
918 linalgOp.getNumInputs() + operand.index();
919 BlockAndValueMapping mapping;
920 mapping.map(entryBlock.getArgument(argIndex), scalarConstant);
921 Region &fusedRegion = fusedOp->getRegion(0);
922 rewriter.cloneRegionBefore(linalgOpRegion, fusedRegion,
923 fusedRegion.begin(), mapping);
924 rewriter.replaceOp(linalgOp, fusedOp->getResults());
925 if (constantOp.use_empty())
926 rewriter.eraseOp(constantOp);
927 return success();
928 }
929 return failure();
930 }
931 };
932 } // namespace
933
934 Optional<SmallVector<Value, 1>>
fuseTensorOps(PatternRewriter & rewriter,Operation * consumer,unsigned consumerIdx)935 mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, Operation *consumer,
936 unsigned consumerIdx) {
937 if (consumerIdx >= consumer->getNumOperands())
938 return llvm::None;
939 Operation *producer = consumer->getOperand(consumerIdx).getDefiningOp();
940 if (!producer || producer->getNumResults() != 1)
941 return llvm::None;
942
943 // Fuse when consumer is GenericOp or IndexedGenericOp.
944 if (!isa<GenericOp, IndexedGenericOp>(consumer) ||
945 !isa<GenericOp, IndexedGenericOp>(producer))
946 return llvm::None;
947
948 return fuseTensorOpsImpl(cast<LinalgOp>(producer), cast<LinalgOp>(consumer),
949 consumerIdx, rewriter);
950 }
951
952 namespace {
953 /// Patterns to fuse a generic op, with the producer of its operands.
954 template <typename LinalgOpTy>
955 struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
956 using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
957
matchAndRewrite__anon6d6c30c10511::FuseTensorOps958 LogicalResult matchAndRewrite(LinalgOpTy op,
959 PatternRewriter &rewriter) const override {
960 // Find the first operand that is defined by another generic op on tensors.
961 for (auto operandNum : llvm::seq<unsigned>(0, op->getNumOperands())) {
962 Operation *producer = op->getOperand(operandNum).getDefiningOp();
963 if (!producer)
964 continue;
965 Optional<SmallVector<Value, 1>> fusedOpResults =
966 fuseTensorOps(rewriter, op, operandNum);
967 if (fusedOpResults) {
968 rewriter.replaceOp(op, *fusedOpResults);
969 if (producer->use_empty())
970 rewriter.eraseOp(producer);
971 return success();
972 }
973 }
974 return failure();
975 }
976 };
977
978 /// Pass that fuses generic ops on tensors. Used only for testing.
979 struct FusionOfTensorOpsPass
980 : public LinalgFusionOfTensorOpsBase<FusionOfTensorOpsPass> {
runOnOperation__anon6d6c30c10511::FusionOfTensorOpsPass981 void runOnOperation() override {
982 OwningRewritePatternList patterns;
983 Operation *op = getOperation();
984 populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns);
985 applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
986 }
987 };
988
989 /// Pass to test folding of reshape op with generic/indexed_generic ops by
990 /// linearization.
991 struct FoldReshapeOpsByLinearizationPass
992 : public LinalgFoldReshapeOpsByLinearizationBase<
993 FoldReshapeOpsByLinearizationPass> {
runOnOperation__anon6d6c30c10511::FoldReshapeOpsByLinearizationPass994 void runOnOperation() override {
995 OwningRewritePatternList patterns;
996 Operation *op = getOperation();
997 populateFoldReshapeOpsByLinearizationPatterns(op->getContext(), patterns);
998 applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
999 }
1000 };
1001
1002 } // namespace
1003
populateFoldReshapeOpsByLinearizationPatterns(MLIRContext * context,OwningRewritePatternList & patterns)1004 void mlir::populateFoldReshapeOpsByLinearizationPatterns(
1005 MLIRContext *context, OwningRewritePatternList &patterns) {
1006 patterns.insert<FoldProducerReshapeOpByLinearization<GenericOp>,
1007 FoldProducerReshapeOpByLinearization<IndexedGenericOp>,
1008 FoldConsumerReshapeOpByLinearization>(context);
1009 }
1010
populateFoldReshapeOpsByExpansionPatterns(MLIRContext * context,OwningRewritePatternList & patterns)1011 void mlir::populateFoldReshapeOpsByExpansionPatterns(
1012 MLIRContext *context, OwningRewritePatternList &patterns) {
1013 patterns.insert<FoldReshapeWithGenericOpByExpansion,
1014 FoldWithProducerReshapeOpByExpansion<GenericOp>,
1015 FoldWithProducerReshapeOpByExpansion<IndexedGenericOp>>(
1016 context);
1017 }
1018
populateLinalgTensorOpsFusionPatterns(MLIRContext * context,OwningRewritePatternList & patterns)1019 void mlir::populateLinalgTensorOpsFusionPatterns(
1020 MLIRContext *context, OwningRewritePatternList &patterns) {
1021 patterns.insert<FuseTensorOps<GenericOp>, FuseTensorOps<IndexedGenericOp>,
1022 FoldSplatConstants<GenericOp>,
1023 FoldSplatConstants<IndexedGenericOp>>(context);
1024 populateFoldReshapeOpsByExpansionPatterns(context, patterns);
1025 GenericOp::getCanonicalizationPatterns(patterns, context);
1026 IndexedGenericOp::getCanonicalizationPatterns(patterns, context);
1027 TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
1028 }
1029
createLinalgFusionOfTensorOpsPass()1030 std::unique_ptr<Pass> mlir::createLinalgFusionOfTensorOpsPass() {
1031 return std::make_unique<FusionOfTensorOpsPass>();
1032 }
1033
createFoldReshapeOpsByLinearizationPass()1034 std::unique_ptr<Pass> mlir::createFoldReshapeOpsByLinearizationPass() {
1035 return std::make_unique<FoldReshapeOpsByLinearizationPass>();
1036 }
1037