1 //===- NestedMacher.h - Nested matcher for Function -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_
10 #define MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_
11 
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/Operation.h"
14 #include "llvm/Support/Allocator.h"
15 
16 namespace mlir {
17 
18 class NestedPattern;
19 class Operation;
20 
21 /// An NestedPattern captures nested patterns in the IR.
22 /// It is used in conjunction with a scoped NestedPatternContext which is an
23 /// llvm::BumpPtrAllocator that handles memory allocations efficiently and
24 /// avoids ownership issues.
25 ///
26 /// In order to use NestedPatterns, first create a scoped context.
27 /// When the context goes out of scope, everything is freed.
28 /// This design simplifies the API by avoiding references to the context and
29 /// makes it clear that references to matchers must not escape.
30 ///
31 /// Example:
32 ///   {
33 ///      NestedPatternContext context;
34 ///      auto gemmLike = Doall(Doall(Red(LoadStores())));
35 ///      auto matches = gemmLike.match(f);
36 ///      // do work on matches
37 ///   }  // everything is freed
38 ///
39 ///
40 /// Nested abstraction for matching results.
41 /// Provides access to the nested Operation* captured by a Matcher.
42 ///
43 /// A NestedMatch contains an Operation* and the children NestedMatch and is
44 /// thus cheap to copy. NestedMatch is stored in a scoped bumper allocator whose
45 /// lifetime is managed by an RAII NestedPatternContext.
46 class NestedMatch {
47 public:
48   static NestedMatch build(Operation *operation,
49                            ArrayRef<NestedMatch> nestedMatches);
50   NestedMatch(const NestedMatch &) = default;
51   NestedMatch &operator=(const NestedMatch &) = default;
52 
53   explicit operator bool() { return matchedOperation != nullptr; }
54 
getMatchedOperation()55   Operation *getMatchedOperation() { return matchedOperation; }
getMatchedChildren()56   ArrayRef<NestedMatch> getMatchedChildren() { return matchedChildren; }
57 
58 private:
59   friend class NestedPattern;
60   friend class NestedPatternContext;
61 
62   /// Underlying global bump allocator managed by a NestedPatternContext.
63   static llvm::BumpPtrAllocator *&allocator();
64 
65   NestedMatch() = default;
66 
67   /// Payload, holds a NestedMatch and all its children along this branch.
68   Operation *matchedOperation;
69   ArrayRef<NestedMatch> matchedChildren;
70 };
71 
72 /// A NestedPattern is a nested operation walker that:
73 ///   1. recursively matches a substructure in the tree;
74 ///   2. uses a filter function to refine matches with extra semantic
75 ///      constraints (passed via a lambda of type FilterFunctionType);
76 ///   3. TODO: optionally applies actions (lambda).
77 ///
78 /// Nested patterns are meant to capture imperfectly nested loops while matching
79 /// properties over the whole loop nest. For instance, in vectorization we are
80 /// interested in capturing all the imperfectly nested loops of a certain type
81 /// and such that all the load and stores have certain access patterns along the
82 /// loops' induction variables). Such NestedMatches are first captured using the
83 /// `match` function and are later processed to analyze properties and apply
84 /// transformations in a non-greedy way.
85 ///
86 /// The NestedMatches captured in the IR can grow large, especially after
87 /// aggressive unrolling. As experience has shown, it is generally better to use
88 /// a plain walk over operations to match flat patterns but the current
89 /// implementation is competitive nonetheless.
90 using FilterFunctionType = std::function<bool(Operation &)>;
defaultFilterFunction(Operation &)91 inline bool defaultFilterFunction(Operation &) { return true; }
92 class NestedPattern {
93 public:
94   NestedPattern(ArrayRef<NestedPattern> nested,
95                 FilterFunctionType filter = defaultFilterFunction);
96   NestedPattern(const NestedPattern &) = default;
97   NestedPattern &operator=(const NestedPattern &) = default;
98 
99   /// Returns all the top-level matches in `func`.
match(FuncOp func,SmallVectorImpl<NestedMatch> * matches)100   void match(FuncOp func, SmallVectorImpl<NestedMatch> *matches) {
101     func.walk([&](Operation *op) { matchOne(op, matches); });
102   }
103 
104   /// Returns all the top-level matches in `op`.
match(Operation * op,SmallVectorImpl<NestedMatch> * matches)105   void match(Operation *op, SmallVectorImpl<NestedMatch> *matches) {
106     op->walk([&](Operation *child) { matchOne(child, matches); });
107   }
108 
109   /// Returns the depth of the pattern.
110   unsigned getDepth() const;
111 
112 private:
113   friend class NestedPatternContext;
114   friend class NestedMatch;
115   friend struct State;
116 
117   /// Underlying global bump allocator managed by a NestedPatternContext.
118   static llvm::BumpPtrAllocator *&allocator();
119 
120   /// Matches this pattern against a single `op` and fills matches with the
121   /// result.
122   void matchOne(Operation *op, SmallVectorImpl<NestedMatch> *matches);
123 
124   /// Nested patterns to be matched.
125   ArrayRef<NestedPattern> nestedPatterns;
126 
127   /// Extra filter function to apply to prune patterns as the IR is walked.
128   FilterFunctionType filter;
129 
130   /// skip is an implementation detail needed so that we can implement match
131   /// without switching on the type of the Operation. The idea is that a
132   /// NestedPattern first checks if it matches locally and then recursively
133   /// applies its nested matchers to its elem->nested. Since we want to rely on
134   /// the existing operation walking functionality rather than duplicate
135   /// it, we allow an off-by-one traversal to account for the fact that we
136   /// write:
137   ///
138   ///  void match(Operation *elem) {
139   ///    for (auto &c : getNestedPatterns()) {
140   ///      NestedPattern childPattern(...);
141   ///                                  ^~~~ Needs off-by-one skip.
142   ///
143   Operation *skip;
144 };
145 
146 /// RAII structure to transparently manage the bump allocator for
147 /// NestedPattern and NestedMatch classes. This avoids passing a context to
148 /// all the API functions.
149 class NestedPatternContext {
150 public:
NestedPatternContext()151   NestedPatternContext() {
152     assert(NestedMatch::allocator() == nullptr &&
153            "Only a single NestedPatternContext is supported");
154     assert(NestedPattern::allocator() == nullptr &&
155            "Only a single NestedPatternContext is supported");
156     NestedMatch::allocator() = &allocator;
157     NestedPattern::allocator() = &allocator;
158   }
~NestedPatternContext()159   ~NestedPatternContext() {
160     NestedMatch::allocator() = nullptr;
161     NestedPattern::allocator() = nullptr;
162   }
163   llvm::BumpPtrAllocator allocator;
164 };
165 
166 namespace matcher {
167 // Syntactic sugar NestedPattern builder functions.
168 NestedPattern Op(FilterFunctionType filter = defaultFilterFunction);
169 NestedPattern If(NestedPattern child);
170 NestedPattern If(FilterFunctionType filter, NestedPattern child);
171 NestedPattern If(ArrayRef<NestedPattern> nested = {});
172 NestedPattern If(FilterFunctionType filter,
173                  ArrayRef<NestedPattern> nested = {});
174 NestedPattern For(NestedPattern child);
175 NestedPattern For(FilterFunctionType filter, NestedPattern child);
176 NestedPattern For(ArrayRef<NestedPattern> nested = {});
177 NestedPattern For(FilterFunctionType filter,
178                   ArrayRef<NestedPattern> nested = {});
179 
180 bool isParallelLoop(Operation &op);
181 bool isReductionLoop(Operation &op);
182 bool isLoadOrStore(Operation &op);
183 
184 } // end namespace matcher
185 } // end namespace mlir
186 
187 #endif // MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_
188