Avoid overloading get_int64_or_default as get_int64

Introduce get_int64_or_default method, and counterparts for
float64 and string.

Provided names for Python arguments.

Tried generating Python stubs automatically with

```
stubgen -m cuda.nvbench._nvbench
```

Gave up on this, since it does not include doc-strings.
It would be nice to compare auto-generated _nvbench.pyi with
__init__.pyi for discrepancies though.
This commit is contained in:
Oleksandr Pavlyk
2025-07-22 16:13:44 -05:00
parent dc7f9edfd4
commit 51fa07fab8
2 changed files with 78 additions and 40 deletions

View File

@@ -98,13 +98,22 @@ class State:
def get_stream(self) -> CudaStream: def get_stream(self) -> CudaStream:
"CudaStream object from this configuration" "CudaStream object from this configuration"
... ...
def get_int64(self, name: str, default_value: Optional[int] = None) -> int: def get_int64(self, name: str) -> int:
"Get value for given Int64 axis from this configuration" "Get value for given Int64 axis from this configuration"
... ...
def get_float64(self, name: str, default_value: Optional[float] = None) -> float: def get_int64_or_default_value(self, name: str, default_value: int) -> int:
"Get value for given Int64 axis from this configuration"
...
def get_float64(self, name: str) -> float:
"Get value for given Float64 axis from this configuration" "Get value for given Float64 axis from this configuration"
... ...
def get_string(self, name: str, default_value: Optional[str] = None) -> str: def get_float64_or_default_value(self, name: str, default_value: float) -> float:
"Get value for given Float64 axis from this configuration"
...
def get_string(self, name: str) -> str:
"Get value for given String axis from this configuration"
...
def get_string_or_default_value(self, name: str, default_value: str) -> str:
"Get value for given String axis from this configuration" "Get value for given String axis from this configuration"
... ...
def add_element_count(self, count: int, column_name: Optional[str] = None) -> None: def add_element_count(self, count: int, column_name: Optional[str] = None) -> None:
@@ -140,7 +149,7 @@ class State:
def get_min_samples(self) -> int: def get_min_samples(self) -> int:
"Get the number of benchmark timings NVBench performs before stopping criterion begins being used" "Get the number of benchmark timings NVBench performs before stopping criterion begins being used"
... ...
def set_min_samples(self, count: int) -> None: def set_min_samples(self, min_samples_count: int) -> None:
"Set the number of benchmark timings for NVBench to perform before stopping criterion begins being used" "Set the number of benchmark timings for NVBench to perform before stopping criterion begins being used"
... ...
def get_disable_blocking_kernel(self) -> bool: def get_disable_blocking_kernel(self) -> bool:
@@ -152,20 +161,20 @@ class State:
def get_run_once(self) -> bool: def get_run_once(self) -> bool:
"Boolean flag whether configuration should only run once" "Boolean flag whether configuration should only run once"
... ...
def set_run_once(self, flag: bool) -> None: def set_run_once(self, run_once_flag: bool) -> None:
"Set run-once flag for this configuration" "Set run-once flag for this configuration"
... ...
def get_timeout(self) -> float: def get_timeout(self) -> float:
"Get time-out value for benchmark execution of this configuration" "Get time-out value for benchmark execution of this configuration"
... ...
def set_timeout(self, duration: float) -> None: def set_timeout(self, duration: float) -> None:
"Set time-out value for benchmark execution of this configuration" "Set time-out value for benchmark execution of this configuration, in seconds"
... ...
def get_blocking_kernel_timeout(self) -> float: def get_blocking_kernel_timeout(self) -> float:
"Get time-out value for execution of blocking kernel" "Get time-out value for execution of blocking kernel"
... ...
def set_blocking_kernel_timeout(self, duration: float) -> None: def set_blocking_kernel_timeout(self, duration: float) -> None:
"Set time-out value for execution of blocking kernel" "Set time-out value for execution of blocking kernel, in seconds"
... ...
def collect_cupti_metrics(self) -> None: def collect_cupti_metrics(self) -> None:
"Request NVBench to record CUPTI metrics while running benchmark for this configuration" "Request NVBench to record CUPTI metrics while running benchmark for this configuration"

View File

@@ -405,14 +405,26 @@ PYBIND11_MODULE(_nvbench, m)
[](nvbench::state &state) { return std::ref(state.get_cuda_stream()); }, [](nvbench::state &state) { return std::ref(state.get_cuda_stream()); },
py::return_value_policy::reference); py::return_value_policy::reference);
pystate_cls.def("get_int64", &nvbench::state::get_int64); pystate_cls.def("get_int64", &nvbench::state::get_int64, py::arg("name"));
pystate_cls.def("get_int64", &nvbench::state::get_int64_or_default); pystate_cls.def("get_int64_or_default",
&nvbench::state::get_int64_or_default,
py::arg("name"),
py::pos_only{},
py::arg("default_value"));
pystate_cls.def("get_float64", &nvbench::state::get_float64); pystate_cls.def("get_float64", &nvbench::state::get_float64, py::arg("name"));
pystate_cls.def("get_float64", &nvbench::state::get_float64_or_default); pystate_cls.def("get_float64_or_default",
&nvbench::state::get_float64_or_default,
py::arg("name"),
py::pos_only{},
py::arg("default_value"));
pystate_cls.def("get_string", &nvbench::state::get_string); pystate_cls.def("get_string", &nvbench::state::get_string, py::arg("name"));
pystate_cls.def("get_string", &nvbench::state::get_string_or_default); pystate_cls.def("get_string_or_default",
&nvbench::state::get_string_or_default,
py::arg("name"),
py::pos_only{},
py::arg("default_value"));
pystate_cls.def("add_element_count", pystate_cls.def("add_element_count",
&nvbench::state::add_element_count, &nvbench::state::add_element_count,
@@ -421,7 +433,7 @@ PYBIND11_MODULE(_nvbench, m)
pystate_cls.def("set_element_count", &nvbench::state::set_element_count); pystate_cls.def("set_element_count", &nvbench::state::set_element_count);
pystate_cls.def("get_element_count", &nvbench::state::get_element_count); pystate_cls.def("get_element_count", &nvbench::state::get_element_count);
pystate_cls.def("skip", &nvbench::state::skip); pystate_cls.def("skip", &nvbench::state::skip, py::arg("reason"));
pystate_cls.def("is_skipped", &nvbench::state::is_skipped); pystate_cls.def("is_skipped", &nvbench::state::is_skipped);
pystate_cls.def("get_skip_reason", &nvbench::state::get_skip_reason); pystate_cls.def("get_skip_reason", &nvbench::state::get_skip_reason);
@@ -450,19 +462,25 @@ PYBIND11_MODULE(_nvbench, m)
pystate_cls.def("get_throttle_threshold", &nvbench::state::get_throttle_threshold); pystate_cls.def("get_throttle_threshold", &nvbench::state::get_throttle_threshold);
pystate_cls.def("get_min_samples", &nvbench::state::get_min_samples); pystate_cls.def("get_min_samples", &nvbench::state::get_min_samples);
pystate_cls.def("set_min_samples", &nvbench::state::set_min_samples); pystate_cls.def("set_min_samples",
&nvbench::state::set_min_samples,
py::arg("min_samples_count"));
pystate_cls.def("get_disable_blocking_kernel", &nvbench::state::get_disable_blocking_kernel); pystate_cls.def("get_disable_blocking_kernel", &nvbench::state::get_disable_blocking_kernel);
pystate_cls.def("set_disable_blocking_kernel", &nvbench::state::set_disable_blocking_kernel); pystate_cls.def("set_disable_blocking_kernel",
&nvbench::state::set_disable_blocking_kernel,
py::arg("disable_blocking_kernel"));
pystate_cls.def("get_run_once", &nvbench::state::get_run_once); pystate_cls.def("get_run_once", &nvbench::state::get_run_once);
pystate_cls.def("set_run_once", &nvbench::state::set_run_once); pystate_cls.def("set_run_once", &nvbench::state::set_run_once, py::arg("run_once"));
pystate_cls.def("get_timeout", &nvbench::state::get_timeout); pystate_cls.def("get_timeout", &nvbench::state::get_timeout);
pystate_cls.def("set_timeout", &nvbench::state::set_timeout); pystate_cls.def("set_timeout", &nvbench::state::set_timeout, py::arg("duration"));
pystate_cls.def("get_blocking_kernel_timeout", &nvbench::state::get_blocking_kernel_timeout); pystate_cls.def("get_blocking_kernel_timeout", &nvbench::state::get_blocking_kernel_timeout);
pystate_cls.def("set_blocking_kernel_timeout", &nvbench::state::set_blocking_kernel_timeout); pystate_cls.def("set_blocking_kernel_timeout",
&nvbench::state::set_blocking_kernel_timeout,
py::arg("duration"));
pystate_cls.def("collect_cupti_metrics", &nvbench::state::collect_cupti_metrics); pystate_cls.def("collect_cupti_metrics", &nvbench::state::collect_cupti_metrics);
pystate_cls.def("is_cupti_required", &nvbench::state::is_cupti_required); pystate_cls.def("is_cupti_required", &nvbench::state::is_cupti_required);
@@ -510,26 +528,36 @@ PYBIND11_MODULE(_nvbench, m)
pystate_cls.def("get_short_description", pystate_cls.def("get_short_description",
[](const nvbench::state &state) { return state.get_short_description(); }); [](const nvbench::state &state) { return state.get_short_description(); });
pystate_cls.def("add_summary", pystate_cls.def(
[](nvbench::state &state, std::string column_name, std::string value) { "add_summary",
auto &summ = state.add_summary("nv/python/" + column_name); [](nvbench::state &state, std::string column_name, std::string value) {
summ.set_string("description", "User tag: " + column_name); auto &summ = state.add_summary("nv/python/" + column_name);
summ.set_string("name", std::move(column_name)); summ.set_string("description", "User tag: " + column_name);
summ.set_string("value", std::move(value)); summ.set_string("name", std::move(column_name));
}); summ.set_string("value", std::move(value));
pystate_cls.def("add_summary", },
[](nvbench::state &state, std::string column_name, std::int64_t value) { py::arg("column_name"),
auto &summ = state.add_summary("nv/python/" + column_name); py::arg("value"));
summ.set_string("description", "User tag: " + column_name); pystate_cls.def(
summ.set_string("name", std::move(column_name)); "add_summary",
summ.set_int64("value", value); [](nvbench::state &state, std::string column_name, std::int64_t value) {
}); auto &summ = state.add_summary("nv/python/" + column_name);
pystate_cls.def("add_summary", [](nvbench::state &state, std::string column_name, double value) { summ.set_string("description", "User tag: " + column_name);
auto &summ = state.add_summary("nv/python/" + column_name); summ.set_string("name", std::move(column_name));
summ.set_string("description", "User tag: " + column_name); summ.set_int64("value", value);
summ.set_string("name", std::move(column_name)); },
summ.set_float64("value", value); py::arg("name"),
}); py::arg("value"));
pystate_cls.def(
"add_summary",
[](nvbench::state &state, std::string column_name, double value) {
auto &summ = state.add_summary("nv/python/" + column_name);
summ.set_string("description", "User tag: " + column_name);
summ.set_string("name", std::move(column_name));
summ.set_float64("value", value);
},
py::arg("name"),
py::arg("value"));
// Use handle to take a memory leak here, since this object's destructor may be called after // Use handle to take a memory leak here, since this object's destructor may be called after
// interpreter has shut down // interpreter has shut down
@@ -546,7 +574,8 @@ PYBIND11_MODULE(_nvbench, m)
"register", "register",
[&](py::object fn) { return std::ref(global_registry->add_bench(fn)); }, [&](py::object fn) { return std::ref(global_registry->add_bench(fn)); },
"Register benchmark function of type Callable[[nvbench.State], None]", "Register benchmark function of type Callable[[nvbench.State], None]",
py::return_value_policy::reference); py::return_value_policy::reference,
py::arg("benchmark_fn"));
m.def( m.def(
"run_all_benchmarks", "run_all_benchmarks",