mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-03-14 20:27:24 +00:00
Allow kernel_generator to be stateful (#234)
In python kernel generator is a user-defined callable. We need to capture Python object of that callable in kernel generator provided for each benchmark. To this end, nvbench::benchmark has been modified to have member of kernel_generator type (must be copy-constructable). Constructor acquires an optional parameter of type `kernel_generator` with default value of default-contstructed instance. nvbench::runner was modified to store kernel_generator instance as well. Its run method creates a fresh copy of stored instance for each invocation, just as it was happening before. nvbench tests/examples pass with this change.
This commit is contained in:
@@ -58,12 +58,16 @@ struct benchmark final : public benchmark_base
|
||||
|
||||
static constexpr std::size_t num_type_configs = nvbench::tl::size<type_configs>{};
|
||||
|
||||
benchmark()
|
||||
benchmark(kernel_generator kgen = {})
|
||||
: benchmark_base(type_axes{})
|
||||
, m_kernel_generator(kgen)
|
||||
{}
|
||||
|
||||
private:
|
||||
std::unique_ptr<benchmark_base> do_clone() const final { return std::make_unique<benchmark>(); }
|
||||
std::unique_ptr<benchmark_base> do_clone() const final
|
||||
{
|
||||
return std::make_unique<benchmark>(this->m_kernel_generator);
|
||||
}
|
||||
|
||||
void do_set_type_axes_names(std::vector<std::string> names) final
|
||||
{
|
||||
@@ -72,10 +76,12 @@ private:
|
||||
|
||||
void do_run() final
|
||||
{
|
||||
nvbench::runner<benchmark> runner{*this};
|
||||
nvbench::runner<benchmark> runner{*this, this->m_kernel_generator};
|
||||
runner.generate_states();
|
||||
runner.run();
|
||||
}
|
||||
|
||||
kernel_generator m_kernel_generator;
|
||||
};
|
||||
|
||||
} // namespace nvbench
|
||||
|
||||
@@ -54,8 +54,9 @@ struct runner : public runner_base
|
||||
using type_configs = typename benchmark_type::type_configs;
|
||||
static constexpr std::size_t num_type_configs = benchmark_type::num_type_configs;
|
||||
|
||||
explicit runner(benchmark_type &bench)
|
||||
explicit runner(benchmark_type &bench, kernel_generator kgen = {})
|
||||
: runner_base{bench}
|
||||
, m_kernel_generator{kgen}
|
||||
{}
|
||||
|
||||
void run()
|
||||
@@ -98,7 +99,8 @@ private:
|
||||
self.run_state_prologue(cur_state);
|
||||
try
|
||||
{
|
||||
kernel_generator{}(cur_state, type_config{});
|
||||
auto kernel_generator_copy = self.m_kernel_generator;
|
||||
kernel_generator_copy(cur_state, type_config{});
|
||||
if (cur_state.is_skipped())
|
||||
{
|
||||
self.print_skip_notification(cur_state);
|
||||
@@ -115,6 +117,8 @@ private:
|
||||
++type_config_index;
|
||||
});
|
||||
}
|
||||
|
||||
kernel_generator m_kernel_generator;
|
||||
};
|
||||
|
||||
} // namespace nvbench
|
||||
|
||||
Reference in New Issue
Block a user