1 //===- PatternApplicator.h - PatternApplicator ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements an applicator that applies pattern rewrites based upon a 10 // user defined cost model. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_REWRITE_PATTERNAPPLICATOR_H 15 #define MLIR_REWRITE_PATTERNAPPLICATOR_H 16 17 #include "mlir/Rewrite/FrozenRewritePatternList.h" 18 19 namespace mlir { 20 class PatternRewriter; 21 22 namespace detail { 23 class PDLByteCodeMutableState; 24 } // end namespace detail 25 26 /// This class manages the application of a group of rewrite patterns, with a 27 /// user-provided cost model. 28 class PatternApplicator { 29 public: 30 /// The cost model dynamically assigns a PatternBenefit to a particular 31 /// pattern. Users can query contained patterns and pass analysis results to 32 /// applyCostModel. Patterns to be discarded should have a benefit of 33 /// `impossibleToMatch`. 34 using CostModel = function_ref<PatternBenefit(const Pattern &)>; 35 36 explicit PatternApplicator(const FrozenRewritePatternList &frozenPatternList); 37 ~PatternApplicator(); 38 39 /// Attempt to match and rewrite the given op with any pattern, allowing a 40 /// predicate to decide if a pattern can be applied or not, and hooks for if 41 /// the pattern match was a success or failure. 42 /// 43 /// canApply: called before each match and rewrite attempt; return false to 44 /// skip pattern. 45 /// onFailure: called when a pattern fails to match to perform cleanup. 46 /// onSuccess: called when a pattern match succeeds; return failure() to 47 /// invalidate the match and try another pattern. 48 LogicalResult 49 matchAndRewrite(Operation *op, PatternRewriter &rewriter, 50 function_ref<bool(const Pattern &)> canApply = {}, 51 function_ref<void(const Pattern &)> onFailure = {}, 52 function_ref<LogicalResult(const Pattern &)> onSuccess = {}); 53 54 /// Apply a cost model to the patterns within this applicator. 55 void applyCostModel(CostModel model); 56 57 /// Apply the default cost model that solely uses the pattern's static 58 /// benefit. applyDefaultCostModel()59 void applyDefaultCostModel() { 60 applyCostModel([](const Pattern &pattern) { return pattern.getBenefit(); }); 61 } 62 63 /// Walk all of the patterns within the applicator. 64 void walkAllPatterns(function_ref<void(const Pattern &)> walk); 65 66 private: 67 /// The list that owns the patterns used within this applicator. 68 const FrozenRewritePatternList &frozenPatternList; 69 /// The set of patterns to match for each operation, stable sorted by benefit. 70 DenseMap<OperationName, SmallVector<const RewritePattern *, 2>> patterns; 71 /// The set of patterns that may match against any operation type, stable 72 /// sorted by benefit. 73 SmallVector<const RewritePattern *, 1> anyOpPatterns; 74 /// The mutable state used during execution of the PDL bytecode. 75 std::unique_ptr<detail::PDLByteCodeMutableState> mutableByteCodeState; 76 }; 77 78 } // end namespace mlir 79 80 #endif // MLIR_REWRITE_PATTERNAPPLICATOR_H 81