1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
10 #include "../PassDetail.h"
11 #include "PredicateTree.h"
12 #include "mlir/Dialect/PDL/IR/PDL.h"
13 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
15 #include "mlir/Pass/Pass.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/ScopedHashTable.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/ADT/TypeSwitch.h"
20
21 using namespace mlir;
22 using namespace mlir::pdl_to_pdl_interp;
23
24 //===----------------------------------------------------------------------===//
25 // PatternLowering
26 //===----------------------------------------------------------------------===//
27
28 namespace {
29 /// This class generators operations within the PDL Interpreter dialect from a
30 /// given module containing PDL pattern operations.
31 struct PatternLowering {
32 public:
33 PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule);
34
35 /// Generate code for matching and rewriting based on the pattern operations
36 /// within the module.
37 void lower(ModuleOp module);
38
39 private:
40 using ValueMap = llvm::ScopedHashTable<Position *, Value>;
41 using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
42
43 /// Generate interpreter operations for the tree rooted at the given matcher
44 /// node.
45 Block *generateMatcher(MatcherNode &node);
46
47 /// Get or create an access to the provided positional value within the
48 /// current block.
49 Value getValueAt(Block *cur, Position *pos);
50
51 /// Create an interpreter predicate operation, branching to the provided true
52 /// and false destinations.
53 void generatePredicate(Block *currentBlock, Qualifier *question,
54 Qualifier *answer, Value val, Block *trueDest,
55 Block *falseDest);
56
57 /// Create an interpreter switch predicate operation, with a provided default
58 /// and several case destinations.
59 void generateSwitch(Block *currentBlock, Qualifier *question, Value val,
60 Block *defaultDest,
61 ArrayRef<std::pair<Qualifier *, Block *>> dests);
62
63 /// Create the interpreter operations to record a successful pattern match.
64 void generateRecordMatch(Block *currentBlock, Block *nextBlock,
65 pdl::PatternOp pattern);
66
67 /// Generate a rewriter function for the given pattern operation, and returns
68 /// a reference to that function.
69 SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
70 SmallVectorImpl<Position *> &usedMatchValues);
71
72 /// Generate the rewriter code for the given operation.
73 void generateRewriter(pdl::AttributeOp attrOp,
74 DenseMap<Value, Value> &rewriteValues,
75 function_ref<Value(Value)> mapRewriteValue);
76 void generateRewriter(pdl::EraseOp eraseOp,
77 DenseMap<Value, Value> &rewriteValues,
78 function_ref<Value(Value)> mapRewriteValue);
79 void generateRewriter(pdl::OperationOp operationOp,
80 DenseMap<Value, Value> &rewriteValues,
81 function_ref<Value(Value)> mapRewriteValue);
82 void generateRewriter(pdl::CreateNativeOp createNativeOp,
83 DenseMap<Value, Value> &rewriteValues,
84 function_ref<Value(Value)> mapRewriteValue);
85 void generateRewriter(pdl::ReplaceOp replaceOp,
86 DenseMap<Value, Value> &rewriteValues,
87 function_ref<Value(Value)> mapRewriteValue);
88 void generateRewriter(pdl::TypeOp typeOp,
89 DenseMap<Value, Value> &rewriteValues,
90 function_ref<Value(Value)> mapRewriteValue);
91
92 /// Generate the values used for resolving the result types of an operation
93 /// created within a dag rewriter region.
94 void generateOperationResultTypeRewriter(
95 pdl::OperationOp op, SmallVectorImpl<Value> &types,
96 DenseMap<Value, Value> &rewriteValues,
97 function_ref<Value(Value)> mapRewriteValue);
98
99 /// A builder to use when generating interpreter operations.
100 OpBuilder builder;
101
102 /// The matcher function used for all match related logic within PDL patterns.
103 FuncOp matcherFunc;
104
105 /// The rewriter module containing the all rewrite related logic within PDL
106 /// patterns.
107 ModuleOp rewriterModule;
108
109 /// The symbol table of the rewriter module used for insertion.
110 SymbolTable rewriterSymbolTable;
111
112 /// A scoped map connecting a position with the corresponding interpreter
113 /// value.
114 ValueMap values;
115
116 /// A stack of blocks used as the failure destination for matcher nodes that
117 /// don't have an explicit failure path.
118 SmallVector<Block *, 8> failureBlockStack;
119
120 /// A mapping between values defined in a pattern match, and the corresponding
121 /// positional value.
122 DenseMap<Value, Position *> valueToPosition;
123
124 /// The set of operation values whose whose location will be used for newly
125 /// generated operations.
126 llvm::SetVector<Value> locOps;
127 };
128 } // end anonymous namespace
129
PatternLowering(FuncOp matcherFunc,ModuleOp rewriterModule)130 PatternLowering::PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule)
131 : builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
132 rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {}
133
lower(ModuleOp module)134 void PatternLowering::lower(ModuleOp module) {
135 PredicateUniquer predicateUniquer;
136 PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
137
138 // Define top-level scope for the arguments to the matcher function.
139 ValueMapScope topLevelValueScope(values);
140
141 // Insert the root operation, i.e. argument to the matcher, at the root
142 // position.
143 Block *matcherEntryBlock = matcherFunc.addEntryBlock();
144 values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0));
145
146 // Generate a root matcher node from the provided PDL module.
147 std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree(
148 module, predicateBuilder, valueToPosition);
149 Block *firstMatcherBlock = generateMatcher(*root);
150
151 // After generation, merged the first matched block into the entry.
152 matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(),
153 firstMatcherBlock->getOperations());
154 firstMatcherBlock->erase();
155 }
156
generateMatcher(MatcherNode & node)157 Block *PatternLowering::generateMatcher(MatcherNode &node) {
158 // Push a new scope for the values used by this matcher.
159 Block *block = matcherFunc.addBlock();
160 ValueMapScope scope(values);
161
162 // If this is the return node, simply insert the corresponding interpreter
163 // finalize.
164 if (isa<ExitNode>(node)) {
165 builder.setInsertionPointToEnd(block);
166 builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc());
167 return block;
168 }
169
170 // If this node contains a position, get the corresponding value for this
171 // block.
172 Position *position = node.getPosition();
173 Value val = position ? getValueAt(block, position) : Value();
174
175 // Get the next block in the match sequence.
176 std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode();
177 Block *nextBlock;
178 if (failureNode) {
179 nextBlock = generateMatcher(*failureNode);
180 failureBlockStack.push_back(nextBlock);
181 } else {
182 assert(!failureBlockStack.empty() && "expected valid failure block");
183 nextBlock = failureBlockStack.back();
184 }
185
186 // If this value corresponds to an operation, record that we are going to use
187 // its location as part of a fused location.
188 bool isOperationValue = val && val.getType().isa<pdl::OperationType>();
189 if (isOperationValue)
190 locOps.insert(val);
191
192 // Generate code for a boolean predicate node.
193 if (auto *boolNode = dyn_cast<BoolNode>(&node)) {
194 auto *child = generateMatcher(*boolNode->getSuccessNode());
195 generatePredicate(block, node.getQuestion(), boolNode->getAnswer(), val,
196 child, nextBlock);
197
198 // Generate code for a switch node.
199 } else if (auto *switchNode = dyn_cast<SwitchNode>(&node)) {
200 // Collect the next blocks for all of the children and generate a switch.
201 llvm::MapVector<Qualifier *, Block *> children;
202 for (auto &it : switchNode->getChildren())
203 children.insert({it.first, generateMatcher(*it.second)});
204 generateSwitch(block, node.getQuestion(), val, nextBlock,
205 children.takeVector());
206
207 // Generate code for a success node.
208 } else if (auto *successNode = dyn_cast<SuccessNode>(&node)) {
209 generateRecordMatch(block, nextBlock, successNode->getPattern());
210 }
211
212 if (failureNode)
213 failureBlockStack.pop_back();
214 if (isOperationValue)
215 locOps.remove(val);
216 return block;
217 }
218
getValueAt(Block * cur,Position * pos)219 Value PatternLowering::getValueAt(Block *cur, Position *pos) {
220 if (Value val = values.lookup(pos))
221 return val;
222
223 // Get the value for the parent position.
224 Value parentVal = getValueAt(cur, pos->getParent());
225
226 // TODO: Use a location from the position.
227 Location loc = parentVal.getLoc();
228 builder.setInsertionPointToEnd(cur);
229 Value value;
230 switch (pos->getKind()) {
231 case Predicates::OperationPos:
232 value = builder.create<pdl_interp::GetDefiningOpOp>(
233 loc, builder.getType<pdl::OperationType>(), parentVal);
234 break;
235 case Predicates::OperandPos: {
236 auto *operandPos = cast<OperandPosition>(pos);
237 value = builder.create<pdl_interp::GetOperandOp>(
238 loc, builder.getType<pdl::ValueType>(), parentVal,
239 operandPos->getOperandNumber());
240 break;
241 }
242 case Predicates::AttributePos: {
243 auto *attrPos = cast<AttributePosition>(pos);
244 value = builder.create<pdl_interp::GetAttributeOp>(
245 loc, builder.getType<pdl::AttributeType>(), parentVal,
246 attrPos->getName().strref());
247 break;
248 }
249 case Predicates::TypePos: {
250 if (parentVal.getType().isa<pdl::ValueType>())
251 value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
252 else
253 value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal);
254 break;
255 }
256 case Predicates::ResultPos: {
257 auto *resPos = cast<ResultPosition>(pos);
258 value = builder.create<pdl_interp::GetResultOp>(
259 loc, builder.getType<pdl::ValueType>(), parentVal,
260 resPos->getResultNumber());
261 break;
262 }
263 default:
264 llvm_unreachable("Generating unknown Position getter");
265 break;
266 }
267 values.insert(pos, value);
268 return value;
269 }
270
generatePredicate(Block * currentBlock,Qualifier * question,Qualifier * answer,Value val,Block * trueDest,Block * falseDest)271 void PatternLowering::generatePredicate(Block *currentBlock,
272 Qualifier *question, Qualifier *answer,
273 Value val, Block *trueDest,
274 Block *falseDest) {
275 builder.setInsertionPointToEnd(currentBlock);
276 Location loc = val.getLoc();
277 switch (question->getKind()) {
278 case Predicates::IsNotNullQuestion:
279 builder.create<pdl_interp::IsNotNullOp>(loc, val, trueDest, falseDest);
280 break;
281 case Predicates::OperationNameQuestion: {
282 auto *opNameAnswer = cast<OperationNameAnswer>(answer);
283 builder.create<pdl_interp::CheckOperationNameOp>(
284 loc, val, opNameAnswer->getValue().getStringRef(), trueDest, falseDest);
285 break;
286 }
287 case Predicates::TypeQuestion: {
288 auto *ans = cast<TypeAnswer>(answer);
289 builder.create<pdl_interp::CheckTypeOp>(
290 loc, val, TypeAttr::get(ans->getValue()), trueDest, falseDest);
291 break;
292 }
293 case Predicates::AttributeQuestion: {
294 auto *ans = cast<AttributeAnswer>(answer);
295 builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(),
296 trueDest, falseDest);
297 break;
298 }
299 case Predicates::OperandCountQuestion: {
300 auto *unsignedAnswer = cast<UnsignedAnswer>(answer);
301 builder.create<pdl_interp::CheckOperandCountOp>(
302 loc, val, unsignedAnswer->getValue(), trueDest, falseDest);
303 break;
304 }
305 case Predicates::ResultCountQuestion: {
306 auto *unsignedAnswer = cast<UnsignedAnswer>(answer);
307 builder.create<pdl_interp::CheckResultCountOp>(
308 loc, val, unsignedAnswer->getValue(), trueDest, falseDest);
309 break;
310 }
311 case Predicates::EqualToQuestion: {
312 auto *equalToQuestion = cast<EqualToQuestion>(question);
313 builder.create<pdl_interp::AreEqualOp>(
314 loc, val, getValueAt(currentBlock, equalToQuestion->getValue()),
315 trueDest, falseDest);
316 break;
317 }
318 case Predicates::ConstraintQuestion: {
319 auto *cstQuestion = cast<ConstraintQuestion>(question);
320 SmallVector<Value, 2> args;
321 for (Position *position : std::get<1>(cstQuestion->getValue()))
322 args.push_back(getValueAt(currentBlock, position));
323 builder.create<pdl_interp::ApplyConstraintOp>(
324 loc, std::get<0>(cstQuestion->getValue()), args,
325 std::get<2>(cstQuestion->getValue()).cast<ArrayAttr>(), trueDest,
326 falseDest);
327 break;
328 }
329 default:
330 llvm_unreachable("Generating unknown Predicate operation");
331 }
332 }
333
334 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
createSwitchOp(Value val,Block * defaultDest,OpBuilder & builder,ArrayRef<std::pair<Qualifier *,Block * >> dests)335 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
336 ArrayRef<std::pair<Qualifier *, Block *>> dests) {
337 std::vector<ValT> values;
338 std::vector<Block *> blocks;
339 values.reserve(dests.size());
340 blocks.reserve(dests.size());
341 for (const auto &it : dests) {
342 blocks.push_back(it.second);
343 values.push_back(cast<PredT>(it.first)->getValue());
344 }
345 builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks);
346 }
347
generateSwitch(Block * currentBlock,Qualifier * question,Value val,Block * defaultDest,ArrayRef<std::pair<Qualifier *,Block * >> dests)348 void PatternLowering::generateSwitch(
349 Block *currentBlock, Qualifier *question, Value val, Block *defaultDest,
350 ArrayRef<std::pair<Qualifier *, Block *>> dests) {
351 builder.setInsertionPointToEnd(currentBlock);
352 switch (question->getKind()) {
353 case Predicates::OperandCountQuestion:
354 return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer,
355 int32_t>(val, defaultDest, builder, dests);
356 case Predicates::ResultCountQuestion:
357 return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer,
358 int32_t>(val, defaultDest, builder, dests);
359 case Predicates::OperationNameQuestion:
360 return createSwitchOp<pdl_interp::SwitchOperationNameOp,
361 OperationNameAnswer>(val, defaultDest, builder,
362 dests);
363 case Predicates::TypeQuestion:
364 return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>(
365 val, defaultDest, builder, dests);
366 case Predicates::AttributeQuestion:
367 return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>(
368 val, defaultDest, builder, dests);
369 default:
370 llvm_unreachable("Generating unknown switch predicate.");
371 }
372 }
373
generateRecordMatch(Block * currentBlock,Block * nextBlock,pdl::PatternOp pattern)374 void PatternLowering::generateRecordMatch(Block *currentBlock, Block *nextBlock,
375 pdl::PatternOp pattern) {
376 // Generate a rewriter for the pattern this success node represents, and track
377 // any values used from the match region.
378 SmallVector<Position *, 8> usedMatchValues;
379 SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
380
381 // Process any values used in the rewrite that are defined in the match.
382 std::vector<Value> mappedMatchValues;
383 mappedMatchValues.reserve(usedMatchValues.size());
384 for (Position *position : usedMatchValues)
385 mappedMatchValues.push_back(getValueAt(currentBlock, position));
386
387 // Collect the set of operations generated by the rewriter.
388 SmallVector<StringRef, 4> generatedOps;
389 for (auto op : pattern.getRewriter().body().getOps<pdl::OperationOp>())
390 generatedOps.push_back(*op.name());
391 ArrayAttr generatedOpsAttr;
392 if (!generatedOps.empty())
393 generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
394
395 // Grab the root kind if present.
396 StringAttr rootKindAttr;
397 if (Optional<StringRef> rootKind = pattern.getRootKind())
398 rootKindAttr = builder.getStringAttr(*rootKind);
399
400 builder.setInsertionPointToEnd(currentBlock);
401 builder.create<pdl_interp::RecordMatchOp>(
402 pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
403 rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(),
404 nextBlock);
405 }
406
generateRewriter(pdl::PatternOp pattern,SmallVectorImpl<Position * > & usedMatchValues)407 SymbolRefAttr PatternLowering::generateRewriter(
408 pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
409 FuncOp rewriterFunc =
410 FuncOp::create(pattern.getLoc(), "pdl_generated_rewriter",
411 builder.getFunctionType(llvm::None, llvm::None));
412 rewriterSymbolTable.insert(rewriterFunc);
413
414 // Generate the rewriter function body.
415 builder.setInsertionPointToEnd(rewriterFunc.addEntryBlock());
416
417 // Map an input operand of the pattern to a generated interpreter value.
418 DenseMap<Value, Value> rewriteValues;
419 auto mapRewriteValue = [&](Value oldValue) {
420 Value &newValue = rewriteValues[oldValue];
421 if (newValue)
422 return newValue;
423
424 // Prefer materializing constants directly when possible.
425 Operation *oldOp = oldValue.getDefiningOp();
426 if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
427 if (Attribute value = attrOp.valueAttr()) {
428 return newValue = builder.create<pdl_interp::CreateAttributeOp>(
429 attrOp.getLoc(), value);
430 }
431 } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
432 if (TypeAttr type = typeOp.typeAttr()) {
433 return newValue = builder.create<pdl_interp::CreateTypeOp>(
434 typeOp.getLoc(), type);
435 }
436 }
437
438 // Otherwise, add this as an input to the rewriter.
439 Position *inputPos = valueToPosition.lookup(oldValue);
440 assert(inputPos && "expected value to be a pattern input");
441 usedMatchValues.push_back(inputPos);
442 return newValue = rewriterFunc.front().addArgument(oldValue.getType());
443 };
444
445 // If this is a custom rewriter, simply dispatch to the registered rewrite
446 // method.
447 pdl::RewriteOp rewriter = pattern.getRewriter();
448 if (StringAttr rewriteName = rewriter.nameAttr()) {
449 Value root = mapRewriteValue(rewriter.root());
450 SmallVector<Value, 4> args = llvm::to_vector<4>(
451 llvm::map_range(rewriter.externalArgs(), mapRewriteValue));
452 builder.create<pdl_interp::ApplyRewriteOp>(
453 rewriter.getLoc(), rewriteName, root, args,
454 rewriter.externalConstParamsAttr());
455 } else {
456 // Otherwise this is a dag rewriter defined using PDL operations.
457 for (Operation &rewriteOp : *rewriter.getBody()) {
458 llvm::TypeSwitch<Operation *>(&rewriteOp)
459 .Case<pdl::AttributeOp, pdl::CreateNativeOp, pdl::EraseOp,
460 pdl::OperationOp, pdl::ReplaceOp, pdl::TypeOp>([&](auto op) {
461 this->generateRewriter(op, rewriteValues, mapRewriteValue);
462 });
463 }
464 }
465
466 // Update the signature of the rewrite function.
467 rewriterFunc.setType(builder.getFunctionType(
468 llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
469 /*results=*/llvm::None));
470
471 builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc());
472 return builder.getSymbolRefAttr(
473 pdl_interp::PDLInterpDialect::getRewriterModuleName(),
474 builder.getSymbolRefAttr(rewriterFunc));
475 }
476
generateRewriter(pdl::AttributeOp attrOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)477 void PatternLowering::generateRewriter(
478 pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
479 function_ref<Value(Value)> mapRewriteValue) {
480 Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
481 attrOp.getLoc(), attrOp.valueAttr());
482 rewriteValues[attrOp] = newAttr;
483 }
484
generateRewriter(pdl::EraseOp eraseOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)485 void PatternLowering::generateRewriter(
486 pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
487 function_ref<Value(Value)> mapRewriteValue) {
488 builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
489 mapRewriteValue(eraseOp.operation()));
490 }
491
generateRewriter(pdl::OperationOp operationOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)492 void PatternLowering::generateRewriter(
493 pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
494 function_ref<Value(Value)> mapRewriteValue) {
495 SmallVector<Value, 4> operands;
496 for (Value operand : operationOp.operands())
497 operands.push_back(mapRewriteValue(operand));
498
499 SmallVector<Value, 4> attributes;
500 for (Value attr : operationOp.attributes())
501 attributes.push_back(mapRewriteValue(attr));
502
503 SmallVector<Value, 2> types;
504 generateOperationResultTypeRewriter(operationOp, types, rewriteValues,
505 mapRewriteValue);
506
507 // Create the new operation.
508 Location loc = operationOp.getLoc();
509 Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
510 loc, *operationOp.name(), types, operands, attributes,
511 operationOp.attributeNames());
512 rewriteValues[operationOp.op()] = createdOp;
513
514 // Make all of the new operation results available.
515 OperandRange resultTypes = operationOp.types();
516 for (auto it : llvm::enumerate(operationOp.results())) {
517 Value getResultVal = builder.create<pdl_interp::GetResultOp>(
518 loc, builder.getType<pdl::ValueType>(), createdOp, it.index());
519 rewriteValues[it.value()] = getResultVal;
520
521 // If any of the types have not been resolved, make those available as well.
522 Value &type = rewriteValues[resultTypes[it.index()]];
523 if (!type)
524 type = builder.create<pdl_interp::GetValueTypeOp>(loc, getResultVal);
525 }
526 }
527
generateRewriter(pdl::CreateNativeOp createNativeOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)528 void PatternLowering::generateRewriter(
529 pdl::CreateNativeOp createNativeOp, DenseMap<Value, Value> &rewriteValues,
530 function_ref<Value(Value)> mapRewriteValue) {
531 SmallVector<Value, 2> arguments;
532 for (Value argument : createNativeOp.args())
533 arguments.push_back(mapRewriteValue(argument));
534 Value result = builder.create<pdl_interp::CreateNativeOp>(
535 createNativeOp.getLoc(), createNativeOp.result().getType(),
536 createNativeOp.nameAttr(), arguments, createNativeOp.constParamsAttr());
537 rewriteValues[createNativeOp] = result;
538 }
539
generateRewriter(pdl::ReplaceOp replaceOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)540 void PatternLowering::generateRewriter(
541 pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
542 function_ref<Value(Value)> mapRewriteValue) {
543 // If the replacement was another operation, get its results. `pdl` allows
544 // for using an operation for simplicitly, but the interpreter isn't as
545 // user facing.
546 ValueRange origOperands;
547 if (Value replOp = replaceOp.replOperation())
548 origOperands = cast<pdl::OperationOp>(replOp.getDefiningOp()).results();
549 else
550 origOperands = replaceOp.replValues();
551
552 // If there are no replacement values, just create an erase instead.
553 if (origOperands.empty()) {
554 builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(),
555 mapRewriteValue(replaceOp.operation()));
556 return;
557 }
558
559 SmallVector<Value, 4> replOperands;
560 for (Value operand : origOperands)
561 replOperands.push_back(mapRewriteValue(operand));
562 builder.create<pdl_interp::ReplaceOp>(
563 replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands);
564 }
565
generateRewriter(pdl::TypeOp typeOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)566 void PatternLowering::generateRewriter(
567 pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
568 function_ref<Value(Value)> mapRewriteValue) {
569 // If the type isn't constant, the users (e.g. OperationOp) will resolve this
570 // type.
571 if (TypeAttr typeAttr = typeOp.typeAttr()) {
572 Value newType =
573 builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
574 rewriteValues[typeOp] = newType;
575 }
576 }
577
generateOperationResultTypeRewriter(pdl::OperationOp op,SmallVectorImpl<Value> & types,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)578 void PatternLowering::generateOperationResultTypeRewriter(
579 pdl::OperationOp op, SmallVectorImpl<Value> &types,
580 DenseMap<Value, Value> &rewriteValues,
581 function_ref<Value(Value)> mapRewriteValue) {
582 // Functor that returns if the given use can be used to infer a type.
583 Block *rewriterBlock = op->getBlock();
584 auto getReplacedOperationFrom = [&](OpOperand &use) -> Operation * {
585 // Check that the use corresponds to a ReplaceOp and that it is the
586 // replacement value, not the operation being replaced.
587 pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
588 if (!replOpUser || use.getOperandNumber() == 0)
589 return nullptr;
590 // Make sure the replaced operation was defined before this one.
591 Operation *replacedOp = replOpUser.operation().getDefiningOp();
592 if (replacedOp->getBlock() != rewriterBlock ||
593 replacedOp->isBeforeInBlock(op))
594 return replacedOp;
595 return nullptr;
596 };
597
598 // If non-None/non-Null, this is an operation that is replaced by `op`.
599 // If Null, there is no full replacement operation for `op`.
600 // If None, a replacement operation hasn't been searched for.
601 Optional<Operation *> fullReplacedOperation;
602 bool hasTypeInference = op.hasTypeInference();
603 auto resultTypeValues = op.types();
604 types.reserve(resultTypeValues.size());
605 for (auto it : llvm::enumerate(op.results())) {
606 Value result = it.value(), resultType = resultTypeValues[it.index()];
607
608 // Check for an already translated value.
609 if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
610 types.push_back(existingRewriteValue);
611 continue;
612 }
613
614 // Check for an input from the matcher.
615 if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
616 types.push_back(mapRewriteValue(resultType));
617 continue;
618 }
619
620 // Check if the operation has type inference support.
621 if (hasTypeInference) {
622 types.push_back(builder.create<pdl_interp::InferredTypeOp>(op.getLoc()));
623 continue;
624 }
625
626 // Look for an operation that was replaced by `op`. The result type will be
627 // inferred from the result that was replaced. There is guaranteed to be a
628 // replacement for either the op, or this specific result. Note that this is
629 // guaranteed by the verifier of `pdl::OperationOp`.
630 Operation *replacedOp = nullptr;
631 if (!fullReplacedOperation.hasValue()) {
632 for (OpOperand &use : op.op().getUses())
633 if ((replacedOp = getReplacedOperationFrom(use)))
634 break;
635 fullReplacedOperation = replacedOp;
636 } else {
637 replacedOp = fullReplacedOperation.getValue();
638 }
639 // Infer from the result, as there was no fully replaced op.
640 if (!replacedOp) {
641 for (OpOperand &use : result.getUses())
642 if ((replacedOp = getReplacedOperationFrom(use)))
643 break;
644 assert(replacedOp && "expected replaced op to infer a result type from");
645 }
646
647 auto replOpOp = cast<pdl::OperationOp>(replacedOp);
648 types.push_back(mapRewriteValue(replOpOp.types()[it.index()]));
649 }
650 }
651
652 //===----------------------------------------------------------------------===//
653 // Conversion Pass
654 //===----------------------------------------------------------------------===//
655
656 namespace {
657 struct PDLToPDLInterpPass
658 : public ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> {
659 void runOnOperation() final;
660 };
661 } // namespace
662
663 /// Convert the given module containing PDL pattern operations into a PDL
664 /// Interpreter operations.
runOnOperation()665 void PDLToPDLInterpPass::runOnOperation() {
666 ModuleOp module = getOperation();
667
668 // Create the main matcher function This function contains all of the match
669 // related functionality from patterns in the module.
670 OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
671 FuncOp matcherFunc = builder.create<FuncOp>(
672 module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
673 builder.getFunctionType(builder.getType<pdl::OperationType>(),
674 /*results=*/llvm::None),
675 /*attrs=*/llvm::None);
676
677 // Create a nested module to hold the functions invoked for rewriting the IR
678 // after a successful match.
679 ModuleOp rewriterModule = builder.create<ModuleOp>(
680 module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
681
682 // Generate the code for the patterns within the module.
683 PatternLowering generator(matcherFunc, rewriterModule);
684 generator.lower(module);
685
686 // After generation, delete all of the pattern operations.
687 for (pdl::PatternOp pattern :
688 llvm::make_early_inc_range(module.getOps<pdl::PatternOp>()))
689 pattern.erase();
690 }
691
createPDLToPDLInterpPass()692 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() {
693 return std::make_unique<PDLToPDLInterpPass>();
694 }
695