1 //===- Interchange.cpp - Linalg interchange transformation ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg interchange transformation.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16 #include "mlir/Dialect/Linalg/Utils/Utils.h"
17 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
18 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
19 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Support/LLVM.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include <type_traits>
29 
30 #define DEBUG_TYPE "linalg-interchange"
31 
32 using namespace mlir;
33 using namespace mlir::linalg;
34 
interchangeGenericLinalgOpPrecondition(Operation * op,ArrayRef<unsigned> interchangeVector)35 LogicalResult mlir::linalg::interchangeGenericLinalgOpPrecondition(
36     Operation *op, ArrayRef<unsigned> interchangeVector) {
37   if (interchangeVector.empty())
38     return failure();
39   // Transformation applies to generic ops only.
40   if (!isa<GenericOp, IndexedGenericOp>(op))
41     return failure();
42   LinalgOp linOp = cast<LinalgOp>(op);
43   // Transformation applies to buffers only.
44   if (!linOp.hasBufferSemantics())
45     return failure();
46   // Permutation must be applicable.
47   if (linOp.getIndexingMap(0).getNumInputs() != interchangeVector.size())
48     return failure();
49   // Permutation map must be invertible.
50   if (!inversePermutation(
51           AffineMap::getPermutationMap(interchangeVector, op->getContext())))
52     return failure();
53   return success();
54 }
55 
interchange(LinalgOp op,ArrayRef<unsigned> interchangeVector)56 LinalgOp mlir::linalg::interchange(LinalgOp op,
57                                    ArrayRef<unsigned> interchangeVector) {
58   if (interchangeVector.empty())
59     return op;
60 
61   MLIRContext *context = op.getContext();
62   auto permutationMap = inversePermutation(
63       AffineMap::getPermutationMap(interchangeVector, context));
64   assert(permutationMap && "expected permutation to be invertible");
65   SmallVector<Attribute, 4> newIndexingMaps;
66   auto indexingMaps = op.indexing_maps().getValue();
67   for (unsigned i = 0, e = op.getNumInputsAndOutputs(); i != e; ++i) {
68     AffineMap m = indexingMaps[i].cast<AffineMapAttr>().getValue();
69     if (!permutationMap.isEmpty())
70       m = m.compose(permutationMap);
71     newIndexingMaps.push_back(AffineMapAttr::get(m));
72   }
73   auto itTypes = op.iterator_types().getValue();
74   SmallVector<Attribute, 4> itTypesVector;
75   for (unsigned i = 0, e = itTypes.size(); i != e; ++i)
76     itTypesVector.push_back(itTypes[i]);
77   applyPermutationToVector(itTypesVector, interchangeVector);
78 
79   op.setAttr(getIndexingMapsAttrName(),
80              ArrayAttr::get(newIndexingMaps, context));
81   op.setAttr(getIteratorTypesAttrName(),
82              ArrayAttr::get(itTypesVector, context));
83 
84   return op;
85 }
86