1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
14
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/OpImplementation.h"
21 #include "mlir/IR/PatternMatch.h"
22
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/StringSet.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/MathExtras.h"
28 #include "llvm/Support/raw_ostream.h"
29
30 using namespace mlir;
31 using namespace mlir::linalg;
32
33 /// Fully compose map with operands and canonicalize the result.
34 /// Return the `createOrFold`'ed AffineApply op.
createFoldedComposedAffineApply(OpBuilder & b,Location loc,AffineMap map,ValueRange operandsRef)35 static Value createFoldedComposedAffineApply(OpBuilder &b, Location loc,
36 AffineMap map,
37 ValueRange operandsRef) {
38 SmallVector<Value, 4> operands(operandsRef.begin(), operandsRef.end());
39 fullyComposeAffineMapAndOperands(&map, &operands);
40 canonicalizeMapAndOperands(&map, &operands);
41 return b.createOrFold<AffineApplyOp>(loc, map, operands);
42 }
43
applyMapToValues(OpBuilder & b,Location loc,AffineMap map,ValueRange values)44 SmallVector<Value, 4> mlir::linalg::applyMapToValues(OpBuilder &b, Location loc,
45 AffineMap map,
46 ValueRange values) {
47 SmallVector<Value, 4> res;
48 res.reserve(map.getNumResults());
49 unsigned numDims = map.getNumDims(), numSym = map.getNumSymbols();
50 // For each `expr` in `map`, applies the `expr` to the values extracted from
51 // ranges. If the resulting application can be folded into a Value, the
52 // folding occurs eagerly.
53 for (auto expr : map.getResults()) {
54 AffineMap map = AffineMap::get(numDims, numSym, expr);
55 res.push_back(createFoldedComposedAffineApply(b, loc, map, values));
56 }
57 return res;
58 }
59
createFlatListOfOperandDims(OpBuilder & b,Location loc)60 SmallVector<Value, 4> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
61 Location loc) {
62 SmallVector<Value, 4> res;
63 for (Value v : getShapedOperands()) {
64 ShapedType t = v.getType().template cast<ShapedType>();
65 for (unsigned i = 0, e = t.getRank(); i < e; ++i)
66 res.push_back(b.create<DimOp>(loc, v, i));
67 }
68 return res;
69 }
70
createLoopRanges(OpBuilder & b,Location loc)71 SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
72 AffineMap map = getLoopsToShapesMap();
73 unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
74 auto viewSizes = createFlatListOfOperandDims(b, loc);
75 SmallVector<Range, 4> res(numDims);
76 Value zeroVal = b.create<ConstantIndexOp>(loc, 0);
77 Value oneVal = b.create<ConstantIndexOp>(loc, 1);
78 for (unsigned idx = 0; idx < numRes; ++idx) {
79 auto result = map.getResult(idx);
80 if (auto d = result.dyn_cast<AffineDimExpr>()) {
81 if (res[d.getPosition()].offset)
82 continue;
83 res[d.getPosition()] = Range{zeroVal, viewSizes[idx], oneVal};
84 }
85 }
86 return res;
87 }
88
89 /// Forward declarations.
90 template <typename NamedStructuredOpType>
91 static void buildNamedStructuredOpRegionAndAttributes(
92 OpBuilder &opBuilder, OperationState &result, TypeRange inputTypes,
93 TypeRange outputBufferTypes, TypeRange initTensorTypes,
94 TypeRange resultTypes);
95
96 static ParseResult
97 parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
98 SmallVectorImpl<Type> &inputTypes,
99 SmallVectorImpl<Type> &outputBufferTypes,
100 SmallVectorImpl<Type> &initTensorTypes);
101
102 template <typename NamedStructuredOpType>
103 static ParseResult
104 parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion,
105 TypeRange inputTypes, TypeRange outputBufferTypes,
106 TypeRange initTensorTypes, TypeRange resultTypes);
107 static ParseResult
108 parseNamedStructuredOpResults(OpAsmParser &parser,
109 SmallVectorImpl<Type> &resultTypes);
110
111 template <typename NamedStructuredOpType>
112 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
113 OperationState &result);
114
115 template <typename NamedStructuredOpType>
116 static void printCommonStructuredOpParts(OpAsmPrinter &p,
117 NamedStructuredOpType op);
118
119 static void printNamedStructuredOpResults(OpAsmPrinter &p,
120 TypeRange resultTypes);
121
122 template <typename NamedStructuredOpType>
123 static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op);
124
125 template <typename NamedStructuredOpType>
126 static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op);
127
128 /// This is a common class used for patterns of the form
129 /// ```
130 /// someop(memrefcast) -> someop
131 /// ```
132 /// It folds the source of the memref_cast into the root operation directly.
foldMemRefCast(Operation * op)133 static LogicalResult foldMemRefCast(Operation *op) {
134 bool folded = false;
135 for (OpOperand &operand : op->getOpOperands()) {
136 auto castOp = operand.get().getDefiningOp<MemRefCastOp>();
137 if (castOp && canFoldIntoConsumerOp(castOp)) {
138 operand.set(castOp.getOperand());
139 folded = true;
140 }
141 }
142 return success(folded);
143 }
144
145 ///////////////////// Operations defined with Tablegen /////////////////////////
146 // For such operations that do not correspond to library calls (i.e. defined in
147 // LinalgOps.td), we define an overloaded `print` function and a
148 // parse`className` function.
149
150 //===----------------------------------------------------------------------===//
151 // GenericOps
152 //===----------------------------------------------------------------------===//
build(OpBuilder & builder,OperationState & result,TypeRange resultTensorTypes,ValueRange inputs,ValueRange outputBuffers,ValueRange initTensors,ArrayRef<AffineMap> indexingMaps,ArrayRef<StringRef> iteratorTypes,StringRef doc,StringRef libraryCall,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuild)153 void GenericOp::build(
154 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
155 ValueRange inputs, ValueRange outputBuffers, ValueRange initTensors,
156 ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes,
157 StringRef doc, StringRef libraryCall,
158 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
159 build(builder, result, resultTensorTypes, inputs, outputBuffers, initTensors,
160 builder.getAffineMapArrayAttr(indexingMaps),
161 builder.getStrArrayAttr(iteratorTypes),
162 doc.empty() ? StringAttr() : builder.getStringAttr(doc),
163 libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
164 ArrayAttr());
165 if (!bodyBuild)
166 return;
167
168 SmallVector<Type, 4> blockArgTypes;
169 for (ValueRange container : {inputs, outputBuffers, initTensors})
170 for (Value v : container)
171 blockArgTypes.push_back(v.getType().cast<ShapedType>().getElementType());
172
173 OpBuilder::InsertionGuard guard(builder);
174 auto ®ion = *result.regions.front();
175 Block *bodyBlock = builder.createBlock(®ion, region.end(), blockArgTypes);
176 bodyBuild(builder, result.location, bodyBlock->getArguments());
177 }
178
build(OpBuilder & builder,OperationState & result,ValueRange inputs,ValueRange outputBuffers,ArrayRef<AffineMap> indexingMaps,ArrayRef<StringRef> iteratorTypes,StringRef doc,StringRef libraryCall,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuild)179 void GenericOp::build(
180 OpBuilder &builder, OperationState &result, ValueRange inputs,
181 ValueRange outputBuffers, ArrayRef<AffineMap> indexingMaps,
182 ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
183 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
184 build(builder, result, TypeRange{}, inputs, outputBuffers, ValueRange{},
185 indexingMaps, iteratorTypes, doc, libraryCall, bodyBuild);
186 }
187
build(OpBuilder & builder,OperationState & result,ValueRange inputs,ValueRange outputBuffers,ArrayRef<AffineMap> indexingMaps,ArrayRef<StringRef> iteratorTypes,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuild)188 void GenericOp::build(
189 OpBuilder &builder, OperationState &result, ValueRange inputs,
190 ValueRange outputBuffers, ArrayRef<AffineMap> indexingMaps,
191 ArrayRef<StringRef> iteratorTypes,
192 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
193 build(builder, result, inputs, outputBuffers, indexingMaps, iteratorTypes,
194 /*doc=*/"",
195 /*libraryCall=*/"", bodyBuild);
196 }
197
build(OpBuilder & builder,OperationState & result,TypeRange resultTensorTypes,ValueRange inputs,ValueRange outputBuffers,ValueRange initTensors,ArrayRef<AffineMap> indexingMaps,ArrayRef<StringRef> iteratorTypes,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuild)198 void GenericOp::build(
199 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
200 ValueRange inputs, ValueRange outputBuffers, ValueRange initTensors,
201 ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes,
202 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
203 build(builder, result, resultTensorTypes, inputs, outputBuffers, initTensors,
204 indexingMaps, iteratorTypes,
205 /*doc=*/"",
206 /*libraryCall=*/"", bodyBuild);
207 }
build(OpBuilder & builder,OperationState & result,TypeRange resultTensorTypes,ValueRange inputs,ValueRange outputBuffers,ValueRange initTensors,ArrayRef<AffineMap> indexingMaps,ArrayRef<StringRef> iteratorTypes,StringRef doc,StringRef libraryCall,function_ref<void (OpBuilder &,Location,ValueRange,ValueRange)> bodyBuild)208 void IndexedGenericOp::build(
209 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
210 ValueRange inputs, ValueRange outputBuffers, ValueRange initTensors,
211 ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes,
212 StringRef doc, StringRef libraryCall,
213 function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
214 bodyBuild) {
215 build(builder, result, resultTensorTypes, inputs, outputBuffers, initTensors,
216 builder.getAffineMapArrayAttr(indexingMaps),
217 builder.getStrArrayAttr(iteratorTypes),
218 doc.empty() ? StringAttr() : builder.getStringAttr(doc),
219 libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
220 ArrayAttr());
221 if (!bodyBuild)
222 return;
223
224 unsigned nLoops = iteratorTypes.size();
225 SmallVector<Type, 4> blockArgTypes(nLoops, builder.getIndexType());
226 for (ValueRange container : {inputs, outputBuffers, initTensors})
227 for (Value v : container)
228 blockArgTypes.push_back(v.getType().cast<ShapedType>().getElementType());
229
230 OpBuilder::InsertionGuard guard(builder);
231 auto ®ion = *result.regions.front();
232 Block *bodyBlock = builder.createBlock(®ion, region.end(), blockArgTypes);
233 bodyBuild(builder, result.location,
234 bodyBlock->getArguments().take_front(nLoops),
235 bodyBlock->getArguments().drop_front(nLoops));
236 }
237
build(OpBuilder & builder,OperationState & result,ValueRange inputs,ValueRange outputBuffers,ArrayRef<AffineMap> indexingMaps,ArrayRef<StringRef> iteratorTypes,StringRef doc,StringRef libraryCall,function_ref<void (OpBuilder &,Location,ValueRange,ValueRange)> bodyBuild)238 void IndexedGenericOp::build(
239 OpBuilder &builder, OperationState &result, ValueRange inputs,
240 ValueRange outputBuffers, ArrayRef<AffineMap> indexingMaps,
241 ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
242 function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
243 bodyBuild) {
244 build(builder, result, TypeRange{}, inputs, outputBuffers, ValueRange{},
245 indexingMaps, iteratorTypes, doc, libraryCall, bodyBuild);
246 }
247
build(OpBuilder & builder,OperationState & result,ValueRange inputs,ValueRange outputBuffers,ArrayRef<AffineMap> indexingMaps,ArrayRef<StringRef> iteratorTypes,function_ref<void (OpBuilder &,Location,ValueRange,ValueRange)> bodyBuild)248 void IndexedGenericOp::build(
249 OpBuilder &builder, OperationState &result, ValueRange inputs,
250 ValueRange outputBuffers, ArrayRef<AffineMap> indexingMaps,
251 ArrayRef<StringRef> iteratorTypes,
252 function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
253 bodyBuild) {
254 build(builder, result, inputs, outputBuffers, indexingMaps, iteratorTypes,
255 /*doc=*/"", /*libraryCall=*/"", bodyBuild);
256 }
257
build(OpBuilder & builder,OperationState & result,TypeRange resultTensorTypes,ValueRange inputs,ValueRange outputBuffers,ValueRange initTensors,ArrayRef<AffineMap> indexingMaps,ArrayRef<StringRef> iteratorTypes,function_ref<void (OpBuilder &,Location,ValueRange,ValueRange)> bodyBuild)258 void IndexedGenericOp::build(
259 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
260 ValueRange inputs, ValueRange outputBuffers, ValueRange initTensors,
261 ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes,
262 function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
263 bodyBuild) {
264 build(builder, result, resultTensorTypes, inputs, outputBuffers, initTensors,
265 indexingMaps, iteratorTypes,
266 /*doc=*/"",
267 /*libraryCall=*/"", bodyBuild);
268 }
269
270 template <typename GenericOpType>
printGenericOp(OpAsmPrinter & p,GenericOpType op)271 static void printGenericOp(OpAsmPrinter &p, GenericOpType op) {
272 p << op.getOperationName() << " ";
273
274 // Print extra attributes.
275 auto genericAttrNames = op.linalgTraitAttrNames();
276
277 llvm::StringSet<> genericAttrNamesSet;
278 genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
279 SmallVector<NamedAttribute, 8> genericAttrs;
280 for (auto attr : op.getAttrs())
281 if (genericAttrNamesSet.count(attr.first.strref()) > 0)
282 genericAttrs.push_back(attr);
283 if (!genericAttrs.empty()) {
284 auto genericDictAttr = DictionaryAttr::get(genericAttrs, op.getContext());
285 p << genericDictAttr;
286 }
287
288 // Printing is shared with named ops, except for the region and attributes
289 printCommonStructuredOpParts(p, op);
290
291 genericAttrNames.push_back("operand_segment_sizes");
292 genericAttrNamesSet.insert(genericAttrNames.back());
293
294 bool hasExtraAttrs = false;
295 for (NamedAttribute n : op.getAttrs()) {
296 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.first.strref())))
297 break;
298 }
299 if (hasExtraAttrs) {
300 p << " attrs = ";
301 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/genericAttrNames);
302 }
303
304 // Print region.
305 if (!op.region().empty())
306 p.printRegion(op.region());
307
308 // Print results.
309 printNamedStructuredOpResults(p, op.result_tensors().getTypes());
310 }
311
print(OpAsmPrinter & p,GenericOp op)312 static void print(OpAsmPrinter &p, GenericOp op) { printGenericOp(p, op); }
313
print(OpAsmPrinter & p,IndexedGenericOp op)314 static void print(OpAsmPrinter &p, IndexedGenericOp op) {
315 printGenericOp(p, op);
316 }
317
parseGenericOp(OpAsmParser & parser,OperationState & result)318 static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) {
319 DictionaryAttr dictAttr;
320 // Parse the core linalg traits that must check into a dictAttr.
321 // The name is unimportant as we will overwrite result.attributes.
322 // The core linalg traits must contain the information necessary to pass the
323 // verifier.
324 if (parser.parseAttribute(dictAttr, "_", result.attributes))
325 return failure();
326 result.attributes.assign(dictAttr.getValue().begin(),
327 dictAttr.getValue().end());
328
329 // Parsing is shared with named ops, except for the region.
330 SmallVector<Type, 1> inputTypes, outputBufferTypes, initTensorTypes;
331 if (parseCommonStructuredOpParts(parser, result, inputTypes,
332 outputBufferTypes, initTensorTypes))
333 return failure();
334
335 // Optional attributes may be added.
336 if (succeeded(parser.parseOptionalKeyword("attrs")))
337 if (failed(parser.parseEqual()) ||
338 failed(parser.parseOptionalAttrDict(result.attributes)))
339 return failure();
340
341 SmallVector<OpAsmParser::OperandType, 8> regionOperands;
342 std::unique_ptr<Region> region = std::make_unique<Region>();
343 SmallVector<Type, 8> operandTypes, regionTypes;
344 if (parser.parseRegion(*region, regionOperands, regionTypes))
345 return failure();
346 result.addRegion(std::move(region));
347
348 // Generic ops may specify that a subset of its outputs are tensors. Such
349 // outputs are specified in the result type.
350 // TODO: may need to move output parsing before region parsing.
351 // Need to wait for declarative assembly resolution to decide.
352 SmallVector<Type, 1> outputTensorsTypes;
353 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
354 return failure();
355 result.addTypes(outputTensorsTypes);
356
357 return success();
358 }
359
getGenericEffectsImpl(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects,ValueRange results,ValueRange inputBuffers,ValueRange outputBuffers)360 static void getGenericEffectsImpl(
361 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
362 &effects,
363 ValueRange results, ValueRange inputBuffers, ValueRange outputBuffers) {
364 for (Value value : results) {
365 effects.emplace_back(MemoryEffects::Allocate::get(), value,
366 SideEffects::DefaultResource::get());
367 }
368 for (Value value : inputBuffers) {
369 effects.emplace_back(MemoryEffects::Read::get(), value,
370 SideEffects::DefaultResource::get());
371 }
372 for (Value value : outputBuffers) {
373 effects.emplace_back(MemoryEffects::Read::get(), value,
374 SideEffects::DefaultResource::get());
375 effects.emplace_back(MemoryEffects::Write::get(), value,
376 SideEffects::DefaultResource::get());
377 }
378 }
379
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)380 void GenericOp::getEffects(
381 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
382 &effects) {
383 getGenericEffectsImpl(effects, getOperation()->getResults(),
384 getInputBuffers(), getOutputBuffers());
385 }
386
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)387 void IndexedGenericOp::getEffects(
388 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
389 &effects) {
390 getGenericEffectsImpl(effects, getOperation()->getResults(),
391 getInputBuffers(), getOutputBuffers());
392 }
393
394 namespace {
395
396 template <typename GenericOpType>
397 struct BlockArgsVerifier {
398 static LogicalResult verify(GenericOpType op, Block &block);
399 };
400
401 template <typename GenericOpType>
verify(GenericOpType op,Block & block)402 LogicalResult BlockArgsVerifier<GenericOpType>::verify(GenericOpType op,
403 Block &block) {
404 auto nOperands = op.getNumOperands();
405 if (block.getNumArguments() != nOperands)
406 return op.emitOpError("expected number of block arguments to match number "
407 "of operands");
408
409 // Note: the number and type of yield values are checked in the YieldOp.
410 auto nInputViews = op.getNumInputs();
411 for (unsigned i = 0; i < nOperands; ++i) {
412 auto viewType = op.getShapedType(i);
413 if (viewType.getElementType() != block.getArgument(i).getType())
414 return op.emitOpError("expected block argument ")
415 << (i + 1) << " of the same type as elemental type of "
416 << ((i < nInputViews) ? "input " : "output ")
417 << "operand: " << viewType;
418 }
419 return success();
420 }
421
422 template <>
verify(IndexedGenericOp op,Block & block)423 LogicalResult BlockArgsVerifier<IndexedGenericOp>::verify(IndexedGenericOp op,
424 Block &block) {
425 auto nInputViews = op.getNumInputs();
426 auto nLoops = op.getNumLoops();
427 auto nOperands = op.getNumOperands();
428 if (block.getNumArguments() != nOperands + nLoops)
429 return op.emitOpError(
430 "expected number of block arguments to match number of operands + "
431 "number of loops");
432
433 // Note: the number and type of yield values are checked in the YieldOp.
434 for (unsigned i = 0; i < nLoops; ++i)
435 if (!block.getArgument(i).getType().isIndex())
436 return op.emitOpError("expected block argument ")
437 << (i + 1) << " to be an index";
438
439 for (unsigned i = 0; i < nOperands; ++i) {
440 unsigned memrefArgIndex = i + nLoops;
441 auto viewType = op.getShapedType(i);
442 if (viewType.getElementType() !=
443 block.getArgument(memrefArgIndex).getType())
444 return op.emitOpError("expected block argument ")
445 << (memrefArgIndex + 1)
446 << " of the same type as elemental type of "
447 << ((i < nInputViews) ? "input " : "output ")
448 << "operand: " << viewType;
449 }
450 return success();
451 }
452
453 template <typename GenericOpType>
454 struct AnnotationsVerifier {
verify__anon40fc37570111::AnnotationsVerifier455 static LogicalResult verify(GenericOpType op) { return success(); }
456 };
457
458 template <>
verify(GenericOp op)459 LogicalResult AnnotationsVerifier<GenericOp>::verify(GenericOp op) {
460 ArrayAttr sparseAttr = op.sparseAttr();
461 if (!sparseAttr)
462 return success();
463 // Verify consistency of sparse annotations.
464 if (!op.hasTensorSemantics())
465 return op.emitOpError("expected sparse annotations on tensors only");
466 if (op.getNumOutputs() != 1)
467 return op.emitOpError("expected single output tensor");
468 unsigned numTensors = op.getNumInputsAndOutputs();
469 if (sparseAttr.size() != numTensors)
470 return op.emitOpError("expected one sparse annotation for each tensor");
471 for (unsigned t = 0; t < numTensors; t++) {
472 auto dimAttr = sparseAttr[t].dyn_cast_or_null<ArrayAttr>();
473 if (!dimAttr)
474 return op.emitOpError("expected sparse annotation array for tensor ")
475 << t;
476 unsigned rank = op.getShapedType(t).getRank();
477 if (dimAttr.size() != rank)
478 return op.emitOpError("expected sparse annotation with rank ")
479 << rank << " for tensor " << t;
480 // Per-dimension annotations for each tensor consist of only "D" or "S".
481 for (unsigned d = 0; d < rank; d++) {
482 if (isDenseDim(dimAttr[d])) {
483 continue;
484 } else if (isSparseDim(dimAttr[d])) {
485 if (t == numTensors - 1)
486 return op.emitOpError("sparse output tensors not supported (yet)");
487 continue;
488 }
489 return op.emitOpError("expected sparse annotation at position ")
490 << d << " for tensor " << t;
491 }
492 }
493 return success();
494 }
495
496 } // namespace
497
498 template <typename GenericOpType>
verifyGenericOp(GenericOpType op)499 static LogicalResult verifyGenericOp(GenericOpType op) {
500 auto nLoops = op.getNumLoops();
501
502 if (op.inputs().size() + op.output_buffers().size() +
503 op.init_tensors().size() + op.getNumResults() ==
504 0)
505 return op.emitOpError("expected at least 1 Shaped operand or return");
506
507 auto ®ion = op.region();
508 if (!llvm::hasSingleElement(region))
509 return op.emitOpError("expected region with 1 block");
510 if (failed(BlockArgsVerifier<GenericOpType>::verify(op, region.front())))
511 return failure();
512
513 if (op.indexing_maps().size() != op.getNumInputsAndOutputs())
514 return op.emitOpError("expected the number of indexing_map (")
515 << op.indexing_maps().size()
516 << ") to be equal to the number of inputs and outputs ("
517 << op.getNumInputsAndOutputs() << ")";
518
519 SmallVector<AffineMap, 4> indexingMaps;
520 indexingMaps.reserve(op.indexing_maps().size());
521 for (auto en : llvm::enumerate(op.indexing_maps())) {
522 auto idx = en.index();
523 auto m = en.value().template cast<AffineMapAttr>().getValue();
524 indexingMaps.push_back(m); // Save reference to map for further checks.
525 auto view = op.getShapedType(idx);
526
527 if (m.getNumSymbols() != 0)
528 return op.emitOpError("unexpected symbols in indexing_map #") << idx;
529
530 if (m.getNumDims() != nLoops)
531 return op.emitOpError("expected indexing_map #")
532 << idx << " to have " << nLoops
533 << " dim(s) to match the number of loops";
534
535 if (m.getNumResults() != view.getRank())
536 return op.emitOpError("expected indexing_map #")
537 << idx << " results to match view rank: " << view;
538 }
539
540 if (!op.getShapesToLoopsMap())
541 return op.emitOpError("expected the shape-to-loops map to be non-null");
542
543 if (failed(AnnotationsVerifier<GenericOpType>::verify(op)))
544 return failure();
545
546 return success();
547 }
548
verify(GenericOp op)549 static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); }
550
verify(IndexedGenericOp op)551 static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
552
553 //===----------------------------------------------------------------------===//
554 // ReshapeOp
555 //===----------------------------------------------------------------------===//
556
557 /// Collapse reassociation maps that are used in pair of reshape ops where one
558 /// is a producer and other is the consumer. Only valid to use this method when
559 /// both the producer and consumer are collapsing dimensions or both are
560 /// expanding dimensions.
561 ///
562 /// For example,
563 /// mapsProducer = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
564 /// affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
565 /// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
566 /// mapsConsumer = [affine_map<(d0, d1, d2) -> (d0, d1)>,
567 /// affine_map<(d0, d1, d2) -> (d2)>]
568 ///
569 /// is folded into
570 ///
571 /// result = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
572 /// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
collapseReassociationMaps(ArrayRef<AffineMap> mapsProducer,ArrayRef<AffineMap> mapsConsumer,MLIRContext * context)573 static ArrayAttr collapseReassociationMaps(ArrayRef<AffineMap> mapsProducer,
574 ArrayRef<AffineMap> mapsConsumer,
575 MLIRContext *context) {
576 // Handle the corner case of the result being a rank 0 shaped type. Return an
577 // emtpy ArrayAttr.
578 if (mapsConsumer.empty() && !mapsProducer.empty())
579 return ArrayAttr::get(ArrayRef<Attribute>(), context);
580 if (mapsProducer.empty() || mapsConsumer.empty() ||
581 mapsProducer[0].getNumDims() < mapsConsumer[0].getNumDims() ||
582 mapsProducer.size() != mapsConsumer[0].getNumDims())
583 return nullptr;
584 unsigned numLhsDims = mapsProducer[0].getNumDims();
585 unsigned currDim = 0;
586 SmallVector<AffineExpr, 4> reassociations;
587 SmallVector<Attribute, 4> reassociationMaps;
588 for (AffineMap rhs : mapsConsumer) {
589 for (AffineExpr rhsExpr : rhs.getResults()) {
590 AffineDimExpr dimExpr = rhsExpr.cast<AffineDimExpr>();
591 for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults();
592 i < e; ++i) {
593 reassociations.push_back(getAffineDimExpr(currDim++, context));
594 }
595 }
596 reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get(
597 numLhsDims, /*numSymbols =*/0, reassociations, context)));
598 reassociations.clear();
599 }
600 return ArrayAttr::get(reassociationMaps, context);
601 }
602
603 namespace {
604 /// Pattern to collapse producer/consumer reshape ops that are both collapsing
605 /// dimensions or are both expanding dimensions.
606 template <typename ReshapeOpTy>
607 struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
608 using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
matchAndRewrite__anon40fc37570211::CollapseReshapeOps609 LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
610 PatternRewriter &rewriter) const override {
611 auto srcReshapeOp = reshapeOp.src().template getDefiningOp<ReshapeOpTy>();
612 if (!srcReshapeOp)
613 return failure();
614
615 auto areReshapeOpsFoldable = [](ShapedType largerType,
616 ShapedType intermediateType,
617 ShapedType smallerType) -> bool {
618 return largerType.getRank() > intermediateType.getRank() &&
619 intermediateType.getRank() > smallerType.getRank();
620 };
621 // Check if producer and consumer are both expanding dims.
622 if (areReshapeOpsFoldable(reshapeOp.getResultType(), reshapeOp.getSrcType(),
623 srcReshapeOp.getSrcType())) {
624 rewriter.replaceOpWithNewOp<ReshapeOpTy>(
625 reshapeOp, reshapeOp.getResultType(), srcReshapeOp.src(),
626 collapseReassociationMaps(reshapeOp.getReassociationMaps(),
627 srcReshapeOp.getReassociationMaps(),
628 rewriter.getContext()));
629 return success();
630 }
631 // Check if producer and consumer are both collapsing dims.
632 if (areReshapeOpsFoldable(srcReshapeOp.getSrcType(), reshapeOp.getSrcType(),
633 reshapeOp.getResultType())) {
634 rewriter.replaceOpWithNewOp<ReshapeOpTy>(
635 reshapeOp, reshapeOp.getResultType(), srcReshapeOp.src(),
636 collapseReassociationMaps(srcReshapeOp.getReassociationMaps(),
637 reshapeOp.getReassociationMaps(),
638 rewriter.getContext()));
639 return success();
640 }
641 return failure();
642 }
643 };
644 } // namespace
645
646 template <typename ReshapeOpTy>
foldReshapeOp(ReshapeOpTy reshapeOp,ArrayRef<Attribute> operands)647 static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
648 ArrayRef<Attribute> operands) {
649 // Fold producer-consumer reshape ops that where the operand type of the
650 // producer is same as the return type of the consumer. This can only be
651 // verified if the shapes in question are static.
652 ReshapeOpTy reshapeSrcOp =
653 reshapeOp.src().template getDefiningOp<ReshapeOpTy>();
654 if (reshapeSrcOp && reshapeSrcOp.getSrcType().hasStaticShape() &&
655 reshapeOp.getResultType().hasStaticShape() &&
656 reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
657 return reshapeSrcOp.src();
658 // Reshape of a constant can be replaced with a new constant.
659 if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
660 return elements.reshape(
661 reshapeOp.getResult().getType().template cast<ShapedType>());
662 }
663 return nullptr;
664 }
665
666 /// Return true if the reassociation specification is valid, false otherwise.
667 /// When false, the `invalidIndex` integer pointer is optionally filled with the
668 /// index of the offending reassociation map.
isReassociationValid(ArrayRef<AffineMap> reassociation,int * invalidIndex=nullptr)669 static bool isReassociationValid(ArrayRef<AffineMap> reassociation,
670 int *invalidIndex = nullptr) {
671 if (reassociation.empty())
672 return true;
673 unsigned nDims = reassociation[0].getNumDims();
674 unsigned nextExpectedDim = 0;
675 for (auto it : llvm::enumerate(reassociation)) {
676 auto m = it.value();
677 if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
678 if (invalidIndex)
679 *invalidIndex = it.index();
680 return false;
681 }
682 for (auto e : m.getResults()) {
683 auto d = e.dyn_cast<AffineDimExpr>();
684 if (!d || d.getPosition() != nextExpectedDim++) {
685 if (invalidIndex)
686 *invalidIndex = it.index();
687 return false;
688 }
689 }
690 }
691 if (nextExpectedDim != nDims) {
692 if (invalidIndex)
693 *invalidIndex = reassociation.size() - 1;
694 return false;
695 }
696 return true;
697 }
698
699 /// Detect whether memref dims [dim, dim + extent) can be reshaped without
700 /// copies.
isReshapableDimBand(unsigned dim,unsigned extent,ArrayRef<int64_t> sizes,ArrayRef<AffineExpr> strides)701 static bool isReshapableDimBand(unsigned dim, unsigned extent,
702 ArrayRef<int64_t> sizes,
703 ArrayRef<AffineExpr> strides) {
704 assert(sizes.size() == strides.size() && "mismatched ranks");
705 // off by 1 indexing to avoid out of bounds
706 // V
707 for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) {
708 // Only bands of static shapes are reshapable. This is due to the fact that
709 // there is no relation between dynamic sizes and dynamic strides: we do not
710 // have enough information to know whether a "-1" size corresponds to the
711 // proper symbol in the AffineExpr of a stride.
712 if (ShapedType::isDynamic(sizes[dim + 1]))
713 return false;
714 // TODO: Refine this by passing the proper nDims and nSymbols so we can
715 // simplify on the fly and catch more reshapable cases.
716 if (strides[idx] != strides[idx + 1] * sizes[idx + 1])
717 return false;
718 }
719 return true;
720 }
721
722 /// Compute the MemRefType obtained by applying the `reassociation` (which is
723 /// expected to be valid) to `type`.
724 /// If `type` is Contiguous MemRefType, this always produce a contiguous
725 /// MemRefType.
726 static MemRefType
computeReshapeCollapsedType(MemRefType type,ArrayRef<AffineMap> reassociation)727 computeReshapeCollapsedType(MemRefType type,
728 ArrayRef<AffineMap> reassociation) {
729 auto sizes = type.getShape();
730 AffineExpr offset;
731 SmallVector<AffineExpr, 4> strides;
732 auto status = getStridesAndOffset(type, strides, offset);
733 (void)status;
734 assert(succeeded(status) && "expected strided memref");
735
736 SmallVector<int64_t, 4> newSizes;
737 newSizes.reserve(reassociation.size());
738 SmallVector<AffineExpr, 4> newStrides;
739 newStrides.reserve(reassociation.size());
740
741 // Use the fact that reassociation is valid to simplify the logic: only use
742 // each map's rank.
743 assert(isReassociationValid(reassociation) && "invalid reassociation");
744 unsigned currentDim = 0;
745 for (AffineMap m : reassociation) {
746 unsigned dim = m.getNumResults();
747 int64_t size = 1;
748 AffineExpr stride = strides[currentDim + dim - 1];
749 if (!isReshapableDimBand(currentDim, dim, sizes, strides)) {
750 size = ShapedType::kDynamicSize;
751 stride = AffineExpr();
752 } else {
753 for (unsigned d = 0; d < dim; ++d)
754 size *= sizes[currentDim + d];
755 }
756 newSizes.push_back(size);
757 newStrides.push_back(stride);
758 currentDim += dim;
759 }
760
761 // Early-exit: if `type` is contiguous, the result must be contiguous.
762 if (canonicalizeStridedLayout(type).getAffineMaps().empty())
763 return MemRefType::Builder(type).setShape(newSizes).setAffineMaps({});
764
765 // Convert back to int64_t because we don't have enough information to create
766 // new strided layouts from AffineExpr only. This corresponds to a case where
767 // copies may be necessary.
768 int64_t intOffset = ShapedType::kDynamicStrideOrOffset;
769 if (auto o = offset.dyn_cast<AffineConstantExpr>())
770 intOffset = o.getValue();
771 SmallVector<int64_t, 4> intStrides;
772 intStrides.reserve(strides.size());
773 for (auto stride : newStrides) {
774 if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>())
775 intStrides.push_back(cst.getValue());
776 else
777 intStrides.push_back(ShapedType::kDynamicStrideOrOffset);
778 }
779 auto layout =
780 makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext());
781 return canonicalizeStridedLayout(
782 MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout}));
783 }
784
785 /// Helper functions assert Attribute of the proper type in attr and returns the
786 /// corresponding vector.
787 /// TODO: this should be evolved into a generic
788 /// `getRangeOfType<AffineMap>(ArrayAttr attrs)` that does not copy.
getAffineMaps(ArrayAttr attrs)789 static SmallVector<AffineMap, 4> getAffineMaps(ArrayAttr attrs) {
790 return llvm::to_vector<8>(llvm::map_range(
791 attrs, [](Attribute a) { return a.cast<AffineMapAttr>().getValue(); }));
792 }
793
794 template <typename AffineExprTy>
getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays)795 unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) {
796 unsigned pos = 0;
797 for (const auto &exprs : exprArrays) {
798 for (auto expr : exprs) {
799 expr.walk([&pos](AffineExpr e) {
800 if (auto d = e.dyn_cast<AffineExprTy>())
801 pos = std::max(pos, d.getPosition());
802 });
803 }
804 }
805 return pos;
806 }
807
808 static SmallVector<AffineMap, 4>
getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation)809 getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) {
810 unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
811 assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 &&
812 "Expected symbol-less expressions");
813 SmallVector<AffineMap, 4> maps;
814 maps.reserve(reassociation.size());
815 for (const auto &exprs : reassociation) {
816 assert(!exprs.empty());
817 maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext()));
818 }
819 return maps;
820 }
821
822 static SmallVector<SmallVector<AffineExpr, 2>, 2>
convertReassociationIndicesToMaps(OpBuilder & b,ArrayRef<ReassociationIndices> reassociationIndices)823 convertReassociationIndicesToMaps(
824 OpBuilder &b, ArrayRef<ReassociationIndices> reassociationIndices) {
825 SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps;
826 for (const auto &indices : reassociationIndices) {
827 SmallVector<AffineExpr, 2> reassociationMap;
828 reassociationMap.reserve(indices.size());
829 for (int64_t index : indices)
830 reassociationMap.push_back(b.getAffineDimExpr(index));
831 reassociationMaps.push_back(std::move(reassociationMap));
832 }
833 return reassociationMaps;
834 }
835
build(OpBuilder & b,OperationState & result,Value src,ArrayRef<ReassociationExprs> reassociation,ArrayRef<NamedAttribute> attrs)836 void mlir::linalg::ReshapeOp::build(OpBuilder &b, OperationState &result,
837 Value src,
838 ArrayRef<ReassociationExprs> reassociation,
839 ArrayRef<NamedAttribute> attrs) {
840 auto maps = getSymbolLessAffineMaps(reassociation);
841 auto memRefType = src.getType().cast<MemRefType>();
842 auto resultType = computeReshapeCollapsedType(memRefType, maps);
843 build(b, result, resultType, src, attrs);
844 result.addAttribute(ReshapeOp::getReassociationAttrName(),
845 b.getAffineMapArrayAttr(maps));
846 }
847
build(OpBuilder & b,OperationState & result,Type resultType,Value src,ArrayRef<ReassociationExprs> reassociation,ArrayRef<NamedAttribute> attrs)848 void mlir::linalg::ReshapeOp::build(OpBuilder &b, OperationState &result,
849 Type resultType, Value src,
850 ArrayRef<ReassociationExprs> reassociation,
851 ArrayRef<NamedAttribute> attrs) {
852 auto maps = getSymbolLessAffineMaps(reassociation);
853 build(b, result, resultType, src, attrs);
854 result.addAttribute(ReshapeOp::getReassociationAttrName(),
855 b.getAffineMapArrayAttr(maps));
856 }
857
getViewSource()858 Value mlir::linalg::ReshapeOp::getViewSource() { return src(); }
859
860 // Common verifier for reshape-like types. Fills `expandedType` and
861 // `collapsedType` with the proper `src` or `result` type.
862 template <typename Op, typename T>
verifyReshapeLikeTypes(Op op,T & expandedType,T & collapsedType)863 static LogicalResult verifyReshapeLikeTypes(Op op, T &expandedType,
864 T &collapsedType) {
865 expandedType = op.getSrcType();
866 collapsedType = op.getResultType();
867 unsigned expandedRank = expandedType.getRank();
868 unsigned collapsedRank = collapsedType.getRank();
869 bool isCollapse = expandedRank > collapsedRank;
870 if (!isCollapse) {
871 std::swap(expandedRank, collapsedRank);
872 std::swap(expandedType, collapsedType);
873 }
874 if (expandedRank == 0)
875 return op.emitOpError("expected non-zero memref ranks");
876 if (expandedRank == collapsedRank)
877 return op.emitOpError("expected to collapse or expand dims");
878
879 if (collapsedRank == 0) {
880 // If collapsed rank is 0, then expanded type must be static shaped and of
881 // sizes 1.
882 if (llvm::any_of(expandedType.getShape(),
883 [](int64_t dim) -> bool { return dim != 1; }))
884 return op.emitOpError(
885 "invalid to reshape tensor/memref with non-unit extent dimensions to "
886 "zero-rank tensor/memref");
887 return success();
888 }
889 if (collapsedRank != op.reassociation().size())
890 return op.emitOpError("expected rank of the collapsed type(")
891 << collapsedRank << ") to be the number of reassociation maps("
892 << op.reassociation().size() << ")";
893 auto maps = getAffineMaps(op.reassociation());
894 for (auto it : llvm::enumerate(maps))
895 if (it.value().getNumDims() != expandedRank)
896 return op.emitOpError("expected reassociation map #")
897 << it.index() << " of same rank as expanded memref("
898 << expandedRank << "), but got " << it.value().getNumDims();
899 int invalidIdx = 0;
900 if (!isReassociationValid(maps, &invalidIdx))
901 return op.emitOpError("expected reassociation map #")
902 << invalidIdx << " to be valid and contiguous";
903 return success();
904 }
905
verify(ReshapeOp op)906 static LogicalResult verify(ReshapeOp op) {
907 MemRefType expandedType, collapsedType;
908 if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType)))
909 return failure();
910 auto maps = getAffineMaps(op.reassociation());
911 MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps);
912 if (collapsedType != expectedType)
913 return op.emitOpError("expected collapsed type to be ")
914 << expectedType << ", but got " << collapsedType;
915 return success();
916 }
917
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)918 void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
919 MLIRContext *context) {
920 results.insert<CollapseReshapeOps<ReshapeOp>>(context);
921 }
922
923 //===----------------------------------------------------------------------===//
924 // TensorReshapeOp
925 //===----------------------------------------------------------------------===//
926
927 /// Compute the RankedTensorType obtained by applying `reassociation` to `type`.
928 static RankedTensorType
computeTensorReshapeCollapsedType(RankedTensorType type,ArrayRef<AffineMap> reassociation)929 computeTensorReshapeCollapsedType(RankedTensorType type,
930 ArrayRef<AffineMap> reassociation) {
931 auto shape = type.getShape();
932 SmallVector<int64_t, 4> newShape;
933 newShape.reserve(reassociation.size());
934
935 // Use the fact that reassociation is valid to simplify the logic: only use
936 // each map's rank.
937 assert(isReassociationValid(reassociation) && "invalid reassociation");
938 unsigned currentDim = 0;
939 for (AffineMap m : reassociation) {
940 unsigned dim = m.getNumResults();
941 auto band = shape.slice(currentDim, dim);
942 int64_t size = 1;
943 if (llvm::is_contained(band, ShapedType::kDynamicSize))
944 size = ShapedType::kDynamicSize;
945 else
946 for (unsigned d = 0; d < dim; ++d)
947 size *= shape[currentDim + d];
948 newShape.push_back(size);
949 currentDim += dim;
950 }
951
952 return RankedTensorType::get(newShape, type.getElementType());
953 }
954
build(OpBuilder & b,OperationState & result,Value src,ArrayRef<ReassociationExprs> reassociation,ArrayRef<NamedAttribute> attrs)955 void mlir::linalg::TensorReshapeOp::build(
956 OpBuilder &b, OperationState &result, Value src,
957 ArrayRef<ReassociationExprs> reassociation,
958 ArrayRef<NamedAttribute> attrs) {
959 auto maps = getSymbolLessAffineMaps(reassociation);
960 auto resultType = computeTensorReshapeCollapsedType(
961 src.getType().cast<RankedTensorType>(), maps);
962 build(b, result, resultType, src, attrs);
963 result.addAttribute(TensorReshapeOp::getReassociationAttrName(),
964 b.getAffineMapArrayAttr(maps));
965 }
966
build(OpBuilder & b,OperationState & result,Type resultType,Value src,ArrayRef<ReassociationExprs> reassociation,ArrayRef<NamedAttribute> attrs)967 void mlir::linalg::TensorReshapeOp::build(
968 OpBuilder &b, OperationState &result, Type resultType, Value src,
969 ArrayRef<ReassociationExprs> reassociation,
970 ArrayRef<NamedAttribute> attrs) {
971 auto maps = getSymbolLessAffineMaps(reassociation);
972 build(b, result, resultType, src, attrs);
973 result.addAttribute(TensorReshapeOp::getReassociationAttrName(),
974 b.getAffineMapArrayAttr(maps));
975 }
976
verify(TensorReshapeOp op)977 static LogicalResult verify(TensorReshapeOp op) {
978 RankedTensorType expandedType, collapsedType;
979 if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType)))
980 return failure();
981 auto maps = getAffineMaps(op.reassociation());
982 // TODO: expanding a ? with a non-constant is under-specified. Error
983 // out.
984 RankedTensorType expectedType =
985 computeTensorReshapeCollapsedType(expandedType, maps);
986 if (collapsedType != expectedType)
987 return op.emitOpError("expected collapsed type to be ")
988 << expectedType << ", but got " << collapsedType;
989 return success();
990 }
991
992 namespace {
993 /// Reshape of a splat constant can be replaced with a constant of the result
994 /// type.
995 struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
996 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
matchAndRewrite__anon40fc37570711::FoldReshapeWithConstant997 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
998 PatternRewriter &rewriter) const override {
999 DenseElementsAttr attr;
1000 if (!matchPattern(reshapeOp.src(), m_Constant(&attr)))
1001 return failure();
1002 if (!attr || !attr.isSplat())
1003 return failure();
1004 DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
1005 reshapeOp.getResultType(), attr.getRawData(), true);
1006 rewriter.replaceOpWithNewOp<ConstantOp>(reshapeOp, newAttr);
1007 return success();
1008 }
1009 };
1010 } // namespace
1011
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1012 void TensorReshapeOp::getCanonicalizationPatterns(
1013 OwningRewritePatternList &results, MLIRContext *context) {
1014 results.insert<CollapseReshapeOps<TensorReshapeOp>, FoldReshapeWithConstant>(
1015 context);
1016 }
1017
1018 //===----------------------------------------------------------------------===//
1019 // SliceOp
1020 //===----------------------------------------------------------------------===//
build(OpBuilder & b,OperationState & result,Value base,ValueRange indexings)1021 void mlir::linalg::SliceOp::build(OpBuilder &b, OperationState &result,
1022 Value base, ValueRange indexings) {
1023 result.addOperands(base);
1024 result.addOperands(indexings);
1025
1026 auto memRefType = base.getType().cast<MemRefType>();
1027 int64_t offset;
1028 SmallVector<int64_t, 4> strides;
1029 auto res = getStridesAndOffset(memRefType, strides, offset);
1030 assert(succeeded(res) && strides.size() == indexings.size());
1031 (void)res;
1032
1033 unsigned rank = memRefType.getRank();
1034 // TODO: propagate static size and stride information when available.
1035 SmallVector<int64_t, 4> sizes(rank, -1); // -1 encodes dynamic size.
1036 result.addTypes({MemRefType::Builder(memRefType)
1037 .setShape(sizes)
1038 .setAffineMaps(makeStridedLinearLayoutMap(
1039 strides, offset, b.getContext()))});
1040 }
1041
print(OpAsmPrinter & p,SliceOp op)1042 static void print(OpAsmPrinter &p, SliceOp op) {
1043 auto indexings = op.indexings();
1044 p << SliceOp::getOperationName() << " " << op.view() << "[" << indexings
1045 << "] ";
1046 p.printOptionalAttrDict(op.getAttrs());
1047 p << " : " << op.getBaseViewType();
1048 if (!indexings.empty())
1049 p << ", " << op.indexings().getTypes();
1050 p << ", " << op.getType();
1051 }
1052
parseSliceOp(OpAsmParser & parser,OperationState & result)1053 static ParseResult parseSliceOp(OpAsmParser &parser, OperationState &result) {
1054 OpAsmParser::OperandType baseInfo;
1055 SmallVector<OpAsmParser::OperandType, 8> operands;
1056 SmallVector<Type, 8> types;
1057 if (parser.parseOperand(baseInfo) ||
1058 parser.parseOperandList(operands, OpAsmParser::Delimiter::Square) ||
1059 parser.parseOptionalAttrDict(result.attributes) ||
1060 parser.parseColonTypeList(types))
1061 return failure();
1062
1063 if (types.size() < 2)
1064 return parser.emitError(parser.getCurrentLocation(),
1065 "expected at least input and result view types");
1066
1067 ArrayRef<Type> indexingTypes = ArrayRef<Type>(types).drop_front().drop_back();
1068 return failure(
1069 parser.resolveOperand(baseInfo, types.front(), result.operands) ||
1070 (!operands.empty() &&
1071 parser.resolveOperands(operands, indexingTypes,
1072 operands.front().location, result.operands)) ||
1073 parser.addTypeToList(types.back(), result.types));
1074 }
1075
verify(SliceOp op)1076 static LogicalResult verify(SliceOp op) {
1077 unsigned rank = op.getBaseViewRank();
1078 if (rank != llvm::size(op.indexings()))
1079 return op.emitOpError("expected ")
1080 << rank << " indexings, got " << llvm::size(op.indexings());
1081 unsigned index = 0;
1082 for (auto indexing : op.indexings()) {
1083 if (indexing.getType().isa<IndexType>())
1084 --rank;
1085 ++index;
1086 }
1087 if (op.getRank() != rank)
1088 return op.emitOpError() << "expected rank of the view(" << op.getRank()
1089 << ") to be the number of ranges(" << rank << ")";
1090 return success();
1091 }
1092
getViewSource()1093 Value SliceOp::getViewSource() { return view(); }
1094
1095 //===----------------------------------------------------------------------===//
1096 // YieldOp
1097 //===----------------------------------------------------------------------===//
1098
print(OpAsmPrinter & p,linalg::YieldOp op)1099 static void print(OpAsmPrinter &p, linalg::YieldOp op) {
1100 p << op.getOperationName();
1101 if (op.getNumOperands() > 0)
1102 p << ' ' << op.getOperands();
1103 p.printOptionalAttrDict(op.getAttrs());
1104 if (op.getNumOperands() > 0)
1105 p << " : " << op.getOperandTypes();
1106 }
1107
parseYieldOp(OpAsmParser & parser,OperationState & result)1108 static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) {
1109 SmallVector<OpAsmParser::OperandType, 2> opInfo;
1110 SmallVector<Type, 2> types;
1111 llvm::SMLoc loc = parser.getCurrentLocation();
1112 return failure(parser.parseOperandList(opInfo) ||
1113 parser.parseOptionalAttrDict(result.attributes) ||
1114 (!opInfo.empty() && parser.parseColonTypeList(types)) ||
1115 parser.resolveOperands(opInfo, types, loc, result.operands));
1116 }
1117
1118 // Check the operand number and types must match the element types of the
1119 // LinalgOp interface's shaped operands.
verifyYield(linalg::YieldOp op,LinalgOp linalgOpInterface)1120 static LogicalResult verifyYield(linalg::YieldOp op,
1121 LinalgOp linalgOpInterface) {
1122 auto nOutputs = linalgOpInterface.getNumOutputs();
1123 if (op.getNumOperands() != nOutputs)
1124 return op.emitOpError("expected number of yield values (")
1125 << nOutputs << ") to match the number of operands of the enclosing "
1126 << "LinalgOp (" << op.getNumOperands() << ")";
1127
1128 for (unsigned i = 0; i != nOutputs; ++i) {
1129 auto elementType =
1130 linalgOpInterface.getOutputShapedType(i).getElementType();
1131 if (op.getOperand(i).getType() != elementType)
1132 return op.emitOpError("type of yield operand ")
1133 << (i + 1) << " (" << op.getOperand(i).getType()
1134 << ") doesn't match "
1135 << "the element type of the enclosing linalg.generic op ("
1136 << elementType << ")";
1137 }
1138 return success();
1139 }
1140
verify(linalg::YieldOp op)1141 static LogicalResult verify(linalg::YieldOp op) {
1142 auto *parentOp = op->getParentOp();
1143 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
1144 return op.emitOpError("expected single non-empty parent region");
1145
1146 if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
1147 return verifyYield(op, cast<LinalgOp>(parentOp));
1148
1149 return op.emitOpError("expected parent op with LinalgOp interface");
1150 }
1151
1152 /////// Operations corresponding to library calls defined with Tablegen ////////
1153
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)1154 void FillOp::getEffects(
1155 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1156 &effects) {
1157 effects.emplace_back(MemoryEffects::Write::get(), output(),
1158 SideEffects::DefaultResource::get());
1159 }
1160
verify(FillOp op)1161 static LogicalResult verify(FillOp op) {
1162 auto viewType = op.getOutputShapedType(0);
1163 auto fillType = op.value().getType();
1164 if (viewType.getElementType() != fillType)
1165 return op.emitOpError("expects fill type to match view elemental type");
1166 return success();
1167 }
1168
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)1169 void CopyOp::getEffects(
1170 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1171 &effects) {
1172 effects.emplace_back(MemoryEffects::Read::get(), input(),
1173 SideEffects::DefaultResource::get());
1174 effects.emplace_back(MemoryEffects::Write::get(), output(),
1175 SideEffects::DefaultResource::get());
1176 }
1177
verify(CopyOp op)1178 static LogicalResult verify(CopyOp op) {
1179 auto outputViewType = op.getOutputShapedType(0);
1180 auto inputViewType = op.getInputShapedType(0);
1181 if (inputViewType.getElementType() != outputViewType.getElementType())
1182 return op.emitOpError("expects views of the same type");
1183 if (inputViewType.getRank() != outputViewType.getRank())
1184 return op.emitOpError("expects views of the same rank");
1185 auto rank = op.getNumParallelLoops();
1186 auto inputPermutationMap = op.inputPermutation();
1187 if (inputPermutationMap) {
1188 if (inputPermutationMap->getNumInputs() != rank)
1189 return op.emitOpError("expects optional input_permutation map of rank ")
1190 << rank;
1191 if (!inputPermutationMap->isPermutation())
1192 return op.emitOpError(
1193 "expects optional input_permutation map to be a permutation");
1194 }
1195 auto outputPermutationMap = op.outputPermutation();
1196 if (outputPermutationMap) {
1197 if (outputPermutationMap->getNumInputs() != rank)
1198 return op.emitOpError("expects optional output_permutation map of rank ")
1199 << rank;
1200 if (!outputPermutationMap->isPermutation())
1201 return op.emitOpError(
1202 "expects optional output_permutation map to be a permutation");
1203 }
1204 if (rank == 0 && inputPermutationMap)
1205 return op.emitOpError("expected no input permutation when rank == 0");
1206 if (rank == 0 && outputPermutationMap)
1207 return op.emitOpError("expected no output permutation when rank == 0");
1208 return success();
1209 }
1210
1211 template <typename LinalgPoolingOp>
verifyStrideOrDilation(LinalgPoolingOp op,ArrayRef<Attribute> attrs,bool isStride)1212 static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op,
1213 ArrayRef<Attribute> attrs,
1214 bool isStride) {
1215 auto strideOrDilation = isStride ? "stride" : "dilation";
1216 if (attrs.size() != op.getNumWindowLoops())
1217 return op.emitOpError("expects num ")
1218 << strideOrDilation
1219 << "s equal to number of window dimensions: " << attrs.size()
1220 << " vs " << op.getNumWindowLoops();
1221 return success();
1222 }
1223
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)1224 void ConvOp::getEffects(
1225 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1226 &effects) {
1227 effects.emplace_back(MemoryEffects::Read::get(), input(),
1228 SideEffects::DefaultResource::get());
1229 effects.emplace_back(MemoryEffects::Read::get(), filter(),
1230 SideEffects::DefaultResource::get());
1231 effects.emplace_back(MemoryEffects::Write::get(), output(),
1232 SideEffects::DefaultResource::get());
1233 }
1234
verify(ConvOp op)1235 static LogicalResult verify(ConvOp op) {
1236 auto oType = op.output().getType().cast<MemRefType>();
1237 auto fType = op.filter().getType().cast<MemRefType>();
1238 auto iType = op.input().getType().cast<MemRefType>();
1239 if (oType.getElementType() != iType.getElementType() ||
1240 oType.getElementType() != fType.getElementType())
1241 return op.emitOpError("expects memref elemental types to match");
1242 if (oType.getRank() != iType.getRank() || oType.getRank() != fType.getRank())
1243 return op.emitOpError("expects memref ranks to match");
1244 if (oType.getRank() <= 2)
1245 return op.emitOpError("expects memref ranks to be greater than 2");
1246 if (auto strides = op.strides()) {
1247 if (failed(
1248 verifyStrideOrDilation(op, strides->getValue(), /*isStride=*/true)))
1249 return failure();
1250 }
1251 if (auto dilations = op.dilations()) {
1252 if (failed(verifyStrideOrDilation(op, dilations->getValue(),
1253 /*isStride=*/false)))
1254 return failure();
1255 }
1256 return success();
1257 }
1258
1259 template <typename PoolingOp>
verifySingleInputPoolingOp(PoolingOp op)1260 static LogicalResult verifySingleInputPoolingOp(PoolingOp op) {
1261 auto inputType = op.input().getType().template cast<MemRefType>();
1262 auto outputType = op.output().getType().template cast<MemRefType>();
1263 if (outputType.getElementType() != inputType.getElementType())
1264 return op.emitOpError("expects memref elemental types to match");
1265
1266 auto windowDimsType = op.windowDims().getType().template cast<MemRefType>();
1267 if (outputType.getRank() != inputType.getRank() ||
1268 outputType.getRank() != windowDimsType.getRank())
1269 return op.emitOpError("expects memref ranks to match");
1270
1271 if (auto strides = op.strides()) {
1272 if (failed(
1273 verifyStrideOrDilation(op, strides->getValue(), /*isStride=*/true)))
1274 return failure();
1275 }
1276 if (auto dilations = op.dilations()) {
1277 if (failed(verifyStrideOrDilation(op, dilations->getValue(),
1278 /*isStride=*/false)))
1279 return failure();
1280 }
1281 return success();
1282 }
1283
1284 #define DEFINE_POOLING_OP_GET_EFFECTS(OP_NAME) \
1285 void OP_NAME::getEffects( \
1286 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \
1287 &effects) { \
1288 effects.emplace_back(MemoryEffects::Read::get(), input(), \
1289 SideEffects::DefaultResource::get()); \
1290 effects.emplace_back(MemoryEffects::Write::get(), output(), \
1291 SideEffects::DefaultResource::get()); \
1292 }
1293
verify(PoolingMaxOp op)1294 static LogicalResult verify(PoolingMaxOp op) {
1295 return verifySingleInputPoolingOp(op);
1296 }
verify(PoolingMinOp op)1297 static LogicalResult verify(PoolingMinOp op) {
1298 return verifySingleInputPoolingOp(op);
1299 }
verify(PoolingSumOp op)1300 static LogicalResult verify(PoolingSumOp op) {
1301 return verifySingleInputPoolingOp(op);
1302 }
1303
1304 DEFINE_POOLING_OP_GET_EFFECTS(PoolingMaxOp)
1305 DEFINE_POOLING_OP_GET_EFFECTS(PoolingMinOp)
1306 DEFINE_POOLING_OP_GET_EFFECTS(PoolingSumOp)
1307
1308 namespace {
1309 struct EraseDeadLinalgOp;
1310 struct FoldTensorCastOp;
1311 } // namespace
1312
1313 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterfaces.cpp.inc"
1314
1315 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc"
1316
1317 #define GET_OP_CLASSES
1318 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
1319
1320 #define GET_OP_CLASSES
1321 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
1322
1323 /// Return the dims that are `iteratorTypeName` loops in the LinalgOp `op`.
1324 /// Assumes `op` is a LinalgOp.
getDimsOfType(Operation * op,StringRef iteratorTypeName,SmallVectorImpl<AffineExpr> & res)1325 void mlir::linalg::getDimsOfType(Operation *op, StringRef iteratorTypeName,
1326 SmallVectorImpl<AffineExpr> &res) {
1327 if (!cast<LinalgOp>(op).iterator_types())
1328 return;
1329
1330 unsigned dim = 0;
1331 MLIRContext *ctx = op->getContext();
1332 for (auto tn :
1333 cast<LinalgOp>(op).iterator_types().getAsValueRange<StringAttr>()) {
1334 if (tn == iteratorTypeName)
1335 res.push_back(getAffineDimExpr(dim, ctx));
1336 ++dim;
1337 }
1338 }
1339
extractOrIdentityMap(Optional<AffineMap> maybeMap,unsigned rank,MLIRContext * context)1340 AffineMap mlir::linalg::extractOrIdentityMap(Optional<AffineMap> maybeMap,
1341 unsigned rank,
1342 MLIRContext *context) {
1343 if (maybeMap)
1344 return maybeMap.getValue();
1345 if (rank == 0)
1346 return AffineMap::get(context);
1347 return AffineMap::getMultiDimIdentityMap(rank, context);
1348 }
1349
1350 SmallVector<AffineExpr, 4>
makeAffineDimExprs(unsigned num,unsigned & startIdx,MLIRContext * context)1351 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
1352 MLIRContext *context) {
1353 SmallVector<AffineExpr, 4> res;
1354 res.reserve(num);
1355 for (unsigned i = 0; i < num; ++i)
1356 res.push_back(getAffineDimExpr(startIdx++, context));
1357 return res;
1358 }
1359
1360 template <typename PoolingOp>
1361 SmallVector<AffineExpr, 4>
weightedPoolingInputIndex(PoolingOp op,ArrayRef<AffineExpr> outputDims,ArrayRef<AffineExpr> windowDims)1362 mlir::linalg::weightedPoolingInputIndex(PoolingOp op,
1363 ArrayRef<AffineExpr> outputDims,
1364 ArrayRef<AffineExpr> windowDims) {
1365 assert(outputDims.size() == windowDims.size());
1366 SmallVector<AffineExpr, 4> res;
1367 res.reserve(outputDims.size());
1368 for (unsigned i = 0, e = outputDims.size(); i < e; ++i) {
1369 // TODO: add a level of indirection to linalg.generic.
1370 auto expr = op.getStride(i) * outputDims[i] +
1371 op.getDilation(i) * windowDims[i] - op.getLowPad(i);
1372 res.push_back(expr);
1373 }
1374 return res;
1375 }
1376
1377 #define INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(OP_TYPE) \
1378 template SmallVector<AffineExpr, 4> \
1379 mlir::linalg::weightedPoolingInputIndex<OP_TYPE>( \
1380 OP_TYPE op, ArrayRef<AffineExpr> outputDims, \
1381 ArrayRef<AffineExpr> windowDims);
1382
1383 INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(ConvOp)
INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingMaxOp)1384 INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingMaxOp)
1385 INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingMinOp)
1386 INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingSumOp)
1387
1388 SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
1389 ArrayRef<AffineExpr> b) {
1390 auto rangeA = llvm::make_range(a.begin(), a.end());
1391 auto rangeB = llvm::make_range(b.begin(), b.end());
1392 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
1393 return llvm::to_vector<4>(concatRanges);
1394 }
1395
appendMangledType(llvm::raw_string_ostream & ss,Type t)1396 static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
1397 if (auto memref = t.dyn_cast<MemRefType>()) {
1398 ss << "view";
1399 for (auto size : memref.getShape())
1400 if (size < 0)
1401 ss << "sx";
1402 else
1403 ss << size << "x";
1404 appendMangledType(ss, memref.getElementType());
1405 } else if (auto vec = t.dyn_cast<VectorType>()) {
1406 ss << "vector";
1407 llvm::interleave(
1408 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
1409 appendMangledType(ss, vec.getElementType());
1410 } else if (t.isSignlessIntOrIndexOrFloat()) {
1411 ss << t;
1412 } else {
1413 llvm_unreachable("Invalid type for linalg library name mangling");
1414 }
1415 }
1416
generateLibraryCallName(Operation * op)1417 std::string mlir::linalg::generateLibraryCallName(Operation *op) {
1418 assert(isa<LinalgOp>(op));
1419 std::string name(op->getName().getStringRef().str());
1420 name.reserve(128);
1421 std::replace(name.begin(), name.end(), '.', '_');
1422 llvm::raw_string_ostream ss(name);
1423 ss << "_";
1424 auto types = op->getOperandTypes();
1425 llvm::interleave(
1426 types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); },
1427 [&]() { ss << "_"; });
1428 return ss.str();
1429 }
1430
1431 // TODO: Consider making all this boilerplate easy to autogenerate
1432 // with Tablegen. This seems a desirable property in the context of
1433 // OpInterfaces where a Linalg "named" op **isa** LinalgOp.
fold(ArrayRef<Attribute> operands)1434 OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
1435 if (succeeded(foldMemRefCast(*this)))
1436 return getResult();
1437 return foldReshapeOp(*this, operands);
1438 }
fold(ArrayRef<Attribute>)1439 OpFoldResult SliceOp::fold(ArrayRef<Attribute>) {
1440 if (succeeded(foldMemRefCast(*this)))
1441 return getResult();
1442 return {};
1443 }
fold(ArrayRef<Attribute> operands)1444 OpFoldResult TensorReshapeOp::fold(ArrayRef<Attribute> operands) {
1445 return foldReshapeOp(*this, operands);
1446 }
1447
1448 //===----------------------------------------------------------------------===//
1449 // Auto-generated Linalg named ops.
1450 //===----------------------------------------------------------------------===//
1451
1452 template <typename NamedStructuredOpType>
buildNamedStructuredOpRegionAndAttributesImpl(OpBuilder & opBuilder,Region & region,TypeRange inputTypes,TypeRange outputBufferTypes,TypeRange initTensorTypes,TypeRange resultTypes,std::function<void (unsigned,unsigned)> errorHandler)1453 static void buildNamedStructuredOpRegionAndAttributesImpl(
1454 OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes,
1455 TypeRange outputBufferTypes, TypeRange initTensorTypes,
1456 TypeRange resultTypes,
1457 std::function<void(unsigned, unsigned)> errorHandler) {
1458 // TODO: atm all operands go through getElementTypeOrSelf,
1459 // reconsider when we have evidence we need to.
1460 SmallVector<Type, 8> argTypes;
1461 for (auto containers : {inputTypes, outputBufferTypes, resultTypes})
1462 for (auto t : containers)
1463 argTypes.push_back(getElementTypeOrSelf(t));
1464
1465 // RAII.
1466 OpBuilder::InsertionGuard guard(opBuilder);
1467 Block *body = opBuilder.createBlock(®ion, {}, argTypes);
1468 unsigned actual = body->getNumArguments();
1469 unsigned expected = NamedStructuredOpType::getNumRegionArgs();
1470 if (expected != actual)
1471 return errorHandler(expected, actual);
1472
1473 opBuilder.setInsertionPointToStart(body);
1474 mlir::edsc::ScopedContext scope(opBuilder, opBuilder.getUnknownLoc());
1475 NamedStructuredOpType::regionBuilder(*body);
1476
1477 // indexing_maps is an auto-generated method.
1478
1479 // iterator_types is an auto-generated method.
1480 }
1481
1482 template <typename NamedStructuredOpType>
buildNamedStructuredOpRegionAndAttributes(OpBuilder & opBuilder,OperationState & result,TypeRange inputTypes,TypeRange outputBufferTypes,TypeRange initTensorTypes,TypeRange resultTypes)1483 void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
1484 OperationState &result,
1485 TypeRange inputTypes,
1486 TypeRange outputBufferTypes,
1487 TypeRange initTensorTypes,
1488 TypeRange resultTypes) {
1489 Region ®ion = *result.addRegion();
1490 buildNamedStructuredOpRegionAndAttributesImpl<NamedStructuredOpType>(
1491 opBuilder, region, inputTypes, outputBufferTypes, initTensorTypes,
1492 resultTypes, [&](unsigned expected, unsigned actual) {
1493 llvm::errs() << "region expects " << expected << " args, got "
1494 << actual;
1495 assert(expected != actual && "incorrect number of arguments");
1496 });
1497 }
1498
1499 template <typename NamedStructuredOpType>
1500 static ParseResult
parseNamedStructuredOpRegion(OpAsmParser & parser,Region & region,TypeRange inputTypes,TypeRange outputBufferTypes,TypeRange initTensorTypes,TypeRange resultTypes)1501 parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion,
1502 TypeRange inputTypes, TypeRange outputBufferTypes,
1503 TypeRange initTensorTypes, TypeRange resultTypes) {
1504 ParseResult res = success();
1505 OpBuilder opBuilder(parser.getBuilder().getContext());
1506 buildNamedStructuredOpRegionAndAttributesImpl<NamedStructuredOpType>(
1507 opBuilder, region, inputTypes, outputBufferTypes, initTensorTypes,
1508 resultTypes, [&](unsigned expected, unsigned actual) {
1509 res = parser.emitError(parser.getCurrentLocation(),
1510 llvm::formatv("region expects {0} args, got {1}",
1511 expected, actual));
1512 });
1513 return res;
1514 }
1515
1516 static ParseResult
parseNamedStructuredOpResults(OpAsmParser & parser,SmallVectorImpl<Type> & resultTypes)1517 parseNamedStructuredOpResults(OpAsmParser &parser,
1518 SmallVectorImpl<Type> &resultTypes) {
1519 if (succeeded(parser.parseOptionalArrow()))
1520 if (parser.parseTypeList(resultTypes))
1521 return failure();
1522 return success();
1523 }
1524
1525 static ParseResult
parseCommonStructuredOpParts(OpAsmParser & parser,OperationState & result,SmallVectorImpl<Type> & inputTypes,SmallVectorImpl<Type> & outputBufferTypes,SmallVectorImpl<Type> & initTensorTypes)1526 parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
1527 SmallVectorImpl<Type> &inputTypes,
1528 SmallVectorImpl<Type> &outputBufferTypes,
1529 SmallVectorImpl<Type> &initTensorTypes) {
1530 llvm::SMLoc inputsOperandsLoc, outputBuffersOperandsLoc,
1531 initTensorsOperandsLoc;
1532 SmallVector<OpAsmParser::OperandType, 4> inputsOperands,
1533 outputBuffersOperands, initTensorsOperands;
1534
1535 parser.parseOptionalAttrDict(result.attributes);
1536
1537 if (succeeded(parser.parseOptionalKeyword("ins"))) {
1538 if (parser.parseLParen())
1539 return failure();
1540
1541 inputsOperandsLoc = parser.getCurrentLocation();
1542 if (parser.parseOperandList(inputsOperands) ||
1543 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
1544 return failure();
1545 }
1546
1547 if (succeeded(parser.parseOptionalKeyword("outs"))) {
1548 outputBuffersOperandsLoc = parser.getCurrentLocation();
1549 if (parser.parseLParen() ||
1550 parser.parseOperandList(outputBuffersOperands) ||
1551 parser.parseColonTypeList(outputBufferTypes) || parser.parseRParen())
1552 return failure();
1553 }
1554 if (succeeded(parser.parseOptionalKeyword("init"))) {
1555 initTensorsOperandsLoc = parser.getCurrentLocation();
1556 if (parser.parseLParen() || parser.parseOperandList(initTensorsOperands) ||
1557 parser.parseColonTypeList(initTensorTypes) || parser.parseRParen())
1558 return failure();
1559 }
1560
1561 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
1562 result.operands) ||
1563 parser.resolveOperands(outputBuffersOperands, outputBufferTypes,
1564 outputBuffersOperandsLoc, result.operands) ||
1565 parser.resolveOperands(initTensorsOperands, initTensorTypes,
1566 initTensorsOperandsLoc, result.operands))
1567 return failure();
1568
1569 result.addAttribute("operand_segment_sizes",
1570 parser.getBuilder().getI32VectorAttr(
1571 {static_cast<int32_t>(inputsOperands.size()),
1572 static_cast<int32_t>(outputBuffersOperands.size()),
1573 static_cast<int32_t>(initTensorsOperands.size())}));
1574 return success();
1575 }
1576
1577 template <typename NamedStructuredOpType>
parseNamedStructuredOp(OpAsmParser & parser,OperationState & result)1578 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
1579 OperationState &result) {
1580 SmallVector<Type, 1> inputTypes, outputBufferTypes, initTensorTypes;
1581 if (parseCommonStructuredOpParts(parser, result, inputTypes,
1582 outputBufferTypes, initTensorTypes))
1583 return failure();
1584
1585 // TODO: consider merging results parsing into region parsing.
1586 // Need to wait for declarative assembly resolution to decide.
1587 SmallVector<Type, 1> outputTensorsTypes;
1588 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1589 return failure();
1590 result.addTypes(outputTensorsTypes);
1591
1592 std::unique_ptr<Region> region = std::make_unique<Region>();
1593 if (parseNamedStructuredOpRegion<NamedStructuredOpType>(
1594 parser, *region, inputTypes, outputBufferTypes, initTensorTypes,
1595 outputTensorsTypes))
1596 return failure();
1597 result.addRegion(std::move(region));
1598
1599 return success();
1600 }
1601
printNamedStructuredOpResults(OpAsmPrinter & p,TypeRange resultTypes)1602 static void printNamedStructuredOpResults(OpAsmPrinter &p,
1603 TypeRange resultTypes) {
1604 if (resultTypes.empty())
1605 return;
1606 p.printOptionalArrowTypeList(resultTypes);
1607 }
1608
1609 template <typename NamedStructuredOpType>
printCommonStructuredOpParts(OpAsmPrinter & p,NamedStructuredOpType op)1610 static void printCommonStructuredOpParts(OpAsmPrinter &p,
1611 NamedStructuredOpType op) {
1612 if (!op.inputs().empty())
1613 p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")";
1614 if (!op.output_buffers().empty())
1615 p << " outs(" << op.output_buffers() << " : "
1616 << op.output_buffers().getTypes() << ")";
1617 if (!op.init_tensors().empty())
1618 p << " init(" << op.init_tensors() << " : " << op.init_tensors().getTypes()
1619 << ") ";
1620 }
1621
1622 template <typename NamedStructuredOpType>
printNamedStructuredOp(OpAsmPrinter & p,NamedStructuredOpType op)1623 static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) {
1624 p << op.getOperationName();
1625 p.printOptionalAttrDict(op.getAttrs(),
1626 /*elidedAttrs=*/{"operand_segment_sizes"});
1627
1628 // Printing is shared with generic ops, except for the region and
1629 // attributes.
1630 printCommonStructuredOpParts(p, op);
1631
1632 // Results printing.
1633 printNamedStructuredOpResults(p, op.result_tensors().getTypes());
1634
1635 // Region is elided.
1636 }
1637
1638 template <typename NamedStructuredOpType>
verifyNamedStructuredOp(NamedStructuredOpType op)1639 static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) {
1640 return verifyGenericOp<NamedStructuredOpType>(op);
1641 }
1642
1643 namespace {
1644 struct EraseDeadLinalgOp : public RewritePattern {
EraseDeadLinalgOp__anon40fc37570f11::EraseDeadLinalgOp1645 EraseDeadLinalgOp(PatternBenefit benefit = 1)
1646 : RewritePattern(benefit, MatchAnyOpTypeTag()) {}
1647
matchAndRewrite__anon40fc37570f11::EraseDeadLinalgOp1648 LogicalResult matchAndRewrite(Operation *op,
1649 PatternRewriter &rewriter) const override {
1650 auto linalgOp = dyn_cast<LinalgOp>(op);
1651 if (!linalgOp)
1652 return failure();
1653 for (Value v : linalgOp.getInputsAndOutputBuffers()) {
1654 // Linalg "inputs" may be either tensor or memref type.
1655 // tensor<0xelt_type> is a convention that may not always mean
1656 // "0 iterations". Only erase in cases we see memref<...x0x...>.
1657 auto mt = v.getType().dyn_cast<MemRefType>();
1658 if (!mt)
1659 continue;
1660 if (llvm::is_contained(mt.getShape(), 0)) {
1661 rewriter.eraseOp(linalgOp);
1662 return success();
1663 }
1664 }
1665 return failure();
1666 }
1667 };
1668
1669 struct FoldTensorCastOp : public RewritePattern {
FoldTensorCastOp__anon40fc37570f11::FoldTensorCastOp1670 FoldTensorCastOp(PatternBenefit benefit = 1)
1671 : RewritePattern(benefit, MatchAnyOpTypeTag()) {}
1672
matchAndRewrite__anon40fc37570f11::FoldTensorCastOp1673 LogicalResult matchAndRewrite(Operation *op,
1674 PatternRewriter &rewriter) const override {
1675 auto linalgOp = dyn_cast<LinalgOp>(op);
1676 if (!linalgOp)
1677 return failure();
1678
1679 // If no operand comes from a TensorCastOp and can be folded then fail.
1680 bool hasTensorCastOperand =
1681 llvm::any_of(linalgOp.getShapedOperands(), [&](Value v) {
1682 if (v.isa<BlockArgument>())
1683 return false;
1684 auto castOp = v.getDefiningOp<TensorCastOp>();
1685 return castOp && canFoldIntoConsumerOp(castOp);
1686 });
1687 if (!hasTensorCastOperand)
1688 return failure();
1689
1690 SmallVector<Type, 4> newResultTypes;
1691 newResultTypes.reserve(op->getNumResults());
1692 SmallVector<Value, 4> newOperands;
1693 newOperands.reserve(op->getNumOperands());
1694 // Inputs may fold.
1695 for (Value v : linalgOp.getInputs()) {
1696 auto tensorCastOp = v.getDefiningOp<TensorCastOp>();
1697 newOperands.push_back(
1698 canFoldIntoConsumerOp(tensorCastOp) ? tensorCastOp.source() : v);
1699 }
1700 // Output buffers are memrefs, they don't fold.
1701 newOperands.append(linalgOp.getOutputBuffers().begin(),
1702 linalgOp.getOutputBuffers().end());
1703 // Init tensors may fold, in which case the resultType must also change.
1704 for (Value v : linalgOp.getInitTensors()) {
1705 auto tensorCastOp = v.getDefiningOp<TensorCastOp>();
1706 bool fold = canFoldIntoConsumerOp(tensorCastOp);
1707 newOperands.push_back(fold ? tensorCastOp.getOperand() : v);
1708 newResultTypes.push_back(newOperands.back().getType());
1709 }
1710 auto extraOperands = linalgOp.getAssumedNonShapedOperands();
1711 newOperands.append(extraOperands.begin(), extraOperands.end());
1712 // Clone op.
1713 Operation *newOp =
1714 linalgOp.clone(rewriter, op->getLoc(), newResultTypes, newOperands);
1715 rewriter.replaceOp(op, newOp->getResults());
1716
1717 return success();
1718 }
1719 };
1720 } // namespace
1721
1722 namespace {
1723 // Deduplicate redundant args of a linalg op.
1724 // An arg is redundant if it has the same Value and indexing map as another.
1725 struct DeduplicateInputs : public RewritePattern {
DeduplicateInputs__anon40fc37571111::DeduplicateInputs1726 DeduplicateInputs(PatternBenefit benefit = 1)
1727 : RewritePattern(benefit, MatchAnyOpTypeTag()) {}
1728
matchAndRewrite__anon40fc37571111::DeduplicateInputs1729 LogicalResult matchAndRewrite(Operation *op,
1730 PatternRewriter &rewriter) const override {
1731 // This pattern reduces the number of arguments of an op, which breaks
1732 // the invariants of semantically charged named ops.
1733 if (!isa<GenericOp, IndexedGenericOp>(op))
1734 return failure();
1735 auto linalgOp = cast<LinalgOp>(op);
1736
1737 // Associate each input to an equivalent "canonical" input that has the same
1738 // Value and indexing map.
1739 //
1740 // In the non-duplicate case, input `i` will have canonical input `i`. But
1741 // in the case of duplicated inputs, the canonical input could be some other
1742 // input `< i`. That is, a later input will have some earlier input as its
1743 // canonical input.
1744 llvm::SmallDenseMap<std::pair<Value, AffineMap>, int> canonicalInput;
1745 // For later remapping tasks like deduplicating payload block arguments,
1746 // having a simple "inputIndex -> canonicalInputIndex" integer mapping is
1747 // convenient.
1748 SmallVector<int, 6> canonicalInputIndices;
1749 for (int i = 0, e = linalgOp.getNumInputs(); i != e; i++) {
1750 Value input = linalgOp.getInput(i);
1751 AffineMap indexingMap = linalgOp.getInputIndexingMap(i);
1752 // STL-like maps have a convenient behavior for our use case here. In the
1753 // case of duplicate keys, the insertion is rejected, and the returned
1754 // iterator gives access to the value already in the map.
1755 auto pair = canonicalInput.insert({{input, indexingMap}, i});
1756 canonicalInputIndices.push_back(pair.first->second);
1757 }
1758
1759 // If there are no duplicate args, then bail out.
1760 if (canonicalInput.size() == linalgOp.getNumInputs())
1761 return failure();
1762
1763 // The operands for the newly canonicalized op.
1764 SmallVector<Value, 6> newOperands;
1765 for (auto v : llvm::enumerate(linalgOp.getInputs()))
1766 if (canonicalInputIndices[v.index()] == static_cast<int>(v.index()))
1767 newOperands.push_back(v.value());
1768 llvm::append_range(newOperands, linalgOp.getOutputBuffers());
1769 llvm::append_range(newOperands, linalgOp.getInitTensors());
1770 llvm::append_range(newOperands, linalgOp.getAssumedNonShapedOperands());
1771
1772 // Clone the old op with new operands.
1773 Operation *newOp = linalgOp.clone(rewriter, op->getLoc(),
1774 op->getResultTypes(), newOperands);
1775 auto newLinalgOp = cast<LinalgOp>(newOp);
1776
1777 // Repair the indexing maps by filtering out the ones that have been
1778 // eliminated.
1779 SmallVector<AffineMap, 6> newIndexingMaps;
1780 for (int i = 0, e = newLinalgOp.getNumInputs(); i != e; i++)
1781 if (canonicalInputIndices[i] == i)
1782 newIndexingMaps.push_back(newLinalgOp.getIndexingMap(i));
1783 for (int i = 0, e = newLinalgOp.getNumOutputs(); i != e; i++)
1784 newIndexingMaps.push_back(newLinalgOp.getOutputIndexingMap(i));
1785 newOp->setAttr("indexing_maps",
1786 rewriter.getAffineMapArrayAttr(newIndexingMaps));
1787
1788 // Set the number of inputs to the new value. The `clone` call above kept
1789 // the value from the original op.
1790 newLinalgOp.setNumInputs(canonicalInput.size());
1791
1792 // linalg.indexed_generic payloads have additional arguments prepended to
1793 // the block arg list. The number of such args is one per dimension of the
1794 // iteration space.
1795 int bbArgBaseOffset = 0;
1796 if (isa<IndexedGenericOp>(op))
1797 bbArgBaseOffset = newIndexingMaps[0].getNumInputs();
1798
1799 // Repair the payload entry block by RAUW'ing redundant arguments and
1800 // erasing them.
1801 Block &payload = newOp->getRegion(0).front();
1802 for (int i = 0, e = linalgOp.getNumInputs(); i < e; i++) {
1803 // Iterate in reverse, so that we erase later args first, preventing the
1804 // argument list from shifting unexpectedly and invalidating all our
1805 // indices.
1806 int reversed = e - i - 1;
1807 int canonicalIndex = canonicalInputIndices[reversed];
1808 if (canonicalInputIndices[reversed] == reversed)
1809 continue;
1810 payload.getArgument(bbArgBaseOffset + reversed)
1811 .replaceAllUsesWith(
1812 payload.getArgument(bbArgBaseOffset + canonicalIndex));
1813 payload.eraseArgument(bbArgBaseOffset + reversed);
1814 }
1815
1816 rewriter.replaceOp(op, newOp->getResults());
1817 return success();
1818 }
1819 };
1820 } // namespace
1821
1822 #define CANONICALIZERS_AND_FOLDERS(XXX) \
1823 void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results, \
1824 MLIRContext *context) { \
1825 results.insert<EraseDeadLinalgOp>(); \
1826 results.insert<FoldTensorCastOp>(); \
1827 results.insert<DeduplicateInputs>(); \
1828 } \
1829 \
1830 LogicalResult XXX::fold(ArrayRef<Attribute>, \
1831 SmallVectorImpl<OpFoldResult> &) { \
1832 return foldMemRefCast(*this); \
1833 }
1834
1835 CANONICALIZERS_AND_FOLDERS(ConvOp)
1836 CANONICALIZERS_AND_FOLDERS(PoolingMaxOp)
1837 CANONICALIZERS_AND_FOLDERS(PoolingMinOp)
1838 CANONICALIZERS_AND_FOLDERS(PoolingSumOp)
1839 CANONICALIZERS_AND_FOLDERS(CopyOp)
1840 CANONICALIZERS_AND_FOLDERS(FillOp)
1841 CANONICALIZERS_AND_FOLDERS(GenericOp)
1842 CANONICALIZERS_AND_FOLDERS(IndexedGenericOp)
1843
1844 // All named ops canonicalizers and folders are auto-generated in the
1845 // .cpp.inc.
1846