1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/DialectConversion.h"
10 #include "mlir/IR/Block.h"
11 #include "mlir/IR/BlockAndValueMapping.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/Rewrite/PatternApplicator.h"
15 #include "mlir/Transforms/Utils.h"
16 #include "llvm/ADT/SetVector.h"
17 #include "llvm/ADT/SmallPtrSet.h"
18 #include "llvm/Support/Debug.h"
19 #include "llvm/Support/FormatVariadic.h"
20 #include "llvm/Support/SaveAndRestore.h"
21 #include "llvm/Support/ScopedPrinter.h"
22 
23 using namespace mlir;
24 using namespace mlir::detail;
25 
26 #define DEBUG_TYPE "dialect-conversion"
27 
28 /// Recursively collect all of the operations to convert from within 'region'.
29 /// If 'target' is nonnull, operations that are recursively legal have their
30 /// regions pre-filtered to avoid considering them for legalization.
31 static LogicalResult
computeConversionSet(iterator_range<Region::iterator> region,Location regionLoc,std::vector<Operation * > & toConvert,ConversionTarget * target=nullptr)32 computeConversionSet(iterator_range<Region::iterator> region,
33                      Location regionLoc, std::vector<Operation *> &toConvert,
34                      ConversionTarget *target = nullptr) {
35   if (llvm::empty(region))
36     return success();
37 
38   // Traverse starting from the entry block.
39   SmallVector<Block *, 16> worklist(1, &*region.begin());
40   DenseSet<Block *> visitedBlocks;
41   visitedBlocks.insert(worklist.front());
42   while (!worklist.empty()) {
43     Block *block = worklist.pop_back_val();
44 
45     // Compute the conversion set of each of the nested operations.
46     for (Operation &op : *block) {
47       toConvert.emplace_back(&op);
48 
49       // Don't check this operation's children for conversion if the operation
50       // is recursively legal.
51       auto legalityInfo = target ? target->isLegal(&op)
52                                  : Optional<ConversionTarget::LegalOpDetails>();
53       if (legalityInfo && legalityInfo->isRecursivelyLegal)
54         continue;
55       for (auto &region : op.getRegions()) {
56         if (failed(computeConversionSet(region.getBlocks(), region.getLoc(),
57                                         toConvert, target)))
58           return failure();
59       }
60     }
61 
62     // Recurse to children that haven't been visited.
63     for (Block *succ : block->getSuccessors())
64       if (visitedBlocks.insert(succ).second)
65         worklist.push_back(succ);
66   }
67 
68   // Check that all blocks in the region were visited.
69   if (llvm::any_of(llvm::drop_begin(region, 1),
70                    [&](Block &block) { return !visitedBlocks.count(&block); }))
71     return emitError(regionLoc, "unreachable blocks were not converted");
72   return success();
73 }
74 
75 /// A utility function to log a successful result for the given reason.
76 template <typename... Args>
logSuccess(llvm::ScopedPrinter & os,StringRef fmt,Args &&...args)77 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
78   LLVM_DEBUG({
79     os.unindent();
80     os.startLine() << "} -> SUCCESS";
81     if (!fmt.empty())
82       os.getOStream() << " : "
83                       << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
84     os.getOStream() << "\n";
85   });
86 }
87 
88 /// A utility function to log a failure result for the given reason.
89 template <typename... Args>
logFailure(llvm::ScopedPrinter & os,StringRef fmt,Args &&...args)90 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
91   LLVM_DEBUG({
92     os.unindent();
93     os.startLine() << "} -> FAILURE : "
94                    << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
95                    << "\n";
96   });
97 }
98 
99 //===----------------------------------------------------------------------===//
100 // ConversionValueMapping
101 //===----------------------------------------------------------------------===//
102 
103 namespace {
104 /// This class wraps a BlockAndValueMapping to provide recursive lookup
105 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
106 struct ConversionValueMapping {
107   /// Lookup a mapped value within the map. If a mapping for the provided value
108   /// does not exist then return the provided value. If `desiredType` is
109   /// non-null, returns the most recently mapped value with that type. If an
110   /// operand of that type does not exist, defaults to normal behavior.
111   Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
112 
113   /// Lookup a mapped value within the map, or return null if a mapping does not
114   /// exist. If a mapping exists, this follows the same behavior of
115   /// `lookupOrDefault`.
116   Value lookupOrNull(Value from) const;
117 
118   /// Map a value to the one provided.
map__anon942ecdf30211::ConversionValueMapping119   void map(Value oldVal, Value newVal) { mapping.map(oldVal, newVal); }
120 
121   /// Drop the last mapping for the given value.
erase__anon942ecdf30211::ConversionValueMapping122   void erase(Value value) { mapping.erase(value); }
123 
124 private:
125   /// Current value mappings.
126   BlockAndValueMapping mapping;
127 };
128 } // end anonymous namespace
129 
lookupOrDefault(Value from,Type desiredType) const130 Value ConversionValueMapping::lookupOrDefault(Value from,
131                                               Type desiredType) const {
132   // If there was no desired type, simply find the leaf value.
133   if (!desiredType) {
134     // If this value had a valid mapping, unmap that value as well in the case
135     // that it was also replaced.
136     while (auto mappedValue = mapping.lookupOrNull(from))
137       from = mappedValue;
138     return from;
139   }
140 
141   // Otherwise, try to find the deepest value that has the desired type.
142   Value desiredValue;
143   do {
144     if (from.getType() == desiredType)
145       desiredValue = from;
146 
147     Value mappedValue = mapping.lookupOrNull(from);
148     if (!mappedValue)
149       break;
150     from = mappedValue;
151   } while (true);
152 
153   // If the desired value was found use it, otherwise default to the leaf value.
154   return desiredValue ? desiredValue : from;
155 }
156 
lookupOrNull(Value from) const157 Value ConversionValueMapping::lookupOrNull(Value from) const {
158   Value result = lookupOrDefault(from);
159   return result == from ? nullptr : result;
160 }
161 
162 //===----------------------------------------------------------------------===//
163 // ArgConverter
164 //===----------------------------------------------------------------------===//
165 namespace {
166 /// This class provides a simple interface for converting the types of block
167 /// arguments. This is done by creating a new block that contains the new legal
168 /// types and extracting the block that contains the old illegal types to allow
169 /// for undoing pending rewrites in the case of failure.
170 struct ArgConverter {
ArgConverter__anon942ecdf30311::ArgConverter171   ArgConverter(PatternRewriter &rewriter) : rewriter(rewriter) {}
172 
173   /// This structure contains the information pertaining to an argument that has
174   /// been converted.
175   struct ConvertedArgInfo {
ConvertedArgInfo__anon942ecdf30311::ArgConverter::ConvertedArgInfo176     ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
177                      Value castValue = nullptr)
178         : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
179 
180     /// The start index of in the new argument list that contains arguments that
181     /// replace the original.
182     unsigned newArgIdx;
183 
184     /// The number of arguments that replaced the original argument.
185     unsigned newArgSize;
186 
187     /// The cast value that was created to cast from the new arguments to the
188     /// old. This only used if 'newArgSize' > 1.
189     Value castValue;
190   };
191 
192   /// This structure contains information pertaining to a block that has had its
193   /// signature converted.
194   struct ConvertedBlockInfo {
ConvertedBlockInfo__anon942ecdf30311::ArgConverter::ConvertedBlockInfo195     ConvertedBlockInfo(Block *origBlock, TypeConverter &converter)
196         : origBlock(origBlock), converter(&converter) {}
197 
198     /// The original block that was requested to have its signature converted.
199     Block *origBlock;
200 
201     /// The conversion information for each of the arguments. The information is
202     /// None if the argument was dropped during conversion.
203     SmallVector<Optional<ConvertedArgInfo>, 1> argInfo;
204 
205     /// The type converter used to convert the arguments.
206     TypeConverter *converter;
207   };
208 
209   /// Return if the signature of the given block has already been converted.
hasBeenConverted__anon942ecdf30311::ArgConverter210   bool hasBeenConverted(Block *block) const {
211     return conversionInfo.count(block) || convertedBlocks.count(block);
212   }
213 
214   /// Set the type converter to use for the given region.
setConverter__anon942ecdf30311::ArgConverter215   void setConverter(Region *region, TypeConverter *typeConverter) {
216     assert(typeConverter && "expected valid type converter");
217     regionToConverter[region] = typeConverter;
218   }
219 
220   /// Return the type converter to use for the given region, or null if there
221   /// isn't one.
getConverter__anon942ecdf30311::ArgConverter222   TypeConverter *getConverter(Region *region) {
223     return regionToConverter.lookup(region);
224   }
225 
226   //===--------------------------------------------------------------------===//
227   // Rewrite Application
228   //===--------------------------------------------------------------------===//
229 
230   /// Erase any rewrites registered for the blocks within the given operation
231   /// which is about to be removed. This merely drops the rewrites without
232   /// undoing them.
233   void notifyOpRemoved(Operation *op);
234 
235   /// Cleanup and undo any generated conversions for the arguments of block.
236   /// This method replaces the new block with the original, reverting the IR to
237   /// its original state.
238   void discardRewrites(Block *block);
239 
240   /// Fully replace uses of the old arguments with the new.
241   void applyRewrites(ConversionValueMapping &mapping);
242 
243   /// Materialize any necessary conversions for converted arguments that have
244   /// live users, using the provided `findLiveUser` to search for a user that
245   /// survives the conversion process.
246   LogicalResult
247   materializeLiveConversions(ConversionValueMapping &mapping,
248                              OpBuilder &builder,
249                              function_ref<Operation *(Value)> findLiveUser);
250 
251   //===--------------------------------------------------------------------===//
252   // Conversion
253   //===--------------------------------------------------------------------===//
254 
255   /// Attempt to convert the signature of the given block, if successful a new
256   /// block is returned containing the new arguments. Returns `block` if it did
257   /// not require conversion.
258   FailureOr<Block *> convertSignature(Block *block, TypeConverter &converter,
259                                       ConversionValueMapping &mapping);
260 
261   /// Apply the given signature conversion on the given block. The new block
262   /// containing the updated signature is returned. If no conversions were
263   /// necessary, e.g. if the block has no arguments, `block` is returned.
264   /// `converter` is used to generate any necessary cast operations that
265   /// translate between the origin argument types and those specified in the
266   /// signature conversion.
267   Block *applySignatureConversion(
268       Block *block, TypeConverter &converter,
269       TypeConverter::SignatureConversion &signatureConversion,
270       ConversionValueMapping &mapping);
271 
272   /// Insert a new conversion into the cache.
273   void insertConversion(Block *newBlock, ConvertedBlockInfo &&info);
274 
275   /// A collection of blocks that have had their arguments converted. This is a
276   /// map from the new replacement block, back to the original block.
277   llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
278 
279   /// The set of original blocks that were converted.
280   DenseSet<Block *> convertedBlocks;
281 
282   /// A mapping from valid regions, to those containing the original blocks of a
283   /// conversion.
284   DenseMap<Region *, std::unique_ptr<Region>> regionMapping;
285 
286   /// A mapping of regions to type converters that should be used when
287   /// converting the arguments of blocks within that region.
288   DenseMap<Region *, TypeConverter *> regionToConverter;
289 
290   /// The pattern rewriter to use when materializing conversions.
291   PatternRewriter &rewriter;
292 };
293 } // end anonymous namespace
294 
295 //===----------------------------------------------------------------------===//
296 // Rewrite Application
297 
notifyOpRemoved(Operation * op)298 void ArgConverter::notifyOpRemoved(Operation *op) {
299   if (conversionInfo.empty())
300     return;
301 
302   for (Region &region : op->getRegions()) {
303     for (Block &block : region) {
304       // Drop any rewrites from within.
305       for (Operation &nestedOp : block)
306         if (nestedOp.getNumRegions())
307           notifyOpRemoved(&nestedOp);
308 
309       // Check if this block was converted.
310       auto it = conversionInfo.find(&block);
311       if (it == conversionInfo.end())
312         continue;
313 
314       // Drop all uses of the original arguments and delete the original block.
315       Block *origBlock = it->second.origBlock;
316       for (BlockArgument arg : origBlock->getArguments())
317         arg.dropAllUses();
318       conversionInfo.erase(it);
319     }
320   }
321 }
322 
discardRewrites(Block * block)323 void ArgConverter::discardRewrites(Block *block) {
324   auto it = conversionInfo.find(block);
325   if (it == conversionInfo.end())
326     return;
327   Block *origBlock = it->second.origBlock;
328 
329   // Drop all uses of the new block arguments and replace uses of the new block.
330   for (int i = block->getNumArguments() - 1; i >= 0; --i)
331     block->getArgument(i).dropAllUses();
332   block->replaceAllUsesWith(origBlock);
333 
334   // Move the operations back the original block and the delete the new block.
335   origBlock->getOperations().splice(origBlock->end(), block->getOperations());
336   origBlock->moveBefore(block);
337   block->erase();
338 
339   convertedBlocks.erase(origBlock);
340   conversionInfo.erase(it);
341 }
342 
applyRewrites(ConversionValueMapping & mapping)343 void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
344   for (auto &info : conversionInfo) {
345     ConvertedBlockInfo &blockInfo = info.second;
346     Block *origBlock = blockInfo.origBlock;
347 
348     // Process the remapping for each of the original arguments.
349     for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
350       Optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i];
351       BlockArgument origArg = origBlock->getArgument(i);
352 
353       // Handle the case of a 1->0 value mapping.
354       if (!argInfo) {
355         if (Value newArg = mapping.lookupOrNull(origArg))
356           origArg.replaceAllUsesWith(newArg);
357         continue;
358       }
359 
360       // Otherwise this is a 1->1+ value mapping.
361       Value castValue = argInfo->castValue;
362       assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
363 
364       // If the argument is still used, replace it with the generated cast.
365       if (!origArg.use_empty())
366         origArg.replaceAllUsesWith(mapping.lookupOrDefault(castValue));
367     }
368   }
369 }
370 
materializeLiveConversions(ConversionValueMapping & mapping,OpBuilder & builder,function_ref<Operation * (Value)> findLiveUser)371 LogicalResult ArgConverter::materializeLiveConversions(
372     ConversionValueMapping &mapping, OpBuilder &builder,
373     function_ref<Operation *(Value)> findLiveUser) {
374   for (auto &info : conversionInfo) {
375     Block *newBlock = info.first;
376     ConvertedBlockInfo &blockInfo = info.second;
377     Block *origBlock = blockInfo.origBlock;
378 
379     // Process the remapping for each of the original arguments.
380     for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
381       // FIXME: We should run the below checks even if the type conversion was
382       // 1->N, but a lot of existing lowering rely on the block argument being
383       // blindly replaced. Those usages should be updated, and this if should be
384       // removed.
385       if (blockInfo.argInfo[i])
386         continue;
387 
388       // If the type of this argument changed and the argument is still live, we
389       // need to materialize a conversion.
390       BlockArgument origArg = origBlock->getArgument(i);
391       auto argReplacementValue = mapping.lookupOrDefault(origArg);
392       bool isDroppedArg = argReplacementValue == origArg;
393       if (argReplacementValue.getType() == origArg.getType() && !isDroppedArg)
394         continue;
395       Operation *liveUser = findLiveUser(origArg);
396       if (!liveUser)
397         continue;
398 
399       if (OpResult result = argReplacementValue.dyn_cast<OpResult>())
400         rewriter.setInsertionPointAfter(result.getOwner());
401       else
402         rewriter.setInsertionPointToStart(newBlock);
403       Value newArg = blockInfo.converter->materializeSourceConversion(
404           rewriter, origArg.getLoc(), origArg.getType(),
405           isDroppedArg ? ValueRange() : ValueRange(argReplacementValue));
406       if (!newArg) {
407         InFlightDiagnostic diag =
408             emitError(origArg.getLoc())
409             << "failed to materialize conversion for block argument #" << i
410             << " that remained live after conversion, type was "
411             << origArg.getType();
412         if (!isDroppedArg)
413           diag << ", with target type " << argReplacementValue.getType();
414         diag.attachNote(liveUser->getLoc())
415             << "see existing live user here: " << *liveUser;
416         return failure();
417       }
418       mapping.map(origArg, newArg);
419     }
420   }
421   return success();
422 }
423 
424 //===----------------------------------------------------------------------===//
425 // Conversion
426 
427 FailureOr<Block *>
convertSignature(Block * block,TypeConverter & converter,ConversionValueMapping & mapping)428 ArgConverter::convertSignature(Block *block, TypeConverter &converter,
429                                ConversionValueMapping &mapping) {
430   // Check if the block was already converted. If the block is detached,
431   // conservatively assume it is going to be deleted.
432   if (hasBeenConverted(block) || !block->getParent())
433     return block;
434 
435   // Try to convert the signature for the block with the provided converter.
436   if (auto conversion = converter.convertBlockSignature(block))
437     return applySignatureConversion(block, converter, *conversion, mapping);
438   return failure();
439 }
440 
applySignatureConversion(Block * block,TypeConverter & converter,TypeConverter::SignatureConversion & signatureConversion,ConversionValueMapping & mapping)441 Block *ArgConverter::applySignatureConversion(
442     Block *block, TypeConverter &converter,
443     TypeConverter::SignatureConversion &signatureConversion,
444     ConversionValueMapping &mapping) {
445   // If no arguments are being changed or added, there is nothing to do.
446   unsigned origArgCount = block->getNumArguments();
447   auto convertedTypes = signatureConversion.getConvertedTypes();
448   if (origArgCount == 0 && convertedTypes.empty())
449     return block;
450 
451   // Split the block at the beginning to get a new block to use for the updated
452   // signature.
453   Block *newBlock = block->splitBlock(block->begin());
454   block->replaceAllUsesWith(newBlock);
455 
456   SmallVector<Value, 4> newArgRange(newBlock->addArguments(convertedTypes));
457   ArrayRef<Value> newArgs(newArgRange);
458 
459   // Remap each of the original arguments as determined by the signature
460   // conversion.
461   ConvertedBlockInfo info(block, converter);
462   info.argInfo.resize(origArgCount);
463 
464   OpBuilder::InsertionGuard guard(rewriter);
465   rewriter.setInsertionPointToStart(newBlock);
466   for (unsigned i = 0; i != origArgCount; ++i) {
467     auto inputMap = signatureConversion.getInputMapping(i);
468     if (!inputMap)
469       continue;
470     BlockArgument origArg = block->getArgument(i);
471 
472     // If inputMap->replacementValue is not nullptr, then the argument is
473     // dropped and a replacement value is provided to be the remappedValue.
474     if (inputMap->replacementValue) {
475       assert(inputMap->size == 0 &&
476              "invalid to provide a replacement value when the argument isn't "
477              "dropped");
478       mapping.map(origArg, inputMap->replacementValue);
479       continue;
480     }
481 
482     // Otherwise, this is a 1->1+ mapping. Call into the provided type converter
483     // to pack the new values. For 1->1 mappings, if there is no materialization
484     // provided, use the argument directly instead.
485     auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
486     Value newArg = converter.materializeArgumentConversion(
487         rewriter, origArg.getLoc(), origArg.getType(), replArgs);
488     if (!newArg) {
489       assert(replArgs.size() == 1 &&
490              "couldn't materialize the result of 1->N conversion");
491       newArg = replArgs.front();
492     }
493     mapping.map(origArg, newArg);
494     info.argInfo[i] =
495         ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
496   }
497 
498   // Remove the original block from the region and return the new one.
499   insertConversion(newBlock, std::move(info));
500   return newBlock;
501 }
502 
insertConversion(Block * newBlock,ConvertedBlockInfo && info)503 void ArgConverter::insertConversion(Block *newBlock,
504                                     ConvertedBlockInfo &&info) {
505   // Get a region to insert the old block.
506   Region *region = newBlock->getParent();
507   std::unique_ptr<Region> &mappedRegion = regionMapping[region];
508   if (!mappedRegion)
509     mappedRegion = std::make_unique<Region>(region->getParentOp());
510 
511   // Move the original block to the mapped region and emplace the conversion.
512   mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(),
513                                    info.origBlock->getIterator());
514   convertedBlocks.insert(info.origBlock);
515   conversionInfo.insert({newBlock, std::move(info)});
516 }
517 
518 //===----------------------------------------------------------------------===//
519 // Rewriter and Translation State
520 //===----------------------------------------------------------------------===//
521 namespace {
522 /// This class contains a snapshot of the current conversion rewriter state.
523 /// This is useful when saving and undoing a set of rewrites.
524 struct RewriterState {
RewriterState__anon942ecdf30411::RewriterState525   RewriterState(unsigned numCreatedOps, unsigned numReplacements,
526                 unsigned numArgReplacements, unsigned numBlockActions,
527                 unsigned numIgnoredOperations, unsigned numRootUpdates)
528       : numCreatedOps(numCreatedOps), numReplacements(numReplacements),
529         numArgReplacements(numArgReplacements),
530         numBlockActions(numBlockActions),
531         numIgnoredOperations(numIgnoredOperations),
532         numRootUpdates(numRootUpdates) {}
533 
534   /// The current number of created operations.
535   unsigned numCreatedOps;
536 
537   /// The current number of replacements queued.
538   unsigned numReplacements;
539 
540   /// The current number of argument replacements queued.
541   unsigned numArgReplacements;
542 
543   /// The current number of block actions performed.
544   unsigned numBlockActions;
545 
546   /// The current number of ignored operations.
547   unsigned numIgnoredOperations;
548 
549   /// The current number of operations that were updated in place.
550   unsigned numRootUpdates;
551 };
552 
553 /// The state of an operation that was updated by a pattern in-place. This
554 /// contains all of the necessary information to reconstruct an operation that
555 /// was updated in place.
556 class OperationTransactionState {
557 public:
558   OperationTransactionState() = default;
OperationTransactionState(Operation * op)559   OperationTransactionState(Operation *op)
560       : op(op), loc(op->getLoc()), attrs(op->getMutableAttrDict()),
561         operands(op->operand_begin(), op->operand_end()),
562         successors(op->successor_begin(), op->successor_end()) {}
563 
564   /// Discard the transaction state and reset the state of the original
565   /// operation.
resetOperation() const566   void resetOperation() const {
567     op->setLoc(loc);
568     op->setAttrs(attrs);
569     op->setOperands(operands);
570     for (auto it : llvm::enumerate(successors))
571       op->setSuccessor(it.value(), it.index());
572   }
573 
574   /// Return the original operation of this state.
getOperation() const575   Operation *getOperation() const { return op; }
576 
577 private:
578   Operation *op;
579   LocationAttr loc;
580   MutableDictionaryAttr attrs;
581   SmallVector<Value, 8> operands;
582   SmallVector<Block *, 2> successors;
583 };
584 
585 /// This class represents one requested operation replacement via 'replaceOp' or
586 /// 'eraseOp`.
587 struct OpReplacement {
588   OpReplacement() = default;
OpReplacement__anon942ecdf30411::OpReplacement589   OpReplacement(TypeConverter *converter) : converter(converter) {}
590 
591   /// An optional type converter that can be used to materialize conversions
592   /// between the new and old values if necessary.
593   TypeConverter *converter = nullptr;
594 };
595 
596 /// The kind of the block action performed during the rewrite.  Actions can be
597 /// undone if the conversion fails.
598 enum class BlockActionKind {
599   Create,
600   Erase,
601   Merge,
602   Move,
603   Split,
604   TypeConversion
605 };
606 
607 /// Original position of the given block in its parent region. During undo
608 /// actions, the block needs to be placed after `insertAfterBlock`.
609 struct BlockPosition {
610   Region *region;
611   Block *insertAfterBlock;
612 };
613 
614 /// Information needed to undo the merge actions.
615 /// - the source block, and
616 /// - the Operation that was the last operation in the dest block before the
617 ///   merge (could be null if the dest block was empty).
618 struct MergeInfo {
619   Block *sourceBlock;
620   Operation *destBlockLastInst;
621 };
622 
623 /// The storage class for an undoable block action (one of BlockActionKind),
624 /// contains the information necessary to undo this action.
625 struct BlockAction {
getCreate__anon942ecdf30411::BlockAction626   static BlockAction getCreate(Block *block) {
627     return {BlockActionKind::Create, block, {}};
628   }
getErase__anon942ecdf30411::BlockAction629   static BlockAction getErase(Block *block, BlockPosition originalPosition) {
630     return {BlockActionKind::Erase, block, {originalPosition}};
631   }
getMerge__anon942ecdf30411::BlockAction632   static BlockAction getMerge(Block *block, Block *sourceBlock) {
633     BlockAction action{BlockActionKind::Merge, block, {}};
634     action.mergeInfo = {sourceBlock, block->empty() ? nullptr : &block->back()};
635     return action;
636   }
getMove__anon942ecdf30411::BlockAction637   static BlockAction getMove(Block *block, BlockPosition originalPosition) {
638     return {BlockActionKind::Move, block, {originalPosition}};
639   }
getSplit__anon942ecdf30411::BlockAction640   static BlockAction getSplit(Block *block, Block *originalBlock) {
641     BlockAction action{BlockActionKind::Split, block, {}};
642     action.originalBlock = originalBlock;
643     return action;
644   }
getTypeConversion__anon942ecdf30411::BlockAction645   static BlockAction getTypeConversion(Block *block) {
646     return BlockAction{BlockActionKind::TypeConversion, block, {}};
647   }
648 
649   // The action kind.
650   BlockActionKind kind;
651 
652   // A pointer to the block that was created by the action.
653   Block *block;
654 
655   union {
656     // In use if kind == BlockActionKind::Move or BlockActionKind::Erase, and
657     // contains a pointer to the region that originally contained the block as
658     // well as the position of the block in that region.
659     BlockPosition originalPosition;
660     // In use if kind == BlockActionKind::Split and contains a pointer to the
661     // block that was split into two parts.
662     Block *originalBlock;
663     // In use if kind == BlockActionKind::Merge, and contains the information
664     // needed to undo the merge.
665     MergeInfo mergeInfo;
666   };
667 };
668 } // end anonymous namespace
669 
670 //===----------------------------------------------------------------------===//
671 // ConversionPatternRewriterImpl
672 //===----------------------------------------------------------------------===//
673 namespace mlir {
674 namespace detail {
675 struct ConversionPatternRewriterImpl {
ConversionPatternRewriterImplmlir::detail::ConversionPatternRewriterImpl676   ConversionPatternRewriterImpl(PatternRewriter &rewriter)
677       : argConverter(rewriter) {}
678 
679   /// Cleanup and destroy any generated rewrite operations. This method is
680   /// invoked when the conversion process fails.
681   void discardRewrites();
682 
683   /// Apply all requested operation rewrites. This method is invoked when the
684   /// conversion process succeeds.
685   void applyRewrites();
686 
687   //===--------------------------------------------------------------------===//
688   // State Management
689   //===--------------------------------------------------------------------===//
690 
691   /// Return the current state of the rewriter.
692   RewriterState getCurrentState();
693 
694   /// Reset the state of the rewriter to a previously saved point.
695   void resetState(RewriterState state);
696 
697   /// Erase any blocks that were unlinked from their regions and stored in block
698   /// actions.
699   void eraseDanglingBlocks();
700 
701   /// Undo the block actions (motions, splits) one by one in reverse order until
702   /// "numActionsToKeep" actions remains.
703   void undoBlockActions(unsigned numActionsToKeep = 0);
704 
705   /// Remap the given operands to those with potentially different types. The
706   /// provided type converter is used to ensure that the remapped types are
707   /// legal. Returns success if the operands could be remapped, failure
708   /// otherwise.
709   LogicalResult remapValues(Location loc, PatternRewriter &rewriter,
710                             TypeConverter *converter,
711                             Operation::operand_range operands,
712                             SmallVectorImpl<Value> &remapped);
713 
714   /// Returns true if the given operation is ignored, and does not need to be
715   /// converted.
716   bool isOpIgnored(Operation *op) const;
717 
718   /// Recursively marks the nested operations under 'op' as ignored. This
719   /// removes them from being considered for legalization.
720   void markNestedOpsIgnored(Operation *op);
721 
722   //===--------------------------------------------------------------------===//
723   // Type Conversion
724   //===--------------------------------------------------------------------===//
725 
726   /// Convert the signature of the given block.
727   FailureOr<Block *> convertBlockSignature(
728       Block *block, TypeConverter &converter,
729       TypeConverter::SignatureConversion *conversion = nullptr);
730 
731   /// Apply a signature conversion on the given region.
732   Block *
733   applySignatureConversion(Region *region,
734                            TypeConverter::SignatureConversion &conversion);
735 
736   /// Convert the types of block arguments within the given region.
737   FailureOr<Block *>
738   convertRegionTypes(Region *region, TypeConverter &converter,
739                      TypeConverter::SignatureConversion *entryConversion);
740 
741   //===--------------------------------------------------------------------===//
742   // Rewriter Notification Hooks
743   //===--------------------------------------------------------------------===//
744 
745   /// PatternRewriter hook for replacing the results of an operation.
746   void notifyOpReplaced(Operation *op, ValueRange newValues);
747 
748   /// Notifies that a block is about to be erased.
749   void notifyBlockIsBeingErased(Block *block);
750 
751   /// Notifies that a block was created.
752   void notifyCreatedBlock(Block *block);
753 
754   /// Notifies that a block was split.
755   void notifySplitBlock(Block *block, Block *continuation);
756 
757   /// Notifies that `block` is being merged with `srcBlock`.
758   void notifyBlocksBeingMerged(Block *block, Block *srcBlock);
759 
760   /// Notifies that the blocks of a region are about to be moved.
761   void notifyRegionIsBeingInlinedBefore(Region &region, Region &parent,
762                                         Region::iterator before);
763 
764   /// Notifies that the blocks of a region were cloned into another.
765   void notifyRegionWasClonedBefore(iterator_range<Region::iterator> &blocks,
766                                    Location origRegionLoc);
767 
768   /// Notifies that a pattern match failed for the given reason.
769   LogicalResult
770   notifyMatchFailure(Location loc,
771                      function_ref<void(Diagnostic &)> reasonCallback);
772 
773   //===--------------------------------------------------------------------===//
774   // State
775   //===--------------------------------------------------------------------===//
776 
777   // Mapping between replaced values that differ in type. This happens when
778   // replacing a value with one of a different type.
779   ConversionValueMapping mapping;
780 
781   /// Utility used to convert block arguments.
782   ArgConverter argConverter;
783 
784   /// Ordered vector of all of the newly created operations during conversion.
785   std::vector<Operation *> createdOps;
786 
787   /// Ordered map of requested operation replacements.
788   llvm::MapVector<Operation *, OpReplacement> replacements;
789 
790   /// Ordered vector of any requested block argument replacements.
791   SmallVector<BlockArgument, 4> argReplacements;
792 
793   /// Ordered list of block operations (creations, splits, motions).
794   SmallVector<BlockAction, 4> blockActions;
795 
796   /// A set of operations that should no longer be considered for legalization,
797   /// but were not directly replace/erased/etc. by a pattern. These are
798   /// generally child operations of other operations who were
799   /// replaced/erased/etc. This is not meant to be an exhaustive list of all
800   /// operations, but the minimal set that can be used to detect if a given
801   /// operation should be `ignored`. For example, we may add the operations that
802   /// define non-empty regions to the set, but not any of the others. This
803   /// simplifies the amount of memory needed as we can query if the parent
804   /// operation was ignored.
805   llvm::SetVector<Operation *> ignoredOps;
806 
807   /// A transaction state for each of operations that were updated in-place.
808   SmallVector<OperationTransactionState, 4> rootUpdates;
809 
810   /// A vector of indices into `replacements` of operations that were replaced
811   /// with values with different result types than the original operation, e.g.
812   /// 1->N conversion of some kind.
813   SmallVector<unsigned, 4> operationsWithChangedResults;
814 
815   /// A default type converter, used when block conversions do not have one
816   /// explicitly provided.
817   TypeConverter defaultTypeConverter;
818 
819   /// The current conversion pattern that is being rewritten, or nullptr if
820   /// called from outside of a conversion pattern rewrite.
821   const ConversionPattern *currentConversionPattern = nullptr;
822 
823 #ifndef NDEBUG
824   /// A set of operations that have pending updates. This tracking isn't
825   /// strictly necessary, and is thus only active during debug builds for extra
826   /// verification.
827   SmallPtrSet<Operation *, 1> pendingRootUpdates;
828 
829   /// A logger used to emit diagnostics during the conversion process.
830   llvm::ScopedPrinter logger{llvm::dbgs()};
831 #endif
832 };
833 } // end namespace detail
834 } // end namespace mlir
835 
836 /// Detach any operations nested in the given operation from their parent
837 /// blocks, and erase the given operation. This can be used when the nested
838 /// operations are scheduled for erasure themselves, so deleting the regions of
839 /// the given operation together with their content would result in double-free.
840 /// This happens, for example, when rolling back op creation in the reverse
841 /// order and if the nested ops were created before the parent op. This function
842 /// does not need to collect nested ops recursively because it is expected to
843 /// also be called for each nested op when it is about to be deleted.
detachNestedAndErase(Operation * op)844 static void detachNestedAndErase(Operation *op) {
845   for (Region &region : op->getRegions()) {
846     for (Block &block : region.getBlocks()) {
847       while (!block.getOperations().empty())
848         block.getOperations().remove(block.getOperations().begin());
849       block.dropAllDefinedValueUses();
850     }
851   }
852   op->erase();
853 }
854 
discardRewrites()855 void ConversionPatternRewriterImpl::discardRewrites() {
856   // Reset any operations that were updated in place.
857   for (auto &state : rootUpdates)
858     state.resetOperation();
859 
860   undoBlockActions();
861 
862   // Remove any newly created ops.
863   for (auto *op : llvm::reverse(createdOps))
864     detachNestedAndErase(op);
865 }
866 
applyRewrites()867 void ConversionPatternRewriterImpl::applyRewrites() {
868   // Apply all of the rewrites replacements requested during conversion.
869   for (auto &repl : replacements) {
870     for (OpResult result : repl.first->getResults())
871       if (Value newValue = mapping.lookupOrNull(result))
872         result.replaceAllUsesWith(newValue);
873 
874     // If this operation defines any regions, drop any pending argument
875     // rewrites.
876     if (repl.first->getNumRegions())
877       argConverter.notifyOpRemoved(repl.first);
878   }
879 
880   // Apply all of the requested argument replacements.
881   for (BlockArgument arg : argReplacements) {
882     Value repl = mapping.lookupOrDefault(arg);
883     if (repl.isa<BlockArgument>()) {
884       arg.replaceAllUsesWith(repl);
885       continue;
886     }
887 
888     // If the replacement value is an operation, we check to make sure that we
889     // don't replace uses that are within the parent operation of the
890     // replacement value.
891     Operation *replOp = repl.cast<OpResult>().getOwner();
892     Block *replBlock = replOp->getBlock();
893     arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
894       Operation *user = operand.getOwner();
895       return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
896     });
897   }
898 
899   // In a second pass, erase all of the replaced operations in reverse. This
900   // allows processing nested operations before their parent region is
901   // destroyed.
902   for (auto &repl : llvm::reverse(replacements))
903     repl.first->erase();
904 
905   argConverter.applyRewrites(mapping);
906 
907   // Now that the ops have been erased, also erase dangling blocks.
908   eraseDanglingBlocks();
909 }
910 
911 //===----------------------------------------------------------------------===//
912 // State Management
913 
getCurrentState()914 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
915   return RewriterState(createdOps.size(), replacements.size(),
916                        argReplacements.size(), blockActions.size(),
917                        ignoredOps.size(), rootUpdates.size());
918 }
919 
resetState(RewriterState state)920 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
921   // Reset any operations that were updated in place.
922   for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
923     rootUpdates[i].resetOperation();
924   rootUpdates.resize(state.numRootUpdates);
925 
926   // Reset any replaced arguments.
927   for (BlockArgument replacedArg :
928        llvm::drop_begin(argReplacements, state.numArgReplacements))
929     mapping.erase(replacedArg);
930   argReplacements.resize(state.numArgReplacements);
931 
932   // Undo any block actions.
933   undoBlockActions(state.numBlockActions);
934 
935   // Reset any replaced operations and undo any saved mappings.
936   for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
937     for (auto result : repl.first->getResults())
938       mapping.erase(result);
939   while (replacements.size() != state.numReplacements)
940     replacements.pop_back();
941 
942   // Pop all of the newly created operations.
943   while (createdOps.size() != state.numCreatedOps) {
944     detachNestedAndErase(createdOps.back());
945     createdOps.pop_back();
946   }
947 
948   // Pop all of the recorded ignored operations that are no longer valid.
949   while (ignoredOps.size() != state.numIgnoredOperations)
950     ignoredOps.pop_back();
951 
952   // Reset operations with changed results.
953   while (!operationsWithChangedResults.empty() &&
954          operationsWithChangedResults.back() >= state.numReplacements)
955     operationsWithChangedResults.pop_back();
956 }
957 
eraseDanglingBlocks()958 void ConversionPatternRewriterImpl::eraseDanglingBlocks() {
959   for (auto &action : blockActions)
960     if (action.kind == BlockActionKind::Erase)
961       delete action.block;
962 }
963 
undoBlockActions(unsigned numActionsToKeep)964 void ConversionPatternRewriterImpl::undoBlockActions(
965     unsigned numActionsToKeep) {
966   for (auto &action :
967        llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) {
968     switch (action.kind) {
969     // Delete the created block.
970     case BlockActionKind::Create: {
971       // Unlink all of the operations within this block, they will be deleted
972       // separately.
973       auto &blockOps = action.block->getOperations();
974       while (!blockOps.empty())
975         blockOps.remove(blockOps.begin());
976       action.block->dropAllDefinedValueUses();
977       action.block->erase();
978       break;
979     }
980     // Put the block (owned by action) back into its original position.
981     case BlockActionKind::Erase: {
982       auto &blockList = action.originalPosition.region->getBlocks();
983       Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
984       blockList.insert((insertAfterBlock
985                             ? std::next(Region::iterator(insertAfterBlock))
986                             : blockList.begin()),
987                        action.block);
988       break;
989     }
990     // Split the block at the position which was originally the end of the
991     // destination block (owned by action), and put the instructions back into
992     // the block used before the merge.
993     case BlockActionKind::Merge: {
994       Block *sourceBlock = action.mergeInfo.sourceBlock;
995       Block::iterator splitPoint =
996           (action.mergeInfo.destBlockLastInst
997                ? ++Block::iterator(action.mergeInfo.destBlockLastInst)
998                : action.block->begin());
999       sourceBlock->getOperations().splice(sourceBlock->begin(),
1000                                           action.block->getOperations(),
1001                                           splitPoint, action.block->end());
1002       break;
1003     }
1004     // Move the block back to its original position.
1005     case BlockActionKind::Move: {
1006       Region *originalRegion = action.originalPosition.region;
1007       Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
1008       originalRegion->getBlocks().splice(
1009           (insertAfterBlock ? std::next(Region::iterator(insertAfterBlock))
1010                             : originalRegion->end()),
1011           action.block->getParent()->getBlocks(), action.block);
1012       break;
1013     }
1014     // Merge back the block that was split out.
1015     case BlockActionKind::Split: {
1016       action.originalBlock->getOperations().splice(
1017           action.originalBlock->end(), action.block->getOperations());
1018       action.block->dropAllDefinedValueUses();
1019       action.block->erase();
1020       break;
1021     }
1022     // Undo the type conversion.
1023     case BlockActionKind::TypeConversion: {
1024       argConverter.discardRewrites(action.block);
1025       break;
1026     }
1027     }
1028   }
1029   blockActions.resize(numActionsToKeep);
1030 }
1031 
remapValues(Location loc,PatternRewriter & rewriter,TypeConverter * converter,Operation::operand_range operands,SmallVectorImpl<Value> & remapped)1032 LogicalResult ConversionPatternRewriterImpl::remapValues(
1033     Location loc, PatternRewriter &rewriter, TypeConverter *converter,
1034     Operation::operand_range operands, SmallVectorImpl<Value> &remapped) {
1035   remapped.reserve(llvm::size(operands));
1036 
1037   SmallVector<Type, 1> legalTypes;
1038   for (auto it : llvm::enumerate(operands)) {
1039     Value operand = it.value();
1040     Type origType = operand.getType();
1041 
1042     // If a converter was provided, get the desired legal types for this
1043     // operand.
1044     Type desiredType;
1045     if (converter) {
1046       // If there is no legal conversion, fail to match this pattern.
1047       legalTypes.clear();
1048       if (failed(converter->convertType(origType, legalTypes))) {
1049         return notifyMatchFailure(loc, [=](Diagnostic &diag) {
1050           diag << "unable to convert type for operand #" << it.index()
1051                << ", type was " << origType;
1052         });
1053       }
1054       // TODO: There currently isn't any mechanism to do 1->N type conversion
1055       // via the PatternRewriter replacement API, so for now we just ignore it.
1056       if (legalTypes.size() == 1)
1057         desiredType = legalTypes.front();
1058     } else {
1059       // TODO: What we should do here is just set `desiredType` to `origType`
1060       // and then handle the necessary type conversions after the conversion
1061       // process has finished. Unfortunately a lot of patterns currently rely on
1062       // receiving the new operands even if the types change, so we keep the
1063       // original behavior here for now until all of the patterns relying on
1064       // this get updated.
1065     }
1066     Value newOperand = mapping.lookupOrDefault(operand, desiredType);
1067 
1068     // Handle the case where the conversion was 1->1 and the new operand type
1069     // isn't legal.
1070     Type newOperandType = newOperand.getType();
1071     if (converter && desiredType && newOperandType != desiredType) {
1072       // Attempt to materialize a conversion for this new value.
1073       newOperand = converter->materializeTargetConversion(
1074           rewriter, loc, desiredType, newOperand);
1075       if (!newOperand) {
1076         return notifyMatchFailure(loc, [=](Diagnostic &diag) {
1077           diag << "unable to materialize a conversion for "
1078                   "operand #"
1079                << it.index() << ", from " << newOperandType << " to "
1080                << desiredType;
1081         });
1082       }
1083     }
1084     remapped.push_back(newOperand);
1085   }
1086   return success();
1087 }
1088 
isOpIgnored(Operation * op) const1089 bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
1090   // Check to see if this operation was replaced or its parent ignored.
1091   return replacements.count(op) || ignoredOps.count(op->getParentOp());
1092 }
1093 
markNestedOpsIgnored(Operation * op)1094 void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
1095   // Walk this operation and collect nested operations that define non-empty
1096   // regions. We mark such operations as 'ignored' so that we know we don't have
1097   // to convert them, or their nested ops.
1098   if (op->getNumRegions() == 0)
1099     return;
1100   op->walk([&](Operation *op) {
1101     if (llvm::any_of(op->getRegions(),
1102                      [](Region &region) { return !region.empty(); }))
1103       ignoredOps.insert(op);
1104   });
1105 }
1106 
1107 //===----------------------------------------------------------------------===//
1108 // Type Conversion
1109 
convertBlockSignature(Block * block,TypeConverter & converter,TypeConverter::SignatureConversion * conversion)1110 FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
1111     Block *block, TypeConverter &converter,
1112     TypeConverter::SignatureConversion *conversion) {
1113   FailureOr<Block *> result =
1114       conversion ? argConverter.applySignatureConversion(block, converter,
1115                                                          *conversion, mapping)
1116                  : argConverter.convertSignature(block, converter, mapping);
1117   if (Block *newBlock = result.getValue()) {
1118     if (newBlock != block)
1119       blockActions.push_back(BlockAction::getTypeConversion(newBlock));
1120   }
1121   return result;
1122 }
1123 
applySignatureConversion(Region * region,TypeConverter::SignatureConversion & conversion)1124 Block *ConversionPatternRewriterImpl::applySignatureConversion(
1125     Region *region, TypeConverter::SignatureConversion &conversion) {
1126   if (!region->empty()) {
1127     return *convertBlockSignature(&region->front(), defaultTypeConverter,
1128                                   &conversion);
1129   }
1130   return nullptr;
1131 }
1132 
convertRegionTypes(Region * region,TypeConverter & converter,TypeConverter::SignatureConversion * entryConversion)1133 FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
1134     Region *region, TypeConverter &converter,
1135     TypeConverter::SignatureConversion *entryConversion) {
1136   argConverter.setConverter(region, &converter);
1137   if (region->empty())
1138     return nullptr;
1139 
1140   // Convert the arguments of each block within the region.
1141   FailureOr<Block *> newEntry =
1142       convertBlockSignature(&region->front(), converter, entryConversion);
1143   for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(*region, 1)))
1144     if (failed(convertBlockSignature(&block, converter)))
1145       return failure();
1146   return newEntry;
1147 }
1148 
1149 //===----------------------------------------------------------------------===//
1150 // Rewriter Notification Hooks
1151 
notifyOpReplaced(Operation * op,ValueRange newValues)1152 void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
1153                                                      ValueRange newValues) {
1154   assert(newValues.size() == op->getNumResults());
1155   assert(!replacements.count(op) && "operation was already replaced");
1156 
1157   // Track if any of the results changed, e.g. erased and replaced with null.
1158   bool resultChanged = false;
1159 
1160   // Create mappings for each of the new result values.
1161   Value newValue, result;
1162   for (auto it : llvm::zip(newValues, op->getResults())) {
1163     std::tie(newValue, result) = it;
1164     if (!newValue) {
1165       resultChanged = true;
1166       continue;
1167     }
1168     // Remap, and check for any result type changes.
1169     mapping.map(result, newValue);
1170     resultChanged |= (newValue.getType() != result.getType());
1171   }
1172   if (resultChanged)
1173     operationsWithChangedResults.push_back(replacements.size());
1174 
1175   // Record the requested operation replacement.
1176   TypeConverter *converter = nullptr;
1177   if (currentConversionPattern)
1178     converter = currentConversionPattern->getTypeConverter();
1179   replacements.insert(std::make_pair(op, OpReplacement(converter)));
1180 
1181   // Mark this operation as recursively ignored so that we don't need to
1182   // convert any nested operations.
1183   markNestedOpsIgnored(op);
1184 }
1185 
notifyBlockIsBeingErased(Block * block)1186 void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
1187   Region *region = block->getParent();
1188   Block *origPrevBlock = block->getPrevNode();
1189   blockActions.push_back(BlockAction::getErase(block, {region, origPrevBlock}));
1190 }
1191 
notifyCreatedBlock(Block * block)1192 void ConversionPatternRewriterImpl::notifyCreatedBlock(Block *block) {
1193   blockActions.push_back(BlockAction::getCreate(block));
1194 }
1195 
notifySplitBlock(Block * block,Block * continuation)1196 void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
1197                                                      Block *continuation) {
1198   blockActions.push_back(BlockAction::getSplit(continuation, block));
1199 }
1200 
notifyBlocksBeingMerged(Block * block,Block * srcBlock)1201 void ConversionPatternRewriterImpl::notifyBlocksBeingMerged(Block *block,
1202                                                             Block *srcBlock) {
1203   blockActions.push_back(BlockAction::getMerge(block, srcBlock));
1204 }
1205 
notifyRegionIsBeingInlinedBefore(Region & region,Region & parent,Region::iterator before)1206 void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore(
1207     Region &region, Region &parent, Region::iterator before) {
1208   if (region.empty())
1209     return;
1210   Block *laterBlock = &region.back();
1211   for (auto &earlierBlock : llvm::drop_begin(llvm::reverse(region), 1)) {
1212     blockActions.push_back(
1213         BlockAction::getMove(laterBlock, {&region, &earlierBlock}));
1214     laterBlock = &earlierBlock;
1215   }
1216   blockActions.push_back(BlockAction::getMove(laterBlock, {&region, nullptr}));
1217 }
1218 
notifyRegionWasClonedBefore(iterator_range<Region::iterator> & blocks,Location origRegionLoc)1219 void ConversionPatternRewriterImpl::notifyRegionWasClonedBefore(
1220     iterator_range<Region::iterator> &blocks, Location origRegionLoc) {
1221   for (Block &block : blocks)
1222     blockActions.push_back(BlockAction::getCreate(&block));
1223 
1224   // Compute the conversion set for the inlined region.
1225   auto result = computeConversionSet(blocks, origRegionLoc, createdOps);
1226 
1227   // This original region has already had its conversion set computed, so there
1228   // shouldn't be any new failures.
1229   (void)result;
1230   assert(succeeded(result) && "expected region to have no unreachable blocks");
1231 }
1232 
notifyMatchFailure(Location loc,function_ref<void (Diagnostic &)> reasonCallback)1233 LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
1234     Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1235   LLVM_DEBUG({
1236     Diagnostic diag(loc, DiagnosticSeverity::Remark);
1237     reasonCallback(diag);
1238     logger.startLine() << "** Failure : " << diag.str() << "\n";
1239   });
1240   return failure();
1241 }
1242 
1243 //===----------------------------------------------------------------------===//
1244 // ConversionPatternRewriter
1245 //===----------------------------------------------------------------------===//
1246 
ConversionPatternRewriter(MLIRContext * ctx)1247 ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
1248     : PatternRewriter(ctx),
1249       impl(new detail::ConversionPatternRewriterImpl(*this)) {}
~ConversionPatternRewriter()1250 ConversionPatternRewriter::~ConversionPatternRewriter() {}
1251 
1252 /// PatternRewriter hook for replacing the results of an operation.
replaceOp(Operation * op,ValueRange newValues)1253 void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
1254   LLVM_DEBUG({
1255     impl->logger.startLine()
1256         << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1257   });
1258   impl->notifyOpReplaced(op, newValues);
1259 }
1260 
1261 /// PatternRewriter hook for erasing a dead operation. The uses of this
1262 /// operation *must* be made dead by the end of the conversion process,
1263 /// otherwise an assert will be issued.
eraseOp(Operation * op)1264 void ConversionPatternRewriter::eraseOp(Operation *op) {
1265   LLVM_DEBUG({
1266     impl->logger.startLine()
1267         << "** Erase   : '" << op->getName() << "'(" << op << ")\n";
1268   });
1269   SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr);
1270   impl->notifyOpReplaced(op, nullRepls);
1271 }
1272 
eraseBlock(Block * block)1273 void ConversionPatternRewriter::eraseBlock(Block *block) {
1274   impl->notifyBlockIsBeingErased(block);
1275 
1276   // Mark all ops for erasure.
1277   for (Operation &op : *block)
1278     eraseOp(&op);
1279 
1280   // Unlink the block from its parent region. The block is kept in the block
1281   // action and will be actually destroyed when rewrites are applied. This
1282   // allows us to keep the operations in the block live and undo the removal by
1283   // re-inserting the block.
1284   block->getParent()->getBlocks().remove(block);
1285 }
1286 
applySignatureConversion(Region * region,TypeConverter::SignatureConversion & conversion)1287 Block *ConversionPatternRewriter::applySignatureConversion(
1288     Region *region, TypeConverter::SignatureConversion &conversion) {
1289   return impl->applySignatureConversion(region, conversion);
1290 }
1291 
convertRegionTypes(Region * region,TypeConverter & converter,TypeConverter::SignatureConversion * entryConversion)1292 FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
1293     Region *region, TypeConverter &converter,
1294     TypeConverter::SignatureConversion *entryConversion) {
1295   return impl->convertRegionTypes(region, converter, entryConversion);
1296 }
1297 
replaceUsesOfBlockArgument(BlockArgument from,Value to)1298 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
1299                                                            Value to) {
1300   LLVM_DEBUG({
1301     Operation *parentOp = from.getOwner()->getParentOp();
1302     impl->logger.startLine() << "** Replace Argument : '" << from
1303                              << "'(in region of '" << parentOp->getName()
1304                              << "'(" << from.getOwner()->getParentOp() << ")\n";
1305   });
1306   impl->argReplacements.push_back(from);
1307   impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
1308 }
1309 
1310 /// Return the converted value that replaces 'key'. Return 'key' if there is
1311 /// no such a converted value.
getRemappedValue(Value key)1312 Value ConversionPatternRewriter::getRemappedValue(Value key) {
1313   return impl->mapping.lookupOrDefault(key);
1314 }
1315 
1316 /// PatternRewriter hook for creating a new block with the given arguments.
notifyBlockCreated(Block * block)1317 void ConversionPatternRewriter::notifyBlockCreated(Block *block) {
1318   impl->notifyCreatedBlock(block);
1319 }
1320 
1321 /// PatternRewriter hook for splitting a block into two parts.
splitBlock(Block * block,Block::iterator before)1322 Block *ConversionPatternRewriter::splitBlock(Block *block,
1323                                              Block::iterator before) {
1324   auto *continuation = PatternRewriter::splitBlock(block, before);
1325   impl->notifySplitBlock(block, continuation);
1326   return continuation;
1327 }
1328 
1329 /// PatternRewriter hook for merging a block into another.
mergeBlocks(Block * source,Block * dest,ValueRange argValues)1330 void ConversionPatternRewriter::mergeBlocks(Block *source, Block *dest,
1331                                             ValueRange argValues) {
1332   impl->notifyBlocksBeingMerged(dest, source);
1333   assert(llvm::all_of(source->getPredecessors(),
1334                       [dest](Block *succ) { return succ == dest; }) &&
1335          "expected 'source' to have no predecessors or only 'dest'");
1336   assert(argValues.size() == source->getNumArguments() &&
1337          "incorrect # of argument replacement values");
1338   for (auto it : llvm::zip(source->getArguments(), argValues))
1339     replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1340   dest->getOperations().splice(dest->end(), source->getOperations());
1341   eraseBlock(source);
1342 }
1343 
1344 /// PatternRewriter hook for moving blocks out of a region.
inlineRegionBefore(Region & region,Region & parent,Region::iterator before)1345 void ConversionPatternRewriter::inlineRegionBefore(Region &region,
1346                                                    Region &parent,
1347                                                    Region::iterator before) {
1348   impl->notifyRegionIsBeingInlinedBefore(region, parent, before);
1349   PatternRewriter::inlineRegionBefore(region, parent, before);
1350 }
1351 
1352 /// PatternRewriter hook for cloning blocks of one region into another.
cloneRegionBefore(Region & region,Region & parent,Region::iterator before,BlockAndValueMapping & mapping)1353 void ConversionPatternRewriter::cloneRegionBefore(
1354     Region &region, Region &parent, Region::iterator before,
1355     BlockAndValueMapping &mapping) {
1356   if (region.empty())
1357     return;
1358   PatternRewriter::cloneRegionBefore(region, parent, before, mapping);
1359 
1360   // Collect the range of the cloned blocks.
1361   auto clonedBeginIt = mapping.lookup(&region.front())->getIterator();
1362   auto clonedBlocks = llvm::make_range(clonedBeginIt, before);
1363   impl->notifyRegionWasClonedBefore(clonedBlocks, region.getLoc());
1364 }
1365 
1366 /// PatternRewriter hook for creating a new operation.
notifyOperationInserted(Operation * op)1367 void ConversionPatternRewriter::notifyOperationInserted(Operation *op) {
1368   LLVM_DEBUG({
1369     impl->logger.startLine()
1370         << "** Insert  : '" << op->getName() << "'(" << op << ")\n";
1371   });
1372   impl->createdOps.push_back(op);
1373 }
1374 
1375 /// PatternRewriter hook for updating the root operation in-place.
startRootUpdate(Operation * op)1376 void ConversionPatternRewriter::startRootUpdate(Operation *op) {
1377 #ifndef NDEBUG
1378   impl->pendingRootUpdates.insert(op);
1379 #endif
1380   impl->rootUpdates.emplace_back(op);
1381 }
1382 
1383 /// PatternRewriter hook for updating the root operation in-place.
finalizeRootUpdate(Operation * op)1384 void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) {
1385   // There is nothing to do here, we only need to track the operation at the
1386   // start of the update.
1387 #ifndef NDEBUG
1388   assert(impl->pendingRootUpdates.erase(op) &&
1389          "operation did not have a pending in-place update");
1390 #endif
1391 }
1392 
1393 /// PatternRewriter hook for updating the root operation in-place.
cancelRootUpdate(Operation * op)1394 void ConversionPatternRewriter::cancelRootUpdate(Operation *op) {
1395 #ifndef NDEBUG
1396   assert(impl->pendingRootUpdates.erase(op) &&
1397          "operation did not have a pending in-place update");
1398 #endif
1399   // Erase the last update for this operation.
1400   auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; };
1401   auto &rootUpdates = impl->rootUpdates;
1402   auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp);
1403   rootUpdates.erase(rootUpdates.begin() + (rootUpdates.rend() - it));
1404 }
1405 
1406 /// PatternRewriter hook for notifying match failure reasons.
notifyMatchFailure(Operation * op,function_ref<void (Diagnostic &)> reasonCallback)1407 LogicalResult ConversionPatternRewriter::notifyMatchFailure(
1408     Operation *op, function_ref<void(Diagnostic &)> reasonCallback) {
1409   return impl->notifyMatchFailure(op->getLoc(), reasonCallback);
1410 }
1411 
1412 /// Return a reference to the internal implementation.
getImpl()1413 detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
1414   return *impl;
1415 }
1416 
1417 //===----------------------------------------------------------------------===//
1418 // ConversionPattern
1419 //===----------------------------------------------------------------------===//
1420 
1421 /// Attempt to match and rewrite the IR root at the specified operation.
1422 LogicalResult
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1423 ConversionPattern::matchAndRewrite(Operation *op,
1424                                    PatternRewriter &rewriter) const {
1425   auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1426   auto &rewriterImpl = dialectRewriter.getImpl();
1427 
1428   // Track the current conversion pattern in the rewriter.
1429   assert(!rewriterImpl.currentConversionPattern &&
1430          "already inside of a pattern rewrite");
1431   llvm::SaveAndRestore<const ConversionPattern *> currentPatternGuard(
1432       rewriterImpl.currentConversionPattern, this);
1433 
1434   // Remap the operands of the operation.
1435   SmallVector<Value, 4> operands;
1436   if (failed(rewriterImpl.remapValues(op->getLoc(), rewriter,
1437                                       getTypeConverter(), op->getOperands(),
1438                                       operands))) {
1439     return failure();
1440   }
1441   return matchAndRewrite(op, operands, dialectRewriter);
1442 }
1443 
1444 //===----------------------------------------------------------------------===//
1445 // OperationLegalizer
1446 //===----------------------------------------------------------------------===//
1447 
1448 namespace {
1449 /// A set of rewrite patterns that can be used to legalize a given operation.
1450 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1451 
1452 /// This class defines a recursive operation legalizer.
1453 class OperationLegalizer {
1454 public:
1455   using LegalizationAction = ConversionTarget::LegalizationAction;
1456 
1457   OperationLegalizer(ConversionTarget &targetInfo,
1458                      const FrozenRewritePatternList &patterns);
1459 
1460   /// Returns true if the given operation is known to be illegal on the target.
1461   bool isIllegal(Operation *op) const;
1462 
1463   /// Attempt to legalize the given operation. Returns success if the operation
1464   /// was legalized, failure otherwise.
1465   LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1466 
1467   /// Returns the conversion target in use by the legalizer.
getTarget()1468   ConversionTarget &getTarget() { return target; }
1469 
1470 private:
1471   /// Attempt to legalize the given operation by folding it.
1472   LogicalResult legalizeWithFold(Operation *op,
1473                                  ConversionPatternRewriter &rewriter);
1474 
1475   /// Attempt to legalize the given operation by applying a pattern. Returns
1476   /// success if the operation was legalized, failure otherwise.
1477   LogicalResult legalizeWithPattern(Operation *op,
1478                                     ConversionPatternRewriter &rewriter);
1479 
1480   /// Return true if the given pattern may be applied to the given operation,
1481   /// false otherwise.
1482   bool canApplyPattern(Operation *op, const Pattern &pattern,
1483                        ConversionPatternRewriter &rewriter);
1484 
1485   /// Legalize the resultant IR after successfully applying the given pattern.
1486   LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1487                                       ConversionPatternRewriter &rewriter,
1488                                       RewriterState &curState);
1489 
1490   /// Legalizes the actions registered during the execution of a pattern.
1491   LogicalResult legalizePatternBlockActions(Operation *op,
1492                                             ConversionPatternRewriter &rewriter,
1493                                             ConversionPatternRewriterImpl &impl,
1494                                             RewriterState &state,
1495                                             RewriterState &newState);
1496   LogicalResult legalizePatternCreatedOperations(
1497       ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
1498       RewriterState &state, RewriterState &newState);
1499   LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1500                                            ConversionPatternRewriterImpl &impl,
1501                                            RewriterState &state,
1502                                            RewriterState &newState);
1503 
1504   //===--------------------------------------------------------------------===//
1505   // Cost Model
1506   //===--------------------------------------------------------------------===//
1507 
1508   /// Build an optimistic legalization graph given the provided patterns. This
1509   /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
1510   /// patterns for operations that are not directly legal, but may be
1511   /// transitively legal for the current target given the provided patterns.
1512   void buildLegalizationGraph(
1513       LegalizationPatterns &anyOpLegalizerPatterns,
1514       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1515 
1516   /// Compute the benefit of each node within the computed legalization graph.
1517   /// This orders the patterns within 'legalizerPatterns' based upon two
1518   /// criteria:
1519   ///  1) Prefer patterns that have the lowest legalization depth, i.e.
1520   ///     represent the more direct mapping to the target.
1521   ///  2) When comparing patterns with the same legalization depth, prefer the
1522   ///     pattern with the highest PatternBenefit. This allows for users to
1523   ///     prefer specific legalizations over others.
1524   void computeLegalizationGraphBenefit(
1525       LegalizationPatterns &anyOpLegalizerPatterns,
1526       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1527 
1528   /// Compute the legalization depth when legalizing an operation of the given
1529   /// type.
1530   unsigned computeOpLegalizationDepth(
1531       OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1532       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1533 
1534   /// Apply the conversion cost model to the given set of patterns, and return
1535   /// the smallest legalization depth of any of the patterns. See
1536   /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
1537   unsigned applyCostModelToPatterns(
1538       LegalizationPatterns &patterns,
1539       DenseMap<OperationName, unsigned> &minOpPatternDepth,
1540       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1541 
1542   /// The current set of patterns that have been applied.
1543   SmallPtrSet<const Pattern *, 8> appliedPatterns;
1544 
1545   /// The legalization information provided by the target.
1546   ConversionTarget &target;
1547 
1548   /// The pattern applicator to use for conversions.
1549   PatternApplicator applicator;
1550 };
1551 } // namespace
1552 
OperationLegalizer(ConversionTarget & targetInfo,const FrozenRewritePatternList & patterns)1553 OperationLegalizer::OperationLegalizer(ConversionTarget &targetInfo,
1554                                        const FrozenRewritePatternList &patterns)
1555     : target(targetInfo), applicator(patterns) {
1556   // The set of patterns that can be applied to illegal operations to transform
1557   // them into legal ones.
1558   DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
1559   LegalizationPatterns anyOpLegalizerPatterns;
1560 
1561   buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1562   computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1563 }
1564 
isIllegal(Operation * op) const1565 bool OperationLegalizer::isIllegal(Operation *op) const {
1566   // Check if the target explicitly marked this operation as illegal.
1567   return target.getOpAction(op->getName()) == LegalizationAction::Illegal;
1568 }
1569 
1570 LogicalResult
legalize(Operation * op,ConversionPatternRewriter & rewriter)1571 OperationLegalizer::legalize(Operation *op,
1572                              ConversionPatternRewriter &rewriter) {
1573 #ifndef NDEBUG
1574   const char *logLineComment =
1575       "//===-------------------------------------------===//\n";
1576 
1577   auto &rewriterImpl = rewriter.getImpl();
1578 #endif
1579   LLVM_DEBUG({
1580     auto &os = rewriterImpl.logger;
1581     os.getOStream() << "\n";
1582     os.startLine() << logLineComment;
1583     os.startLine() << "Legalizing operation : '" << op->getName() << "'(" << op
1584                    << ") {\n";
1585     os.indent();
1586 
1587     // If the operation has no regions, just print it here.
1588     if (op->getNumRegions() == 0) {
1589       op->print(os.startLine(), OpPrintingFlags().printGenericOpForm());
1590       os.getOStream() << "\n\n";
1591     }
1592   });
1593 
1594   // Check if this operation is legal on the target.
1595   if (auto legalityInfo = target.isLegal(op)) {
1596     LLVM_DEBUG({
1597       logSuccess(
1598           rewriterImpl.logger, "operation marked legal by the target{0}",
1599           legalityInfo->isRecursivelyLegal
1600               ? "; NOTE: operation is recursively legal; skipping internals"
1601               : "");
1602       rewriterImpl.logger.startLine() << logLineComment;
1603     });
1604 
1605     // If this operation is recursively legal, mark its children as ignored so
1606     // that we don't consider them for legalization.
1607     if (legalityInfo->isRecursivelyLegal)
1608       rewriter.getImpl().markNestedOpsIgnored(op);
1609     return success();
1610   }
1611 
1612   // Check to see if the operation is ignored and doesn't need to be converted.
1613   if (rewriter.getImpl().isOpIgnored(op)) {
1614     LLVM_DEBUG({
1615       logSuccess(rewriterImpl.logger,
1616                  "operation marked 'ignored' during conversion");
1617       rewriterImpl.logger.startLine() << logLineComment;
1618     });
1619     return success();
1620   }
1621 
1622   // If the operation isn't legal, try to fold it in-place.
1623   // TODO: Should we always try to do this, even if the op is
1624   // already legal?
1625   if (succeeded(legalizeWithFold(op, rewriter))) {
1626     LLVM_DEBUG({
1627       logSuccess(rewriterImpl.logger, "operation was folded");
1628       rewriterImpl.logger.startLine() << logLineComment;
1629     });
1630     return success();
1631   }
1632 
1633   // Otherwise, we need to apply a legalization pattern to this operation.
1634   if (succeeded(legalizeWithPattern(op, rewriter))) {
1635     LLVM_DEBUG({
1636       logSuccess(rewriterImpl.logger, "");
1637       rewriterImpl.logger.startLine() << logLineComment;
1638     });
1639     return success();
1640   }
1641 
1642   LLVM_DEBUG({
1643     logFailure(rewriterImpl.logger, "no matched legalization pattern");
1644     rewriterImpl.logger.startLine() << logLineComment;
1645   });
1646   return failure();
1647 }
1648 
1649 LogicalResult
legalizeWithFold(Operation * op,ConversionPatternRewriter & rewriter)1650 OperationLegalizer::legalizeWithFold(Operation *op,
1651                                      ConversionPatternRewriter &rewriter) {
1652   auto &rewriterImpl = rewriter.getImpl();
1653   RewriterState curState = rewriterImpl.getCurrentState();
1654 
1655   LLVM_DEBUG({
1656     rewriterImpl.logger.startLine() << "* Fold {\n";
1657     rewriterImpl.logger.indent();
1658   });
1659 
1660   // Try to fold the operation.
1661   SmallVector<Value, 2> replacementValues;
1662   rewriter.setInsertionPoint(op);
1663   if (failed(rewriter.tryFold(op, replacementValues))) {
1664     LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
1665     return failure();
1666   }
1667 
1668   // Insert a replacement for 'op' with the folded replacement values.
1669   rewriter.replaceOp(op, replacementValues);
1670 
1671   // Recursively legalize any new constant operations.
1672   for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
1673        i != e; ++i) {
1674     Operation *cstOp = rewriterImpl.createdOps[i];
1675     if (failed(legalize(cstOp, rewriter))) {
1676       LLVM_DEBUG(logFailure(rewriterImpl.logger,
1677                             "generated constant '{0}' was illegal",
1678                             cstOp->getName()));
1679       rewriterImpl.resetState(curState);
1680       return failure();
1681     }
1682   }
1683 
1684   LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
1685   return success();
1686 }
1687 
1688 LogicalResult
legalizeWithPattern(Operation * op,ConversionPatternRewriter & rewriter)1689 OperationLegalizer::legalizeWithPattern(Operation *op,
1690                                         ConversionPatternRewriter &rewriter) {
1691   auto &rewriterImpl = rewriter.getImpl();
1692 
1693   // Functor that returns if the given pattern may be applied.
1694   auto canApply = [&](const Pattern &pattern) {
1695     return canApplyPattern(op, pattern, rewriter);
1696   };
1697 
1698   // Functor that cleans up the rewriter state after a pattern failed to match.
1699   RewriterState curState = rewriterImpl.getCurrentState();
1700   auto onFailure = [&](const Pattern &pattern) {
1701     LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern failed to match"));
1702     rewriterImpl.resetState(curState);
1703     appliedPatterns.erase(&pattern);
1704   };
1705 
1706   // Functor that performs additional legalization when a pattern is
1707   // successfully applied.
1708   auto onSuccess = [&](const Pattern &pattern) {
1709     auto result = legalizePatternResult(op, pattern, rewriter, curState);
1710     appliedPatterns.erase(&pattern);
1711     if (failed(result))
1712       rewriterImpl.resetState(curState);
1713     return result;
1714   };
1715 
1716   // Try to match and rewrite a pattern on this operation.
1717   return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
1718                                     onSuccess);
1719 }
1720 
canApplyPattern(Operation * op,const Pattern & pattern,ConversionPatternRewriter & rewriter)1721 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
1722                                          ConversionPatternRewriter &rewriter) {
1723   LLVM_DEBUG({
1724     auto &os = rewriter.getImpl().logger;
1725     os.getOStream() << "\n";
1726     os.startLine() << "* Pattern : '" << op->getName() << " -> (";
1727     llvm::interleaveComma(pattern.getGeneratedOps(), llvm::dbgs());
1728     os.getOStream() << ")' {\n";
1729     os.indent();
1730   });
1731 
1732   // Ensure that we don't cycle by not allowing the same pattern to be
1733   // applied twice in the same recursion stack if it is not known to be safe.
1734   if (!pattern.hasBoundedRewriteRecursion() &&
1735       !appliedPatterns.insert(&pattern).second) {
1736     LLVM_DEBUG(
1737         logFailure(rewriter.getImpl().logger, "pattern was already applied"));
1738     return false;
1739   }
1740   return true;
1741 }
1742 
1743 LogicalResult
legalizePatternResult(Operation * op,const Pattern & pattern,ConversionPatternRewriter & rewriter,RewriterState & curState)1744 OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
1745                                           ConversionPatternRewriter &rewriter,
1746                                           RewriterState &curState) {
1747   auto &impl = rewriter.getImpl();
1748 
1749 #ifndef NDEBUG
1750   assert(impl.pendingRootUpdates.empty() && "dangling root updates");
1751 #endif
1752 
1753   // Check that the root was either replaced or updated in place.
1754   auto replacedRoot = [&] {
1755     return llvm::any_of(
1756         llvm::drop_begin(impl.replacements, curState.numReplacements),
1757         [op](auto &it) { return it.first == op; });
1758   };
1759   auto updatedRootInPlace = [&] {
1760     return llvm::any_of(
1761         llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates),
1762         [op](auto &state) { return state.getOperation() == op; });
1763   };
1764   (void)replacedRoot;
1765   (void)updatedRootInPlace;
1766   assert((replacedRoot() || updatedRootInPlace()) &&
1767          "expected pattern to replace the root operation");
1768 
1769   // Legalize each of the actions registered during application.
1770   RewriterState newState = impl.getCurrentState();
1771   if (failed(legalizePatternBlockActions(op, rewriter, impl, curState,
1772                                          newState)) ||
1773       failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
1774       failed(legalizePatternCreatedOperations(rewriter, impl, curState,
1775                                               newState))) {
1776     return failure();
1777   }
1778 
1779   LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
1780   return success();
1781 }
1782 
legalizePatternBlockActions(Operation * op,ConversionPatternRewriter & rewriter,ConversionPatternRewriterImpl & impl,RewriterState & state,RewriterState & newState)1783 LogicalResult OperationLegalizer::legalizePatternBlockActions(
1784     Operation *op, ConversionPatternRewriter &rewriter,
1785     ConversionPatternRewriterImpl &impl, RewriterState &state,
1786     RewriterState &newState) {
1787   SmallPtrSet<Operation *, 16> operationsToIgnore;
1788 
1789   // If the pattern moved or created any blocks, make sure the types of block
1790   // arguments get legalized.
1791   for (int i = state.numBlockActions, e = newState.numBlockActions; i != e;
1792        ++i) {
1793     auto &action = impl.blockActions[i];
1794     if (action.kind == BlockActionKind::TypeConversion ||
1795         action.kind == BlockActionKind::Erase)
1796       continue;
1797     // Only check blocks outside of the current operation.
1798     Operation *parentOp = action.block->getParentOp();
1799     if (!parentOp || parentOp == op || action.block->getNumArguments() == 0)
1800       continue;
1801 
1802     // If the region of the block has a type converter, try to convert the block
1803     // directly.
1804     if (auto *converter =
1805             impl.argConverter.getConverter(action.block->getParent())) {
1806       if (failed(impl.convertBlockSignature(action.block, *converter))) {
1807         LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
1808                                            "block"));
1809         return failure();
1810       }
1811       continue;
1812     }
1813 
1814     // Otherwise, check that this operation isn't one generated by this pattern.
1815     // This is because we will attempt to legalize the parent operation, and
1816     // blocks in regions created by this pattern will already be legalized later
1817     // on. If we haven't built the set yet, build it now.
1818     if (operationsToIgnore.empty()) {
1819       auto createdOps = ArrayRef<Operation *>(impl.createdOps)
1820                             .drop_front(state.numCreatedOps);
1821       operationsToIgnore.insert(createdOps.begin(), createdOps.end());
1822     }
1823 
1824     // If this operation should be considered for re-legalization, try it.
1825     if (operationsToIgnore.insert(parentOp).second &&
1826         failed(legalize(parentOp, rewriter))) {
1827       LLVM_DEBUG(logFailure(
1828           impl.logger, "operation '{0}'({1}) became illegal after block action",
1829           parentOp->getName(), parentOp));
1830       return failure();
1831     }
1832   }
1833   return success();
1834 }
legalizePatternCreatedOperations(ConversionPatternRewriter & rewriter,ConversionPatternRewriterImpl & impl,RewriterState & state,RewriterState & newState)1835 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
1836     ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
1837     RewriterState &state, RewriterState &newState) {
1838   for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) {
1839     Operation *op = impl.createdOps[i];
1840     if (failed(legalize(op, rewriter))) {
1841       LLVM_DEBUG(logFailure(impl.logger,
1842                             "generated operation '{0}'({1}) was illegal",
1843                             op->getName(), op));
1844       return failure();
1845     }
1846   }
1847   return success();
1848 }
legalizePatternRootUpdates(ConversionPatternRewriter & rewriter,ConversionPatternRewriterImpl & impl,RewriterState & state,RewriterState & newState)1849 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
1850     ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
1851     RewriterState &state, RewriterState &newState) {
1852   for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) {
1853     Operation *op = impl.rootUpdates[i].getOperation();
1854     if (failed(legalize(op, rewriter))) {
1855       LLVM_DEBUG(logFailure(impl.logger,
1856                             "operation updated in-place '{0}' was illegal",
1857                             op->getName()));
1858       return failure();
1859     }
1860   }
1861   return success();
1862 }
1863 
1864 //===----------------------------------------------------------------------===//
1865 // Cost Model
1866 
buildLegalizationGraph(LegalizationPatterns & anyOpLegalizerPatterns,DenseMap<OperationName,LegalizationPatterns> & legalizerPatterns)1867 void OperationLegalizer::buildLegalizationGraph(
1868     LegalizationPatterns &anyOpLegalizerPatterns,
1869     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
1870   // A mapping between an operation and a set of operations that can be used to
1871   // generate it.
1872   DenseMap<OperationName, SmallPtrSet<OperationName, 2>> parentOps;
1873   // A mapping between an operation and any currently invalid patterns it has.
1874   DenseMap<OperationName, SmallPtrSet<const Pattern *, 2>> invalidPatterns;
1875   // A worklist of patterns to consider for legality.
1876   llvm::SetVector<const Pattern *> patternWorklist;
1877 
1878   // Build the mapping from operations to the parent ops that may generate them.
1879   applicator.walkAllPatterns([&](const Pattern &pattern) {
1880     Optional<OperationName> root = pattern.getRootKind();
1881 
1882     // If the pattern has no specific root, we can't analyze the relationship
1883     // between the root op and generated operations. Given that, add all such
1884     // patterns to the legalization set.
1885     if (!root) {
1886       anyOpLegalizerPatterns.push_back(&pattern);
1887       return;
1888     }
1889 
1890     // Skip operations that are always known to be legal.
1891     if (target.getOpAction(*root) == LegalizationAction::Legal)
1892       return;
1893 
1894     // Add this pattern to the invalid set for the root op and record this root
1895     // as a parent for any generated operations.
1896     invalidPatterns[*root].insert(&pattern);
1897     for (auto op : pattern.getGeneratedOps())
1898       parentOps[op].insert(*root);
1899 
1900     // Add this pattern to the worklist.
1901     patternWorklist.insert(&pattern);
1902   });
1903 
1904   // If there are any patterns that don't have a specific root kind, we can't
1905   // make direct assumptions about what operations will never be legalized.
1906   // Note: Technically we could, but it would require an analysis that may
1907   // recurse into itself. It would be better to perform this kind of filtering
1908   // at a higher level than here anyways.
1909   if (!anyOpLegalizerPatterns.empty()) {
1910     for (const Pattern *pattern : patternWorklist)
1911       legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
1912     return;
1913   }
1914 
1915   while (!patternWorklist.empty()) {
1916     auto *pattern = patternWorklist.pop_back_val();
1917 
1918     // Check to see if any of the generated operations are invalid.
1919     if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
1920           Optional<LegalizationAction> action = target.getOpAction(op);
1921           return !legalizerPatterns.count(op) &&
1922                  (!action || action == LegalizationAction::Illegal);
1923         }))
1924       continue;
1925 
1926     // Otherwise, if all of the generated operation are valid, this op is now
1927     // legal so add all of the child patterns to the worklist.
1928     legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
1929     invalidPatterns[*pattern->getRootKind()].erase(pattern);
1930 
1931     // Add any invalid patterns of the parent operations to see if they have now
1932     // become legal.
1933     for (auto op : parentOps[*pattern->getRootKind()])
1934       patternWorklist.set_union(invalidPatterns[op]);
1935   }
1936 }
1937 
computeLegalizationGraphBenefit(LegalizationPatterns & anyOpLegalizerPatterns,DenseMap<OperationName,LegalizationPatterns> & legalizerPatterns)1938 void OperationLegalizer::computeLegalizationGraphBenefit(
1939     LegalizationPatterns &anyOpLegalizerPatterns,
1940     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
1941   // The smallest pattern depth, when legalizing an operation.
1942   DenseMap<OperationName, unsigned> minOpPatternDepth;
1943 
1944   // For each operation that is transitively legal, compute a cost for it.
1945   for (auto &opIt : legalizerPatterns)
1946     if (!minOpPatternDepth.count(opIt.first))
1947       computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
1948                                  legalizerPatterns);
1949 
1950   // Apply the cost model to the patterns that can match any operation. Those
1951   // with a specific operation type are already resolved when computing the op
1952   // legalization depth.
1953   if (!anyOpLegalizerPatterns.empty())
1954     applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
1955                              legalizerPatterns);
1956 
1957   // Apply a cost model to the pattern applicator. We order patterns first by
1958   // depth then benefit. `legalizerPatterns` contains per-op patterns by
1959   // decreasing benefit.
1960   applicator.applyCostModel([&](const Pattern &pattern) {
1961     ArrayRef<const Pattern *> orderedPatternList;
1962     if (Optional<OperationName> rootName = pattern.getRootKind())
1963       orderedPatternList = legalizerPatterns[*rootName];
1964     else
1965       orderedPatternList = anyOpLegalizerPatterns;
1966 
1967     // If the pattern is not found, then it was removed and cannot be matched.
1968     auto it = llvm::find(orderedPatternList, &pattern);
1969     if (it == orderedPatternList.end())
1970       return PatternBenefit::impossibleToMatch();
1971 
1972     // Patterns found earlier in the list have higher benefit.
1973     return PatternBenefit(std::distance(it, orderedPatternList.end()));
1974   });
1975 }
1976 
computeOpLegalizationDepth(OperationName op,DenseMap<OperationName,unsigned> & minOpPatternDepth,DenseMap<OperationName,LegalizationPatterns> & legalizerPatterns)1977 unsigned OperationLegalizer::computeOpLegalizationDepth(
1978     OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1979     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
1980   // Check for existing depth.
1981   auto depthIt = minOpPatternDepth.find(op);
1982   if (depthIt != minOpPatternDepth.end())
1983     return depthIt->second;
1984 
1985   // If a mapping for this operation does not exist, then this operation
1986   // is always legal. Return 0 as the depth for a directly legal operation.
1987   auto opPatternsIt = legalizerPatterns.find(op);
1988   if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
1989     return 0u;
1990 
1991   // Record this initial depth in case we encounter this op again when
1992   // recursively computing the depth.
1993   minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
1994 
1995   // Apply the cost model to the operation patterns, and update the minimum
1996   // depth.
1997   unsigned minDepth = applyCostModelToPatterns(
1998       opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
1999   minOpPatternDepth[op] = minDepth;
2000   return minDepth;
2001 }
2002 
applyCostModelToPatterns(LegalizationPatterns & patterns,DenseMap<OperationName,unsigned> & minOpPatternDepth,DenseMap<OperationName,LegalizationPatterns> & legalizerPatterns)2003 unsigned OperationLegalizer::applyCostModelToPatterns(
2004     LegalizationPatterns &patterns,
2005     DenseMap<OperationName, unsigned> &minOpPatternDepth,
2006     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2007   unsigned minDepth = std::numeric_limits<unsigned>::max();
2008 
2009   // Compute the depth for each pattern within the set.
2010   SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
2011   patternsByDepth.reserve(patterns.size());
2012   for (const Pattern *pattern : patterns) {
2013     unsigned depth = 0;
2014     for (auto generatedOp : pattern->getGeneratedOps()) {
2015       unsigned generatedOpDepth = computeOpLegalizationDepth(
2016           generatedOp, minOpPatternDepth, legalizerPatterns);
2017       depth = std::max(depth, generatedOpDepth + 1);
2018     }
2019     patternsByDepth.emplace_back(pattern, depth);
2020 
2021     // Update the minimum depth of the pattern list.
2022     minDepth = std::min(minDepth, depth);
2023   }
2024 
2025   // If the operation only has one legalization pattern, there is no need to
2026   // sort them.
2027   if (patternsByDepth.size() == 1)
2028     return minDepth;
2029 
2030   // Sort the patterns by those likely to be the most beneficial.
2031   llvm::array_pod_sort(patternsByDepth.begin(), patternsByDepth.end(),
2032                        [](const std::pair<const Pattern *, unsigned> *lhs,
2033                           const std::pair<const Pattern *, unsigned> *rhs) {
2034                          // First sort by the smaller pattern legalization
2035                          // depth.
2036                          if (lhs->second != rhs->second)
2037                            return llvm::array_pod_sort_comparator<unsigned>(
2038                                &lhs->second, &rhs->second);
2039 
2040                          // Then sort by the larger pattern benefit.
2041                          auto lhsBenefit = lhs->first->getBenefit();
2042                          auto rhsBenefit = rhs->first->getBenefit();
2043                          return llvm::array_pod_sort_comparator<PatternBenefit>(
2044                              &rhsBenefit, &lhsBenefit);
2045                        });
2046 
2047   // Update the legalization pattern to use the new sorted list.
2048   patterns.clear();
2049   for (auto &patternIt : patternsByDepth)
2050     patterns.push_back(patternIt.first);
2051   return minDepth;
2052 }
2053 
2054 //===----------------------------------------------------------------------===//
2055 // OperationConverter
2056 //===----------------------------------------------------------------------===//
2057 namespace {
2058 enum OpConversionMode {
2059   // In this mode, the conversion will ignore failed conversions to allow
2060   // illegal operations to co-exist in the IR.
2061   Partial,
2062 
2063   // In this mode, all operations must be legal for the given target for the
2064   // conversion to succeed.
2065   Full,
2066 
2067   // In this mode, operations are analyzed for legality. No actual rewrites are
2068   // applied to the operations on success.
2069   Analysis,
2070 };
2071 
2072 // This class converts operations to a given conversion target via a set of
2073 // rewrite patterns. The conversion behaves differently depending on the
2074 // conversion mode.
2075 struct OperationConverter {
OperationConverter__anon942ecdf31911::OperationConverter2076   explicit OperationConverter(ConversionTarget &target,
2077                               const FrozenRewritePatternList &patterns,
2078                               OpConversionMode mode,
2079                               DenseSet<Operation *> *trackedOps = nullptr)
2080       : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
2081 
2082   /// Converts the given operations to the conversion target.
2083   LogicalResult convertOperations(ArrayRef<Operation *> ops);
2084 
2085 private:
2086   /// Converts an operation with the given rewriter.
2087   LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2088 
2089   /// This method is called after the conversion process to legalize any
2090   /// remaining artifacts and complete the conversion.
2091   LogicalResult finalize(ConversionPatternRewriter &rewriter);
2092 
2093   /// Legalize the types of converted block arguments.
2094   LogicalResult
2095   legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
2096                                  ConversionPatternRewriterImpl &rewriterImpl);
2097 
2098   /// Legalize an operation result that was marked as "erased".
2099   LogicalResult
2100   legalizeErasedResult(Operation *op, OpResult result,
2101                        ConversionPatternRewriterImpl &rewriterImpl);
2102 
2103   /// Legalize an operation result that was replaced with a value of a different
2104   /// type.
2105   LogicalResult
2106   legalizeChangedResultType(Operation *op, OpResult result, Value newValue,
2107                             TypeConverter *replConverter,
2108                             ConversionPatternRewriter &rewriter,
2109                             ConversionPatternRewriterImpl &rewriterImpl);
2110 
2111   /// The legalizer to use when converting operations.
2112   OperationLegalizer opLegalizer;
2113 
2114   /// The conversion mode to use when legalizing operations.
2115   OpConversionMode mode;
2116 
2117   /// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
2118   /// this is populated with ops found to be legalizable to the target.
2119   /// When mode == OpConversionMode::Partial, this is populated with ops found
2120   /// *not* to be legalizable to the target.
2121   DenseSet<Operation *> *trackedOps;
2122 };
2123 } // end anonymous namespace
2124 
convert(ConversionPatternRewriter & rewriter,Operation * op)2125 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2126                                           Operation *op) {
2127   // Legalize the given operation.
2128   if (failed(opLegalizer.legalize(op, rewriter))) {
2129     // Handle the case of a failed conversion for each of the different modes.
2130     // Full conversions expect all operations to be converted.
2131     if (mode == OpConversionMode::Full)
2132       return op->emitError()
2133              << "failed to legalize operation '" << op->getName() << "'";
2134     // Partial conversions allow conversions to fail iff the operation was not
2135     // explicitly marked as illegal. If the user provided a nonlegalizableOps
2136     // set, non-legalizable ops are included.
2137     if (mode == OpConversionMode::Partial) {
2138       if (opLegalizer.isIllegal(op))
2139         return op->emitError()
2140                << "failed to legalize operation '" << op->getName()
2141                << "' that was explicitly marked illegal";
2142       if (trackedOps)
2143         trackedOps->insert(op);
2144     }
2145   } else if (mode == OpConversionMode::Analysis) {
2146     // Analysis conversions don't fail if any operations fail to legalize,
2147     // they are only interested in the operations that were successfully
2148     // legalized.
2149     trackedOps->insert(op);
2150   }
2151   return success();
2152 }
2153 
convertOperations(ArrayRef<Operation * > ops)2154 LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2155   if (ops.empty())
2156     return success();
2157   ConversionTarget &target = opLegalizer.getTarget();
2158 
2159   // Compute the set of operations and blocks to convert.
2160   std::vector<Operation *> toConvert;
2161   for (auto *op : ops) {
2162     toConvert.emplace_back(op);
2163     for (auto &region : op->getRegions())
2164       if (failed(computeConversionSet(region.getBlocks(), region.getLoc(),
2165                                       toConvert, &target)))
2166         return failure();
2167   }
2168 
2169   // Convert each operation and discard rewrites on failure.
2170   ConversionPatternRewriter rewriter(ops.front()->getContext());
2171   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2172   for (auto *op : toConvert)
2173     if (failed(convert(rewriter, op)))
2174       return rewriterImpl.discardRewrites(), failure();
2175 
2176   // Now that all of the operations have been converted, finalize the conversion
2177   // process to ensure any lingering conversion artifacts are cleaned up and
2178   // legalized.
2179   if (failed(finalize(rewriter)))
2180     return rewriterImpl.discardRewrites(), failure();
2181 
2182   // After a successful conversion, apply rewrites if this is not an analysis
2183   // conversion.
2184   if (mode == OpConversionMode::Analysis)
2185     rewriterImpl.discardRewrites();
2186   else
2187     rewriterImpl.applyRewrites();
2188   return success();
2189 }
2190 
2191 LogicalResult
finalize(ConversionPatternRewriter & rewriter)2192 OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
2193   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2194 
2195   // Legalize converted block arguments.
2196   if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
2197     return failure();
2198 
2199   // Process requested operation replacements.
2200   for (unsigned i = 0, e = rewriterImpl.operationsWithChangedResults.size();
2201        i != e; ++i) {
2202     unsigned replIdx = rewriterImpl.operationsWithChangedResults[i];
2203     auto &repl = *(rewriterImpl.replacements.begin() + replIdx);
2204     for (OpResult result : repl.first->getResults()) {
2205       Value newValue = rewriterImpl.mapping.lookupOrNull(result);
2206 
2207       // If the operation result was replaced with null, all of the uses of this
2208       // value should be replaced.
2209       if (!newValue) {
2210         if (failed(legalizeErasedResult(repl.first, result, rewriterImpl)))
2211           return failure();
2212         continue;
2213       }
2214 
2215       // Otherwise, check to see if the type of the result changed.
2216       if (result.getType() == newValue.getType())
2217         continue;
2218 
2219       // Legalize this result.
2220       rewriter.setInsertionPoint(repl.first);
2221       if (failed(legalizeChangedResultType(repl.first, result, newValue,
2222                                            repl.second.converter, rewriter,
2223                                            rewriterImpl)))
2224         return failure();
2225 
2226       // Update the end iterator for this loop in the case it was updated
2227       // when legalizing generated conversion operations.
2228       e = rewriterImpl.operationsWithChangedResults.size();
2229     }
2230   }
2231   return success();
2232 }
2233 
legalizeConvertedArgumentTypes(ConversionPatternRewriter & rewriter,ConversionPatternRewriterImpl & rewriterImpl)2234 LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2235     ConversionPatternRewriter &rewriter,
2236     ConversionPatternRewriterImpl &rewriterImpl) {
2237   // Functor used to check if all users of a value will be dead after
2238   // conversion.
2239   auto findLiveUser = [&](Value val) {
2240     auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
2241       return rewriterImpl.isOpIgnored(user);
2242     });
2243     return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
2244   };
2245 
2246   // Materialize any necessary conversions for converted block arguments that
2247   // are still live.
2248   size_t numCreatedOps = rewriterImpl.createdOps.size();
2249   if (failed(rewriterImpl.argConverter.materializeLiveConversions(
2250           rewriterImpl.mapping, rewriter, findLiveUser)))
2251     return failure();
2252 
2253   // Legalize any newly created operations during argument materialization.
2254   for (int i : llvm::seq<int>(numCreatedOps, rewriterImpl.createdOps.size())) {
2255     if (failed(opLegalizer.legalize(rewriterImpl.createdOps[i], rewriter))) {
2256       return rewriterImpl.createdOps[i]->emitError()
2257              << "failed to legalize conversion operation generated for block "
2258                 "argument that remained live after conversion";
2259     }
2260   }
2261   return success();
2262 }
2263 
legalizeErasedResult(Operation * op,OpResult result,ConversionPatternRewriterImpl & rewriterImpl)2264 LogicalResult OperationConverter::legalizeErasedResult(
2265     Operation *op, OpResult result,
2266     ConversionPatternRewriterImpl &rewriterImpl) {
2267   // If the operation result was replaced with null, all of the uses of this
2268   // value should be replaced.
2269   auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
2270     return rewriterImpl.isOpIgnored(user);
2271   });
2272   if (liveUserIt != result.user_end()) {
2273     InFlightDiagnostic diag = op->emitError("failed to legalize operation '")
2274                               << op->getName() << "' marked as erased";
2275     diag.attachNote(liveUserIt->getLoc())
2276         << "found live user of result #" << result.getResultNumber() << ": "
2277         << *liveUserIt;
2278     return failure();
2279   }
2280   return success();
2281 }
2282 
legalizeChangedResultType(Operation * op,OpResult result,Value newValue,TypeConverter * replConverter,ConversionPatternRewriter & rewriter,ConversionPatternRewriterImpl & rewriterImpl)2283 LogicalResult OperationConverter::legalizeChangedResultType(
2284     Operation *op, OpResult result, Value newValue,
2285     TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
2286     ConversionPatternRewriterImpl &rewriterImpl) {
2287   // Walk the users of this value to see if there are any live users that
2288   // weren't replaced during conversion.
2289   auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
2290     return rewriterImpl.isOpIgnored(user);
2291   });
2292   if (liveUserIt == result.user_end())
2293     return success();
2294 
2295   // If the replacement has a type converter, attempt to materialize a
2296   // conversion back to the original type.
2297   if (!replConverter) {
2298     // TODO: We should emit an error here, similarly to the case where the
2299     // result is replaced with null. Unfortunately a lot of existing
2300     // patterns rely on this behavior, so until those patterns are updated
2301     // we keep the legacy behavior here of just forwarding the new value.
2302     return success();
2303   }
2304 
2305   // Track the number of created operations so that new ones can be legalized.
2306   size_t numCreatedOps = rewriterImpl.createdOps.size();
2307 
2308   // Materialize a conversion for this live result value.
2309   Type resultType = result.getType();
2310   Value convertedValue = replConverter->materializeSourceConversion(
2311       rewriter, op->getLoc(), resultType, newValue);
2312   if (!convertedValue) {
2313     InFlightDiagnostic diag = op->emitError()
2314                               << "failed to materialize conversion for result #"
2315                               << result.getResultNumber() << " of operation '"
2316                               << op->getName()
2317                               << "' that remained live after conversion";
2318     diag.attachNote(liveUserIt->getLoc())
2319         << "see existing live user here: " << *liveUserIt;
2320     return failure();
2321   }
2322 
2323   // Legalize all of the newly created conversion operations.
2324   for (int i : llvm::seq<int>(numCreatedOps, rewriterImpl.createdOps.size())) {
2325     if (failed(opLegalizer.legalize(rewriterImpl.createdOps[i], rewriter))) {
2326       return op->emitError("failed to legalize conversion operation generated ")
2327              << "for result #" << result.getResultNumber() << " of operation '"
2328              << op->getName() << "' that remained live after conversion";
2329     }
2330   }
2331 
2332   rewriterImpl.mapping.map(result, convertedValue);
2333   return success();
2334 }
2335 
2336 //===----------------------------------------------------------------------===//
2337 // Type Conversion
2338 //===----------------------------------------------------------------------===//
2339 
2340 /// Remap an input of the original signature with a new set of types. The
2341 /// new types are appended to the new signature conversion.
addInputs(unsigned origInputNo,ArrayRef<Type> types)2342 void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
2343                                                    ArrayRef<Type> types) {
2344   assert(!types.empty() && "expected valid types");
2345   remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
2346   addInputs(types);
2347 }
2348 
2349 /// Append new input types to the signature conversion, this should only be
2350 /// used if the new types are not intended to remap an existing input.
addInputs(ArrayRef<Type> types)2351 void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) {
2352   assert(!types.empty() &&
2353          "1->0 type remappings don't need to be added explicitly");
2354   argTypes.append(types.begin(), types.end());
2355 }
2356 
2357 /// Remap an input of the original signature with a range of types in the
2358 /// new signature.
remapInput(unsigned origInputNo,unsigned newInputNo,unsigned newInputCount)2359 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2360                                                     unsigned newInputNo,
2361                                                     unsigned newInputCount) {
2362   assert(!remappedInputs[origInputNo] && "input has already been remapped");
2363   assert(newInputCount != 0 && "expected valid input count");
2364   remappedInputs[origInputNo] =
2365       InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr};
2366 }
2367 
2368 /// Remap an input of the original signature to another `replacementValue`
2369 /// value. This would make the signature converter drop this argument.
remapInput(unsigned origInputNo,Value replacementValue)2370 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2371                                                     Value replacementValue) {
2372   assert(!remappedInputs[origInputNo] && "input has already been remapped");
2373   remappedInputs[origInputNo] =
2374       InputMapping{origInputNo, /*size=*/0, replacementValue};
2375 }
2376 
2377 /// This hooks allows for converting a type.
convertType(Type t,SmallVectorImpl<Type> & results)2378 LogicalResult TypeConverter::convertType(Type t,
2379                                          SmallVectorImpl<Type> &results) {
2380   auto existingIt = cachedDirectConversions.find(t);
2381   if (existingIt != cachedDirectConversions.end()) {
2382     if (existingIt->second)
2383       results.push_back(existingIt->second);
2384     return success(existingIt->second != nullptr);
2385   }
2386   auto multiIt = cachedMultiConversions.find(t);
2387   if (multiIt != cachedMultiConversions.end()) {
2388     results.append(multiIt->second.begin(), multiIt->second.end());
2389     return success();
2390   }
2391 
2392   // Walk the added converters in reverse order to apply the most recently
2393   // registered first.
2394   size_t currentCount = results.size();
2395   for (ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2396     if (Optional<LogicalResult> result = converter(t, results)) {
2397       if (!succeeded(*result)) {
2398         cachedDirectConversions.try_emplace(t, nullptr);
2399         return failure();
2400       }
2401       auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
2402       if (newTypes.size() == 1)
2403         cachedDirectConversions.try_emplace(t, newTypes.front());
2404       else
2405         cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2406       return success();
2407     }
2408   }
2409   return failure();
2410 }
2411 
2412 /// This hook simplifies defining 1-1 type conversions. This function returns
2413 /// the type to convert to on success, and a null type on failure.
convertType(Type t)2414 Type TypeConverter::convertType(Type t) {
2415   // Use the multi-type result version to convert the type.
2416   SmallVector<Type, 1> results;
2417   if (failed(convertType(t, results)))
2418     return nullptr;
2419 
2420   // Check to ensure that only one type was produced.
2421   return results.size() == 1 ? results.front() : nullptr;
2422 }
2423 
2424 /// Convert the given set of types, filling 'results' as necessary. This
2425 /// returns failure if the conversion of any of the types fails, success
2426 /// otherwise.
convertTypes(ArrayRef<Type> types,SmallVectorImpl<Type> & results)2427 LogicalResult TypeConverter::convertTypes(ArrayRef<Type> types,
2428                                           SmallVectorImpl<Type> &results) {
2429   for (auto type : types)
2430     if (failed(convertType(type, results)))
2431       return failure();
2432   return success();
2433 }
2434 
2435 /// Return true if the given type is legal for this type converter, i.e. the
2436 /// type converts to itself.
isLegal(Type type)2437 bool TypeConverter::isLegal(Type type) { return convertType(type) == type; }
2438 /// Return true if the given operation has legal operand and result types.
isLegal(Operation * op)2439 bool TypeConverter::isLegal(Operation *op) {
2440   return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
2441 }
2442 
2443 /// Return true if the types of block arguments within the region are legal.
isLegal(Region * region)2444 bool TypeConverter::isLegal(Region *region) {
2445   return llvm::all_of(*region, [this](Block &block) {
2446     return isLegal(block.getArgumentTypes());
2447   });
2448 }
2449 
2450 /// Return true if the inputs and outputs of the given function type are
2451 /// legal.
isSignatureLegal(FunctionType ty)2452 bool TypeConverter::isSignatureLegal(FunctionType ty) {
2453   return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
2454 }
2455 
2456 /// This hook allows for converting a specific argument of a signature.
convertSignatureArg(unsigned inputNo,Type type,SignatureConversion & result)2457 LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
2458                                                  SignatureConversion &result) {
2459   // Try to convert the given input type.
2460   SmallVector<Type, 1> convertedTypes;
2461   if (failed(convertType(type, convertedTypes)))
2462     return failure();
2463 
2464   // If this argument is being dropped, there is nothing left to do.
2465   if (convertedTypes.empty())
2466     return success();
2467 
2468   // Otherwise, add the new inputs.
2469   result.addInputs(inputNo, convertedTypes);
2470   return success();
2471 }
convertSignatureArgs(TypeRange types,SignatureConversion & result,unsigned origInputOffset)2472 LogicalResult TypeConverter::convertSignatureArgs(TypeRange types,
2473                                                   SignatureConversion &result,
2474                                                   unsigned origInputOffset) {
2475   for (unsigned i = 0, e = types.size(); i != e; ++i)
2476     if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
2477       return failure();
2478   return success();
2479 }
2480 
materializeConversion(MutableArrayRef<MaterializationCallbackFn> materializations,OpBuilder & builder,Location loc,Type resultType,ValueRange inputs)2481 Value TypeConverter::materializeConversion(
2482     MutableArrayRef<MaterializationCallbackFn> materializations,
2483     OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) {
2484   for (MaterializationCallbackFn &fn : llvm::reverse(materializations))
2485     if (Optional<Value> result = fn(builder, resultType, inputs, loc))
2486       return result.getValue();
2487   return nullptr;
2488 }
2489 
2490 /// This function converts the type signature of the given block, by invoking
2491 /// 'convertSignatureArg' for each argument. This function should return a valid
2492 /// conversion for the signature on success, None otherwise.
convertBlockSignature(Block * block)2493 auto TypeConverter::convertBlockSignature(Block *block)
2494     -> Optional<SignatureConversion> {
2495   SignatureConversion conversion(block->getNumArguments());
2496   if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
2497     return llvm::None;
2498   return conversion;
2499 }
2500 
2501 /// Create a default conversion pattern that rewrites the type signature of a
2502 /// FuncOp.
2503 namespace {
2504 struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
FuncOpSignatureConversion__anon942ecdf31f11::FuncOpSignatureConversion2505   FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
2506       : OpConversionPattern(converter, ctx) {}
2507 
2508   /// Hook for derived classes to implement combined matching and rewriting.
2509   LogicalResult
matchAndRewrite__anon942ecdf31f11::FuncOpSignatureConversion2510   matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
2511                   ConversionPatternRewriter &rewriter) const override {
2512     FunctionType type = funcOp.getType();
2513 
2514     // Convert the original function types.
2515     TypeConverter::SignatureConversion result(type.getNumInputs());
2516     SmallVector<Type, 1> newResults;
2517     if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) ||
2518         failed(typeConverter->convertTypes(type.getResults(), newResults)) ||
2519         failed(rewriter.convertRegionTypes(&funcOp.getBody(), *typeConverter,
2520                                            &result)))
2521       return failure();
2522 
2523     // Update the function signature in-place.
2524     rewriter.updateRootInPlace(funcOp, [&] {
2525       funcOp.setType(FunctionType::get(result.getConvertedTypes(), newResults,
2526                                        funcOp.getContext()));
2527     });
2528     return success();
2529   }
2530 };
2531 } // end anonymous namespace
2532 
populateFuncOpTypeConversionPattern(OwningRewritePatternList & patterns,MLIRContext * ctx,TypeConverter & converter)2533 void mlir::populateFuncOpTypeConversionPattern(
2534     OwningRewritePatternList &patterns, MLIRContext *ctx,
2535     TypeConverter &converter) {
2536   patterns.insert<FuncOpSignatureConversion>(ctx, converter);
2537 }
2538 
2539 //===----------------------------------------------------------------------===//
2540 // ConversionTarget
2541 //===----------------------------------------------------------------------===//
2542 
2543 /// Register a legality action for the given operation.
setOpAction(OperationName op,LegalizationAction action)2544 void ConversionTarget::setOpAction(OperationName op,
2545                                    LegalizationAction action) {
2546   legalOperations[op] = {action, /*isRecursivelyLegal=*/false, llvm::None};
2547 }
2548 
2549 /// Register a legality action for the given dialects.
setDialectAction(ArrayRef<StringRef> dialectNames,LegalizationAction action)2550 void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
2551                                         LegalizationAction action) {
2552   for (StringRef dialect : dialectNames)
2553     legalDialects[dialect] = action;
2554 }
2555 
2556 /// Get the legality action for the given operation.
getOpAction(OperationName op) const2557 auto ConversionTarget::getOpAction(OperationName op) const
2558     -> Optional<LegalizationAction> {
2559   Optional<LegalizationInfo> info = getOpInfo(op);
2560   return info ? info->action : Optional<LegalizationAction>();
2561 }
2562 
2563 /// If the given operation instance is legal on this target, a structure
2564 /// containing legality information is returned. If the operation is not legal,
2565 /// None is returned.
isLegal(Operation * op) const2566 auto ConversionTarget::isLegal(Operation *op) const
2567     -> Optional<LegalOpDetails> {
2568   Optional<LegalizationInfo> info = getOpInfo(op->getName());
2569   if (!info)
2570     return llvm::None;
2571 
2572   // Returns true if this operation instance is known to be legal.
2573   auto isOpLegal = [&] {
2574     // Handle dynamic legality either with the provided legality function, or
2575     // the default hook on the derived instance.
2576     if (info->action == LegalizationAction::Dynamic)
2577       return info->legalityFn ? (*info->legalityFn)(op)
2578                               : isDynamicallyLegal(op);
2579 
2580     // Otherwise, the operation is only legal if it was marked 'Legal'.
2581     return info->action == LegalizationAction::Legal;
2582   };
2583   if (!isOpLegal())
2584     return llvm::None;
2585 
2586   // This operation is legal, compute any additional legality information.
2587   LegalOpDetails legalityDetails;
2588   if (info->isRecursivelyLegal) {
2589     auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
2590     if (legalityFnIt != opRecursiveLegalityFns.end())
2591       legalityDetails.isRecursivelyLegal = legalityFnIt->second(op);
2592     else
2593       legalityDetails.isRecursivelyLegal = true;
2594   }
2595   return legalityDetails;
2596 }
2597 
2598 /// Set the dynamic legality callback for the given operation.
setLegalityCallback(OperationName name,const DynamicLegalityCallbackFn & callback)2599 void ConversionTarget::setLegalityCallback(
2600     OperationName name, const DynamicLegalityCallbackFn &callback) {
2601   assert(callback && "expected valid legality callback");
2602   auto infoIt = legalOperations.find(name);
2603   assert(infoIt != legalOperations.end() &&
2604          infoIt->second.action == LegalizationAction::Dynamic &&
2605          "expected operation to already be marked as dynamically legal");
2606   infoIt->second.legalityFn = callback;
2607 }
2608 
2609 /// Set the recursive legality callback for the given operation and mark the
2610 /// operation as recursively legal.
markOpRecursivelyLegal(OperationName name,const DynamicLegalityCallbackFn & callback)2611 void ConversionTarget::markOpRecursivelyLegal(
2612     OperationName name, const DynamicLegalityCallbackFn &callback) {
2613   auto infoIt = legalOperations.find(name);
2614   assert(infoIt != legalOperations.end() &&
2615          infoIt->second.action != LegalizationAction::Illegal &&
2616          "expected operation to already be marked as legal");
2617   infoIt->second.isRecursivelyLegal = true;
2618   if (callback)
2619     opRecursiveLegalityFns[name] = callback;
2620   else
2621     opRecursiveLegalityFns.erase(name);
2622 }
2623 
2624 /// Set the dynamic legality callback for the given dialects.
setLegalityCallback(ArrayRef<StringRef> dialects,const DynamicLegalityCallbackFn & callback)2625 void ConversionTarget::setLegalityCallback(
2626     ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
2627   assert(callback && "expected valid legality callback");
2628   for (StringRef dialect : dialects)
2629     dialectLegalityFns[dialect] = callback;
2630 }
2631 
2632 /// Get the legalization information for the given operation.
getOpInfo(OperationName op) const2633 auto ConversionTarget::getOpInfo(OperationName op) const
2634     -> Optional<LegalizationInfo> {
2635   // Check for info for this specific operation.
2636   auto it = legalOperations.find(op);
2637   if (it != legalOperations.end())
2638     return it->second;
2639   // Check for info for the parent dialect.
2640   auto dialectIt = legalDialects.find(op.getDialect());
2641   if (dialectIt != legalDialects.end()) {
2642     Optional<DynamicLegalityCallbackFn> callback;
2643     auto dialectFn = dialectLegalityFns.find(op.getDialect());
2644     if (dialectFn != dialectLegalityFns.end())
2645       callback = dialectFn->second;
2646     return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
2647                             callback};
2648   }
2649   // Otherwise, check if we mark unknown operations as dynamic.
2650   if (unknownOpsDynamicallyLegal)
2651     return LegalizationInfo{LegalizationAction::Dynamic,
2652                             /*isRecursivelyLegal=*/false, unknownLegalityFn};
2653   return llvm::None;
2654 }
2655 
2656 //===----------------------------------------------------------------------===//
2657 // Op Conversion Entry Points
2658 //===----------------------------------------------------------------------===//
2659 
2660 /// Apply a partial conversion on the given operations and all nested
2661 /// operations. This method converts as many operations to the target as
2662 /// possible, ignoring operations that failed to legalize. This method only
2663 /// returns failure if there ops explicitly marked as illegal.
2664 /// If an `unconvertedOps` set is provided, all operations that are found not
2665 /// to be legalizable to the given `target` are placed within that set. (Note
2666 /// that if there is an op explicitly marked as illegal, the conversion
2667 /// terminates and the `unconvertedOps` set will not necessarily be complete.)
2668 LogicalResult
applyPartialConversion(ArrayRef<Operation * > ops,ConversionTarget & target,const FrozenRewritePatternList & patterns,DenseSet<Operation * > * unconvertedOps)2669 mlir::applyPartialConversion(ArrayRef<Operation *> ops,
2670                              ConversionTarget &target,
2671                              const FrozenRewritePatternList &patterns,
2672                              DenseSet<Operation *> *unconvertedOps) {
2673   OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
2674                                  unconvertedOps);
2675   return opConverter.convertOperations(ops);
2676 }
2677 LogicalResult
applyPartialConversion(Operation * op,ConversionTarget & target,const FrozenRewritePatternList & patterns,DenseSet<Operation * > * unconvertedOps)2678 mlir::applyPartialConversion(Operation *op, ConversionTarget &target,
2679                              const FrozenRewritePatternList &patterns,
2680                              DenseSet<Operation *> *unconvertedOps) {
2681   return applyPartialConversion(llvm::makeArrayRef(op), target, patterns,
2682                                 unconvertedOps);
2683 }
2684 
2685 /// Apply a complete conversion on the given operations, and all nested
2686 /// operations. This method will return failure if the conversion of any
2687 /// operation fails.
2688 LogicalResult
applyFullConversion(ArrayRef<Operation * > ops,ConversionTarget & target,const FrozenRewritePatternList & patterns)2689 mlir::applyFullConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
2690                           const FrozenRewritePatternList &patterns) {
2691   OperationConverter opConverter(target, patterns, OpConversionMode::Full);
2692   return opConverter.convertOperations(ops);
2693 }
2694 LogicalResult
applyFullConversion(Operation * op,ConversionTarget & target,const FrozenRewritePatternList & patterns)2695 mlir::applyFullConversion(Operation *op, ConversionTarget &target,
2696                           const FrozenRewritePatternList &patterns) {
2697   return applyFullConversion(llvm::makeArrayRef(op), target, patterns);
2698 }
2699 
2700 /// Apply an analysis conversion on the given operations, and all nested
2701 /// operations. This method analyzes which operations would be successfully
2702 /// converted to the target if a conversion was applied. All operations that
2703 /// were found to be legalizable to the given 'target' are placed within the
2704 /// provided 'convertedOps' set; note that no actual rewrites are applied to the
2705 /// operations on success and only pre-existing operations are added to the set.
2706 LogicalResult
applyAnalysisConversion(ArrayRef<Operation * > ops,ConversionTarget & target,const FrozenRewritePatternList & patterns,DenseSet<Operation * > & convertedOps)2707 mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
2708                               ConversionTarget &target,
2709                               const FrozenRewritePatternList &patterns,
2710                               DenseSet<Operation *> &convertedOps) {
2711   OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
2712                                  &convertedOps);
2713   return opConverter.convertOperations(ops);
2714 }
2715 LogicalResult
applyAnalysisConversion(Operation * op,ConversionTarget & target,const FrozenRewritePatternList & patterns,DenseSet<Operation * > & convertedOps)2716 mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
2717                               const FrozenRewritePatternList &patterns,
2718                               DenseSet<Operation *> &convertedOps) {
2719   return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns,
2720                                  convertedOps);
2721 }
2722