Replace use of py::handle to store global_registry

Use py::gil_safe_call_once_and_store facility pybind11 provides.
This commit is contained in:
Oleksandr Pavlyk
2025-12-09 14:02:42 -06:00
parent 39c29026fd
commit 8ff0557ad8

View File

@@ -124,7 +124,8 @@ struct nvbench_run_error : std::runtime_error
// that are defined for the base class
using std::runtime_error::runtime_error;
};
py::handle benchmark_exc{};
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object> exc_storage;
void run_interruptible(nvbench::option_parser &parser)
{
@@ -223,18 +224,18 @@ public:
}
catch (py::error_already_set &e)
{
py::raise_from(e, benchmark_exc.ptr(), "Python error raised ");
py::raise_from(e, exc_storage.get_stored().ptr(), "Python error raised ");
throw py::error_already_set();
}
catch (const std::exception &e)
{
const std::string &exc_message = e.what();
py::set_error(benchmark_exc, exc_message.c_str());
py::set_error(exc_storage.get_stored(), exc_message.c_str());
throw py::error_already_set();
}
catch (...)
{
py::set_error(benchmark_exc, "Caught unknown exception in nvbench_main");
py::set_error(exc_storage.get_stored(), "Caught unknown exception in nvbench_main");
throw py::error_already_set();
}
}
@@ -1162,11 +1163,12 @@ PYBIND11_MODULE(PYBIND11_MODULE_NAME, m)
static constexpr const char *exception_nvbench_runtime_error_doc = R"XXXX(
An exception raised if running benchmarks encounters an error
)XXXX";
py::object benchmark_exc_ =
py::exception<nvbench_run_error>(m, "NVBenchRuntimeError", PyExc_RuntimeError);
benchmark_exc_.attr("__doc__") = exception_nvbench_runtime_error_doc;
benchmark_exc = benchmark_exc_.release();
exc_storage.call_once_and_store_result([&]() {
py::object benchmark_exc_ =
py::exception<nvbench_run_error>(m, "NVBenchRuntimeError", PyExc_RuntimeError);
benchmark_exc_.attr("__doc__") = exception_nvbench_runtime_error_doc;
return benchmark_exc_;
});
// ATTN: nvbench::benchmark_manager is a singleton, it is exposed through
// GlobalBenchmarkRegistry class
@@ -1175,7 +1177,7 @@ An exception raised if running benchmarks encounters an error
py::nodelete{});
// function register
auto func_register_impl = [&](py::object fn) { return std::ref(global_registry->add_bench(fn)); };
auto func_register_impl = [](py::object fn) { return std::ref(global_registry->add_bench(fn)); };
static constexpr const char *func_register_doc = R"XXXX(
Register benchmark function of type Callable[[nvbench.State], None]
)XXXX";
@@ -1210,7 +1212,7 @@ Register benchmark function of type Callable[[nvbench.State], None]
// Testing utilities
m.def("test_cpp_exception", []() { throw nvbench_run_error("Test"); });
m.def("test_py_exception", []() {
py::set_error(benchmark_exc, "Test");
py::set_error(exc_storage.get_stored(), "Test");
throw py::error_already_set();
});
}