mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-03-14 20:27:24 +00:00
Introduce runner->run_or_skip(bool &) and benchmark->run_or_skip(bool &)
These methods take reference to a boolean whose value signals whether benchmark instances pending for execution are to be skipped. void benchmark->run_or_skip(bool &) is called by Python to ensure that KeyboardInterrupt is properly handled in scripts that contain multiple benchmarks, or in case when single benchmark script is executed on a machine with more than one device.
This commit is contained in:
@@ -81,6 +81,13 @@ private:
|
||||
runner.run();
|
||||
}
|
||||
|
||||
void do_run_or_skip(bool &skip_remaining) final
|
||||
{
|
||||
nvbench::runner<benchmark> runner{*this, this->m_kernel_generator};
|
||||
runner.generate_states();
|
||||
runner.run_or_skip(skip_remaining);
|
||||
}
|
||||
|
||||
kernel_generator m_kernel_generator;
|
||||
};
|
||||
|
||||
|
||||
@@ -145,6 +145,7 @@ struct benchmark_base
|
||||
[[nodiscard]] std::vector<nvbench::state> &get_states() { return m_states; }
|
||||
|
||||
void run() { this->do_run(); }
|
||||
void run_or_skip(bool &skip_remaining) { this->do_run_or_skip(skip_remaining); }
|
||||
|
||||
void set_printer(nvbench::printer_base &printer) { m_printer = std::ref(printer); }
|
||||
|
||||
@@ -320,6 +321,7 @@ private:
|
||||
virtual std::unique_ptr<benchmark_base> do_clone() const = 0;
|
||||
virtual void do_set_type_axes_names(std::vector<std::string> names) = 0;
|
||||
virtual void do_run() = 0;
|
||||
virtual void do_run_or_skip(bool &skip_remaining) = 0;
|
||||
};
|
||||
|
||||
} // namespace nvbench
|
||||
|
||||
@@ -29,8 +29,6 @@ namespace nvbench
|
||||
|
||||
struct stop_runner_loop : std::runtime_error
|
||||
{
|
||||
// ask compiler to generate all constructor signatures
|
||||
// that are defined for the base class
|
||||
using std::runtime_error::runtime_error;
|
||||
};
|
||||
|
||||
@@ -67,22 +65,28 @@ struct runner : public runner_base
|
||||
{}
|
||||
|
||||
void run()
|
||||
{
|
||||
[[maybe_unused]] bool skip_remaining = false;
|
||||
run_or_skip(skip_remaining);
|
||||
}
|
||||
|
||||
void run_or_skip(bool &skip_remaining)
|
||||
{
|
||||
if (m_benchmark.m_devices.empty())
|
||||
{
|
||||
this->run_device(std::nullopt);
|
||||
this->run_device(std::nullopt, skip_remaining);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (const auto &device : m_benchmark.m_devices)
|
||||
{
|
||||
this->run_device(device);
|
||||
this->run_device(device, skip_remaining);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void run_device(const std::optional<nvbench::device_info> &device)
|
||||
void run_device(const std::optional<nvbench::device_info> &device, bool &skip_remaining)
|
||||
{
|
||||
if (device)
|
||||
{
|
||||
@@ -92,11 +96,10 @@ private:
|
||||
// Iterate through type_configs:
|
||||
std::size_t type_config_index = 0;
|
||||
nvbench::tl::foreach<type_configs>(
|
||||
[&self = *this, &states = m_benchmark.m_states, &type_config_index, &device](
|
||||
[&self = *this, &states = m_benchmark.m_states, &type_config_index, &device, &skip_remaining](
|
||||
auto type_config_wrapper) {
|
||||
// Get current type_config:
|
||||
using type_config = typename decltype(type_config_wrapper)::type;
|
||||
bool skip_remaining = false;
|
||||
using type_config = typename decltype(type_config_wrapper)::type;
|
||||
|
||||
// Find states with the current device / type_config
|
||||
for (nvbench::state &cur_state : states)
|
||||
|
||||
Reference in New Issue
Block a user