1 //===- PatternMatch.h - PatternMatcher classes -------==---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_PATTERNMATCHER_H
10 #define MLIR_PATTERNMATCHER_H
11 
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 
15 namespace mlir {
16 
17 class PatternRewriter;
18 
19 //===----------------------------------------------------------------------===//
20 // PatternBenefit class
21 //===----------------------------------------------------------------------===//
22 
23 /// This class represents the benefit of a pattern match in a unitless scheme
24 /// that ranges from 0 (very little benefit) to 65K.  The most common unit to
25 /// use here is the "number of operations matched" by the pattern.
26 ///
27 /// This also has a sentinel representation that can be used for patterns that
28 /// fail to match.
29 ///
30 class PatternBenefit {
31   enum { ImpossibleToMatchSentinel = 65535 };
32 
33 public:
PatternBenefit()34   PatternBenefit() : representation(ImpossibleToMatchSentinel) {}
35   PatternBenefit(unsigned benefit);
36   PatternBenefit(const PatternBenefit &) = default;
37   PatternBenefit &operator=(const PatternBenefit &) = default;
38 
impossibleToMatch()39   static PatternBenefit impossibleToMatch() { return PatternBenefit(); }
isImpossibleToMatch()40   bool isImpossibleToMatch() const { return *this == impossibleToMatch(); }
41 
42   /// If the corresponding pattern can match, return its benefit.  If the
43   // corresponding pattern isImpossibleToMatch() then this aborts.
44   unsigned short getBenefit() const;
45 
46   bool operator==(const PatternBenefit &rhs) const {
47     return representation == rhs.representation;
48   }
49   bool operator!=(const PatternBenefit &rhs) const { return !(*this == rhs); }
50   bool operator<(const PatternBenefit &rhs) const {
51     return representation < rhs.representation;
52   }
53   bool operator>(const PatternBenefit &rhs) const { return rhs < *this; }
54   bool operator<=(const PatternBenefit &rhs) const { return !(*this > rhs); }
55   bool operator>=(const PatternBenefit &rhs) const { return !(*this < rhs); }
56 
57 private:
58   unsigned short representation;
59 };
60 
61 //===----------------------------------------------------------------------===//
62 // Pattern
63 //===----------------------------------------------------------------------===//
64 
65 /// This class contains all of the data related to a pattern, but does not
66 /// contain any methods or logic for the actual matching. This class is solely
67 /// used to interface with the metadata of a pattern, such as the benefit or
68 /// root operation.
69 class Pattern {
70 public:
71   /// Return a list of operations that may be generated when rewriting an
72   /// operation instance with this pattern.
getGeneratedOps()73   ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
74 
75   /// Return the root node that this pattern matches. Patterns that can match
76   /// multiple root types return None.
getRootKind()77   Optional<OperationName> getRootKind() const { return rootKind; }
78 
79   /// Return the benefit (the inverse of "cost") of matching this pattern.  The
80   /// benefit of a Pattern is always static - rewrites that may have dynamic
81   /// benefit can be instantiated multiple times (different Pattern instances)
82   /// for each benefit that they may return, and be guarded by different match
83   /// condition predicates.
getBenefit()84   PatternBenefit getBenefit() const { return benefit; }
85 
86   /// Returns true if this pattern is known to result in recursive application,
87   /// i.e. this pattern may generate IR that also matches this pattern, but is
88   /// known to bound the recursion. This signals to a rewrite driver that it is
89   /// safe to apply this pattern recursively to generated IR.
hasBoundedRewriteRecursion()90   bool hasBoundedRewriteRecursion() const { return hasBoundedRecursion; }
91 
92 protected:
93   /// This class acts as a special tag that makes the desire to match "any"
94   /// operation type explicit. This helps to avoid unnecessary usages of this
95   /// feature, and ensures that the user is making a conscious decision.
96   struct MatchAnyOpTypeTag {};
97 
98   /// Construct a pattern with a certain benefit that matches the operation
99   /// with the given root name.
100   Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context);
101   /// Construct a pattern with a certain benefit that matches any operation
102   /// type. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
103   /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
104   /// always be supplied here.
105   Pattern(PatternBenefit benefit, MatchAnyOpTypeTag tag);
106   /// Construct a pattern with a certain benefit that matches the operation with
107   /// the given root name. `generatedNames` contains the names of operations
108   /// that may be generated during a successful rewrite.
109   Pattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
110           PatternBenefit benefit, MLIRContext *context);
111   /// Construct a pattern that may match any operation type. `generatedNames`
112   /// contains the names of operations that may be generated during a successful
113   /// rewrite. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
114   /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
115   /// always be supplied here.
116   Pattern(ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
117           MLIRContext *context, MatchAnyOpTypeTag tag);
118 
119   /// Set the flag detailing if this pattern has bounded rewrite recursion or
120   /// not.
121   void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg = true) {
122     hasBoundedRecursion = hasBoundedRecursionArg;
123   }
124 
125 private:
126   /// A list of the potential operations that may be generated when rewriting
127   /// an op with this pattern.
128   SmallVector<OperationName, 2> generatedOps;
129 
130   /// The root operation of the pattern. If the pattern matches a specific
131   /// operation, this contains the name of that operation. Contains None
132   /// otherwise.
133   Optional<OperationName> rootKind;
134 
135   /// The expected benefit of matching this pattern.
136   const PatternBenefit benefit;
137 
138   /// A boolean flag of whether this pattern has bounded recursion or not.
139   bool hasBoundedRecursion = false;
140 };
141 
142 //===----------------------------------------------------------------------===//
143 // RewritePattern
144 //===----------------------------------------------------------------------===//
145 
146 /// RewritePattern is the common base class for all DAG to DAG replacements.
147 /// There are two possible usages of this class:
148 ///   * Multi-step RewritePattern with "match" and "rewrite"
149 ///     - By overloading the "match" and "rewrite" functions, the user can
150 ///       separate the concerns of matching and rewriting.
151 ///   * Single-step RewritePattern with "matchAndRewrite"
152 ///     - By overloading the "matchAndRewrite" function, the user can perform
153 ///       the rewrite in the same call as the match.
154 ///
155 class RewritePattern : public Pattern {
156 public:
~RewritePattern()157   virtual ~RewritePattern() {}
158 
159   /// Rewrite the IR rooted at the specified operation with the result of
160   /// this pattern, generating any new operations with the specified
161   /// builder.  If an unexpected error is encountered (an internal
162   /// compiler error), it is emitted through the normal MLIR diagnostic
163   /// hooks and the IR is left in a valid state.
164   virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
165 
166   /// Attempt to match against code rooted at the specified operation,
167   /// which is the same operation code as getRootKind().
168   virtual LogicalResult match(Operation *op) const;
169 
170   /// Attempt to match against code rooted at the specified operation,
171   /// which is the same operation code as getRootKind(). If successful, this
172   /// function will automatically perform the rewrite.
matchAndRewrite(Operation * op,PatternRewriter & rewriter)173   virtual LogicalResult matchAndRewrite(Operation *op,
174                                         PatternRewriter &rewriter) const {
175     if (succeeded(match(op))) {
176       rewrite(op, rewriter);
177       return success();
178     }
179     return failure();
180   }
181 
182 protected:
183   /// Inherit the base constructors from `Pattern`.
184   using Pattern::Pattern;
185 
186   /// An anchor for the virtual table.
187   virtual void anchor();
188 };
189 
190 /// OpRewritePattern is a wrapper around RewritePattern that allows for
191 /// matching and rewriting against an instance of a derived operation class as
192 /// opposed to a raw Operation.
193 template <typename SourceOp>
194 struct OpRewritePattern : public RewritePattern {
195   /// Patterns must specify the root operation name they match against, and can
196   /// also specify the benefit of the pattern matching.
197   OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
RewritePatternOpRewritePattern198       : RewritePattern(SourceOp::getOperationName(), benefit, context) {}
199 
200   /// Wrappers around the RewritePattern methods that pass the derived op type.
rewriteOpRewritePattern201   void rewrite(Operation *op, PatternRewriter &rewriter) const final {
202     rewrite(cast<SourceOp>(op), rewriter);
203   }
matchOpRewritePattern204   LogicalResult match(Operation *op) const final {
205     return match(cast<SourceOp>(op));
206   }
matchAndRewriteOpRewritePattern207   LogicalResult matchAndRewrite(Operation *op,
208                                 PatternRewriter &rewriter) const final {
209     return matchAndRewrite(cast<SourceOp>(op), rewriter);
210   }
211 
212   /// Rewrite and Match methods that operate on the SourceOp type. These must be
213   /// overridden by the derived pattern class.
rewriteOpRewritePattern214   virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
215     llvm_unreachable("must override rewrite or matchAndRewrite");
216   }
matchOpRewritePattern217   virtual LogicalResult match(SourceOp op) const {
218     llvm_unreachable("must override match or matchAndRewrite");
219   }
matchAndRewriteOpRewritePattern220   virtual LogicalResult matchAndRewrite(SourceOp op,
221                                         PatternRewriter &rewriter) const {
222     if (succeeded(match(op))) {
223       rewrite(op, rewriter);
224       return success();
225     }
226     return failure();
227   }
228 };
229 
230 //===----------------------------------------------------------------------===//
231 // PDLPatternModule
232 //===----------------------------------------------------------------------===//
233 
234 //===----------------------------------------------------------------------===//
235 // PDLValue
236 
237 /// Storage type of byte-code interpreter values. These are passed to constraint
238 /// functions as arguments.
239 class PDLValue {
240   /// The internal implementation type when the value is an Attribute,
241   /// Operation*, or Type. See `impl` below for more details.
242   using AttrOpTypeImplT = llvm::PointerUnion<Attribute, Operation *, Type>;
243 
244 public:
PDLValue(const PDLValue & other)245   PDLValue(const PDLValue &other) : impl(other.impl) {}
impl()246   PDLValue(std::nullptr_t = nullptr) : impl() {}
PDLValue(Attribute value)247   PDLValue(Attribute value) : impl(value) {}
PDLValue(Operation * value)248   PDLValue(Operation *value) : impl(value) {}
PDLValue(Type value)249   PDLValue(Type value) : impl(value) {}
PDLValue(Value value)250   PDLValue(Value value) : impl(value) {}
251 
252   /// Returns true if the type of the held value is `T`.
253   template <typename T>
isa()254   std::enable_if_t<std::is_same<T, Value>::value, bool> isa() const {
255     return impl.is<Value>();
256   }
257   template <typename T>
isa()258   std::enable_if_t<!std::is_same<T, Value>::value, bool> isa() const {
259     auto attrOpTypeImpl = impl.dyn_cast<AttrOpTypeImplT>();
260     return attrOpTypeImpl && attrOpTypeImpl.is<T>();
261   }
262 
263   /// Attempt to dynamically cast this value to type `T`, returns null if this
264   /// value is not an instance of `T`.
265   template <typename T>
dyn_cast()266   std::enable_if_t<std::is_same<T, Value>::value, T> dyn_cast() const {
267     return impl.dyn_cast<T>();
268   }
269   template <typename T>
dyn_cast()270   std::enable_if_t<!std::is_same<T, Value>::value, T> dyn_cast() const {
271     auto attrOpTypeImpl = impl.dyn_cast<AttrOpTypeImplT>();
272     return attrOpTypeImpl && attrOpTypeImpl.dyn_cast<T>();
273   }
274 
275   /// Cast this value to type `T`, asserts if this value is not an instance of
276   /// `T`.
277   template <typename T>
cast()278   std::enable_if_t<std::is_same<T, Value>::value, T> cast() const {
279     return impl.get<T>();
280   }
281   template <typename T>
cast()282   std::enable_if_t<!std::is_same<T, Value>::value, T> cast() const {
283     return impl.get<AttrOpTypeImplT>().get<T>();
284   }
285 
286   /// Get an opaque pointer to the value.
getAsOpaquePointer()287   void *getAsOpaquePointer() { return impl.getOpaqueValue(); }
288 
289   /// Print this value to the provided output stream.
290   void print(raw_ostream &os);
291 
292 private:
293   /// The internal opaque representation of a PDLValue. We use a nested
294   /// PointerUnion structure here because `Value` only has 1 low bit
295   /// available, where as the remaining types all have 3.
296   llvm::PointerUnion<AttrOpTypeImplT, Value> impl;
297 };
298 
299 inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) {
300   value.print(os);
301   return os;
302 }
303 
304 //===----------------------------------------------------------------------===//
305 // PDLPatternModule
306 
307 /// A generic PDL pattern constraint function. This function applies a
308 /// constraint to a given set of opaque PDLValue entities. The second parameter
309 /// is a set of constant value parameters specified in Attribute form. Returns
310 /// success if the constraint successfully held, failure otherwise.
311 using PDLConstraintFunction = std::function<LogicalResult(
312     ArrayRef<PDLValue>, ArrayAttr, PatternRewriter &)>;
313 /// A native PDL creation function. This function creates a new PDLValue given
314 /// a set of existing PDL values, a set of constant parameters specified in
315 /// Attribute form, and a PatternRewriter. Returns the newly created PDLValue.
316 using PDLCreateFunction =
317     std::function<PDLValue(ArrayRef<PDLValue>, ArrayAttr, PatternRewriter &)>;
318 /// A native PDL rewrite function. This function rewrites the given root
319 /// operation using the provided PatternRewriter. This method is only invoked
320 /// when the corresponding match was successful.
321 using PDLRewriteFunction = std::function<void(Operation *, ArrayRef<PDLValue>,
322                                               ArrayAttr, PatternRewriter &)>;
323 /// A generic PDL pattern constraint function. This function applies a
324 /// constraint to a given opaque PDLValue entity. The second parameter is a set
325 /// of constant value parameters specified in Attribute form. Returns success if
326 /// the constraint successfully held, failure otherwise.
327 using PDLSingleEntityConstraintFunction =
328     std::function<LogicalResult(PDLValue, ArrayAttr, PatternRewriter &)>;
329 
330 /// This class contains all of the necessary data for a set of PDL patterns, or
331 /// pattern rewrites specified in the form of the PDL dialect. This PDL module
332 /// contained by this pattern may contain any number of `pdl.pattern`
333 /// operations.
334 class PDLPatternModule {
335 public:
336   PDLPatternModule() = default;
337 
338   /// Construct a PDL pattern with the given module.
PDLPatternModule(OwningModuleRef pdlModule)339   PDLPatternModule(OwningModuleRef pdlModule)
340       : pdlModule(std::move(pdlModule)) {}
341 
342   /// Merge the state in `other` into this pattern module.
343   void mergeIn(PDLPatternModule &&other);
344 
345   /// Return the internal PDL module of this pattern.
getModule()346   ModuleOp getModule() { return pdlModule.get(); }
347 
348   //===--------------------------------------------------------------------===//
349   // Function Registry
350 
351   /// Register a constraint function.
352   void registerConstraintFunction(StringRef name,
353                                   PDLConstraintFunction constraintFn);
354   /// Register a single entity constraint function.
355   template <typename SingleEntityFn>
356   std::enable_if_t<!llvm::is_invocable<SingleEntityFn, ArrayRef<PDLValue>,
357                                        ArrayAttr, PatternRewriter &>::value>
registerConstraintFunction(StringRef name,SingleEntityFn && constraintFn)358   registerConstraintFunction(StringRef name, SingleEntityFn &&constraintFn) {
359     registerConstraintFunction(
360         name, [constraintFn = std::forward<SingleEntityFn>(constraintFn)](
361                   ArrayRef<PDLValue> values, ArrayAttr constantParams,
362                   PatternRewriter &rewriter) {
363           assert(values.size() == 1 &&
364                  "expected values to have a single entity");
365           return constraintFn(values[0], constantParams, rewriter);
366         });
367   }
368 
369   /// Register a creation function.
370   void registerCreateFunction(StringRef name, PDLCreateFunction createFn);
371 
372   /// Register a rewrite function.
373   void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn);
374 
375   /// Return the set of the registered constraint functions.
getConstraintFunctions()376   const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
377     return constraintFunctions;
378   }
takeConstraintFunctions()379   llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() {
380     return constraintFunctions;
381   }
382   /// Return the set of the registered create functions.
getCreateFunctions()383   const llvm::StringMap<PDLCreateFunction> &getCreateFunctions() const {
384     return createFunctions;
385   }
takeCreateFunctions()386   llvm::StringMap<PDLCreateFunction> takeCreateFunctions() {
387     return createFunctions;
388   }
389   /// Return the set of the registered rewrite functions.
getRewriteFunctions()390   const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions() const {
391     return rewriteFunctions;
392   }
takeRewriteFunctions()393   llvm::StringMap<PDLRewriteFunction> takeRewriteFunctions() {
394     return rewriteFunctions;
395   }
396 
397   /// Clear out the patterns and functions within this module.
clear()398   void clear() {
399     pdlModule = nullptr;
400     constraintFunctions.clear();
401     createFunctions.clear();
402     rewriteFunctions.clear();
403   }
404 
405 private:
406   /// The module containing the `pdl.pattern` operations.
407   OwningModuleRef pdlModule;
408 
409   /// The external functions referenced from within the PDL module.
410   llvm::StringMap<PDLConstraintFunction> constraintFunctions;
411   llvm::StringMap<PDLCreateFunction> createFunctions;
412   llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
413 };
414 
415 //===----------------------------------------------------------------------===//
416 // PatternRewriter
417 //===----------------------------------------------------------------------===//
418 
419 /// This class coordinates the application of a pattern to the current function,
420 /// providing a way to create operations and keep track of what gets deleted.
421 ///
422 /// These class serves two purposes:
423 ///  1) it is the interface that patterns interact with to make mutations to the
424 ///     IR they are being applied to.
425 ///  2) It is a base class that clients of the PatternMatcher use when they want
426 ///     to apply patterns and observe their effects (e.g. to keep worklists or
427 ///     other data structures up to date).
428 ///
429 class PatternRewriter : public OpBuilder, public OpBuilder::Listener {
430 public:
431   /// Move the blocks that belong to "region" before the given position in
432   /// another region "parent". The two regions must be different. The caller
433   /// is responsible for creating or updating the operation transferring flow
434   /// of control to the region and passing it the correct block arguments.
435   virtual void inlineRegionBefore(Region &region, Region &parent,
436                                   Region::iterator before);
437   void inlineRegionBefore(Region &region, Block *before);
438 
439   /// Clone the blocks that belong to "region" before the given position in
440   /// another region "parent". The two regions must be different. The caller is
441   /// responsible for creating or updating the operation transferring flow of
442   /// control to the region and passing it the correct block arguments.
443   virtual void cloneRegionBefore(Region &region, Region &parent,
444                                  Region::iterator before,
445                                  BlockAndValueMapping &mapping);
446   void cloneRegionBefore(Region &region, Region &parent,
447                          Region::iterator before);
448   void cloneRegionBefore(Region &region, Block *before);
449 
450   /// This method performs the final replacement for a pattern, where the
451   /// results of the operation are updated to use the specified list of SSA
452   /// values.
453   virtual void replaceOp(Operation *op, ValueRange newValues);
454 
455   /// Replaces the result op with a new op that is created without verification.
456   /// The result values of the two ops must be the same types.
457   template <typename OpTy, typename... Args>
replaceOpWithNewOp(Operation * op,Args &&...args)458   void replaceOpWithNewOp(Operation *op, Args &&... args) {
459     auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
460     replaceOpWithResultsOfAnotherOp(op, newOp.getOperation());
461   }
462 
463   /// This method erases an operation that is known to have no uses.
464   virtual void eraseOp(Operation *op);
465 
466   /// This method erases all operations in a block.
467   virtual void eraseBlock(Block *block);
468 
469   /// Merge the operations of block 'source' into the end of block 'dest'.
470   /// 'source's predecessors must either be empty or only contain 'dest`.
471   /// 'argValues' is used to replace the block arguments of 'source' after
472   /// merging.
473   virtual void mergeBlocks(Block *source, Block *dest,
474                            ValueRange argValues = llvm::None);
475 
476   // Merge the operations of block 'source' before the operation 'op'. Source
477   // block should not have existing predecessors or successors.
478   void mergeBlockBefore(Block *source, Operation *op,
479                         ValueRange argValues = llvm::None);
480 
481   /// Split the operations starting at "before" (inclusive) out of the given
482   /// block into a new block, and return it.
483   virtual Block *splitBlock(Block *block, Block::iterator before);
484 
485   /// This method is used to notify the rewriter that an in-place operation
486   /// modification is about to happen. A call to this function *must* be
487   /// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`.
488   /// This is a minor efficiency win (it avoids creating a new operation and
489   /// removing the old one) but also often allows simpler code in the client.
startRootUpdate(Operation * op)490   virtual void startRootUpdate(Operation *op) {}
491 
492   /// This method is used to signal the end of a root update on the given
493   /// operation. This can only be called on operations that were provided to a
494   /// call to `startRootUpdate`.
finalizeRootUpdate(Operation * op)495   virtual void finalizeRootUpdate(Operation *op) {}
496 
497   /// This method cancels a pending root update. This can only be called on
498   /// operations that were provided to a call to `startRootUpdate`.
cancelRootUpdate(Operation * op)499   virtual void cancelRootUpdate(Operation *op) {}
500 
501   /// This method is a utility wrapper around a root update of an operation. It
502   /// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given
503   /// callable.
504   template <typename CallableT>
updateRootInPlace(Operation * root,CallableT && callable)505   void updateRootInPlace(Operation *root, CallableT &&callable) {
506     startRootUpdate(root);
507     callable();
508     finalizeRootUpdate(root);
509   }
510 
511   /// Notify the pattern rewriter that the pattern is failing to match the given
512   /// operation, and provide a callback to populate a diagnostic with the reason
513   /// why the failure occurred. This method allows for derived rewriters to
514   /// optionally hook into the reason why a pattern failed, and display it to
515   /// users.
516   template <typename CallbackT>
517   std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
notifyMatchFailure(Operation * op,CallbackT && reasonCallback)518   notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
519 #ifndef NDEBUG
520     return notifyMatchFailure(op,
521                               function_ref<void(Diagnostic &)>(reasonCallback));
522 #else
523     return failure();
524 #endif
525   }
notifyMatchFailure(Operation * op,const Twine & msg)526   LogicalResult notifyMatchFailure(Operation *op, const Twine &msg) {
527     return notifyMatchFailure(op, [&](Diagnostic &diag) { diag << msg; });
528   }
notifyMatchFailure(Operation * op,const char * msg)529   LogicalResult notifyMatchFailure(Operation *op, const char *msg) {
530     return notifyMatchFailure(op, Twine(msg));
531   }
532 
533 protected:
534   /// Initialize the builder with this rewriter as the listener.
PatternRewriter(MLIRContext * ctx)535   explicit PatternRewriter(MLIRContext *ctx)
536       : OpBuilder(ctx, /*listener=*/this) {}
537   ~PatternRewriter() override;
538 
539   /// These are the callback methods that subclasses can choose to implement if
540   /// they would like to be notified about certain types of mutations.
541 
542   /// Notify the pattern rewriter that the specified operation is about to be
543   /// replaced with another set of operations.  This is called before the uses
544   /// of the operation have been changed.
notifyRootReplaced(Operation * op)545   virtual void notifyRootReplaced(Operation *op) {}
546 
547   /// This is called on an operation that a pattern match is removing, right
548   /// before the operation is deleted.  At this point, the operation has zero
549   /// uses.
notifyOperationRemoved(Operation * op)550   virtual void notifyOperationRemoved(Operation *op) {}
551 
552   /// Notify the pattern rewriter that the pattern is failing to match the given
553   /// operation, and provide a callback to populate a diagnostic with the reason
554   /// why the failure occurred. This method allows for derived rewriters to
555   /// optionally hook into the reason why a pattern failed, and display it to
556   /// users.
557   virtual LogicalResult
notifyMatchFailure(Operation * op,function_ref<void (Diagnostic &)> reasonCallback)558   notifyMatchFailure(Operation *op,
559                      function_ref<void(Diagnostic &)> reasonCallback) {
560     return failure();
561   }
562 
563 private:
564   /// 'op' and 'newOp' are known to have the same number of results, replace the
565   /// uses of op with uses of newOp.
566   void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp);
567 };
568 
569 //===----------------------------------------------------------------------===//
570 // OwningRewritePatternList
571 //===----------------------------------------------------------------------===//
572 
573 class OwningRewritePatternList {
574   using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
575 
576 public:
577   OwningRewritePatternList() = default;
578 
579   /// Construct a OwningRewritePatternList populated with the given pattern.
OwningRewritePatternList(std::unique_ptr<RewritePattern> pattern)580   OwningRewritePatternList(std::unique_ptr<RewritePattern> pattern) {
581     nativePatterns.emplace_back(std::move(pattern));
582   }
OwningRewritePatternList(PDLPatternModule && pattern)583   OwningRewritePatternList(PDLPatternModule &&pattern)
584       : pdlPatterns(std::move(pattern)) {}
585 
586   /// Return the native patterns held in this list.
getNativePatterns()587   NativePatternListT &getNativePatterns() { return nativePatterns; }
588 
589   /// Return the PDL patterns held in this list.
getPDLPatterns()590   PDLPatternModule &getPDLPatterns() { return pdlPatterns; }
591 
592   /// Clear out all of the held patterns in this list.
clear()593   void clear() {
594     nativePatterns.clear();
595     pdlPatterns.clear();
596   }
597 
598   //===--------------------------------------------------------------------===//
599   // Pattern Insertion
600   //===--------------------------------------------------------------------===//
601 
602   /// Add an instance of each of the pattern types 'Ts' to the pattern list with
603   /// the given arguments. Return a reference to `this` for chaining insertions.
604   /// Note: ConstructorArg is necessary here to separate the two variadic lists.
605   template <typename... Ts, typename ConstructorArg,
606             typename... ConstructorArgs,
607             typename = std::enable_if_t<sizeof...(Ts) != 0>>
insert(ConstructorArg && arg,ConstructorArgs &&...args)608   OwningRewritePatternList &insert(ConstructorArg &&arg,
609                                    ConstructorArgs &&...args) {
610     // The following expands a call to emplace_back for each of the pattern
611     // types 'Ts'. This magic is necessary due to a limitation in the places
612     // that a parameter pack can be expanded in c++11.
613     // FIXME: In c++17 this can be simplified by using 'fold expressions'.
614     (void)std::initializer_list<int>{0, (insertImpl<Ts>(arg, args...), 0)...};
615     return *this;
616   }
617 
618   /// Add an instance of each of the pattern types 'Ts'. Return a reference to
619   /// `this` for chaining insertions.
insert()620   template <typename... Ts> OwningRewritePatternList &insert() {
621     (void)std::initializer_list<int>{0, (insertImpl<Ts>(), 0)...};
622     return *this;
623   }
624 
625   /// Add the given native pattern to the pattern list. Return a reference to
626   /// `this` for chaining insertions.
insert(std::unique_ptr<RewritePattern> pattern)627   OwningRewritePatternList &insert(std::unique_ptr<RewritePattern> pattern) {
628     nativePatterns.emplace_back(std::move(pattern));
629     return *this;
630   }
631 
632   /// Add the given PDL pattern to the pattern list. Return a reference to
633   /// `this` for chaining insertions.
insert(PDLPatternModule && pattern)634   OwningRewritePatternList &insert(PDLPatternModule &&pattern) {
635     pdlPatterns.mergeIn(std::move(pattern));
636     return *this;
637   }
638 
639 private:
640   /// Add an instance of the pattern type 'T'. Return a reference to `this` for
641   /// chaining insertions.
642   template <typename T, typename... Args>
643   std::enable_if_t<std::is_base_of<RewritePattern, T>::value>
insertImpl(Args &&...args)644   insertImpl(Args &&...args) {
645     nativePatterns.emplace_back(
646         std::make_unique<T>(std::forward<Args>(args)...));
647   }
648   template <typename T, typename... Args>
649   std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value>
insertImpl(Args &&...args)650   insertImpl(Args &&...args) {
651     pdlPatterns.mergeIn(T(std::forward<Args>(args)...));
652   }
653 
654   NativePatternListT nativePatterns;
655   PDLPatternModule pdlPatterns;
656 };
657 
658 } // end namespace mlir
659 
660 #endif // MLIR_PATTERN_MATCH_H
661