1 //===- NestedMatcher.cpp - NestedMatcher Impl  ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/NestedMatcher.h"
10 #include "mlir/Dialect/Affine/IR/AffineOps.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 
13 #include "llvm/ADT/ArrayRef.h"
14 #include "llvm/ADT/STLExtras.h"
15 #include "llvm/Support/Allocator.h"
16 #include "llvm/Support/raw_ostream.h"
17 
18 using namespace mlir;
19 
allocator()20 llvm::BumpPtrAllocator *&NestedMatch::allocator() {
21   thread_local llvm::BumpPtrAllocator *allocator = nullptr;
22   return allocator;
23 }
24 
build(Operation * operation,ArrayRef<NestedMatch> nestedMatches)25 NestedMatch NestedMatch::build(Operation *operation,
26                                ArrayRef<NestedMatch> nestedMatches) {
27   auto *result = allocator()->Allocate<NestedMatch>();
28   auto *children = allocator()->Allocate<NestedMatch>(nestedMatches.size());
29   std::uninitialized_copy(nestedMatches.begin(), nestedMatches.end(), children);
30   new (result) NestedMatch();
31   result->matchedOperation = operation;
32   result->matchedChildren =
33       ArrayRef<NestedMatch>(children, nestedMatches.size());
34   return *result;
35 }
36 
allocator()37 llvm::BumpPtrAllocator *&NestedPattern::allocator() {
38   thread_local llvm::BumpPtrAllocator *allocator = nullptr;
39   return allocator;
40 }
41 
NestedPattern(ArrayRef<NestedPattern> nested,FilterFunctionType filter)42 NestedPattern::NestedPattern(ArrayRef<NestedPattern> nested,
43                              FilterFunctionType filter)
44     : nestedPatterns(), filter(filter), skip(nullptr) {
45   if (!nested.empty()) {
46     auto *newNested = allocator()->Allocate<NestedPattern>(nested.size());
47     std::uninitialized_copy(nested.begin(), nested.end(), newNested);
48     nestedPatterns = ArrayRef<NestedPattern>(newNested, nested.size());
49   }
50 }
51 
getDepth() const52 unsigned NestedPattern::getDepth() const {
53   if (nestedPatterns.empty()) {
54     return 1;
55   }
56   unsigned depth = 0;
57   for (auto &c : nestedPatterns) {
58     depth = std::max(depth, c.getDepth());
59   }
60   return depth + 1;
61 }
62 
63 /// Matches a single operation in the following way:
64 ///   1. checks the kind of operation against the matcher, if different then
65 ///      there is no match;
66 ///   2. calls the customizable filter function to refine the single operation
67 ///      match with extra semantic constraints;
68 ///   3. if all is good, recursively matches the nested patterns;
69 ///   4. if all nested match then the single operation matches too and is
70 ///      appended to the list of matches;
71 ///   5. TODO: Optionally applies actions (lambda), in which case we will want
72 ///      to traverse in post-order DFS to avoid invalidating iterators.
matchOne(Operation * op,SmallVectorImpl<NestedMatch> * matches)73 void NestedPattern::matchOne(Operation *op,
74                              SmallVectorImpl<NestedMatch> *matches) {
75   if (skip == op) {
76     return;
77   }
78   // Local custom filter function
79   if (!filter(*op)) {
80     return;
81   }
82 
83   if (nestedPatterns.empty()) {
84     SmallVector<NestedMatch, 8> nestedMatches;
85     matches->push_back(NestedMatch::build(op, nestedMatches));
86     return;
87   }
88   // Take a copy of each nested pattern so we can match it.
89   for (auto nestedPattern : nestedPatterns) {
90     SmallVector<NestedMatch, 8> nestedMatches;
91     // Skip elem in the walk immediately following. Without this we would
92     // essentially need to reimplement walk here.
93     nestedPattern.skip = op;
94     nestedPattern.match(op, &nestedMatches);
95     // If we could not match even one of the specified nestedPattern, early exit
96     // as this whole branch is not a match.
97     if (nestedMatches.empty()) {
98       return;
99     }
100     matches->push_back(NestedMatch::build(op, nestedMatches));
101   }
102 }
103 
isAffineForOp(Operation & op)104 static bool isAffineForOp(Operation &op) { return isa<AffineForOp>(op); }
105 
isAffineIfOp(Operation & op)106 static bool isAffineIfOp(Operation &op) { return isa<AffineIfOp>(op); }
107 
108 namespace mlir {
109 namespace matcher {
110 
Op(FilterFunctionType filter)111 NestedPattern Op(FilterFunctionType filter) {
112   return NestedPattern({}, filter);
113 }
114 
If(NestedPattern child)115 NestedPattern If(NestedPattern child) {
116   return NestedPattern(child, isAffineIfOp);
117 }
If(FilterFunctionType filter,NestedPattern child)118 NestedPattern If(FilterFunctionType filter, NestedPattern child) {
119   return NestedPattern(child, [filter](Operation &op) {
120     return isAffineIfOp(op) && filter(op);
121   });
122 }
If(ArrayRef<NestedPattern> nested)123 NestedPattern If(ArrayRef<NestedPattern> nested) {
124   return NestedPattern(nested, isAffineIfOp);
125 }
If(FilterFunctionType filter,ArrayRef<NestedPattern> nested)126 NestedPattern If(FilterFunctionType filter, ArrayRef<NestedPattern> nested) {
127   return NestedPattern(nested, [filter](Operation &op) {
128     return isAffineIfOp(op) && filter(op);
129   });
130 }
131 
For(NestedPattern child)132 NestedPattern For(NestedPattern child) {
133   return NestedPattern(child, isAffineForOp);
134 }
For(FilterFunctionType filter,NestedPattern child)135 NestedPattern For(FilterFunctionType filter, NestedPattern child) {
136   return NestedPattern(
137       child, [=](Operation &op) { return isAffineForOp(op) && filter(op); });
138 }
For(ArrayRef<NestedPattern> nested)139 NestedPattern For(ArrayRef<NestedPattern> nested) {
140   return NestedPattern(nested, isAffineForOp);
141 }
For(FilterFunctionType filter,ArrayRef<NestedPattern> nested)142 NestedPattern For(FilterFunctionType filter, ArrayRef<NestedPattern> nested) {
143   return NestedPattern(
144       nested, [=](Operation &op) { return isAffineForOp(op) && filter(op); });
145 }
146 
isLoadOrStore(Operation & op)147 bool isLoadOrStore(Operation &op) {
148   return isa<AffineLoadOp, AffineStoreOp>(op);
149 }
150 
151 } // end namespace matcher
152 } // end namespace mlir
153