1 //===- Builders.cpp - MLIR Declarative Linalg Builders --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Vector/EDSC/Builders.h"
10 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
11 #include "mlir/Dialect/Vector/VectorOps.h"
12 #include "mlir/EDSC/Builders.h"
13 #include "mlir/IR/AffineExpr.h"
14 #include "mlir/IR/Builders.h"
15 
16 using namespace mlir;
17 using namespace mlir::edsc;
18 using namespace mlir::edsc::intrinsics;
19 using namespace mlir::edsc::ops;
20 
vector_contraction(StructuredIndexed A,StructuredIndexed B,StructuredIndexed C,ArrayRef<IteratorType> iteratorTypes)21 Value mlir::edsc::ops::vector_contraction(
22     StructuredIndexed A, StructuredIndexed B, StructuredIndexed C,
23     ArrayRef<IteratorType> iteratorTypes) {
24   using IndexingExprs = ArrayRef<ArrayRef<AffineExpr>>;
25   return vector_contract(
26       A.getValue(), B.getValue(), C.getValue(),
27       IndexingExprs{A.getExprs(), B.getExprs(), C.getExprs()},
28       ArrayRef<StringRef>{
29           llvm::to_vector<8>(llvm::map_range(iteratorTypes, toString))});
30 }
31 
vector_contraction_matmul(Value A,Value B,Value C)32 Value mlir::edsc::ops::vector_contraction_matmul(Value A, Value B, Value C) {
33   AffineExpr m, n, k;
34   bindDims(ScopedContext::getContext(), m, n, k);
35   return vector_contraction(StructuredIndexed(A, {m, k}),
36                             StructuredIndexed(B, {k, n}),
37                             StructuredIndexed(C, {m, n}),
38                             {IteratorType::Parallel, IteratorType::Parallel,
39                              IteratorType::Reduction});
40 }
41