//===- Matchers.h - Various common matchers ---------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file provides a simple and efficient mechanism for performing general // tree-based pattern matching over MLIR. This mechanism is inspired by LLVM's // include/llvm/IR/PatternMatch.h. // //===----------------------------------------------------------------------===// #ifndef MLIR_MATCHERS_H #define MLIR_MATCHERS_H #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" namespace mlir { namespace detail { /// The matcher that matches a certain kind of Attribute and binds the value /// inside the Attribute. template < typename AttrClass, // Require AttrClass to be a derived class from Attribute and get its // value type typename ValueType = typename std::enable_if::value, AttrClass>::type::ValueType, // Require the ValueType is not void typename = typename std::enable_if::value>::type> struct attr_value_binder { ValueType *bind_value; /// Creates a matcher instance that binds the value to bv if match succeeds. attr_value_binder(ValueType *bv) : bind_value(bv) {} bool match(const Attribute &attr) { if (auto intAttr = attr.dyn_cast()) { *bind_value = intAttr.getValue(); return true; } return false; } }; /// The matcher that matches operations that have the `ConstantLike` trait. struct constant_op_matcher { bool match(Operation *op) { return op->hasTrait(); } }; /// The matcher that matches operations that have the `ConstantLike` trait, and /// binds the folded attribute value. template struct constant_op_binder { AttrT *bind_value; /// Creates a matcher instance that binds the constant attribute value to /// bind_value if match succeeds. constant_op_binder(AttrT *bind_value) : bind_value(bind_value) {} /// Creates a matcher instance that doesn't bind if match succeeds. constant_op_binder() : bind_value(nullptr) {} bool match(Operation *op) { if (!op->hasTrait()) return false; // Fold the constant to an attribute. SmallVector foldedOp; LogicalResult result = op->fold(/*operands=*/llvm::None, foldedOp); (void)result; assert(succeeded(result) && "expected ConstantLike op to be foldable"); if (auto attr = foldedOp.front().get().dyn_cast()) { if (bind_value) *bind_value = attr; return true; } return false; } }; /// The matcher that matches a constant scalar / vector splat / tensor splat /// integer operation and binds the constant integer value. struct constant_int_op_binder { IntegerAttr::ValueType *bind_value; /// Creates a matcher instance that binds the value to bv if match succeeds. constant_int_op_binder(IntegerAttr::ValueType *bv) : bind_value(bv) {} bool match(Operation *op) { Attribute attr; if (!constant_op_binder(&attr).match(op)) return false; auto type = op->getResult(0).getType(); if (type.isa()) return attr_value_binder(bind_value).match(attr); if (type.isa()) { if (auto splatAttr = attr.dyn_cast()) { return attr_value_binder(bind_value) .match(splatAttr.getSplatValue()); } } return false; } }; /// The matcher that matches a given target constant scalar / vector splat / /// tensor splat integer value. template struct constant_int_value_matcher { bool match(Operation *op) { APInt value; return constant_int_op_binder(&value).match(op) && TargetValue == value; } }; /// The matcher that matches anything except the given target constant scalar / /// vector splat / tensor splat integer value. template struct constant_int_not_value_matcher { bool match(Operation *op) { APInt value; return constant_int_op_binder(&value).match(op) && TargetNotValue != value; } }; /// The matcher that matches a certain kind of op. template struct op_matcher { bool match(Operation *op) { return isa(op); } }; /// Trait to check whether T provides a 'match' method with type /// `OperationOrValue`. template using has_operation_or_value_matcher_t = decltype(std::declval().match(std::declval())); /// Statically switch to a Value matcher. template typename std::enable_if_t< llvm::is_detected::value, bool> matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) { return matcher.match(op->getOperand(idx)); } /// Statically switch to an Operation matcher. template typename std::enable_if_t< llvm::is_detected::value, bool> matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) { if (auto defOp = op->getOperand(idx).getDefiningOp()) return matcher.match(defOp); return false; } /// Terminal matcher, always returns true. struct AnyValueMatcher { bool match(Value op) const { return true; } }; /// Binds to a specific value and matches it. struct PatternMatcherValue { PatternMatcherValue(Value val) : value(val) {} bool match(Value val) const { return val == value; } Value value; }; template constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback, std::index_sequence) { (void)std::initializer_list{ 0, (callback(std::integral_constant{}, std::get(tuple)), 0)...}; } template constexpr void enumerate(std::tuple &tuple, CallbackT &&callback) { detail::enumerateImpl(tuple, std::forward(callback), std::make_index_sequence{}); } /// RecursivePatternMatcher that composes. template struct RecursivePatternMatcher { RecursivePatternMatcher(OperandMatchers... matchers) : operandMatchers(matchers...) {} bool match(Operation *op) { if (!isa(op) || op->getNumOperands() != sizeof...(OperandMatchers)) return false; bool res = true; enumerate(operandMatchers, [&](size_t index, auto &matcher) { res &= matchOperandOrValueAtIndex(op, index, matcher); }); return res; } std::tuple operandMatchers; }; } // end namespace detail /// Matches a constant foldable operation. inline detail::constant_op_matcher m_Constant() { return detail::constant_op_matcher(); } /// Matches a value from a constant foldable operation and writes the value to /// bind_value. template inline detail::constant_op_binder m_Constant(AttrT *bind_value) { return detail::constant_op_binder(bind_value); } /// Matches a constant scalar / vector splat / tensor splat integer one. inline detail::constant_int_value_matcher<1> m_One() { return detail::constant_int_value_matcher<1>(); } /// Matches the given OpClass. template inline detail::op_matcher m_Op() { return detail::op_matcher(); } /// Matches a constant scalar / vector splat / tensor splat integer zero. inline detail::constant_int_value_matcher<0> m_Zero() { return detail::constant_int_value_matcher<0>(); } /// Matches a constant scalar / vector splat / tensor splat integer that is any /// non-zero value. inline detail::constant_int_not_value_matcher<0> m_NonZero() { return detail::constant_int_not_value_matcher<0>(); } /// Entry point for matching a pattern over a Value. template inline bool matchPattern(Value value, const Pattern &pattern) { // TODO: handle other cases if (auto *op = value.getDefiningOp()) return const_cast(pattern).match(op); return false; } /// Entry point for matching a pattern over an Operation. template inline bool matchPattern(Operation *op, const Pattern &pattern) { return const_cast(pattern).match(op); } /// Matches a constant holding a scalar/vector/tensor integer (splat) and /// writes the integer value to bind_value. inline detail::constant_int_op_binder m_ConstantInt(IntegerAttr::ValueType *bind_value) { return detail::constant_int_op_binder(bind_value); } template auto m_Op(Matchers... matchers) { return detail::RecursivePatternMatcher(matchers...); } namespace matchers { inline auto m_Any() { return detail::AnyValueMatcher(); } inline auto m_Val(Value v) { return detail::PatternMatcherValue(v); } } // namespace matchers } // end namespace mlir #endif // MLIR_MATCHERS_H