1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
10 // binding classes wrapping a generic operation API.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/TableGen/GenInfo.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "llvm/ADT/StringSet.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
21
22 using namespace mlir;
23 using namespace mlir::tblgen;
24
25 /// File header and includes.
26 constexpr const char *fileHeader = R"Py(
27 # Autogenerated by mlir-tblgen; don't manually edit.
28
29 import array
30 from . import _cext
31 from . import _segmented_accessor, _equally_sized_accessor, _get_default_loc_context
32 _ir = _cext.ir
33 )Py";
34
35 /// Template for dialect class:
36 /// {0} is the dialect namespace.
37 constexpr const char *dialectClassTemplate = R"Py(
38 @_cext.register_dialect
39 class _Dialect(_ir.Dialect):
40 DIALECT_NAMESPACE = "{0}"
41 pass
42
43 )Py";
44
45 /// Template for operation class:
46 /// {0} is the Python class name;
47 /// {1} is the operation name.
48 constexpr const char *opClassTemplate = R"Py(
49 @_cext.register_operation(_Dialect)
50 class {0}(_ir.OpView):
51 OPERATION_NAME = "{1}"
52 )Py";
53
54 /// Template for single-element accessor:
55 /// {0} is the name of the accessor;
56 /// {1} is either 'operand' or 'result';
57 /// {2} is the position in the element list.
58 constexpr const char *opSingleTemplate = R"Py(
59 @property
60 def {0}(self):
61 return self.operation.{1}s[{2}]
62 )Py";
63
64 /// Template for single-element accessor after a variable-length group:
65 /// {0} is the name of the accessor;
66 /// {1} is either 'operand' or 'result';
67 /// {2} is the total number of element groups;
68 /// {3} is the position of the current group in the group list.
69 /// This works for both a single variadic group (non-negative length) and an
70 /// single optional element (zero length if the element is absent).
71 constexpr const char *opSingleAfterVariableTemplate = R"Py(
72 @property
73 def {0}(self):
74 variadic_group_length = len(self.operation.{1}s) - {2} + 1
75 return self.operation.{1}s[{3} + variadic_group_length - 1]
76 )Py";
77
78 /// Template for an optional element accessor:
79 /// {0} is the name of the accessor;
80 /// {1} is either 'operand' or 'result';
81 /// {2} is the total number of element groups;
82 /// {3} is the position of the current group in the group list.
83 constexpr const char *opOneOptionalTemplate = R"Py(
84 @property
85 def {0}(self);
86 return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2}
87 else None
88 )Py";
89
90 /// Template for the variadic group accessor in the single variadic group case:
91 /// {0} is the name of the accessor;
92 /// {1} is either 'operand' or 'result';
93 /// {2} is the total number of element groups;
94 /// {3} is the position of the current group in the group list.
95 constexpr const char *opOneVariadicTemplate = R"Py(
96 @property
97 def {0}(self):
98 variadic_group_length = len(self.operation.{1}s) - {2} + 1
99 return self.operation.{1}s[{3}:{3} + variadic_group_length]
100 )Py";
101
102 /// First part of the template for equally-sized variadic group accessor:
103 /// {0} is the name of the accessor;
104 /// {1} is either 'operand' or 'result';
105 /// {2} is the total number of variadic groups;
106 /// {3} is the number of non-variadic groups preceding the current group;
107 /// {3} is the number of variadic groups preceding the current group.
108 constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
109 @property
110 def {0}(self):
111 start, pg = _equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py";
112
113 /// Second part of the template for equally-sized case, accessing a single
114 /// element:
115 /// {0} is either 'operand' or 'result'.
116 constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
117 return self.operation.{0}s[start]
118 )Py";
119
120 /// Second part of the template for equally-sized case, accessing a variadic
121 /// group:
122 /// {0} is either 'operand' or 'result'.
123 constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
124 return self.operation.{0}s[start:start + pg]
125 )Py";
126
127 /// Template for an attribute-sized group accessor:
128 /// {0} is the name of the accessor;
129 /// {1} is either 'operand' or 'result';
130 /// {2} is the position of the group in the group list;
131 /// {3} is a return suffix (expected [0] for single-element, empty for
132 /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
133 constexpr const char *opVariadicSegmentTemplate = R"Py(
134 @property
135 def {0}(self):
136 {1}_range = _segmented_accessor(
137 self.operation.{1}s,
138 self.operation.attributes["{1}_segment_sizes"], {2})
139 return {1}_range{3}
140 )Py";
141
142 /// Template for a suffix when accessing an optional element in the
143 /// attribute-sized case:
144 /// {0} is either 'operand' or 'result';
145 constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
146 R"Py([0] if len({0}_range) > 0 else None)Py";
147
148 /// Template for an operation attribute getter:
149 /// {0} is the name of the attribute sanitized for Python;
150 /// {1} is the Python type of the attribute;
151 /// {2} os the original name of the attribute.
152 constexpr const char *attributeGetterTemplate = R"Py(
153 @property
154 def {0}(self):
155 return {1}(self.operation.attributes["{2}"])
156 )Py";
157
158 /// Template for an optional operation attribute getter:
159 /// {0} is the name of the attribute sanitized for Python;
160 /// {1} is the Python type of the attribute;
161 /// {2} is the original name of the attribute.
162 constexpr const char *optionalAttributeGetterTemplate = R"Py(
163 @property
164 def {0}(self):
165 if "{2}" not in self.operation.attributes:
166 return None
167 return {1}(self.operation.attributes["{2}"])
168 )Py";
169
170 /// Template for a getter of a unit operation attribute, returns True of the
171 /// unit attribute is present, False otherwise (unit attributes have meaning
172 /// by mere presence):
173 /// {0} is the name of the attribute sanitized for Python,
174 /// {1} is the original name of the attribute.
175 constexpr const char *unitAttributeGetterTemplate = R"Py(
176 @property
177 def {0}(self):
178 return "{1}" in self.operation.attributes
179 )Py";
180
181 /// Template for an operation attribute setter:
182 /// {0} is the name of the attribute sanitized for Python;
183 /// {1} is the original name of the attribute.
184 constexpr const char *attributeSetterTemplate = R"Py(
185 @{0}.setter
186 def {0}(self, value):
187 if value is None:
188 raise ValueError("'None' not allowed as value for mandatory attributes")
189 self.operation.attributes["{1}"] = value
190 )Py";
191
192 /// Template for a setter of an optional operation attribute, setting to None
193 /// removes the attribute:
194 /// {0} is the name of the attribute sanitized for Python;
195 /// {1} is the original name of the attribute.
196 constexpr const char *optionalAttributeSetterTemplate = R"Py(
197 @{0}.setter
198 def {0}(self, value):
199 if value is not None:
200 self.operation.attributes["{1}"] = value
201 elif "{1}" in self.operation.attributes:
202 del self.operation.attributes["{1}"]
203 )Py";
204
205 /// Template for a setter of a unit operation attribute, setting to None or
206 /// False removes the attribute:
207 /// {0} is the name of the attribute sanitized for Python;
208 /// {1} is the original name of the attribute.
209 constexpr const char *unitAttributeSetterTemplate = R"Py(
210 @{0}.setter
211 def {0}(self, value):
212 if bool(value):
213 self.operation.attributes["{1}"] = _ir.UnitAttr.get()
214 elif "{1}" in self.operation.attributes:
215 del self.operation.attributes["{1}"]
216 )Py";
217
218 /// Template for a deleter of an optional or a unit operation attribute, removes
219 /// the attribute from the operation:
220 /// {0} is the name of the attribute sanitized for Python;
221 /// {1} is the original name of the attribute.
222 constexpr const char *attributeDeleterTemplate = R"Py(
223 @{0}.deleter
224 def {0}(self):
225 del self.operation.attributes["{1}"]
226 )Py";
227
228 static llvm::cl::OptionCategory
229 clOpPythonBindingCat("Options for -gen-python-op-bindings");
230
231 static llvm::cl::opt<std::string>
232 clDialectName("bind-dialect",
233 llvm::cl::desc("The dialect to run the generator for"),
234 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
235
236 using AttributeClasses = DenseMap<StringRef, StringRef>;
237
238 /// Checks whether `str` is a Python keyword.
isPythonKeyword(StringRef str)239 static bool isPythonKeyword(StringRef str) {
240 static llvm::StringSet<> keywords(
241 {"and", "as", "assert", "break", "class", "continue",
242 "def", "del", "elif", "else", "except", "finally",
243 "for", "from", "global", "if", "import", "in",
244 "is", "lambda", "nonlocal", "not", "or", "pass",
245 "raise", "return", "try", "while", "with", "yield"});
246 return keywords.contains(str);
247 };
248
249 /// Modifies the `name` in a way that it becomes suitable for Python bindings
250 /// (does not change the `name` if it already is suitable) and returns the
251 /// modified version.
sanitizeName(StringRef name)252 static std::string sanitizeName(StringRef name) {
253 if (isPythonKeyword(name))
254 return (name + "_").str();
255 return name.str();
256 }
257
attrSizedTraitForKind(const char * kind)258 static std::string attrSizedTraitForKind(const char *kind) {
259 return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
260 llvm::StringRef(kind).take_front().upper(),
261 llvm::StringRef(kind).drop_front());
262 }
263
264 /// Emits accessors to "elements" of an Op definition. Currently, the supported
265 /// elements are operands and results, indicated by `kind`, which must be either
266 /// `operand` or `result` and is used verbatim in the emitted code.
emitElementAccessors(const Operator & op,raw_ostream & os,const char * kind,llvm::function_ref<unsigned (const Operator &)> getNumVariadic,llvm::function_ref<int (const Operator &)> getNumElements,llvm::function_ref<const NamedTypeConstraint & (const Operator &,int)> getElement)267 static void emitElementAccessors(
268 const Operator &op, raw_ostream &os, const char *kind,
269 llvm::function_ref<unsigned(const Operator &)> getNumVariadic,
270 llvm::function_ref<int(const Operator &)> getNumElements,
271 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
272 getElement) {
273 assert(llvm::is_contained(
274 llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) &&
275 "unsupported kind");
276
277 // Traits indicating how to process variadic elements.
278 std::string sameSizeTrait =
279 llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
280 llvm::StringRef(kind).take_front().upper(),
281 llvm::StringRef(kind).drop_front());
282 std::string attrSizedTrait = attrSizedTraitForKind(kind);
283
284 unsigned numVariadic = getNumVariadic(op);
285
286 // If there is only one variadic element group, its size can be inferred from
287 // the total number of elements. If there are none, the generation is
288 // straightforward.
289 if (numVariadic <= 1) {
290 bool seenVariableLength = false;
291 for (int i = 0, e = getNumElements(op); i < e; ++i) {
292 const NamedTypeConstraint &element = getElement(op, i);
293 if (element.isVariableLength())
294 seenVariableLength = true;
295 if (element.name.empty())
296 continue;
297 if (element.isVariableLength()) {
298 os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate
299 : opOneVariadicTemplate,
300 sanitizeName(element.name), kind,
301 getNumElements(op), i);
302 } else if (seenVariableLength) {
303 os << llvm::formatv(opSingleAfterVariableTemplate,
304 sanitizeName(element.name), kind,
305 getNumElements(op), i);
306 } else {
307 os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind,
308 i);
309 }
310 }
311 return;
312 }
313
314 // Handle the operations where variadic groups have the same size.
315 if (op.getTrait(sameSizeTrait)) {
316 int numPrecedingSimple = 0;
317 int numPrecedingVariadic = 0;
318 for (int i = 0, e = getNumElements(op); i < e; ++i) {
319 const NamedTypeConstraint &element = getElement(op, i);
320 if (!element.name.empty()) {
321 os << llvm::formatv(opVariadicEqualPrefixTemplate,
322 sanitizeName(element.name), kind, numVariadic,
323 numPrecedingSimple, numPrecedingVariadic);
324 os << llvm::formatv(element.isVariableLength()
325 ? opVariadicEqualVariadicTemplate
326 : opVariadicEqualSimpleTemplate,
327 kind);
328 }
329 if (element.isVariableLength())
330 ++numPrecedingVariadic;
331 else
332 ++numPrecedingSimple;
333 }
334 return;
335 }
336
337 // Handle the operations where the size of groups (variadic or not) is
338 // provided as an attribute. For non-variadic elements, make sure to return
339 // an element rather than a singleton container.
340 if (op.getTrait(attrSizedTrait)) {
341 for (int i = 0, e = getNumElements(op); i < e; ++i) {
342 const NamedTypeConstraint &element = getElement(op, i);
343 if (element.name.empty())
344 continue;
345 std::string trailing;
346 if (!element.isVariableLength())
347 trailing = "[0]";
348 else if (element.isOptional())
349 trailing = std::string(
350 llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
351 os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name),
352 kind, i, trailing);
353 }
354 return;
355 }
356
357 llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure");
358 }
359
360 /// Free function helpers accessing Operator components.
getNumOperands(const Operator & op)361 static int getNumOperands(const Operator &op) { return op.getNumOperands(); }
getOperand(const Operator & op,int i)362 static const NamedTypeConstraint &getOperand(const Operator &op, int i) {
363 return op.getOperand(i);
364 }
getNumResults(const Operator & op)365 static int getNumResults(const Operator &op) { return op.getNumResults(); }
getResult(const Operator & op,int i)366 static const NamedTypeConstraint &getResult(const Operator &op, int i) {
367 return op.getResult(i);
368 }
369
370 /// Emits accessors to Op operands.
emitOperandAccessors(const Operator & op,raw_ostream & os)371 static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
372 auto getNumVariadic = [](const Operator &oper) {
373 return oper.getNumVariableLengthOperands();
374 };
375 emitElementAccessors(op, os, "operand", getNumVariadic, getNumOperands,
376 getOperand);
377 }
378
379 /// Emits accessors Op results.
emitResultAccessors(const Operator & op,raw_ostream & os)380 static void emitResultAccessors(const Operator &op, raw_ostream &os) {
381 auto getNumVariadic = [](const Operator &oper) {
382 return oper.getNumVariableLengthResults();
383 };
384 emitElementAccessors(op, os, "result", getNumVariadic, getNumResults,
385 getResult);
386 }
387
388 /// Emits accessors to Op attributes.
emitAttributeAccessors(const Operator & op,const AttributeClasses & attributeClasses,raw_ostream & os)389 static void emitAttributeAccessors(const Operator &op,
390 const AttributeClasses &attributeClasses,
391 raw_ostream &os) {
392 for (const auto &namedAttr : op.getAttributes()) {
393 // Skip "derived" attributes because they are just C++ functions that we
394 // don't currently expose.
395 if (namedAttr.attr.isDerivedAttr())
396 continue;
397
398 if (namedAttr.name.empty())
399 continue;
400
401 std::string sanitizedName = sanitizeName(namedAttr.name);
402
403 // Unit attributes are handled specially.
404 if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
405 os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName,
406 namedAttr.name);
407 os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName,
408 namedAttr.name);
409 os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
410 namedAttr.name);
411 continue;
412 }
413
414 // Other kinds of attributes need a mapping to a Python type.
415 if (!attributeClasses.count(namedAttr.attr.getStorageType().trim()))
416 continue;
417
418 StringRef pythonType =
419 attributeClasses.lookup(namedAttr.attr.getStorageType());
420 if (namedAttr.attr.isOptional()) {
421 os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName,
422 pythonType, namedAttr.name);
423 os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName,
424 namedAttr.name);
425 os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
426 namedAttr.name);
427 } else {
428 os << llvm::formatv(attributeGetterTemplate, sanitizedName, pythonType,
429 namedAttr.name);
430 os << llvm::formatv(attributeSetterTemplate, sanitizedName,
431 namedAttr.name);
432 // Non-optional attributes cannot be deleted.
433 }
434 }
435 }
436
437 /// Template for the default auto-generated builder.
438 /// {0} is the operation name;
439 /// {1} is a comma-separated list of builder arguments, including the trailing
440 /// `loc` and `ip`;
441 /// {2} is the code populating `operands`, `results` and `attributes` fields.
442 constexpr const char *initTemplate = R"Py(
443 def __init__(self, {1}):
444 operands = []
445 results = []
446 attributes = {{}
447 {2}
448 super().__init__(_ir.Operation.create(
449 "{0}", attributes=attributes, operands=operands, results=results,
450 loc=loc, ip=ip))
451 )Py";
452
453 /// Template for appending a single element to the operand/result list.
454 /// {0} is either 'operand' or 'result';
455 /// {1} is the field name.
456 constexpr const char *singleElementAppendTemplate = "{0}s.append({1})";
457
458 /// Template for appending an optional element to the operand/result list.
459 /// {0} is either 'operand' or 'result';
460 /// {1} is the field name.
461 constexpr const char *optionalAppendTemplate =
462 "if {1} is not None: {0}s.append({1})";
463
464 /// Template for appending a variadic element to the operand/result list.
465 /// {0} is either 'operand' or 'result';
466 /// {1} is the field name.
467 constexpr const char *variadicAppendTemplate = "{0}s += [*{1}]";
468
469 /// Template for setting up the segment sizes buffer.
470 constexpr const char *segmentDeclarationTemplate =
471 "{0}_segment_sizes = array.array('L')";
472
473 /// Template for attaching segment sizes to the attribute list.
474 constexpr const char *segmentAttributeTemplate =
475 R"Py(attributes["{0}_segment_sizes"] = _ir.DenseElementsAttr.get({0}_segment_sizes,
476 context=_get_default_loc_context(loc)))Py";
477
478 /// Template for appending the unit size to the segment sizes.
479 /// {0} is either 'operand' or 'result';
480 /// {1} is the field name.
481 constexpr const char *singleElementSegmentTemplate =
482 "{0}_segment_sizes.append(1) # {1}";
483
484 /// Template for appending 0/1 for an optional element to the segment sizes.
485 /// {0} is either 'operand' or 'result';
486 /// {1} is the field name.
487 constexpr const char *optionalSegmentTemplate =
488 "{0}_segment_sizes.append(0 if {1} is None else 1)";
489
490 /// Template for appending the length of a variadic group to the segment sizes.
491 /// {0} is either 'operand' or 'result';
492 /// {1} is the field name.
493 constexpr const char *variadicSegmentTemplate =
494 "{0}_segment_sizes.append(len({1}))";
495
496 /// Template for setting an attribute in the operation builder.
497 /// {0} is the attribute name;
498 /// {1} is the builder argument name.
499 constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py";
500
501 /// Template for setting an optional attribute in the operation builder.
502 /// {0} is the attribute name;
503 /// {1} is the builder argument name.
504 constexpr const char *initOptionalAttributeTemplate =
505 R"Py(if {1} is not None: attributes["{0}"] = {1})Py";
506
507 constexpr const char *initUnitAttributeTemplate =
508 R"Py(if bool({1}): attributes["{0}"] = _ir.UnitAttr.get(
509 _get_default_loc_context(loc)))Py";
510
511 /// Populates `builderArgs` with the Python-compatible names of builder function
512 /// arguments, first the results, then the intermixed attributes and operands in
513 /// the same order as they appear in the `arguments` field of the op definition.
514 /// Additionally, `operandNames` is populated with names of operands in their
515 /// order of appearance.
516 static void
populateBuilderArgs(const Operator & op,llvm::SmallVectorImpl<std::string> & builderArgs,llvm::SmallVectorImpl<std::string> & operandNames)517 populateBuilderArgs(const Operator &op,
518 llvm::SmallVectorImpl<std::string> &builderArgs,
519 llvm::SmallVectorImpl<std::string> &operandNames) {
520 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
521 std::string name = op.getResultName(i).str();
522 if (name.empty())
523 name = llvm::formatv("_gen_res_{0}", i);
524 name = sanitizeName(name);
525 builderArgs.push_back(name);
526 }
527 for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
528 std::string name = op.getArgName(i).str();
529 if (name.empty())
530 name = llvm::formatv("_gen_arg_{0}", i);
531 name = sanitizeName(name);
532 builderArgs.push_back(name);
533 if (!op.getArg(i).is<NamedAttribute *>())
534 operandNames.push_back(name);
535 }
536 }
537
538 /// Populates `builderLines` with additional lines that are required in the
539 /// builder to set up operation attributes. `argNames` is expected to contain
540 /// the names of builder arguments that correspond to op arguments, i.e. to the
541 /// operands and attributes in the same order as they appear in the `arguments`
542 /// field.
543 static void
populateBuilderLinesAttr(const Operator & op,llvm::ArrayRef<std::string> argNames,llvm::SmallVectorImpl<std::string> & builderLines)544 populateBuilderLinesAttr(const Operator &op,
545 llvm::ArrayRef<std::string> argNames,
546 llvm::SmallVectorImpl<std::string> &builderLines) {
547 for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
548 Argument arg = op.getArg(i);
549 auto *attribute = arg.dyn_cast<NamedAttribute *>();
550 if (!attribute)
551 continue;
552
553 // Unit attributes are handled specially.
554 if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
555 builderLines.push_back(llvm::formatv(initUnitAttributeTemplate,
556 attribute->name, argNames[i]));
557 continue;
558 }
559
560 builderLines.push_back(llvm::formatv(attribute->attr.isOptional()
561 ? initOptionalAttributeTemplate
562 : initAttributeTemplate,
563 attribute->name, argNames[i]));
564 }
565 }
566
567 /// Populates `builderLines` with additional lines that are required in the
568 /// builder. `kind` must be either "operand" or "result". `names` contains the
569 /// names of init arguments that correspond to the elements.
populateBuilderLines(const Operator & op,const char * kind,llvm::ArrayRef<std::string> names,llvm::SmallVectorImpl<std::string> & builderLines,llvm::function_ref<int (const Operator &)> getNumElements,llvm::function_ref<const NamedTypeConstraint & (const Operator &,int)> getElement)570 static void populateBuilderLines(
571 const Operator &op, const char *kind, llvm::ArrayRef<std::string> names,
572 llvm::SmallVectorImpl<std::string> &builderLines,
573 llvm::function_ref<int(const Operator &)> getNumElements,
574 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
575 getElement) {
576 // The segment sizes buffer only has to be populated if there attr-sized
577 // segments trait is present.
578 bool includeSegments = op.getTrait(attrSizedTraitForKind(kind)) != nullptr;
579 if (includeSegments)
580 builderLines.push_back(llvm::formatv(segmentDeclarationTemplate, kind));
581
582 // For each element, find or generate a name.
583 for (int i = 0, e = getNumElements(op); i < e; ++i) {
584 const NamedTypeConstraint &element = getElement(op, i);
585 std::string name = names[i];
586
587 // Choose the formatting string based on the element kind.
588 llvm::StringRef formatString, segmentFormatString;
589 if (!element.isVariableLength()) {
590 formatString = singleElementAppendTemplate;
591 segmentFormatString = singleElementSegmentTemplate;
592 } else if (element.isOptional()) {
593 formatString = optionalAppendTemplate;
594 segmentFormatString = optionalSegmentTemplate;
595 } else {
596 assert(element.isVariadic() && "unhandled element group type");
597 formatString = variadicAppendTemplate;
598 segmentFormatString = variadicSegmentTemplate;
599 }
600
601 // Add the lines.
602 builderLines.push_back(llvm::formatv(formatString.data(), kind, name));
603 if (includeSegments)
604 builderLines.push_back(
605 llvm::formatv(segmentFormatString.data(), kind, name));
606 }
607
608 if (includeSegments)
609 builderLines.push_back(llvm::formatv(segmentAttributeTemplate, kind));
610 }
611
612 /// Emits a default builder constructing an operation from the list of its
613 /// result types, followed by a list of its operands.
emitDefaultOpBuilder(const Operator & op,raw_ostream & os)614 static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
615 // If we are asked to skip default builders, comply.
616 if (op.skipDefaultBuilders())
617 return;
618
619 llvm::SmallVector<std::string, 8> builderArgs;
620 llvm::SmallVector<std::string, 8> builderLines;
621 llvm::SmallVector<std::string, 4> operandArgNames;
622 builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
623 op.getNumNativeAttributes());
624 populateBuilderArgs(op, builderArgs, operandArgNames);
625 populateBuilderLines(
626 op, "result",
627 llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()),
628 builderLines, getNumResults, getResult);
629 populateBuilderLines(op, "operand", operandArgNames, builderLines,
630 getNumOperands, getOperand);
631 populateBuilderLinesAttr(
632 op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()),
633 builderLines);
634
635 builderArgs.push_back("loc=None");
636 builderArgs.push_back("ip=None");
637 os << llvm::formatv(initTemplate, op.getOperationName(),
638 llvm::join(builderArgs, ", "),
639 llvm::join(builderLines, "\n "));
640 }
641
constructAttributeMapping(const llvm::RecordKeeper & records,AttributeClasses & attributeClasses)642 static void constructAttributeMapping(const llvm::RecordKeeper &records,
643 AttributeClasses &attributeClasses) {
644 for (const llvm::Record *rec :
645 records.getAllDerivedDefinitions("PythonAttr")) {
646 attributeClasses.try_emplace(rec->getValueAsString("cppStorageType").trim(),
647 rec->getValueAsString("pythonType").trim());
648 }
649 }
650
651 /// Emits bindings for a specific Op to the given output stream.
emitOpBindings(const Operator & op,const AttributeClasses & attributeClasses,raw_ostream & os)652 static void emitOpBindings(const Operator &op,
653 const AttributeClasses &attributeClasses,
654 raw_ostream &os) {
655 os << llvm::formatv(opClassTemplate, op.getCppClassName(),
656 op.getOperationName());
657 emitDefaultOpBuilder(op, os);
658 emitOperandAccessors(op, os);
659 emitAttributeAccessors(op, attributeClasses, os);
660 emitResultAccessors(op, os);
661 }
662
663 /// Emits bindings for the dialect specified in the command line, including file
664 /// headers and utilities. Returns `false` on success to comply with Tablegen
665 /// registration requirements.
emitAllOps(const llvm::RecordKeeper & records,raw_ostream & os)666 static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
667 if (clDialectName.empty())
668 llvm::PrintFatalError("dialect name not provided");
669
670 AttributeClasses attributeClasses;
671 constructAttributeMapping(records, attributeClasses);
672
673 os << fileHeader;
674 os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
675 for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
676 Operator op(rec);
677 if (op.getDialectName() == clDialectName.getValue())
678 emitOpBindings(op, attributeClasses, os);
679 }
680 return false;
681 }
682
683 static GenRegistration
684 genPythonBindings("gen-python-op-bindings",
685 "Generate Python bindings for MLIR Ops", &emitAllOps);
686