1 //===- Matchers.h - Various common matchers ---------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides a simple and efficient mechanism for performing general
10 // tree-based pattern matching over MLIR. This mechanism is inspired by LLVM's
11 // include/llvm/IR/PatternMatch.h.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_MATCHERS_H
16 #define MLIR_MATCHERS_H
17 
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/OpDefinition.h"
20 
21 namespace mlir {
22 
23 namespace detail {
24 
25 /// The matcher that matches a certain kind of Attribute and binds the value
26 /// inside the Attribute.
27 template <
28     typename AttrClass,
29     // Require AttrClass to be a derived class from Attribute and get its
30     // value type
31     typename ValueType =
32         typename std::enable_if<std::is_base_of<Attribute, AttrClass>::value,
33                                 AttrClass>::type::ValueType,
34     // Require the ValueType is not void
35     typename = typename std::enable_if<!std::is_void<ValueType>::value>::type>
36 struct attr_value_binder {
37   ValueType *bind_value;
38 
39   /// Creates a matcher instance that binds the value to bv if match succeeds.
attr_value_binderattr_value_binder40   attr_value_binder(ValueType *bv) : bind_value(bv) {}
41 
matchattr_value_binder42   bool match(const Attribute &attr) {
43     if (auto intAttr = attr.dyn_cast<AttrClass>()) {
44       *bind_value = intAttr.getValue();
45       return true;
46     }
47     return false;
48   }
49 };
50 
51 /// The matcher that matches operations that have the `ConstantLike` trait.
52 struct constant_op_matcher {
matchconstant_op_matcher53   bool match(Operation *op) { return op->hasTrait<OpTrait::ConstantLike>(); }
54 };
55 
56 /// The matcher that matches operations that have the `ConstantLike` trait, and
57 /// binds the folded attribute value.
58 template <typename AttrT> struct constant_op_binder {
59   AttrT *bind_value;
60 
61   /// Creates a matcher instance that binds the constant attribute value to
62   /// bind_value if match succeeds.
constant_op_binderconstant_op_binder63   constant_op_binder(AttrT *bind_value) : bind_value(bind_value) {}
64   /// Creates a matcher instance that doesn't bind if match succeeds.
constant_op_binderconstant_op_binder65   constant_op_binder() : bind_value(nullptr) {}
66 
matchconstant_op_binder67   bool match(Operation *op) {
68     if (!op->hasTrait<OpTrait::ConstantLike>())
69       return false;
70 
71     // Fold the constant to an attribute.
72     SmallVector<OpFoldResult, 1> foldedOp;
73     LogicalResult result = op->fold(/*operands=*/llvm::None, foldedOp);
74     (void)result;
75     assert(succeeded(result) && "expected ConstantLike op to be foldable");
76 
77     if (auto attr = foldedOp.front().get<Attribute>().dyn_cast<AttrT>()) {
78       if (bind_value)
79         *bind_value = attr;
80       return true;
81     }
82     return false;
83   }
84 };
85 
86 /// The matcher that matches a constant scalar / vector splat / tensor splat
87 /// integer operation and binds the constant integer value.
88 struct constant_int_op_binder {
89   IntegerAttr::ValueType *bind_value;
90 
91   /// Creates a matcher instance that binds the value to bv if match succeeds.
constant_int_op_binderconstant_int_op_binder92   constant_int_op_binder(IntegerAttr::ValueType *bv) : bind_value(bv) {}
93 
matchconstant_int_op_binder94   bool match(Operation *op) {
95     Attribute attr;
96     if (!constant_op_binder<Attribute>(&attr).match(op))
97       return false;
98     auto type = op->getResult(0).getType();
99 
100     if (type.isa<IntegerType, IndexType>())
101       return attr_value_binder<IntegerAttr>(bind_value).match(attr);
102     if (type.isa<VectorType, RankedTensorType>()) {
103       if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
104         return attr_value_binder<IntegerAttr>(bind_value)
105             .match(splatAttr.getSplatValue());
106       }
107     }
108     return false;
109   }
110 };
111 
112 /// The matcher that matches a given target constant scalar / vector splat /
113 /// tensor splat integer value.
114 template <int64_t TargetValue> struct constant_int_value_matcher {
matchconstant_int_value_matcher115   bool match(Operation *op) {
116     APInt value;
117     return constant_int_op_binder(&value).match(op) && TargetValue == value;
118   }
119 };
120 
121 /// The matcher that matches anything except the given target constant scalar /
122 /// vector splat / tensor splat integer value.
123 template <int64_t TargetNotValue> struct constant_int_not_value_matcher {
matchconstant_int_not_value_matcher124   bool match(Operation *op) {
125     APInt value;
126     return constant_int_op_binder(&value).match(op) && TargetNotValue != value;
127   }
128 };
129 
130 /// The matcher that matches a certain kind of op.
131 template <typename OpClass> struct op_matcher {
matchop_matcher132   bool match(Operation *op) { return isa<OpClass>(op); }
133 };
134 
135 /// Trait to check whether T provides a 'match' method with type
136 /// `OperationOrValue`.
137 template <typename T, typename OperationOrValue>
138 using has_operation_or_value_matcher_t =
139     decltype(std::declval<T>().match(std::declval<OperationOrValue>()));
140 
141 /// Statically switch to a Value matcher.
142 template <typename MatcherClass>
143 typename std::enable_if_t<
144     llvm::is_detected<detail::has_operation_or_value_matcher_t, MatcherClass,
145                       Value>::value,
146     bool>
matchOperandOrValueAtIndex(Operation * op,unsigned idx,MatcherClass & matcher)147 matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
148   return matcher.match(op->getOperand(idx));
149 }
150 
151 /// Statically switch to an Operation matcher.
152 template <typename MatcherClass>
153 typename std::enable_if_t<
154     llvm::is_detected<detail::has_operation_or_value_matcher_t, MatcherClass,
155                       Operation *>::value,
156     bool>
matchOperandOrValueAtIndex(Operation * op,unsigned idx,MatcherClass & matcher)157 matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
158   if (auto defOp = op->getOperand(idx).getDefiningOp())
159     return matcher.match(defOp);
160   return false;
161 }
162 
163 /// Terminal matcher, always returns true.
164 struct AnyValueMatcher {
matchAnyValueMatcher165   bool match(Value op) const { return true; }
166 };
167 
168 /// Binds to a specific value and matches it.
169 struct PatternMatcherValue {
PatternMatcherValuePatternMatcherValue170   PatternMatcherValue(Value val) : value(val) {}
matchPatternMatcherValue171   bool match(Value val) const { return val == value; }
172   Value value;
173 };
174 
175 template <typename TupleT, class CallbackT, std::size_t... Is>
enumerateImpl(TupleT && tuple,CallbackT && callback,std::index_sequence<Is...>)176 constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback,
177                              std::index_sequence<Is...>) {
178   (void)std::initializer_list<int>{
179       0,
180       (callback(std::integral_constant<std::size_t, Is>{}, std::get<Is>(tuple)),
181        0)...};
182 }
183 
184 template <typename... Tys, typename CallbackT>
enumerate(std::tuple<Tys...> & tuple,CallbackT && callback)185 constexpr void enumerate(std::tuple<Tys...> &tuple, CallbackT &&callback) {
186   detail::enumerateImpl(tuple, std::forward<CallbackT>(callback),
187                         std::make_index_sequence<sizeof...(Tys)>{});
188 }
189 
190 /// RecursivePatternMatcher that composes.
191 template <typename OpType, typename... OperandMatchers>
192 struct RecursivePatternMatcher {
RecursivePatternMatcherRecursivePatternMatcher193   RecursivePatternMatcher(OperandMatchers... matchers)
194       : operandMatchers(matchers...) {}
matchRecursivePatternMatcher195   bool match(Operation *op) {
196     if (!isa<OpType>(op) || op->getNumOperands() != sizeof...(OperandMatchers))
197       return false;
198     bool res = true;
199     enumerate(operandMatchers, [&](size_t index, auto &matcher) {
200       res &= matchOperandOrValueAtIndex(op, index, matcher);
201     });
202     return res;
203   }
204   std::tuple<OperandMatchers...> operandMatchers;
205 };
206 
207 } // end namespace detail
208 
209 /// Matches a constant foldable operation.
m_Constant()210 inline detail::constant_op_matcher m_Constant() {
211   return detail::constant_op_matcher();
212 }
213 
214 /// Matches a value from a constant foldable operation and writes the value to
215 /// bind_value.
216 template <typename AttrT>
m_Constant(AttrT * bind_value)217 inline detail::constant_op_binder<AttrT> m_Constant(AttrT *bind_value) {
218   return detail::constant_op_binder<AttrT>(bind_value);
219 }
220 
221 /// Matches a constant scalar / vector splat / tensor splat integer one.
m_One()222 inline detail::constant_int_value_matcher<1> m_One() {
223   return detail::constant_int_value_matcher<1>();
224 }
225 
226 /// Matches the given OpClass.
m_Op()227 template <typename OpClass> inline detail::op_matcher<OpClass> m_Op() {
228   return detail::op_matcher<OpClass>();
229 }
230 
231 /// Matches a constant scalar / vector splat / tensor splat integer zero.
m_Zero()232 inline detail::constant_int_value_matcher<0> m_Zero() {
233   return detail::constant_int_value_matcher<0>();
234 }
235 
236 /// Matches a constant scalar / vector splat / tensor splat integer that is any
237 /// non-zero value.
m_NonZero()238 inline detail::constant_int_not_value_matcher<0> m_NonZero() {
239   return detail::constant_int_not_value_matcher<0>();
240 }
241 
242 /// Entry point for matching a pattern over a Value.
243 template <typename Pattern>
matchPattern(Value value,const Pattern & pattern)244 inline bool matchPattern(Value value, const Pattern &pattern) {
245   // TODO: handle other cases
246   if (auto *op = value.getDefiningOp())
247     return const_cast<Pattern &>(pattern).match(op);
248   return false;
249 }
250 
251 /// Entry point for matching a pattern over an Operation.
252 template <typename Pattern>
matchPattern(Operation * op,const Pattern & pattern)253 inline bool matchPattern(Operation *op, const Pattern &pattern) {
254   return const_cast<Pattern &>(pattern).match(op);
255 }
256 
257 /// Matches a constant holding a scalar/vector/tensor integer (splat) and
258 /// writes the integer value to bind_value.
259 inline detail::constant_int_op_binder
m_ConstantInt(IntegerAttr::ValueType * bind_value)260 m_ConstantInt(IntegerAttr::ValueType *bind_value) {
261   return detail::constant_int_op_binder(bind_value);
262 }
263 
264 template <typename OpType, typename... Matchers>
m_Op(Matchers...matchers)265 auto m_Op(Matchers... matchers) {
266   return detail::RecursivePatternMatcher<OpType, Matchers...>(matchers...);
267 }
268 
269 namespace matchers {
m_Any()270 inline auto m_Any() { return detail::AnyValueMatcher(); }
m_Val(Value v)271 inline auto m_Val(Value v) { return detail::PatternMatcherValue(v); }
272 } // namespace matchers
273 
274 } // end namespace mlir
275 
276 #endif // MLIR_MATCHERS_H
277