Introduce runner->run_or_skip(bool &) and benchmark->run_or_skip(bool &)

These methods take reference to a boolean whose value signals whether
benchmark instances pending for execution are to be skipped.

void benchmark->run_or_skip(bool &) is called by Python to ensure
that KeyboardInterrupt is properly handled in scripts that contain
multiple benchmarks, or in case when single benchmark script is
executed on a machine with more than one device.
This commit is contained in:
Oleksandr Pavlyk
2025-12-08 14:24:32 -06:00
parent a7763bdd7a
commit 8e6154511e
3 changed files with 20 additions and 8 deletions

View File

@@ -81,6 +81,13 @@ private:
runner.run();
}
void do_run_or_skip(bool &skip_remaining) final
{
nvbench::runner<benchmark> runner{*this, this->m_kernel_generator};
runner.generate_states();
runner.run_or_skip(skip_remaining);
}
kernel_generator m_kernel_generator;
};

View File

@@ -145,6 +145,7 @@ struct benchmark_base
[[nodiscard]] std::vector<nvbench::state> &get_states() { return m_states; }
void run() { this->do_run(); }
void run_or_skip(bool &skip_remaining) { this->do_run_or_skip(skip_remaining); }
void set_printer(nvbench::printer_base &printer) { m_printer = std::ref(printer); }
@@ -320,6 +321,7 @@ private:
virtual std::unique_ptr<benchmark_base> do_clone() const = 0;
virtual void do_set_type_axes_names(std::vector<std::string> names) = 0;
virtual void do_run() = 0;
virtual void do_run_or_skip(bool &skip_remaining) = 0;
};
} // namespace nvbench

View File

@@ -29,8 +29,6 @@ namespace nvbench
struct stop_runner_loop : std::runtime_error
{
// ask compiler to generate all constructor signatures
// that are defined for the base class
using std::runtime_error::runtime_error;
};
@@ -67,22 +65,28 @@ struct runner : public runner_base
{}
void run()
{
[[maybe_unused]] bool skip_remaining = false;
run_or_skip(skip_remaining);
}
void run_or_skip(bool &skip_remaining)
{
if (m_benchmark.m_devices.empty())
{
this->run_device(std::nullopt);
this->run_device(std::nullopt, skip_remaining);
}
else
{
for (const auto &device : m_benchmark.m_devices)
{
this->run_device(device);
this->run_device(device, skip_remaining);
}
}
}
private:
void run_device(const std::optional<nvbench::device_info> &device)
void run_device(const std::optional<nvbench::device_info> &device, bool &skip_remaining)
{
if (device)
{
@@ -92,11 +96,10 @@ private:
// Iterate through type_configs:
std::size_t type_config_index = 0;
nvbench::tl::foreach<type_configs>(
[&self = *this, &states = m_benchmark.m_states, &type_config_index, &device](
[&self = *this, &states = m_benchmark.m_states, &type_config_index, &device, &skip_remaining](
auto type_config_wrapper) {
// Get current type_config:
using type_config = typename decltype(type_config_wrapper)::type;
bool skip_remaining = false;
using type_config = typename decltype(type_config_wrapper)::type;
// Find states with the current device / type_config
for (nvbench::state &cur_state : states)