Add state_generator::create method to encapsulate state creation.

This commit is contained in:
Allison Vacanti
2020-12-29 19:35:07 -05:00
parent beaead2c3f
commit e631f1ff03
5 changed files with 183 additions and 15 deletions

View File

@@ -1,15 +1,26 @@
#include <nvbench/detail/state_generator.cuh>
#include <nvbench/axes_metadata.cuh>
#include <nvbench/axis_base.cuh>
#include "test_asserts.cuh"
#include <fmt/format.h>
struct state_generator_tester : nvbench::detail::state_generator
{
using nvbench::detail::state_generator::add_axis;
using nvbench::detail::state_generator::get_current_indices;
using nvbench::detail::state_generator::get_number_of_states;
using nvbench::detail::state_generator::init;
using nvbench::detail::state_generator::iter_valid;
using nvbench::detail::state_generator::next;
};
void test_empty()
{
// no axes = one state
nvbench::detail::state_generator sg;
state_generator_tester sg;
ASSERT(sg.get_number_of_states() == 1);
sg.init();
ASSERT(sg.iter_valid());
@@ -20,18 +31,24 @@ void test_empty()
void test_single_state()
{
// one single-value axis = one state
nvbench::detail::state_generator sg;
state_generator_tester sg;
sg.add_axis("OnlyAxis", nvbench::axis_type::string, 1);
ASSERT(sg.get_number_of_states() == 1);
sg.init();
ASSERT(sg.iter_valid());
ASSERT(sg.get_current_indices().size() == 1);
ASSERT(sg.get_current_indices()[0].axis == "OnlyAxis");
ASSERT(sg.get_current_indices()[0].index == 0);
ASSERT(sg.get_current_indices()[0].size == 1);
ASSERT(sg.get_current_indices()[0].type == nvbench::axis_type::string);
sg.next();
ASSERT(!sg.iter_valid());
}
void test_basic()
{
nvbench::detail::state_generator sg;
state_generator_tester sg;
sg.add_axis("Axis1", nvbench::axis_type::string, 2);
sg.add_axis("Axis2", nvbench::axis_type::string, 3);
sg.add_axis("Axis3", nvbench::axis_type::string, 3);
@@ -51,6 +68,7 @@ void test_basic()
fmt::format_to(line, "| {:^2}", line_num++);
for (auto &axis_index : sg.get_current_indices())
{
ASSERT(axis_index.type == nvbench::axis_type::string);
fmt::format_to(line,
" | {}: {}/{}",
axis_index.axis,
@@ -104,9 +122,85 @@ void test_basic()
fmt::format("Expected:\n\"{}\"\n\nActual:\n\"{}\"", ref, test));
}
void test_create()
{
nvbench::axes_metadata axes;
axes.add_float64_axis("Radians", {3.14, 6.28});
axes.add_int64_axis("VecSize", {2, 3, 4}, nvbench::int64_axis_flags::none);
axes.add_int64_axis("NumInputs",
{10, 15, 20},
nvbench::int64_axis_flags::power_of_two);
axes.add_string_axis("Strategy", {"Recursive", "Iterative"});
const auto states = nvbench::detail::state_generator::create(axes);
ASSERT_MSG(states.size() == (2 * 3 * 3 * 2),
"Actual: {} Expected: {}",
states.size(),
2 * 3 * 3 * 2);
fmt::memory_buffer buffer;
for (const nvbench::state &state : states)
{
fmt::format_to(buffer,
"Radians: {:.2f} | "
"VecSize: {:1d} | "
"NumInputs: {:7d} | "
"Strategy: {}\n",
state.get_float64("Radians"),
state.get_int64("VecSize"),
state.get_int64("NumInputs"),
state.get_string("Strategy"));
}
const std::string ref =
R"expected(Radians: 3.14 | VecSize: 2 | NumInputs: 1024 | Strategy: Recursive
Radians: 6.28 | VecSize: 2 | NumInputs: 1024 | Strategy: Recursive
Radians: 3.14 | VecSize: 3 | NumInputs: 1024 | Strategy: Recursive
Radians: 6.28 | VecSize: 3 | NumInputs: 1024 | Strategy: Recursive
Radians: 3.14 | VecSize: 4 | NumInputs: 1024 | Strategy: Recursive
Radians: 6.28 | VecSize: 4 | NumInputs: 1024 | Strategy: Recursive
Radians: 3.14 | VecSize: 2 | NumInputs: 32768 | Strategy: Recursive
Radians: 6.28 | VecSize: 2 | NumInputs: 32768 | Strategy: Recursive
Radians: 3.14 | VecSize: 3 | NumInputs: 32768 | Strategy: Recursive
Radians: 6.28 | VecSize: 3 | NumInputs: 32768 | Strategy: Recursive
Radians: 3.14 | VecSize: 4 | NumInputs: 32768 | Strategy: Recursive
Radians: 6.28 | VecSize: 4 | NumInputs: 32768 | Strategy: Recursive
Radians: 3.14 | VecSize: 2 | NumInputs: 1048576 | Strategy: Recursive
Radians: 6.28 | VecSize: 2 | NumInputs: 1048576 | Strategy: Recursive
Radians: 3.14 | VecSize: 3 | NumInputs: 1048576 | Strategy: Recursive
Radians: 6.28 | VecSize: 3 | NumInputs: 1048576 | Strategy: Recursive
Radians: 3.14 | VecSize: 4 | NumInputs: 1048576 | Strategy: Recursive
Radians: 6.28 | VecSize: 4 | NumInputs: 1048576 | Strategy: Recursive
Radians: 3.14 | VecSize: 2 | NumInputs: 1024 | Strategy: Iterative
Radians: 6.28 | VecSize: 2 | NumInputs: 1024 | Strategy: Iterative
Radians: 3.14 | VecSize: 3 | NumInputs: 1024 | Strategy: Iterative
Radians: 6.28 | VecSize: 3 | NumInputs: 1024 | Strategy: Iterative
Radians: 3.14 | VecSize: 4 | NumInputs: 1024 | Strategy: Iterative
Radians: 6.28 | VecSize: 4 | NumInputs: 1024 | Strategy: Iterative
Radians: 3.14 | VecSize: 2 | NumInputs: 32768 | Strategy: Iterative
Radians: 6.28 | VecSize: 2 | NumInputs: 32768 | Strategy: Iterative
Radians: 3.14 | VecSize: 3 | NumInputs: 32768 | Strategy: Iterative
Radians: 6.28 | VecSize: 3 | NumInputs: 32768 | Strategy: Iterative
Radians: 3.14 | VecSize: 4 | NumInputs: 32768 | Strategy: Iterative
Radians: 6.28 | VecSize: 4 | NumInputs: 32768 | Strategy: Iterative
Radians: 3.14 | VecSize: 2 | NumInputs: 1048576 | Strategy: Iterative
Radians: 6.28 | VecSize: 2 | NumInputs: 1048576 | Strategy: Iterative
Radians: 3.14 | VecSize: 3 | NumInputs: 1048576 | Strategy: Iterative
Radians: 6.28 | VecSize: 3 | NumInputs: 1048576 | Strategy: Iterative
Radians: 3.14 | VecSize: 4 | NumInputs: 1048576 | Strategy: Iterative
Radians: 6.28 | VecSize: 4 | NumInputs: 1048576 | Strategy: Iterative
)expected";
const std::string test = fmt::to_string(buffer);
ASSERT_MSG(test == ref,
fmt::format("Expected:\n\"{}\"\n\nActual:\n\"{}\"", ref, test));
}
int main()
{
test_empty();
test_single_state();
test_basic();
test_create();
}