More cleanup

This commit is contained in:
Robert Maynard
2022-04-12 10:54:36 -04:00
parent 4c964d2923
commit 26467f3855
4 changed files with 84 additions and 31 deletions

View File

@@ -77,14 +77,14 @@ struct axes_metadata
Args &&...args)
{
(this->add_axis(std::forward<Args>(args)), ...);
this->user_iteration_axes({args.get_name()...}, std::move(make));
this->user_iteration_axes(std::move(make), {args.get_name()...});
}
void zip_axes(std::vector<std::string> names);
void
user_iteration_axes(std::vector<std::string> names,
std::function<nvbench::make_user_space_signature> make);
user_iteration_axes(std::function<nvbench::make_user_space_signature> make,
std::vector<std::string> names);
[[nodiscard]] const axes_iteration_space &get_type_iteration_space() const
{

View File

@@ -265,8 +265,8 @@ void axes_metadata::zip_axes(std::vector<std::string> names)
}
void axes_metadata::user_iteration_axes(
std::vector<std::string> names,
std::function<nvbench::make_user_space_signature> make)
std::function<nvbench::make_user_space_signature> make,
std::vector<std::string> names)
{
// compute the numeric indice for each name we have
auto [input_indices,

View File

@@ -118,12 +118,6 @@ struct benchmark_base
return *this;
}
benchmark_base &zip_axes(std::vector<std::string> names)
{
m_axes.zip_axes(std::move(names));
return *this;
}
template<typename... Args>
benchmark_base &add_user_iteration_axes(Args&&... args)
{
@@ -131,15 +125,6 @@ struct benchmark_base
return *this;
}
benchmark_base &
user_iteration_axes(std::vector<std::string> names,
std::function<nvbench::make_user_space_signature> make)
{
m_axes.user_iteration_axes(std::move(names), std::move(make));
return *this;
}
benchmark_base &set_devices(std::vector<int> device_ids);
benchmark_base &set_devices(std::vector<nvbench::device_info> devices)
@@ -272,6 +257,38 @@ struct benchmark_base
/// @}
protected:
/// Move existing Axis to being part of zip axis iteration space.
/// This will remove any existing iteration spaces that the named axis
/// are part of, while restoring all other axis in those spaces to
/// the default linear space
///
/// This is meant to be used only by the option_parser
/// @{
benchmark_base &zip_axes(std::vector<std::string> names)
{
m_axes.zip_axes(std::move(names));
return *this;
}
/// @}
/// Move existing Axis to being part of user axis iteration space.
/// This will remove any existing iteration spaces that the named axis
/// are part of, while restoring all other axis in those spaces to
/// the default linear space
///
/// This is meant to be used only by the option_parser
/// @{
benchmark_base &
user_iteration_axes(std::function<nvbench::make_user_space_signature> make,
std::vector<std::string> names)
{
m_axes.user_iteration_axes(std::move(make), std::move(names));
return *this;
}
/// @}
friend struct nvbench::runner_base;
template <typename BenchmarkType>

View File

@@ -60,6 +60,42 @@ void no_op_generator(nvbench::state &state)
}
NVBENCH_DEFINE_CALLABLE(no_op_generator, no_op_callable);
template <typename KernelGenerator, typename TypeAxes = nvbench::type_list<>>
struct rezippable_benchmark final : public nvbench::benchmark_base
{
using kernel_generator = KernelGenerator;
using type_axes = TypeAxes;
using type_configs = nvbench::tl::cartesian_product<type_axes>;
static constexpr std::size_t num_type_configs =
nvbench::tl::size<type_configs>{};
rezippable_benchmark()
: benchmark_base(type_axes{})
{}
using nvbench::benchmark_base::zip_axes;
using nvbench::benchmark_base::user_iteration_axes;
private:
std::unique_ptr<benchmark_base> do_clone() const final
{
return std::make_unique<rezippable_benchmark>();
}
void do_set_type_axes_names(std::vector<std::string> names) final
{
m_axes.set_type_axes_names(std::move(names));
}
void do_run() final
{
nvbench::runner<rezippable_benchmark> runner{*this};
runner.generate_states();
runner.run();
}
};
template <typename Integer, typename Float, typename Other>
void template_no_op_generator(nvbench::state &state,
nvbench::type_list<Integer, Float, Other>)
@@ -91,7 +127,7 @@ void test_zip_axes()
void test_tie_invalid_names()
{
using benchmark_type = nvbench::benchmark<no_op_callable>;
using benchmark_type = rezippable_benchmark<no_op_callable>;
benchmark_type bench;
bench.add_float64_axis("F64 Axis", {0., .1, .25, .5, 1.});
bench.add_int64_axis("I64 Axis", {1, 3, 2});
@@ -114,11 +150,11 @@ void test_tie_unequal_length()
void test_tie_type_axi()
{
using benchmark_type =
nvbench::benchmark<template_no_op_callable,
nvbench::type_list<nvbench::type_list<nvbench::int8_t>,
nvbench::type_list<nvbench::float32_t>,
nvbench::type_list<bool>>>;
using benchmark_type = rezippable_benchmark<
template_no_op_callable,
nvbench::type_list<nvbench::type_list<nvbench::int8_t>,
nvbench::type_list<nvbench::float32_t>,
nvbench::type_list<bool>>>;
benchmark_type bench;
bench.set_type_axes_names({"Integer", "Float", "Other"});
bench.add_float64_axis("F64 Axis", {0., .1, .25, .5, 1.});
@@ -129,7 +165,7 @@ void test_tie_type_axi()
void test_rezip_axes()
{
using benchmark_type = nvbench::benchmark<no_op_callable>;
using benchmark_type = rezippable_benchmark<no_op_callable>;
benchmark_type bench;
bench.add_int64_axis("IAxis_A", {1, 3, 2, 4, 5});
bench.add_int64_axis("IAxis_B", {1, 3, 2, 4, 5});
@@ -155,7 +191,7 @@ void test_rezip_axes()
void test_rezip_axes2()
{
using benchmark_type = nvbench::benchmark<no_op_callable>;
using benchmark_type = rezippable_benchmark<no_op_callable>;
benchmark_type bench;
bench.add_int64_axis("IAxis_A", {1, 3, 2, 4, 5});
bench.add_int64_axis("IAxis_B", {1, 3, 2, 4, 5});
@@ -298,15 +334,15 @@ struct under_diag final : nvbench::user_axis_space
void test_user_axes()
{
using benchmark_type = nvbench::benchmark<no_op_callable>;
using benchmark_type = rezippable_benchmark<no_op_callable>;
benchmark_type bench;
bench.add_float64_axis("F64 Axis", {0., .1, .25, .5, 1.});
bench.add_int64_axis("I64 Axis", {1, 3, 2, 4, 5});
bench.user_iteration_axes(
{"F64 Axis", "I64 Axis"},
[](auto... args) -> std::unique_ptr<nvbench::axis_space_base> {
return std::make_unique<under_diag>(args...);
});
},
{"F64 Axis", "I64 Axis"});
ASSERT_MSG(bench.get_config_count() == 15 * bench.get_devices().size(),
"Got {}",