Add python api for cold warmup parameters (#363)

This commit is contained in:
Oleksandr Pavlyk
2026-05-18 10:56:44 -05:00
committed by GitHub
parent ce75dab94b
commit 4472e7b59b
5 changed files with 201 additions and 14 deletions

View File

@@ -63,6 +63,8 @@ class Benchmark:
def set_criterion_param_int64(self, name: str, value: SupportsInt) -> Self: ...
def set_criterion_param_string(self, name: str, value: str) -> Self: ...
def set_min_samples(self, count: SupportsInt) -> Self: ...
def set_cold_warmup_runs(self, count: SupportsInt) -> Self: ...
def set_cold_max_warmup_walltime(self, duration_seconds: SupportsFloat) -> Self: ...
def set_is_cpu_only(self, is_cpu_only: bool) -> Self: ...
class Launch:
@@ -104,6 +106,10 @@ class State:
def set_throttle_threshold(self, threshold_fraction: SupportsFloat) -> None: ...
def get_min_samples(self) -> int: ...
def set_min_samples(self, min_samples_count: SupportsInt) -> None: ...
def get_cold_warmup_runs(self) -> int: ...
def set_cold_warmup_runs(self, cold_warmup_runs: SupportsInt) -> None: ...
def get_cold_max_warmup_walltime(self) -> float: ...
def set_cold_max_warmup_walltime(self, duration_seconds: SupportsFloat) -> None: ...
def get_disable_blocking_kernel(self) -> bool: ...
def set_disable_blocking_kernel(self, flag: bool) -> None: ...
def get_run_once(self) -> bool: ...
@@ -206,6 +212,14 @@ class _OptionDecorators:
) -> Callable[[_F], _F]: ...
def min_samples(self, count: SupportsInt) -> Callable[[_F], _F]: ...
def set_min_samples(self, count: SupportsInt) -> Callable[[_F], _F]: ...
def cold_warmup_runs(self, count: SupportsInt) -> Callable[[_F], _F]: ...
def set_cold_warmup_runs(self, count: SupportsInt) -> Callable[[_F], _F]: ...
def cold_max_warmup_walltime(
self, duration_seconds: SupportsFloat
) -> Callable[[_F], _F]: ...
def set_cold_max_warmup_walltime(
self, duration_seconds: SupportsFloat
) -> Callable[[_F], _F]: ...
def is_cpu_only(self, value: bool = True) -> Callable[[_F], _F]: ...
def set_is_cpu_only(self, value: bool) -> Callable[[_F], _F]: ...

View File

@@ -289,6 +289,28 @@ class _OptionDecorators:
lambda benchmark: benchmark.set_min_samples(count)
)
def cold_warmup_runs(self, count: int) -> Callable[[_F], _F]:
"""Set the number of cold measurement warmup runs."""
return self.set_cold_warmup_runs(count)
def set_cold_warmup_runs(self, count: int) -> Callable[[_F], _F]:
"""Set the number of cold measurement warmup runs."""
return _append_benchmark_action(
lambda benchmark: benchmark.set_cold_warmup_runs(count)
)
def cold_max_warmup_walltime(self, duration_seconds: float) -> Callable[[_F], _F]:
"""Set the maximum walltime spent on cold measurement warmup runs."""
return self.set_cold_max_warmup_walltime(duration_seconds)
def set_cold_max_warmup_walltime(
self, duration_seconds: float
) -> Callable[[_F], _F]:
"""Set the maximum walltime spent on cold measurement warmup runs."""
return _append_benchmark_action(
lambda benchmark: benchmark.set_cold_max_warmup_walltime(duration_seconds)
)
def is_cpu_only(self, value: bool = True) -> Callable[[_F], _F]:
"""Set whether the benchmark only performs CPU work."""
return self.set_is_cpu_only(value)

View File

@@ -472,6 +472,8 @@ static void def_class_Benchmark(py::module_ m)
// nvbench::benchmark_base::set_criterion_param_float64
// nvbench::benchmark_base::set_criterion_param_string
// nvbench::benchmark_base::set_min_samples
// nvbench::benchmark_base::set_cold_warmup_runs
// nvbench::benchmark_base::set_cold_max_warmup_walltime
static constexpr const char *class_Benchmark_doc = R"XXXX(
Represents NVBench benchmark.
@@ -731,6 +733,36 @@ Set minimal samples count before stopping criterion applies
method_set_min_samples_doc,
py::return_value_policy::reference,
py::arg("min_samples_count"));
// method Benchmark.set_cold_warmup_runs
auto method_set_cold_warmup_runs_impl = [](nvbench::benchmark_base &self,
nvbench::int64_t count) {
self.set_cold_warmup_runs(count);
return std::ref(self);
};
static constexpr const char *method_set_cold_warmup_runs_doc = R"XXXX(
Set the number of cold measurement warmup runs
)XXXX";
py_benchmark_cls.def("set_cold_warmup_runs",
method_set_cold_warmup_runs_impl,
method_set_cold_warmup_runs_doc,
py::return_value_policy::reference,
py::arg("cold_warmup_runs"));
// method Benchmark.set_cold_max_warmup_walltime
auto method_set_cold_max_warmup_walltime_impl = [](nvbench::benchmark_base &self,
nvbench::float64_t duration_seconds) {
self.set_cold_max_warmup_walltime(duration_seconds);
return std::ref(self);
};
static constexpr const char *method_set_cold_max_warmup_walltime_doc = R"XXXX(
Set the maximum walltime spent on cold measurement warmup runs, in seconds
)XXXX";
py_benchmark_cls.def("set_cold_max_warmup_walltime",
method_set_cold_max_warmup_walltime_impl,
method_set_cold_max_warmup_walltime_doc,
py::return_value_policy::reference,
py::arg("duration_seconds"));
}
void def_class_State(py::module_ m)
@@ -763,6 +795,10 @@ void def_class_State(py::module_ m)
// nvbench::state::get_skip_reason
// nvbench::state::get_min_samples
// nvbench::state::set_min_samples
// nvbench::state::get_cold_warmup_runs
// nvbench::state::set_cold_warmup_runs
// nvbench::state::get_cold_max_warmup_walltime
// nvbench::state::set_cold_max_warmup_walltime
// nvbench::state::get_criterion_params
// nvbench::state::get_stopping_criterion
// nvbench::state::get_run_once
@@ -1023,6 +1059,40 @@ Set the number of benchmark timings for NVBench to perform before stopping crite
method_set_min_samples_doc,
py::arg("min_samples_count"));
// method State.get_cold_warmup_runs
static constexpr const char *method_get_cold_warmup_runs_doc = R"XXXX(
Get the number of cold measurement warmup runs
)XXXX";
pystate_cls.def("get_cold_warmup_runs",
&nvbench::state::get_cold_warmup_runs,
method_get_cold_warmup_runs_doc);
// method State.set_cold_warmup_runs
static constexpr const char *method_set_cold_warmup_runs_doc = R"XXXX(
Set the number of cold measurement warmup runs
)XXXX";
pystate_cls.def("set_cold_warmup_runs",
&nvbench::state::set_cold_warmup_runs,
method_set_cold_warmup_runs_doc,
py::arg("cold_warmup_runs"));
// method State.get_cold_max_warmup_walltime
static constexpr const char *method_get_cold_max_warmup_walltime_doc = R"XXXX(
Get the maximum walltime spent on cold measurement warmup runs, in seconds
)XXXX";
pystate_cls.def("get_cold_max_warmup_walltime",
&nvbench::state::get_cold_max_warmup_walltime,
method_get_cold_max_warmup_walltime_doc);
// method State.set_cold_max_warmup_walltime
static constexpr const char *method_set_cold_max_warmup_walltime_doc = R"XXXX(
Set the maximum walltime spent on cold measurement warmup runs, in seconds
)XXXX";
pystate_cls.def("set_cold_max_warmup_walltime",
&nvbench::state::set_cold_max_warmup_walltime,
method_set_cold_max_warmup_walltime_doc,
py::arg("duration_seconds"));
// method State.get_disable_blocking_kernel
static constexpr const char *method_get_disable_blocking_kernel_doc = R"XXXX(
Return True if use of blocking kernel by NVBench is disabled, False otherwise

View File

@@ -60,11 +60,15 @@ __global__ void sleep_kernel(double seconds) {
def no_axes(state: bench.State):
state.set_min_samples(1000)
state.set_cold_warmup_runs(5)
state.set_cold_max_warmup_walltime(0.25)
sleep_dur = 1e-3
krn = make_sleep_kernel()
launch_config = core.LaunchConfig(grid=1, block=1, shmem_size=0)
print(f"Stopping criterion used: {state.get_stopping_criterion()}")
print(f"Cold warmup runs: {state.get_cold_warmup_runs()}")
print(f"Cold max warmup walltime: {state.get_cold_max_warmup_walltime()}")
def launcher(launch: bench.Launch):
s = as_core_Stream(launch.get_stream())
@@ -216,12 +220,14 @@ if __name__ == "__main__":
# benchmark with no axes, that uses default value
default_b = bench.register(default_value)
default_b.set_min_samples(7)
default_b.set_cold_warmup_runs(11)
# specify axis
axes_b = bench.register(single_float64_axis).add_float64_axis(
"Duration", [7e-5, 1e-4, 5e-4]
)
axes_b.set_timeout(20)
axes_b.set_cold_max_warmup_walltime(0.5)
axes_b.set_skip_time(1e-5)
axes_b.set_throttle_threshold(0.2)
axes_b.set_throttle_recovery_delay(0.1)

View File

@@ -39,19 +39,23 @@ def test_api_ctor(cls):
cls()
def t_bench(state: bench.State):
s = {"a": 1, "b": 0.5, "c": "test", "d": {"a": 1}}
def launcher(launch: bench.Launch):
for _ in range(10000):
_ = json.dumps(s)
state.exec(launcher)
def test_cpu_only():
saved_timers = []
observed = {}
@bench.register()
@bench.option.set_is_cpu_only(True)
def t_bench(state: bench.State):
s = {"a": 1, "b": 0.5, "c": "test", "d": {"a": 1}}
def launcher(launch: bench.Launch):
for _ in range(10000):
_ = json.dumps(s)
state.exec(launcher)
@bench.register()
@bench.option.set_is_cpu_only(True)
def t_bench_timer(state: bench.State):
s = {"a": 1, "b": 0.5, "c": "test", "d": {"a": 1}}
@@ -67,11 +71,20 @@ def test_cpu_only():
state.exec(launcher, timer=True)
b = bench.register(t_bench)
b.set_is_cpu_only(True)
@bench.register()
@bench.option.set_is_cpu_only(True)
@bench.option.set_cold_warmup_runs(13)
@bench.option.set_cold_max_warmup_walltime(0.5)
def cold_warmup_state_probe(state: bench.State):
observed["benchmark_runs"] = state.get_cold_warmup_runs()
observed["benchmark_walltime"] = state.get_cold_max_warmup_walltime()
b_timer = bench.register(t_bench_timer)
b_timer.set_is_cpu_only(True)
state.set_cold_warmup_runs(3)
state.set_cold_max_warmup_walltime(0.125)
observed["state_runs"] = state.get_cold_warmup_runs()
observed["state_walltime"] = state.get_cold_max_warmup_walltime()
state.exec(lambda launch: None)
bench.run_all_benchmarks(["-q", "--profile"])
@@ -79,6 +92,13 @@ def test_cpu_only():
with pytest.raises(RuntimeError, match="Timer is no longer valid"):
saved_timers[0].start()
assert observed == {
"benchmark_runs": 13,
"benchmark_walltime": 0.5,
"state_runs": 3,
"state_walltime": 0.125,
}
def docstring_check(doc_str: Union[str, None]) -> None:
assert isinstance(doc_str, str)
@@ -132,6 +152,10 @@ def test_decorator_docstrings():
obj_has_docstring_check(bench.option.set_criterion_param_string)
obj_has_docstring_check(bench.option.min_samples)
obj_has_docstring_check(bench.option.set_min_samples)
obj_has_docstring_check(bench.option.cold_warmup_runs)
obj_has_docstring_check(bench.option.set_cold_warmup_runs)
obj_has_docstring_check(bench.option.cold_max_warmup_walltime)
obj_has_docstring_check(bench.option.set_cold_max_warmup_walltime)
obj_has_docstring_check(bench.option.is_cpu_only)
obj_has_docstring_check(bench.option.set_is_cpu_only)
@@ -149,6 +173,14 @@ def test_register_decorator_preserves_function_and_applies_options(monkeypatch):
self.calls.append(("min_samples", count))
return self
def set_cold_warmup_runs(self, count):
self.calls.append(("cold_warmup_runs", count))
return self
def set_cold_max_warmup_walltime(self, duration_seconds):
self.calls.append(("cold_max_warmup_walltime", duration_seconds))
return self
fake_benchmark = FakeBenchmark()
registered_functions = []
@@ -161,6 +193,8 @@ def test_register_decorator_preserves_function_and_applies_options(monkeypatch):
@bench.register()
@bench.axis.int64("Elements", [1, 2, 3])
@bench.option.min_samples(11)
@bench.option.cold_warmup_runs(7)
@bench.option.cold_max_warmup_walltime(0.25)
def decorated(state: bench.State):
pass
@@ -168,6 +202,41 @@ def test_register_decorator_preserves_function_and_applies_options(monkeypatch):
assert fake_benchmark.calls == [
("int64", "Elements", [1, 2, 3]),
("min_samples", 11),
("cold_warmup_runs", 7),
("cold_max_warmup_walltime", 0.25),
]
assert callable(decorated)
def test_set_cold_warmup_option_decorators_apply_options(monkeypatch):
class FakeBenchmark:
def __init__(self):
self.calls = []
def set_cold_warmup_runs(self, count):
self.calls.append(("cold_warmup_runs", count))
return self
def set_cold_max_warmup_walltime(self, duration_seconds):
self.calls.append(("cold_max_warmup_walltime", duration_seconds))
return self
fake_benchmark = FakeBenchmark()
def fake_register(fn):
return fake_benchmark
monkeypatch.setattr(bench, "_register", fake_register)
@bench.register()
@bench.option.set_cold_warmup_runs(13)
@bench.option.set_cold_max_warmup_walltime(0.5)
def decorated(state: bench.State):
pass
assert fake_benchmark.calls == [
("cold_warmup_runs", 13),
("cold_max_warmup_walltime", 0.5),
]
assert callable(decorated)
@@ -241,6 +310,10 @@ def test_State_doc():
obj_has_docstring_check(cl.get_int64)
obj_has_docstring_check(cl.get_float64)
obj_has_docstring_check(cl.get_string)
obj_has_docstring_check(cl.get_cold_warmup_runs)
obj_has_docstring_check(cl.set_cold_warmup_runs)
obj_has_docstring_check(cl.get_cold_max_warmup_walltime)
obj_has_docstring_check(cl.set_cold_max_warmup_walltime)
obj_has_docstring_check(cl.skip)
@@ -269,3 +342,5 @@ def test_Benchmark_doc():
obj_has_docstring_check(cl.add_int64_power_of_two_axis)
obj_has_docstring_check(cl.add_float64_axis)
obj_has_docstring_check(cl.add_string_axis)
obj_has_docstring_check(cl.set_cold_warmup_runs)
obj_has_docstring_check(cl.set_cold_max_warmup_walltime)