1 //===- PDL.cpp - Pattern Descriptor Language Dialect ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/PDL/IR/PDL.h"
10 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/Interfaces/InferTypeOpInterface.h"
14 #include "llvm/ADT/StringSwitch.h"
15
16 using namespace mlir;
17 using namespace mlir::pdl;
18
19 //===----------------------------------------------------------------------===//
20 // PDLDialect
21 //===----------------------------------------------------------------------===//
22
initialize()23 void PDLDialect::initialize() {
24 addOperations<
25 #define GET_OP_LIST
26 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
27 >();
28 addTypes<AttributeType, OperationType, TypeType, ValueType>();
29 }
30
parseType(DialectAsmParser & parser) const31 Type PDLDialect::parseType(DialectAsmParser &parser) const {
32 StringRef keyword;
33 if (parser.parseKeyword(&keyword))
34 return Type();
35
36 Builder &builder = parser.getBuilder();
37 Type result = StringSwitch<Type>(keyword)
38 .Case("attribute", builder.getType<AttributeType>())
39 .Case("operation", builder.getType<OperationType>())
40 .Case("type", builder.getType<TypeType>())
41 .Case("value", builder.getType<ValueType>())
42 .Default(Type());
43 if (!result)
44 parser.emitError(parser.getNameLoc(), "invalid 'pdl' type: `")
45 << keyword << "'";
46 return result;
47 }
48
printType(Type type,DialectAsmPrinter & printer) const49 void PDLDialect::printType(Type type, DialectAsmPrinter &printer) const {
50 if (type.isa<AttributeType>())
51 printer << "attribute";
52 else if (type.isa<OperationType>())
53 printer << "operation";
54 else if (type.isa<TypeType>())
55 printer << "type";
56 else if (type.isa<ValueType>())
57 printer << "value";
58 else
59 llvm_unreachable("unknown 'pdl' type");
60 }
61
62 /// Returns true if the given operation is used by a "binding" pdl operation
63 /// within the main matcher body of a `pdl.pattern`.
64 static LogicalResult
verifyHasBindingUseInMatcher(Operation * op,StringRef bindableContextStr="`pdl.operation`")65 verifyHasBindingUseInMatcher(Operation *op,
66 StringRef bindableContextStr = "`pdl.operation`") {
67 // If the pattern is not a pattern, there is nothing to do.
68 if (!isa<PatternOp>(op->getParentOp()))
69 return success();
70 Block *matcherBlock = op->getBlock();
71 for (Operation *user : op->getUsers()) {
72 if (user->getBlock() != matcherBlock)
73 continue;
74 if (isa<AttributeOp, InputOp, OperationOp, RewriteOp>(user))
75 return success();
76 }
77 return op->emitOpError()
78 << "expected a bindable (i.e. " << bindableContextStr
79 << ") user when defined in the matcher body of a `pdl.pattern`";
80 }
81
82 //===----------------------------------------------------------------------===//
83 // pdl::ApplyConstraintOp
84 //===----------------------------------------------------------------------===//
85
verify(ApplyConstraintOp op)86 static LogicalResult verify(ApplyConstraintOp op) {
87 if (op.getNumOperands() == 0)
88 return op.emitOpError("expected at least one argument");
89 return success();
90 }
91
92 //===----------------------------------------------------------------------===//
93 // pdl::AttributeOp
94 //===----------------------------------------------------------------------===//
95
verify(AttributeOp op)96 static LogicalResult verify(AttributeOp op) {
97 Value attrType = op.type();
98 Optional<Attribute> attrValue = op.value();
99
100 if (!attrValue && isa<RewriteOp>(op->getParentOp()))
101 return op.emitOpError("expected constant value when specified within a "
102 "`pdl.rewrite`");
103 if (attrValue && attrType)
104 return op.emitOpError("expected only one of [`type`, `value`] to be set");
105 return verifyHasBindingUseInMatcher(op);
106 }
107
108 //===----------------------------------------------------------------------===//
109 // pdl::InputOp
110 //===----------------------------------------------------------------------===//
111
verify(InputOp op)112 static LogicalResult verify(InputOp op) {
113 return verifyHasBindingUseInMatcher(op);
114 }
115
116 //===----------------------------------------------------------------------===//
117 // pdl::OperationOp
118 //===----------------------------------------------------------------------===//
119
parseOperationOp(OpAsmParser & p,OperationState & state)120 static ParseResult parseOperationOp(OpAsmParser &p, OperationState &state) {
121 Builder &builder = p.getBuilder();
122
123 // Parse the optional operation name.
124 bool startsWithOperands = succeeded(p.parseOptionalLParen());
125 bool startsWithAttributes =
126 !startsWithOperands && succeeded(p.parseOptionalLBrace());
127 bool startsWithOpName = false;
128 if (!startsWithAttributes && !startsWithOperands) {
129 StringAttr opName;
130 OptionalParseResult opNameResult =
131 p.parseOptionalAttribute(opName, "name", state.attributes);
132 startsWithOpName = opNameResult.hasValue();
133 if (startsWithOpName && failed(*opNameResult))
134 return failure();
135 }
136
137 // Parse the operands.
138 SmallVector<OpAsmParser::OperandType, 4> operands;
139 if (startsWithOperands ||
140 (!startsWithAttributes && succeeded(p.parseOptionalLParen()))) {
141 if (p.parseOperandList(operands) || p.parseRParen() ||
142 p.resolveOperands(operands, builder.getType<ValueType>(),
143 state.operands))
144 return failure();
145 }
146
147 // Parse the attributes.
148 SmallVector<Attribute, 4> attrNames;
149 if (startsWithAttributes || succeeded(p.parseOptionalLBrace())) {
150 SmallVector<OpAsmParser::OperandType, 4> attrOps;
151 do {
152 StringAttr nameAttr;
153 OpAsmParser::OperandType operand;
154 if (p.parseAttribute(nameAttr) || p.parseEqual() ||
155 p.parseOperand(operand))
156 return failure();
157 attrNames.push_back(nameAttr);
158 attrOps.push_back(operand);
159 } while (succeeded(p.parseOptionalComma()));
160
161 if (p.parseRBrace() ||
162 p.resolveOperands(attrOps, builder.getType<AttributeType>(),
163 state.operands))
164 return failure();
165 }
166 state.addAttribute("attributeNames", builder.getArrayAttr(attrNames));
167 state.addTypes(builder.getType<OperationType>());
168
169 // Parse the result types.
170 SmallVector<OpAsmParser::OperandType, 4> opResultTypes;
171 if (succeeded(p.parseOptionalArrow())) {
172 if (p.parseOperandList(opResultTypes) ||
173 p.resolveOperands(opResultTypes, builder.getType<TypeType>(),
174 state.operands))
175 return failure();
176 state.types.append(opResultTypes.size(), builder.getType<ValueType>());
177 }
178
179 if (p.parseOptionalAttrDict(state.attributes))
180 return failure();
181
182 int32_t operandSegmentSizes[] = {static_cast<int32_t>(operands.size()),
183 static_cast<int32_t>(attrNames.size()),
184 static_cast<int32_t>(opResultTypes.size())};
185 state.addAttribute("operand_segment_sizes",
186 builder.getI32VectorAttr(operandSegmentSizes));
187 return success();
188 }
189
print(OpAsmPrinter & p,OperationOp op)190 static void print(OpAsmPrinter &p, OperationOp op) {
191 p << "pdl.operation ";
192 if (Optional<StringRef> name = op.name())
193 p << '"' << *name << '"';
194
195 auto operandValues = op.operands();
196 if (!operandValues.empty())
197 p << '(' << operandValues << ')';
198
199 // Emit the optional attributes.
200 ArrayAttr attrNames = op.attributeNames();
201 if (!attrNames.empty()) {
202 Operation::operand_range attrArgs = op.attributes();
203 p << " {";
204 interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
205 [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
206 p << '}';
207 }
208
209 // Print the result type constraints of the operation.
210 if (!op.results().empty())
211 p << " -> " << op.types();
212 p.printOptionalAttrDict(op.getAttrs(),
213 {"attributeNames", "name", "operand_segment_sizes"});
214 }
215
216 /// Verifies that the result types of this operation, defined within a
217 /// `pdl.rewrite`, can be inferred.
verifyResultTypesAreInferrable(OperationOp op,ResultRange opResults,OperandRange resultTypes)218 static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
219 ResultRange opResults,
220 OperandRange resultTypes) {
221 // Functor that returns if the given use can be used to infer a type.
222 Block *rewriterBlock = op->getBlock();
223 auto canInferTypeFromUse = [&](OpOperand &use) {
224 // If the use is within a ReplaceOp and isn't the operation being replaced
225 // (i.e. is not the first operand of the replacement), we can infer a type.
226 ReplaceOp replOpUser = dyn_cast<ReplaceOp>(use.getOwner());
227 if (!replOpUser || use.getOperandNumber() == 0)
228 return false;
229 // Make sure the replaced operation was defined before this one.
230 Operation *replacedOp = replOpUser.operation().getDefiningOp();
231 return replacedOp->getBlock() != rewriterBlock ||
232 replacedOp->isBeforeInBlock(op);
233 };
234
235 // Check to see if the uses of the operation itself can be used to infer
236 // types.
237 if (llvm::any_of(op.op().getUses(), canInferTypeFromUse))
238 return success();
239
240 // Otherwise, make sure each of the types can be inferred.
241 for (int i : llvm::seq<int>(0, opResults.size())) {
242 Operation *resultTypeOp = resultTypes[i].getDefiningOp();
243 assert(resultTypeOp && "expected valid result type operation");
244
245 // If the op was defined by a `create_native`, it is guaranteed to be
246 // usable.
247 if (isa<CreateNativeOp>(resultTypeOp))
248 continue;
249
250 // If the type is already constrained, there is nothing to do.
251 TypeOp typeOp = cast<TypeOp>(resultTypeOp);
252 if (typeOp.type())
253 continue;
254
255 // If the type operation was defined in the matcher and constrains the
256 // result of an input operation, it can be used.
257 auto constrainsInputOp = [rewriterBlock](Operation *user) {
258 return user->getBlock() != rewriterBlock && isa<OperationOp>(user);
259 };
260 if (llvm::any_of(typeOp.getResult().getUsers(), constrainsInputOp))
261 continue;
262
263 // Otherwise, check to see if any uses of the result can infer the type.
264 if (llvm::any_of(opResults[i].getUses(), canInferTypeFromUse))
265 continue;
266 return op
267 .emitOpError("must have inferable or constrained result types when "
268 "nested within `pdl.rewrite`")
269 .attachNote()
270 .append("result type #", i, " was not constrained");
271 }
272 return success();
273 }
274
verify(OperationOp op)275 static LogicalResult verify(OperationOp op) {
276 bool isWithinRewrite = isa<RewriteOp>(op->getParentOp());
277 if (isWithinRewrite && !op.name())
278 return op.emitOpError("must have an operation name when nested within "
279 "a `pdl.rewrite`");
280 ArrayAttr attributeNames = op.attributeNames();
281 auto attributeValues = op.attributes();
282 if (attributeNames.size() != attributeValues.size()) {
283 return op.emitOpError()
284 << "expected the same number of attribute values and attribute "
285 "names, got "
286 << attributeNames.size() << " names and " << attributeValues.size()
287 << " values";
288 }
289
290 OperandRange resultTypes = op.types();
291 auto opResults = op.results();
292 if (resultTypes.size() != opResults.size()) {
293 return op.emitOpError() << "expected the same number of result values and "
294 "result type constraints, got "
295 << opResults.size() << " results and "
296 << resultTypes.size() << " constraints";
297 }
298
299 // If the operation is within a rewrite body and doesn't have type inferrence,
300 // ensure that the result types can be resolved.
301 if (isWithinRewrite && !op.hasTypeInference()) {
302 if (failed(verifyResultTypesAreInferrable(op, opResults, resultTypes)))
303 return failure();
304 }
305
306 return verifyHasBindingUseInMatcher(op, "`pdl.operation` or `pdl.rewrite`");
307 }
308
hasTypeInference()309 bool OperationOp::hasTypeInference() {
310 Optional<StringRef> opName = name();
311 if (!opName)
312 return false;
313
314 OperationName name(*opName, getContext());
315 if (const AbstractOperation *op = name.getAbstractOperation())
316 return op->getInterface<InferTypeOpInterface>();
317 return false;
318 }
319
320 //===----------------------------------------------------------------------===//
321 // pdl::PatternOp
322 //===----------------------------------------------------------------------===//
323
verify(PatternOp pattern)324 static LogicalResult verify(PatternOp pattern) {
325 Region &body = pattern.body();
326 auto *term = body.front().getTerminator();
327 if (!isa<RewriteOp>(term)) {
328 return pattern.emitOpError("expected body to terminate with `pdl.rewrite`")
329 .attachNote(term->getLoc())
330 .append("see terminator defined here");
331 }
332
333 // Check that all values defined in the top-level pattern are referenced at
334 // least once in the source tree.
335 WalkResult result = body.walk([&](Operation *op) -> WalkResult {
336 if (!isa_and_nonnull<PDLDialect>(op->getDialect())) {
337 pattern
338 .emitOpError("expected only `pdl` operations within the pattern body")
339 .attachNote(op->getLoc())
340 .append("see non-`pdl` operation defined here");
341 return WalkResult::interrupt();
342 }
343 return WalkResult::advance();
344 });
345 return failure(result.wasInterrupted());
346 }
347
build(OpBuilder & builder,OperationState & state,Optional<StringRef> rootKind,Optional<uint16_t> benefit,Optional<StringRef> name)348 void PatternOp::build(OpBuilder &builder, OperationState &state,
349 Optional<StringRef> rootKind, Optional<uint16_t> benefit,
350 Optional<StringRef> name) {
351 build(builder, state,
352 rootKind ? builder.getStringAttr(*rootKind) : StringAttr(),
353 builder.getI16IntegerAttr(benefit ? *benefit : 0),
354 name ? builder.getStringAttr(*name) : StringAttr());
355 builder.createBlock(state.addRegion());
356 }
357
358 /// Returns the rewrite operation of this pattern.
getRewriter()359 RewriteOp PatternOp::getRewriter() {
360 return cast<RewriteOp>(body().front().getTerminator());
361 }
362
363 /// Return the root operation kind that this pattern matches, or None if
364 /// there isn't a specific root.
getRootKind()365 Optional<StringRef> PatternOp::getRootKind() {
366 OperationOp rootOp = cast<OperationOp>(getRewriter().root().getDefiningOp());
367 return rootOp.name();
368 }
369
370 //===----------------------------------------------------------------------===//
371 // pdl::ReplaceOp
372 //===----------------------------------------------------------------------===//
373
verify(ReplaceOp op)374 static LogicalResult verify(ReplaceOp op) {
375 auto sourceOp = cast<OperationOp>(op.operation().getDefiningOp());
376 auto sourceOpResults = sourceOp.results();
377 auto replValues = op.replValues();
378
379 if (Value replOpVal = op.replOperation()) {
380 auto replOp = cast<OperationOp>(replOpVal.getDefiningOp());
381 auto replOpResults = replOp.results();
382 if (sourceOpResults.size() != replOpResults.size()) {
383 return op.emitOpError()
384 << "expected source operation to have the same number of results "
385 "as the replacement operation, replacement operation provided "
386 << replOpResults.size() << " but expected "
387 << sourceOpResults.size();
388 }
389
390 if (!replValues.empty()) {
391 return op.emitOpError() << "expected no replacement values to be provided"
392 " when the replacement operation is present";
393 }
394
395 return success();
396 }
397
398 if (sourceOpResults.size() != replValues.size()) {
399 return op.emitOpError()
400 << "expected source operation to have the same number of results "
401 "as the provided replacement values, found "
402 << replValues.size() << " replacement values but expected "
403 << sourceOpResults.size();
404 }
405
406 return success();
407 }
408
409 //===----------------------------------------------------------------------===//
410 // pdl::RewriteOp
411 //===----------------------------------------------------------------------===//
412
verify(RewriteOp op)413 static LogicalResult verify(RewriteOp op) {
414 Region &rewriteRegion = op.body();
415
416 // Handle the case where the rewrite is external.
417 if (op.name()) {
418 if (!rewriteRegion.empty()) {
419 return op.emitOpError()
420 << "expected rewrite region to be empty when rewrite is external";
421 }
422 return success();
423 }
424
425 // Otherwise, check that the rewrite region only contains a single block.
426 if (rewriteRegion.empty()) {
427 return op.emitOpError() << "expected rewrite region to be non-empty if "
428 "external name is not specified";
429 }
430
431 // Check that no additional arguments were provided.
432 if (!op.externalArgs().empty()) {
433 return op.emitOpError() << "expected no external arguments when the "
434 "rewrite is specified inline";
435 }
436 if (op.externalConstParams()) {
437 return op.emitOpError() << "expected no external constant parameters when "
438 "the rewrite is specified inline";
439 }
440
441 return success();
442 }
443
444 //===----------------------------------------------------------------------===//
445 // pdl::TypeOp
446 //===----------------------------------------------------------------------===//
447
verify(TypeOp op)448 static LogicalResult verify(TypeOp op) {
449 return verifyHasBindingUseInMatcher(
450 op, "`pdl.attribute`, `pdl.input`, or `pdl.operation`");
451 }
452
453 //===----------------------------------------------------------------------===//
454 // TableGen'd op method definitions
455 //===----------------------------------------------------------------------===//
456
457 #define GET_OP_CLASSES
458 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
459