1 //===- MainModule.cpp - Main pybind module --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include <tuple>
10
11 #include "PybindUtils.h"
12
13 #include "Globals.h"
14 #include "IRModules.h"
15 #include "Pass.h"
16
17 namespace py = pybind11;
18 using namespace mlir;
19 using namespace mlir::python;
20
21 // -----------------------------------------------------------------------------
22 // PyGlobals
23 // -----------------------------------------------------------------------------
24
25 PyGlobals *PyGlobals::instance = nullptr;
26
PyGlobals()27 PyGlobals::PyGlobals() {
28 assert(!instance && "PyGlobals already constructed");
29 instance = this;
30 }
31
~PyGlobals()32 PyGlobals::~PyGlobals() { instance = nullptr; }
33
loadDialectModule(llvm::StringRef dialectNamespace)34 void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
35 py::gil_scoped_acquire();
36 if (loadedDialectModulesCache.contains(dialectNamespace))
37 return;
38 // Since re-entrancy is possible, make a copy of the search prefixes.
39 std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
40 py::object loaded;
41 for (std::string moduleName : localSearchPrefixes) {
42 moduleName.push_back('.');
43 moduleName.append(dialectNamespace.data(), dialectNamespace.size());
44
45 try {
46 py::gil_scoped_release();
47 loaded = py::module::import(moduleName.c_str());
48 } catch (py::error_already_set &e) {
49 if (e.matches(PyExc_ModuleNotFoundError)) {
50 continue;
51 } else {
52 throw;
53 }
54 }
55 break;
56 }
57
58 // Note: Iterator cannot be shared from prior to loading, since re-entrancy
59 // may have occurred, which may do anything.
60 loadedDialectModulesCache.insert(dialectNamespace);
61 }
62
registerDialectImpl(const std::string & dialectNamespace,py::object pyClass)63 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
64 py::object pyClass) {
65 py::gil_scoped_acquire();
66 py::object &found = dialectClassMap[dialectNamespace];
67 if (found) {
68 throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
69 dialectNamespace +
70 "' is already registered.");
71 }
72 found = std::move(pyClass);
73 }
74
registerOperationImpl(const std::string & operationName,py::object pyClass,py::object rawOpViewClass)75 void PyGlobals::registerOperationImpl(const std::string &operationName,
76 py::object pyClass,
77 py::object rawOpViewClass) {
78 py::gil_scoped_acquire();
79 py::object &found = operationClassMap[operationName];
80 if (found) {
81 throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
82 operationName +
83 "' is already registered.");
84 }
85 found = std::move(pyClass);
86 rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
87 }
88
89 llvm::Optional<py::object>
lookupDialectClass(const std::string & dialectNamespace)90 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
91 py::gil_scoped_acquire();
92 loadDialectModule(dialectNamespace);
93 // Fast match against the class map first (common case).
94 const auto foundIt = dialectClassMap.find(dialectNamespace);
95 if (foundIt != dialectClassMap.end()) {
96 if (foundIt->second.is_none())
97 return llvm::None;
98 assert(foundIt->second && "py::object is defined");
99 return foundIt->second;
100 }
101
102 // Not found and loading did not yield a registration. Negative cache.
103 dialectClassMap[dialectNamespace] = py::none();
104 return llvm::None;
105 }
106
107 llvm::Optional<pybind11::object>
lookupRawOpViewClass(llvm::StringRef operationName)108 PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
109 {
110 py::gil_scoped_acquire();
111 auto foundIt = rawOpViewClassMapCache.find(operationName);
112 if (foundIt != rawOpViewClassMapCache.end()) {
113 if (foundIt->second.is_none())
114 return llvm::None;
115 assert(foundIt->second && "py::object is defined");
116 return foundIt->second;
117 }
118 }
119
120 // Not found. Load the dialect namespace.
121 auto split = operationName.split('.');
122 llvm::StringRef dialectNamespace = split.first;
123 loadDialectModule(dialectNamespace);
124
125 // Attempt to find from the canonical map and cache.
126 {
127 py::gil_scoped_acquire();
128 auto foundIt = rawOpViewClassMap.find(operationName);
129 if (foundIt != rawOpViewClassMap.end()) {
130 if (foundIt->second.is_none())
131 return llvm::None;
132 assert(foundIt->second && "py::object is defined");
133 // Positive cache.
134 rawOpViewClassMapCache[operationName] = foundIt->second;
135 return foundIt->second;
136 } else {
137 // Negative cache.
138 rawOpViewClassMap[operationName] = py::none();
139 return llvm::None;
140 }
141 }
142 }
143
clearImportCache()144 void PyGlobals::clearImportCache() {
145 py::gil_scoped_acquire();
146 loadedDialectModulesCache.clear();
147 rawOpViewClassMapCache.clear();
148 }
149
150 // -----------------------------------------------------------------------------
151 // Module initialization.
152 // -----------------------------------------------------------------------------
153
PYBIND11_MODULE(_mlir,m)154 PYBIND11_MODULE(_mlir, m) {
155 m.doc() = "MLIR Python Native Extension";
156
157 py::class_<PyGlobals>(m, "_Globals")
158 .def_property("dialect_search_modules",
159 &PyGlobals::getDialectSearchPrefixes,
160 &PyGlobals::setDialectSearchPrefixes)
161 .def("append_dialect_search_prefix",
162 [](PyGlobals &self, std::string moduleName) {
163 self.getDialectSearchPrefixes().push_back(std::move(moduleName));
164 self.clearImportCache();
165 })
166 .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
167 "Testing hook for directly registering a dialect")
168 .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
169 "Testing hook for directly registering an operation");
170
171 // Aside from making the globals accessible to python, having python manage
172 // it is necessary to make sure it is destroyed (and releases its python
173 // resources) properly.
174 m.attr("globals") =
175 py::cast(new PyGlobals, py::return_value_policy::take_ownership);
176
177 // Registration decorators.
178 m.def(
179 "register_dialect",
180 [](py::object pyClass) {
181 std::string dialectNamespace =
182 pyClass.attr("DIALECT_NAMESPACE").cast<std::string>();
183 PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
184 return pyClass;
185 },
186 "Class decorator for registering a custom Dialect wrapper");
187 m.def(
188 "register_operation",
189 [](py::object dialectClass) -> py::cpp_function {
190 return py::cpp_function(
191 [dialectClass](py::object opClass) -> py::object {
192 std::string operationName =
193 opClass.attr("OPERATION_NAME").cast<std::string>();
194 auto rawSubclass = PyOpView::createRawSubclass(opClass);
195 PyGlobals::get().registerOperationImpl(operationName, opClass,
196 rawSubclass);
197
198 // Dict-stuff the new opClass by name onto the dialect class.
199 py::object opClassName = opClass.attr("__name__");
200 dialectClass.attr(opClassName) = opClass;
201
202 // Now create a special "Raw" subclass that passes through
203 // construction to the OpView parent (bypasses the intermediate
204 // child's __init__).
205 opClass.attr("_Raw") = rawSubclass;
206 return opClass;
207 });
208 },
209 "Class decorator for registering a custom Operation wrapper");
210
211 // Define and populate IR submodule.
212 auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
213 populateIRSubmodule(irModule);
214
215 // Define and populate PassManager submodule.
216 auto passModule =
217 m.def_submodule("passmanager", "MLIR Pass Management Bindings");
218 populatePassManagerSubmodule(passModule);
219 }
220