1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
10 // binding classes wrapping a generic operation API.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/GenInfo.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "llvm/ADT/StringSet.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
21 
22 using namespace mlir;
23 using namespace mlir::tblgen;
24 
25 /// File header and includes.
26 constexpr const char *fileHeader = R"Py(
27 # Autogenerated by mlir-tblgen; don't manually edit.
28 
29 import array
30 from . import _cext
31 from . import _segmented_accessor, _equally_sized_accessor, _get_default_loc_context
32 _ir = _cext.ir
33 )Py";
34 
35 /// Template for dialect class:
36 ///   {0} is the dialect namespace.
37 constexpr const char *dialectClassTemplate = R"Py(
38 @_cext.register_dialect
39 class _Dialect(_ir.Dialect):
40   DIALECT_NAMESPACE = "{0}"
41   pass
42 
43 )Py";
44 
45 /// Template for operation class:
46 ///   {0} is the Python class name;
47 ///   {1} is the operation name.
48 constexpr const char *opClassTemplate = R"Py(
49 @_cext.register_operation(_Dialect)
50 class {0}(_ir.OpView):
51   OPERATION_NAME = "{1}"
52 )Py";
53 
54 /// Template for single-element accessor:
55 ///   {0} is the name of the accessor;
56 ///   {1} is either 'operand' or 'result';
57 ///   {2} is the position in the element list.
58 constexpr const char *opSingleTemplate = R"Py(
59   @property
60   def {0}(self):
61     return self.operation.{1}s[{2}]
62 )Py";
63 
64 /// Template for single-element accessor after a variable-length group:
65 ///   {0} is the name of the accessor;
66 ///   {1} is either 'operand' or 'result';
67 ///   {2} is the total number of element groups;
68 ///   {3} is the position of the current group in the group list.
69 /// This works for both a single variadic group (non-negative length) and an
70 /// single optional element (zero length if the element is absent).
71 constexpr const char *opSingleAfterVariableTemplate = R"Py(
72   @property
73   def {0}(self):
74     variadic_group_length = len(self.operation.{1}s) - {2} + 1
75     return self.operation.{1}s[{3} + variadic_group_length - 1]
76 )Py";
77 
78 /// Template for an optional element accessor:
79 ///   {0} is the name of the accessor;
80 ///   {1} is either 'operand' or 'result';
81 ///   {2} is the total number of element groups;
82 ///   {3} is the position of the current group in the group list.
83 constexpr const char *opOneOptionalTemplate = R"Py(
84   @property
85   def {0}(self);
86     return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2}
87                                     else None
88 )Py";
89 
90 /// Template for the variadic group accessor in the single variadic group case:
91 ///   {0} is the name of the accessor;
92 ///   {1} is either 'operand' or 'result';
93 ///   {2} is the total number of element groups;
94 ///   {3} is the position of the current group in the group list.
95 constexpr const char *opOneVariadicTemplate = R"Py(
96   @property
97   def {0}(self):
98     variadic_group_length = len(self.operation.{1}s) - {2} + 1
99     return self.operation.{1}s[{3}:{3} + variadic_group_length]
100 )Py";
101 
102 /// First part of the template for equally-sized variadic group accessor:
103 ///   {0} is the name of the accessor;
104 ///   {1} is either 'operand' or 'result';
105 ///   {2} is the total number of variadic groups;
106 ///   {3} is the number of non-variadic groups preceding the current group;
107 ///   {3} is the number of variadic groups preceding the current group.
108 constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
109   @property
110   def {0}(self):
111     start, pg = _equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py";
112 
113 /// Second part of the template for equally-sized case, accessing a single
114 /// element:
115 ///   {0} is either 'operand' or 'result'.
116 constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
117     return self.operation.{0}s[start]
118 )Py";
119 
120 /// Second part of the template for equally-sized case, accessing a variadic
121 /// group:
122 ///   {0} is either 'operand' or 'result'.
123 constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
124     return self.operation.{0}s[start:start + pg]
125 )Py";
126 
127 /// Template for an attribute-sized group accessor:
128 ///   {0} is the name of the accessor;
129 ///   {1} is either 'operand' or 'result';
130 ///   {2} is the position of the group in the group list;
131 ///   {3} is a return suffix (expected [0] for single-element, empty for
132 ///       variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
133 constexpr const char *opVariadicSegmentTemplate = R"Py(
134   @property
135   def {0}(self):
136     {1}_range = _segmented_accessor(
137          self.operation.{1}s,
138          self.operation.attributes["{1}_segment_sizes"], {2})
139     return {1}_range{3}
140 )Py";
141 
142 /// Template for a suffix when accessing an optional element in the
143 /// attribute-sized case:
144 ///   {0} is either 'operand' or 'result';
145 constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
146     R"Py([0] if len({0}_range) > 0 else None)Py";
147 
148 /// Template for an operation attribute getter:
149 ///   {0} is the name of the attribute sanitized for Python;
150 ///   {1} is the Python type of the attribute;
151 ///   {2} os the original name of the attribute.
152 constexpr const char *attributeGetterTemplate = R"Py(
153   @property
154   def {0}(self):
155     return {1}(self.operation.attributes["{2}"])
156 )Py";
157 
158 /// Template for an optional operation attribute getter:
159 ///   {0} is the name of the attribute sanitized for Python;
160 ///   {1} is the Python type of the attribute;
161 ///   {2} is the original name of the attribute.
162 constexpr const char *optionalAttributeGetterTemplate = R"Py(
163   @property
164   def {0}(self):
165     if "{2}" not in self.operation.attributes:
166       return None
167     return {1}(self.operation.attributes["{2}"])
168 )Py";
169 
170 /// Template for a getter of a unit operation attribute, returns True of the
171 /// unit attribute is present, False otherwise (unit attributes have meaning
172 /// by mere presence):
173 ///    {0} is the name of the attribute sanitized for Python,
174 ///    {1} is the original name of the attribute.
175 constexpr const char *unitAttributeGetterTemplate = R"Py(
176   @property
177   def {0}(self):
178     return "{1}" in self.operation.attributes
179 )Py";
180 
181 /// Template for an operation attribute setter:
182 ///    {0} is the name of the attribute sanitized for Python;
183 ///    {1} is the original name of the attribute.
184 constexpr const char *attributeSetterTemplate = R"Py(
185   @{0}.setter
186   def {0}(self, value):
187     if value is None:
188       raise ValueError("'None' not allowed as value for mandatory attributes")
189     self.operation.attributes["{1}"] = value
190 )Py";
191 
192 /// Template for a setter of an optional operation attribute, setting to None
193 /// removes the attribute:
194 ///    {0} is the name of the attribute sanitized for Python;
195 ///    {1} is the original name of the attribute.
196 constexpr const char *optionalAttributeSetterTemplate = R"Py(
197   @{0}.setter
198   def {0}(self, value):
199     if value is not None:
200       self.operation.attributes["{1}"] = value
201     elif "{1}" in self.operation.attributes:
202       del self.operation.attributes["{1}"]
203 )Py";
204 
205 /// Template for a setter of a unit operation attribute, setting to None or
206 /// False removes the attribute:
207 ///    {0} is the name of the attribute sanitized for Python;
208 ///    {1} is the original name of the attribute.
209 constexpr const char *unitAttributeSetterTemplate = R"Py(
210   @{0}.setter
211   def {0}(self, value):
212     if bool(value):
213       self.operation.attributes["{1}"] = _ir.UnitAttr.get()
214     elif "{1}" in self.operation.attributes:
215       del self.operation.attributes["{1}"]
216 )Py";
217 
218 /// Template for a deleter of an optional or a unit operation attribute, removes
219 /// the attribute from the operation:
220 ///    {0} is the name of the attribute sanitized for Python;
221 ///    {1} is the original name of the attribute.
222 constexpr const char *attributeDeleterTemplate = R"Py(
223   @{0}.deleter
224   def {0}(self):
225     del self.operation.attributes["{1}"]
226 )Py";
227 
228 static llvm::cl::OptionCategory
229     clOpPythonBindingCat("Options for -gen-python-op-bindings");
230 
231 static llvm::cl::opt<std::string>
232     clDialectName("bind-dialect",
233                   llvm::cl::desc("The dialect to run the generator for"),
234                   llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
235 
236 using AttributeClasses = DenseMap<StringRef, StringRef>;
237 
238 /// Checks whether `str` is a Python keyword.
isPythonKeyword(StringRef str)239 static bool isPythonKeyword(StringRef str) {
240   static llvm::StringSet<> keywords(
241       {"and",   "as",     "assert",   "break", "class",  "continue",
242        "def",   "del",    "elif",     "else",  "except", "finally",
243        "for",   "from",   "global",   "if",    "import", "in",
244        "is",    "lambda", "nonlocal", "not",   "or",     "pass",
245        "raise", "return", "try",      "while", "with",   "yield"});
246   return keywords.contains(str);
247 };
248 
249 /// Modifies the `name` in a way that it becomes suitable for Python bindings
250 /// (does not change the `name` if it already is suitable) and returns the
251 /// modified version.
sanitizeName(StringRef name)252 static std::string sanitizeName(StringRef name) {
253   if (isPythonKeyword(name))
254     return (name + "_").str();
255   return name.str();
256 }
257 
attrSizedTraitForKind(const char * kind)258 static std::string attrSizedTraitForKind(const char *kind) {
259   return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
260                        llvm::StringRef(kind).take_front().upper(),
261                        llvm::StringRef(kind).drop_front());
262 }
263 
264 /// Emits accessors to "elements" of an Op definition. Currently, the supported
265 /// elements are operands and results, indicated by `kind`, which must be either
266 /// `operand` or `result` and is used verbatim in the emitted code.
emitElementAccessors(const Operator & op,raw_ostream & os,const char * kind,llvm::function_ref<unsigned (const Operator &)> getNumVariadic,llvm::function_ref<int (const Operator &)> getNumElements,llvm::function_ref<const NamedTypeConstraint & (const Operator &,int)> getElement)267 static void emitElementAccessors(
268     const Operator &op, raw_ostream &os, const char *kind,
269     llvm::function_ref<unsigned(const Operator &)> getNumVariadic,
270     llvm::function_ref<int(const Operator &)> getNumElements,
271     llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
272         getElement) {
273   assert(llvm::is_contained(
274              llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) &&
275          "unsupported kind");
276 
277   // Traits indicating how to process variadic elements.
278   std::string sameSizeTrait =
279       llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
280                     llvm::StringRef(kind).take_front().upper(),
281                     llvm::StringRef(kind).drop_front());
282   std::string attrSizedTrait = attrSizedTraitForKind(kind);
283 
284   unsigned numVariadic = getNumVariadic(op);
285 
286   // If there is only one variadic element group, its size can be inferred from
287   // the total number of elements. If there are none, the generation is
288   // straightforward.
289   if (numVariadic <= 1) {
290     bool seenVariableLength = false;
291     for (int i = 0, e = getNumElements(op); i < e; ++i) {
292       const NamedTypeConstraint &element = getElement(op, i);
293       if (element.isVariableLength())
294         seenVariableLength = true;
295       if (element.name.empty())
296         continue;
297       if (element.isVariableLength()) {
298         os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate
299                                                  : opOneVariadicTemplate,
300                             sanitizeName(element.name), kind,
301                             getNumElements(op), i);
302       } else if (seenVariableLength) {
303         os << llvm::formatv(opSingleAfterVariableTemplate,
304                             sanitizeName(element.name), kind,
305                             getNumElements(op), i);
306       } else {
307         os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind,
308                             i);
309       }
310     }
311     return;
312   }
313 
314   // Handle the operations where variadic groups have the same size.
315   if (op.getTrait(sameSizeTrait)) {
316     int numPrecedingSimple = 0;
317     int numPrecedingVariadic = 0;
318     for (int i = 0, e = getNumElements(op); i < e; ++i) {
319       const NamedTypeConstraint &element = getElement(op, i);
320       if (!element.name.empty()) {
321         os << llvm::formatv(opVariadicEqualPrefixTemplate,
322                             sanitizeName(element.name), kind, numVariadic,
323                             numPrecedingSimple, numPrecedingVariadic);
324         os << llvm::formatv(element.isVariableLength()
325                                 ? opVariadicEqualVariadicTemplate
326                                 : opVariadicEqualSimpleTemplate,
327                             kind);
328       }
329       if (element.isVariableLength())
330         ++numPrecedingVariadic;
331       else
332         ++numPrecedingSimple;
333     }
334     return;
335   }
336 
337   // Handle the operations where the size of groups (variadic or not) is
338   // provided as an attribute. For non-variadic elements, make sure to return
339   // an element rather than a singleton container.
340   if (op.getTrait(attrSizedTrait)) {
341     for (int i = 0, e = getNumElements(op); i < e; ++i) {
342       const NamedTypeConstraint &element = getElement(op, i);
343       if (element.name.empty())
344         continue;
345       std::string trailing;
346       if (!element.isVariableLength())
347         trailing = "[0]";
348       else if (element.isOptional())
349         trailing = std::string(
350             llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
351       os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name),
352                           kind, i, trailing);
353     }
354     return;
355   }
356 
357   llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure");
358 }
359 
360 /// Free function helpers accessing Operator components.
getNumOperands(const Operator & op)361 static int getNumOperands(const Operator &op) { return op.getNumOperands(); }
getOperand(const Operator & op,int i)362 static const NamedTypeConstraint &getOperand(const Operator &op, int i) {
363   return op.getOperand(i);
364 }
getNumResults(const Operator & op)365 static int getNumResults(const Operator &op) { return op.getNumResults(); }
getResult(const Operator & op,int i)366 static const NamedTypeConstraint &getResult(const Operator &op, int i) {
367   return op.getResult(i);
368 }
369 
370 /// Emits accessors to Op operands.
emitOperandAccessors(const Operator & op,raw_ostream & os)371 static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
372   auto getNumVariadic = [](const Operator &oper) {
373     return oper.getNumVariableLengthOperands();
374   };
375   emitElementAccessors(op, os, "operand", getNumVariadic, getNumOperands,
376                        getOperand);
377 }
378 
379 /// Emits accessors Op results.
emitResultAccessors(const Operator & op,raw_ostream & os)380 static void emitResultAccessors(const Operator &op, raw_ostream &os) {
381   auto getNumVariadic = [](const Operator &oper) {
382     return oper.getNumVariableLengthResults();
383   };
384   emitElementAccessors(op, os, "result", getNumVariadic, getNumResults,
385                        getResult);
386 }
387 
388 /// Emits accessors to Op attributes.
emitAttributeAccessors(const Operator & op,const AttributeClasses & attributeClasses,raw_ostream & os)389 static void emitAttributeAccessors(const Operator &op,
390                                    const AttributeClasses &attributeClasses,
391                                    raw_ostream &os) {
392   for (const auto &namedAttr : op.getAttributes()) {
393     // Skip "derived" attributes because they are just C++ functions that we
394     // don't currently expose.
395     if (namedAttr.attr.isDerivedAttr())
396       continue;
397 
398     if (namedAttr.name.empty())
399       continue;
400 
401     std::string sanitizedName = sanitizeName(namedAttr.name);
402 
403     // Unit attributes are handled specially.
404     if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
405       os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName,
406                           namedAttr.name);
407       os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName,
408                           namedAttr.name);
409       os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
410                           namedAttr.name);
411       continue;
412     }
413 
414     // Other kinds of attributes need a mapping to a Python type.
415     if (!attributeClasses.count(namedAttr.attr.getStorageType().trim()))
416       continue;
417 
418     StringRef pythonType =
419         attributeClasses.lookup(namedAttr.attr.getStorageType());
420     if (namedAttr.attr.isOptional()) {
421       os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName,
422                           pythonType, namedAttr.name);
423       os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName,
424                           namedAttr.name);
425       os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
426                           namedAttr.name);
427     } else {
428       os << llvm::formatv(attributeGetterTemplate, sanitizedName, pythonType,
429                           namedAttr.name);
430       os << llvm::formatv(attributeSetterTemplate, sanitizedName,
431                           namedAttr.name);
432       // Non-optional attributes cannot be deleted.
433     }
434   }
435 }
436 
437 /// Template for the default auto-generated builder.
438 ///   {0} is the operation name;
439 ///   {1} is a comma-separated list of builder arguments, including the trailing
440 ///       `loc` and `ip`;
441 ///   {2} is the code populating `operands`, `results` and `attributes` fields.
442 constexpr const char *initTemplate = R"Py(
443   def __init__(self, {1}):
444     operands = []
445     results = []
446     attributes = {{}
447     {2}
448     super().__init__(_ir.Operation.create(
449       "{0}", attributes=attributes, operands=operands, results=results,
450       loc=loc, ip=ip))
451 )Py";
452 
453 /// Template for appending a single element to the operand/result list.
454 ///   {0} is either 'operand' or 'result';
455 ///   {1} is the field name.
456 constexpr const char *singleElementAppendTemplate = "{0}s.append({1})";
457 
458 /// Template for appending an optional element to the operand/result list.
459 ///   {0} is either 'operand' or 'result';
460 ///   {1} is the field name.
461 constexpr const char *optionalAppendTemplate =
462     "if {1} is not None: {0}s.append({1})";
463 
464 /// Template for appending a variadic element to the operand/result list.
465 ///   {0} is either 'operand' or 'result';
466 ///   {1} is the field name.
467 constexpr const char *variadicAppendTemplate = "{0}s += [*{1}]";
468 
469 /// Template for setting up the segment sizes buffer.
470 constexpr const char *segmentDeclarationTemplate =
471     "{0}_segment_sizes = array.array('L')";
472 
473 /// Template for attaching segment sizes to the attribute list.
474 constexpr const char *segmentAttributeTemplate =
475     R"Py(attributes["{0}_segment_sizes"] = _ir.DenseElementsAttr.get({0}_segment_sizes,
476       context=_get_default_loc_context(loc)))Py";
477 
478 /// Template for appending the unit size to the segment sizes.
479 ///   {0} is either 'operand' or 'result';
480 ///   {1} is the field name.
481 constexpr const char *singleElementSegmentTemplate =
482     "{0}_segment_sizes.append(1) # {1}";
483 
484 /// Template for appending 0/1 for an optional element to the segment sizes.
485 ///   {0} is either 'operand' or 'result';
486 ///   {1} is the field name.
487 constexpr const char *optionalSegmentTemplate =
488     "{0}_segment_sizes.append(0 if {1} is None else 1)";
489 
490 /// Template for appending the length of a variadic group to the segment sizes.
491 ///   {0} is either 'operand' or 'result';
492 ///   {1} is the field name.
493 constexpr const char *variadicSegmentTemplate =
494     "{0}_segment_sizes.append(len({1}))";
495 
496 /// Template for setting an attribute in the operation builder.
497 ///   {0} is the attribute name;
498 ///   {1} is the builder argument name.
499 constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py";
500 
501 /// Template for setting an optional attribute in the operation builder.
502 ///   {0} is the attribute name;
503 ///   {1} is the builder argument name.
504 constexpr const char *initOptionalAttributeTemplate =
505     R"Py(if {1} is not None: attributes["{0}"] = {1})Py";
506 
507 constexpr const char *initUnitAttributeTemplate =
508     R"Py(if bool({1}): attributes["{0}"] = _ir.UnitAttr.get(
509       _get_default_loc_context(loc)))Py";
510 
511 /// Populates `builderArgs` with the Python-compatible names of builder function
512 /// arguments, first the results, then the intermixed attributes and operands in
513 /// the same order as they appear in the `arguments` field of the op definition.
514 /// Additionally, `operandNames` is populated with names of operands in their
515 /// order of appearance.
516 static void
populateBuilderArgs(const Operator & op,llvm::SmallVectorImpl<std::string> & builderArgs,llvm::SmallVectorImpl<std::string> & operandNames)517 populateBuilderArgs(const Operator &op,
518                     llvm::SmallVectorImpl<std::string> &builderArgs,
519                     llvm::SmallVectorImpl<std::string> &operandNames) {
520   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
521     std::string name = op.getResultName(i).str();
522     if (name.empty())
523       name = llvm::formatv("_gen_res_{0}", i);
524     name = sanitizeName(name);
525     builderArgs.push_back(name);
526   }
527   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
528     std::string name = op.getArgName(i).str();
529     if (name.empty())
530       name = llvm::formatv("_gen_arg_{0}", i);
531     name = sanitizeName(name);
532     builderArgs.push_back(name);
533     if (!op.getArg(i).is<NamedAttribute *>())
534       operandNames.push_back(name);
535   }
536 }
537 
538 /// Populates `builderLines` with additional lines that are required in the
539 /// builder to set up operation attributes. `argNames` is expected to contain
540 /// the names of builder arguments that correspond to op arguments, i.e. to the
541 /// operands and attributes in the same order as they appear in the `arguments`
542 /// field.
543 static void
populateBuilderLinesAttr(const Operator & op,llvm::ArrayRef<std::string> argNames,llvm::SmallVectorImpl<std::string> & builderLines)544 populateBuilderLinesAttr(const Operator &op,
545                          llvm::ArrayRef<std::string> argNames,
546                          llvm::SmallVectorImpl<std::string> &builderLines) {
547   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
548     Argument arg = op.getArg(i);
549     auto *attribute = arg.dyn_cast<NamedAttribute *>();
550     if (!attribute)
551       continue;
552 
553     // Unit attributes are handled specially.
554     if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
555       builderLines.push_back(llvm::formatv(initUnitAttributeTemplate,
556                                            attribute->name, argNames[i]));
557       continue;
558     }
559 
560     builderLines.push_back(llvm::formatv(attribute->attr.isOptional()
561                                              ? initOptionalAttributeTemplate
562                                              : initAttributeTemplate,
563                                          attribute->name, argNames[i]));
564   }
565 }
566 
567 /// Populates `builderLines` with additional lines that are required in the
568 /// builder. `kind` must be either "operand" or "result". `names` contains the
569 /// names of init arguments that correspond to the elements.
populateBuilderLines(const Operator & op,const char * kind,llvm::ArrayRef<std::string> names,llvm::SmallVectorImpl<std::string> & builderLines,llvm::function_ref<int (const Operator &)> getNumElements,llvm::function_ref<const NamedTypeConstraint & (const Operator &,int)> getElement)570 static void populateBuilderLines(
571     const Operator &op, const char *kind, llvm::ArrayRef<std::string> names,
572     llvm::SmallVectorImpl<std::string> &builderLines,
573     llvm::function_ref<int(const Operator &)> getNumElements,
574     llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
575         getElement) {
576   // The segment sizes buffer only has to be populated if there attr-sized
577   // segments trait is present.
578   bool includeSegments = op.getTrait(attrSizedTraitForKind(kind)) != nullptr;
579   if (includeSegments)
580     builderLines.push_back(llvm::formatv(segmentDeclarationTemplate, kind));
581 
582   // For each element, find or generate a name.
583   for (int i = 0, e = getNumElements(op); i < e; ++i) {
584     const NamedTypeConstraint &element = getElement(op, i);
585     std::string name = names[i];
586 
587     // Choose the formatting string based on the element kind.
588     llvm::StringRef formatString, segmentFormatString;
589     if (!element.isVariableLength()) {
590       formatString = singleElementAppendTemplate;
591       segmentFormatString = singleElementSegmentTemplate;
592     } else if (element.isOptional()) {
593       formatString = optionalAppendTemplate;
594       segmentFormatString = optionalSegmentTemplate;
595     } else {
596       assert(element.isVariadic() && "unhandled element group type");
597       formatString = variadicAppendTemplate;
598       segmentFormatString = variadicSegmentTemplate;
599     }
600 
601     // Add the lines.
602     builderLines.push_back(llvm::formatv(formatString.data(), kind, name));
603     if (includeSegments)
604       builderLines.push_back(
605           llvm::formatv(segmentFormatString.data(), kind, name));
606   }
607 
608   if (includeSegments)
609     builderLines.push_back(llvm::formatv(segmentAttributeTemplate, kind));
610 }
611 
612 /// Emits a default builder constructing an operation from the list of its
613 /// result types, followed by a list of its operands.
emitDefaultOpBuilder(const Operator & op,raw_ostream & os)614 static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
615   // If we are asked to skip default builders, comply.
616   if (op.skipDefaultBuilders())
617     return;
618 
619   llvm::SmallVector<std::string, 8> builderArgs;
620   llvm::SmallVector<std::string, 8> builderLines;
621   llvm::SmallVector<std::string, 4> operandArgNames;
622   builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
623                       op.getNumNativeAttributes());
624   populateBuilderArgs(op, builderArgs, operandArgNames);
625   populateBuilderLines(
626       op, "result",
627       llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()),
628       builderLines, getNumResults, getResult);
629   populateBuilderLines(op, "operand", operandArgNames, builderLines,
630                        getNumOperands, getOperand);
631   populateBuilderLinesAttr(
632       op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()),
633       builderLines);
634 
635   builderArgs.push_back("loc=None");
636   builderArgs.push_back("ip=None");
637   os << llvm::formatv(initTemplate, op.getOperationName(),
638                       llvm::join(builderArgs, ", "),
639                       llvm::join(builderLines, "\n    "));
640 }
641 
constructAttributeMapping(const llvm::RecordKeeper & records,AttributeClasses & attributeClasses)642 static void constructAttributeMapping(const llvm::RecordKeeper &records,
643                                       AttributeClasses &attributeClasses) {
644   for (const llvm::Record *rec :
645        records.getAllDerivedDefinitions("PythonAttr")) {
646     attributeClasses.try_emplace(rec->getValueAsString("cppStorageType").trim(),
647                                  rec->getValueAsString("pythonType").trim());
648   }
649 }
650 
651 /// Emits bindings for a specific Op to the given output stream.
emitOpBindings(const Operator & op,const AttributeClasses & attributeClasses,raw_ostream & os)652 static void emitOpBindings(const Operator &op,
653                            const AttributeClasses &attributeClasses,
654                            raw_ostream &os) {
655   os << llvm::formatv(opClassTemplate, op.getCppClassName(),
656                       op.getOperationName());
657   emitDefaultOpBuilder(op, os);
658   emitOperandAccessors(op, os);
659   emitAttributeAccessors(op, attributeClasses, os);
660   emitResultAccessors(op, os);
661 }
662 
663 /// Emits bindings for the dialect specified in the command line, including file
664 /// headers and utilities. Returns `false` on success to comply with Tablegen
665 /// registration requirements.
emitAllOps(const llvm::RecordKeeper & records,raw_ostream & os)666 static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
667   if (clDialectName.empty())
668     llvm::PrintFatalError("dialect name not provided");
669 
670   AttributeClasses attributeClasses;
671   constructAttributeMapping(records, attributeClasses);
672 
673   os << fileHeader;
674   os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
675   for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
676     Operator op(rec);
677     if (op.getDialectName() == clDialectName.getValue())
678       emitOpBindings(op, attributeClasses, os);
679   }
680   return false;
681 }
682 
683 static GenRegistration
684     genPythonBindings("gen-python-op-bindings",
685                       "Generate Python bindings for MLIR Ops", &emitAllOps);
686