1 //===- IRModules.h - IR Submodules of pybind module -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H
10 #define MLIR_BINDINGS_PYTHON_IRMODULES_H
11 
12 #include <vector>
13 
14 #include "PybindUtils.h"
15 
16 #include "mlir-c/IR.h"
17 #include "llvm/ADT/DenseMap.h"
18 
19 namespace mlir {
20 namespace python {
21 
22 class PyBlock;
23 class PyInsertionPoint;
24 class PyLocation;
25 class DefaultingPyLocation;
26 class PyMlirContext;
27 class DefaultingPyMlirContext;
28 class PyModule;
29 class PyOperation;
30 class PyType;
31 class PyValue;
32 
33 /// Template for a reference to a concrete type which captures a python
34 /// reference to its underlying python object.
35 template <typename T>
36 class PyObjectRef {
37 public:
PyObjectRef(T * referrent,pybind11::object object)38   PyObjectRef(T *referrent, pybind11::object object)
39       : referrent(referrent), object(std::move(object)) {
40     assert(this->referrent &&
41            "cannot construct PyObjectRef with null referrent");
42     assert(this->object && "cannot construct PyObjectRef with null object");
43   }
PyObjectRef(PyObjectRef && other)44   PyObjectRef(PyObjectRef &&other)
45       : referrent(other.referrent), object(std::move(other.object)) {
46     other.referrent = nullptr;
47     assert(!other.object);
48   }
PyObjectRef(const PyObjectRef & other)49   PyObjectRef(const PyObjectRef &other)
50       : referrent(other.referrent), object(other.object /* copies */) {}
~PyObjectRef()51   ~PyObjectRef() {}
52 
getRefCount()53   int getRefCount() {
54     if (!object)
55       return 0;
56     return object.ref_count();
57   }
58 
59   /// Releases the object held by this instance, returning it.
60   /// This is the proper thing to return from a function that wants to return
61   /// the reference. Note that this does not work from initializers.
releaseObject()62   pybind11::object releaseObject() {
63     assert(referrent && object);
64     referrent = nullptr;
65     auto stolen = std::move(object);
66     return stolen;
67   }
68 
get()69   T *get() { return referrent; }
70   T *operator->() {
71     assert(referrent && object);
72     return referrent;
73   }
getObject()74   pybind11::object getObject() {
75     assert(referrent && object);
76     return object;
77   }
78   operator bool() const { return referrent && object; }
79 
80 private:
81   T *referrent;
82   pybind11::object object;
83 };
84 
85 /// Tracks an entry in the thread context stack. New entries are pushed onto
86 /// here for each with block that activates a new InsertionPoint, Context or
87 /// Location.
88 ///
89 /// Pushing either a Location or InsertionPoint also pushes its associated
90 /// Context. Pushing a Context will not modify the Location or InsertionPoint
91 /// unless if they are from a different context, in which case, they are
92 /// cleared.
93 class PyThreadContextEntry {
94 public:
95   enum class FrameKind {
96     Context,
97     InsertionPoint,
98     Location,
99   };
100 
PyThreadContextEntry(FrameKind frameKind,pybind11::object context,pybind11::object insertionPoint,pybind11::object location)101   PyThreadContextEntry(FrameKind frameKind, pybind11::object context,
102                        pybind11::object insertionPoint,
103                        pybind11::object location)
104       : context(std::move(context)), insertionPoint(std::move(insertionPoint)),
105         location(std::move(location)), frameKind(frameKind) {}
106 
107   /// Gets the top of stack context and return nullptr if not defined.
108   static PyMlirContext *getDefaultContext();
109 
110   /// Gets the top of stack insertion point and return nullptr if not defined.
111   static PyInsertionPoint *getDefaultInsertionPoint();
112 
113   /// Gets the top of stack location and returns nullptr if not defined.
114   static PyLocation *getDefaultLocation();
115 
116   PyMlirContext *getContext();
117   PyInsertionPoint *getInsertionPoint();
118   PyLocation *getLocation();
getFrameKind()119   FrameKind getFrameKind() { return frameKind; }
120 
121   /// Stack management.
122   static PyThreadContextEntry *getTopOfStack();
123   static pybind11::object pushContext(PyMlirContext &context);
124   static void popContext(PyMlirContext &context);
125   static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint);
126   static void popInsertionPoint(PyInsertionPoint &insertionPoint);
127   static pybind11::object pushLocation(PyLocation &location);
128   static void popLocation(PyLocation &location);
129 
130   /// Gets the thread local stack.
131   static std::vector<PyThreadContextEntry> &getStack();
132 
133 private:
134   static void push(FrameKind frameKind, pybind11::object context,
135                    pybind11::object insertionPoint, pybind11::object location);
136 
137   /// An object reference to the PyContext.
138   pybind11::object context;
139   /// An object reference to the current insertion point.
140   pybind11::object insertionPoint;
141   /// An object reference to the current location.
142   pybind11::object location;
143   // The kind of push that was performed.
144   FrameKind frameKind;
145 };
146 
147 /// Wrapper around MlirContext.
148 using PyMlirContextRef = PyObjectRef<PyMlirContext>;
149 class PyMlirContext {
150 public:
151   PyMlirContext() = delete;
152   PyMlirContext(const PyMlirContext &) = delete;
153   PyMlirContext(PyMlirContext &&) = delete;
154 
155   /// For the case of a python __init__ (py::init) method, pybind11 is quite
156   /// strict about needing to return a pointer that is not yet associated to
157   /// an py::object. Since the forContext() method acts like a pool, possibly
158   /// returning a recycled context, it does not satisfy this need. The usual
159   /// way in python to accomplish such a thing is to override __new__, but
160   /// that is also not supported by pybind11. Instead, we use this entry
161   /// point which always constructs a fresh context (which cannot alias an
162   /// existing one because it is fresh).
163   static PyMlirContext *createNewContextForInit();
164 
165   /// Returns a context reference for the singleton PyMlirContext wrapper for
166   /// the given context.
167   static PyMlirContextRef forContext(MlirContext context);
168   ~PyMlirContext();
169 
170   /// Accesses the underlying MlirContext.
get()171   MlirContext get() { return context; }
172 
173   /// Gets a strong reference to this context, which will ensure it is kept
174   /// alive for the life of the reference.
getRef()175   PyMlirContextRef getRef() {
176     return PyMlirContextRef(this, pybind11::cast(this));
177   }
178 
179   /// Gets a capsule wrapping the void* within the MlirContext.
180   pybind11::object getCapsule();
181 
182   /// Creates a PyMlirContext from the MlirContext wrapped by a capsule.
183   /// Note that PyMlirContext instances are uniqued, so the returned object
184   /// may be a pre-existing object. Ownership of the underlying MlirContext
185   /// is taken by calling this function.
186   static pybind11::object createFromCapsule(pybind11::object capsule);
187 
188   /// Gets the count of live context objects. Used for testing.
189   static size_t getLiveCount();
190 
191   /// Gets the count of live operations associated with this context.
192   /// Used for testing.
193   size_t getLiveOperationCount();
194 
195   /// Gets the count of live modules associated with this context.
196   /// Used for testing.
197   size_t getLiveModuleCount();
198 
199   /// Enter and exit the context manager.
200   pybind11::object contextEnter();
201   void contextExit(pybind11::object excType, pybind11::object excVal,
202                    pybind11::object excTb);
203 
204 private:
205   PyMlirContext(MlirContext context);
206   // Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
207   // preserving the relationship that an MlirContext maps to a single
208   // PyMlirContext wrapper. This could be replaced in the future with an
209   // extension mechanism on the MlirContext for stashing user pointers.
210   // Note that this holds a handle, which does not imply ownership.
211   // Mappings will be removed when the context is destructed.
212   using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
213   static LiveContextMap &getLiveContexts();
214 
215   // Interns all live modules associated with this context. Modules tracked
216   // in this map are valid. When a module is invalidated, it is removed
217   // from this map, and while it still exists as an instance, any
218   // attempt to access it will raise an error.
219   using LiveModuleMap =
220       llvm::DenseMap<const void *, std::pair<pybind11::handle, PyModule *>>;
221   LiveModuleMap liveModules;
222 
223   // Interns all live operations associated with this context. Operations
224   // tracked in this map are valid. When an operation is invalidated, it is
225   // removed from this map, and while it still exists as an instance, any
226   // attempt to access it will raise an error.
227   using LiveOperationMap =
228       llvm::DenseMap<void *, std::pair<pybind11::handle, PyOperation *>>;
229   LiveOperationMap liveOperations;
230 
231   MlirContext context;
232   friend class PyModule;
233   friend class PyOperation;
234 };
235 
236 /// Used in function arguments when None should resolve to the current context
237 /// manager set instance.
238 class DefaultingPyMlirContext
239     : public Defaulting<DefaultingPyMlirContext, PyMlirContext> {
240 public:
241   using Defaulting::Defaulting;
242   static constexpr const char kTypeDescription[] =
243       "[ThreadContextAware] mlir.ir.Context";
244   static PyMlirContext &resolve();
245 };
246 
247 /// Base class for all objects that directly or indirectly depend on an
248 /// MlirContext. The lifetime of the context will extend at least to the
249 /// lifetime of these instances.
250 /// Immutable objects that depend on a context extend this directly.
251 class BaseContextObject {
252 public:
BaseContextObject(PyMlirContextRef ref)253   BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) {
254     assert(this->contextRef &&
255            "context object constructed with null context ref");
256   }
257 
258   /// Accesses the context reference.
getContext()259   PyMlirContextRef &getContext() { return contextRef; }
260 
261 private:
262   PyMlirContextRef contextRef;
263 };
264 
265 /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in
266 /// order to differentiate it from the `Dialect` base class which is extended by
267 /// plugins which extend dialect functionality through extension python code.
268 /// This should be seen as the "low-level" object and `Dialect` as the
269 /// high-level, user facing object.
270 class PyDialectDescriptor : public BaseContextObject {
271 public:
PyDialectDescriptor(PyMlirContextRef contextRef,MlirDialect dialect)272   PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect)
273       : BaseContextObject(std::move(contextRef)), dialect(dialect) {}
274 
get()275   MlirDialect get() { return dialect; }
276 
277 private:
278   MlirDialect dialect;
279 };
280 
281 /// User-level object for accessing dialects with dotted syntax such as:
282 ///   ctx.dialect.std
283 class PyDialects : public BaseContextObject {
284 public:
PyDialects(PyMlirContextRef contextRef)285   PyDialects(PyMlirContextRef contextRef)
286       : BaseContextObject(std::move(contextRef)) {}
287 
288   MlirDialect getDialectForKey(const std::string &key, bool attrError);
289 };
290 
291 /// User-level dialect object. For dialects that have a registered extension,
292 /// this will be the base class of the extension dialect type. For un-extended,
293 /// objects of this type will be returned directly.
294 class PyDialect {
295 public:
PyDialect(pybind11::object descriptor)296   PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {}
297 
getDescriptor()298   pybind11::object getDescriptor() { return descriptor; }
299 
300 private:
301   pybind11::object descriptor;
302 };
303 
304 /// Wrapper around an MlirLocation.
305 class PyLocation : public BaseContextObject {
306 public:
PyLocation(PyMlirContextRef contextRef,MlirLocation loc)307   PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
308       : BaseContextObject(std::move(contextRef)), loc(loc) {}
309 
MlirLocation()310   operator MlirLocation() const { return loc; }
get()311   MlirLocation get() const { return loc; }
312 
313   /// Enter and exit the context manager.
314   pybind11::object contextEnter();
315   void contextExit(pybind11::object excType, pybind11::object excVal,
316                    pybind11::object excTb);
317 
318   /// Gets a capsule wrapping the void* within the MlirContext.
319   pybind11::object getCapsule();
320 
321   /// Creates a PyMlirContext from the MlirContext wrapped by a capsule.
322   /// Note that PyMlirContext instances are uniqued, so the returned object
323   /// may be a pre-existing object. Ownership of the underlying MlirContext
324   /// is taken by calling this function.
325   static PyLocation createFromCapsule(pybind11::object capsule);
326 
327 private:
328   MlirLocation loc;
329 };
330 
331 /// Used in function arguments when None should resolve to the current context
332 /// manager set instance.
333 class DefaultingPyLocation
334     : public Defaulting<DefaultingPyLocation, PyLocation> {
335 public:
336   using Defaulting::Defaulting;
337   static constexpr const char kTypeDescription[] =
338       "[ThreadContextAware] mlir.ir.Location";
339   static PyLocation &resolve();
340 
MlirLocation()341   operator MlirLocation() const { return *get(); }
342 };
343 
344 /// Wrapper around MlirModule.
345 /// This is the top-level, user-owned object that contains regions/ops/blocks.
346 class PyModule;
347 using PyModuleRef = PyObjectRef<PyModule>;
348 class PyModule : public BaseContextObject {
349 public:
350   /// Returns a PyModule reference for the given MlirModule. This may return
351   /// a pre-existing or new object.
352   static PyModuleRef forModule(MlirModule module);
353   PyModule(PyModule &) = delete;
354   PyModule(PyMlirContext &&) = delete;
355   ~PyModule();
356 
357   /// Gets the backing MlirModule.
get()358   MlirModule get() { return module; }
359 
360   /// Gets a strong reference to this module.
getRef()361   PyModuleRef getRef() {
362     return PyModuleRef(this,
363                        pybind11::reinterpret_borrow<pybind11::object>(handle));
364   }
365 
366   /// Gets a capsule wrapping the void* within the MlirModule.
367   /// Note that the module does not (yet) provide a corresponding factory for
368   /// constructing from a capsule as that would require uniquing PyModule
369   /// instances, which is not currently done.
370   pybind11::object getCapsule();
371 
372   /// Creates a PyModule from the MlirModule wrapped by a capsule.
373   /// Note that PyModule instances are uniqued, so the returned object
374   /// may be a pre-existing object. Ownership of the underlying MlirModule
375   /// is taken by calling this function.
376   static pybind11::object createFromCapsule(pybind11::object capsule);
377 
378 private:
379   PyModule(PyMlirContextRef contextRef, MlirModule module);
380   MlirModule module;
381   pybind11::handle handle;
382 };
383 
384 /// Base class for PyOperation and PyOpView which exposes the primary, user
385 /// visible methods for manipulating it.
386 class PyOperationBase {
387 public:
388   virtual ~PyOperationBase() = default;
389   /// Implements the bound 'print' method and helps with others.
390   void print(pybind11::object fileObject, bool binary,
391              llvm::Optional<int64_t> largeElementsLimit, bool enableDebugInfo,
392              bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope);
393   pybind11::object getAsm(bool binary,
394                           llvm::Optional<int64_t> largeElementsLimit,
395                           bool enableDebugInfo, bool prettyDebugInfo,
396                           bool printGenericOpForm, bool useLocalScope);
397 
398   /// Each must provide access to the raw Operation.
399   virtual PyOperation &getOperation() = 0;
400 };
401 
402 /// Wrapper around PyOperation.
403 /// Operations exist in either an attached (dependent) or detached (top-level)
404 /// state. In the detached state (as on creation), an operation is owned by
405 /// the creator and its lifetime extends either until its reference count
406 /// drops to zero or it is attached to a parent, at which point its lifetime
407 /// is bounded by its top-level parent reference.
408 class PyOperation;
409 using PyOperationRef = PyObjectRef<PyOperation>;
410 class PyOperation : public PyOperationBase, public BaseContextObject {
411 public:
412   ~PyOperation();
getOperation()413   PyOperation &getOperation() override { return *this; }
414 
415   /// Returns a PyOperation for the given MlirOperation, optionally associating
416   /// it with a parentKeepAlive.
417   static PyOperationRef
418   forOperation(PyMlirContextRef contextRef, MlirOperation operation,
419                pybind11::object parentKeepAlive = pybind11::object());
420 
421   /// Creates a detached operation. The operation must not be associated with
422   /// any existing live operation.
423   static PyOperationRef
424   createDetached(PyMlirContextRef contextRef, MlirOperation operation,
425                  pybind11::object parentKeepAlive = pybind11::object());
426 
427   /// Gets the backing operation.
MlirOperation()428   operator MlirOperation() const { return get(); }
get()429   MlirOperation get() const {
430     checkValid();
431     return operation;
432   }
433 
getRef()434   PyOperationRef getRef() {
435     return PyOperationRef(
436         this, pybind11::reinterpret_borrow<pybind11::object>(handle));
437   }
438 
isAttached()439   bool isAttached() { return attached; }
setAttached()440   void setAttached() {
441     assert(!attached && "operation already attached");
442     attached = true;
443   }
444   void checkValid() const;
445 
446   /// Gets the owning block or raises an exception if the operation has no
447   /// owning block.
448   PyBlock getBlock();
449 
450   /// Gets the parent operation or raises an exception if the operation has
451   /// no parent.
452   PyOperationRef getParentOperation();
453 
454   /// Creates an operation. See corresponding python docstring.
455   static pybind11::object
456   create(std::string name, llvm::Optional<std::vector<PyValue *>> operands,
457          llvm::Optional<std::vector<PyType *>> results,
458          llvm::Optional<pybind11::dict> attributes,
459          llvm::Optional<std::vector<PyBlock *>> successors, int regions,
460          DefaultingPyLocation location, pybind11::object ip);
461 
462   /// Creates an OpView suitable for this operation.
463   pybind11::object createOpView();
464 
465 private:
466   PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
467   static PyOperationRef createInstance(PyMlirContextRef contextRef,
468                                        MlirOperation operation,
469                                        pybind11::object parentKeepAlive);
470 
471   MlirOperation operation;
472   pybind11::handle handle;
473   // Keeps the parent alive, regardless of whether it is an Operation or
474   // Module.
475   // TODO: As implemented, this facility is only sufficient for modeling the
476   // trivial module parent back-reference. Generalize this to also account for
477   // transitions from detached to attached and address TODOs in the
478   // ir_operation.py regarding testing corresponding lifetime guarantees.
479   pybind11::object parentKeepAlive;
480   bool attached = true;
481   bool valid = true;
482 };
483 
484 /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for
485 /// providing more instance-specific accessors and serve as the base class for
486 /// custom ODS-style operation classes. Since this class is subclass on the
487 /// python side, it must present an __init__ method that operates in pure
488 /// python types.
489 class PyOpView : public PyOperationBase {
490 public:
491   PyOpView(pybind11::object operationObject);
getOperation()492   PyOperation &getOperation() override { return operation; }
493 
494   static pybind11::object createRawSubclass(pybind11::object userClass);
495 
getOperationObject()496   pybind11::object getOperationObject() { return operationObject; }
497 
498 private:
499   PyOperation &operation;           // For efficient, cast-free access from C++
500   pybind11::object operationObject; // Holds the reference.
501 };
502 
503 /// Wrapper around an MlirRegion.
504 /// Regions are managed completely by their containing operation. Unlike the
505 /// C++ API, the python API does not support detached regions.
506 class PyRegion {
507 public:
PyRegion(PyOperationRef parentOperation,MlirRegion region)508   PyRegion(PyOperationRef parentOperation, MlirRegion region)
509       : parentOperation(std::move(parentOperation)), region(region) {
510     assert(!mlirRegionIsNull(region) && "python region cannot be null");
511   }
512 
get()513   MlirRegion get() { return region; }
getParentOperation()514   PyOperationRef &getParentOperation() { return parentOperation; }
515 
checkValid()516   void checkValid() { return parentOperation->checkValid(); }
517 
518 private:
519   PyOperationRef parentOperation;
520   MlirRegion region;
521 };
522 
523 /// Wrapper around an MlirBlock.
524 /// Blocks are managed completely by their containing operation. Unlike the
525 /// C++ API, the python API does not support detached blocks.
526 class PyBlock {
527 public:
PyBlock(PyOperationRef parentOperation,MlirBlock block)528   PyBlock(PyOperationRef parentOperation, MlirBlock block)
529       : parentOperation(std::move(parentOperation)), block(block) {
530     assert(!mlirBlockIsNull(block) && "python block cannot be null");
531   }
532 
get()533   MlirBlock get() { return block; }
getParentOperation()534   PyOperationRef &getParentOperation() { return parentOperation; }
535 
checkValid()536   void checkValid() { return parentOperation->checkValid(); }
537 
538 private:
539   PyOperationRef parentOperation;
540   MlirBlock block;
541 };
542 
543 /// An insertion point maintains a pointer to a Block and a reference operation.
544 /// Calls to insert() will insert a new operation before the
545 /// reference operation. If the reference operation is null, then appends to
546 /// the end of the block.
547 class PyInsertionPoint {
548 public:
549   /// Creates an insertion point positioned after the last operation in the
550   /// block, but still inside the block.
551   PyInsertionPoint(PyBlock &block);
552   /// Creates an insertion point positioned before a reference operation.
553   PyInsertionPoint(PyOperationBase &beforeOperationBase);
554 
555   /// Shortcut to create an insertion point at the beginning of the block.
556   static PyInsertionPoint atBlockBegin(PyBlock &block);
557   /// Shortcut to create an insertion point before the block terminator.
558   static PyInsertionPoint atBlockTerminator(PyBlock &block);
559 
560   /// Inserts an operation.
561   void insert(PyOperationBase &operationBase);
562 
563   /// Enter and exit the context manager.
564   pybind11::object contextEnter();
565   void contextExit(pybind11::object excType, pybind11::object excVal,
566                    pybind11::object excTb);
567 
getBlock()568   PyBlock &getBlock() { return block; }
569 
570 private:
571   // Trampoline constructor that avoids null initializing members while
572   // looking up parents.
PyInsertionPoint(PyBlock block,llvm::Optional<PyOperationRef> refOperation)573   PyInsertionPoint(PyBlock block, llvm::Optional<PyOperationRef> refOperation)
574       : refOperation(std::move(refOperation)), block(std::move(block)) {}
575 
576   llvm::Optional<PyOperationRef> refOperation;
577   PyBlock block;
578 };
579 
580 /// Wrapper around the generic MlirAttribute.
581 /// The lifetime of a type is bound by the PyContext that created it.
582 class PyAttribute : public BaseContextObject {
583 public:
PyAttribute(PyMlirContextRef contextRef,MlirAttribute attr)584   PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
585       : BaseContextObject(std::move(contextRef)), attr(attr) {}
586   bool operator==(const PyAttribute &other);
MlirAttribute()587   operator MlirAttribute() const { return attr; }
get()588   MlirAttribute get() const { return attr; }
589 
590   /// Gets a capsule wrapping the void* within the MlirContext.
591   pybind11::object getCapsule();
592 
593   /// Creates a PyMlirContext from the MlirContext wrapped by a capsule.
594   /// Note that PyMlirContext instances are uniqued, so the returned object
595   /// may be a pre-existing object. Ownership of the underlying MlirContext
596   /// is taken by calling this function.
597   static PyAttribute createFromCapsule(pybind11::object capsule);
598 
599 private:
600   MlirAttribute attr;
601 };
602 
603 /// Represents a Python MlirNamedAttr, carrying an optional owned name.
604 /// TODO: Refactor this and the C-API to be based on an Identifier owned
605 /// by the context so as to avoid ownership issues here.
606 class PyNamedAttribute {
607 public:
608   /// Constructs a PyNamedAttr that retains an owned name. This should be
609   /// used in any code that originates an MlirNamedAttribute from a python
610   /// string.
611   /// The lifetime of the PyNamedAttr must extend to the lifetime of the
612   /// passed attribute.
613   PyNamedAttribute(MlirAttribute attr, std::string ownedName);
614 
615   MlirNamedAttribute namedAttr;
616 
617 private:
618   // Since the MlirNamedAttr contains an internal pointer to the actual
619   // memory of the owned string, it must be heap allocated to remain valid.
620   // Otherwise, strings that fit within the small object optimization threshold
621   // will have their memory address change as the containing object is moved,
622   // resulting in an invalid aliased pointer.
623   std::unique_ptr<std::string> ownedName;
624 };
625 
626 /// Wrapper around the generic MlirType.
627 /// The lifetime of a type is bound by the PyContext that created it.
628 class PyType : public BaseContextObject {
629 public:
PyType(PyMlirContextRef contextRef,MlirType type)630   PyType(PyMlirContextRef contextRef, MlirType type)
631       : BaseContextObject(std::move(contextRef)), type(type) {}
632   bool operator==(const PyType &other);
MlirType()633   operator MlirType() const { return type; }
get()634   MlirType get() const { return type; }
635 
636   /// Gets a capsule wrapping the void* within the MlirContext.
637   pybind11::object getCapsule();
638 
639   /// Creates a PyMlirContext from the MlirContext wrapped by a capsule.
640   /// Note that PyMlirContext instances are uniqued, so the returned object
641   /// may be a pre-existing object. Ownership of the underlying MlirContext
642   /// is taken by calling this function.
643   static PyType createFromCapsule(pybind11::object capsule);
644 
645 private:
646   MlirType type;
647 };
648 
649 /// Wrapper around the generic MlirValue.
650 /// Values are managed completely by the operation that resulted in their
651 /// definition. For op result value, this is the operation that defines the
652 /// value. For block argument values, this is the operation that contains the
653 /// block to which the value is an argument (blocks cannot be detached in Python
654 /// bindings so such operation always exists).
655 class PyValue {
656 public:
PyValue(PyOperationRef parentOperation,MlirValue value)657   PyValue(PyOperationRef parentOperation, MlirValue value)
658       : parentOperation(parentOperation), value(value) {}
659 
get()660   MlirValue get() { return value; }
getParentOperation()661   PyOperationRef &getParentOperation() { return parentOperation; }
662 
checkValid()663   void checkValid() { return parentOperation->checkValid(); }
664 
665 private:
666   PyOperationRef parentOperation;
667   MlirValue value;
668 };
669 
670 void populateIRSubmodule(pybind11::module &m);
671 
672 } // namespace python
673 } // namespace mlir
674 
675 namespace pybind11 {
676 namespace detail {
677 
678 template <>
679 struct type_caster<mlir::python::DefaultingPyMlirContext>
680     : MlirDefaultingCaster<mlir::python::DefaultingPyMlirContext> {};
681 template <>
682 struct type_caster<mlir::python::DefaultingPyLocation>
683     : MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {};
684 
685 } // namespace detail
686 } // namespace pybind11
687 
688 #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H
689