1 //===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/SCF.h"
10 #include "mlir/Dialect/StandardOps/IR/Ops.h"
11 #include "mlir/IR/BlockAndValueMapping.h"
12 #include "mlir/IR/PatternMatch.h"
13 #include "mlir/Support/MathExtras.h"
14 #include "mlir/Transforms/InliningUtils.h"
15 
16 using namespace mlir;
17 using namespace mlir::scf;
18 
19 //===----------------------------------------------------------------------===//
20 // SCFDialect Dialect Interfaces
21 //===----------------------------------------------------------------------===//
22 
23 namespace {
24 struct SCFInlinerInterface : public DialectInlinerInterface {
25   using DialectInlinerInterface::DialectInlinerInterface;
26   // We don't have any special restrictions on what can be inlined into
27   // destination regions (e.g. while/conditional bodies). Always allow it.
isLegalToInline__anon4ade4bc50111::SCFInlinerInterface28   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
29                        BlockAndValueMapping &valueMapping) const final {
30     return true;
31   }
32   // Operations in scf dialect are always legal to inline since they are
33   // pure.
isLegalToInline__anon4ade4bc50111::SCFInlinerInterface34   bool isLegalToInline(Operation *, Region *, bool,
35                        BlockAndValueMapping &) const final {
36     return true;
37   }
38   // Handle the given inlined terminator by replacing it with a new operation
39   // as necessary. Required when the region has only one block.
handleTerminator__anon4ade4bc50111::SCFInlinerInterface40   void handleTerminator(Operation *op,
41                         ArrayRef<Value> valuesToRepl) const final {
42     auto retValOp = dyn_cast<scf::YieldOp>(op);
43     if (!retValOp)
44       return;
45 
46     for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
47       std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
48     }
49   }
50 };
51 } // end anonymous namespace
52 
53 //===----------------------------------------------------------------------===//
54 // SCFDialect
55 //===----------------------------------------------------------------------===//
56 
initialize()57 void SCFDialect::initialize() {
58   addOperations<
59 #define GET_OP_LIST
60 #include "mlir/Dialect/SCF/SCFOps.cpp.inc"
61       >();
62   addInterfaces<SCFInlinerInterface>();
63 }
64 
65 /// Default callback for IfOp builders. Inserts a yield without arguments.
buildTerminatedBody(OpBuilder & builder,Location loc)66 void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) {
67   builder.create<scf::YieldOp>(loc);
68 }
69 
70 //===----------------------------------------------------------------------===//
71 // ForOp
72 //===----------------------------------------------------------------------===//
73 
build(OpBuilder & builder,OperationState & result,Value lb,Value ub,Value step,ValueRange iterArgs,BodyBuilderFn bodyBuilder)74 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
75                   Value ub, Value step, ValueRange iterArgs,
76                   BodyBuilderFn bodyBuilder) {
77   result.addOperands({lb, ub, step});
78   result.addOperands(iterArgs);
79   for (Value v : iterArgs)
80     result.addTypes(v.getType());
81   Region *bodyRegion = result.addRegion();
82   bodyRegion->push_back(new Block);
83   Block &bodyBlock = bodyRegion->front();
84   bodyBlock.addArgument(builder.getIndexType());
85   for (Value v : iterArgs)
86     bodyBlock.addArgument(v.getType());
87 
88   // Create the default terminator if the builder is not provided and if the
89   // iteration arguments are not provided. Otherwise, leave this to the caller
90   // because we don't know which values to return from the loop.
91   if (iterArgs.empty() && !bodyBuilder) {
92     ForOp::ensureTerminator(*bodyRegion, builder, result.location);
93   } else if (bodyBuilder) {
94     OpBuilder::InsertionGuard guard(builder);
95     builder.setInsertionPointToStart(&bodyBlock);
96     bodyBuilder(builder, result.location, bodyBlock.getArgument(0),
97                 bodyBlock.getArguments().drop_front());
98   }
99 }
100 
verify(ForOp op)101 static LogicalResult verify(ForOp op) {
102   if (auto cst = op.step().getDefiningOp<ConstantIndexOp>())
103     if (cst.getValue() <= 0)
104       return op.emitOpError("constant step operand must be positive");
105 
106   // Check that the body defines as single block argument for the induction
107   // variable.
108   auto *body = op.getBody();
109   if (!body->getArgument(0).getType().isIndex())
110     return op.emitOpError(
111         "expected body first argument to be an index argument for "
112         "the induction variable");
113 
114   auto opNumResults = op.getNumResults();
115   if (opNumResults == 0)
116     return success();
117   // If ForOp defines values, check that the number and types of
118   // the defined values match ForOp initial iter operands and backedge
119   // basic block arguments.
120   if (op.getNumIterOperands() != opNumResults)
121     return op.emitOpError(
122         "mismatch in number of loop-carried values and defined values");
123   if (op.getNumRegionIterArgs() != opNumResults)
124     return op.emitOpError(
125         "mismatch in number of basic block args and defined values");
126   auto iterOperands = op.getIterOperands();
127   auto iterArgs = op.getRegionIterArgs();
128   auto opResults = op.getResults();
129   unsigned i = 0;
130   for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) {
131     if (std::get<0>(e).getType() != std::get<2>(e).getType())
132       return op.emitOpError() << "types mismatch between " << i
133                               << "th iter operand and defined value";
134     if (std::get<1>(e).getType() != std::get<2>(e).getType())
135       return op.emitOpError() << "types mismatch between " << i
136                               << "th iter region arg and defined value";
137 
138     i++;
139   }
140 
141   return RegionBranchOpInterface::verifyTypes(op);
142 }
143 
144 /// Prints the initialization list in the form of
145 ///   <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
146 /// where 'inner' values are assumed to be region arguments and 'outer' values
147 /// are regular SSA values.
printInitializationList(OpAsmPrinter & p,Block::BlockArgListType blocksArgs,ValueRange initializers,StringRef prefix="")148 static void printInitializationList(OpAsmPrinter &p,
149                                     Block::BlockArgListType blocksArgs,
150                                     ValueRange initializers,
151                                     StringRef prefix = "") {
152   assert(blocksArgs.size() == initializers.size() &&
153          "expected same length of arguments and initializers");
154   if (initializers.empty())
155     return;
156 
157   p << prefix << '(';
158   llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
159     p << std::get<0>(it) << " = " << std::get<1>(it);
160   });
161   p << ")";
162 }
163 
print(OpAsmPrinter & p,ForOp op)164 static void print(OpAsmPrinter &p, ForOp op) {
165   p << op.getOperationName() << " " << op.getInductionVar() << " = "
166     << op.lowerBound() << " to " << op.upperBound() << " step " << op.step();
167 
168   printInitializationList(p, op.getRegionIterArgs(), op.getIterOperands(),
169                           " iter_args");
170   if (!op.getIterOperands().empty())
171     p << " -> (" << op.getIterOperands().getTypes() << ')';
172   p.printRegion(op.region(),
173                 /*printEntryBlockArgs=*/false,
174                 /*printBlockTerminators=*/op.hasIterOperands());
175   p.printOptionalAttrDict(op.getAttrs());
176 }
177 
parseForOp(OpAsmParser & parser,OperationState & result)178 static ParseResult parseForOp(OpAsmParser &parser, OperationState &result) {
179   auto &builder = parser.getBuilder();
180   OpAsmParser::OperandType inductionVariable, lb, ub, step;
181   // Parse the induction variable followed by '='.
182   if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual())
183     return failure();
184 
185   // Parse loop bounds.
186   Type indexType = builder.getIndexType();
187   if (parser.parseOperand(lb) ||
188       parser.resolveOperand(lb, indexType, result.operands) ||
189       parser.parseKeyword("to") || parser.parseOperand(ub) ||
190       parser.resolveOperand(ub, indexType, result.operands) ||
191       parser.parseKeyword("step") || parser.parseOperand(step) ||
192       parser.resolveOperand(step, indexType, result.operands))
193     return failure();
194 
195   // Parse the optional initial iteration arguments.
196   SmallVector<OpAsmParser::OperandType, 4> regionArgs, operands;
197   SmallVector<Type, 4> argTypes;
198   regionArgs.push_back(inductionVariable);
199 
200   if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
201     // Parse assignment list and results type list.
202     if (parser.parseAssignmentList(regionArgs, operands) ||
203         parser.parseArrowTypeList(result.types))
204       return failure();
205     // Resolve input operands.
206     for (auto operand_type : llvm::zip(operands, result.types))
207       if (parser.resolveOperand(std::get<0>(operand_type),
208                                 std::get<1>(operand_type), result.operands))
209         return failure();
210   }
211   // Induction variable.
212   argTypes.push_back(indexType);
213   // Loop carried variables
214   argTypes.append(result.types.begin(), result.types.end());
215   // Parse the body region.
216   Region *body = result.addRegion();
217   if (regionArgs.size() != argTypes.size())
218     return parser.emitError(
219         parser.getNameLoc(),
220         "mismatch in number of loop-carried values and defined values");
221 
222   if (parser.parseRegion(*body, regionArgs, argTypes))
223     return failure();
224 
225   ForOp::ensureTerminator(*body, builder, result.location);
226 
227   // Parse the optional attribute list.
228   if (parser.parseOptionalAttrDict(result.attributes))
229     return failure();
230 
231   return success();
232 }
233 
getLoopBody()234 Region &ForOp::getLoopBody() { return region(); }
235 
isDefinedOutsideOfLoop(Value value)236 bool ForOp::isDefinedOutsideOfLoop(Value value) {
237   return !region().isAncestor(value.getParentRegion());
238 }
239 
moveOutOfLoop(ArrayRef<Operation * > ops)240 LogicalResult ForOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
241   for (auto op : ops)
242     op->moveBefore(*this);
243   return success();
244 }
245 
getForInductionVarOwner(Value val)246 ForOp mlir::scf::getForInductionVarOwner(Value val) {
247   auto ivArg = val.dyn_cast<BlockArgument>();
248   if (!ivArg)
249     return ForOp();
250   assert(ivArg.getOwner() && "unlinked block argument");
251   auto *containingOp = ivArg.getOwner()->getParentOp();
252   return dyn_cast_or_null<ForOp>(containingOp);
253 }
254 
255 /// Return operands used when entering the region at 'index'. These operands
256 /// correspond to the loop iterator operands, i.e., those exclusing the
257 /// induction variable. LoopOp only has one region, so 0 is the only valid value
258 /// for `index`.
getSuccessorEntryOperands(unsigned index)259 OperandRange ForOp::getSuccessorEntryOperands(unsigned index) {
260   assert(index == 0 && "invalid region index");
261 
262   // The initial operands map to the loop arguments after the induction
263   // variable.
264   return initArgs();
265 }
266 
267 /// Given the region at `index`, or the parent operation if `index` is None,
268 /// return the successor regions. These are the regions that may be selected
269 /// during the flow of control. `operands` is a set of optional attributes that
270 /// correspond to a constant value for each operand, or null if that operand is
271 /// not a constant.
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)272 void ForOp::getSuccessorRegions(Optional<unsigned> index,
273                                 ArrayRef<Attribute> operands,
274                                 SmallVectorImpl<RegionSuccessor> &regions) {
275   // If the predecessor is the ForOp, branch into the body using the iterator
276   // arguments.
277   if (!index.hasValue()) {
278     regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
279     return;
280   }
281 
282   // Otherwise, the loop may branch back to itself or the parent operation.
283   assert(index.getValue() == 0 && "expected loop region");
284   regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
285   regions.push_back(RegionSuccessor(getResults()));
286 }
287 
getNumRegionInvocations(ArrayRef<Attribute> operands,SmallVectorImpl<int64_t> & countPerRegion)288 void ForOp::getNumRegionInvocations(ArrayRef<Attribute> operands,
289                                     SmallVectorImpl<int64_t> &countPerRegion) {
290   assert(countPerRegion.empty());
291   countPerRegion.resize(1);
292 
293   auto lb = operands[0].dyn_cast_or_null<IntegerAttr>();
294   auto ub = operands[1].dyn_cast_or_null<IntegerAttr>();
295   auto step = operands[2].dyn_cast_or_null<IntegerAttr>();
296 
297   // Loop bounds are not known statically.
298   if (!lb || !ub || !step || step.getValue().getSExtValue() == 0) {
299     countPerRegion[0] = -1;
300     return;
301   }
302 
303   countPerRegion[0] =
304       ceilDiv(ub.getValue().getSExtValue() - lb.getValue().getSExtValue(),
305               step.getValue().getSExtValue());
306 }
307 
buildLoopNest(OpBuilder & builder,Location loc,ValueRange lbs,ValueRange ubs,ValueRange steps,ValueRange iterArgs,function_ref<ValueVector (OpBuilder &,Location,ValueRange,ValueRange)> bodyBuilder)308 LoopNest mlir::scf::buildLoopNest(
309     OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
310     ValueRange steps, ValueRange iterArgs,
311     function_ref<ValueVector(OpBuilder &, Location, ValueRange, ValueRange)>
312         bodyBuilder) {
313   assert(lbs.size() == ubs.size() &&
314          "expected the same number of lower and upper bounds");
315   assert(lbs.size() == steps.size() &&
316          "expected the same number of lower bounds and steps");
317 
318   // If there are no bounds, call the body-building function and return early.
319   if (lbs.empty()) {
320     ValueVector results =
321         bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
322                     : ValueVector();
323     assert(results.size() == iterArgs.size() &&
324            "loop nest body must return as many values as loop has iteration "
325            "arguments");
326     return LoopNest();
327   }
328 
329   // First, create the loop structure iteratively using the body-builder
330   // callback of `ForOp::build`. Do not create `YieldOp`s yet.
331   OpBuilder::InsertionGuard guard(builder);
332   SmallVector<scf::ForOp, 4> loops;
333   SmallVector<Value, 4> ivs;
334   loops.reserve(lbs.size());
335   ivs.reserve(lbs.size());
336   ValueRange currentIterArgs = iterArgs;
337   Location currentLoc = loc;
338   for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
339     auto loop = builder.create<scf::ForOp>(
340         currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
341         [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
342             ValueRange args) {
343           ivs.push_back(iv);
344           // It is safe to store ValueRange args because it points to block
345           // arguments of a loop operation that we also own.
346           currentIterArgs = args;
347           currentLoc = nestedLoc;
348         });
349     // Set the builder to point to the body of the newly created loop. We don't
350     // do this in the callback because the builder is reset when the callback
351     // returns.
352     builder.setInsertionPointToStart(loop.getBody());
353     loops.push_back(loop);
354   }
355 
356   // For all loops but the innermost, yield the results of the nested loop.
357   for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
358     builder.setInsertionPointToEnd(loops[i].getBody());
359     builder.create<scf::YieldOp>(loc, loops[i + 1].getResults());
360   }
361 
362   // In the body of the innermost loop, call the body building function if any
363   // and yield its results.
364   builder.setInsertionPointToStart(loops.back().getBody());
365   ValueVector results = bodyBuilder
366                             ? bodyBuilder(builder, currentLoc, ivs,
367                                           loops.back().getRegionIterArgs())
368                             : ValueVector();
369   assert(results.size() == iterArgs.size() &&
370          "loop nest body must return as many values as loop has iteration "
371          "arguments");
372   builder.setInsertionPointToEnd(loops.back().getBody());
373   builder.create<scf::YieldOp>(loc, results);
374 
375   // Return the loops.
376   LoopNest res;
377   res.loops.assign(loops.begin(), loops.end());
378   return res;
379 }
380 
buildLoopNest(OpBuilder & builder,Location loc,ValueRange lbs,ValueRange ubs,ValueRange steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilder)381 LoopNest mlir::scf::buildLoopNest(
382     OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
383     ValueRange steps,
384     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
385   // Delegate to the main function by wrapping the body builder.
386   return buildLoopNest(builder, loc, lbs, ubs, steps, llvm::None,
387                        [&bodyBuilder](OpBuilder &nestedBuilder,
388                                       Location nestedLoc, ValueRange ivs,
389                                       ValueRange) -> ValueVector {
390                          if (bodyBuilder)
391                            bodyBuilder(nestedBuilder, nestedLoc, ivs);
392                          return {};
393                        });
394 }
395 
396 /// Replaces the given op with the contents of the given single-block region,
397 /// using the operands of the block terminator to replace operation results.
replaceOpWithRegion(PatternRewriter & rewriter,Operation * op,Region & region,ValueRange blockArgs={})398 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
399                                 Region &region, ValueRange blockArgs = {}) {
400   assert(llvm::hasSingleElement(region) && "expected single-region block");
401   Block *block = &region.front();
402   Operation *terminator = block->getTerminator();
403   ValueRange results = terminator->getOperands();
404   rewriter.mergeBlockBefore(block, op, blockArgs);
405   rewriter.replaceOp(op, results);
406   rewriter.eraseOp(terminator);
407 }
408 
409 namespace {
410 // Fold away ForOp iter arguments that are also yielded by the op.
411 // These arguments must be defined outside of the ForOp region and can just be
412 // forwarded after simplifying the op inits, yields and returns.
413 //
414 // The implementation uses `mergeBlockBefore` to steal the content of the
415 // original ForOp and avoid cloning.
416 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
417   using OpRewritePattern<scf::ForOp>::OpRewritePattern;
418 
matchAndRewrite__anon4ade4bc50511::ForOpIterArgsFolder419   LogicalResult matchAndRewrite(scf::ForOp forOp,
420                                 PatternRewriter &rewriter) const final {
421     bool canonicalize = false;
422     Block &block = forOp.region().front();
423     auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
424 
425     // An internal flat vector of block transfer
426     // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
427     // transformed block argument mappings. This plays the role of a
428     // BlockAndValueMapping for the particular use case of calling into
429     // `mergeBlockBefore`.
430     SmallVector<bool, 4> keepMask;
431     keepMask.reserve(yieldOp.getNumOperands());
432     SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
433         newResultValues;
434     newBlockTransferArgs.reserve(1 + forOp.getNumIterOperands());
435     newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
436     newIterArgs.reserve(forOp.getNumIterOperands());
437     newYieldValues.reserve(yieldOp.getNumOperands());
438     newResultValues.reserve(forOp.getNumResults());
439     for (auto it : llvm::zip(forOp.getIterOperands(),   // iter from outside
440                              forOp.getRegionIterArgs(), // iter inside region
441                              yieldOp.getOperands()      // iter yield
442                              )) {
443       // Forwarded is `true` when the region `iter` argument is yielded.
444       bool forwarded = (std::get<1>(it) == std::get<2>(it));
445       keepMask.push_back(!forwarded);
446       canonicalize |= forwarded;
447       if (forwarded) {
448         newBlockTransferArgs.push_back(std::get<0>(it));
449         newResultValues.push_back(std::get<0>(it));
450         continue;
451       }
452       newIterArgs.push_back(std::get<0>(it));
453       newYieldValues.push_back(std::get<2>(it));
454       newBlockTransferArgs.push_back(Value()); // placeholder with null value
455       newResultValues.push_back(Value());      // placeholder with null value
456     }
457 
458     if (!canonicalize)
459       return failure();
460 
461     scf::ForOp newForOp = rewriter.create<scf::ForOp>(
462         forOp.getLoc(), forOp.lowerBound(), forOp.upperBound(), forOp.step(),
463         newIterArgs);
464     Block &newBlock = newForOp.region().front();
465 
466     // Replace the null placeholders with newly constructed values.
467     newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
468     for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
469          idx != e; ++idx) {
470       Value &blockTransferArg = newBlockTransferArgs[1 + idx];
471       Value &newResultVal = newResultValues[idx];
472       assert((blockTransferArg && newResultVal) ||
473              (!blockTransferArg && !newResultVal));
474       if (!blockTransferArg) {
475         blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
476         newResultVal = newForOp.getResult(collapsedIdx++);
477       }
478     }
479 
480     Block &oldBlock = forOp.region().front();
481     assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
482            "unexpected argument size mismatch");
483 
484     // No results case: the scf::ForOp builder already created a zero
485     // reult terminator. Merge before this terminator and just get rid of the
486     // original terminator that has been merged in.
487     if (newIterArgs.empty()) {
488       auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
489       rewriter.mergeBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
490       rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
491       rewriter.replaceOp(forOp, newResultValues);
492       return success();
493     }
494 
495     // No terminator case: merge and rewrite the merged terminator.
496     auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
497       OpBuilder::InsertionGuard g(rewriter);
498       rewriter.setInsertionPoint(mergedTerminator);
499       SmallVector<Value, 4> filteredOperands;
500       filteredOperands.reserve(newResultValues.size());
501       for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
502         if (keepMask[idx])
503           filteredOperands.push_back(mergedTerminator.getOperand(idx));
504       rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(),
505                                     filteredOperands);
506     };
507 
508     rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
509     auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
510     cloneFilteredTerminator(mergedYieldOp);
511     rewriter.eraseOp(mergedYieldOp);
512     rewriter.replaceOp(forOp, newResultValues);
513     return success();
514   }
515 };
516 
517 /// Rewriting pattern that erases loops that are known not to iterate and
518 /// replaces single-iteration loops with their bodies.
519 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
520   using OpRewritePattern<ForOp>::OpRewritePattern;
521 
matchAndRewrite__anon4ade4bc50511::SimplifyTrivialLoops522   LogicalResult matchAndRewrite(ForOp op,
523                                 PatternRewriter &rewriter) const override {
524     // If the upper bound is the same as the lower bound, the loop does not
525     // iterate, just remove it.
526     if (op.lowerBound() == op.upperBound()) {
527       rewriter.replaceOp(op, op.getIterOperands());
528       return success();
529     }
530 
531     auto lb = op.lowerBound().getDefiningOp<ConstantOp>();
532     auto ub = op.upperBound().getDefiningOp<ConstantOp>();
533     if (!lb || !ub)
534       return failure();
535 
536     // If the loop is known to have 0 iterations, remove it.
537     llvm::APInt lbValue = lb.getValue().cast<IntegerAttr>().getValue();
538     llvm::APInt ubValue = ub.getValue().cast<IntegerAttr>().getValue();
539     if (lbValue.sge(ubValue)) {
540       rewriter.replaceOp(op, op.getIterOperands());
541       return success();
542     }
543 
544     auto step = op.step().getDefiningOp<ConstantOp>();
545     if (!step)
546       return failure();
547 
548     // If the loop is known to have 1 iteration, inline its body and remove the
549     // loop.
550     llvm::APInt stepValue = lb.getValue().cast<IntegerAttr>().getValue();
551     if ((lbValue + stepValue).sge(ubValue)) {
552       SmallVector<Value, 4> blockArgs;
553       blockArgs.reserve(op.getNumIterOperands() + 1);
554       blockArgs.push_back(op.lowerBound());
555       llvm::append_range(blockArgs, op.getIterOperands());
556       replaceOpWithRegion(rewriter, op, op.getLoopBody(), blockArgs);
557       return success();
558     }
559 
560     return failure();
561   }
562 };
563 } // namespace
564 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)565 void ForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
566                                         MLIRContext *context) {
567   results.insert<ForOpIterArgsFolder, SimplifyTrivialLoops>(context);
568 }
569 
570 //===----------------------------------------------------------------------===//
571 // IfOp
572 //===----------------------------------------------------------------------===//
573 
build(OpBuilder & builder,OperationState & result,Value cond,bool withElseRegion)574 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
575                  bool withElseRegion) {
576   build(builder, result, /*resultTypes=*/llvm::None, cond, withElseRegion);
577 }
578 
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,Value cond,bool withElseRegion)579 void IfOp::build(OpBuilder &builder, OperationState &result,
580                  TypeRange resultTypes, Value cond, bool withElseRegion) {
581   auto addTerminator = [&](OpBuilder &nested, Location loc) {
582     if (resultTypes.empty())
583       IfOp::ensureTerminator(*nested.getInsertionBlock()->getParent(), nested,
584                              loc);
585   };
586 
587   build(builder, result, resultTypes, cond, addTerminator,
588         withElseRegion ? addTerminator
589                        : function_ref<void(OpBuilder &, Location)>());
590 }
591 
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,Value cond,function_ref<void (OpBuilder &,Location)> thenBuilder,function_ref<void (OpBuilder &,Location)> elseBuilder)592 void IfOp::build(OpBuilder &builder, OperationState &result,
593                  TypeRange resultTypes, Value cond,
594                  function_ref<void(OpBuilder &, Location)> thenBuilder,
595                  function_ref<void(OpBuilder &, Location)> elseBuilder) {
596   assert(thenBuilder && "the builder callback for 'then' must be present");
597 
598   result.addOperands(cond);
599   result.addTypes(resultTypes);
600 
601   OpBuilder::InsertionGuard guard(builder);
602   Region *thenRegion = result.addRegion();
603   builder.createBlock(thenRegion);
604   thenBuilder(builder, result.location);
605 
606   Region *elseRegion = result.addRegion();
607   if (!elseBuilder)
608     return;
609 
610   builder.createBlock(elseRegion);
611   elseBuilder(builder, result.location);
612 }
613 
build(OpBuilder & builder,OperationState & result,Value cond,function_ref<void (OpBuilder &,Location)> thenBuilder,function_ref<void (OpBuilder &,Location)> elseBuilder)614 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
615                  function_ref<void(OpBuilder &, Location)> thenBuilder,
616                  function_ref<void(OpBuilder &, Location)> elseBuilder) {
617   build(builder, result, TypeRange(), cond, thenBuilder, elseBuilder);
618 }
619 
verify(IfOp op)620 static LogicalResult verify(IfOp op) {
621   if (op.getNumResults() != 0 && op.elseRegion().empty())
622     return op.emitOpError("must have an else block if defining values");
623 
624   return RegionBranchOpInterface::verifyTypes(op);
625 }
626 
parseIfOp(OpAsmParser & parser,OperationState & result)627 static ParseResult parseIfOp(OpAsmParser &parser, OperationState &result) {
628   // Create the regions for 'then'.
629   result.regions.reserve(2);
630   Region *thenRegion = result.addRegion();
631   Region *elseRegion = result.addRegion();
632 
633   auto &builder = parser.getBuilder();
634   OpAsmParser::OperandType cond;
635   Type i1Type = builder.getIntegerType(1);
636   if (parser.parseOperand(cond) ||
637       parser.resolveOperand(cond, i1Type, result.operands))
638     return failure();
639   // Parse optional results type list.
640   if (parser.parseOptionalArrowTypeList(result.types))
641     return failure();
642   // Parse the 'then' region.
643   if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
644     return failure();
645   IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
646 
647   // If we find an 'else' keyword then parse the 'else' region.
648   if (!parser.parseOptionalKeyword("else")) {
649     if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
650       return failure();
651     IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
652   }
653 
654   // Parse the optional attribute list.
655   if (parser.parseOptionalAttrDict(result.attributes))
656     return failure();
657   return success();
658 }
659 
print(OpAsmPrinter & p,IfOp op)660 static void print(OpAsmPrinter &p, IfOp op) {
661   bool printBlockTerminators = false;
662 
663   p << IfOp::getOperationName() << " " << op.condition();
664   if (!op.results().empty()) {
665     p << " -> (" << op.getResultTypes() << ")";
666     // Print yield explicitly if the op defines values.
667     printBlockTerminators = true;
668   }
669   p.printRegion(op.thenRegion(),
670                 /*printEntryBlockArgs=*/false,
671                 /*printBlockTerminators=*/printBlockTerminators);
672 
673   // Print the 'else' regions if it exists and has a block.
674   auto &elseRegion = op.elseRegion();
675   if (!elseRegion.empty()) {
676     p << " else";
677     p.printRegion(elseRegion,
678                   /*printEntryBlockArgs=*/false,
679                   /*printBlockTerminators=*/printBlockTerminators);
680   }
681 
682   p.printOptionalAttrDict(op.getAttrs());
683 }
684 
685 /// Given the region at `index`, or the parent operation if `index` is None,
686 /// return the successor regions. These are the regions that may be selected
687 /// during the flow of control. `operands` is a set of optional attributes that
688 /// correspond to a constant value for each operand, or null if that operand is
689 /// not a constant.
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)690 void IfOp::getSuccessorRegions(Optional<unsigned> index,
691                                ArrayRef<Attribute> operands,
692                                SmallVectorImpl<RegionSuccessor> &regions) {
693   // The `then` and the `else` region branch back to the parent operation.
694   if (index.hasValue()) {
695     regions.push_back(RegionSuccessor(getResults()));
696     return;
697   }
698 
699   // Don't consider the else region if it is empty.
700   Region *elseRegion = &this->elseRegion();
701   if (elseRegion->empty())
702     elseRegion = nullptr;
703 
704   // Otherwise, the successor is dependent on the condition.
705   bool condition;
706   if (auto condAttr = operands.front().dyn_cast_or_null<IntegerAttr>()) {
707     condition = condAttr.getValue().isOneValue();
708   } else {
709     // If the condition isn't constant, both regions may be executed.
710     regions.push_back(RegionSuccessor(&thenRegion()));
711     regions.push_back(RegionSuccessor(elseRegion));
712     return;
713   }
714 
715   // Add the successor regions using the condition.
716   regions.push_back(RegionSuccessor(condition ? &thenRegion() : elseRegion));
717 }
718 
719 namespace {
720 // Pattern to remove unused IfOp results.
721 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
722   using OpRewritePattern<IfOp>::OpRewritePattern;
723 
transferBody__anon4ade4bc50811::RemoveUnusedResults724   void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
725                     PatternRewriter &rewriter) const {
726     // Move all operations to the destination block.
727     rewriter.mergeBlocks(source, dest);
728     // Replace the yield op by one that returns only the used values.
729     auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
730     SmallVector<Value, 4> usedOperands;
731     llvm::transform(usedResults, std::back_inserter(usedOperands),
732                     [&](OpResult result) {
733                       return yieldOp.getOperand(result.getResultNumber());
734                     });
735     rewriter.updateRootInPlace(yieldOp,
736                                [&]() { yieldOp->setOperands(usedOperands); });
737   }
738 
matchAndRewrite__anon4ade4bc50811::RemoveUnusedResults739   LogicalResult matchAndRewrite(IfOp op,
740                                 PatternRewriter &rewriter) const override {
741     // Compute the list of used results.
742     SmallVector<OpResult, 4> usedResults;
743     llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
744                   [](OpResult result) { return !result.use_empty(); });
745 
746     // Replace the operation if only a subset of its results have uses.
747     if (usedResults.size() == op.getNumResults())
748       return failure();
749 
750     // Compute the result types of the replacement operation.
751     SmallVector<Type, 4> newTypes;
752     llvm::transform(usedResults, std::back_inserter(newTypes),
753                     [](OpResult result) { return result.getType(); });
754 
755     // Create a replacement operation with empty then and else regions.
756     auto emptyBuilder = [](OpBuilder &, Location) {};
757     auto newOp = rewriter.create<IfOp>(op.getLoc(), newTypes, op.condition(),
758                                        emptyBuilder, emptyBuilder);
759 
760     // Move the bodies and replace the terminators (note there is a then and
761     // an else region since the operation returns results).
762     transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
763     transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
764 
765     // Replace the operation by the new one.
766     SmallVector<Value, 4> repResults(op.getNumResults());
767     for (auto en : llvm::enumerate(usedResults))
768       repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
769     rewriter.replaceOp(op, repResults);
770     return success();
771   }
772 };
773 
774 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
775   using OpRewritePattern<IfOp>::OpRewritePattern;
776 
matchAndRewrite__anon4ade4bc50811::RemoveStaticCondition777   LogicalResult matchAndRewrite(IfOp op,
778                                 PatternRewriter &rewriter) const override {
779     auto constant = op.condition().getDefiningOp<ConstantOp>();
780     if (!constant)
781       return failure();
782 
783     if (constant.getValue().cast<BoolAttr>().getValue())
784       replaceOpWithRegion(rewriter, op, op.thenRegion());
785     else if (!op.elseRegion().empty())
786       replaceOpWithRegion(rewriter, op, op.elseRegion());
787     else
788       rewriter.eraseOp(op);
789 
790     return success();
791   }
792 };
793 } // namespace
794 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)795 void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
796                                        MLIRContext *context) {
797   results.insert<RemoveUnusedResults, RemoveStaticCondition>(context);
798 }
799 
800 //===----------------------------------------------------------------------===//
801 // ParallelOp
802 //===----------------------------------------------------------------------===//
803 
build(OpBuilder & builder,OperationState & result,ValueRange lowerBounds,ValueRange upperBounds,ValueRange steps,ValueRange initVals,function_ref<void (OpBuilder &,Location,ValueRange,ValueRange)> bodyBuilderFn)804 void ParallelOp::build(
805     OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
806     ValueRange upperBounds, ValueRange steps, ValueRange initVals,
807     function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
808         bodyBuilderFn) {
809   result.addOperands(lowerBounds);
810   result.addOperands(upperBounds);
811   result.addOperands(steps);
812   result.addOperands(initVals);
813   result.addAttribute(
814       ParallelOp::getOperandSegmentSizeAttr(),
815       builder.getI32VectorAttr({static_cast<int32_t>(lowerBounds.size()),
816                                 static_cast<int32_t>(upperBounds.size()),
817                                 static_cast<int32_t>(steps.size()),
818                                 static_cast<int32_t>(initVals.size())}));
819   result.addTypes(initVals.getTypes());
820 
821   OpBuilder::InsertionGuard guard(builder);
822   unsigned numIVs = steps.size();
823   SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
824   Region *bodyRegion = result.addRegion();
825   Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes);
826 
827   if (bodyBuilderFn) {
828     builder.setInsertionPointToStart(bodyBlock);
829     bodyBuilderFn(builder, result.location,
830                   bodyBlock->getArguments().take_front(numIVs),
831                   bodyBlock->getArguments().drop_front(numIVs));
832   }
833   ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
834 }
835 
build(OpBuilder & builder,OperationState & result,ValueRange lowerBounds,ValueRange upperBounds,ValueRange steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn)836 void ParallelOp::build(
837     OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
838     ValueRange upperBounds, ValueRange steps,
839     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
840   // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
841   // we don't capture a reference to a temporary by constructing the lambda at
842   // function level.
843   auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
844                                            Location nestedLoc, ValueRange ivs,
845                                            ValueRange) {
846     bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
847   };
848   function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
849   if (bodyBuilderFn)
850     wrapper = wrappedBuilderFn;
851 
852   build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
853         wrapper);
854 }
855 
verify(ParallelOp op)856 static LogicalResult verify(ParallelOp op) {
857   // Check that there is at least one value in lowerBound, upperBound and step.
858   // It is sufficient to test only step, because it is ensured already that the
859   // number of elements in lowerBound, upperBound and step are the same.
860   Operation::operand_range stepValues = op.step();
861   if (stepValues.empty())
862     return op.emitOpError(
863         "needs at least one tuple element for lowerBound, upperBound and step");
864 
865   // Check whether all constant step values are positive.
866   for (Value stepValue : stepValues)
867     if (auto cst = stepValue.getDefiningOp<ConstantIndexOp>())
868       if (cst.getValue() <= 0)
869         return op.emitOpError("constant step operand must be positive");
870 
871   // Check that the body defines the same number of block arguments as the
872   // number of tuple elements in step.
873   Block *body = op.getBody();
874   if (body->getNumArguments() != stepValues.size())
875     return op.emitOpError()
876            << "expects the same number of induction variables: "
877            << body->getNumArguments()
878            << " as bound and step values: " << stepValues.size();
879   for (auto arg : body->getArguments())
880     if (!arg.getType().isIndex())
881       return op.emitOpError(
882           "expects arguments for the induction variable to be of index type");
883 
884   // Check that the yield has no results
885   Operation *yield = body->getTerminator();
886   if (yield->getNumOperands() != 0)
887     return yield->emitOpError() << "not allowed to have operands inside '"
888                                 << ParallelOp::getOperationName() << "'";
889 
890   // Check that the number of results is the same as the number of ReduceOps.
891   SmallVector<ReduceOp, 4> reductions(body->getOps<ReduceOp>());
892   auto resultsSize = op.results().size();
893   auto reductionsSize = reductions.size();
894   auto initValsSize = op.initVals().size();
895   if (resultsSize != reductionsSize)
896     return op.emitOpError()
897            << "expects number of results: " << resultsSize
898            << " to be the same as number of reductions: " << reductionsSize;
899   if (resultsSize != initValsSize)
900     return op.emitOpError()
901            << "expects number of results: " << resultsSize
902            << " to be the same as number of initial values: " << initValsSize;
903 
904   // Check that the types of the results and reductions are the same.
905   for (auto resultAndReduce : llvm::zip(op.results(), reductions)) {
906     auto resultType = std::get<0>(resultAndReduce).getType();
907     auto reduceOp = std::get<1>(resultAndReduce);
908     auto reduceType = reduceOp.operand().getType();
909     if (resultType != reduceType)
910       return reduceOp.emitOpError()
911              << "expects type of reduce: " << reduceType
912              << " to be the same as result type: " << resultType;
913   }
914   return success();
915 }
916 
parseParallelOp(OpAsmParser & parser,OperationState & result)917 static ParseResult parseParallelOp(OpAsmParser &parser,
918                                    OperationState &result) {
919   auto &builder = parser.getBuilder();
920   // Parse an opening `(` followed by induction variables followed by `)`
921   SmallVector<OpAsmParser::OperandType, 4> ivs;
922   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
923                                      OpAsmParser::Delimiter::Paren))
924     return failure();
925 
926   // Parse loop bounds.
927   SmallVector<OpAsmParser::OperandType, 4> lower;
928   if (parser.parseEqual() ||
929       parser.parseOperandList(lower, ivs.size(),
930                               OpAsmParser::Delimiter::Paren) ||
931       parser.resolveOperands(lower, builder.getIndexType(), result.operands))
932     return failure();
933 
934   SmallVector<OpAsmParser::OperandType, 4> upper;
935   if (parser.parseKeyword("to") ||
936       parser.parseOperandList(upper, ivs.size(),
937                               OpAsmParser::Delimiter::Paren) ||
938       parser.resolveOperands(upper, builder.getIndexType(), result.operands))
939     return failure();
940 
941   // Parse step values.
942   SmallVector<OpAsmParser::OperandType, 4> steps;
943   if (parser.parseKeyword("step") ||
944       parser.parseOperandList(steps, ivs.size(),
945                               OpAsmParser::Delimiter::Paren) ||
946       parser.resolveOperands(steps, builder.getIndexType(), result.operands))
947     return failure();
948 
949   // Parse init values.
950   SmallVector<OpAsmParser::OperandType, 4> initVals;
951   if (succeeded(parser.parseOptionalKeyword("init"))) {
952     if (parser.parseOperandList(initVals, /*requiredOperandCount=*/-1,
953                                 OpAsmParser::Delimiter::Paren))
954       return failure();
955   }
956 
957   // Parse optional results in case there is a reduce.
958   if (parser.parseOptionalArrowTypeList(result.types))
959     return failure();
960 
961   // Now parse the body.
962   Region *body = result.addRegion();
963   SmallVector<Type, 4> types(ivs.size(), builder.getIndexType());
964   if (parser.parseRegion(*body, ivs, types))
965     return failure();
966 
967   // Set `operand_segment_sizes` attribute.
968   result.addAttribute(
969       ParallelOp::getOperandSegmentSizeAttr(),
970       builder.getI32VectorAttr({static_cast<int32_t>(lower.size()),
971                                 static_cast<int32_t>(upper.size()),
972                                 static_cast<int32_t>(steps.size()),
973                                 static_cast<int32_t>(initVals.size())}));
974 
975   // Parse attributes.
976   if (parser.parseOptionalAttrDict(result.attributes))
977     return failure();
978 
979   if (!initVals.empty())
980     parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
981                            result.operands);
982   // Add a terminator if none was parsed.
983   ForOp::ensureTerminator(*body, builder, result.location);
984 
985   return success();
986 }
987 
print(OpAsmPrinter & p,ParallelOp op)988 static void print(OpAsmPrinter &p, ParallelOp op) {
989   p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = ("
990     << op.lowerBound() << ") to (" << op.upperBound() << ") step (" << op.step()
991     << ")";
992   if (!op.initVals().empty())
993     p << " init (" << op.initVals() << ")";
994   p.printOptionalArrowTypeList(op.getResultTypes());
995   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
996   p.printOptionalAttrDict(
997       op.getAttrs(), /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
998 }
999 
getLoopBody()1000 Region &ParallelOp::getLoopBody() { return region(); }
1001 
isDefinedOutsideOfLoop(Value value)1002 bool ParallelOp::isDefinedOutsideOfLoop(Value value) {
1003   return !region().isAncestor(value.getParentRegion());
1004 }
1005 
moveOutOfLoop(ArrayRef<Operation * > ops)1006 LogicalResult ParallelOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
1007   for (auto op : ops)
1008     op->moveBefore(*this);
1009   return success();
1010 }
1011 
getParallelForInductionVarOwner(Value val)1012 ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) {
1013   auto ivArg = val.dyn_cast<BlockArgument>();
1014   if (!ivArg)
1015     return ParallelOp();
1016   assert(ivArg.getOwner() && "unlinked block argument");
1017   auto *containingOp = ivArg.getOwner()->getParentOp();
1018   return dyn_cast<ParallelOp>(containingOp);
1019 }
1020 
1021 namespace {
1022 // Collapse loop dimensions that perform a single iteration.
1023 struct CollapseSingleIterationLoops : public OpRewritePattern<ParallelOp> {
1024   using OpRewritePattern<ParallelOp>::OpRewritePattern;
1025 
matchAndRewrite__anon4ade4bc50f11::CollapseSingleIterationLoops1026   LogicalResult matchAndRewrite(ParallelOp op,
1027                                 PatternRewriter &rewriter) const override {
1028     BlockAndValueMapping mapping;
1029     // Compute new loop bounds that omit all single-iteration loop dimensions.
1030     SmallVector<Value, 2> newLowerBounds;
1031     SmallVector<Value, 2> newUpperBounds;
1032     SmallVector<Value, 2> newSteps;
1033     newLowerBounds.reserve(op.lowerBound().size());
1034     newUpperBounds.reserve(op.upperBound().size());
1035     newSteps.reserve(op.step().size());
1036     for (auto dim : llvm::zip(op.lowerBound(), op.upperBound(), op.step(),
1037                               op.getInductionVars())) {
1038       Value lowerBound, upperBound, step, iv;
1039       std::tie(lowerBound, upperBound, step, iv) = dim;
1040       // Collect the statically known loop bounds.
1041       auto lowerBoundConstant =
1042           dyn_cast_or_null<ConstantIndexOp>(lowerBound.getDefiningOp());
1043       auto upperBoundConstant =
1044           dyn_cast_or_null<ConstantIndexOp>(upperBound.getDefiningOp());
1045       auto stepConstant =
1046           dyn_cast_or_null<ConstantIndexOp>(step.getDefiningOp());
1047       // Replace the loop induction variable by the lower bound if the loop
1048       // performs a single iteration. Otherwise, copy the loop bounds.
1049       if (lowerBoundConstant && upperBoundConstant && stepConstant &&
1050           (upperBoundConstant.getValue() - lowerBoundConstant.getValue()) > 0 &&
1051           (upperBoundConstant.getValue() - lowerBoundConstant.getValue()) <=
1052               stepConstant.getValue()) {
1053         mapping.map(iv, lowerBound);
1054       } else {
1055         newLowerBounds.push_back(lowerBound);
1056         newUpperBounds.push_back(upperBound);
1057         newSteps.push_back(step);
1058       }
1059     }
1060     // Exit if all or none of the loop dimensions perform a single iteration.
1061     if (newLowerBounds.size() == 0 ||
1062         newLowerBounds.size() == op.lowerBound().size())
1063       return failure();
1064     // Replace the parallel loop by lower-dimensional parallel loop.
1065     auto newOp =
1066         rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
1067                                     newSteps, op.initVals(), nullptr);
1068     // Clone the loop body and remap the block arguments of the collapsed loops
1069     // (inlining does not support a cancellable block argument mapping).
1070     rewriter.cloneRegionBefore(op.region(), newOp.region(),
1071                                newOp.region().begin(), mapping);
1072     rewriter.replaceOp(op, newOp.getResults());
1073     return success();
1074   }
1075 };
1076 
1077 /// Removes parallel loops in which at least one lower/upper bound pair consists
1078 /// of the same values - such loops have an empty iteration domain.
1079 struct RemoveEmptyParallelLoops : public OpRewritePattern<ParallelOp> {
1080   using OpRewritePattern<ParallelOp>::OpRewritePattern;
1081 
matchAndRewrite__anon4ade4bc50f11::RemoveEmptyParallelLoops1082   LogicalResult matchAndRewrite(ParallelOp op,
1083                                 PatternRewriter &rewriter) const override {
1084     for (auto dim : llvm::zip(op.lowerBound(), op.upperBound())) {
1085       if (std::get<0>(dim) == std::get<1>(dim)) {
1086         rewriter.replaceOp(op, op.initVals());
1087         return success();
1088       }
1089     }
1090     return failure();
1091   }
1092 };
1093 
1094 } // namespace
1095 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1096 void ParallelOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1097                                              MLIRContext *context) {
1098   results.insert<CollapseSingleIterationLoops, RemoveEmptyParallelLoops>(
1099       context);
1100 }
1101 
1102 //===----------------------------------------------------------------------===//
1103 // ReduceOp
1104 //===----------------------------------------------------------------------===//
1105 
build(OpBuilder & builder,OperationState & result,Value operand,function_ref<void (OpBuilder &,Location,Value,Value)> bodyBuilderFn)1106 void ReduceOp::build(
1107     OpBuilder &builder, OperationState &result, Value operand,
1108     function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuilderFn) {
1109   auto type = operand.getType();
1110   result.addOperands(operand);
1111 
1112   OpBuilder::InsertionGuard guard(builder);
1113   Region *bodyRegion = result.addRegion();
1114   Block *body = builder.createBlock(bodyRegion, {}, ArrayRef<Type>{type, type});
1115   if (bodyBuilderFn)
1116     bodyBuilderFn(builder, result.location, body->getArgument(0),
1117                   body->getArgument(1));
1118 }
1119 
verify(ReduceOp op)1120 static LogicalResult verify(ReduceOp op) {
1121   // The region of a ReduceOp has two arguments of the same type as its operand.
1122   auto type = op.operand().getType();
1123   Block &block = op.reductionOperator().front();
1124   if (block.empty())
1125     return op.emitOpError("the block inside reduce should not be empty");
1126   if (block.getNumArguments() != 2 ||
1127       llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
1128         return arg.getType() != type;
1129       }))
1130     return op.emitOpError()
1131            << "expects two arguments to reduce block of type " << type;
1132 
1133   // Check that the block is terminated by a ReduceReturnOp.
1134   if (!isa<ReduceReturnOp>(block.getTerminator()))
1135     return op.emitOpError("the block inside reduce should be terminated with a "
1136                           "'scf.reduce.return' op");
1137 
1138   return success();
1139 }
1140 
parseReduceOp(OpAsmParser & parser,OperationState & result)1141 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
1142   // Parse an opening `(` followed by the reduced value followed by `)`
1143   OpAsmParser::OperandType operand;
1144   if (parser.parseLParen() || parser.parseOperand(operand) ||
1145       parser.parseRParen())
1146     return failure();
1147 
1148   Type resultType;
1149   // Parse the type of the operand (and also what reduce computes on).
1150   if (parser.parseColonType(resultType) ||
1151       parser.resolveOperand(operand, resultType, result.operands))
1152     return failure();
1153 
1154   // Now parse the body.
1155   Region *body = result.addRegion();
1156   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1157     return failure();
1158 
1159   return success();
1160 }
1161 
print(OpAsmPrinter & p,ReduceOp op)1162 static void print(OpAsmPrinter &p, ReduceOp op) {
1163   p << op.getOperationName() << "(" << op.operand() << ") ";
1164   p << " : " << op.operand().getType();
1165   p.printRegion(op.reductionOperator());
1166 }
1167 
1168 //===----------------------------------------------------------------------===//
1169 // ReduceReturnOp
1170 //===----------------------------------------------------------------------===//
1171 
verify(ReduceReturnOp op)1172 static LogicalResult verify(ReduceReturnOp op) {
1173   // The type of the return value should be the same type as the type of the
1174   // operand of the enclosing ReduceOp.
1175   auto reduceOp = cast<ReduceOp>(op->getParentOp());
1176   Type reduceType = reduceOp.operand().getType();
1177   if (reduceType != op.result().getType())
1178     return op.emitOpError() << "needs to have type " << reduceType
1179                             << " (the type of the enclosing ReduceOp)";
1180   return success();
1181 }
1182 
1183 //===----------------------------------------------------------------------===//
1184 // WhileOp
1185 //===----------------------------------------------------------------------===//
1186 
getSuccessorEntryOperands(unsigned index)1187 OperandRange WhileOp::getSuccessorEntryOperands(unsigned index) {
1188   assert(index == 0 &&
1189          "WhileOp is expected to branch only to the first region");
1190 
1191   return inits();
1192 }
1193 
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)1194 void WhileOp::getSuccessorRegions(Optional<unsigned> index,
1195                                   ArrayRef<Attribute> operands,
1196                                   SmallVectorImpl<RegionSuccessor> &regions) {
1197   (void)operands;
1198 
1199   if (!index.hasValue()) {
1200     regions.emplace_back(&before(), before().getArguments());
1201     return;
1202   }
1203 
1204   assert(*index < 2 && "there are only two regions in a WhileOp");
1205   if (*index == 0) {
1206     regions.emplace_back(&after(), after().getArguments());
1207     regions.emplace_back(getResults());
1208     return;
1209   }
1210 
1211   regions.emplace_back(&before(), before().getArguments());
1212 }
1213 
1214 /// Parses a `while` op.
1215 ///
1216 /// op ::= `scf.while` assignments `:` function-type region `do` region
1217 ///         `attributes` attribute-dict
1218 /// initializer ::= /* empty */ | `(` assignment-list `)`
1219 /// assignment-list ::= assignment | assignment `,` assignment-list
1220 /// assignment ::= ssa-value `=` ssa-value
parseWhileOp(OpAsmParser & parser,OperationState & result)1221 static ParseResult parseWhileOp(OpAsmParser &parser, OperationState &result) {
1222   SmallVector<OpAsmParser::OperandType, 4> regionArgs, operands;
1223   Region *before = result.addRegion();
1224   Region *after = result.addRegion();
1225 
1226   OptionalParseResult listResult =
1227       parser.parseOptionalAssignmentList(regionArgs, operands);
1228   if (listResult.hasValue() && failed(listResult.getValue()))
1229     return failure();
1230 
1231   FunctionType functionType;
1232   llvm::SMLoc typeLoc = parser.getCurrentLocation();
1233   if (failed(parser.parseColonType(functionType)))
1234     return failure();
1235 
1236   result.addTypes(functionType.getResults());
1237 
1238   if (functionType.getNumInputs() != operands.size()) {
1239     return parser.emitError(typeLoc)
1240            << "expected as many input types as operands "
1241            << "(expected " << operands.size() << " got "
1242            << functionType.getNumInputs() << ")";
1243   }
1244 
1245   // Resolve input operands.
1246   if (failed(parser.resolveOperands(operands, functionType.getInputs(),
1247                                     parser.getCurrentLocation(),
1248                                     result.operands)))
1249     return failure();
1250 
1251   return failure(
1252       parser.parseRegion(*before, regionArgs, functionType.getInputs()) ||
1253       parser.parseKeyword("do") || parser.parseRegion(*after) ||
1254       parser.parseOptionalAttrDictWithKeyword(result.attributes));
1255 }
1256 
1257 /// Prints a `while` op.
print(OpAsmPrinter & p,scf::WhileOp op)1258 static void print(OpAsmPrinter &p, scf::WhileOp op) {
1259   p << op.getOperationName();
1260   printInitializationList(p, op.before().front().getArguments(), op.inits(),
1261                           " ");
1262   p << " : ";
1263   p.printFunctionalType(op.inits().getTypes(), op.results().getTypes());
1264   p.printRegion(op.before(), /*printEntryBlockArgs=*/false);
1265   p << " do";
1266   p.printRegion(op.after());
1267   p.printOptionalAttrDictWithKeyword(op.getAttrs());
1268 }
1269 
1270 /// Verifies that two ranges of types match, i.e. have the same number of
1271 /// entries and that types are pairwise equals. Reports errors on the given
1272 /// operation in case of mismatch.
1273 template <typename OpTy>
verifyTypeRangesMatch(OpTy op,TypeRange left,TypeRange right,StringRef message)1274 static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
1275                                            TypeRange right, StringRef message) {
1276   if (left.size() != right.size())
1277     return op.emitOpError("expects the same number of ") << message;
1278 
1279   for (unsigned i = 0, e = left.size(); i < e; ++i) {
1280     if (left[i] != right[i]) {
1281       InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
1282                                 << message;
1283       diag.attachNote() << "for argument " << i << ", found " << left[i]
1284                         << " and " << right[i];
1285       return diag;
1286     }
1287   }
1288 
1289   return success();
1290 }
1291 
1292 /// Verifies that the first block of the given `region` is terminated by a
1293 /// YieldOp. Reports errors on the given operation if it is not the case.
1294 template <typename TerminatorTy>
verifyAndGetTerminator(scf::WhileOp op,Region & region,StringRef errorMessage)1295 static TerminatorTy verifyAndGetTerminator(scf::WhileOp op, Region &region,
1296                                            StringRef errorMessage) {
1297   Operation *terminatorOperation = region.front().getTerminator();
1298   if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
1299     return yield;
1300 
1301   auto diag = op.emitOpError(errorMessage);
1302   if (terminatorOperation)
1303     diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
1304   return nullptr;
1305 }
1306 
verify(scf::WhileOp op)1307 static LogicalResult verify(scf::WhileOp op) {
1308   if (failed(RegionBranchOpInterface::verifyTypes(op)))
1309     return failure();
1310 
1311   auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
1312       op, op.before(),
1313       "expects the 'before' region to terminate with 'scf.condition'");
1314   if (!beforeTerminator)
1315     return failure();
1316 
1317   TypeRange trailingTerminatorOperands = beforeTerminator.args().getTypes();
1318   if (failed(verifyTypeRangesMatch(op, trailingTerminatorOperands,
1319                                    op.after().getArgumentTypes(),
1320                                    "trailing operands of the 'before' block "
1321                                    "terminator and 'after' region arguments")))
1322     return failure();
1323 
1324   if (failed(verifyTypeRangesMatch(
1325           op, trailingTerminatorOperands, op.getResultTypes(),
1326           "trailing operands of the 'before' block terminator and op results")))
1327     return failure();
1328 
1329   auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
1330       op, op.after(),
1331       "expects the 'after' region to terminate with 'scf.yield'");
1332   return success(afterTerminator != nullptr);
1333 }
1334 
1335 //===----------------------------------------------------------------------===//
1336 // YieldOp
1337 //===----------------------------------------------------------------------===//
1338 
parseYieldOp(OpAsmParser & parser,OperationState & result)1339 static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) {
1340   SmallVector<OpAsmParser::OperandType, 4> operands;
1341   SmallVector<Type, 4> types;
1342   llvm::SMLoc loc = parser.getCurrentLocation();
1343   // Parse variadic operands list, their types, and resolve operands to SSA
1344   // values.
1345   if (parser.parseOperandList(operands) ||
1346       parser.parseOptionalColonTypeList(types) ||
1347       parser.resolveOperands(operands, types, loc, result.operands))
1348     return failure();
1349   return success();
1350 }
1351 
print(OpAsmPrinter & p,scf::YieldOp op)1352 static void print(OpAsmPrinter &p, scf::YieldOp op) {
1353   p << op.getOperationName();
1354   if (op.getNumOperands() != 0)
1355     p << ' ' << op.getOperands() << " : " << op.getOperandTypes();
1356 }
1357 
1358 //===----------------------------------------------------------------------===//
1359 // TableGen'd op method definitions
1360 //===----------------------------------------------------------------------===//
1361 
1362 #define GET_OP_CLASSES
1363 #include "mlir/Dialect/SCF/SCFOps.cpp.inc"
1364