1 #include <pybind11/embed.h>
2 
3 #ifdef _MSC_VER
4 // Silence MSVC C++17 deprecation warning from Catch regarding std::uncaught_exceptions (up to catch
5 // 2.0.1; this should be fixed in the next catch release after 2.0.1).
6 #  pragma warning(disable: 4996)
7 #endif
8 
9 #include <catch.hpp>
10 
11 #include <thread>
12 #include <fstream>
13 #include <functional>
14 
15 namespace py = pybind11;
16 using namespace py::literals;
17 
18 class Widget {
19 public:
Widget(std::string message)20     Widget(std::string message) : message(message) { }
21     virtual ~Widget() = default;
22 
the_message() const23     std::string the_message() const { return message; }
24     virtual int the_answer() const = 0;
25 
26 private:
27     std::string message;
28 };
29 
30 class PyWidget final : public Widget {
31     using Widget::Widget;
32 
the_answer() const33     int the_answer() const override { PYBIND11_OVERRIDE_PURE(int, Widget, the_answer); }
34 };
35 
PYBIND11_EMBEDDED_MODULE(widget_module,m)36 PYBIND11_EMBEDDED_MODULE(widget_module, m) {
37     py::class_<Widget, PyWidget>(m, "Widget")
38         .def(py::init<std::string>())
39         .def_property_readonly("the_message", &Widget::the_message);
40 
41     m.def("add", [](int i, int j) { return i + j; });
42 }
43 
44 PYBIND11_EMBEDDED_MODULE(throw_exception, ) {
45     throw std::runtime_error("C++ Error");
46 }
47 
48 PYBIND11_EMBEDDED_MODULE(throw_error_already_set, ) {
49     auto d = py::dict();
50     d["missing"].cast<py::object>();
51 }
52 
53 TEST_CASE("Pass classes and data between modules defined in C++ and Python") {
54     auto module_ = py::module_::import("test_interpreter");
55     REQUIRE(py::hasattr(module_, "DerivedWidget"));
56 
57     auto locals = py::dict("hello"_a="Hello, World!", "x"_a=5, **module_.attr("__dict__"));
58     py::exec(R"(
59         widget = DerivedWidget("{} - {}".format(hello, x))
60         message = widget.the_message
61     )", py::globals(), locals);
62     REQUIRE(locals["message"].cast<std::string>() == "Hello, World! - 5");
63 
64     auto py_widget = module_.attr("DerivedWidget")("The question");
65     auto message = py_widget.attr("the_message");
66     REQUIRE(message.cast<std::string>() == "The question");
67 
68     const auto &cpp_widget = py_widget.cast<const Widget &>();
69     REQUIRE(cpp_widget.the_answer() == 42);
70 }
71 
72 TEST_CASE("Import error handling") {
73     REQUIRE_NOTHROW(py::module_::import("widget_module"));
74     REQUIRE_THROWS_WITH(py::module_::import("throw_exception"),
75                         "ImportError: C++ Error");
76     REQUIRE_THROWS_WITH(py::module_::import("throw_error_already_set"),
77                         Catch::Contains("ImportError: KeyError"));
78 }
79 
80 TEST_CASE("There can be only one interpreter") {
81     static_assert(std::is_move_constructible<py::scoped_interpreter>::value, "");
82     static_assert(!std::is_move_assignable<py::scoped_interpreter>::value, "");
83     static_assert(!std::is_copy_constructible<py::scoped_interpreter>::value, "");
84     static_assert(!std::is_copy_assignable<py::scoped_interpreter>::value, "");
85 
86     REQUIRE_THROWS_WITH(py::initialize_interpreter(), "The interpreter is already running");
87     REQUIRE_THROWS_WITH(py::scoped_interpreter(), "The interpreter is already running");
88 
89     py::finalize_interpreter();
90     REQUIRE_NOTHROW(py::scoped_interpreter());
91     {
92         auto pyi1 = py::scoped_interpreter();
93         auto pyi2 = std::move(pyi1);
94     }
95     py::initialize_interpreter();
96 }
97 
has_pybind11_internals_builtin()98 bool has_pybind11_internals_builtin() {
99     auto builtins = py::handle(PyEval_GetBuiltins());
100     return builtins.contains(PYBIND11_INTERNALS_ID);
101 };
102 
has_pybind11_internals_static()103 bool has_pybind11_internals_static() {
104     auto **&ipp = py::detail::get_internals_pp();
105     return ipp && *ipp;
106 }
107 
108 TEST_CASE("Restart the interpreter") {
109     // Verify pre-restart state.
110     REQUIRE(py::module_::import("widget_module").attr("add")(1, 2).cast<int>() == 3);
111     REQUIRE(has_pybind11_internals_builtin());
112     REQUIRE(has_pybind11_internals_static());
113     REQUIRE(py::module_::import("external_module").attr("A")(123).attr("value").cast<int>() == 123);
114 
115     // local and foreign module internals should point to the same internals:
116     REQUIRE(reinterpret_cast<uintptr_t>(*py::detail::get_internals_pp()) ==
117             py::module_::import("external_module").attr("internals_at")().cast<uintptr_t>());
118 
119     // Restart the interpreter.
120     py::finalize_interpreter();
121     REQUIRE(Py_IsInitialized() == 0);
122 
123     py::initialize_interpreter();
124     REQUIRE(Py_IsInitialized() == 1);
125 
126     // Internals are deleted after a restart.
127     REQUIRE_FALSE(has_pybind11_internals_builtin());
128     REQUIRE_FALSE(has_pybind11_internals_static());
129     pybind11::detail::get_internals();
130     REQUIRE(has_pybind11_internals_builtin());
131     REQUIRE(has_pybind11_internals_static());
132     REQUIRE(reinterpret_cast<uintptr_t>(*py::detail::get_internals_pp()) ==
133             py::module_::import("external_module").attr("internals_at")().cast<uintptr_t>());
134 
135     // Make sure that an interpreter with no get_internals() created until finalize still gets the
136     // internals destroyed
137     py::finalize_interpreter();
138     py::initialize_interpreter();
139     bool ran = false;
140     py::module_::import("__main__").attr("internals_destroy_test") =
__anonba480a4a0202(void *ran) 141         py::capsule(&ran, [](void *ran) { py::detail::get_internals(); *static_cast<bool *>(ran) = true; });
142     REQUIRE_FALSE(has_pybind11_internals_builtin());
143     REQUIRE_FALSE(has_pybind11_internals_static());
144     REQUIRE_FALSE(ran);
145     py::finalize_interpreter();
146     REQUIRE(ran);
147     py::initialize_interpreter();
148     REQUIRE_FALSE(has_pybind11_internals_builtin());
149     REQUIRE_FALSE(has_pybind11_internals_static());
150 
151     // C++ modules can be reloaded.
152     auto cpp_module = py::module_::import("widget_module");
153     REQUIRE(cpp_module.attr("add")(1, 2).cast<int>() == 3);
154 
155     // C++ type information is reloaded and can be used in python modules.
156     auto py_module = py::module_::import("test_interpreter");
157     auto py_widget = py_module.attr("DerivedWidget")("Hello after restart");
158     REQUIRE(py_widget.attr("the_message").cast<std::string>() == "Hello after restart");
159 }
160 
161 TEST_CASE("Subinterpreter") {
162     // Add tags to the modules in the main interpreter and test the basics.
163     py::module_::import("__main__").attr("main_tag") = "main interpreter";
164     {
165         auto m = py::module_::import("widget_module");
166         m.attr("extension_module_tag") = "added to module in main interpreter";
167 
168         REQUIRE(m.attr("add")(1, 2).cast<int>() == 3);
169     }
170     REQUIRE(has_pybind11_internals_builtin());
171     REQUIRE(has_pybind11_internals_static());
172 
173     /// Create and switch to a subinterpreter.
174     auto main_tstate = PyThreadState_Get();
175     auto sub_tstate = Py_NewInterpreter();
176 
177     // Subinterpreters get their own copy of builtins. detail::get_internals() still
178     // works by returning from the static variable, i.e. all interpreters share a single
179     // global pybind11::internals;
180     REQUIRE_FALSE(has_pybind11_internals_builtin());
181     REQUIRE(has_pybind11_internals_static());
182 
183     // Modules tags should be gone.
184     REQUIRE_FALSE(py::hasattr(py::module_::import("__main__"), "tag"));
185     {
186         auto m = py::module_::import("widget_module");
187         REQUIRE_FALSE(py::hasattr(m, "extension_module_tag"));
188 
189         // Function bindings should still work.
190         REQUIRE(m.attr("add")(1, 2).cast<int>() == 3);
191     }
192 
193     // Restore main interpreter.
194     Py_EndInterpreter(sub_tstate);
195     PyThreadState_Swap(main_tstate);
196 
197     REQUIRE(py::hasattr(py::module_::import("__main__"), "main_tag"));
198     REQUIRE(py::hasattr(py::module_::import("widget_module"), "extension_module_tag"));
199 }
200 
201 TEST_CASE("Execution frame") {
202     // When the interpreter is embedded, there is no execution frame, but `py::exec`
203     // should still function by using reasonable globals: `__main__.__dict__`.
204     py::exec("var = dict(number=42)");
205     REQUIRE(py::globals()["var"]["number"].cast<int>() == 42);
206 }
207 
208 TEST_CASE("Threads") {
209     // Restart interpreter to ensure threads are not initialized
210     py::finalize_interpreter();
211     py::initialize_interpreter();
212     REQUIRE_FALSE(has_pybind11_internals_static());
213 
214     constexpr auto num_threads = 10;
215     auto locals = py::dict("count"_a=0);
216 
217     {
218         py::gil_scoped_release gil_release{};
219         REQUIRE(has_pybind11_internals_static());
220 
221         auto threads = std::vector<std::thread>();
222         for (auto i = 0; i < num_threads; ++i) {
__anonba480a4a0302() 223             threads.emplace_back([&]() {
224                 py::gil_scoped_acquire gil{};
225                 locals["count"] = locals["count"].cast<int>() + 1;
226             });
227         }
228 
229         for (auto &thread : threads) {
230             thread.join();
231         }
232     }
233 
234     REQUIRE(locals["count"].cast<int>() == num_threads);
235 }
236 
237 // Scope exit utility https://stackoverflow.com/a/36644501/7255855
238 struct scope_exit {
239     std::function<void()> f_;
scope_exitscope_exit240     explicit scope_exit(std::function<void()> f) noexcept : f_(std::move(f)) {}
~scope_exitscope_exit241     ~scope_exit() { if (f_) f_(); }
242 };
243 
244 TEST_CASE("Reload module from file") {
245     // Disable generation of cached bytecode (.pyc files) for this test, otherwise
246     // Python might pick up an old version from the cache instead of the new versions
247     // of the .py files generated below
248     auto sys = py::module_::import("sys");
249     bool dont_write_bytecode = sys.attr("dont_write_bytecode").cast<bool>();
250     sys.attr("dont_write_bytecode") = true;
251     // Reset the value at scope exit
__anonba480a4a0402() 252     scope_exit reset_dont_write_bytecode([&]() {
253         sys.attr("dont_write_bytecode") = dont_write_bytecode;
254     });
255 
256     std::string module_name = "test_module_reload";
257     std::string module_file = module_name + ".py";
258 
259     // Create the module .py file
260     std::ofstream test_module(module_file);
261     test_module << "def test():\n";
262     test_module << "    return 1\n";
263     test_module.close();
264     // Delete the file at scope exit
__anonba480a4a0502() 265     scope_exit delete_module_file([&]() {
266         std::remove(module_file.c_str());
267     });
268 
269     // Import the module from file
270     auto module_ = py::module_::import(module_name.c_str());
271     int result = module_.attr("test")().cast<int>();
272     REQUIRE(result == 1);
273 
274     // Update the module .py file with a small change
275     test_module.open(module_file);
276     test_module << "def test():\n";
277     test_module << "    return 2\n";
278     test_module.close();
279 
280     // Reload the module
281     module_.reload();
282     result = module_.attr("test")().cast<int>();
283     REQUIRE(result == 2);
284 }
285