From 8e6154511e89031bf0c4440e57a94096e26076a3 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk <21087696+oleksandr-pavlyk@users.noreply.github.com> Date: Mon, 8 Dec 2025 14:24:32 -0600 Subject: [PATCH] Introduce runner->run_or_skip(bool &) and benchmark->run_or_skip(bool &) These methods take reference to a boolean whose value signals whether benchmark instances pending for execution are to be skipped. void benchmark->run_or_skip(bool &) is called by Python to ensure that KeyboardInterrupt is properly handled in scripts that contain multiple benchmarks, or in case when single benchmark script is executed on a machine with more than one device. --- nvbench/benchmark.cuh | 7 +++++++ nvbench/benchmark_base.cuh | 2 ++ nvbench/runner.cuh | 19 +++++++++++-------- 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/nvbench/benchmark.cuh b/nvbench/benchmark.cuh index 4456a94..963f7a3 100644 --- a/nvbench/benchmark.cuh +++ b/nvbench/benchmark.cuh @@ -81,6 +81,13 @@ private: runner.run(); } + void do_run_or_skip(bool &skip_remaining) final + { + nvbench::runner runner{*this, this->m_kernel_generator}; + runner.generate_states(); + runner.run_or_skip(skip_remaining); + } + kernel_generator m_kernel_generator; }; diff --git a/nvbench/benchmark_base.cuh b/nvbench/benchmark_base.cuh index dce0afc..3eddf2b 100644 --- a/nvbench/benchmark_base.cuh +++ b/nvbench/benchmark_base.cuh @@ -145,6 +145,7 @@ struct benchmark_base [[nodiscard]] std::vector &get_states() { return m_states; } void run() { this->do_run(); } + void run_or_skip(bool &skip_remaining) { this->do_run_or_skip(skip_remaining); } void set_printer(nvbench::printer_base &printer) { m_printer = std::ref(printer); } @@ -320,6 +321,7 @@ private: virtual std::unique_ptr do_clone() const = 0; virtual void do_set_type_axes_names(std::vector names) = 0; virtual void do_run() = 0; + virtual void do_run_or_skip(bool &skip_remaining) = 0; }; } // namespace nvbench diff --git a/nvbench/runner.cuh b/nvbench/runner.cuh index c3cc283..8a78f33 100644 --- a/nvbench/runner.cuh +++ b/nvbench/runner.cuh @@ -29,8 +29,6 @@ namespace nvbench struct stop_runner_loop : std::runtime_error { - // ask compiler to generate all constructor signatures - // that are defined for the base class using std::runtime_error::runtime_error; }; @@ -67,22 +65,28 @@ struct runner : public runner_base {} void run() + { + [[maybe_unused]] bool skip_remaining = false; + run_or_skip(skip_remaining); + } + + void run_or_skip(bool &skip_remaining) { if (m_benchmark.m_devices.empty()) { - this->run_device(std::nullopt); + this->run_device(std::nullopt, skip_remaining); } else { for (const auto &device : m_benchmark.m_devices) { - this->run_device(device); + this->run_device(device, skip_remaining); } } } private: - void run_device(const std::optional &device) + void run_device(const std::optional &device, bool &skip_remaining) { if (device) { @@ -92,11 +96,10 @@ private: // Iterate through type_configs: std::size_t type_config_index = 0; nvbench::tl::foreach( - [&self = *this, &states = m_benchmark.m_states, &type_config_index, &device]( + [&self = *this, &states = m_benchmark.m_states, &type_config_index, &device, &skip_remaining]( auto type_config_wrapper) { // Get current type_config: - using type_config = typename decltype(type_config_wrapper)::type; - bool skip_remaining = false; + using type_config = typename decltype(type_config_wrapper)::type; // Find states with the current device / type_config for (nvbench::state &cur_state : states)