1 //===- PatternApplicator.h - PatternApplicator ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an applicator that applies pattern rewrites based upon a
10 // user defined cost model.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_REWRITE_PATTERNAPPLICATOR_H
15 #define MLIR_REWRITE_PATTERNAPPLICATOR_H
16 
17 #include "mlir/Rewrite/FrozenRewritePatternList.h"
18 
19 namespace mlir {
20 class PatternRewriter;
21 
22 namespace detail {
23 class PDLByteCodeMutableState;
24 } // end namespace detail
25 
26 /// This class manages the application of a group of rewrite patterns, with a
27 /// user-provided cost model.
28 class PatternApplicator {
29 public:
30   /// The cost model dynamically assigns a PatternBenefit to a particular
31   /// pattern. Users can query contained patterns and pass analysis results to
32   /// applyCostModel. Patterns to be discarded should have a benefit of
33   /// `impossibleToMatch`.
34   using CostModel = function_ref<PatternBenefit(const Pattern &)>;
35 
36   explicit PatternApplicator(const FrozenRewritePatternList &frozenPatternList);
37   ~PatternApplicator();
38 
39   /// Attempt to match and rewrite the given op with any pattern, allowing a
40   /// predicate to decide if a pattern can be applied or not, and hooks for if
41   /// the pattern match was a success or failure.
42   ///
43   /// canApply:  called before each match and rewrite attempt; return false to
44   ///            skip pattern.
45   /// onFailure: called when a pattern fails to match to perform cleanup.
46   /// onSuccess: called when a pattern match succeeds; return failure() to
47   ///            invalidate the match and try another pattern.
48   LogicalResult
49   matchAndRewrite(Operation *op, PatternRewriter &rewriter,
50                   function_ref<bool(const Pattern &)> canApply = {},
51                   function_ref<void(const Pattern &)> onFailure = {},
52                   function_ref<LogicalResult(const Pattern &)> onSuccess = {});
53 
54   /// Apply a cost model to the patterns within this applicator.
55   void applyCostModel(CostModel model);
56 
57   /// Apply the default cost model that solely uses the pattern's static
58   /// benefit.
applyDefaultCostModel()59   void applyDefaultCostModel() {
60     applyCostModel([](const Pattern &pattern) { return pattern.getBenefit(); });
61   }
62 
63   /// Walk all of the patterns within the applicator.
64   void walkAllPatterns(function_ref<void(const Pattern &)> walk);
65 
66 private:
67   /// The list that owns the patterns used within this applicator.
68   const FrozenRewritePatternList &frozenPatternList;
69   /// The set of patterns to match for each operation, stable sorted by benefit.
70   DenseMap<OperationName, SmallVector<const RewritePattern *, 2>> patterns;
71   /// The set of patterns that may match against any operation type, stable
72   /// sorted by benefit.
73   SmallVector<const RewritePattern *, 1> anyOpPatterns;
74   /// The mutable state used during execution of the PDL bytecode.
75   std::unique_ptr<detail::PDLByteCodeMutableState> mutableByteCodeState;
76 };
77 
78 } // end namespace mlir
79 
80 #endif // MLIR_REWRITE_PATTERNAPPLICATOR_H
81