1 //===- Builders.cpp - MLIR Declarative Linalg Builders --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Builders.h"
10 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
11 #include "mlir/Dialect/Linalg/EDSC/Builders.h"
12 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
13 #include "mlir/Dialect/SCF/EDSC/Builders.h"
14 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
15 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
16 #include "mlir/IR/AffineExpr.h"
17 
18 using namespace mlir;
19 using namespace mlir::edsc;
20 using namespace mlir::edsc::intrinsics;
21 using namespace mlir::linalg;
22 using namespace mlir::scf;
23 
makeGenericLinalgOp(ArrayRef<IteratorType> iteratorTypes,ArrayRef<StructuredIndexed> inputs,ArrayRef<StructuredIndexed> outputBuffers,ArrayRef<Value> initTensors,ArrayRef<StructuredIndexed> resultTensorTypes,function_ref<void (ValueRange)> regionBuilder,ArrayRef<Value> otherValues,ArrayRef<Attribute> otherAttributes)24 Operation *mlir::edsc::makeGenericLinalgOp(
25     ArrayRef<IteratorType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
26     ArrayRef<StructuredIndexed> outputBuffers, ArrayRef<Value> initTensors,
27     ArrayRef<StructuredIndexed> resultTensorTypes,
28     function_ref<void(ValueRange)> regionBuilder, ArrayRef<Value> otherValues,
29     ArrayRef<Attribute> otherAttributes) {
30   OpBuilder &builder = edsc::ScopedContext::getBuilderRef();
31 
32   // Build maps
33   SmallVector<SmallVector<AffineExpr, 4>, 4> exprsList;
34   exprsList.reserve(inputs.size() + outputBuffers.size() + initTensors.size());
35   for (auto container : {inputs, outputBuffers, resultTensorTypes})
36     for (const StructuredIndexed &s : container)
37       exprsList.emplace_back(s.getExprs().begin(), s.getExprs().end());
38   auto maps = AffineMap::inferFromExprList(exprsList);
39 
40   SmallVector<Type, 4> types;
41   assert(llvm::all_of(resultTensorTypes, [](const StructuredIndexed &s) {
42     return !s.hasValue();
43   }));
44   std::copy(resultTensorTypes.begin(), resultTensorTypes.end(),
45             std::back_inserter(types));
46 
47   SmallVector<Value, 4> inputValues, outputBufferValues, initTensorValues;
48   inputValues.reserve(inputs.size());
49   outputBufferValues.reserve(outputBuffers.size());
50   initTensorValues.reserve(initTensors.size());
51   std::copy(inputs.begin(), inputs.end(), std::back_inserter(inputValues));
52   std::copy(outputBuffers.begin(), outputBuffers.end(),
53             std::back_inserter(outputBufferValues));
54   std::copy(initTensors.begin(), initTensors.end(),
55             std::back_inserter(initTensorValues));
56 
57   auto iteratorStrTypes =
58       llvm::to_vector<8>(llvm::map_range(iteratorTypes, toString));
59   // clang-format off
60   auto *op =
61       edsc::ScopedContext::getBuilderRef()
62           .create<linalg::GenericOp>(
63               edsc::ScopedContext::getLocation(),
64               types,
65               inputValues,
66               outputBufferValues,
67               initTensorValues,
68               builder.getAffineMapArrayAttr(maps),
69               builder.getStrArrayAttr(iteratorStrTypes),
70               StringAttr() /*doc*/,
71               StringAttr() /*library_call*/,
72               ArrayAttr() /*sparse*/
73               /* TODO: other attributes in op */
74               )
75           .getOperation();
76   // clang-format on
77 
78   using namespace edsc;
79   SmallVector<Type, 4> blockTypes;
80   blockTypes.reserve(inputs.size() + outputBuffers.size() + initTensors.size());
81   for (auto container : {inputs, outputBuffers})
82     for (const StructuredIndexed &s : container)
83       blockTypes.push_back(getElementTypeOrSelf(s.getType()));
84   for (Value v : initTensors)
85     blockTypes.push_back(getElementTypeOrSelf(v.getType()));
86 
87   assert(op->getNumRegions() == 1);
88   assert(op->getRegion(0).empty());
89   OpBuilder opBuilder(op);
90   ScopedContext scope(opBuilder, op->getLoc());
91   buildInNewBlock(op->getRegion(0), blockTypes, regionBuilder);
92   assert(llvm::hasSingleElement(op->getRegion(0)));
93   return op;
94 }
95 
mulRegionBuilder(ValueRange args)96 void mlir::edsc::ops::mulRegionBuilder(ValueRange args) {
97   using edsc::op::operator+;
98   using edsc::op::operator*;
99   assert(args.size() == 2 && "expected 2 block arguments");
100   Value a(args[0]), b(args[1]);
101   linalg_yield(a * b);
102 }
103 
macRegionBuilder(ValueRange args)104 void mlir::edsc::ops::macRegionBuilder(ValueRange args) {
105   using edsc::op::operator+;
106   using edsc::op::operator*;
107   assert(args.size() == 3 && "expected 3 block arguments");
108   Value a(args[0]), b(args[1]), c(args[2]);
109   linalg_yield(c + a * b);
110 }
111 
linalg_generic_pointwise(UnaryPointwiseOpBuilder unaryOp,StructuredIndexed I,StructuredIndexed O)112 Operation *mlir::edsc::ops::linalg_generic_pointwise(
113     UnaryPointwiseOpBuilder unaryOp, StructuredIndexed I, StructuredIndexed O) {
114   SmallVector<IteratorType, 4> iterTypes(O.getExprs().size(),
115                                          IteratorType::Parallel);
116   auto fun = [&unaryOp](ValueRange args) {
117     assert(!args.empty() && "expected >= 1 block arguments");
118     Value a(args[0]);
119     linalg_yield(unaryOp(a));
120   };
121   if (O.getType().isa<RankedTensorType>())
122     return makeGenericLinalgOp(iterTypes, /*inputs=*/{I}, /*outputBuffers=*/{},
123                                /*initTensors=*/{}, /*resultTensorTypes=*/{O},
124                                fun);
125   return makeGenericLinalgOp(iterTypes, /*inputs=*/{I}, /*outputBuffers=*/{O},
126                              /*initTensors=*/{}, /*resultTensorTypes=*/{}, fun);
127 }
128 
linalg_generic_pointwise_tanh(StructuredIndexed I,StructuredIndexed O)129 Operation *mlir::edsc::ops::linalg_generic_pointwise_tanh(StructuredIndexed I,
130                                                           StructuredIndexed O) {
131   UnaryPointwiseOpBuilder unOp([](Value a) -> Value { return std_tanh(a); });
132   return linalg_generic_pointwise(unOp, I, O);
133 }
134 
135 /// Binary pointwise operation (with broadcast) entry point.
linalg_generic_pointwise(BinaryPointwiseOpBuilder binaryOp,StructuredIndexed I1,StructuredIndexed I2,StructuredIndexed O)136 Operation *mlir::edsc::ops::linalg_generic_pointwise(
137     BinaryPointwiseOpBuilder binaryOp, StructuredIndexed I1,
138     StructuredIndexed I2, StructuredIndexed O) {
139   SmallVector<IteratorType, 4> iterTypes(O.getExprs().size(),
140                                          IteratorType::Parallel);
141   auto fun = [&binaryOp](ValueRange args) {
142     assert(args.size() >= 2 && "expected >= 2 block arguments");
143     Value a(args[0]), b(args[1]);
144     linalg_yield(binaryOp(a, b));
145   };
146   if (O.getType().isa<RankedTensorType>())
147     return makeGenericLinalgOp(
148         iterTypes, /*inputs=*/{I1, I2}, /*outputBuffers=*/{},
149         /*initTensors=*/{}, /*resultTensorTypes=*/{O}, fun);
150   return makeGenericLinalgOp(iterTypes, /*inputs=*/{I1, I2},
151                              /*outputBuffers=*/{O},
152                              /*initTensors=*/{}, /*resultTensorTypes=*/{}, fun);
153 }
154 
linalg_generic_pointwise_add(StructuredIndexed I1,StructuredIndexed I2,StructuredIndexed O)155 Operation *mlir::edsc::ops::linalg_generic_pointwise_add(StructuredIndexed I1,
156                                                          StructuredIndexed I2,
157                                                          StructuredIndexed O) {
158   using edsc::op::operator+;
159   BinaryPointwiseOpBuilder binOp(
160       [](Value a, Value b) -> Value { return a + b; });
161   return linalg_generic_pointwise(binOp, I1, I2, O);
162 }
163 
linalg_generic_pointwise_max(StructuredIndexed I1,StructuredIndexed I2,StructuredIndexed O)164 Operation *mlir::edsc::ops::linalg_generic_pointwise_max(StructuredIndexed I1,
165                                                          StructuredIndexed I2,
166                                                          StructuredIndexed O) {
167   BinaryPointwiseOpBuilder binOp([](Value a, Value b) -> Value {
168     using edsc::op::sgt;
169     return std_select(sgt(a, b), a, b);
170   });
171   return linalg_generic_pointwise(binOp, I1, I2, O);
172 }
173 
174 Operation *
linalg_generic_matmul(Value vA,Value vB,Value vC,MatmulRegionBuilder regionBuilder)175 mlir::edsc::ops::linalg_generic_matmul(Value vA, Value vB, Value vC,
176                                        MatmulRegionBuilder regionBuilder) {
177   // clang-format off
178   AffineExpr m, n, k;
179   bindDims(ScopedContext::getContext(), m, n, k);
180   StructuredIndexed A(vA), B(vB), C(vC);
181   return makeGenericLinalgOp(
182     {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
183     /*inputs=*/{A({m, k}), B({k, n})},
184     /*outputBuffers=*/{C({m, n})},
185     /*initTensors=*/{},
186     /*resultTensorTypes=*/{},
187     regionBuilder);
188   // clang-format on
189 }
190 
191 Operation *
linalg_generic_matmul(Value vA,Value vB,Value vC,RankedTensorType tD,MatmulRegionBuilder regionBuilder)192 mlir::edsc::ops::linalg_generic_matmul(Value vA, Value vB, Value vC,
193                                        RankedTensorType tD,
194                                        MatmulRegionBuilder regionBuilder) {
195   // clang-format off
196   AffineExpr m, n, k;
197   bindDims(ScopedContext::getContext(), m, n, k);
198   StructuredIndexed A(vA), B(vB), C(vC), D(tD);
199   return makeGenericLinalgOp(
200     {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
201     /*inputs=*/{A({m, k}), B({k, n})},
202     /*outputBuffers=*/{},
203     /*initTensors=*/{C({m, n})},
204     /*resultTensorTypes=*/{D({m, n})},
205     regionBuilder);
206   // clang-format on
207 }
208 
linalg_generic_conv_nhwc(Value vI,Value vW,Value vO,ArrayRef<int> strides,ArrayRef<int> dilations)209 Operation *mlir::edsc::ops::linalg_generic_conv_nhwc(Value vI, Value vW,
210                                                      Value vO,
211                                                      ArrayRef<int> strides,
212                                                      ArrayRef<int> dilations) {
213   MLIRContext *ctx = ScopedContext::getContext();
214   // TODO: some template magic to make everything rank-polymorphic.
215   assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm");
216   assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm");
217 
218   // Some short names.
219   auto par = IteratorType::Parallel;
220   auto red = IteratorType::Reduction;
221   auto s = strides;
222   auto d = dilations;
223 
224   AffineExpr b, f, h, w, kh, kw, c;
225   bindDims(ctx, b, f, h, w, kh, kw, c);
226   unsigned numDims = c.cast<AffineDimExpr>().getPosition() + 1;
227   StructuredIndexed I(vI), W(vW), O(vO);
228   // clang-format off
229   return makeGenericLinalgOp(
230     {par, par, par, par, red, red, red},
231     /*inputs=*/{
232       I({b,
233          // Roundtrip to flattened form to serve as canonicalization and ensure
234          // consistent ordering of subexpressions.
235          simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0),
236          simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
237          c}),
238       W({kh, kw, c, f}) },
239     /*outputBuffers=*/{ O({b, h, w, f}) },
240     /*initTensors=*/{},
241     /*resultTensorTypes=*/{},
242     macRegionBuilder);
243   // clang-format on
244 }
245 
linalg_generic_dilated_conv_nhwc(Value vI,Value vW,Value vO,int depth_multiplier,ArrayRef<int> strides,ArrayRef<int> dilations)246 Operation *mlir::edsc::ops::linalg_generic_dilated_conv_nhwc(
247     Value vI, Value vW, Value vO, int depth_multiplier, ArrayRef<int> strides,
248     ArrayRef<int> dilations) {
249   MLIRContext *ctx = ScopedContext::getContext();
250   // TODO: some template magic to make everything rank-polymorphic.
251   assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm");
252   assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm");
253 
254   // Some short names.
255   auto par = IteratorType::Parallel;
256   auto red = IteratorType::Reduction;
257   auto s = strides;
258   auto d = dilations;
259 
260   // clang-format off
261   AffineExpr b, dm, c, h, w, kh, kw;
262   bindDims(ctx, b, dm, c, h, w, kh, kw);
263   unsigned numDims = kw.cast<AffineDimExpr>().getPosition() + 1;
264   StructuredIndexed I(vI), W(vW), O(vO);
265   return makeGenericLinalgOp(
266     {par, par, par, par, par, red, red},
267     /*inputs=*/{
268       I({b,
269          // Roundtrip to flattened form to serve as canonicalization and ensure
270          // consistent ordering of subexpressions.
271          simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0),
272          simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
273          c}),
274       W({kh, kw, c, dm})},
275     /*outputBuffers=*/{
276       O({b, h, w, simplifyAffineExpr(c * depth_multiplier + dm, numDims, 0)})},
277     /*initTensors=*/{},
278     /*resultTensorTypes=*/{},
279     macRegionBuilder);
280   // clang-format on
281 }
282