1 //===- PredicateTree.cpp - Predicate tree merging -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PredicateTree.h"
10 #include "mlir/Dialect/PDL/IR/PDL.h"
11 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
12 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/Interfaces/InferTypeOpInterface.h"
15 
16 using namespace mlir;
17 using namespace mlir::pdl_to_pdl_interp;
18 
19 //===----------------------------------------------------------------------===//
20 // Predicate List Building
21 //===----------------------------------------------------------------------===//
22 
23 /// Compares the depths of two positions.
comparePosDepth(Position * lhs,Position * rhs)24 static bool comparePosDepth(Position *lhs, Position *rhs) {
25   return lhs->getIndex().size() < rhs->getIndex().size();
26 }
27 
28 /// Collect the tree predicates anchored at the given value.
getTreePredicates(std::vector<PositionalPredicate> & predList,Value val,PredicateBuilder & builder,DenseMap<Value,Position * > & inputs,Position * pos)29 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
30                               Value val, PredicateBuilder &builder,
31                               DenseMap<Value, Position *> &inputs,
32                               Position *pos) {
33   // Make sure this input value is accessible to the rewrite.
34   auto it = inputs.try_emplace(val, pos);
35 
36   // If this is an input value that has been visited in the tree, add a
37   // constraint to ensure that both instances refer to the same value.
38   if (!it.second &&
39       isa<pdl::AttributeOp, pdl::InputOp, pdl::TypeOp>(val.getDefiningOp())) {
40     auto minMaxPositions = std::minmax(pos, it.first->second, comparePosDepth);
41     predList.emplace_back(minMaxPositions.second,
42                           builder.getEqualTo(minMaxPositions.first));
43     return;
44   }
45 
46   // Check for a per-position predicate to apply.
47   switch (pos->getKind()) {
48   case Predicates::AttributePos: {
49     assert(val.getType().isa<pdl::AttributeType>() &&
50            "expected attribute type");
51     pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp());
52     predList.emplace_back(pos, builder.getIsNotNull());
53 
54     // If the attribute has a type, add a type constraint.
55     if (Value type = attr.type()) {
56       getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
57 
58       // Check for a constant value of the attribute.
59     } else if (Optional<Attribute> value = attr.value()) {
60       predList.emplace_back(pos, builder.getAttributeConstraint(*value));
61     }
62     break;
63   }
64   case Predicates::OperandPos: {
65     assert(val.getType().isa<pdl::ValueType>() && "expected value type");
66 
67     // Prevent traversal into a null value.
68     predList.emplace_back(pos, builder.getIsNotNull());
69 
70     // If this is a typed input, add a type constraint.
71     if (auto in = val.getDefiningOp<pdl::InputOp>()) {
72       if (Value type = in.type()) {
73         getTreePredicates(predList, type, builder, inputs,
74                           builder.getType(pos));
75       }
76 
77       // Otherwise, recurse into the parent node.
78     } else if (auto parentOp = val.getDefiningOp<pdl::OperationOp>()) {
79       getTreePredicates(predList, parentOp.op(), builder, inputs,
80                         builder.getParent(cast<OperandPosition>(pos)));
81     }
82     break;
83   }
84   case Predicates::OperationPos: {
85     assert(val.getType().isa<pdl::OperationType>() && "expected operation");
86     pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp());
87     OperationPosition *opPos = cast<OperationPosition>(pos);
88 
89     // Ensure getDefiningOp returns a non-null operation.
90     if (!opPos->isRoot())
91       predList.emplace_back(pos, builder.getIsNotNull());
92 
93     // Check that this is the correct root operation.
94     if (Optional<StringRef> opName = op.name())
95       predList.emplace_back(pos, builder.getOperationName(*opName));
96 
97     // Check that the operation has the proper number of operands and results.
98     OperandRange operands = op.operands();
99     ResultRange results = op.results();
100     predList.emplace_back(pos, builder.getOperandCount(operands.size()));
101     predList.emplace_back(pos, builder.getResultCount(results.size()));
102 
103     // Recurse into any attributes, operands, or results.
104     for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
105       getTreePredicates(
106           predList, std::get<1>(it), builder, inputs,
107           builder.getAttribute(opPos,
108                                std::get<0>(it).cast<StringAttr>().getValue()));
109     }
110     for (auto operandIt : llvm::enumerate(operands))
111       getTreePredicates(predList, operandIt.value(), builder, inputs,
112                         builder.getOperand(opPos, operandIt.index()));
113 
114     // Only recurse into results that are not referenced in the source tree.
115     for (auto resultIt : llvm::enumerate(results)) {
116       getTreePredicates(predList, resultIt.value(), builder, inputs,
117                         builder.getResult(opPos, resultIt.index()));
118     }
119     break;
120   }
121   case Predicates::ResultPos: {
122     assert(val.getType().isa<pdl::ValueType>() && "expected value type");
123     pdl::OperationOp parentOp = cast<pdl::OperationOp>(val.getDefiningOp());
124 
125     // Prevent traversing a null value.
126     predList.emplace_back(pos, builder.getIsNotNull());
127 
128     // Traverse the type constraint.
129     unsigned resultNo = cast<ResultPosition>(pos)->getResultNumber();
130     getTreePredicates(predList, parentOp.types()[resultNo], builder, inputs,
131                       builder.getType(pos));
132     break;
133   }
134   case Predicates::TypePos: {
135     assert(val.getType().isa<pdl::TypeType>() && "expected value type");
136     pdl::TypeOp typeOp = cast<pdl::TypeOp>(val.getDefiningOp());
137 
138     // Check for a constraint on a constant type.
139     if (Optional<Type> type = typeOp.type())
140       predList.emplace_back(pos, builder.getTypeConstraint(*type));
141     break;
142   }
143   default:
144     llvm_unreachable("unknown position kind");
145   }
146 }
147 
148 /// Collect all of the predicates related to constraints within the given
149 /// pattern operation.
collectConstraintPredicates(pdl::PatternOp pattern,std::vector<PositionalPredicate> & predList,PredicateBuilder & builder,DenseMap<Value,Position * > & inputs)150 static void collectConstraintPredicates(
151     pdl::PatternOp pattern, std::vector<PositionalPredicate> &predList,
152     PredicateBuilder &builder, DenseMap<Value, Position *> &inputs) {
153   for (auto op : pattern.body().getOps<pdl::ApplyConstraintOp>()) {
154     OperandRange arguments = op.args();
155     ArrayAttr parameters = op.constParamsAttr();
156 
157     std::vector<Position *> allPositions;
158     allPositions.reserve(arguments.size());
159     for (Value arg : arguments)
160       allPositions.push_back(inputs.lookup(arg));
161 
162     // Push the constraint to the furthest position.
163     Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
164                                       comparePosDepth);
165     PredicateBuilder::Predicate pred =
166         builder.getConstraint(op.name(), std::move(allPositions), parameters);
167     predList.emplace_back(pos, pred);
168   }
169 }
170 
171 /// Given a pattern operation, build the set of matcher predicates necessary to
172 /// match this pattern.
buildPredicateList(pdl::PatternOp pattern,PredicateBuilder & builder,std::vector<PositionalPredicate> & predList,DenseMap<Value,Position * > & valueToPosition)173 static void buildPredicateList(pdl::PatternOp pattern,
174                                PredicateBuilder &builder,
175                                std::vector<PositionalPredicate> &predList,
176                                DenseMap<Value, Position *> &valueToPosition) {
177   getTreePredicates(predList, pattern.getRewriter().root(), builder,
178                     valueToPosition, builder.getRoot());
179   collectConstraintPredicates(pattern, predList, builder, valueToPosition);
180 }
181 
182 //===----------------------------------------------------------------------===//
183 // Pattern Predicate Tree Merging
184 //===----------------------------------------------------------------------===//
185 
186 namespace {
187 
188 /// This class represents a specific predicate applied to a position, and
189 /// provides hashing and ordering operators. This class allows for computing a
190 /// frequence sum and ordering predicates based on a cost model.
191 struct OrderedPredicate {
OrderedPredicate__anon966670330111::OrderedPredicate192   OrderedPredicate(const std::pair<Position *, Qualifier *> &ip)
193       : position(ip.first), question(ip.second) {}
OrderedPredicate__anon966670330111::OrderedPredicate194   OrderedPredicate(const PositionalPredicate &ip)
195       : position(ip.position), question(ip.question) {}
196 
197   /// The position this predicate is applied to.
198   Position *position;
199 
200   /// The question that is applied by this predicate onto the position.
201   Qualifier *question;
202 
203   /// The first and second order benefit sums.
204   /// The primary sum is the number of occurrences of this predicate among all
205   /// of the patterns.
206   unsigned primary = 0;
207   /// The secondary sum is a squared summation of the primary sum of all of the
208   /// predicates within each pattern that contains this predicate. This allows
209   /// for favoring predicates that are more commonly shared within a pattern, as
210   /// opposed to those shared across patterns.
211   unsigned secondary = 0;
212 
213   /// A map between a pattern operation and the answer to the predicate question
214   /// within that pattern.
215   DenseMap<Operation *, Qualifier *> patternToAnswer;
216 
217   /// Returns true if this predicate is ordered before `other`, based on the
218   /// cost model.
operator <__anon966670330111::OrderedPredicate219   bool operator<(const OrderedPredicate &other) const {
220     // Sort by:
221     // * first and secondary order sums
222     // * lower depth
223     // * position dependency
224     // * predicate dependency.
225     auto *otherPos = other.position;
226     return std::make_tuple(other.primary, other.secondary,
227                            otherPos->getIndex().size(), otherPos->getKind(),
228                            other.question->getKind()) >
229            std::make_tuple(primary, secondary, position->getIndex().size(),
230                            position->getKind(), question->getKind());
231   }
232 };
233 
234 /// A DenseMapInfo for OrderedPredicate based solely on the position and
235 /// question.
236 struct OrderedPredicateDenseInfo {
237   using Base = DenseMapInfo<std::pair<Position *, Qualifier *>>;
238 
getEmptyKey__anon966670330111::OrderedPredicateDenseInfo239   static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); }
getTombstoneKey__anon966670330111::OrderedPredicateDenseInfo240   static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); }
isEqual__anon966670330111::OrderedPredicateDenseInfo241   static bool isEqual(const OrderedPredicate &lhs,
242                       const OrderedPredicate &rhs) {
243     return lhs.position == rhs.position && lhs.question == rhs.question;
244   }
getHashValue__anon966670330111::OrderedPredicateDenseInfo245   static unsigned getHashValue(const OrderedPredicate &p) {
246     return llvm::hash_combine(p.position, p.question);
247   }
248 };
249 
250 /// This class wraps a set of ordered predicates that are used within a specific
251 /// pattern operation.
252 struct OrderedPredicateList {
OrderedPredicateList__anon966670330111::OrderedPredicateList253   OrderedPredicateList(pdl::PatternOp pattern) : pattern(pattern) {}
254 
255   pdl::PatternOp pattern;
256   DenseSet<OrderedPredicate *> predicates;
257 };
258 } // end anonymous namespace
259 
260 /// Returns true if the given matcher refers to the same predicate as the given
261 /// ordered predicate. This means that the position and questions of the two
262 /// match.
isSamePredicate(MatcherNode * node,OrderedPredicate * predicate)263 static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) {
264   return node->getPosition() == predicate->position &&
265          node->getQuestion() == predicate->question;
266 }
267 
268 /// Get or insert a child matcher for the given parent switch node, given a
269 /// predicate and parent pattern.
getOrCreateChild(SwitchNode * node,OrderedPredicate * predicate,pdl::PatternOp pattern)270 std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node,
271                                                OrderedPredicate *predicate,
272                                                pdl::PatternOp pattern) {
273   assert(isSamePredicate(node, predicate) &&
274          "expected matcher to equal the given predicate");
275 
276   auto it = predicate->patternToAnswer.find(pattern);
277   assert(it != predicate->patternToAnswer.end() &&
278          "expected pattern to exist in predicate");
279   return node->getChildren().insert({it->second, nullptr}).first->second;
280 }
281 
282 /// Build the matcher CFG by "pushing" patterns through by sorted predicate
283 /// order. A pattern will traverse as far as possible using common predicates
284 /// and then either diverge from the CFG or reach the end of a branch and start
285 /// creating new nodes.
propagatePattern(std::unique_ptr<MatcherNode> & node,OrderedPredicateList & list,std::vector<OrderedPredicate * >::iterator current,std::vector<OrderedPredicate * >::iterator end)286 static void propagatePattern(std::unique_ptr<MatcherNode> &node,
287                              OrderedPredicateList &list,
288                              std::vector<OrderedPredicate *>::iterator current,
289                              std::vector<OrderedPredicate *>::iterator end) {
290   if (current == end) {
291     // We've hit the end of a pattern, so create a successful result node.
292     node = std::make_unique<SuccessNode>(list.pattern, std::move(node));
293 
294     // If the pattern doesn't contain this predicate, ignore it.
295   } else if (list.predicates.find(*current) == list.predicates.end()) {
296     propagatePattern(node, list, std::next(current), end);
297 
298     // If the current matcher node is invalid, create a new one for this
299     // position and continue propagation.
300   } else if (!node) {
301     // Create a new node at this position and continue
302     node = std::make_unique<SwitchNode>((*current)->position,
303                                         (*current)->question);
304     propagatePattern(
305         getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
306         list, std::next(current), end);
307 
308     // If the matcher has already been created, and it is for this predicate we
309     // continue propagation to the child.
310   } else if (isSamePredicate(node.get(), *current)) {
311     propagatePattern(
312         getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
313         list, std::next(current), end);
314 
315     // If the matcher doesn't match the current predicate, insert a branch as
316     // the common set of matchers has diverged.
317   } else {
318     propagatePattern(node->getFailureNode(), list, current, end);
319   }
320 }
321 
322 /// Fold any switch nodes nested under `node` to boolean nodes when possible.
323 /// `node` is updated in-place if it is a switch.
foldSwitchToBool(std::unique_ptr<MatcherNode> & node)324 static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) {
325   if (!node)
326     return;
327 
328   if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) {
329     SwitchNode::ChildMapT &children = switchNode->getChildren();
330     for (auto &it : children)
331       foldSwitchToBool(it.second);
332 
333     // If the node only contains one child, collapse it into a boolean predicate
334     // node.
335     if (children.size() == 1) {
336       auto childIt = children.begin();
337       node = std::make_unique<BoolNode>(
338           node->getPosition(), node->getQuestion(), childIt->first,
339           std::move(childIt->second), std::move(node->getFailureNode()));
340     }
341   } else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) {
342     foldSwitchToBool(boolNode->getSuccessNode());
343   }
344 
345   foldSwitchToBool(node->getFailureNode());
346 }
347 
348 /// Insert an exit node at the end of the failure path of the `root`.
insertExitNode(std::unique_ptr<MatcherNode> * root)349 static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
350   while (*root)
351     root = &(*root)->getFailureNode();
352   *root = std::make_unique<ExitNode>();
353 }
354 
355 /// Given a module containing PDL pattern operations, generate a matcher tree
356 /// using the patterns within the given module and return the root matcher node.
357 std::unique_ptr<MatcherNode>
generateMatcherTree(ModuleOp module,PredicateBuilder & builder,DenseMap<Value,Position * > & valueToPosition)358 MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
359                                  DenseMap<Value, Position *> &valueToPosition) {
360   // Collect the set of predicates contained within the pattern operations of
361   // the module.
362   SmallVector<std::pair<pdl::PatternOp, std::vector<PositionalPredicate>>, 16>
363       patternsAndPredicates;
364   for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
365     std::vector<PositionalPredicate> predicateList;
366     buildPredicateList(pattern, builder, predicateList, valueToPosition);
367     patternsAndPredicates.emplace_back(pattern, std::move(predicateList));
368   }
369 
370   // Associate a pattern result with each unique predicate.
371   DenseSet<OrderedPredicate, OrderedPredicateDenseInfo> uniqued;
372   for (auto &patternAndPredList : patternsAndPredicates) {
373     for (auto &predicate : patternAndPredList.second) {
374       auto it = uniqued.insert(predicate);
375       it.first->patternToAnswer.try_emplace(patternAndPredList.first,
376                                             predicate.answer);
377     }
378   }
379 
380   // Associate each pattern to a set of its ordered predicates for later lookup.
381   std::vector<OrderedPredicateList> lists;
382   lists.reserve(patternsAndPredicates.size());
383   for (auto &patternAndPredList : patternsAndPredicates) {
384     OrderedPredicateList list(patternAndPredList.first);
385     for (auto &predicate : patternAndPredList.second) {
386       OrderedPredicate *orderedPredicate = &*uniqued.find(predicate);
387       list.predicates.insert(orderedPredicate);
388 
389       // Increment the primary sum for each reference to a particular predicate.
390       ++orderedPredicate->primary;
391     }
392     lists.push_back(std::move(list));
393   }
394 
395   // For a particular pattern, get the total primary sum and add it to the
396   // secondary sum of each predicate. Square the primary sums to emphasize
397   // shared predicates within rather than across patterns.
398   for (auto &list : lists) {
399     unsigned total = 0;
400     for (auto *predicate : list.predicates)
401       total += predicate->primary * predicate->primary;
402     for (auto *predicate : list.predicates)
403       predicate->secondary += total;
404   }
405 
406   // Sort the set of predicates now that the cost primary and secondary sums
407   // have been computed.
408   std::vector<OrderedPredicate *> ordered;
409   ordered.reserve(uniqued.size());
410   for (auto &ip : uniqued)
411     ordered.push_back(&ip);
412   std::stable_sort(
413       ordered.begin(), ordered.end(),
414       [](OrderedPredicate *lhs, OrderedPredicate *rhs) { return *lhs < *rhs; });
415 
416   // Build the matchers for each of the pattern predicate lists.
417   std::unique_ptr<MatcherNode> root;
418   for (OrderedPredicateList &list : lists)
419     propagatePattern(root, list, ordered.begin(), ordered.end());
420 
421   // Collapse the graph and insert the exit node.
422   foldSwitchToBool(root);
423   insertExitNode(&root);
424   return root;
425 }
426 
427 //===----------------------------------------------------------------------===//
428 // MatcherNode
429 //===----------------------------------------------------------------------===//
430 
MatcherNode(TypeID matcherTypeID,Position * p,Qualifier * q,std::unique_ptr<MatcherNode> failureNode)431 MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q,
432                          std::unique_ptr<MatcherNode> failureNode)
433     : position(p), question(q), failureNode(std::move(failureNode)),
434       matcherTypeID(matcherTypeID) {}
435 
436 //===----------------------------------------------------------------------===//
437 // BoolNode
438 //===----------------------------------------------------------------------===//
439 
BoolNode(Position * position,Qualifier * question,Qualifier * answer,std::unique_ptr<MatcherNode> successNode,std::unique_ptr<MatcherNode> failureNode)440 BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer,
441                    std::unique_ptr<MatcherNode> successNode,
442                    std::unique_ptr<MatcherNode> failureNode)
443     : MatcherNode(TypeID::get<BoolNode>(), position, question,
444                   std::move(failureNode)),
445       answer(answer), successNode(std::move(successNode)) {}
446 
447 //===----------------------------------------------------------------------===//
448 // SuccessNode
449 //===----------------------------------------------------------------------===//
450 
SuccessNode(pdl::PatternOp pattern,std::unique_ptr<MatcherNode> failureNode)451 SuccessNode::SuccessNode(pdl::PatternOp pattern,
452                          std::unique_ptr<MatcherNode> failureNode)
453     : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr,
454                   /*question=*/nullptr, std::move(failureNode)),
455       pattern(pattern) {}
456 
457 //===----------------------------------------------------------------------===//
458 // SwitchNode
459 //===----------------------------------------------------------------------===//
460 
SwitchNode(Position * position,Qualifier * question)461 SwitchNode::SwitchNode(Position *position, Qualifier *question)
462     : MatcherNode(TypeID::get<SwitchNode>(), position, question) {}
463