mirror of
https://github.com/pybind/pybind11.git
synced 2026-04-19 22:39:09 +00:00
Added function for reloading module (#1040)
This commit is contained in:
committed by
Dean Moldovan
parent
2cf87a54d8
commit
c64e6b1670
@@ -2,6 +2,8 @@
|
||||
#include <catch.hpp>
|
||||
|
||||
#include <thread>
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
@@ -216,3 +218,52 @@ TEST_CASE("Threads") {
|
||||
|
||||
REQUIRE(locals["count"].cast<int>() == num_threads);
|
||||
}
|
||||
|
||||
// Scope exit utility https://stackoverflow.com/a/36644501/7255855
|
||||
struct scope_exit {
|
||||
std::function<void()> f_;
|
||||
explicit scope_exit(std::function<void()> f) noexcept : f_(std::move(f)) {}
|
||||
~scope_exit() { if (f_) f_(); }
|
||||
};
|
||||
|
||||
TEST_CASE("Reload module from file") {
|
||||
// Disable generation of cached bytecode (.pyc files) for this test, otherwise
|
||||
// Python might pick up an old version from the cache instead of the new versions
|
||||
// of the .py files generated below
|
||||
auto sys = py::module::import("sys");
|
||||
bool dont_write_bytecode = sys.attr("dont_write_bytecode").cast<bool>();
|
||||
sys.attr("dont_write_bytecode") = true;
|
||||
// Reset the value at scope exit
|
||||
scope_exit reset_dont_write_bytecode([&]() {
|
||||
sys.attr("dont_write_bytecode") = dont_write_bytecode;
|
||||
});
|
||||
|
||||
std::string module_name = "test_module_reload";
|
||||
std::string module_file = module_name + ".py";
|
||||
|
||||
// Create the module .py file
|
||||
std::ofstream test_module(module_file);
|
||||
test_module << "def test():\n";
|
||||
test_module << " return 1\n";
|
||||
test_module.close();
|
||||
// Delete the file at scope exit
|
||||
scope_exit delete_module_file([&]() {
|
||||
std::remove(module_file.c_str());
|
||||
});
|
||||
|
||||
// Import the module from file
|
||||
auto module = py::module::import(module_name.c_str());
|
||||
int result = module.attr("test")().cast<int>();
|
||||
REQUIRE(result == 1);
|
||||
|
||||
// Update the module .py file with a small change
|
||||
test_module.open(module_file);
|
||||
test_module << "def test():\n";
|
||||
test_module << " return 2\n";
|
||||
test_module.close();
|
||||
|
||||
// Reload the module
|
||||
module.reload();
|
||||
result = module.attr("test")().cast<int>();
|
||||
REQUIRE(result == 2);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user