diff --git a/python/src/py_nvbench.cpp b/python/src/py_nvbench.cpp index 8856e8e..766f983 100644 --- a/python/src/py_nvbench.cpp +++ b/python/src/py_nvbench.cpp @@ -35,8 +35,8 @@ namespace py = pybind11; -namespace -{ +// namespace +//{ struct PyObjectDeleter { @@ -61,6 +61,8 @@ struct PyObjectDeleter } }; +namespace +{ struct benchmark_wrapper_t { @@ -91,7 +93,14 @@ struct benchmark_wrapper_t auto arg = py::cast(std::ref(state), py::return_value_policy::reference); // Execute Python callable - (*m_fn)(arg); + try + { + (*m_fn)(arg); + } + catch (const py::error_already_set &e) + { + throw nvbench::stop_runner_loop(e.what()); + } } private: @@ -99,6 +108,7 @@ private: // since copy constructor must be const (benchmark::do_clone is const member method) std::shared_ptr m_fn; }; +} // namespace // Use struct to ensure public inheritance struct nvbench_run_error : std::runtime_error @@ -183,17 +193,20 @@ public: } catch (py::error_already_set &e) { + std::cout << "Caught error_already_set\n"; py::raise_from(e, benchmark_exc.ptr(), "Python error raised "); throw py::error_already_set(); } catch (const std::exception &e) { const std::string &exc_message = e.what(); + std::cout << "Caught std::exception " << exc_message << std::endl; py::set_error(benchmark_exc, exc_message.c_str()); throw py::error_already_set(); } catch (...) { + std::cout << "Got fall-through exception\n"; py::set_error(benchmark_exc, "Caught unknown exception in nvbench_main"); throw py::error_already_set(); } @@ -222,7 +235,7 @@ py::dict py_get_axis_values(const nvbench::state &state) // essentially a global variable, but allocated on the heap during module initialization std::unique_ptr global_registry{}; -} // end of anonymous namespace +//} // end of anonymous namespace // ========================================== // PLEASE KEEP IN SYNC WITH __init__.pyi FILE @@ -255,6 +268,7 @@ PYBIND11_MODULE(_nvbench, m) return std::make_pair(std::size_t{0}, reinterpret_cast(s.get_stream())); }); + py_cuda_stream_cls.def("addressof", [](const nvbench::cuda_stream &s) -> std::size_t { return reinterpret_cast(s.get_stream()); }); @@ -295,6 +309,7 @@ PYBIND11_MODULE(_nvbench, m) auto py_benchmark_cls = py::class_(m, "Benchmark"); py_benchmark_cls.def("get_name", &nvbench::benchmark_base::get_name); + py_benchmark_cls.def( "add_int64_axis", [](nvbench::benchmark_base &self, std::string name, std::vector data) { @@ -304,6 +319,7 @@ PYBIND11_MODULE(_nvbench, m) py::return_value_policy::reference, py::arg("name"), py::arg("values")); + py_benchmark_cls.def( "add_int64_power_of_two_axis", [](nvbench::benchmark_base &self, std::string name, std::vector data) { @@ -315,6 +331,7 @@ PYBIND11_MODULE(_nvbench, m) py::return_value_policy::reference, py::arg("name"), py::arg("values")); + py_benchmark_cls.def( "add_float64_axis", [](nvbench::benchmark_base &self, std::string name, std::vector data) { @@ -324,6 +341,7 @@ PYBIND11_MODULE(_nvbench, m) py::return_value_policy::reference, py::arg("name"), py::arg("values")); + py_benchmark_cls.def( "add_string_axis", [](nvbench::benchmark_base &self, std::string name, std::vector data) { @@ -333,6 +351,7 @@ PYBIND11_MODULE(_nvbench, m) py::return_value_policy::reference, py::arg("name"), py::arg("values")); + py_benchmark_cls.def( "set_name", [](nvbench::benchmark_base &self, std::string name) { @@ -341,6 +360,7 @@ PYBIND11_MODULE(_nvbench, m) }, py::return_value_policy::reference, py::arg("name")); + py_benchmark_cls.def( "set_is_cpu_only", [](nvbench::benchmark_base &self, bool is_cpu_only) { @@ -349,6 +369,7 @@ PYBIND11_MODULE(_nvbench, m) }, py::return_value_policy::reference, py::arg("is_cpu_only")); + // TODO: should this be exposed? py_benchmark_cls.def( "set_run_once", @@ -358,6 +379,7 @@ PYBIND11_MODULE(_nvbench, m) }, py::return_value_policy::reference, py::arg("run_once")); + py_benchmark_cls.def( "set_skip_time", [](nvbench::benchmark_base &self, nvbench::float64_t skip_duration_seconds) { @@ -366,6 +388,7 @@ PYBIND11_MODULE(_nvbench, m) }, py::return_value_policy::reference, py::arg("duration_seconds")); + py_benchmark_cls.def( "set_timeout", [](nvbench::benchmark_base &self, nvbench::float64_t duration_seconds) { @@ -374,6 +397,7 @@ PYBIND11_MODULE(_nvbench, m) }, py::return_value_policy::reference, py::arg("duration_seconds")); + py_benchmark_cls.def( "set_throttle_threshold", [](nvbench::benchmark_base &self, nvbench::float32_t threshold) { @@ -382,6 +406,7 @@ PYBIND11_MODULE(_nvbench, m) }, py::return_value_policy::reference, py::arg("threshold")); + py_benchmark_cls.def( "set_throttle_recovery_delay", [](nvbench::benchmark_base &self, nvbench::float32_t delay) { @@ -390,6 +415,7 @@ PYBIND11_MODULE(_nvbench, m) }, py::return_value_policy::reference, py::arg("delay_seconds")); + py_benchmark_cls.def( "set_stopping_criterion", [](nvbench::benchmark_base &self, std::string criterion) { @@ -398,6 +424,7 @@ PYBIND11_MODULE(_nvbench, m) }, py::return_value_policy::reference, py::arg("criterion")); + py_benchmark_cls.def( "set_criterion_param_int64", [](nvbench::benchmark_base &self, std::string name, nvbench::int64_t value) { @@ -407,6 +434,7 @@ PYBIND11_MODULE(_nvbench, m) py::return_value_policy::reference, py::arg("name"), py::arg("value")); + py_benchmark_cls.def( "set_criterion_param_float64", [](nvbench::benchmark_base &self, std::string name, nvbench::float64_t value) { @@ -416,6 +444,7 @@ PYBIND11_MODULE(_nvbench, m) py::return_value_policy::reference, py::arg("name"), py::arg("value")); + py_benchmark_cls.def( "set_criterion_param_string", [](nvbench::benchmark_base &self, std::string name, std::string value) { @@ -425,6 +454,7 @@ PYBIND11_MODULE(_nvbench, m) py::return_value_policy::reference, py::arg("name"), py::arg("value")); + py_benchmark_cls.def( "set_min_samples", [](nvbench::benchmark_base &self, nvbench::int64_t count) { @@ -508,9 +538,11 @@ PYBIND11_MODULE(_nvbench, m) pystate_cls.def("has_device", [](const nvbench::state &state) -> bool { return static_cast(state.get_device()); }); + pystate_cls.def("has_printers", [](const nvbench::state &state) -> bool { return state.get_benchmark().get_printer().has_value(); }); + pystate_cls.def("get_device", [](const nvbench::state &state) { auto dev = state.get_device(); if (dev.has_value()) @@ -550,6 +582,7 @@ PYBIND11_MODULE(_nvbench, m) &nvbench::state::add_element_count, py::arg("count"), py::arg("column_name") = py::str("")); + pystate_cls.def("set_element_count", &nvbench::state::set_element_count, py::arg("count")); pystate_cls.def("get_element_count", &nvbench::state::get_element_count); @@ -566,6 +599,7 @@ PYBIND11_MODULE(_nvbench, m) py::arg("nbytes"), py::pos_only{}, py::arg("column_name") = py::str("")); + pystate_cls.def( "add_global_memory_writes", [](nvbench::state &state, std::size_t nbytes, const std::string &column_name) -> void { @@ -575,10 +609,12 @@ PYBIND11_MODULE(_nvbench, m) py::arg("nbytes"), py::pos_only{}, py::arg("column_name") = py::str("")); + pystate_cls.def( "get_benchmark", [](const nvbench::state &state) { return std::ref(state.get_benchmark()); }, py::return_value_policy::reference); + pystate_cls.def("get_throttle_threshold", &nvbench::state::get_throttle_threshold); pystate_cls.def("set_throttle_threshold", &nvbench::state::set_throttle_threshold, @@ -590,22 +626,27 @@ PYBIND11_MODULE(_nvbench, m) py::arg("min_samples_count")); pystate_cls.def("get_disable_blocking_kernel", &nvbench::state::get_disable_blocking_kernel); + pystate_cls.def("set_disable_blocking_kernel", &nvbench::state::set_disable_blocking_kernel, py::arg("disable_blocking_kernel")); pystate_cls.def("get_run_once", &nvbench::state::get_run_once); + pystate_cls.def("set_run_once", &nvbench::state::set_run_once, py::arg("run_once")); pystate_cls.def("get_timeout", &nvbench::state::get_timeout); + pystate_cls.def("set_timeout", &nvbench::state::set_timeout, py::arg("duration")); pystate_cls.def("get_blocking_kernel_timeout", &nvbench::state::get_blocking_kernel_timeout); + pystate_cls.def("set_blocking_kernel_timeout", &nvbench::state::set_blocking_kernel_timeout, py::arg("duration")); pystate_cls.def("collect_cupti_metrics", &nvbench::state::collect_cupti_metrics); + pystate_cls.def("is_cupti_required", &nvbench::state::is_cupti_required); pystate_cls.def( @@ -670,6 +711,7 @@ PYBIND11_MODULE(_nvbench, m) }, py::arg("name"), py::arg("value")); + pystate_cls.def( "add_summary", [](nvbench::state &state, std::string column_name, std::int64_t value) { @@ -680,6 +722,7 @@ PYBIND11_MODULE(_nvbench, m) }, py::arg("name"), py::arg("value")); + pystate_cls.def( "add_summary", [](nvbench::state &state, std::string column_name, double value) { @@ -690,6 +733,7 @@ PYBIND11_MODULE(_nvbench, m) }, py::arg("name"), py::arg("value")); + pystate_cls.def("get_axis_values_as_string", [](const nvbench::state &state) { return state.get_axis_values_as_string(); }); pystate_cls.def("get_axis_values", &py_get_axis_values);