Allow kernel_generator to be stateful (#234)

In python kernel generator is a user-defined callable.
We need to capture Python object of that callable in
kernel generator provided for each benchmark.

To this end, nvbench::benchmark has been modified to have member of
kernel_generator type (must be copy-constructable). Constructor acquires
an optional parameter of type `kernel_generator` with default value
of default-contstructed instance.

nvbench::runner was modified to store kernel_generator instance as well.
Its run method creates a fresh copy of stored instance for each invocation,
just as it was happening before.

nvbench tests/examples pass with this change.
This commit is contained in:
Oleksandr Pavlyk
2025-06-28 21:17:12 -05:00
committed by GitHub
parent c2a30cf0d2
commit c463a783bb
2 changed files with 15 additions and 5 deletions

View File

@@ -58,12 +58,16 @@ struct benchmark final : public benchmark_base
static constexpr std::size_t num_type_configs = nvbench::tl::size<type_configs>{};
benchmark()
benchmark(kernel_generator kgen = {})
: benchmark_base(type_axes{})
, m_kernel_generator(kgen)
{}
private:
std::unique_ptr<benchmark_base> do_clone() const final { return std::make_unique<benchmark>(); }
std::unique_ptr<benchmark_base> do_clone() const final
{
return std::make_unique<benchmark>(this->m_kernel_generator);
}
void do_set_type_axes_names(std::vector<std::string> names) final
{
@@ -72,10 +76,12 @@ private:
void do_run() final
{
nvbench::runner<benchmark> runner{*this};
nvbench::runner<benchmark> runner{*this, this->m_kernel_generator};
runner.generate_states();
runner.run();
}
kernel_generator m_kernel_generator;
};
} // namespace nvbench

View File

@@ -54,8 +54,9 @@ struct runner : public runner_base
using type_configs = typename benchmark_type::type_configs;
static constexpr std::size_t num_type_configs = benchmark_type::num_type_configs;
explicit runner(benchmark_type &bench)
explicit runner(benchmark_type &bench, kernel_generator kgen = {})
: runner_base{bench}
, m_kernel_generator{kgen}
{}
void run()
@@ -98,7 +99,8 @@ private:
self.run_state_prologue(cur_state);
try
{
kernel_generator{}(cur_state, type_config{});
auto kernel_generator_copy = self.m_kernel_generator;
kernel_generator_copy(cur_state, type_config{});
if (cur_state.is_skipped())
{
self.print_skip_notification(cur_state);
@@ -115,6 +117,8 @@ private:
++type_config_index;
});
}
kernel_generator m_kernel_generator;
};
} // namespace nvbench