1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModules.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Registration.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include <pybind11/stl.h>
20 
21 namespace py = pybind11;
22 using namespace mlir;
23 using namespace mlir::python;
24 
25 using llvm::SmallVector;
26 using llvm::StringRef;
27 using llvm::Twine;
28 
29 //------------------------------------------------------------------------------
30 // Docstrings (trivial, non-duplicated docstrings are included inline).
31 //------------------------------------------------------------------------------
32 
33 static const char kContextParseTypeDocstring[] =
34     R"(Parses the assembly form of a type.
35 
36 Returns a Type object or raises a ValueError if the type cannot be parsed.
37 
38 See also: https://mlir.llvm.org/docs/LangRef/#type-system
39 )";
40 
41 static const char kContextGetFileLocationDocstring[] =
42     R"(Gets a Location representing a file, line and column)";
43 
44 static const char kModuleParseDocstring[] =
45     R"(Parses a module's assembly format from a string.
46 
47 Returns a new MlirModule or raises a ValueError if the parsing fails.
48 
49 See also: https://mlir.llvm.org/docs/LangRef/
50 )";
51 
52 static const char kOperationCreateDocstring[] =
53     R"(Creates a new operation.
54 
55 Args:
56   name: Operation name (e.g. "dialect.operation").
57   results: Sequence of Type representing op result types.
58   attributes: Dict of str:Attribute.
59   successors: List of Block for the operation's successors.
60   regions: Number of regions to create.
61   location: A Location object (defaults to resolve from context manager).
62   ip: An InsertionPoint (defaults to resolve from context manager or set to
63     False to disable insertion, even with an insertion point set in the
64     context manager).
65 Returns:
66   A new "detached" Operation object. Detached operations can be added
67   to blocks, which causes them to become "attached."
68 )";
69 
70 static const char kOperationPrintDocstring[] =
71     R"(Prints the assembly form of the operation to a file like object.
72 
73 Args:
74   file: The file like object to write to. Defaults to sys.stdout.
75   binary: Whether to write bytes (True) or str (False). Defaults to False.
76   large_elements_limit: Whether to elide elements attributes above this
77     number of elements. Defaults to None (no limit).
78   enable_debug_info: Whether to print debug/location information. Defaults
79     to False.
80   pretty_debug_info: Whether to format debug information for easier reading
81     by a human (warning: the result is unparseable).
82   print_generic_op_form: Whether to print the generic assembly forms of all
83     ops. Defaults to False.
84   use_local_Scope: Whether to print in a way that is more optimized for
85     multi-threaded access but may not be consistent with how the overall
86     module prints.
87 )";
88 
89 static const char kOperationGetAsmDocstring[] =
90     R"(Gets the assembly form of the operation with all options available.
91 
92 Args:
93   binary: Whether to return a bytes (True) or str (False) object. Defaults to
94     False.
95   ... others ...: See the print() method for common keyword arguments for
96     configuring the printout.
97 Returns:
98   Either a bytes or str object, depending on the setting of the 'binary'
99   argument.
100 )";
101 
102 static const char kOperationStrDunderDocstring[] =
103     R"(Gets the assembly form of the operation with default options.
104 
105 If more advanced control over the assembly formatting or I/O options is needed,
106 use the dedicated print or get_asm method, which supports keyword arguments to
107 customize behavior.
108 )";
109 
110 static const char kDumpDocstring[] =
111     R"(Dumps a debug representation of the object to stderr.)";
112 
113 static const char kAppendBlockDocstring[] =
114     R"(Appends a new block, with argument types as positional args.
115 
116 Returns:
117   The created block.
118 )";
119 
120 static const char kValueDunderStrDocstring[] =
121     R"(Returns the string form of the value.
122 
123 If the value is a block argument, this is the assembly form of its type and the
124 position in the argument list. If the value is an operation result, this is
125 equivalent to printing the operation that produced it.
126 )";
127 
128 //------------------------------------------------------------------------------
129 // Utilities.
130 //------------------------------------------------------------------------------
131 
132 /// Checks whether the given type is an integer or float type.
mlirTypeIsAIntegerOrFloat(MlirType type)133 static int mlirTypeIsAIntegerOrFloat(MlirType type) {
134   return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
135          mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
136 }
137 
138 static py::object
createCustomDialectWrapper(const std::string & dialectNamespace,py::object dialectDescriptor)139 createCustomDialectWrapper(const std::string &dialectNamespace,
140                            py::object dialectDescriptor) {
141   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
142   if (!dialectClass) {
143     // Use the base class.
144     return py::cast(PyDialect(std::move(dialectDescriptor)));
145   }
146 
147   // Create the custom implementation.
148   return (*dialectClass)(std::move(dialectDescriptor));
149 }
150 
toMlirStringRef(const std::string & s)151 static MlirStringRef toMlirStringRef(const std::string &s) {
152   return mlirStringRefCreate(s.data(), s.size());
153 }
154 
155 //------------------------------------------------------------------------------
156 // Collections.
157 //------------------------------------------------------------------------------
158 
159 namespace {
160 
161 class PyRegionIterator {
162 public:
PyRegionIterator(PyOperationRef operation)163   PyRegionIterator(PyOperationRef operation)
164       : operation(std::move(operation)) {}
165 
dunderIter()166   PyRegionIterator &dunderIter() { return *this; }
167 
dunderNext()168   PyRegion dunderNext() {
169     operation->checkValid();
170     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
171       throw py::stop_iteration();
172     }
173     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
174     return PyRegion(operation, region);
175   }
176 
bind(py::module & m)177   static void bind(py::module &m) {
178     py::class_<PyRegionIterator>(m, "RegionIterator")
179         .def("__iter__", &PyRegionIterator::dunderIter)
180         .def("__next__", &PyRegionIterator::dunderNext);
181   }
182 
183 private:
184   PyOperationRef operation;
185   int nextIndex = 0;
186 };
187 
188 /// Regions of an op are fixed length and indexed numerically so are represented
189 /// with a sequence-like container.
190 class PyRegionList {
191 public:
PyRegionList(PyOperationRef operation)192   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
193 
dunderLen()194   intptr_t dunderLen() {
195     operation->checkValid();
196     return mlirOperationGetNumRegions(operation->get());
197   }
198 
dunderGetItem(intptr_t index)199   PyRegion dunderGetItem(intptr_t index) {
200     // dunderLen checks validity.
201     if (index < 0 || index >= dunderLen()) {
202       throw SetPyError(PyExc_IndexError,
203                        "attempt to access out of bounds region");
204     }
205     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
206     return PyRegion(operation, region);
207   }
208 
bind(py::module & m)209   static void bind(py::module &m) {
210     py::class_<PyRegionList>(m, "ReqionSequence")
211         .def("__len__", &PyRegionList::dunderLen)
212         .def("__getitem__", &PyRegionList::dunderGetItem);
213   }
214 
215 private:
216   PyOperationRef operation;
217 };
218 
219 class PyBlockIterator {
220 public:
PyBlockIterator(PyOperationRef operation,MlirBlock next)221   PyBlockIterator(PyOperationRef operation, MlirBlock next)
222       : operation(std::move(operation)), next(next) {}
223 
dunderIter()224   PyBlockIterator &dunderIter() { return *this; }
225 
dunderNext()226   PyBlock dunderNext() {
227     operation->checkValid();
228     if (mlirBlockIsNull(next)) {
229       throw py::stop_iteration();
230     }
231 
232     PyBlock returnBlock(operation, next);
233     next = mlirBlockGetNextInRegion(next);
234     return returnBlock;
235   }
236 
bind(py::module & m)237   static void bind(py::module &m) {
238     py::class_<PyBlockIterator>(m, "BlockIterator")
239         .def("__iter__", &PyBlockIterator::dunderIter)
240         .def("__next__", &PyBlockIterator::dunderNext);
241   }
242 
243 private:
244   PyOperationRef operation;
245   MlirBlock next;
246 };
247 
248 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
249 /// we present them as a more full-featured list-like container but optimize
250 /// it for forward iteration. Blocks are always owned by a region.
251 class PyBlockList {
252 public:
PyBlockList(PyOperationRef operation,MlirRegion region)253   PyBlockList(PyOperationRef operation, MlirRegion region)
254       : operation(std::move(operation)), region(region) {}
255 
dunderIter()256   PyBlockIterator dunderIter() {
257     operation->checkValid();
258     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
259   }
260 
dunderLen()261   intptr_t dunderLen() {
262     operation->checkValid();
263     intptr_t count = 0;
264     MlirBlock block = mlirRegionGetFirstBlock(region);
265     while (!mlirBlockIsNull(block)) {
266       count += 1;
267       block = mlirBlockGetNextInRegion(block);
268     }
269     return count;
270   }
271 
dunderGetItem(intptr_t index)272   PyBlock dunderGetItem(intptr_t index) {
273     operation->checkValid();
274     if (index < 0) {
275       throw SetPyError(PyExc_IndexError,
276                        "attempt to access out of bounds block");
277     }
278     MlirBlock block = mlirRegionGetFirstBlock(region);
279     while (!mlirBlockIsNull(block)) {
280       if (index == 0) {
281         return PyBlock(operation, block);
282       }
283       block = mlirBlockGetNextInRegion(block);
284       index -= 1;
285     }
286     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
287   }
288 
appendBlock(py::args pyArgTypes)289   PyBlock appendBlock(py::args pyArgTypes) {
290     operation->checkValid();
291     llvm::SmallVector<MlirType, 4> argTypes;
292     argTypes.reserve(pyArgTypes.size());
293     for (auto &pyArg : pyArgTypes) {
294       argTypes.push_back(pyArg.cast<PyType &>());
295     }
296 
297     MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
298     mlirRegionAppendOwnedBlock(region, block);
299     return PyBlock(operation, block);
300   }
301 
bind(py::module & m)302   static void bind(py::module &m) {
303     py::class_<PyBlockList>(m, "BlockList")
304         .def("__getitem__", &PyBlockList::dunderGetItem)
305         .def("__iter__", &PyBlockList::dunderIter)
306         .def("__len__", &PyBlockList::dunderLen)
307         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
308   }
309 
310 private:
311   PyOperationRef operation;
312   MlirRegion region;
313 };
314 
315 class PyOperationIterator {
316 public:
PyOperationIterator(PyOperationRef parentOperation,MlirOperation next)317   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
318       : parentOperation(std::move(parentOperation)), next(next) {}
319 
dunderIter()320   PyOperationIterator &dunderIter() { return *this; }
321 
dunderNext()322   py::object dunderNext() {
323     parentOperation->checkValid();
324     if (mlirOperationIsNull(next)) {
325       throw py::stop_iteration();
326     }
327 
328     PyOperationRef returnOperation =
329         PyOperation::forOperation(parentOperation->getContext(), next);
330     next = mlirOperationGetNextInBlock(next);
331     return returnOperation->createOpView();
332   }
333 
bind(py::module & m)334   static void bind(py::module &m) {
335     py::class_<PyOperationIterator>(m, "OperationIterator")
336         .def("__iter__", &PyOperationIterator::dunderIter)
337         .def("__next__", &PyOperationIterator::dunderNext);
338   }
339 
340 private:
341   PyOperationRef parentOperation;
342   MlirOperation next;
343 };
344 
345 /// Operations are exposed by the C-API as a forward-only linked list. In
346 /// Python, we present them as a more full-featured list-like container but
347 /// optimize it for forward iteration. Iterable operations are always owned
348 /// by a block.
349 class PyOperationList {
350 public:
PyOperationList(PyOperationRef parentOperation,MlirBlock block)351   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
352       : parentOperation(std::move(parentOperation)), block(block) {}
353 
dunderIter()354   PyOperationIterator dunderIter() {
355     parentOperation->checkValid();
356     return PyOperationIterator(parentOperation,
357                                mlirBlockGetFirstOperation(block));
358   }
359 
dunderLen()360   intptr_t dunderLen() {
361     parentOperation->checkValid();
362     intptr_t count = 0;
363     MlirOperation childOp = mlirBlockGetFirstOperation(block);
364     while (!mlirOperationIsNull(childOp)) {
365       count += 1;
366       childOp = mlirOperationGetNextInBlock(childOp);
367     }
368     return count;
369   }
370 
dunderGetItem(intptr_t index)371   py::object dunderGetItem(intptr_t index) {
372     parentOperation->checkValid();
373     if (index < 0) {
374       throw SetPyError(PyExc_IndexError,
375                        "attempt to access out of bounds operation");
376     }
377     MlirOperation childOp = mlirBlockGetFirstOperation(block);
378     while (!mlirOperationIsNull(childOp)) {
379       if (index == 0) {
380         return PyOperation::forOperation(parentOperation->getContext(), childOp)
381             ->createOpView();
382       }
383       childOp = mlirOperationGetNextInBlock(childOp);
384       index -= 1;
385     }
386     throw SetPyError(PyExc_IndexError,
387                      "attempt to access out of bounds operation");
388   }
389 
bind(py::module & m)390   static void bind(py::module &m) {
391     py::class_<PyOperationList>(m, "OperationList")
392         .def("__getitem__", &PyOperationList::dunderGetItem)
393         .def("__iter__", &PyOperationList::dunderIter)
394         .def("__len__", &PyOperationList::dunderLen);
395   }
396 
397 private:
398   PyOperationRef parentOperation;
399   MlirBlock block;
400 };
401 
402 } // namespace
403 
404 //------------------------------------------------------------------------------
405 // PyMlirContext
406 //------------------------------------------------------------------------------
407 
PyMlirContext(MlirContext context)408 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
409   py::gil_scoped_acquire acquire;
410   auto &liveContexts = getLiveContexts();
411   liveContexts[context.ptr] = this;
412 }
413 
~PyMlirContext()414 PyMlirContext::~PyMlirContext() {
415   // Note that the only public way to construct an instance is via the
416   // forContext method, which always puts the associated handle into
417   // liveContexts.
418   py::gil_scoped_acquire acquire;
419   getLiveContexts().erase(context.ptr);
420   mlirContextDestroy(context);
421 }
422 
getCapsule()423 py::object PyMlirContext::getCapsule() {
424   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
425 }
426 
createFromCapsule(py::object capsule)427 py::object PyMlirContext::createFromCapsule(py::object capsule) {
428   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
429   if (mlirContextIsNull(rawContext))
430     throw py::error_already_set();
431   return forContext(rawContext).releaseObject();
432 }
433 
createNewContextForInit()434 PyMlirContext *PyMlirContext::createNewContextForInit() {
435   MlirContext context = mlirContextCreate();
436   mlirRegisterAllDialects(context);
437   return new PyMlirContext(context);
438 }
439 
forContext(MlirContext context)440 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
441   py::gil_scoped_acquire acquire;
442   auto &liveContexts = getLiveContexts();
443   auto it = liveContexts.find(context.ptr);
444   if (it == liveContexts.end()) {
445     // Create.
446     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
447     py::object pyRef = py::cast(unownedContextWrapper);
448     assert(pyRef && "cast to py::object failed");
449     liveContexts[context.ptr] = unownedContextWrapper;
450     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
451   }
452   // Use existing.
453   py::object pyRef = py::cast(it->second);
454   return PyMlirContextRef(it->second, std::move(pyRef));
455 }
456 
getLiveContexts()457 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
458   static LiveContextMap liveContexts;
459   return liveContexts;
460 }
461 
getLiveCount()462 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
463 
getLiveOperationCount()464 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
465 
getLiveModuleCount()466 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
467 
contextEnter()468 pybind11::object PyMlirContext::contextEnter() {
469   return PyThreadContextEntry::pushContext(*this);
470 }
471 
contextExit(pybind11::object excType,pybind11::object excVal,pybind11::object excTb)472 void PyMlirContext::contextExit(pybind11::object excType,
473                                 pybind11::object excVal,
474                                 pybind11::object excTb) {
475   PyThreadContextEntry::popContext(*this);
476 }
477 
resolve()478 PyMlirContext &DefaultingPyMlirContext::resolve() {
479   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
480   if (!context) {
481     throw SetPyError(
482         PyExc_RuntimeError,
483         "An MLIR function requires a Context but none was provided in the call "
484         "or from the surrounding environment. Either pass to the function with "
485         "a 'context=' argument or establish a default using 'with Context():'");
486   }
487   return *context;
488 }
489 
490 //------------------------------------------------------------------------------
491 // PyThreadContextEntry management
492 //------------------------------------------------------------------------------
493 
getStack()494 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
495   static thread_local std::vector<PyThreadContextEntry> stack;
496   return stack;
497 }
498 
getTopOfStack()499 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
500   auto &stack = getStack();
501   if (stack.empty())
502     return nullptr;
503   return &stack.back();
504 }
505 
push(FrameKind frameKind,py::object context,py::object insertionPoint,py::object location)506 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
507                                 py::object insertionPoint,
508                                 py::object location) {
509   auto &stack = getStack();
510   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
511                      std::move(location));
512   // If the new stack has more than one entry and the context of the new top
513   // entry matches the previous, copy the insertionPoint and location from the
514   // previous entry if missing from the new top entry.
515   if (stack.size() > 1) {
516     auto &prev = *(stack.rbegin() + 1);
517     auto &current = stack.back();
518     if (current.context.is(prev.context)) {
519       // Default non-context objects from the previous entry.
520       if (!current.insertionPoint)
521         current.insertionPoint = prev.insertionPoint;
522       if (!current.location)
523         current.location = prev.location;
524     }
525   }
526 }
527 
getContext()528 PyMlirContext *PyThreadContextEntry::getContext() {
529   if (!context)
530     return nullptr;
531   return py::cast<PyMlirContext *>(context);
532 }
533 
getInsertionPoint()534 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
535   if (!insertionPoint)
536     return nullptr;
537   return py::cast<PyInsertionPoint *>(insertionPoint);
538 }
539 
getLocation()540 PyLocation *PyThreadContextEntry::getLocation() {
541   if (!location)
542     return nullptr;
543   return py::cast<PyLocation *>(location);
544 }
545 
getDefaultContext()546 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
547   auto *tos = getTopOfStack();
548   return tos ? tos->getContext() : nullptr;
549 }
550 
getDefaultInsertionPoint()551 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
552   auto *tos = getTopOfStack();
553   return tos ? tos->getInsertionPoint() : nullptr;
554 }
555 
getDefaultLocation()556 PyLocation *PyThreadContextEntry::getDefaultLocation() {
557   auto *tos = getTopOfStack();
558   return tos ? tos->getLocation() : nullptr;
559 }
560 
pushContext(PyMlirContext & context)561 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
562   py::object contextObj = py::cast(context);
563   push(FrameKind::Context, /*context=*/contextObj,
564        /*insertionPoint=*/py::object(),
565        /*location=*/py::object());
566   return contextObj;
567 }
568 
popContext(PyMlirContext & context)569 void PyThreadContextEntry::popContext(PyMlirContext &context) {
570   auto &stack = getStack();
571   if (stack.empty())
572     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
573   auto &tos = stack.back();
574   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
575     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
576   stack.pop_back();
577 }
578 
579 py::object
pushInsertionPoint(PyInsertionPoint & insertionPoint)580 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
581   py::object contextObj =
582       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
583   py::object insertionPointObj = py::cast(insertionPoint);
584   push(FrameKind::InsertionPoint,
585        /*context=*/contextObj,
586        /*insertionPoint=*/insertionPointObj,
587        /*location=*/py::object());
588   return insertionPointObj;
589 }
590 
popInsertionPoint(PyInsertionPoint & insertionPoint)591 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
592   auto &stack = getStack();
593   if (stack.empty())
594     throw SetPyError(PyExc_RuntimeError,
595                      "Unbalanced InsertionPoint enter/exit");
596   auto &tos = stack.back();
597   if (tos.frameKind != FrameKind::InsertionPoint &&
598       tos.getInsertionPoint() != &insertionPoint)
599     throw SetPyError(PyExc_RuntimeError,
600                      "Unbalanced InsertionPoint enter/exit");
601   stack.pop_back();
602 }
603 
pushLocation(PyLocation & location)604 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
605   py::object contextObj = location.getContext().getObject();
606   py::object locationObj = py::cast(location);
607   push(FrameKind::Location, /*context=*/contextObj,
608        /*insertionPoint=*/py::object(),
609        /*location=*/locationObj);
610   return locationObj;
611 }
612 
popLocation(PyLocation & location)613 void PyThreadContextEntry::popLocation(PyLocation &location) {
614   auto &stack = getStack();
615   if (stack.empty())
616     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
617   auto &tos = stack.back();
618   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
619     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
620   stack.pop_back();
621 }
622 
623 //------------------------------------------------------------------------------
624 // PyDialect, PyDialectDescriptor, PyDialects
625 //------------------------------------------------------------------------------
626 
getDialectForKey(const std::string & key,bool attrError)627 MlirDialect PyDialects::getDialectForKey(const std::string &key,
628                                          bool attrError) {
629   // If the "std" dialect was asked for, substitute the empty namespace :(
630   static const std::string emptyKey;
631   const std::string *canonKey = key == "std" ? &emptyKey : &key;
632   MlirDialect dialect = mlirContextGetOrLoadDialect(
633       getContext()->get(), {canonKey->data(), canonKey->size()});
634   if (mlirDialectIsNull(dialect)) {
635     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
636                      Twine("Dialect '") + key + "' not found");
637   }
638   return dialect;
639 }
640 
641 //------------------------------------------------------------------------------
642 // PyLocation
643 //------------------------------------------------------------------------------
644 
getCapsule()645 py::object PyLocation::getCapsule() {
646   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
647 }
648 
createFromCapsule(py::object capsule)649 PyLocation PyLocation::createFromCapsule(py::object capsule) {
650   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
651   if (mlirLocationIsNull(rawLoc))
652     throw py::error_already_set();
653   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
654                     rawLoc);
655 }
656 
contextEnter()657 py::object PyLocation::contextEnter() {
658   return PyThreadContextEntry::pushLocation(*this);
659 }
660 
contextExit(py::object excType,py::object excVal,py::object excTb)661 void PyLocation::contextExit(py::object excType, py::object excVal,
662                              py::object excTb) {
663   PyThreadContextEntry::popLocation(*this);
664 }
665 
resolve()666 PyLocation &DefaultingPyLocation::resolve() {
667   auto *location = PyThreadContextEntry::getDefaultLocation();
668   if (!location) {
669     throw SetPyError(
670         PyExc_RuntimeError,
671         "An MLIR function requires a Location but none was provided in the "
672         "call or from the surrounding environment. Either pass to the function "
673         "with a 'loc=' argument or establish a default using 'with loc:'");
674   }
675   return *location;
676 }
677 
678 //------------------------------------------------------------------------------
679 // PyModule
680 //------------------------------------------------------------------------------
681 
PyModule(PyMlirContextRef contextRef,MlirModule module)682 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
683     : BaseContextObject(std::move(contextRef)), module(module) {}
684 
~PyModule()685 PyModule::~PyModule() {
686   py::gil_scoped_acquire acquire;
687   auto &liveModules = getContext()->liveModules;
688   assert(liveModules.count(module.ptr) == 1 &&
689          "destroying module not in live map");
690   liveModules.erase(module.ptr);
691   mlirModuleDestroy(module);
692 }
693 
forModule(MlirModule module)694 PyModuleRef PyModule::forModule(MlirModule module) {
695   MlirContext context = mlirModuleGetContext(module);
696   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
697 
698   py::gil_scoped_acquire acquire;
699   auto &liveModules = contextRef->liveModules;
700   auto it = liveModules.find(module.ptr);
701   if (it == liveModules.end()) {
702     // Create.
703     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
704     // Note that the default return value policy on cast is automatic_reference,
705     // which does not take ownership (delete will not be called).
706     // Just be explicit.
707     py::object pyRef =
708         py::cast(unownedModule, py::return_value_policy::take_ownership);
709     unownedModule->handle = pyRef;
710     liveModules[module.ptr] =
711         std::make_pair(unownedModule->handle, unownedModule);
712     return PyModuleRef(unownedModule, std::move(pyRef));
713   }
714   // Use existing.
715   PyModule *existing = it->second.second;
716   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
717   return PyModuleRef(existing, std::move(pyRef));
718 }
719 
createFromCapsule(py::object capsule)720 py::object PyModule::createFromCapsule(py::object capsule) {
721   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
722   if (mlirModuleIsNull(rawModule))
723     throw py::error_already_set();
724   return forModule(rawModule).releaseObject();
725 }
726 
getCapsule()727 py::object PyModule::getCapsule() {
728   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
729 }
730 
731 //------------------------------------------------------------------------------
732 // PyOperation
733 //------------------------------------------------------------------------------
734 
PyOperation(PyMlirContextRef contextRef,MlirOperation operation)735 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
736     : BaseContextObject(std::move(contextRef)), operation(operation) {}
737 
~PyOperation()738 PyOperation::~PyOperation() {
739   auto &liveOperations = getContext()->liveOperations;
740   assert(liveOperations.count(operation.ptr) == 1 &&
741          "destroying operation not in live map");
742   liveOperations.erase(operation.ptr);
743   if (!isAttached()) {
744     mlirOperationDestroy(operation);
745   }
746 }
747 
createInstance(PyMlirContextRef contextRef,MlirOperation operation,py::object parentKeepAlive)748 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
749                                            MlirOperation operation,
750                                            py::object parentKeepAlive) {
751   auto &liveOperations = contextRef->liveOperations;
752   // Create.
753   PyOperation *unownedOperation =
754       new PyOperation(std::move(contextRef), operation);
755   // Note that the default return value policy on cast is automatic_reference,
756   // which does not take ownership (delete will not be called).
757   // Just be explicit.
758   py::object pyRef =
759       py::cast(unownedOperation, py::return_value_policy::take_ownership);
760   unownedOperation->handle = pyRef;
761   if (parentKeepAlive) {
762     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
763   }
764   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
765   return PyOperationRef(unownedOperation, std::move(pyRef));
766 }
767 
forOperation(PyMlirContextRef contextRef,MlirOperation operation,py::object parentKeepAlive)768 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
769                                          MlirOperation operation,
770                                          py::object parentKeepAlive) {
771   auto &liveOperations = contextRef->liveOperations;
772   auto it = liveOperations.find(operation.ptr);
773   if (it == liveOperations.end()) {
774     // Create.
775     return createInstance(std::move(contextRef), operation,
776                           std::move(parentKeepAlive));
777   }
778   // Use existing.
779   PyOperation *existing = it->second.second;
780   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
781   return PyOperationRef(existing, std::move(pyRef));
782 }
783 
createDetached(PyMlirContextRef contextRef,MlirOperation operation,py::object parentKeepAlive)784 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
785                                            MlirOperation operation,
786                                            py::object parentKeepAlive) {
787   auto &liveOperations = contextRef->liveOperations;
788   assert(liveOperations.count(operation.ptr) == 0 &&
789          "cannot create detached operation that already exists");
790   (void)liveOperations;
791 
792   PyOperationRef created = createInstance(std::move(contextRef), operation,
793                                           std::move(parentKeepAlive));
794   created->attached = false;
795   return created;
796 }
797 
checkValid() const798 void PyOperation::checkValid() const {
799   if (!valid) {
800     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
801   }
802 }
803 
print(py::object fileObject,bool binary,llvm::Optional<int64_t> largeElementsLimit,bool enableDebugInfo,bool prettyDebugInfo,bool printGenericOpForm,bool useLocalScope)804 void PyOperationBase::print(py::object fileObject, bool binary,
805                             llvm::Optional<int64_t> largeElementsLimit,
806                             bool enableDebugInfo, bool prettyDebugInfo,
807                             bool printGenericOpForm, bool useLocalScope) {
808   PyOperation &operation = getOperation();
809   operation.checkValid();
810   if (fileObject.is_none())
811     fileObject = py::module::import("sys").attr("stdout");
812 
813   if (!printGenericOpForm && !mlirOperationVerify(operation)) {
814     fileObject.attr("write")("// Verification failed, printing generic form\n");
815     printGenericOpForm = true;
816   }
817 
818   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
819   if (largeElementsLimit)
820     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
821   if (enableDebugInfo)
822     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
823   if (printGenericOpForm)
824     mlirOpPrintingFlagsPrintGenericOpForm(flags);
825 
826   PyFileAccumulator accum(fileObject, binary);
827   py::gil_scoped_release();
828   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
829                               accum.getUserData());
830   mlirOpPrintingFlagsDestroy(flags);
831 }
832 
getAsm(bool binary,llvm::Optional<int64_t> largeElementsLimit,bool enableDebugInfo,bool prettyDebugInfo,bool printGenericOpForm,bool useLocalScope)833 py::object PyOperationBase::getAsm(bool binary,
834                                    llvm::Optional<int64_t> largeElementsLimit,
835                                    bool enableDebugInfo, bool prettyDebugInfo,
836                                    bool printGenericOpForm,
837                                    bool useLocalScope) {
838   py::object fileObject;
839   if (binary) {
840     fileObject = py::module::import("io").attr("BytesIO")();
841   } else {
842     fileObject = py::module::import("io").attr("StringIO")();
843   }
844   print(fileObject, /*binary=*/binary,
845         /*largeElementsLimit=*/largeElementsLimit,
846         /*enableDebugInfo=*/enableDebugInfo,
847         /*prettyDebugInfo=*/prettyDebugInfo,
848         /*printGenericOpForm=*/printGenericOpForm,
849         /*useLocalScope=*/useLocalScope);
850 
851   return fileObject.attr("getvalue")();
852 }
853 
getParentOperation()854 PyOperationRef PyOperation::getParentOperation() {
855   if (!isAttached())
856     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
857   MlirOperation operation = mlirOperationGetParentOperation(get());
858   if (mlirOperationIsNull(operation))
859     throw SetPyError(PyExc_ValueError, "Operation has no parent.");
860   return PyOperation::forOperation(getContext(), operation);
861 }
862 
getBlock()863 PyBlock PyOperation::getBlock() {
864   PyOperationRef parentOperation = getParentOperation();
865   MlirBlock block = mlirOperationGetBlock(get());
866   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
867   return PyBlock{std::move(parentOperation), block};
868 }
869 
create(std::string name,llvm::Optional<std::vector<PyValue * >> operands,llvm::Optional<std::vector<PyType * >> results,llvm::Optional<py::dict> attributes,llvm::Optional<std::vector<PyBlock * >> successors,int regions,DefaultingPyLocation location,py::object maybeIp)870 py::object PyOperation::create(
871     std::string name, llvm::Optional<std::vector<PyValue *>> operands,
872     llvm::Optional<std::vector<PyType *>> results,
873     llvm::Optional<py::dict> attributes,
874     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
875     DefaultingPyLocation location, py::object maybeIp) {
876   llvm::SmallVector<MlirValue, 4> mlirOperands;
877   llvm::SmallVector<MlirType, 4> mlirResults;
878   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
879   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
880 
881   // General parameter validation.
882   if (regions < 0)
883     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
884 
885   // Unpack/validate operands.
886   if (operands) {
887     mlirOperands.reserve(operands->size());
888     for (PyValue *operand : *operands) {
889       if (!operand)
890         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
891       mlirOperands.push_back(operand->get());
892     }
893   }
894 
895   // Unpack/validate results.
896   if (results) {
897     mlirResults.reserve(results->size());
898     for (PyType *result : *results) {
899       // TODO: Verify result type originate from the same context.
900       if (!result)
901         throw SetPyError(PyExc_ValueError, "result type cannot be None");
902       mlirResults.push_back(*result);
903     }
904   }
905   // Unpack/validate attributes.
906   if (attributes) {
907     mlirAttributes.reserve(attributes->size());
908     for (auto &it : *attributes) {
909       std::string key;
910       try {
911         key = it.first.cast<std::string>();
912       } catch (py::cast_error &err) {
913         std::string msg = "Invalid attribute key (not a string) when "
914                           "attempting to create the operation \"" +
915                           name + "\" (" + err.what() + ")";
916         throw py::cast_error(msg);
917       }
918       try {
919         auto &attribute = it.second.cast<PyAttribute &>();
920         // TODO: Verify attribute originates from the same context.
921         mlirAttributes.emplace_back(std::move(key), attribute);
922       } catch (py::reference_cast_error &) {
923         // This exception seems thrown when the value is "None".
924         std::string msg =
925             "Found an invalid (`None`?) attribute value for the key \"" + key +
926             "\" when attempting to create the operation \"" + name + "\"";
927         throw py::cast_error(msg);
928       } catch (py::cast_error &err) {
929         std::string msg = "Invalid attribute value for the key \"" + key +
930                           "\" when attempting to create the operation \"" +
931                           name + "\" (" + err.what() + ")";
932         throw py::cast_error(msg);
933       }
934     }
935   }
936   // Unpack/validate successors.
937   if (successors) {
938     llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
939     mlirSuccessors.reserve(successors->size());
940     for (auto *successor : *successors) {
941       // TODO: Verify successor originate from the same context.
942       if (!successor)
943         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
944       mlirSuccessors.push_back(successor->get());
945     }
946   }
947 
948   // Apply unpacked/validated to the operation state. Beyond this
949   // point, exceptions cannot be thrown or else the state will leak.
950   MlirOperationState state =
951       mlirOperationStateGet(toMlirStringRef(name), location);
952   if (!mlirOperands.empty())
953     mlirOperationStateAddOperands(&state, mlirOperands.size(),
954                                   mlirOperands.data());
955   if (!mlirResults.empty())
956     mlirOperationStateAddResults(&state, mlirResults.size(),
957                                  mlirResults.data());
958   if (!mlirAttributes.empty()) {
959     // Note that the attribute names directly reference bytes in
960     // mlirAttributes, so that vector must not be changed from here
961     // on.
962     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
963     mlirNamedAttributes.reserve(mlirAttributes.size());
964     for (auto &it : mlirAttributes)
965       mlirNamedAttributes.push_back(
966           mlirNamedAttributeGet(toMlirStringRef(it.first), it.second));
967     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
968                                     mlirNamedAttributes.data());
969   }
970   if (!mlirSuccessors.empty())
971     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
972                                     mlirSuccessors.data());
973   if (regions) {
974     llvm::SmallVector<MlirRegion, 4> mlirRegions;
975     mlirRegions.resize(regions);
976     for (int i = 0; i < regions; ++i)
977       mlirRegions[i] = mlirRegionCreate();
978     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
979                                       mlirRegions.data());
980   }
981 
982   // Construct the operation.
983   MlirOperation operation = mlirOperationCreate(&state);
984   PyOperationRef created =
985       PyOperation::createDetached(location->getContext(), operation);
986 
987   // InsertPoint active?
988   if (!maybeIp.is(py::cast(false))) {
989     PyInsertionPoint *ip;
990     if (maybeIp.is_none()) {
991       ip = PyThreadContextEntry::getDefaultInsertionPoint();
992     } else {
993       ip = py::cast<PyInsertionPoint *>(maybeIp);
994     }
995     if (ip)
996       ip->insert(*created.get());
997   }
998 
999   return created->createOpView();
1000 }
1001 
createOpView()1002 py::object PyOperation::createOpView() {
1003   MlirIdentifier ident = mlirOperationGetName(get());
1004   MlirStringRef identStr = mlirIdentifierStr(ident);
1005   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1006       StringRef(identStr.data, identStr.length));
1007   if (opViewClass)
1008     return (*opViewClass)(getRef().getObject());
1009   return py::cast(PyOpView(getRef().getObject()));
1010 }
1011 
PyOpView(py::object operationObject)1012 PyOpView::PyOpView(py::object operationObject)
1013     // Casting through the PyOperationBase base-class and then back to the
1014     // Operation lets us accept any PyOperationBase subclass.
1015     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1016       operationObject(operation.getRef().getObject()) {}
1017 
createRawSubclass(py::object userClass)1018 py::object PyOpView::createRawSubclass(py::object userClass) {
1019   // This is... a little gross. The typical pattern is to have a pure python
1020   // class that extends OpView like:
1021   //   class AddFOp(_cext.ir.OpView):
1022   //     def __init__(self, loc, lhs, rhs):
1023   //       operation = loc.context.create_operation(
1024   //           "addf", lhs, rhs, results=[lhs.type])
1025   //       super().__init__(operation)
1026   //
1027   // I.e. The goal of the user facing type is to provide a nice constructor
1028   // that has complete freedom for the op under construction. This is at odds
1029   // with our other desire to sometimes create this object by just passing an
1030   // operation (to initialize the base class). We could do *arg and **kwargs
1031   // munging to try to make it work, but instead, we synthesize a new class
1032   // on the fly which extends this user class (AddFOp in this example) and
1033   // *give it* the base class's __init__ method, thus bypassing the
1034   // intermediate subclass's __init__ method entirely. While slightly,
1035   // underhanded, this is safe/legal because the type hierarchy has not changed
1036   // (we just added a new leaf) and we aren't mucking around with __new__.
1037   // Typically, this new class will be stored on the original as "_Raw" and will
1038   // be used for casts and other things that need a variant of the class that
1039   // is initialized purely from an operation.
1040   py::object parentMetaclass =
1041       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1042   py::dict attributes;
1043   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1044   // now.
1045   //   auto opViewType = py::type::of<PyOpView>();
1046   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1047   attributes["__init__"] = opViewType.attr("__init__");
1048   py::str origName = userClass.attr("__name__");
1049   py::str newName = py::str("_") + origName;
1050   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1051 }
1052 
1053 //------------------------------------------------------------------------------
1054 // PyInsertionPoint.
1055 //------------------------------------------------------------------------------
1056 
PyInsertionPoint(PyBlock & block)1057 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1058 
PyInsertionPoint(PyOperationBase & beforeOperationBase)1059 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1060     : refOperation(beforeOperationBase.getOperation().getRef()),
1061       block((*refOperation)->getBlock()) {}
1062 
insert(PyOperationBase & operationBase)1063 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1064   PyOperation &operation = operationBase.getOperation();
1065   if (operation.isAttached())
1066     throw SetPyError(PyExc_ValueError,
1067                      "Attempt to insert operation that is already attached");
1068   block.getParentOperation()->checkValid();
1069   MlirOperation beforeOp = {nullptr};
1070   if (refOperation) {
1071     // Insert before operation.
1072     (*refOperation)->checkValid();
1073     beforeOp = (*refOperation)->get();
1074   }
1075   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1076   operation.setAttached();
1077 }
1078 
atBlockBegin(PyBlock & block)1079 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1080   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1081   if (mlirOperationIsNull(firstOp)) {
1082     // Just insert at end.
1083     return PyInsertionPoint(block);
1084   }
1085 
1086   // Insert before first op.
1087   PyOperationRef firstOpRef = PyOperation::forOperation(
1088       block.getParentOperation()->getContext(), firstOp);
1089   return PyInsertionPoint{block, std::move(firstOpRef)};
1090 }
1091 
atBlockTerminator(PyBlock & block)1092 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1093   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1094   if (mlirOperationIsNull(terminator))
1095     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1096   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1097       block.getParentOperation()->getContext(), terminator);
1098   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1099 }
1100 
contextEnter()1101 py::object PyInsertionPoint::contextEnter() {
1102   return PyThreadContextEntry::pushInsertionPoint(*this);
1103 }
1104 
contextExit(pybind11::object excType,pybind11::object excVal,pybind11::object excTb)1105 void PyInsertionPoint::contextExit(pybind11::object excType,
1106                                    pybind11::object excVal,
1107                                    pybind11::object excTb) {
1108   PyThreadContextEntry::popInsertionPoint(*this);
1109 }
1110 
1111 //------------------------------------------------------------------------------
1112 // PyAttribute.
1113 //------------------------------------------------------------------------------
1114 
operator ==(const PyAttribute & other)1115 bool PyAttribute::operator==(const PyAttribute &other) {
1116   return mlirAttributeEqual(attr, other.attr);
1117 }
1118 
getCapsule()1119 py::object PyAttribute::getCapsule() {
1120   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1121 }
1122 
createFromCapsule(py::object capsule)1123 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1124   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1125   if (mlirAttributeIsNull(rawAttr))
1126     throw py::error_already_set();
1127   return PyAttribute(
1128       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1129 }
1130 
1131 //------------------------------------------------------------------------------
1132 // PyNamedAttribute.
1133 //------------------------------------------------------------------------------
1134 
PyNamedAttribute(MlirAttribute attr,std::string ownedName)1135 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1136     : ownedName(new std::string(std::move(ownedName))) {
1137   namedAttr = mlirNamedAttributeGet(toMlirStringRef(*this->ownedName), attr);
1138 }
1139 
1140 //------------------------------------------------------------------------------
1141 // PyType.
1142 //------------------------------------------------------------------------------
1143 
operator ==(const PyType & other)1144 bool PyType::operator==(const PyType &other) {
1145   return mlirTypeEqual(type, other.type);
1146 }
1147 
getCapsule()1148 py::object PyType::getCapsule() {
1149   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1150 }
1151 
createFromCapsule(py::object capsule)1152 PyType PyType::createFromCapsule(py::object capsule) {
1153   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1154   if (mlirTypeIsNull(rawType))
1155     throw py::error_already_set();
1156   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1157                 rawType);
1158 }
1159 
1160 //------------------------------------------------------------------------------
1161 // PyValue and subclases.
1162 //------------------------------------------------------------------------------
1163 
1164 namespace {
1165 /// CRTP base class for Python MLIR values that subclass Value and should be
1166 /// castable from it. The value hierarchy is one level deep and is not supposed
1167 /// to accommodate other levels unless core MLIR changes.
1168 template <typename DerivedTy>
1169 class PyConcreteValue : public PyValue {
1170 public:
1171   // Derived classes must define statics for:
1172   //   IsAFunctionTy isaFunction
1173   //   const char *pyClassName
1174   // and redefine bindDerived.
1175   using ClassTy = py::class_<DerivedTy, PyValue>;
1176   using IsAFunctionTy = bool (*)(MlirValue);
1177 
1178   PyConcreteValue() = default;
PyConcreteValue(PyOperationRef operationRef,MlirValue value)1179   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1180       : PyValue(operationRef, value) {}
PyConcreteValue(PyValue & orig)1181   PyConcreteValue(PyValue &orig)
1182       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1183 
1184   /// Attempts to cast the original value to the derived type and throws on
1185   /// type mismatches.
castFrom(PyValue & orig)1186   static MlirValue castFrom(PyValue &orig) {
1187     if (!DerivedTy::isaFunction(orig.get())) {
1188       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1189       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1190                                              DerivedTy::pyClassName +
1191                                              " (from " + origRepr + ")");
1192     }
1193     return orig.get();
1194   }
1195 
1196   /// Binds the Python module objects to functions of this class.
bind(py::module & m)1197   static void bind(py::module &m) {
1198     auto cls = ClassTy(m, DerivedTy::pyClassName);
1199     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>());
1200     DerivedTy::bindDerived(cls);
1201   }
1202 
1203   /// Implemented by derived classes to add methods to the Python subclass.
bindDerived(ClassTy & m)1204   static void bindDerived(ClassTy &m) {}
1205 };
1206 
1207 /// Python wrapper for MlirBlockArgument.
1208 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1209 public:
1210   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1211   static constexpr const char *pyClassName = "BlockArgument";
1212   using PyConcreteValue::PyConcreteValue;
1213 
bindDerived(ClassTy & c)1214   static void bindDerived(ClassTy &c) {
1215     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1216       return PyBlock(self.getParentOperation(),
1217                      mlirBlockArgumentGetOwner(self.get()));
1218     });
1219     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1220       return mlirBlockArgumentGetArgNumber(self.get());
1221     });
1222     c.def("set_type", [](PyBlockArgument &self, PyType type) {
1223       return mlirBlockArgumentSetType(self.get(), type);
1224     });
1225   }
1226 };
1227 
1228 /// Python wrapper for MlirOpResult.
1229 class PyOpResult : public PyConcreteValue<PyOpResult> {
1230 public:
1231   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1232   static constexpr const char *pyClassName = "OpResult";
1233   using PyConcreteValue::PyConcreteValue;
1234 
bindDerived(ClassTy & c)1235   static void bindDerived(ClassTy &c) {
1236     c.def_property_readonly("owner", [](PyOpResult &self) {
1237       assert(
1238           mlirOperationEqual(self.getParentOperation()->get(),
1239                              mlirOpResultGetOwner(self.get())) &&
1240           "expected the owner of the value in Python to match that in the IR");
1241       return self.getParentOperation();
1242     });
1243     c.def_property_readonly("result_number", [](PyOpResult &self) {
1244       return mlirOpResultGetResultNumber(self.get());
1245     });
1246   }
1247 };
1248 
1249 /// A list of block arguments. Internally, these are stored as consecutive
1250 /// elements, random access is cheap. The argument list is associated with the
1251 /// operation that contains the block (detached blocks are not allowed in
1252 /// Python bindings) and extends its lifetime.
1253 class PyBlockArgumentList {
1254 public:
PyBlockArgumentList(PyOperationRef operation,MlirBlock block)1255   PyBlockArgumentList(PyOperationRef operation, MlirBlock block)
1256       : operation(std::move(operation)), block(block) {}
1257 
1258   /// Returns the length of the block argument list.
dunderLen()1259   intptr_t dunderLen() {
1260     operation->checkValid();
1261     return mlirBlockGetNumArguments(block);
1262   }
1263 
1264   /// Returns `index`-th element of the block argument list.
dunderGetItem(intptr_t index)1265   PyBlockArgument dunderGetItem(intptr_t index) {
1266     if (index < 0 || index >= dunderLen()) {
1267       throw SetPyError(PyExc_IndexError,
1268                        "attempt to access out of bounds region");
1269     }
1270     PyValue value(operation, mlirBlockGetArgument(block, index));
1271     return PyBlockArgument(value);
1272   }
1273 
1274   /// Defines a Python class in the bindings.
bind(py::module & m)1275   static void bind(py::module &m) {
1276     py::class_<PyBlockArgumentList>(m, "BlockArgumentList")
1277         .def("__len__", &PyBlockArgumentList::dunderLen)
1278         .def("__getitem__", &PyBlockArgumentList::dunderGetItem);
1279   }
1280 
1281 private:
1282   PyOperationRef operation;
1283   MlirBlock block;
1284 };
1285 
1286 /// A list of operation operands. Internally, these are stored as consecutive
1287 /// elements, random access is cheap. The result list is associated with the
1288 /// operation whose results these are, and extends the lifetime of this
1289 /// operation.
1290 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1291 public:
1292   static constexpr const char *pyClassName = "OpOperandList";
1293 
PyOpOperandList(PyOperationRef operation,intptr_t startIndex=0,intptr_t length=-1,intptr_t step=1)1294   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
1295                   intptr_t length = -1, intptr_t step = 1)
1296       : Sliceable(startIndex,
1297                   length == -1 ? mlirOperationGetNumOperands(operation->get())
1298                                : length,
1299                   step),
1300         operation(operation) {}
1301 
getNumElements()1302   intptr_t getNumElements() {
1303     operation->checkValid();
1304     return mlirOperationGetNumOperands(operation->get());
1305   }
1306 
getElement(intptr_t pos)1307   PyValue getElement(intptr_t pos) {
1308     return PyValue(operation, mlirOperationGetOperand(operation->get(), pos));
1309   }
1310 
slice(intptr_t startIndex,intptr_t length,intptr_t step)1311   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1312     return PyOpOperandList(operation, startIndex, length, step);
1313   }
1314 
1315 private:
1316   PyOperationRef operation;
1317 };
1318 
1319 /// A list of operation results. Internally, these are stored as consecutive
1320 /// elements, random access is cheap. The result list is associated with the
1321 /// operation whose results these are, and extends the lifetime of this
1322 /// operation.
1323 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1324 public:
1325   static constexpr const char *pyClassName = "OpResultList";
1326 
PyOpResultList(PyOperationRef operation,intptr_t startIndex=0,intptr_t length=-1,intptr_t step=1)1327   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1328                  intptr_t length = -1, intptr_t step = 1)
1329       : Sliceable(startIndex,
1330                   length == -1 ? mlirOperationGetNumResults(operation->get())
1331                                : length,
1332                   step),
1333         operation(operation) {}
1334 
getNumElements()1335   intptr_t getNumElements() {
1336     operation->checkValid();
1337     return mlirOperationGetNumResults(operation->get());
1338   }
1339 
getElement(intptr_t index)1340   PyOpResult getElement(intptr_t index) {
1341     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1342     return PyOpResult(value);
1343   }
1344 
slice(intptr_t startIndex,intptr_t length,intptr_t step)1345   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1346     return PyOpResultList(operation, startIndex, length, step);
1347   }
1348 
1349 private:
1350   PyOperationRef operation;
1351 };
1352 
1353 /// A list of operation attributes. Can be indexed by name, producing
1354 /// attributes, or by index, producing named attributes.
1355 class PyOpAttributeMap {
1356 public:
PyOpAttributeMap(PyOperationRef operation)1357   PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
1358 
dunderGetItemNamed(const std::string & name)1359   PyAttribute dunderGetItemNamed(const std::string &name) {
1360     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
1361                                                          toMlirStringRef(name));
1362     if (mlirAttributeIsNull(attr)) {
1363       throw SetPyError(PyExc_KeyError,
1364                        "attempt to access a non-existent attribute");
1365     }
1366     return PyAttribute(operation->getContext(), attr);
1367   }
1368 
dunderGetItemIndexed(intptr_t index)1369   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
1370     if (index < 0 || index >= dunderLen()) {
1371       throw SetPyError(PyExc_IndexError,
1372                        "attempt to access out of bounds attribute");
1373     }
1374     MlirNamedAttribute namedAttr =
1375         mlirOperationGetAttribute(operation->get(), index);
1376     return PyNamedAttribute(namedAttr.attribute,
1377                             std::string(namedAttr.name.data));
1378   }
1379 
dunderSetItem(const std::string & name,PyAttribute attr)1380   void dunderSetItem(const std::string &name, PyAttribute attr) {
1381     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
1382                                     attr);
1383   }
1384 
dunderDelItem(const std::string & name)1385   void dunderDelItem(const std::string &name) {
1386     int removed = mlirOperationRemoveAttributeByName(operation->get(),
1387                                                      toMlirStringRef(name));
1388     if (!removed)
1389       throw SetPyError(PyExc_KeyError,
1390                        "attempt to delete a non-existent attribute");
1391   }
1392 
dunderLen()1393   intptr_t dunderLen() {
1394     return mlirOperationGetNumAttributes(operation->get());
1395   }
1396 
dunderContains(const std::string & name)1397   bool dunderContains(const std::string &name) {
1398     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
1399         operation->get(), toMlirStringRef(name)));
1400   }
1401 
bind(py::module & m)1402   static void bind(py::module &m) {
1403     py::class_<PyOpAttributeMap>(m, "OpAttributeMap")
1404         .def("__contains__", &PyOpAttributeMap::dunderContains)
1405         .def("__len__", &PyOpAttributeMap::dunderLen)
1406         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
1407         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
1408         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
1409         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
1410   }
1411 
1412 private:
1413   PyOperationRef operation;
1414 };
1415 
1416 } // end namespace
1417 
1418 //------------------------------------------------------------------------------
1419 // Builtin attribute subclasses.
1420 //------------------------------------------------------------------------------
1421 
1422 namespace {
1423 
1424 /// CRTP base classes for Python attributes that subclass Attribute and should
1425 /// be castable from it (i.e. via something like StringAttr(attr)).
1426 /// By default, attribute class hierarchies are one level deep (i.e. a
1427 /// concrete attribute class extends PyAttribute); however, intermediate
1428 /// python-visible base classes can be modeled by specifying a BaseTy.
1429 template <typename DerivedTy, typename BaseTy = PyAttribute>
1430 class PyConcreteAttribute : public BaseTy {
1431 public:
1432   // Derived classes must define statics for:
1433   //   IsAFunctionTy isaFunction
1434   //   const char *pyClassName
1435   using ClassTy = py::class_<DerivedTy, BaseTy>;
1436   using IsAFunctionTy = bool (*)(MlirAttribute);
1437 
1438   PyConcreteAttribute() = default;
PyConcreteAttribute(PyMlirContextRef contextRef,MlirAttribute attr)1439   PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
1440       : BaseTy(std::move(contextRef), attr) {}
PyConcreteAttribute(PyAttribute & orig)1441   PyConcreteAttribute(PyAttribute &orig)
1442       : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
1443 
castFrom(PyAttribute & orig)1444   static MlirAttribute castFrom(PyAttribute &orig) {
1445     if (!DerivedTy::isaFunction(orig)) {
1446       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1447       throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") +
1448                                              DerivedTy::pyClassName +
1449                                              " (from " + origRepr + ")");
1450     }
1451     return orig;
1452   }
1453 
bind(py::module & m)1454   static void bind(py::module &m) {
1455     auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol());
1456     cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>());
1457     DerivedTy::bindDerived(cls);
1458   }
1459 
1460   /// Implemented by derived classes to add methods to the Python subclass.
bindDerived(ClassTy & m)1461   static void bindDerived(ClassTy &m) {}
1462 };
1463 
1464 /// Float Point Attribute subclass - FloatAttr.
1465 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
1466 public:
1467   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
1468   static constexpr const char *pyClassName = "FloatAttr";
1469   using PyConcreteAttribute::PyConcreteAttribute;
1470 
bindDerived(ClassTy & c)1471   static void bindDerived(ClassTy &c) {
1472     c.def_static(
1473         "get",
1474         [](PyType &type, double value, DefaultingPyLocation loc) {
1475           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(type, value, loc);
1476           // TODO: Rework error reporting once diagnostic engine is exposed
1477           // in C API.
1478           if (mlirAttributeIsNull(attr)) {
1479             throw SetPyError(PyExc_ValueError,
1480                              Twine("invalid '") +
1481                                  py::repr(py::cast(type)).cast<std::string>() +
1482                                  "' and expected floating point type.");
1483           }
1484           return PyFloatAttribute(type.getContext(), attr);
1485         },
1486         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
1487         "Gets an uniqued float point attribute associated to a type");
1488     c.def_static(
1489         "get_f32",
1490         [](double value, DefaultingPyMlirContext context) {
1491           MlirAttribute attr = mlirFloatAttrDoubleGet(
1492               context->get(), mlirF32TypeGet(context->get()), value);
1493           return PyFloatAttribute(context->getRef(), attr);
1494         },
1495         py::arg("value"), py::arg("context") = py::none(),
1496         "Gets an uniqued float point attribute associated to a f32 type");
1497     c.def_static(
1498         "get_f64",
1499         [](double value, DefaultingPyMlirContext context) {
1500           MlirAttribute attr = mlirFloatAttrDoubleGet(
1501               context->get(), mlirF64TypeGet(context->get()), value);
1502           return PyFloatAttribute(context->getRef(), attr);
1503         },
1504         py::arg("value"), py::arg("context") = py::none(),
1505         "Gets an uniqued float point attribute associated to a f64 type");
1506     c.def_property_readonly(
1507         "value",
1508         [](PyFloatAttribute &self) {
1509           return mlirFloatAttrGetValueDouble(self);
1510         },
1511         "Returns the value of the float point attribute");
1512   }
1513 };
1514 
1515 /// Integer Attribute subclass - IntegerAttr.
1516 class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
1517 public:
1518   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
1519   static constexpr const char *pyClassName = "IntegerAttr";
1520   using PyConcreteAttribute::PyConcreteAttribute;
1521 
bindDerived(ClassTy & c)1522   static void bindDerived(ClassTy &c) {
1523     c.def_static(
1524         "get",
1525         [](PyType &type, int64_t value) {
1526           MlirAttribute attr = mlirIntegerAttrGet(type, value);
1527           return PyIntegerAttribute(type.getContext(), attr);
1528         },
1529         py::arg("type"), py::arg("value"),
1530         "Gets an uniqued integer attribute associated to a type");
1531     c.def_property_readonly(
1532         "value",
1533         [](PyIntegerAttribute &self) {
1534           return mlirIntegerAttrGetValueInt(self);
1535         },
1536         "Returns the value of the integer attribute");
1537   }
1538 };
1539 
1540 /// Bool Attribute subclass - BoolAttr.
1541 class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
1542 public:
1543   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
1544   static constexpr const char *pyClassName = "BoolAttr";
1545   using PyConcreteAttribute::PyConcreteAttribute;
1546 
bindDerived(ClassTy & c)1547   static void bindDerived(ClassTy &c) {
1548     c.def_static(
1549         "get",
1550         [](bool value, DefaultingPyMlirContext context) {
1551           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
1552           return PyBoolAttribute(context->getRef(), attr);
1553         },
1554         py::arg("value"), py::arg("context") = py::none(),
1555         "Gets an uniqued bool attribute");
1556     c.def_property_readonly(
1557         "value",
1558         [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
1559         "Returns the value of the bool attribute");
1560   }
1561 };
1562 
1563 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
1564 public:
1565   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
1566   static constexpr const char *pyClassName = "StringAttr";
1567   using PyConcreteAttribute::PyConcreteAttribute;
1568 
bindDerived(ClassTy & c)1569   static void bindDerived(ClassTy &c) {
1570     c.def_static(
1571         "get",
1572         [](std::string value, DefaultingPyMlirContext context) {
1573           MlirAttribute attr =
1574               mlirStringAttrGet(context->get(), toMlirStringRef(value));
1575           return PyStringAttribute(context->getRef(), attr);
1576         },
1577         py::arg("value"), py::arg("context") = py::none(),
1578         "Gets a uniqued string attribute");
1579     c.def_static(
1580         "get_typed",
1581         [](PyType &type, std::string value) {
1582           MlirAttribute attr =
1583               mlirStringAttrTypedGet(type, toMlirStringRef(value));
1584           return PyStringAttribute(type.getContext(), attr);
1585         },
1586 
1587         "Gets a uniqued string attribute associated to a type");
1588     c.def_property_readonly(
1589         "value",
1590         [](PyStringAttribute &self) {
1591           MlirStringRef stringRef = mlirStringAttrGetValue(self);
1592           return py::str(stringRef.data, stringRef.length);
1593         },
1594         "Returns the value of the string attribute");
1595   }
1596 };
1597 
1598 // TODO: Support construction of bool elements.
1599 // TODO: Support construction of string elements.
1600 class PyDenseElementsAttribute
1601     : public PyConcreteAttribute<PyDenseElementsAttribute> {
1602 public:
1603   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
1604   static constexpr const char *pyClassName = "DenseElementsAttr";
1605   using PyConcreteAttribute::PyConcreteAttribute;
1606 
1607   static PyDenseElementsAttribute
getFromBuffer(py::buffer array,bool signless,DefaultingPyMlirContext contextWrapper)1608   getFromBuffer(py::buffer array, bool signless,
1609                 DefaultingPyMlirContext contextWrapper) {
1610     // Request a contiguous view. In exotic cases, this will cause a copy.
1611     int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
1612     Py_buffer *view = new Py_buffer();
1613     if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
1614       delete view;
1615       throw py::error_already_set();
1616     }
1617     py::buffer_info arrayInfo(view);
1618 
1619     MlirContext context = contextWrapper->get();
1620     // Switch on the types that can be bulk loaded between the Python and
1621     // MLIR-C APIs.
1622     // See: https://docs.python.org/3/library/struct.html#format-characters
1623     if (arrayInfo.format == "f") {
1624       // f32
1625       assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
1626       return PyDenseElementsAttribute(
1627           contextWrapper->getRef(),
1628           bulkLoad(context, mlirDenseElementsAttrFloatGet,
1629                    mlirF32TypeGet(context), arrayInfo));
1630     } else if (arrayInfo.format == "d") {
1631       // f64
1632       assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
1633       return PyDenseElementsAttribute(
1634           contextWrapper->getRef(),
1635           bulkLoad(context, mlirDenseElementsAttrDoubleGet,
1636                    mlirF64TypeGet(context), arrayInfo));
1637     } else if (isSignedIntegerFormat(arrayInfo.format)) {
1638       if (arrayInfo.itemsize == 4) {
1639         // i32
1640         MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
1641                                         : mlirIntegerTypeSignedGet(context, 32);
1642         return PyDenseElementsAttribute(contextWrapper->getRef(),
1643                                         bulkLoad(context,
1644                                                  mlirDenseElementsAttrInt32Get,
1645                                                  elementType, arrayInfo));
1646       } else if (arrayInfo.itemsize == 8) {
1647         // i64
1648         MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
1649                                         : mlirIntegerTypeSignedGet(context, 64);
1650         return PyDenseElementsAttribute(contextWrapper->getRef(),
1651                                         bulkLoad(context,
1652                                                  mlirDenseElementsAttrInt64Get,
1653                                                  elementType, arrayInfo));
1654       }
1655     } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
1656       if (arrayInfo.itemsize == 4) {
1657         // unsigned i32
1658         MlirType elementType = signless
1659                                    ? mlirIntegerTypeGet(context, 32)
1660                                    : mlirIntegerTypeUnsignedGet(context, 32);
1661         return PyDenseElementsAttribute(contextWrapper->getRef(),
1662                                         bulkLoad(context,
1663                                                  mlirDenseElementsAttrUInt32Get,
1664                                                  elementType, arrayInfo));
1665       } else if (arrayInfo.itemsize == 8) {
1666         // unsigned i64
1667         MlirType elementType = signless
1668                                    ? mlirIntegerTypeGet(context, 64)
1669                                    : mlirIntegerTypeUnsignedGet(context, 64);
1670         return PyDenseElementsAttribute(contextWrapper->getRef(),
1671                                         bulkLoad(context,
1672                                                  mlirDenseElementsAttrUInt64Get,
1673                                                  elementType, arrayInfo));
1674       }
1675     }
1676 
1677     // TODO: Fall back to string-based get.
1678     std::string message = "unimplemented array format conversion from format: ";
1679     message.append(arrayInfo.format);
1680     throw SetPyError(PyExc_ValueError, message);
1681   }
1682 
getSplat(PyType shapedType,PyAttribute & elementAttr)1683   static PyDenseElementsAttribute getSplat(PyType shapedType,
1684                                            PyAttribute &elementAttr) {
1685     auto contextWrapper =
1686         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
1687     if (!mlirAttributeIsAInteger(elementAttr) &&
1688         !mlirAttributeIsAFloat(elementAttr)) {
1689       std::string message = "Illegal element type for DenseElementsAttr: ";
1690       message.append(py::repr(py::cast(elementAttr)));
1691       throw SetPyError(PyExc_ValueError, message);
1692     }
1693     if (!mlirTypeIsAShaped(shapedType) ||
1694         !mlirShapedTypeHasStaticShape(shapedType)) {
1695       std::string message =
1696           "Expected a static ShapedType for the shaped_type parameter: ";
1697       message.append(py::repr(py::cast(shapedType)));
1698       throw SetPyError(PyExc_ValueError, message);
1699     }
1700     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
1701     MlirType attrType = mlirAttributeGetType(elementAttr);
1702     if (!mlirTypeEqual(shapedElementType, attrType)) {
1703       std::string message =
1704           "Shaped element type and attribute type must be equal: shaped=";
1705       message.append(py::repr(py::cast(shapedType)));
1706       message.append(", element=");
1707       message.append(py::repr(py::cast(elementAttr)));
1708       throw SetPyError(PyExc_ValueError, message);
1709     }
1710 
1711     MlirAttribute elements =
1712         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
1713     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
1714   }
1715 
dunderLen()1716   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
1717 
accessBuffer()1718   py::buffer_info accessBuffer() {
1719     MlirType shapedType = mlirAttributeGetType(*this);
1720     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
1721 
1722     if (mlirTypeIsAF32(elementType)) {
1723       // f32
1724       return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue);
1725     } else if (mlirTypeIsAF64(elementType)) {
1726       // f64
1727       return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue);
1728     } else if (mlirTypeIsAInteger(elementType) &&
1729                mlirIntegerTypeGetWidth(elementType) == 32) {
1730       if (mlirIntegerTypeIsSignless(elementType) ||
1731           mlirIntegerTypeIsSigned(elementType)) {
1732         // i32
1733         return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value);
1734       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
1735         // unsigned i32
1736         return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value);
1737       }
1738     } else if (mlirTypeIsAInteger(elementType) &&
1739                mlirIntegerTypeGetWidth(elementType) == 64) {
1740       if (mlirIntegerTypeIsSignless(elementType) ||
1741           mlirIntegerTypeIsSigned(elementType)) {
1742         // i64
1743         return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value);
1744       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
1745         // unsigned i64
1746         return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value);
1747       }
1748     }
1749 
1750     std::string message = "unimplemented array format.";
1751     throw SetPyError(PyExc_ValueError, message);
1752   }
1753 
bindDerived(ClassTy & c)1754   static void bindDerived(ClassTy &c) {
1755     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
1756         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
1757                     py::arg("array"), py::arg("signless") = true,
1758                     py::arg("context") = py::none(),
1759                     "Gets from a buffer or ndarray")
1760         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
1761                     py::arg("shaped_type"), py::arg("element_attr"),
1762                     "Gets a DenseElementsAttr where all values are the same")
1763         .def_property_readonly("is_splat",
1764                                [](PyDenseElementsAttribute &self) -> bool {
1765                                  return mlirDenseElementsAttrIsSplat(self);
1766                                })
1767         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
1768   }
1769 
1770 private:
1771   template <typename ElementTy>
1772   static MlirAttribute
bulkLoad(MlirContext context,MlirAttribute (* ctor)(MlirType,intptr_t,ElementTy *),MlirType mlirElementType,py::buffer_info & arrayInfo)1773   bulkLoad(MlirContext context,
1774            MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *),
1775            MlirType mlirElementType, py::buffer_info &arrayInfo) {
1776     SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(),
1777                                   arrayInfo.shape.begin() + arrayInfo.ndim);
1778     auto shapedType =
1779         mlirRankedTensorTypeGet(shape.size(), shape.data(), mlirElementType);
1780     intptr_t numElements = arrayInfo.size;
1781     const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr);
1782     return ctor(shapedType, numElements, contents);
1783   }
1784 
isUnsignedIntegerFormat(const std::string & format)1785   static bool isUnsignedIntegerFormat(const std::string &format) {
1786     if (format.empty())
1787       return false;
1788     char code = format[0];
1789     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
1790            code == 'Q';
1791   }
1792 
isSignedIntegerFormat(const std::string & format)1793   static bool isSignedIntegerFormat(const std::string &format) {
1794     if (format.empty())
1795       return false;
1796     char code = format[0];
1797     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
1798            code == 'q';
1799   }
1800 
1801   template <typename Type>
bufferInfo(MlirType shapedType,Type (* value)(MlirAttribute,intptr_t))1802   py::buffer_info bufferInfo(MlirType shapedType,
1803                              Type (*value)(MlirAttribute, intptr_t)) {
1804     intptr_t rank = mlirShapedTypeGetRank(shapedType);
1805     // Prepare the data for the buffer_info.
1806     // Buffer is configured for read-only access below.
1807     Type *data = static_cast<Type *>(
1808         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
1809     // Prepare the shape for the buffer_info.
1810     SmallVector<intptr_t, 4> shape;
1811     for (intptr_t i = 0; i < rank; ++i)
1812       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
1813     // Prepare the strides for the buffer_info.
1814     SmallVector<intptr_t, 4> strides;
1815     intptr_t strideFactor = 1;
1816     for (intptr_t i = 1; i < rank; ++i) {
1817       strideFactor = 1;
1818       for (intptr_t j = i; j < rank; ++j) {
1819         strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
1820       }
1821       strides.push_back(sizeof(Type) * strideFactor);
1822     }
1823     strides.push_back(sizeof(Type));
1824     return py::buffer_info(data, sizeof(Type),
1825                            py::format_descriptor<Type>::format(), rank, shape,
1826                            strides, /*readonly=*/true);
1827   }
1828 }; // namespace
1829 
1830 /// Refinement of the PyDenseElementsAttribute for attributes containing integer
1831 /// (and boolean) values. Supports element access.
1832 class PyDenseIntElementsAttribute
1833     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
1834                                  PyDenseElementsAttribute> {
1835 public:
1836   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
1837   static constexpr const char *pyClassName = "DenseIntElementsAttr";
1838   using PyConcreteAttribute::PyConcreteAttribute;
1839 
1840   /// Returns the element at the given linear position. Asserts if the index is
1841   /// out of range.
dunderGetItem(intptr_t pos)1842   py::int_ dunderGetItem(intptr_t pos) {
1843     if (pos < 0 || pos >= dunderLen()) {
1844       throw SetPyError(PyExc_IndexError,
1845                        "attempt to access out of bounds element");
1846     }
1847 
1848     MlirType type = mlirAttributeGetType(*this);
1849     type = mlirShapedTypeGetElementType(type);
1850     assert(mlirTypeIsAInteger(type) &&
1851            "expected integer element type in dense int elements attribute");
1852     // Dispatch element extraction to an appropriate C function based on the
1853     // elemental type of the attribute. py::int_ is implicitly constructible
1854     // from any C++ integral type and handles bitwidth correctly.
1855     // TODO: consider caching the type properties in the constructor to avoid
1856     // querying them on each element access.
1857     unsigned width = mlirIntegerTypeGetWidth(type);
1858     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
1859     if (isUnsigned) {
1860       if (width == 1) {
1861         return mlirDenseElementsAttrGetBoolValue(*this, pos);
1862       }
1863       if (width == 32) {
1864         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
1865       }
1866       if (width == 64) {
1867         return mlirDenseElementsAttrGetUInt64Value(*this, pos);
1868       }
1869     } else {
1870       if (width == 1) {
1871         return mlirDenseElementsAttrGetBoolValue(*this, pos);
1872       }
1873       if (width == 32) {
1874         return mlirDenseElementsAttrGetInt32Value(*this, pos);
1875       }
1876       if (width == 64) {
1877         return mlirDenseElementsAttrGetInt64Value(*this, pos);
1878       }
1879     }
1880     throw SetPyError(PyExc_TypeError, "Unsupported integer type");
1881   }
1882 
bindDerived(ClassTy & c)1883   static void bindDerived(ClassTy &c) {
1884     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
1885   }
1886 };
1887 
1888 /// Refinement of PyDenseElementsAttribute for attributes containing
1889 /// floating-point values. Supports element access.
1890 class PyDenseFPElementsAttribute
1891     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
1892                                  PyDenseElementsAttribute> {
1893 public:
1894   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
1895   static constexpr const char *pyClassName = "DenseFPElementsAttr";
1896   using PyConcreteAttribute::PyConcreteAttribute;
1897 
dunderGetItem(intptr_t pos)1898   py::float_ dunderGetItem(intptr_t pos) {
1899     if (pos < 0 || pos >= dunderLen()) {
1900       throw SetPyError(PyExc_IndexError,
1901                        "attempt to access out of bounds element");
1902     }
1903 
1904     MlirType type = mlirAttributeGetType(*this);
1905     type = mlirShapedTypeGetElementType(type);
1906     // Dispatch element extraction to an appropriate C function based on the
1907     // elemental type of the attribute. py::float_ is implicitly constructible
1908     // from float and double.
1909     // TODO: consider caching the type properties in the constructor to avoid
1910     // querying them on each element access.
1911     if (mlirTypeIsAF32(type)) {
1912       return mlirDenseElementsAttrGetFloatValue(*this, pos);
1913     }
1914     if (mlirTypeIsAF64(type)) {
1915       return mlirDenseElementsAttrGetDoubleValue(*this, pos);
1916     }
1917     throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
1918   }
1919 
bindDerived(ClassTy & c)1920   static void bindDerived(ClassTy &c) {
1921     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
1922   }
1923 };
1924 
1925 class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
1926 public:
1927   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
1928   static constexpr const char *pyClassName = "TypeAttr";
1929   using PyConcreteAttribute::PyConcreteAttribute;
1930 
bindDerived(ClassTy & c)1931   static void bindDerived(ClassTy &c) {
1932     c.def_static(
1933         "get",
1934         [](PyType value, DefaultingPyMlirContext context) {
1935           MlirAttribute attr = mlirTypeAttrGet(value.get());
1936           return PyTypeAttribute(context->getRef(), attr);
1937         },
1938         py::arg("value"), py::arg("context") = py::none(),
1939         "Gets a uniqued Type attribute");
1940     c.def_property_readonly("value", [](PyTypeAttribute &self) {
1941       return PyType(self.getContext()->getRef(),
1942                     mlirTypeAttrGetValue(self.get()));
1943     });
1944   }
1945 };
1946 
1947 /// Unit Attribute subclass. Unit attributes don't have values.
1948 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
1949 public:
1950   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
1951   static constexpr const char *pyClassName = "UnitAttr";
1952   using PyConcreteAttribute::PyConcreteAttribute;
1953 
bindDerived(ClassTy & c)1954   static void bindDerived(ClassTy &c) {
1955     c.def_static(
1956         "get",
1957         [](DefaultingPyMlirContext context) {
1958           return PyUnitAttribute(context->getRef(),
1959                                  mlirUnitAttrGet(context->get()));
1960         },
1961         py::arg("context") = py::none(), "Create a Unit attribute.");
1962   }
1963 };
1964 
1965 } // namespace
1966 
1967 //------------------------------------------------------------------------------
1968 // Builtin type subclasses.
1969 //------------------------------------------------------------------------------
1970 
1971 namespace {
1972 
1973 /// CRTP base classes for Python types that subclass Type and should be
1974 /// castable from it (i.e. via something like IntegerType(t)).
1975 /// By default, type class hierarchies are one level deep (i.e. a
1976 /// concrete type class extends PyType); however, intermediate python-visible
1977 /// base classes can be modeled by specifying a BaseTy.
1978 template <typename DerivedTy, typename BaseTy = PyType>
1979 class PyConcreteType : public BaseTy {
1980 public:
1981   // Derived classes must define statics for:
1982   //   IsAFunctionTy isaFunction
1983   //   const char *pyClassName
1984   using ClassTy = py::class_<DerivedTy, BaseTy>;
1985   using IsAFunctionTy = bool (*)(MlirType);
1986 
1987   PyConcreteType() = default;
PyConcreteType(PyMlirContextRef contextRef,MlirType t)1988   PyConcreteType(PyMlirContextRef contextRef, MlirType t)
1989       : BaseTy(std::move(contextRef), t) {}
PyConcreteType(PyType & orig)1990   PyConcreteType(PyType &orig)
1991       : PyConcreteType(orig.getContext(), castFrom(orig)) {}
1992 
castFrom(PyType & orig)1993   static MlirType castFrom(PyType &orig) {
1994     if (!DerivedTy::isaFunction(orig)) {
1995       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1996       throw SetPyError(PyExc_ValueError, Twine("Cannot cast type to ") +
1997                                              DerivedTy::pyClassName +
1998                                              " (from " + origRepr + ")");
1999     }
2000     return orig;
2001   }
2002 
bind(py::module & m)2003   static void bind(py::module &m) {
2004     auto cls = ClassTy(m, DerivedTy::pyClassName);
2005     cls.def(py::init<PyType &>(), py::keep_alive<0, 1>());
2006     DerivedTy::bindDerived(cls);
2007   }
2008 
2009   /// Implemented by derived classes to add methods to the Python subclass.
bindDerived(ClassTy & m)2010   static void bindDerived(ClassTy &m) {}
2011 };
2012 
2013 class PyIntegerType : public PyConcreteType<PyIntegerType> {
2014 public:
2015   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
2016   static constexpr const char *pyClassName = "IntegerType";
2017   using PyConcreteType::PyConcreteType;
2018 
bindDerived(ClassTy & c)2019   static void bindDerived(ClassTy &c) {
2020     c.def_static(
2021         "get_signless",
2022         [](unsigned width, DefaultingPyMlirContext context) {
2023           MlirType t = mlirIntegerTypeGet(context->get(), width);
2024           return PyIntegerType(context->getRef(), t);
2025         },
2026         py::arg("width"), py::arg("context") = py::none(),
2027         "Create a signless integer type");
2028     c.def_static(
2029         "get_signed",
2030         [](unsigned width, DefaultingPyMlirContext context) {
2031           MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
2032           return PyIntegerType(context->getRef(), t);
2033         },
2034         py::arg("width"), py::arg("context") = py::none(),
2035         "Create a signed integer type");
2036     c.def_static(
2037         "get_unsigned",
2038         [](unsigned width, DefaultingPyMlirContext context) {
2039           MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
2040           return PyIntegerType(context->getRef(), t);
2041         },
2042         py::arg("width"), py::arg("context") = py::none(),
2043         "Create an unsigned integer type");
2044     c.def_property_readonly(
2045         "width",
2046         [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
2047         "Returns the width of the integer type");
2048     c.def_property_readonly(
2049         "is_signless",
2050         [](PyIntegerType &self) -> bool {
2051           return mlirIntegerTypeIsSignless(self);
2052         },
2053         "Returns whether this is a signless integer");
2054     c.def_property_readonly(
2055         "is_signed",
2056         [](PyIntegerType &self) -> bool {
2057           return mlirIntegerTypeIsSigned(self);
2058         },
2059         "Returns whether this is a signed integer");
2060     c.def_property_readonly(
2061         "is_unsigned",
2062         [](PyIntegerType &self) -> bool {
2063           return mlirIntegerTypeIsUnsigned(self);
2064         },
2065         "Returns whether this is an unsigned integer");
2066   }
2067 };
2068 
2069 /// Index Type subclass - IndexType.
2070 class PyIndexType : public PyConcreteType<PyIndexType> {
2071 public:
2072   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
2073   static constexpr const char *pyClassName = "IndexType";
2074   using PyConcreteType::PyConcreteType;
2075 
bindDerived(ClassTy & c)2076   static void bindDerived(ClassTy &c) {
2077     c.def_static(
2078         "get",
2079         [](DefaultingPyMlirContext context) {
2080           MlirType t = mlirIndexTypeGet(context->get());
2081           return PyIndexType(context->getRef(), t);
2082         },
2083         py::arg("context") = py::none(), "Create a index type.");
2084   }
2085 };
2086 
2087 /// Floating Point Type subclass - BF16Type.
2088 class PyBF16Type : public PyConcreteType<PyBF16Type> {
2089 public:
2090   static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
2091   static constexpr const char *pyClassName = "BF16Type";
2092   using PyConcreteType::PyConcreteType;
2093 
bindDerived(ClassTy & c)2094   static void bindDerived(ClassTy &c) {
2095     c.def_static(
2096         "get",
2097         [](DefaultingPyMlirContext context) {
2098           MlirType t = mlirBF16TypeGet(context->get());
2099           return PyBF16Type(context->getRef(), t);
2100         },
2101         py::arg("context") = py::none(), "Create a bf16 type.");
2102   }
2103 };
2104 
2105 /// Floating Point Type subclass - F16Type.
2106 class PyF16Type : public PyConcreteType<PyF16Type> {
2107 public:
2108   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
2109   static constexpr const char *pyClassName = "F16Type";
2110   using PyConcreteType::PyConcreteType;
2111 
bindDerived(ClassTy & c)2112   static void bindDerived(ClassTy &c) {
2113     c.def_static(
2114         "get",
2115         [](DefaultingPyMlirContext context) {
2116           MlirType t = mlirF16TypeGet(context->get());
2117           return PyF16Type(context->getRef(), t);
2118         },
2119         py::arg("context") = py::none(), "Create a f16 type.");
2120   }
2121 };
2122 
2123 /// Floating Point Type subclass - F32Type.
2124 class PyF32Type : public PyConcreteType<PyF32Type> {
2125 public:
2126   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
2127   static constexpr const char *pyClassName = "F32Type";
2128   using PyConcreteType::PyConcreteType;
2129 
bindDerived(ClassTy & c)2130   static void bindDerived(ClassTy &c) {
2131     c.def_static(
2132         "get",
2133         [](DefaultingPyMlirContext context) {
2134           MlirType t = mlirF32TypeGet(context->get());
2135           return PyF32Type(context->getRef(), t);
2136         },
2137         py::arg("context") = py::none(), "Create a f32 type.");
2138   }
2139 };
2140 
2141 /// Floating Point Type subclass - F64Type.
2142 class PyF64Type : public PyConcreteType<PyF64Type> {
2143 public:
2144   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
2145   static constexpr const char *pyClassName = "F64Type";
2146   using PyConcreteType::PyConcreteType;
2147 
bindDerived(ClassTy & c)2148   static void bindDerived(ClassTy &c) {
2149     c.def_static(
2150         "get",
2151         [](DefaultingPyMlirContext context) {
2152           MlirType t = mlirF64TypeGet(context->get());
2153           return PyF64Type(context->getRef(), t);
2154         },
2155         py::arg("context") = py::none(), "Create a f64 type.");
2156   }
2157 };
2158 
2159 /// None Type subclass - NoneType.
2160 class PyNoneType : public PyConcreteType<PyNoneType> {
2161 public:
2162   static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
2163   static constexpr const char *pyClassName = "NoneType";
2164   using PyConcreteType::PyConcreteType;
2165 
bindDerived(ClassTy & c)2166   static void bindDerived(ClassTy &c) {
2167     c.def_static(
2168         "get",
2169         [](DefaultingPyMlirContext context) {
2170           MlirType t = mlirNoneTypeGet(context->get());
2171           return PyNoneType(context->getRef(), t);
2172         },
2173         py::arg("context") = py::none(), "Create a none type.");
2174   }
2175 };
2176 
2177 /// Complex Type subclass - ComplexType.
2178 class PyComplexType : public PyConcreteType<PyComplexType> {
2179 public:
2180   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
2181   static constexpr const char *pyClassName = "ComplexType";
2182   using PyConcreteType::PyConcreteType;
2183 
bindDerived(ClassTy & c)2184   static void bindDerived(ClassTy &c) {
2185     c.def_static(
2186         "get",
2187         [](PyType &elementType) {
2188           // The element must be a floating point or integer scalar type.
2189           if (mlirTypeIsAIntegerOrFloat(elementType)) {
2190             MlirType t = mlirComplexTypeGet(elementType);
2191             return PyComplexType(elementType.getContext(), t);
2192           }
2193           throw SetPyError(
2194               PyExc_ValueError,
2195               Twine("invalid '") +
2196                   py::repr(py::cast(elementType)).cast<std::string>() +
2197                   "' and expected floating point or integer type.");
2198         },
2199         "Create a complex type");
2200     c.def_property_readonly(
2201         "element_type",
2202         [](PyComplexType &self) -> PyType {
2203           MlirType t = mlirComplexTypeGetElementType(self);
2204           return PyType(self.getContext(), t);
2205         },
2206         "Returns element type.");
2207   }
2208 };
2209 
2210 class PyShapedType : public PyConcreteType<PyShapedType> {
2211 public:
2212   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped;
2213   static constexpr const char *pyClassName = "ShapedType";
2214   using PyConcreteType::PyConcreteType;
2215 
bindDerived(ClassTy & c)2216   static void bindDerived(ClassTy &c) {
2217     c.def_property_readonly(
2218         "element_type",
2219         [](PyShapedType &self) {
2220           MlirType t = mlirShapedTypeGetElementType(self);
2221           return PyType(self.getContext(), t);
2222         },
2223         "Returns the element type of the shaped type.");
2224     c.def_property_readonly(
2225         "has_rank",
2226         [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
2227         "Returns whether the given shaped type is ranked.");
2228     c.def_property_readonly(
2229         "rank",
2230         [](PyShapedType &self) {
2231           self.requireHasRank();
2232           return mlirShapedTypeGetRank(self);
2233         },
2234         "Returns the rank of the given ranked shaped type.");
2235     c.def_property_readonly(
2236         "has_static_shape",
2237         [](PyShapedType &self) -> bool {
2238           return mlirShapedTypeHasStaticShape(self);
2239         },
2240         "Returns whether the given shaped type has a static shape.");
2241     c.def(
2242         "is_dynamic_dim",
2243         [](PyShapedType &self, intptr_t dim) -> bool {
2244           self.requireHasRank();
2245           return mlirShapedTypeIsDynamicDim(self, dim);
2246         },
2247         "Returns whether the dim-th dimension of the given shaped type is "
2248         "dynamic.");
2249     c.def(
2250         "get_dim_size",
2251         [](PyShapedType &self, intptr_t dim) {
2252           self.requireHasRank();
2253           return mlirShapedTypeGetDimSize(self, dim);
2254         },
2255         "Returns the dim-th dimension of the given ranked shaped type.");
2256     c.def_static(
2257         "is_dynamic_size",
2258         [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
2259         "Returns whether the given dimension size indicates a dynamic "
2260         "dimension.");
2261     c.def(
2262         "is_dynamic_stride_or_offset",
2263         [](PyShapedType &self, int64_t val) -> bool {
2264           self.requireHasRank();
2265           return mlirShapedTypeIsDynamicStrideOrOffset(val);
2266         },
2267         "Returns whether the given value is used as a placeholder for dynamic "
2268         "strides and offsets in shaped types.");
2269   }
2270 
2271 private:
requireHasRank()2272   void requireHasRank() {
2273     if (!mlirShapedTypeHasRank(*this)) {
2274       throw SetPyError(
2275           PyExc_ValueError,
2276           "calling this method requires that the type has a rank.");
2277     }
2278   }
2279 };
2280 
2281 /// Vector Type subclass - VectorType.
2282 class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
2283 public:
2284   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
2285   static constexpr const char *pyClassName = "VectorType";
2286   using PyConcreteType::PyConcreteType;
2287 
bindDerived(ClassTy & c)2288   static void bindDerived(ClassTy &c) {
2289     c.def_static(
2290         "get",
2291         [](std::vector<int64_t> shape, PyType &elementType,
2292            DefaultingPyLocation loc) {
2293           MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(),
2294                                                 elementType, loc);
2295           // TODO: Rework error reporting once diagnostic engine is exposed
2296           // in C API.
2297           if (mlirTypeIsNull(t)) {
2298             throw SetPyError(
2299                 PyExc_ValueError,
2300                 Twine("invalid '") +
2301                     py::repr(py::cast(elementType)).cast<std::string>() +
2302                     "' and expected floating point or integer type.");
2303           }
2304           return PyVectorType(elementType.getContext(), t);
2305         },
2306         py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(),
2307         "Create a vector type");
2308   }
2309 };
2310 
2311 /// Ranked Tensor Type subclass - RankedTensorType.
2312 class PyRankedTensorType
2313     : public PyConcreteType<PyRankedTensorType, PyShapedType> {
2314 public:
2315   static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
2316   static constexpr const char *pyClassName = "RankedTensorType";
2317   using PyConcreteType::PyConcreteType;
2318 
bindDerived(ClassTy & c)2319   static void bindDerived(ClassTy &c) {
2320     c.def_static(
2321         "get",
2322         [](std::vector<int64_t> shape, PyType &elementType,
2323            DefaultingPyLocation loc) {
2324           MlirType t = mlirRankedTensorTypeGetChecked(
2325               shape.size(), shape.data(), elementType, loc);
2326           // TODO: Rework error reporting once diagnostic engine is exposed
2327           // in C API.
2328           if (mlirTypeIsNull(t)) {
2329             throw SetPyError(
2330                 PyExc_ValueError,
2331                 Twine("invalid '") +
2332                     py::repr(py::cast(elementType)).cast<std::string>() +
2333                     "' and expected floating point, integer, vector or "
2334                     "complex "
2335                     "type.");
2336           }
2337           return PyRankedTensorType(elementType.getContext(), t);
2338         },
2339         py::arg("shape"), py::arg("element_type"), py::arg("loc") = py::none(),
2340         "Create a ranked tensor type");
2341   }
2342 };
2343 
2344 /// Unranked Tensor Type subclass - UnrankedTensorType.
2345 class PyUnrankedTensorType
2346     : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
2347 public:
2348   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
2349   static constexpr const char *pyClassName = "UnrankedTensorType";
2350   using PyConcreteType::PyConcreteType;
2351 
bindDerived(ClassTy & c)2352   static void bindDerived(ClassTy &c) {
2353     c.def_static(
2354         "get",
2355         [](PyType &elementType, DefaultingPyLocation loc) {
2356           MlirType t = mlirUnrankedTensorTypeGetChecked(elementType, loc);
2357           // TODO: Rework error reporting once diagnostic engine is exposed
2358           // in C API.
2359           if (mlirTypeIsNull(t)) {
2360             throw SetPyError(
2361                 PyExc_ValueError,
2362                 Twine("invalid '") +
2363                     py::repr(py::cast(elementType)).cast<std::string>() +
2364                     "' and expected floating point, integer, vector or "
2365                     "complex "
2366                     "type.");
2367           }
2368           return PyUnrankedTensorType(elementType.getContext(), t);
2369         },
2370         py::arg("element_type"), py::arg("loc") = py::none(),
2371         "Create a unranked tensor type");
2372   }
2373 };
2374 
2375 /// Ranked MemRef Type subclass - MemRefType.
2376 class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
2377 public:
2378   static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
2379   static constexpr const char *pyClassName = "MemRefType";
2380   using PyConcreteType::PyConcreteType;
2381 
bindDerived(ClassTy & c)2382   static void bindDerived(ClassTy &c) {
2383     // TODO: Add mlirMemRefTypeGet and mlirMemRefTypeGetAffineMap binding
2384     // once the affine map binding is completed.
2385     c.def_static(
2386          "get_contiguous_memref",
2387          // TODO: Make the location optional and create a default location.
2388          [](PyType &elementType, std::vector<int64_t> shape,
2389             unsigned memorySpace, DefaultingPyLocation loc) {
2390            MlirType t = mlirMemRefTypeContiguousGetChecked(
2391                elementType, shape.size(), shape.data(), memorySpace, loc);
2392            // TODO: Rework error reporting once diagnostic engine is exposed
2393            // in C API.
2394            if (mlirTypeIsNull(t)) {
2395              throw SetPyError(
2396                  PyExc_ValueError,
2397                  Twine("invalid '") +
2398                      py::repr(py::cast(elementType)).cast<std::string>() +
2399                      "' and expected floating point, integer, vector or "
2400                      "complex "
2401                      "type.");
2402            }
2403            return PyMemRefType(elementType.getContext(), t);
2404          },
2405          py::arg("element_type"), py::arg("shape"), py::arg("memory_space"),
2406          py::arg("loc") = py::none(), "Create a memref type")
2407         .def_property_readonly(
2408             "num_affine_maps",
2409             [](PyMemRefType &self) -> intptr_t {
2410               return mlirMemRefTypeGetNumAffineMaps(self);
2411             },
2412             "Returns the number of affine layout maps in the given MemRef "
2413             "type.")
2414         .def_property_readonly(
2415             "memory_space",
2416             [](PyMemRefType &self) -> unsigned {
2417               return mlirMemRefTypeGetMemorySpace(self);
2418             },
2419             "Returns the memory space of the given MemRef type.");
2420   }
2421 };
2422 
2423 /// Unranked MemRef Type subclass - UnrankedMemRefType.
2424 class PyUnrankedMemRefType
2425     : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
2426 public:
2427   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
2428   static constexpr const char *pyClassName = "UnrankedMemRefType";
2429   using PyConcreteType::PyConcreteType;
2430 
bindDerived(ClassTy & c)2431   static void bindDerived(ClassTy &c) {
2432     c.def_static(
2433          "get",
2434          [](PyType &elementType, unsigned memorySpace,
2435             DefaultingPyLocation loc) {
2436            MlirType t =
2437                mlirUnrankedMemRefTypeGetChecked(elementType, memorySpace, loc);
2438            // TODO: Rework error reporting once diagnostic engine is exposed
2439            // in C API.
2440            if (mlirTypeIsNull(t)) {
2441              throw SetPyError(
2442                  PyExc_ValueError,
2443                  Twine("invalid '") +
2444                      py::repr(py::cast(elementType)).cast<std::string>() +
2445                      "' and expected floating point, integer, vector or "
2446                      "complex "
2447                      "type.");
2448            }
2449            return PyUnrankedMemRefType(elementType.getContext(), t);
2450          },
2451          py::arg("element_type"), py::arg("memory_space"),
2452          py::arg("loc") = py::none(), "Create a unranked memref type")
2453         .def_property_readonly(
2454             "memory_space",
2455             [](PyUnrankedMemRefType &self) -> unsigned {
2456               return mlirUnrankedMemrefGetMemorySpace(self);
2457             },
2458             "Returns the memory space of the given Unranked MemRef type.");
2459   }
2460 };
2461 
2462 /// Tuple Type subclass - TupleType.
2463 class PyTupleType : public PyConcreteType<PyTupleType> {
2464 public:
2465   static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
2466   static constexpr const char *pyClassName = "TupleType";
2467   using PyConcreteType::PyConcreteType;
2468 
bindDerived(ClassTy & c)2469   static void bindDerived(ClassTy &c) {
2470     c.def_static(
2471         "get_tuple",
2472         [](py::list elementList, DefaultingPyMlirContext context) {
2473           intptr_t num = py::len(elementList);
2474           // Mapping py::list to SmallVector.
2475           SmallVector<MlirType, 4> elements;
2476           for (auto element : elementList)
2477             elements.push_back(element.cast<PyType>());
2478           MlirType t = mlirTupleTypeGet(context->get(), num, elements.data());
2479           return PyTupleType(context->getRef(), t);
2480         },
2481         py::arg("elements"), py::arg("context") = py::none(),
2482         "Create a tuple type");
2483     c.def(
2484         "get_type",
2485         [](PyTupleType &self, intptr_t pos) -> PyType {
2486           MlirType t = mlirTupleTypeGetType(self, pos);
2487           return PyType(self.getContext(), t);
2488         },
2489         "Returns the pos-th type in the tuple type.");
2490     c.def_property_readonly(
2491         "num_types",
2492         [](PyTupleType &self) -> intptr_t {
2493           return mlirTupleTypeGetNumTypes(self);
2494         },
2495         "Returns the number of types contained in a tuple.");
2496   }
2497 };
2498 
2499 /// Function type.
2500 class PyFunctionType : public PyConcreteType<PyFunctionType> {
2501 public:
2502   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
2503   static constexpr const char *pyClassName = "FunctionType";
2504   using PyConcreteType::PyConcreteType;
2505 
bindDerived(ClassTy & c)2506   static void bindDerived(ClassTy &c) {
2507     c.def_static(
2508         "get",
2509         [](std::vector<PyType> inputs, std::vector<PyType> results,
2510            DefaultingPyMlirContext context) {
2511           SmallVector<MlirType, 4> inputsRaw(inputs.begin(), inputs.end());
2512           SmallVector<MlirType, 4> resultsRaw(results.begin(), results.end());
2513           MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(),
2514                                            inputsRaw.data(), resultsRaw.size(),
2515                                            resultsRaw.data());
2516           return PyFunctionType(context->getRef(), t);
2517         },
2518         py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(),
2519         "Gets a FunctionType from a list of input and result types");
2520     c.def_property_readonly(
2521         "inputs",
2522         [](PyFunctionType &self) {
2523           MlirType t = self;
2524           auto contextRef = self.getContext();
2525           py::list types;
2526           for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
2527                ++i) {
2528             types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i)));
2529           }
2530           return types;
2531         },
2532         "Returns the list of input types in the FunctionType.");
2533     c.def_property_readonly(
2534         "results",
2535         [](PyFunctionType &self) {
2536           auto contextRef = self.getContext();
2537           py::list types;
2538           for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
2539                ++i) {
2540             types.append(
2541                 PyType(contextRef, mlirFunctionTypeGetResult(self, i)));
2542           }
2543           return types;
2544         },
2545         "Returns the list of result types in the FunctionType.");
2546   }
2547 };
2548 
2549 } // namespace
2550 
2551 //------------------------------------------------------------------------------
2552 // Populates the pybind11 IR submodule.
2553 //------------------------------------------------------------------------------
2554 
populateIRSubmodule(py::module & m)2555 void mlir::python::populateIRSubmodule(py::module &m) {
2556   //----------------------------------------------------------------------------
2557   // Mapping of MlirContext
2558   //----------------------------------------------------------------------------
2559   py::class_<PyMlirContext>(m, "Context")
2560       .def(py::init<>(&PyMlirContext::createNewContextForInit))
2561       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2562       .def("_get_context_again",
2563            [](PyMlirContext &self) {
2564              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2565              return ref.releaseObject();
2566            })
2567       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
2568       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2569       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2570                              &PyMlirContext::getCapsule)
2571       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2572       .def("__enter__", &PyMlirContext::contextEnter)
2573       .def("__exit__", &PyMlirContext::contextExit)
2574       .def_property_readonly_static(
2575           "current",
2576           [](py::object & /*class*/) {
2577             auto *context = PyThreadContextEntry::getDefaultContext();
2578             if (!context)
2579               throw SetPyError(PyExc_ValueError, "No current Context");
2580             return context;
2581           },
2582           "Gets the Context bound to the current thread or raises ValueError")
2583       .def_property_readonly(
2584           "dialects",
2585           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2586           "Gets a container for accessing dialects by name")
2587       .def_property_readonly(
2588           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2589           "Alias for 'dialect'")
2590       .def(
2591           "get_dialect_descriptor",
2592           [=](PyMlirContext &self, std::string &name) {
2593             MlirDialect dialect = mlirContextGetOrLoadDialect(
2594                 self.get(), {name.data(), name.size()});
2595             if (mlirDialectIsNull(dialect)) {
2596               throw SetPyError(PyExc_ValueError,
2597                                Twine("Dialect '") + name + "' not found");
2598             }
2599             return PyDialectDescriptor(self.getRef(), dialect);
2600           },
2601           "Gets or loads a dialect by name, returning its descriptor object")
2602       .def_property(
2603           "allow_unregistered_dialects",
2604           [](PyMlirContext &self) -> bool {
2605             return mlirContextGetAllowUnregisteredDialects(self.get());
2606           },
2607           [](PyMlirContext &self, bool value) {
2608             mlirContextSetAllowUnregisteredDialects(self.get(), value);
2609           });
2610 
2611   //----------------------------------------------------------------------------
2612   // Mapping of PyDialectDescriptor
2613   //----------------------------------------------------------------------------
2614   py::class_<PyDialectDescriptor>(m, "DialectDescriptor")
2615       .def_property_readonly("namespace",
2616                              [](PyDialectDescriptor &self) {
2617                                MlirStringRef ns =
2618                                    mlirDialectGetNamespace(self.get());
2619                                return py::str(ns.data, ns.length);
2620                              })
2621       .def("__repr__", [](PyDialectDescriptor &self) {
2622         MlirStringRef ns = mlirDialectGetNamespace(self.get());
2623         std::string repr("<DialectDescriptor ");
2624         repr.append(ns.data, ns.length);
2625         repr.append(">");
2626         return repr;
2627       });
2628 
2629   //----------------------------------------------------------------------------
2630   // Mapping of PyDialects
2631   //----------------------------------------------------------------------------
2632   py::class_<PyDialects>(m, "Dialects")
2633       .def("__getitem__",
2634            [=](PyDialects &self, std::string keyName) {
2635              MlirDialect dialect =
2636                  self.getDialectForKey(keyName, /*attrError=*/false);
2637              py::object descriptor =
2638                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
2639              return createCustomDialectWrapper(keyName, std::move(descriptor));
2640            })
2641       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2642         MlirDialect dialect =
2643             self.getDialectForKey(attrName, /*attrError=*/true);
2644         py::object descriptor =
2645             py::cast(PyDialectDescriptor{self.getContext(), dialect});
2646         return createCustomDialectWrapper(attrName, std::move(descriptor));
2647       });
2648 
2649   //----------------------------------------------------------------------------
2650   // Mapping of PyDialect
2651   //----------------------------------------------------------------------------
2652   py::class_<PyDialect>(m, "Dialect")
2653       .def(py::init<py::object>(), "descriptor")
2654       .def_property_readonly(
2655           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
2656       .def("__repr__", [](py::object self) {
2657         auto clazz = self.attr("__class__");
2658         return py::str("<Dialect ") +
2659                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
2660                clazz.attr("__module__") + py::str(".") +
2661                clazz.attr("__name__") + py::str(")>");
2662       });
2663 
2664   //----------------------------------------------------------------------------
2665   // Mapping of Location
2666   //----------------------------------------------------------------------------
2667   py::class_<PyLocation>(m, "Location")
2668       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2669       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2670       .def("__enter__", &PyLocation::contextEnter)
2671       .def("__exit__", &PyLocation::contextExit)
2672       .def("__eq__",
2673            [](PyLocation &self, PyLocation &other) -> bool {
2674              return mlirLocationEqual(self, other);
2675            })
2676       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
2677       .def_property_readonly_static(
2678           "current",
2679           [](py::object & /*class*/) {
2680             auto *loc = PyThreadContextEntry::getDefaultLocation();
2681             if (!loc)
2682               throw SetPyError(PyExc_ValueError, "No current Location");
2683             return loc;
2684           },
2685           "Gets the Location bound to the current thread or raises ValueError")
2686       .def_static(
2687           "unknown",
2688           [](DefaultingPyMlirContext context) {
2689             return PyLocation(context->getRef(),
2690                               mlirLocationUnknownGet(context->get()));
2691           },
2692           py::arg("context") = py::none(),
2693           "Gets a Location representing an unknown location")
2694       .def_static(
2695           "file",
2696           [](std::string filename, int line, int col,
2697              DefaultingPyMlirContext context) {
2698             return PyLocation(
2699                 context->getRef(),
2700                 mlirLocationFileLineColGet(
2701                     context->get(), toMlirStringRef(filename), line, col));
2702           },
2703           py::arg("filename"), py::arg("line"), py::arg("col"),
2704           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
2705       .def_property_readonly(
2706           "context",
2707           [](PyLocation &self) { return self.getContext().getObject(); },
2708           "Context that owns the Location")
2709       .def("__repr__", [](PyLocation &self) {
2710         PyPrintAccumulator printAccum;
2711         mlirLocationPrint(self, printAccum.getCallback(),
2712                           printAccum.getUserData());
2713         return printAccum.join();
2714       });
2715 
2716   //----------------------------------------------------------------------------
2717   // Mapping of Module
2718   //----------------------------------------------------------------------------
2719   py::class_<PyModule>(m, "Module")
2720       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2721       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2722       .def_static(
2723           "parse",
2724           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
2725             MlirModule module = mlirModuleCreateParse(
2726                 context->get(), toMlirStringRef(moduleAsm));
2727             // TODO: Rework error reporting once diagnostic engine is exposed
2728             // in C API.
2729             if (mlirModuleIsNull(module)) {
2730               throw SetPyError(
2731                   PyExc_ValueError,
2732                   "Unable to parse module assembly (see diagnostics)");
2733             }
2734             return PyModule::forModule(module).releaseObject();
2735           },
2736           py::arg("asm"), py::arg("context") = py::none(),
2737           kModuleParseDocstring)
2738       .def_static(
2739           "create",
2740           [](DefaultingPyLocation loc) {
2741             MlirModule module = mlirModuleCreateEmpty(loc);
2742             return PyModule::forModule(module).releaseObject();
2743           },
2744           py::arg("loc") = py::none(), "Creates an empty module")
2745       .def_property_readonly(
2746           "context",
2747           [](PyModule &self) { return self.getContext().getObject(); },
2748           "Context that created the Module")
2749       .def_property_readonly(
2750           "operation",
2751           [](PyModule &self) {
2752             return PyOperation::forOperation(self.getContext(),
2753                                              mlirModuleGetOperation(self.get()),
2754                                              self.getRef().releaseObject())
2755                 .releaseObject();
2756           },
2757           "Accesses the module as an operation")
2758       .def_property_readonly(
2759           "body",
2760           [](PyModule &self) {
2761             PyOperationRef module_op = PyOperation::forOperation(
2762                 self.getContext(), mlirModuleGetOperation(self.get()),
2763                 self.getRef().releaseObject());
2764             PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
2765             return returnBlock;
2766           },
2767           "Return the block for this module")
2768       .def(
2769           "dump",
2770           [](PyModule &self) {
2771             mlirOperationDump(mlirModuleGetOperation(self.get()));
2772           },
2773           kDumpDocstring)
2774       .def(
2775           "__str__",
2776           [](PyModule &self) {
2777             MlirOperation operation = mlirModuleGetOperation(self.get());
2778             PyPrintAccumulator printAccum;
2779             mlirOperationPrint(operation, printAccum.getCallback(),
2780                                printAccum.getUserData());
2781             return printAccum.join();
2782           },
2783           kOperationStrDunderDocstring);
2784 
2785   //----------------------------------------------------------------------------
2786   // Mapping of Operation.
2787   //----------------------------------------------------------------------------
2788   py::class_<PyOperationBase>(m, "_OperationBase")
2789       .def("__eq__",
2790            [](PyOperationBase &self, PyOperationBase &other) {
2791              return &self.getOperation() == &other.getOperation();
2792            })
2793       .def("__eq__",
2794            [](PyOperationBase &self, py::object other) { return false; })
2795       .def_property_readonly("attributes",
2796                              [](PyOperationBase &self) {
2797                                return PyOpAttributeMap(
2798                                    self.getOperation().getRef());
2799                              })
2800       .def_property_readonly("operands",
2801                              [](PyOperationBase &self) {
2802                                return PyOpOperandList(
2803                                    self.getOperation().getRef());
2804                              })
2805       .def_property_readonly("regions",
2806                              [](PyOperationBase &self) {
2807                                return PyRegionList(
2808                                    self.getOperation().getRef());
2809                              })
2810       .def_property_readonly(
2811           "results",
2812           [](PyOperationBase &self) {
2813             return PyOpResultList(self.getOperation().getRef());
2814           },
2815           "Returns the list of Operation results.")
2816       .def_property_readonly(
2817           "result",
2818           [](PyOperationBase &self) {
2819             auto &operation = self.getOperation();
2820             auto numResults = mlirOperationGetNumResults(operation);
2821             if (numResults != 1) {
2822               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2823               throw SetPyError(
2824                   PyExc_ValueError,
2825                   Twine("Cannot call .result on operation ") +
2826                       StringRef(name.data, name.length) + " which has " +
2827                       Twine(numResults) +
2828                       " results (it is only valid for operations with a "
2829                       "single result)");
2830             }
2831             return PyOpResult(operation.getRef(),
2832                               mlirOperationGetResult(operation, 0));
2833           },
2834           "Shortcut to get an op result if it has only one (throws an error "
2835           "otherwise).")
2836       .def("__iter__",
2837            [](PyOperationBase &self) {
2838              return PyRegionIterator(self.getOperation().getRef());
2839            })
2840       .def(
2841           "__str__",
2842           [](PyOperationBase &self) {
2843             return self.getAsm(/*binary=*/false,
2844                                /*largeElementsLimit=*/llvm::None,
2845                                /*enableDebugInfo=*/false,
2846                                /*prettyDebugInfo=*/false,
2847                                /*printGenericOpForm=*/false,
2848                                /*useLocalScope=*/false);
2849           },
2850           "Returns the assembly form of the operation.")
2851       .def("print", &PyOperationBase::print,
2852            // Careful: Lots of arguments must match up with print method.
2853            py::arg("file") = py::none(), py::arg("binary") = false,
2854            py::arg("large_elements_limit") = py::none(),
2855            py::arg("enable_debug_info") = false,
2856            py::arg("pretty_debug_info") = false,
2857            py::arg("print_generic_op_form") = false,
2858            py::arg("use_local_scope") = false, kOperationPrintDocstring)
2859       .def("get_asm", &PyOperationBase::getAsm,
2860            // Careful: Lots of arguments must match up with get_asm method.
2861            py::arg("binary") = false,
2862            py::arg("large_elements_limit") = py::none(),
2863            py::arg("enable_debug_info") = false,
2864            py::arg("pretty_debug_info") = false,
2865            py::arg("print_generic_op_form") = false,
2866            py::arg("use_local_scope") = false, kOperationGetAsmDocstring);
2867 
2868   py::class_<PyOperation, PyOperationBase>(m, "Operation")
2869       .def_static("create", &PyOperation::create, py::arg("name"),
2870                   py::arg("operands") = py::none(),
2871                   py::arg("results") = py::none(),
2872                   py::arg("attributes") = py::none(),
2873                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2874                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2875                   kOperationCreateDocstring)
2876       .def_property_readonly(
2877           "context",
2878           [](PyOperation &self) { return self.getContext().getObject(); },
2879           "Context that owns the Operation")
2880       .def_property_readonly("opview", &PyOperation::createOpView);
2881 
2882   py::class_<PyOpView, PyOperationBase>(m, "OpView")
2883       .def(py::init<py::object>())
2884       .def_property_readonly("operation", &PyOpView::getOperationObject)
2885       .def_property_readonly(
2886           "context",
2887           [](PyOpView &self) {
2888             return self.getOperation().getContext().getObject();
2889           },
2890           "Context that owns the Operation")
2891       .def("__str__",
2892            [](PyOpView &self) { return py::str(self.getOperationObject()); });
2893 
2894   //----------------------------------------------------------------------------
2895   // Mapping of PyRegion.
2896   //----------------------------------------------------------------------------
2897   py::class_<PyRegion>(m, "Region")
2898       .def_property_readonly(
2899           "blocks",
2900           [](PyRegion &self) {
2901             return PyBlockList(self.getParentOperation(), self.get());
2902           },
2903           "Returns a forward-optimized sequence of blocks.")
2904       .def(
2905           "__iter__",
2906           [](PyRegion &self) {
2907             self.checkValid();
2908             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2909             return PyBlockIterator(self.getParentOperation(), firstBlock);
2910           },
2911           "Iterates over blocks in the region.")
2912       .def("__eq__",
2913            [](PyRegion &self, PyRegion &other) {
2914              return self.get().ptr == other.get().ptr;
2915            })
2916       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2917 
2918   //----------------------------------------------------------------------------
2919   // Mapping of PyBlock.
2920   //----------------------------------------------------------------------------
2921   py::class_<PyBlock>(m, "Block")
2922       .def_property_readonly(
2923           "arguments",
2924           [](PyBlock &self) {
2925             return PyBlockArgumentList(self.getParentOperation(), self.get());
2926           },
2927           "Returns a list of block arguments.")
2928       .def_property_readonly(
2929           "operations",
2930           [](PyBlock &self) {
2931             return PyOperationList(self.getParentOperation(), self.get());
2932           },
2933           "Returns a forward-optimized sequence of operations.")
2934       .def(
2935           "__iter__",
2936           [](PyBlock &self) {
2937             self.checkValid();
2938             MlirOperation firstOperation =
2939                 mlirBlockGetFirstOperation(self.get());
2940             return PyOperationIterator(self.getParentOperation(),
2941                                        firstOperation);
2942           },
2943           "Iterates over operations in the block.")
2944       .def("__eq__",
2945            [](PyBlock &self, PyBlock &other) {
2946              return self.get().ptr == other.get().ptr;
2947            })
2948       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2949       .def(
2950           "__str__",
2951           [](PyBlock &self) {
2952             self.checkValid();
2953             PyPrintAccumulator printAccum;
2954             mlirBlockPrint(self.get(), printAccum.getCallback(),
2955                            printAccum.getUserData());
2956             return printAccum.join();
2957           },
2958           "Returns the assembly form of the block.");
2959 
2960   //----------------------------------------------------------------------------
2961   // Mapping of PyInsertionPoint.
2962   //----------------------------------------------------------------------------
2963 
2964   py::class_<PyInsertionPoint>(m, "InsertionPoint")
2965       .def(py::init<PyBlock &>(), py::arg("block"),
2966            "Inserts after the last operation but still inside the block.")
2967       .def("__enter__", &PyInsertionPoint::contextEnter)
2968       .def("__exit__", &PyInsertionPoint::contextExit)
2969       .def_property_readonly_static(
2970           "current",
2971           [](py::object & /*class*/) {
2972             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2973             if (!ip)
2974               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2975             return ip;
2976           },
2977           "Gets the InsertionPoint bound to the current thread or raises "
2978           "ValueError if none has been set")
2979       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2980            "Inserts before a referenced operation.")
2981       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2982                   py::arg("block"), "Inserts at the beginning of the block.")
2983       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2984                   py::arg("block"), "Inserts before the block terminator.")
2985       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2986            "Inserts an operation.");
2987 
2988   //----------------------------------------------------------------------------
2989   // Mapping of PyAttribute.
2990   //----------------------------------------------------------------------------
2991   py::class_<PyAttribute>(m, "Attribute")
2992       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2993                              &PyAttribute::getCapsule)
2994       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2995       .def_static(
2996           "parse",
2997           [](std::string attrSpec, DefaultingPyMlirContext context) {
2998             MlirAttribute type = mlirAttributeParseGet(
2999                 context->get(), toMlirStringRef(attrSpec));
3000             // TODO: Rework error reporting once diagnostic engine is exposed
3001             // in C API.
3002             if (mlirAttributeIsNull(type)) {
3003               throw SetPyError(PyExc_ValueError,
3004                                Twine("Unable to parse attribute: '") +
3005                                    attrSpec + "'");
3006             }
3007             return PyAttribute(context->getRef(), type);
3008           },
3009           py::arg("asm"), py::arg("context") = py::none(),
3010           "Parses an attribute from an assembly form")
3011       .def_property_readonly(
3012           "context",
3013           [](PyAttribute &self) { return self.getContext().getObject(); },
3014           "Context that owns the Attribute")
3015       .def_property_readonly("type",
3016                              [](PyAttribute &self) {
3017                                return PyType(self.getContext()->getRef(),
3018                                              mlirAttributeGetType(self));
3019                              })
3020       .def(
3021           "get_named",
3022           [](PyAttribute &self, std::string name) {
3023             return PyNamedAttribute(self, std::move(name));
3024           },
3025           py::keep_alive<0, 1>(), "Binds a name to the attribute")
3026       .def("__eq__",
3027            [](PyAttribute &self, PyAttribute &other) { return self == other; })
3028       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
3029       .def(
3030           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
3031           kDumpDocstring)
3032       .def(
3033           "__str__",
3034           [](PyAttribute &self) {
3035             PyPrintAccumulator printAccum;
3036             mlirAttributePrint(self, printAccum.getCallback(),
3037                                printAccum.getUserData());
3038             return printAccum.join();
3039           },
3040           "Returns the assembly form of the Attribute.")
3041       .def("__repr__", [](PyAttribute &self) {
3042         // Generally, assembly formats are not printed for __repr__ because
3043         // this can cause exceptionally long debug output and exceptions.
3044         // However, attribute values are generally considered useful and are
3045         // printed. This may need to be re-evaluated if debug dumps end up
3046         // being excessive.
3047         PyPrintAccumulator printAccum;
3048         printAccum.parts.append("Attribute(");
3049         mlirAttributePrint(self, printAccum.getCallback(),
3050                            printAccum.getUserData());
3051         printAccum.parts.append(")");
3052         return printAccum.join();
3053       });
3054 
3055   //----------------------------------------------------------------------------
3056   // Mapping of PyNamedAttribute
3057   //----------------------------------------------------------------------------
3058   py::class_<PyNamedAttribute>(m, "NamedAttribute")
3059       .def("__repr__",
3060            [](PyNamedAttribute &self) {
3061              PyPrintAccumulator printAccum;
3062              printAccum.parts.append("NamedAttribute(");
3063              printAccum.parts.append(self.namedAttr.name.data);
3064              printAccum.parts.append("=");
3065              mlirAttributePrint(self.namedAttr.attribute,
3066                                 printAccum.getCallback(),
3067                                 printAccum.getUserData());
3068              printAccum.parts.append(")");
3069              return printAccum.join();
3070            })
3071       .def_property_readonly(
3072           "name",
3073           [](PyNamedAttribute &self) {
3074             return py::str(self.namedAttr.name.data,
3075                            self.namedAttr.name.length);
3076           },
3077           "The name of the NamedAttribute binding")
3078       .def_property_readonly(
3079           "attr",
3080           [](PyNamedAttribute &self) {
3081             // TODO: When named attribute is removed/refactored, also remove
3082             // this constructor (it does an inefficient table lookup).
3083             auto contextRef = PyMlirContext::forContext(
3084                 mlirAttributeGetContext(self.namedAttr.attribute));
3085             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
3086           },
3087           py::keep_alive<0, 1>(),
3088           "The underlying generic attribute of the NamedAttribute binding");
3089 
3090   // Builtin attribute bindings.
3091   PyFloatAttribute::bind(m);
3092   PyIntegerAttribute::bind(m);
3093   PyBoolAttribute::bind(m);
3094   PyStringAttribute::bind(m);
3095   PyDenseElementsAttribute::bind(m);
3096   PyDenseIntElementsAttribute::bind(m);
3097   PyDenseFPElementsAttribute::bind(m);
3098   PyTypeAttribute::bind(m);
3099   PyUnitAttribute::bind(m);
3100 
3101   //----------------------------------------------------------------------------
3102   // Mapping of PyType.
3103   //----------------------------------------------------------------------------
3104   py::class_<PyType>(m, "Type")
3105       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
3106       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
3107       .def_static(
3108           "parse",
3109           [](std::string typeSpec, DefaultingPyMlirContext context) {
3110             MlirType type =
3111                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
3112             // TODO: Rework error reporting once diagnostic engine is exposed
3113             // in C API.
3114             if (mlirTypeIsNull(type)) {
3115               throw SetPyError(PyExc_ValueError,
3116                                Twine("Unable to parse type: '") + typeSpec +
3117                                    "'");
3118             }
3119             return PyType(context->getRef(), type);
3120           },
3121           py::arg("asm"), py::arg("context") = py::none(),
3122           kContextParseTypeDocstring)
3123       .def_property_readonly(
3124           "context", [](PyType &self) { return self.getContext().getObject(); },
3125           "Context that owns the Type")
3126       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
3127       .def("__eq__", [](PyType &self, py::object &other) { return false; })
3128       .def(
3129           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
3130       .def(
3131           "__str__",
3132           [](PyType &self) {
3133             PyPrintAccumulator printAccum;
3134             mlirTypePrint(self, printAccum.getCallback(),
3135                           printAccum.getUserData());
3136             return printAccum.join();
3137           },
3138           "Returns the assembly form of the type.")
3139       .def("__repr__", [](PyType &self) {
3140         // Generally, assembly formats are not printed for __repr__ because
3141         // this can cause exceptionally long debug output and exceptions.
3142         // However, types are an exception as they typically have compact
3143         // assembly forms and printing them is useful.
3144         PyPrintAccumulator printAccum;
3145         printAccum.parts.append("Type(");
3146         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
3147         printAccum.parts.append(")");
3148         return printAccum.join();
3149       });
3150 
3151   // Builtin type bindings.
3152   PyIntegerType::bind(m);
3153   PyIndexType::bind(m);
3154   PyBF16Type::bind(m);
3155   PyF16Type::bind(m);
3156   PyF32Type::bind(m);
3157   PyF64Type::bind(m);
3158   PyNoneType::bind(m);
3159   PyComplexType::bind(m);
3160   PyShapedType::bind(m);
3161   PyVectorType::bind(m);
3162   PyRankedTensorType::bind(m);
3163   PyUnrankedTensorType::bind(m);
3164   PyMemRefType::bind(m);
3165   PyUnrankedMemRefType::bind(m);
3166   PyTupleType::bind(m);
3167   PyFunctionType::bind(m);
3168 
3169   //----------------------------------------------------------------------------
3170   // Mapping of Value.
3171   //----------------------------------------------------------------------------
3172   py::class_<PyValue>(m, "Value")
3173       .def_property_readonly(
3174           "context",
3175           [](PyValue &self) { return self.getParentOperation()->getContext(); },
3176           "Context in which the value lives.")
3177       .def(
3178           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
3179           kDumpDocstring)
3180       .def("__eq__",
3181            [](PyValue &self, PyValue &other) {
3182              return self.get().ptr == other.get().ptr;
3183            })
3184       .def("__eq__", [](PyValue &self, py::object other) { return false; })
3185       .def(
3186           "__str__",
3187           [](PyValue &self) {
3188             PyPrintAccumulator printAccum;
3189             printAccum.parts.append("Value(");
3190             mlirValuePrint(self.get(), printAccum.getCallback(),
3191                            printAccum.getUserData());
3192             printAccum.parts.append(")");
3193             return printAccum.join();
3194           },
3195           kValueDunderStrDocstring)
3196       .def_property_readonly("type", [](PyValue &self) {
3197         return PyType(self.getParentOperation()->getContext(),
3198                       mlirValueGetType(self.get()));
3199       });
3200   PyBlockArgument::bind(m);
3201   PyOpResult::bind(m);
3202 
3203   // Container bindings.
3204   PyBlockArgumentList::bind(m);
3205   PyBlockIterator::bind(m);
3206   PyBlockList::bind(m);
3207   PyOperationIterator::bind(m);
3208   PyOperationList::bind(m);
3209   PyOpAttributeMap::bind(m);
3210   PyOpOperandList::bind(m);
3211   PyOpResultList::bind(m);
3212   PyRegionIterator::bind(m);
3213   PyRegionList::bind(m);
3214 }
3215