//===- Interchange.cpp - Linalg interchange transformation ----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements the linalg interchange transformation. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include #define DEBUG_TYPE "linalg-interchange" using namespace mlir; using namespace mlir::linalg; LogicalResult mlir::linalg::interchangeGenericLinalgOpPrecondition( Operation *op, ArrayRef interchangeVector) { if (interchangeVector.empty()) return failure(); // Transformation applies to generic ops only. if (!isa(op)) return failure(); LinalgOp linOp = cast(op); // Transformation applies to buffers only. if (!linOp.hasBufferSemantics()) return failure(); // Permutation must be applicable. if (linOp.getIndexingMap(0).getNumInputs() != interchangeVector.size()) return failure(); // Permutation map must be invertible. if (!inversePermutation( AffineMap::getPermutationMap(interchangeVector, op->getContext()))) return failure(); return success(); } LinalgOp mlir::linalg::interchange(LinalgOp op, ArrayRef interchangeVector) { if (interchangeVector.empty()) return op; MLIRContext *context = op.getContext(); auto permutationMap = inversePermutation( AffineMap::getPermutationMap(interchangeVector, context)); assert(permutationMap && "expected permutation to be invertible"); SmallVector newIndexingMaps; auto indexingMaps = op.indexing_maps().getValue(); for (unsigned i = 0, e = op.getNumInputsAndOutputs(); i != e; ++i) { AffineMap m = indexingMaps[i].cast().getValue(); if (!permutationMap.isEmpty()) m = m.compose(permutationMap); newIndexingMaps.push_back(AffineMapAttr::get(m)); } auto itTypes = op.iterator_types().getValue(); SmallVector itTypesVector; for (unsigned i = 0, e = itTypes.size(); i != e; ++i) itTypesVector.push_back(itTypes[i]); applyPermutationToVector(itTypesVector, interchangeVector); op.setAttr(getIndexingMapsAttrName(), ArrayAttr::get(newIndexingMaps, context)); op.setAttr(getIteratorTypesAttrName(), ArrayAttr::get(itTypesVector, context)); return op; }