mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-04-20 14:58:54 +00:00
Add state_generator::create method to encapsulate state creation.
This commit is contained in:
@@ -1,14 +1,71 @@
|
||||
#include <nvbench/detail/state_generator.cuh>
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace nvbench
|
||||
{
|
||||
|
||||
namespace detail
|
||||
{
|
||||
|
||||
std::vector<nvbench::state> state_generator::create(const axes_metadata &axes)
|
||||
{
|
||||
state_generator sg;
|
||||
{
|
||||
const auto &axes_vec = axes.get_axes();
|
||||
std::for_each(axes_vec.cbegin(), axes_vec.cend(), [&sg](const auto &axis) {
|
||||
if (axis->get_type() != nvbench::axis_type::type)
|
||||
{
|
||||
sg.add_axis(*axis);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<nvbench::state> states;
|
||||
{
|
||||
states.reserve(sg.get_number_of_states());
|
||||
for (sg.init(); sg.iter_valid(); sg.next())
|
||||
{
|
||||
nvbench::state state;
|
||||
for (const axis_index &axis_info : sg.get_current_indices())
|
||||
{
|
||||
switch (axis_info.type)
|
||||
{
|
||||
default:
|
||||
case axis_type::type:
|
||||
assert("unreachable." && false);
|
||||
break;
|
||||
|
||||
case axis_type::int64:
|
||||
state.set_param(
|
||||
axis_info.axis,
|
||||
axes.get_int64_axis(axis_info.axis).get_value(axis_info.index));
|
||||
break;
|
||||
|
||||
case axis_type::float64:
|
||||
state.set_param(
|
||||
axis_info.axis,
|
||||
axes.get_float64_axis(axis_info.axis).get_value(axis_info.index));
|
||||
break;
|
||||
|
||||
case axis_type::string:
|
||||
state.set_param(
|
||||
axis_info.axis,
|
||||
axes.get_string_axis(axis_info.axis).get_value(axis_info.index));
|
||||
break;
|
||||
}
|
||||
}
|
||||
states.push_back(std::move(state));
|
||||
}
|
||||
}
|
||||
|
||||
return states;
|
||||
}
|
||||
|
||||
std::size_t state_generator::get_number_of_states() const
|
||||
{
|
||||
return std::transform_reduce(m_indices.cbegin(),
|
||||
@@ -23,7 +80,7 @@ std::size_t state_generator::get_number_of_states() const
|
||||
void state_generator::init()
|
||||
{
|
||||
m_current = 0;
|
||||
m_total = this->get_number_of_states();
|
||||
m_total = this->get_number_of_states();
|
||||
for (axis_index &entry : m_indices)
|
||||
{
|
||||
entry.index = 0;
|
||||
@@ -34,7 +91,7 @@ bool state_generator::iter_valid() const { return m_current < m_total; }
|
||||
|
||||
void state_generator::next()
|
||||
{
|
||||
for (axis_index& axis_info : m_indices)
|
||||
for (axis_index &axis_info : m_indices)
|
||||
{
|
||||
axis_info.index += 1;
|
||||
if (axis_info.index >= axis_info.size)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include <nvbench/axis_base.cuh> // for axis_type
|
||||
#include <nvbench/axes_metadata.cuh>
|
||||
#include <nvbench/axis_base.cuh>
|
||||
#include <nvbench/state.cuh>
|
||||
|
||||
#include <string_view>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace nvbench
|
||||
@@ -13,9 +15,13 @@ namespace detail
|
||||
|
||||
struct state_generator
|
||||
{
|
||||
|
||||
static std::vector<nvbench::state> create(const axes_metadata &axes);
|
||||
|
||||
protected:
|
||||
struct axis_index
|
||||
{
|
||||
std::string_view axis;
|
||||
std::string axis;
|
||||
nvbench::axis_type type;
|
||||
std::size_t index;
|
||||
std::size_t size;
|
||||
@@ -26,7 +32,7 @@ struct state_generator
|
||||
this->add_axis(axis.get_name(), axis.get_type(), axis.get_size());
|
||||
}
|
||||
|
||||
void add_axis(std::string_view axis,
|
||||
void add_axis(std::string axis,
|
||||
nvbench::axis_type type,
|
||||
std::size_t size)
|
||||
{
|
||||
@@ -35,9 +41,6 @@ struct state_generator
|
||||
|
||||
[[nodiscard]] std::size_t get_number_of_states() const;
|
||||
|
||||
// Yep, this class is its own non-STL-style iterator.
|
||||
// It's fiiiiine, we're in detail::. PRs welcome.
|
||||
//
|
||||
// Usage:
|
||||
// ```
|
||||
// state_generator sg;
|
||||
@@ -60,7 +63,6 @@ struct state_generator
|
||||
[[nodiscard]] bool iter_valid() const;
|
||||
void next();
|
||||
|
||||
private:
|
||||
std::vector<axis_index> m_indices;
|
||||
std::size_t m_current{};
|
||||
std::size_t m_total{};
|
||||
|
||||
@@ -9,6 +9,11 @@
|
||||
namespace nvbench
|
||||
{
|
||||
|
||||
const state::param_type &state::get_param(const std::string &axis_name) const
|
||||
{
|
||||
return m_params.at(axis_name);
|
||||
}
|
||||
|
||||
nvbench::int64_t state::get_int64(const std::string &axis_name) const
|
||||
{
|
||||
return std::get<nvbench::int64_t>(m_params.at(axis_name));
|
||||
|
||||
@@ -9,11 +9,16 @@
|
||||
namespace nvbench
|
||||
{
|
||||
|
||||
namespace detail
|
||||
{
|
||||
struct state_generator;
|
||||
}
|
||||
|
||||
struct state
|
||||
{
|
||||
// move-only
|
||||
state(const state &) = delete;
|
||||
state(state &&) = default;
|
||||
state(state &&) = default;
|
||||
state &operator=(const state &) = delete;
|
||||
state &operator=(state &&) = default;
|
||||
|
||||
@@ -26,16 +31,21 @@ struct state
|
||||
get_string(const std::string &axis_name) const;
|
||||
|
||||
protected:
|
||||
state() = default;
|
||||
friend struct nvbench::detail::state_generator;
|
||||
|
||||
using param_type =
|
||||
std::variant<nvbench::int64_t, nvbench::float64_t, std::string>;
|
||||
using params_type = std::unordered_map<std::string, param_type>;
|
||||
|
||||
state() = default;
|
||||
|
||||
explicit state(params_type params)
|
||||
: m_params{std::move(params)}
|
||||
{}
|
||||
|
||||
[[nodiscard]] const params_type &get_params() const { return m_params; }
|
||||
[[nodiscard]] const param_type &get_param(const std::string &name) const;
|
||||
|
||||
void set_param(std::string axis_name, nvbench::int64_t value);
|
||||
void set_param(std::string axis_name, nvbench::float64_t value);
|
||||
void set_param(std::string axis_name, std::string value);
|
||||
|
||||
@@ -1,15 +1,26 @@
|
||||
#include <nvbench/detail/state_generator.cuh>
|
||||
|
||||
#include <nvbench/axes_metadata.cuh>
|
||||
#include <nvbench/axis_base.cuh>
|
||||
|
||||
#include "test_asserts.cuh"
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
struct state_generator_tester : nvbench::detail::state_generator
|
||||
{
|
||||
using nvbench::detail::state_generator::add_axis;
|
||||
using nvbench::detail::state_generator::get_current_indices;
|
||||
using nvbench::detail::state_generator::get_number_of_states;
|
||||
using nvbench::detail::state_generator::init;
|
||||
using nvbench::detail::state_generator::iter_valid;
|
||||
using nvbench::detail::state_generator::next;
|
||||
};
|
||||
|
||||
void test_empty()
|
||||
{
|
||||
// no axes = one state
|
||||
nvbench::detail::state_generator sg;
|
||||
state_generator_tester sg;
|
||||
ASSERT(sg.get_number_of_states() == 1);
|
||||
sg.init();
|
||||
ASSERT(sg.iter_valid());
|
||||
@@ -20,18 +31,24 @@ void test_empty()
|
||||
void test_single_state()
|
||||
{
|
||||
// one single-value axis = one state
|
||||
nvbench::detail::state_generator sg;
|
||||
state_generator_tester sg;
|
||||
sg.add_axis("OnlyAxis", nvbench::axis_type::string, 1);
|
||||
ASSERT(sg.get_number_of_states() == 1);
|
||||
sg.init();
|
||||
ASSERT(sg.iter_valid());
|
||||
ASSERT(sg.get_current_indices().size() == 1);
|
||||
ASSERT(sg.get_current_indices()[0].axis == "OnlyAxis");
|
||||
ASSERT(sg.get_current_indices()[0].index == 0);
|
||||
ASSERT(sg.get_current_indices()[0].size == 1);
|
||||
ASSERT(sg.get_current_indices()[0].type == nvbench::axis_type::string);
|
||||
|
||||
sg.next();
|
||||
ASSERT(!sg.iter_valid());
|
||||
}
|
||||
|
||||
void test_basic()
|
||||
{
|
||||
nvbench::detail::state_generator sg;
|
||||
state_generator_tester sg;
|
||||
sg.add_axis("Axis1", nvbench::axis_type::string, 2);
|
||||
sg.add_axis("Axis2", nvbench::axis_type::string, 3);
|
||||
sg.add_axis("Axis3", nvbench::axis_type::string, 3);
|
||||
@@ -51,6 +68,7 @@ void test_basic()
|
||||
fmt::format_to(line, "| {:^2}", line_num++);
|
||||
for (auto &axis_index : sg.get_current_indices())
|
||||
{
|
||||
ASSERT(axis_index.type == nvbench::axis_type::string);
|
||||
fmt::format_to(line,
|
||||
" | {}: {}/{}",
|
||||
axis_index.axis,
|
||||
@@ -104,9 +122,85 @@ void test_basic()
|
||||
fmt::format("Expected:\n\"{}\"\n\nActual:\n\"{}\"", ref, test));
|
||||
}
|
||||
|
||||
void test_create()
|
||||
{
|
||||
nvbench::axes_metadata axes;
|
||||
axes.add_float64_axis("Radians", {3.14, 6.28});
|
||||
axes.add_int64_axis("VecSize", {2, 3, 4}, nvbench::int64_axis_flags::none);
|
||||
axes.add_int64_axis("NumInputs",
|
||||
{10, 15, 20},
|
||||
nvbench::int64_axis_flags::power_of_two);
|
||||
axes.add_string_axis("Strategy", {"Recursive", "Iterative"});
|
||||
|
||||
const auto states = nvbench::detail::state_generator::create(axes);
|
||||
|
||||
ASSERT_MSG(states.size() == (2 * 3 * 3 * 2),
|
||||
"Actual: {} Expected: {}",
|
||||
states.size(),
|
||||
2 * 3 * 3 * 2);
|
||||
|
||||
fmt::memory_buffer buffer;
|
||||
for (const nvbench::state &state : states)
|
||||
{
|
||||
fmt::format_to(buffer,
|
||||
"Radians: {:.2f} | "
|
||||
"VecSize: {:1d} | "
|
||||
"NumInputs: {:7d} | "
|
||||
"Strategy: {}\n",
|
||||
state.get_float64("Radians"),
|
||||
state.get_int64("VecSize"),
|
||||
state.get_int64("NumInputs"),
|
||||
state.get_string("Strategy"));
|
||||
}
|
||||
|
||||
const std::string ref =
|
||||
R"expected(Radians: 3.14 | VecSize: 2 | NumInputs: 1024 | Strategy: Recursive
|
||||
Radians: 6.28 | VecSize: 2 | NumInputs: 1024 | Strategy: Recursive
|
||||
Radians: 3.14 | VecSize: 3 | NumInputs: 1024 | Strategy: Recursive
|
||||
Radians: 6.28 | VecSize: 3 | NumInputs: 1024 | Strategy: Recursive
|
||||
Radians: 3.14 | VecSize: 4 | NumInputs: 1024 | Strategy: Recursive
|
||||
Radians: 6.28 | VecSize: 4 | NumInputs: 1024 | Strategy: Recursive
|
||||
Radians: 3.14 | VecSize: 2 | NumInputs: 32768 | Strategy: Recursive
|
||||
Radians: 6.28 | VecSize: 2 | NumInputs: 32768 | Strategy: Recursive
|
||||
Radians: 3.14 | VecSize: 3 | NumInputs: 32768 | Strategy: Recursive
|
||||
Radians: 6.28 | VecSize: 3 | NumInputs: 32768 | Strategy: Recursive
|
||||
Radians: 3.14 | VecSize: 4 | NumInputs: 32768 | Strategy: Recursive
|
||||
Radians: 6.28 | VecSize: 4 | NumInputs: 32768 | Strategy: Recursive
|
||||
Radians: 3.14 | VecSize: 2 | NumInputs: 1048576 | Strategy: Recursive
|
||||
Radians: 6.28 | VecSize: 2 | NumInputs: 1048576 | Strategy: Recursive
|
||||
Radians: 3.14 | VecSize: 3 | NumInputs: 1048576 | Strategy: Recursive
|
||||
Radians: 6.28 | VecSize: 3 | NumInputs: 1048576 | Strategy: Recursive
|
||||
Radians: 3.14 | VecSize: 4 | NumInputs: 1048576 | Strategy: Recursive
|
||||
Radians: 6.28 | VecSize: 4 | NumInputs: 1048576 | Strategy: Recursive
|
||||
Radians: 3.14 | VecSize: 2 | NumInputs: 1024 | Strategy: Iterative
|
||||
Radians: 6.28 | VecSize: 2 | NumInputs: 1024 | Strategy: Iterative
|
||||
Radians: 3.14 | VecSize: 3 | NumInputs: 1024 | Strategy: Iterative
|
||||
Radians: 6.28 | VecSize: 3 | NumInputs: 1024 | Strategy: Iterative
|
||||
Radians: 3.14 | VecSize: 4 | NumInputs: 1024 | Strategy: Iterative
|
||||
Radians: 6.28 | VecSize: 4 | NumInputs: 1024 | Strategy: Iterative
|
||||
Radians: 3.14 | VecSize: 2 | NumInputs: 32768 | Strategy: Iterative
|
||||
Radians: 6.28 | VecSize: 2 | NumInputs: 32768 | Strategy: Iterative
|
||||
Radians: 3.14 | VecSize: 3 | NumInputs: 32768 | Strategy: Iterative
|
||||
Radians: 6.28 | VecSize: 3 | NumInputs: 32768 | Strategy: Iterative
|
||||
Radians: 3.14 | VecSize: 4 | NumInputs: 32768 | Strategy: Iterative
|
||||
Radians: 6.28 | VecSize: 4 | NumInputs: 32768 | Strategy: Iterative
|
||||
Radians: 3.14 | VecSize: 2 | NumInputs: 1048576 | Strategy: Iterative
|
||||
Radians: 6.28 | VecSize: 2 | NumInputs: 1048576 | Strategy: Iterative
|
||||
Radians: 3.14 | VecSize: 3 | NumInputs: 1048576 | Strategy: Iterative
|
||||
Radians: 6.28 | VecSize: 3 | NumInputs: 1048576 | Strategy: Iterative
|
||||
Radians: 3.14 | VecSize: 4 | NumInputs: 1048576 | Strategy: Iterative
|
||||
Radians: 6.28 | VecSize: 4 | NumInputs: 1048576 | Strategy: Iterative
|
||||
)expected";
|
||||
|
||||
const std::string test = fmt::to_string(buffer);
|
||||
ASSERT_MSG(test == ref,
|
||||
fmt::format("Expected:\n\"{}\"\n\nActual:\n\"{}\"", ref, test));
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
test_empty();
|
||||
test_single_state();
|
||||
test_basic();
|
||||
test_create();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user