1 //===- PredicateTree.h - Predicate tree node definitions --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains definitions for nodes of a tree structure for representing
10 // the general control flow within a pattern match.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATETREE_H_
15 #define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATETREE_H_
16 
17 #include "Predicate.h"
18 #include "mlir/Dialect/PDL/IR/PDL.h"
19 #include "llvm/ADT/MapVector.h"
20 
21 namespace mlir {
22 class ModuleOp;
23 
24 namespace pdl_to_pdl_interp {
25 
26 class MatcherNode;
27 
28 /// A PositionalPredicate is a predicate that is associated with a specific
29 /// positional value.
30 struct PositionalPredicate {
PositionalPredicatePositionalPredicate31   PositionalPredicate(Position *pos,
32                       const PredicateBuilder::Predicate &predicate)
33       : position(pos), question(predicate.first), answer(predicate.second) {}
34 
35   /// The position the predicate is applied to.
36   Position *position;
37 
38   /// The question that the predicate applies.
39   Qualifier *question;
40 
41   /// The expected answer of the predicate.
42   Qualifier *answer;
43 };
44 
45 //===----------------------------------------------------------------------===//
46 // MatcherNode
47 //===----------------------------------------------------------------------===//
48 
49 /// This class represents the base of a predicate matcher node.
50 class MatcherNode {
51 public:
52   virtual ~MatcherNode() = default;
53 
54   /// Given a module containing PDL pattern operations, generate a matcher tree
55   /// using the patterns within the given module and return the root matcher
56   /// node. `valueToPosition` is a map that is populated with the original
57   /// pdl values and their corresponding positions in the matcher tree.
58   static std::unique_ptr<MatcherNode>
59   generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
60                       DenseMap<Value, Position *> &valueToPosition);
61 
62   /// Returns the position on which the question predicate should be checked.
getPosition()63   Position *getPosition() const { return position; }
64 
65   /// Returns the predicate checked on this node.
getQuestion()66   Qualifier *getQuestion() const { return question; }
67 
68   /// Returns the node that should be visited if this, or a subsequent node
69   /// fails.
getFailureNode()70   std::unique_ptr<MatcherNode> &getFailureNode() { return failureNode; }
71 
72   /// Sets the node that should be visited if this, or a subsequent node fails.
setFailureNode(std::unique_ptr<MatcherNode> node)73   void setFailureNode(std::unique_ptr<MatcherNode> node) {
74     failureNode = std::move(node);
75   }
76 
77   /// Returns the unique type ID of this matcher instance. This should not be
78   /// used directly, and is provided to support type casting.
getMatcherTypeID()79   TypeID getMatcherTypeID() const { return matcherTypeID; }
80 
81 protected:
82   MatcherNode(TypeID matcherTypeID, Position *position = nullptr,
83               Qualifier *question = nullptr,
84               std::unique_ptr<MatcherNode> failureNode = nullptr);
85 
86 private:
87   /// The position on which the predicate should be checked.
88   Position *position;
89 
90   /// The predicate that is checked on the given position.
91   Qualifier *question;
92 
93   /// The node to visit if this node fails.
94   std::unique_ptr<MatcherNode> failureNode;
95 
96   /// An owning store for the failure node if it is owned by this node.
97   std::unique_ptr<MatcherNode> failureNodeStorage;
98 
99   /// A unique identifier for the derived matcher node, used for type casting.
100   TypeID matcherTypeID;
101 };
102 
103 //===----------------------------------------------------------------------===//
104 // BoolNode
105 
106 /// A BoolNode denotes a question with a boolean-like result. These nodes branch
107 /// to a single node on a successful result, otherwise defaulting to the failure
108 /// node.
109 struct BoolNode : public MatcherNode {
110   BoolNode(Position *position, Qualifier *question, Qualifier *answer,
111            std::unique_ptr<MatcherNode> successNode,
112            std::unique_ptr<MatcherNode> failureNode = nullptr);
113 
114   /// Returns if the given matcher node is an instance of this class, used to
115   /// support type casting.
classofBoolNode116   static bool classof(const MatcherNode *node) {
117     return node->getMatcherTypeID() == TypeID::get<BoolNode>();
118   }
119 
120   /// Returns the expected answer of this boolean node.
getAnswerBoolNode121   Qualifier *getAnswer() const { return answer; }
122 
123   /// Returns the node that should be visited on success.
getSuccessNodeBoolNode124   std::unique_ptr<MatcherNode> &getSuccessNode() { return successNode; }
125 
126 private:
127   /// The expected answer of this boolean node.
128   Qualifier *answer;
129 
130   /// The next node if this node succeeds. Otherwise, go to the failure node.
131   std::unique_ptr<MatcherNode> successNode;
132 };
133 
134 //===----------------------------------------------------------------------===//
135 // ExitNode
136 
137 /// An ExitNode is a special sentinel node that denotes the end of matcher.
138 struct ExitNode : public MatcherNode {
ExitNodeExitNode139   ExitNode() : MatcherNode(TypeID::get<ExitNode>()) {}
140 
141   /// Returns if the given matcher node is an instance of this class, used to
142   /// support type casting.
classofExitNode143   static bool classof(const MatcherNode *node) {
144     return node->getMatcherTypeID() == TypeID::get<ExitNode>();
145   }
146 };
147 
148 //===----------------------------------------------------------------------===//
149 // SuccessNode
150 
151 /// A SuccessNode denotes that a given high level pattern has successfully been
152 /// matched. This does not terminate the matcher, as there may be multiple
153 /// successful matches.
154 struct SuccessNode : public MatcherNode {
155   explicit SuccessNode(pdl::PatternOp pattern,
156                        std::unique_ptr<MatcherNode> failureNode);
157 
158   /// Returns if the given matcher node is an instance of this class, used to
159   /// support type casting.
classofSuccessNode160   static bool classof(const MatcherNode *node) {
161     return node->getMatcherTypeID() == TypeID::get<SuccessNode>();
162   }
163 
164   /// Return the high level pattern operation that is matched with this node.
getPatternSuccessNode165   pdl::PatternOp getPattern() const { return pattern; }
166 
167 private:
168   /// The high level pattern operation that was successfully matched with this
169   /// node.
170   pdl::PatternOp pattern;
171 };
172 
173 //===----------------------------------------------------------------------===//
174 // SwitchNode
175 
176 /// A SwitchNode denotes a question with multiple potential results. These nodes
177 /// branch to a specific node based on the result of the question.
178 struct SwitchNode : public MatcherNode {
179   SwitchNode(Position *position, Qualifier *question);
180 
181   /// Returns if the given matcher node is an instance of this class, used to
182   /// support type casting.
classofSwitchNode183   static bool classof(const MatcherNode *node) {
184     return node->getMatcherTypeID() == TypeID::get<SwitchNode>();
185   }
186 
187   /// Returns the children of this switch node. The children are contained
188   /// within a mapping between the various case answers to destination matcher
189   /// nodes.
190   using ChildMapT = llvm::MapVector<Qualifier *, std::unique_ptr<MatcherNode>>;
getChildrenSwitchNode191   ChildMapT &getChildren() { return children; }
192 
193 private:
194   /// Switch predicate "answers" select the child. Answers that are not found
195   /// default to the failure node.
196   ChildMapT children;
197 };
198 
199 } // end namespace pdl_to_pdl_interp
200 } // end namespace mlir
201 
202 #endif // MLIR_CONVERSION_PDLTOPDLINTERP_PREDICATETREE_H_
203