1 //===- PDL.cpp - Pattern Descriptor Language Dialect ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/PDL/IR/PDL.h"
10 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/Interfaces/InferTypeOpInterface.h"
14 #include "llvm/ADT/StringSwitch.h"
15 
16 using namespace mlir;
17 using namespace mlir::pdl;
18 
19 //===----------------------------------------------------------------------===//
20 // PDLDialect
21 //===----------------------------------------------------------------------===//
22 
initialize()23 void PDLDialect::initialize() {
24   addOperations<
25 #define GET_OP_LIST
26 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
27       >();
28   addTypes<AttributeType, OperationType, TypeType, ValueType>();
29 }
30 
parseType(DialectAsmParser & parser) const31 Type PDLDialect::parseType(DialectAsmParser &parser) const {
32   StringRef keyword;
33   if (parser.parseKeyword(&keyword))
34     return Type();
35 
36   Builder &builder = parser.getBuilder();
37   Type result = StringSwitch<Type>(keyword)
38                     .Case("attribute", builder.getType<AttributeType>())
39                     .Case("operation", builder.getType<OperationType>())
40                     .Case("type", builder.getType<TypeType>())
41                     .Case("value", builder.getType<ValueType>())
42                     .Default(Type());
43   if (!result)
44     parser.emitError(parser.getNameLoc(), "invalid 'pdl' type: `")
45         << keyword << "'";
46   return result;
47 }
48 
printType(Type type,DialectAsmPrinter & printer) const49 void PDLDialect::printType(Type type, DialectAsmPrinter &printer) const {
50   if (type.isa<AttributeType>())
51     printer << "attribute";
52   else if (type.isa<OperationType>())
53     printer << "operation";
54   else if (type.isa<TypeType>())
55     printer << "type";
56   else if (type.isa<ValueType>())
57     printer << "value";
58   else
59     llvm_unreachable("unknown 'pdl' type");
60 }
61 
62 /// Returns true if the given operation is used by a "binding" pdl operation
63 /// within the main matcher body of a `pdl.pattern`.
64 static LogicalResult
verifyHasBindingUseInMatcher(Operation * op,StringRef bindableContextStr="`pdl.operation`")65 verifyHasBindingUseInMatcher(Operation *op,
66                              StringRef bindableContextStr = "`pdl.operation`") {
67   // If the pattern is not a pattern, there is nothing to do.
68   if (!isa<PatternOp>(op->getParentOp()))
69     return success();
70   Block *matcherBlock = op->getBlock();
71   for (Operation *user : op->getUsers()) {
72     if (user->getBlock() != matcherBlock)
73       continue;
74     if (isa<AttributeOp, InputOp, OperationOp, RewriteOp>(user))
75       return success();
76   }
77   return op->emitOpError()
78          << "expected a bindable (i.e. " << bindableContextStr
79          << ") user when defined in the matcher body of a `pdl.pattern`";
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // pdl::ApplyConstraintOp
84 //===----------------------------------------------------------------------===//
85 
verify(ApplyConstraintOp op)86 static LogicalResult verify(ApplyConstraintOp op) {
87   if (op.getNumOperands() == 0)
88     return op.emitOpError("expected at least one argument");
89   return success();
90 }
91 
92 //===----------------------------------------------------------------------===//
93 // pdl::AttributeOp
94 //===----------------------------------------------------------------------===//
95 
verify(AttributeOp op)96 static LogicalResult verify(AttributeOp op) {
97   Value attrType = op.type();
98   Optional<Attribute> attrValue = op.value();
99 
100   if (!attrValue && isa<RewriteOp>(op->getParentOp()))
101     return op.emitOpError("expected constant value when specified within a "
102                           "`pdl.rewrite`");
103   if (attrValue && attrType)
104     return op.emitOpError("expected only one of [`type`, `value`] to be set");
105   return verifyHasBindingUseInMatcher(op);
106 }
107 
108 //===----------------------------------------------------------------------===//
109 // pdl::InputOp
110 //===----------------------------------------------------------------------===//
111 
verify(InputOp op)112 static LogicalResult verify(InputOp op) {
113   return verifyHasBindingUseInMatcher(op);
114 }
115 
116 //===----------------------------------------------------------------------===//
117 // pdl::OperationOp
118 //===----------------------------------------------------------------------===//
119 
parseOperationOp(OpAsmParser & p,OperationState & state)120 static ParseResult parseOperationOp(OpAsmParser &p, OperationState &state) {
121   Builder &builder = p.getBuilder();
122 
123   // Parse the optional operation name.
124   bool startsWithOperands = succeeded(p.parseOptionalLParen());
125   bool startsWithAttributes =
126       !startsWithOperands && succeeded(p.parseOptionalLBrace());
127   bool startsWithOpName = false;
128   if (!startsWithAttributes && !startsWithOperands) {
129     StringAttr opName;
130     OptionalParseResult opNameResult =
131         p.parseOptionalAttribute(opName, "name", state.attributes);
132     startsWithOpName = opNameResult.hasValue();
133     if (startsWithOpName && failed(*opNameResult))
134       return failure();
135   }
136 
137   // Parse the operands.
138   SmallVector<OpAsmParser::OperandType, 4> operands;
139   if (startsWithOperands ||
140       (!startsWithAttributes && succeeded(p.parseOptionalLParen()))) {
141     if (p.parseOperandList(operands) || p.parseRParen() ||
142         p.resolveOperands(operands, builder.getType<ValueType>(),
143                           state.operands))
144       return failure();
145   }
146 
147   // Parse the attributes.
148   SmallVector<Attribute, 4> attrNames;
149   if (startsWithAttributes || succeeded(p.parseOptionalLBrace())) {
150     SmallVector<OpAsmParser::OperandType, 4> attrOps;
151     do {
152       StringAttr nameAttr;
153       OpAsmParser::OperandType operand;
154       if (p.parseAttribute(nameAttr) || p.parseEqual() ||
155           p.parseOperand(operand))
156         return failure();
157       attrNames.push_back(nameAttr);
158       attrOps.push_back(operand);
159     } while (succeeded(p.parseOptionalComma()));
160 
161     if (p.parseRBrace() ||
162         p.resolveOperands(attrOps, builder.getType<AttributeType>(),
163                           state.operands))
164       return failure();
165   }
166   state.addAttribute("attributeNames", builder.getArrayAttr(attrNames));
167   state.addTypes(builder.getType<OperationType>());
168 
169   // Parse the result types.
170   SmallVector<OpAsmParser::OperandType, 4> opResultTypes;
171   if (succeeded(p.parseOptionalArrow())) {
172     if (p.parseOperandList(opResultTypes) ||
173         p.resolveOperands(opResultTypes, builder.getType<TypeType>(),
174                           state.operands))
175       return failure();
176     state.types.append(opResultTypes.size(), builder.getType<ValueType>());
177   }
178 
179   if (p.parseOptionalAttrDict(state.attributes))
180     return failure();
181 
182   int32_t operandSegmentSizes[] = {static_cast<int32_t>(operands.size()),
183                                    static_cast<int32_t>(attrNames.size()),
184                                    static_cast<int32_t>(opResultTypes.size())};
185   state.addAttribute("operand_segment_sizes",
186                      builder.getI32VectorAttr(operandSegmentSizes));
187   return success();
188 }
189 
print(OpAsmPrinter & p,OperationOp op)190 static void print(OpAsmPrinter &p, OperationOp op) {
191   p << "pdl.operation ";
192   if (Optional<StringRef> name = op.name())
193     p << '"' << *name << '"';
194 
195   auto operandValues = op.operands();
196   if (!operandValues.empty())
197     p << '(' << operandValues << ')';
198 
199   // Emit the optional attributes.
200   ArrayAttr attrNames = op.attributeNames();
201   if (!attrNames.empty()) {
202     Operation::operand_range attrArgs = op.attributes();
203     p << " {";
204     interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
205                     [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
206     p << '}';
207   }
208 
209   // Print the result type constraints of the operation.
210   if (!op.results().empty())
211     p << " -> " << op.types();
212   p.printOptionalAttrDict(op.getAttrs(),
213                           {"attributeNames", "name", "operand_segment_sizes"});
214 }
215 
216 /// Verifies that the result types of this operation, defined within a
217 /// `pdl.rewrite`, can be inferred.
verifyResultTypesAreInferrable(OperationOp op,ResultRange opResults,OperandRange resultTypes)218 static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
219                                                     ResultRange opResults,
220                                                     OperandRange resultTypes) {
221   // Functor that returns if the given use can be used to infer a type.
222   Block *rewriterBlock = op->getBlock();
223   auto canInferTypeFromUse = [&](OpOperand &use) {
224     // If the use is within a ReplaceOp and isn't the operation being replaced
225     // (i.e. is not the first operand of the replacement), we can infer a type.
226     ReplaceOp replOpUser = dyn_cast<ReplaceOp>(use.getOwner());
227     if (!replOpUser || use.getOperandNumber() == 0)
228       return false;
229     // Make sure the replaced operation was defined before this one.
230     Operation *replacedOp = replOpUser.operation().getDefiningOp();
231     return replacedOp->getBlock() != rewriterBlock ||
232            replacedOp->isBeforeInBlock(op);
233   };
234 
235   // Check to see if the uses of the operation itself can be used to infer
236   // types.
237   if (llvm::any_of(op.op().getUses(), canInferTypeFromUse))
238     return success();
239 
240   // Otherwise, make sure each of the types can be inferred.
241   for (int i : llvm::seq<int>(0, opResults.size())) {
242     Operation *resultTypeOp = resultTypes[i].getDefiningOp();
243     assert(resultTypeOp && "expected valid result type operation");
244 
245     // If the op was defined by a `create_native`, it is guaranteed to be
246     // usable.
247     if (isa<CreateNativeOp>(resultTypeOp))
248       continue;
249 
250     // If the type is already constrained, there is nothing to do.
251     TypeOp typeOp = cast<TypeOp>(resultTypeOp);
252     if (typeOp.type())
253       continue;
254 
255     // If the type operation was defined in the matcher and constrains the
256     // result of an input operation, it can be used.
257     auto constrainsInputOp = [rewriterBlock](Operation *user) {
258       return user->getBlock() != rewriterBlock && isa<OperationOp>(user);
259     };
260     if (llvm::any_of(typeOp.getResult().getUsers(), constrainsInputOp))
261       continue;
262 
263     // Otherwise, check to see if any uses of the result can infer the type.
264     if (llvm::any_of(opResults[i].getUses(), canInferTypeFromUse))
265       continue;
266     return op
267         .emitOpError("must have inferable or constrained result types when "
268                      "nested within `pdl.rewrite`")
269         .attachNote()
270         .append("result type #", i, " was not constrained");
271   }
272   return success();
273 }
274 
verify(OperationOp op)275 static LogicalResult verify(OperationOp op) {
276   bool isWithinRewrite = isa<RewriteOp>(op->getParentOp());
277   if (isWithinRewrite && !op.name())
278     return op.emitOpError("must have an operation name when nested within "
279                           "a `pdl.rewrite`");
280   ArrayAttr attributeNames = op.attributeNames();
281   auto attributeValues = op.attributes();
282   if (attributeNames.size() != attributeValues.size()) {
283     return op.emitOpError()
284            << "expected the same number of attribute values and attribute "
285               "names, got "
286            << attributeNames.size() << " names and " << attributeValues.size()
287            << " values";
288   }
289 
290   OperandRange resultTypes = op.types();
291   auto opResults = op.results();
292   if (resultTypes.size() != opResults.size()) {
293     return op.emitOpError() << "expected the same number of result values and "
294                                "result type constraints, got "
295                             << opResults.size() << " results and "
296                             << resultTypes.size() << " constraints";
297   }
298 
299   // If the operation is within a rewrite body and doesn't have type inferrence,
300   // ensure that the result types can be resolved.
301   if (isWithinRewrite && !op.hasTypeInference()) {
302     if (failed(verifyResultTypesAreInferrable(op, opResults, resultTypes)))
303       return failure();
304   }
305 
306   return verifyHasBindingUseInMatcher(op, "`pdl.operation` or `pdl.rewrite`");
307 }
308 
hasTypeInference()309 bool OperationOp::hasTypeInference() {
310   Optional<StringRef> opName = name();
311   if (!opName)
312     return false;
313 
314   OperationName name(*opName, getContext());
315   if (const AbstractOperation *op = name.getAbstractOperation())
316     return op->getInterface<InferTypeOpInterface>();
317   return false;
318 }
319 
320 //===----------------------------------------------------------------------===//
321 // pdl::PatternOp
322 //===----------------------------------------------------------------------===//
323 
verify(PatternOp pattern)324 static LogicalResult verify(PatternOp pattern) {
325   Region &body = pattern.body();
326   auto *term = body.front().getTerminator();
327   if (!isa<RewriteOp>(term)) {
328     return pattern.emitOpError("expected body to terminate with `pdl.rewrite`")
329         .attachNote(term->getLoc())
330         .append("see terminator defined here");
331   }
332 
333   // Check that all values defined in the top-level pattern are referenced at
334   // least once in the source tree.
335   WalkResult result = body.walk([&](Operation *op) -> WalkResult {
336     if (!isa_and_nonnull<PDLDialect>(op->getDialect())) {
337       pattern
338           .emitOpError("expected only `pdl` operations within the pattern body")
339           .attachNote(op->getLoc())
340           .append("see non-`pdl` operation defined here");
341       return WalkResult::interrupt();
342     }
343     return WalkResult::advance();
344   });
345   return failure(result.wasInterrupted());
346 }
347 
build(OpBuilder & builder,OperationState & state,Optional<StringRef> rootKind,Optional<uint16_t> benefit,Optional<StringRef> name)348 void PatternOp::build(OpBuilder &builder, OperationState &state,
349                       Optional<StringRef> rootKind, Optional<uint16_t> benefit,
350                       Optional<StringRef> name) {
351   build(builder, state,
352         rootKind ? builder.getStringAttr(*rootKind) : StringAttr(),
353         builder.getI16IntegerAttr(benefit ? *benefit : 0),
354         name ? builder.getStringAttr(*name) : StringAttr());
355   builder.createBlock(state.addRegion());
356 }
357 
358 /// Returns the rewrite operation of this pattern.
getRewriter()359 RewriteOp PatternOp::getRewriter() {
360   return cast<RewriteOp>(body().front().getTerminator());
361 }
362 
363 /// Return the root operation kind that this pattern matches, or None if
364 /// there isn't a specific root.
getRootKind()365 Optional<StringRef> PatternOp::getRootKind() {
366   OperationOp rootOp = cast<OperationOp>(getRewriter().root().getDefiningOp());
367   return rootOp.name();
368 }
369 
370 //===----------------------------------------------------------------------===//
371 // pdl::ReplaceOp
372 //===----------------------------------------------------------------------===//
373 
verify(ReplaceOp op)374 static LogicalResult verify(ReplaceOp op) {
375   auto sourceOp = cast<OperationOp>(op.operation().getDefiningOp());
376   auto sourceOpResults = sourceOp.results();
377   auto replValues = op.replValues();
378 
379   if (Value replOpVal = op.replOperation()) {
380     auto replOp = cast<OperationOp>(replOpVal.getDefiningOp());
381     auto replOpResults = replOp.results();
382     if (sourceOpResults.size() != replOpResults.size()) {
383       return op.emitOpError()
384              << "expected source operation to have the same number of results "
385                 "as the replacement operation, replacement operation provided "
386              << replOpResults.size() << " but expected "
387              << sourceOpResults.size();
388     }
389 
390     if (!replValues.empty()) {
391       return op.emitOpError() << "expected no replacement values to be provided"
392                                  " when the replacement operation is present";
393     }
394 
395     return success();
396   }
397 
398   if (sourceOpResults.size() != replValues.size()) {
399     return op.emitOpError()
400            << "expected source operation to have the same number of results "
401               "as the provided replacement values, found "
402            << replValues.size() << " replacement values but expected "
403            << sourceOpResults.size();
404   }
405 
406   return success();
407 }
408 
409 //===----------------------------------------------------------------------===//
410 // pdl::RewriteOp
411 //===----------------------------------------------------------------------===//
412 
verify(RewriteOp op)413 static LogicalResult verify(RewriteOp op) {
414   Region &rewriteRegion = op.body();
415 
416   // Handle the case where the rewrite is external.
417   if (op.name()) {
418     if (!rewriteRegion.empty()) {
419       return op.emitOpError()
420              << "expected rewrite region to be empty when rewrite is external";
421     }
422     return success();
423   }
424 
425   // Otherwise, check that the rewrite region only contains a single block.
426   if (rewriteRegion.empty()) {
427     return op.emitOpError() << "expected rewrite region to be non-empty if "
428                                "external name is not specified";
429   }
430 
431   // Check that no additional arguments were provided.
432   if (!op.externalArgs().empty()) {
433     return op.emitOpError() << "expected no external arguments when the "
434                                "rewrite is specified inline";
435   }
436   if (op.externalConstParams()) {
437     return op.emitOpError() << "expected no external constant parameters when "
438                                "the rewrite is specified inline";
439   }
440 
441   return success();
442 }
443 
444 //===----------------------------------------------------------------------===//
445 // pdl::TypeOp
446 //===----------------------------------------------------------------------===//
447 
verify(TypeOp op)448 static LogicalResult verify(TypeOp op) {
449   return verifyHasBindingUseInMatcher(
450       op, "`pdl.attribute`, `pdl.input`, or `pdl.operation`");
451 }
452 
453 //===----------------------------------------------------------------------===//
454 // TableGen'd op method definitions
455 //===----------------------------------------------------------------------===//
456 
457 #define GET_OP_CLASSES
458 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
459