diff --git a/nvbench/CMakeLists.txt b/nvbench/CMakeLists.txt index 52947ef..ca687f3 100644 --- a/nvbench/CMakeLists.txt +++ b/nvbench/CMakeLists.txt @@ -5,6 +5,7 @@ set(srcs params.cu string_axis.cu type_axis.cu + detail/state_generator.cu ) # TODO shared may be a good idea to reduce compilation overhead for large diff --git a/nvbench/detail/state_generator.cu b/nvbench/detail/state_generator.cu new file mode 100644 index 0000000..88ae995 --- /dev/null +++ b/nvbench/detail/state_generator.cu @@ -0,0 +1,51 @@ +#include + +#include +#include + +namespace nvbench +{ + +namespace detail +{ + +std::size_t state_generator::get_number_of_states() const +{ + return std::transform_reduce(m_indices.cbegin(), + m_indices.cend(), + std::size_t{1}, + std::multiplies<>{}, + [](const axis_index &size_info) { + return size_info.size; + }); +} + +void state_generator::init() +{ + m_current = 0; + m_total = this->get_number_of_states(); + for (axis_index &entry : m_indices) + { + entry.index = 0; + } +} + +bool state_generator::iter_valid() const { return m_current < m_total; } + +void state_generator::next() +{ + for (axis_index& axis_info : m_indices) + { + axis_info.index += 1; + if (axis_info.index >= axis_info.size) + { + axis_info.index = 0; + continue; // carry the addition to the next entry in m_indices + } + break; // done + } + m_current += 1; +} + +} // namespace detail +} // namespace nvbench diff --git a/nvbench/detail/state_generator.cuh b/nvbench/detail/state_generator.cuh new file mode 100644 index 0000000..1ae7ec4 --- /dev/null +++ b/nvbench/detail/state_generator.cuh @@ -0,0 +1,70 @@ +#pragma once + +#include // for axis_type + +#include +#include + +namespace nvbench +{ + +namespace detail +{ + +struct state_generator +{ + struct axis_index + { + std::string_view axis; + nvbench::axis_type type; + std::size_t index; + std::size_t size; + }; + + void add_axis(const nvbench::axis_base &axis) + { + this->add_axis(axis.get_name(), axis.get_type(), axis.get_size()); + } + + void add_axis(std::string_view axis, + nvbench::axis_type type, + std::size_t size) + { + m_indices.push_back({std::move(axis), type, std::size_t{0}, size}); + } + + [[nodiscard]] std::size_t get_number_of_states() const; + + // Yep, this class is its own non-STL-style iterator. + // It's fiiiiine, we're in detail::. PRs welcome. + // + // Usage: + // ``` + // state_generator sg; + // sg.add_axis(...); + // for (sg.init(); sg.iter_valid(); sg.next()) + // { + // for (const auto& axis_index : sg.get_current_indices()) + // { + // std::string axis_name = axis_index.axis; + // nvbench::axis_type type = axis_index.type; + // std::size_t value_index = axis_index.index; + // } + // } + // ``` + void init(); + [[nodiscard]] const std::vector &get_current_indices() + { + return m_indices; + } + [[nodiscard]] bool iter_valid() const; + void next(); + +private: + std::vector m_indices; + std::size_t m_current{}; + std::size_t m_total{}; +}; + +} // namespace detail +} // namespace nvbench diff --git a/testing/CMakeLists.txt b/testing/CMakeLists.txt index 53e614e..fe3350c 100644 --- a/testing/CMakeLists.txt +++ b/testing/CMakeLists.txt @@ -3,6 +3,7 @@ set(test_srcs int64_axis.cu float64_axis.cu params.cu + state_generator.cu string_axis.cu type_axis.cu type_list.cu diff --git a/testing/state_generator.cu b/testing/state_generator.cu new file mode 100644 index 0000000..836ce6a --- /dev/null +++ b/testing/state_generator.cu @@ -0,0 +1,112 @@ +#include + +#include + +#include "test_asserts.cuh" + +#include + +void test_empty() +{ + // no axes = one state + nvbench::detail::state_generator sg; + ASSERT(sg.get_number_of_states() == 1); + sg.init(); + ASSERT(sg.iter_valid()); + sg.next(); + ASSERT(!sg.iter_valid()); +} + +void test_single_state() +{ + // one single-value axis = one state + nvbench::detail::state_generator sg; + sg.add_axis("OnlyAxis", nvbench::axis_type::string, 1); + ASSERT(sg.get_number_of_states() == 1); + sg.init(); + ASSERT(sg.iter_valid()); + sg.next(); + ASSERT(!sg.iter_valid()); +} + +void test_basic() +{ + nvbench::detail::state_generator sg; + sg.add_axis("Axis1", nvbench::axis_type::string, 2); + sg.add_axis("Axis2", nvbench::axis_type::string, 3); + sg.add_axis("Axis3", nvbench::axis_type::string, 3); + sg.add_axis("Axis4", nvbench::axis_type::string, 2); + + ASSERT_MSG(sg.get_number_of_states() == (2 * 3 * 3 * 2), + "Actual: {} Expected: {}", + sg.get_number_of_states(), + 2 * 3 * 3 * 2); + + fmt::memory_buffer buffer; + fmt::memory_buffer line; + std::size_t line_num{0}; + for (sg.init(); sg.iter_valid(); sg.next()) + { + line.clear(); + fmt::format_to(line, "| {:^2}", line_num++); + for (auto &axis_index : sg.get_current_indices()) + { + fmt::format_to(line, + " | {}: {}/{}", + axis_index.axis, + axis_index.index, + axis_index.size); + } + fmt::format_to(buffer, "{} |\n", fmt::to_string(line)); + } + + const std::string ref = + R"expected(| 0 | Axis1: 0/2 | Axis2: 0/3 | Axis3: 0/3 | Axis4: 0/2 | +| 1 | Axis1: 1/2 | Axis2: 0/3 | Axis3: 0/3 | Axis4: 0/2 | +| 2 | Axis1: 0/2 | Axis2: 1/3 | Axis3: 0/3 | Axis4: 0/2 | +| 3 | Axis1: 1/2 | Axis2: 1/3 | Axis3: 0/3 | Axis4: 0/2 | +| 4 | Axis1: 0/2 | Axis2: 2/3 | Axis3: 0/3 | Axis4: 0/2 | +| 5 | Axis1: 1/2 | Axis2: 2/3 | Axis3: 0/3 | Axis4: 0/2 | +| 6 | Axis1: 0/2 | Axis2: 0/3 | Axis3: 1/3 | Axis4: 0/2 | +| 7 | Axis1: 1/2 | Axis2: 0/3 | Axis3: 1/3 | Axis4: 0/2 | +| 8 | Axis1: 0/2 | Axis2: 1/3 | Axis3: 1/3 | Axis4: 0/2 | +| 9 | Axis1: 1/2 | Axis2: 1/3 | Axis3: 1/3 | Axis4: 0/2 | +| 10 | Axis1: 0/2 | Axis2: 2/3 | Axis3: 1/3 | Axis4: 0/2 | +| 11 | Axis1: 1/2 | Axis2: 2/3 | Axis3: 1/3 | Axis4: 0/2 | +| 12 | Axis1: 0/2 | Axis2: 0/3 | Axis3: 2/3 | Axis4: 0/2 | +| 13 | Axis1: 1/2 | Axis2: 0/3 | Axis3: 2/3 | Axis4: 0/2 | +| 14 | Axis1: 0/2 | Axis2: 1/3 | Axis3: 2/3 | Axis4: 0/2 | +| 15 | Axis1: 1/2 | Axis2: 1/3 | Axis3: 2/3 | Axis4: 0/2 | +| 16 | Axis1: 0/2 | Axis2: 2/3 | Axis3: 2/3 | Axis4: 0/2 | +| 17 | Axis1: 1/2 | Axis2: 2/3 | Axis3: 2/3 | Axis4: 0/2 | +| 18 | Axis1: 0/2 | Axis2: 0/3 | Axis3: 0/3 | Axis4: 1/2 | +| 19 | Axis1: 1/2 | Axis2: 0/3 | Axis3: 0/3 | Axis4: 1/2 | +| 20 | Axis1: 0/2 | Axis2: 1/3 | Axis3: 0/3 | Axis4: 1/2 | +| 21 | Axis1: 1/2 | Axis2: 1/3 | Axis3: 0/3 | Axis4: 1/2 | +| 22 | Axis1: 0/2 | Axis2: 2/3 | Axis3: 0/3 | Axis4: 1/2 | +| 23 | Axis1: 1/2 | Axis2: 2/3 | Axis3: 0/3 | Axis4: 1/2 | +| 24 | Axis1: 0/2 | Axis2: 0/3 | Axis3: 1/3 | Axis4: 1/2 | +| 25 | Axis1: 1/2 | Axis2: 0/3 | Axis3: 1/3 | Axis4: 1/2 | +| 26 | Axis1: 0/2 | Axis2: 1/3 | Axis3: 1/3 | Axis4: 1/2 | +| 27 | Axis1: 1/2 | Axis2: 1/3 | Axis3: 1/3 | Axis4: 1/2 | +| 28 | Axis1: 0/2 | Axis2: 2/3 | Axis3: 1/3 | Axis4: 1/2 | +| 29 | Axis1: 1/2 | Axis2: 2/3 | Axis3: 1/3 | Axis4: 1/2 | +| 30 | Axis1: 0/2 | Axis2: 0/3 | Axis3: 2/3 | Axis4: 1/2 | +| 31 | Axis1: 1/2 | Axis2: 0/3 | Axis3: 2/3 | Axis4: 1/2 | +| 32 | Axis1: 0/2 | Axis2: 1/3 | Axis3: 2/3 | Axis4: 1/2 | +| 33 | Axis1: 1/2 | Axis2: 1/3 | Axis3: 2/3 | Axis4: 1/2 | +| 34 | Axis1: 0/2 | Axis2: 2/3 | Axis3: 2/3 | Axis4: 1/2 | +| 35 | Axis1: 1/2 | Axis2: 2/3 | Axis3: 2/3 | Axis4: 1/2 | +)expected"; + + const std::string test = fmt::to_string(buffer); + ASSERT_MSG(test == ref, + fmt::format("Expected:\n\"{}\"\n\nActual:\n\"{}\"", ref, test)); +} + +int main() +{ + test_empty(); + test_single_state(); + test_basic(); +} diff --git a/testing/test_asserts.cuh b/testing/test_asserts.cuh index e0a8094..0f56e75 100644 --- a/testing/test_asserts.cuh +++ b/testing/test_asserts.cuh @@ -14,7 +14,7 @@ } \ } while (false) -#define ASSERT_MSG(cond, msg) \ +#define ASSERT_MSG(cond, fmtstr, ...) \ do \ { \ if (cond) \ @@ -25,7 +25,7 @@ __FILE__, \ __LINE__, \ #cond, \ - msg); \ + fmt::format(fmtstr, __VA_ARGS__)); \ exit(EXIT_FAILURE); \ } \ } while (false)