1 //===- PDLInterp.cpp - PDL Interpreter Dialect ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
10 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/IR/DialectImplementation.h"
13 
14 using namespace mlir;
15 using namespace mlir::pdl_interp;
16 
17 //===----------------------------------------------------------------------===//
18 // PDLInterp Dialect
19 //===----------------------------------------------------------------------===//
20 
initialize()21 void PDLInterpDialect::initialize() {
22   addOperations<
23 #define GET_OP_LIST
24 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
25       >();
26 }
27 
28 //===----------------------------------------------------------------------===//
29 // pdl_interp::CreateOperationOp
30 //===----------------------------------------------------------------------===//
31 
parseCreateOperationOp(OpAsmParser & p,OperationState & state)32 static ParseResult parseCreateOperationOp(OpAsmParser &p,
33                                           OperationState &state) {
34   if (p.parseOptionalAttrDict(state.attributes))
35     return failure();
36   Builder &builder = p.getBuilder();
37 
38   // Parse the operation name.
39   StringAttr opName;
40   if (p.parseAttribute(opName, "name", state.attributes))
41     return failure();
42 
43   // Parse the operands.
44   SmallVector<OpAsmParser::OperandType, 4> operands;
45   if (p.parseLParen() || p.parseOperandList(operands) || p.parseRParen() ||
46       p.resolveOperands(operands, builder.getType<pdl::ValueType>(),
47                         state.operands))
48     return failure();
49 
50   // Parse the attributes.
51   SmallVector<Attribute, 4> attrNames;
52   if (succeeded(p.parseOptionalLBrace())) {
53     SmallVector<OpAsmParser::OperandType, 4> attrOps;
54     do {
55       StringAttr nameAttr;
56       OpAsmParser::OperandType operand;
57       if (p.parseAttribute(nameAttr) || p.parseEqual() ||
58           p.parseOperand(operand))
59         return failure();
60       attrNames.push_back(nameAttr);
61       attrOps.push_back(operand);
62     } while (succeeded(p.parseOptionalComma()));
63 
64     if (p.parseRBrace() ||
65         p.resolveOperands(attrOps, builder.getType<pdl::AttributeType>(),
66                           state.operands))
67       return failure();
68   }
69   state.addAttribute("attributeNames", builder.getArrayAttr(attrNames));
70   state.addTypes(builder.getType<pdl::OperationType>());
71 
72   // Parse the result types.
73   SmallVector<OpAsmParser::OperandType, 4> opResultTypes;
74   if (p.parseArrow())
75     return failure();
76   if (succeeded(p.parseOptionalLParen())) {
77     if (p.parseRParen())
78       return failure();
79   } else if (p.parseOperandList(opResultTypes) ||
80              p.resolveOperands(opResultTypes, builder.getType<pdl::TypeType>(),
81                                state.operands)) {
82     return failure();
83   }
84 
85   int32_t operandSegmentSizes[] = {static_cast<int32_t>(operands.size()),
86                                    static_cast<int32_t>(attrNames.size()),
87                                    static_cast<int32_t>(opResultTypes.size())};
88   state.addAttribute("operand_segment_sizes",
89                      builder.getI32VectorAttr(operandSegmentSizes));
90   return success();
91 }
92 
print(OpAsmPrinter & p,CreateOperationOp op)93 static void print(OpAsmPrinter &p, CreateOperationOp op) {
94   p << "pdl_interp.create_operation ";
95   p.printOptionalAttrDict(op.getAttrs(),
96                           {"attributeNames", "name", "operand_segment_sizes"});
97   p << '"' << op.name() << "\"(" << op.operands() << ')';
98 
99   // Emit the optional attributes.
100   ArrayAttr attrNames = op.attributeNames();
101   if (!attrNames.empty()) {
102     Operation::operand_range attrArgs = op.attributes();
103     p << " {";
104     interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
105                     [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
106     p << '}';
107   }
108 
109   // Print the result type constraints of the operation.
110   auto types = op.types();
111   if (types.empty())
112     p << " -> ()";
113   else
114     p << " -> " << op.types();
115 }
116 
117 //===----------------------------------------------------------------------===//
118 // TableGen Auto-Generated Op and Interface Definitions
119 //===----------------------------------------------------------------------===//
120 
121 #define GET_OP_CLASSES
122 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
123