mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-04-20 06:48:53 +00:00
nvbench.State.exec validates arg to be a callable
Add names to method arguments to make it more self-descriptive.
This commit is contained in:
@@ -259,7 +259,9 @@ PYBIND11_MODULE(_nvbench, m)
|
||||
self.add_int64_axis(std::move(name), std::move(data));
|
||||
return std::ref(self);
|
||||
},
|
||||
py::return_value_policy::reference);
|
||||
py::return_value_policy::reference,
|
||||
py::arg("name"),
|
||||
py::arg("values"));
|
||||
py_benchmark_cls.def(
|
||||
"add_int64_power_of_two_axis",
|
||||
[](nvbench::benchmark_base &self, std::string name, std::vector<nvbench::int64_t> data) {
|
||||
@@ -268,42 +270,51 @@ PYBIND11_MODULE(_nvbench, m)
|
||||
nvbench::int64_axis_flags::power_of_two);
|
||||
return std::ref(self);
|
||||
},
|
||||
py::return_value_policy::reference);
|
||||
py::return_value_policy::reference,
|
||||
py::arg("name"),
|
||||
py::arg("values"));
|
||||
py_benchmark_cls.def(
|
||||
"add_float64_axis",
|
||||
[](nvbench::benchmark_base &self, std::string name, std::vector<nvbench::float64_t> data) {
|
||||
self.add_float64_axis(std::move(name), std::move(data));
|
||||
return std::ref(self);
|
||||
},
|
||||
py::return_value_policy::reference);
|
||||
py::return_value_policy::reference,
|
||||
py::arg("name"),
|
||||
py::arg("values"));
|
||||
py_benchmark_cls.def(
|
||||
"add_string_axis",
|
||||
[](nvbench::benchmark_base &self, std::string name, std::vector<std::string> data) {
|
||||
self.add_string_axis(std::move(name), std::move(data));
|
||||
return std::ref(self);
|
||||
},
|
||||
py::return_value_policy::reference);
|
||||
py::return_value_policy::reference,
|
||||
py::arg("name"),
|
||||
py::arg("values"));
|
||||
py_benchmark_cls.def(
|
||||
"set_name",
|
||||
[](nvbench::benchmark_base &self, std::string name) {
|
||||
self.set_name(std::move(name));
|
||||
return std::ref(self);
|
||||
},
|
||||
py::return_value_policy::reference);
|
||||
py::return_value_policy::reference,
|
||||
py::arg("name"));
|
||||
py_benchmark_cls.def(
|
||||
"set_is_cpu_only",
|
||||
[](nvbench::benchmark_base &self, bool is_cpu_only) {
|
||||
self.set_is_cpu_only(is_cpu_only);
|
||||
return std::ref(self);
|
||||
},
|
||||
py::return_value_policy::reference);
|
||||
py::return_value_policy::reference,
|
||||
py::arg("is_cpu_only"));
|
||||
py_benchmark_cls.def(
|
||||
"set_run_once",
|
||||
[](nvbench::benchmark_base &self, bool v) {
|
||||
self.set_run_once(v);
|
||||
[](nvbench::benchmark_base &self, bool run_once) {
|
||||
self.set_run_once(run_once);
|
||||
return std::ref(self);
|
||||
},
|
||||
py::return_value_policy::reference);
|
||||
py::return_value_policy::reference,
|
||||
py::arg("run_once"));
|
||||
|
||||
// == STEP 5
|
||||
// Define PyState class
|
||||
@@ -421,7 +432,7 @@ PYBIND11_MODULE(_nvbench, m)
|
||||
&nvbench::state::add_element_count,
|
||||
py::arg("count"),
|
||||
py::arg("column_name") = py::str(""));
|
||||
pystate_cls.def("set_element_count", &nvbench::state::set_element_count);
|
||||
pystate_cls.def("set_element_count", &nvbench::state::set_element_count, py::arg("count"));
|
||||
pystate_cls.def("get_element_count", &nvbench::state::get_element_count);
|
||||
|
||||
pystate_cls.def("skip", &nvbench::state::skip, py::arg("reason"));
|
||||
@@ -478,40 +489,49 @@ PYBIND11_MODULE(_nvbench, m)
|
||||
|
||||
pystate_cls.def(
|
||||
"exec",
|
||||
[](nvbench::state &state, py::object callable_fn, bool batched, bool sync) {
|
||||
[](nvbench::state &state, py::object py_launcher_fn, bool batched, bool sync) {
|
||||
if (!PyCallable_Check(py_launcher_fn.ptr()))
|
||||
{
|
||||
throw py::type_error("Argument of exec method must be a callable object");
|
||||
}
|
||||
|
||||
// wrapper to invoke Python callable
|
||||
auto launcher_fn = [callable_fn](nvbench::launch &launch_descr) -> void {
|
||||
auto cpp_launcher_fn = [py_launcher_fn](nvbench::launch &launch_descr) -> void {
|
||||
// cast C++ object to python object
|
||||
auto launch_pyarg = py::cast(std::ref(launch_descr), py::return_value_policy::reference);
|
||||
// call Python callable
|
||||
callable_fn(launch_pyarg);
|
||||
py_launcher_fn(launch_pyarg);
|
||||
};
|
||||
|
||||
if (sync)
|
||||
{
|
||||
if (batched)
|
||||
{
|
||||
state.exec(nvbench::exec_tag::sync, launcher_fn);
|
||||
constexpr auto tag = nvbench::exec_tag::sync;
|
||||
state.exec(tag, cpp_launcher_fn);
|
||||
}
|
||||
else
|
||||
{
|
||||
state.exec(nvbench::exec_tag::sync | nvbench::exec_tag::no_batch, launcher_fn);
|
||||
constexpr auto tag = nvbench::exec_tag::sync | nvbench::exec_tag::no_batch;
|
||||
state.exec(tag, cpp_launcher_fn);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (batched)
|
||||
{
|
||||
state.exec(nvbench::exec_tag::none, launcher_fn);
|
||||
constexpr auto tag = nvbench::exec_tag::none;
|
||||
state.exec(tag, cpp_launcher_fn);
|
||||
}
|
||||
else
|
||||
{
|
||||
state.exec(nvbench::exec_tag::no_batch, launcher_fn);
|
||||
constexpr auto tag = nvbench::exec_tag::no_batch;
|
||||
state.exec(tag, cpp_launcher_fn);
|
||||
}
|
||||
}
|
||||
},
|
||||
"Executor for given callable fn(state : Launch)",
|
||||
py::arg("fn"),
|
||||
"Executor for given launcher callable fn(state : Launch)",
|
||||
py::arg("launcher_fn"),
|
||||
py::pos_only{},
|
||||
py::arg("batched") = true,
|
||||
py::arg("sync") = false);
|
||||
@@ -527,7 +547,7 @@ PYBIND11_MODULE(_nvbench, m)
|
||||
summ.set_string("name", std::move(column_name));
|
||||
summ.set_string("value", std::move(value));
|
||||
},
|
||||
py::arg("column_name"),
|
||||
py::arg("name"),
|
||||
py::arg("value"));
|
||||
pystate_cls.def(
|
||||
"add_summary",
|
||||
|
||||
Reference in New Issue
Block a user