mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-03-14 20:27:24 +00:00
Add detail::state_generator.
This helper utility computes the cartesian product of the runtime axes.
This commit is contained in:
@@ -5,6 +5,7 @@ set(srcs
|
||||
params.cu
|
||||
string_axis.cu
|
||||
type_axis.cu
|
||||
detail/state_generator.cu
|
||||
)
|
||||
|
||||
# TODO shared may be a good idea to reduce compilation overhead for large
|
||||
|
||||
51
nvbench/detail/state_generator.cu
Normal file
51
nvbench/detail/state_generator.cu
Normal file
@@ -0,0 +1,51 @@
|
||||
#include <nvbench/detail/state_generator.cuh>
|
||||
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
|
||||
namespace nvbench
|
||||
{
|
||||
|
||||
namespace detail
|
||||
{
|
||||
|
||||
std::size_t state_generator::get_number_of_states() const
|
||||
{
|
||||
return std::transform_reduce(m_indices.cbegin(),
|
||||
m_indices.cend(),
|
||||
std::size_t{1},
|
||||
std::multiplies<>{},
|
||||
[](const axis_index &size_info) {
|
||||
return size_info.size;
|
||||
});
|
||||
}
|
||||
|
||||
void state_generator::init()
|
||||
{
|
||||
m_current = 0;
|
||||
m_total = this->get_number_of_states();
|
||||
for (axis_index &entry : m_indices)
|
||||
{
|
||||
entry.index = 0;
|
||||
}
|
||||
}
|
||||
|
||||
bool state_generator::iter_valid() const { return m_current < m_total; }
|
||||
|
||||
void state_generator::next()
|
||||
{
|
||||
for (axis_index& axis_info : m_indices)
|
||||
{
|
||||
axis_info.index += 1;
|
||||
if (axis_info.index >= axis_info.size)
|
||||
{
|
||||
axis_info.index = 0;
|
||||
continue; // carry the addition to the next entry in m_indices
|
||||
}
|
||||
break; // done
|
||||
}
|
||||
m_current += 1;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace nvbench
|
||||
70
nvbench/detail/state_generator.cuh
Normal file
70
nvbench/detail/state_generator.cuh
Normal file
@@ -0,0 +1,70 @@
|
||||
#pragma once
|
||||
|
||||
#include <nvbench/axis_base.cuh> // for axis_type
|
||||
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
||||
namespace nvbench
|
||||
{
|
||||
|
||||
namespace detail
|
||||
{
|
||||
|
||||
struct state_generator
|
||||
{
|
||||
struct axis_index
|
||||
{
|
||||
std::string_view axis;
|
||||
nvbench::axis_type type;
|
||||
std::size_t index;
|
||||
std::size_t size;
|
||||
};
|
||||
|
||||
void add_axis(const nvbench::axis_base &axis)
|
||||
{
|
||||
this->add_axis(axis.get_name(), axis.get_type(), axis.get_size());
|
||||
}
|
||||
|
||||
void add_axis(std::string_view axis,
|
||||
nvbench::axis_type type,
|
||||
std::size_t size)
|
||||
{
|
||||
m_indices.push_back({std::move(axis), type, std::size_t{0}, size});
|
||||
}
|
||||
|
||||
[[nodiscard]] std::size_t get_number_of_states() const;
|
||||
|
||||
// Yep, this class is its own non-STL-style iterator.
|
||||
// It's fiiiiine, we're in detail::. PRs welcome.
|
||||
//
|
||||
// Usage:
|
||||
// ```
|
||||
// state_generator sg;
|
||||
// sg.add_axis(...);
|
||||
// for (sg.init(); sg.iter_valid(); sg.next())
|
||||
// {
|
||||
// for (const auto& axis_index : sg.get_current_indices())
|
||||
// {
|
||||
// std::string axis_name = axis_index.axis;
|
||||
// nvbench::axis_type type = axis_index.type;
|
||||
// std::size_t value_index = axis_index.index;
|
||||
// }
|
||||
// }
|
||||
// ```
|
||||
void init();
|
||||
[[nodiscard]] const std::vector<axis_index> &get_current_indices()
|
||||
{
|
||||
return m_indices;
|
||||
}
|
||||
[[nodiscard]] bool iter_valid() const;
|
||||
void next();
|
||||
|
||||
private:
|
||||
std::vector<axis_index> m_indices;
|
||||
std::size_t m_current{};
|
||||
std::size_t m_total{};
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
} // namespace nvbench
|
||||
@@ -3,6 +3,7 @@ set(test_srcs
|
||||
int64_axis.cu
|
||||
float64_axis.cu
|
||||
params.cu
|
||||
state_generator.cu
|
||||
string_axis.cu
|
||||
type_axis.cu
|
||||
type_list.cu
|
||||
|
||||
112
testing/state_generator.cu
Normal file
112
testing/state_generator.cu
Normal file
@@ -0,0 +1,112 @@
|
||||
#include <nvbench/detail/state_generator.cuh>
|
||||
|
||||
#include <nvbench/axis_base.cuh>
|
||||
|
||||
#include "test_asserts.cuh"
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
void test_empty()
|
||||
{
|
||||
// no axes = one state
|
||||
nvbench::detail::state_generator sg;
|
||||
ASSERT(sg.get_number_of_states() == 1);
|
||||
sg.init();
|
||||
ASSERT(sg.iter_valid());
|
||||
sg.next();
|
||||
ASSERT(!sg.iter_valid());
|
||||
}
|
||||
|
||||
void test_single_state()
|
||||
{
|
||||
// one single-value axis = one state
|
||||
nvbench::detail::state_generator sg;
|
||||
sg.add_axis("OnlyAxis", nvbench::axis_type::string, 1);
|
||||
ASSERT(sg.get_number_of_states() == 1);
|
||||
sg.init();
|
||||
ASSERT(sg.iter_valid());
|
||||
sg.next();
|
||||
ASSERT(!sg.iter_valid());
|
||||
}
|
||||
|
||||
void test_basic()
|
||||
{
|
||||
nvbench::detail::state_generator sg;
|
||||
sg.add_axis("Axis1", nvbench::axis_type::string, 2);
|
||||
sg.add_axis("Axis2", nvbench::axis_type::string, 3);
|
||||
sg.add_axis("Axis3", nvbench::axis_type::string, 3);
|
||||
sg.add_axis("Axis4", nvbench::axis_type::string, 2);
|
||||
|
||||
ASSERT_MSG(sg.get_number_of_states() == (2 * 3 * 3 * 2),
|
||||
"Actual: {} Expected: {}",
|
||||
sg.get_number_of_states(),
|
||||
2 * 3 * 3 * 2);
|
||||
|
||||
fmt::memory_buffer buffer;
|
||||
fmt::memory_buffer line;
|
||||
std::size_t line_num{0};
|
||||
for (sg.init(); sg.iter_valid(); sg.next())
|
||||
{
|
||||
line.clear();
|
||||
fmt::format_to(line, "| {:^2}", line_num++);
|
||||
for (auto &axis_index : sg.get_current_indices())
|
||||
{
|
||||
fmt::format_to(line,
|
||||
" | {}: {}/{}",
|
||||
axis_index.axis,
|
||||
axis_index.index,
|
||||
axis_index.size);
|
||||
}
|
||||
fmt::format_to(buffer, "{} |\n", fmt::to_string(line));
|
||||
}
|
||||
|
||||
const std::string ref =
|
||||
R"expected(| 0 | Axis1: 0/2 | Axis2: 0/3 | Axis3: 0/3 | Axis4: 0/2 |
|
||||
| 1 | Axis1: 1/2 | Axis2: 0/3 | Axis3: 0/3 | Axis4: 0/2 |
|
||||
| 2 | Axis1: 0/2 | Axis2: 1/3 | Axis3: 0/3 | Axis4: 0/2 |
|
||||
| 3 | Axis1: 1/2 | Axis2: 1/3 | Axis3: 0/3 | Axis4: 0/2 |
|
||||
| 4 | Axis1: 0/2 | Axis2: 2/3 | Axis3: 0/3 | Axis4: 0/2 |
|
||||
| 5 | Axis1: 1/2 | Axis2: 2/3 | Axis3: 0/3 | Axis4: 0/2 |
|
||||
| 6 | Axis1: 0/2 | Axis2: 0/3 | Axis3: 1/3 | Axis4: 0/2 |
|
||||
| 7 | Axis1: 1/2 | Axis2: 0/3 | Axis3: 1/3 | Axis4: 0/2 |
|
||||
| 8 | Axis1: 0/2 | Axis2: 1/3 | Axis3: 1/3 | Axis4: 0/2 |
|
||||
| 9 | Axis1: 1/2 | Axis2: 1/3 | Axis3: 1/3 | Axis4: 0/2 |
|
||||
| 10 | Axis1: 0/2 | Axis2: 2/3 | Axis3: 1/3 | Axis4: 0/2 |
|
||||
| 11 | Axis1: 1/2 | Axis2: 2/3 | Axis3: 1/3 | Axis4: 0/2 |
|
||||
| 12 | Axis1: 0/2 | Axis2: 0/3 | Axis3: 2/3 | Axis4: 0/2 |
|
||||
| 13 | Axis1: 1/2 | Axis2: 0/3 | Axis3: 2/3 | Axis4: 0/2 |
|
||||
| 14 | Axis1: 0/2 | Axis2: 1/3 | Axis3: 2/3 | Axis4: 0/2 |
|
||||
| 15 | Axis1: 1/2 | Axis2: 1/3 | Axis3: 2/3 | Axis4: 0/2 |
|
||||
| 16 | Axis1: 0/2 | Axis2: 2/3 | Axis3: 2/3 | Axis4: 0/2 |
|
||||
| 17 | Axis1: 1/2 | Axis2: 2/3 | Axis3: 2/3 | Axis4: 0/2 |
|
||||
| 18 | Axis1: 0/2 | Axis2: 0/3 | Axis3: 0/3 | Axis4: 1/2 |
|
||||
| 19 | Axis1: 1/2 | Axis2: 0/3 | Axis3: 0/3 | Axis4: 1/2 |
|
||||
| 20 | Axis1: 0/2 | Axis2: 1/3 | Axis3: 0/3 | Axis4: 1/2 |
|
||||
| 21 | Axis1: 1/2 | Axis2: 1/3 | Axis3: 0/3 | Axis4: 1/2 |
|
||||
| 22 | Axis1: 0/2 | Axis2: 2/3 | Axis3: 0/3 | Axis4: 1/2 |
|
||||
| 23 | Axis1: 1/2 | Axis2: 2/3 | Axis3: 0/3 | Axis4: 1/2 |
|
||||
| 24 | Axis1: 0/2 | Axis2: 0/3 | Axis3: 1/3 | Axis4: 1/2 |
|
||||
| 25 | Axis1: 1/2 | Axis2: 0/3 | Axis3: 1/3 | Axis4: 1/2 |
|
||||
| 26 | Axis1: 0/2 | Axis2: 1/3 | Axis3: 1/3 | Axis4: 1/2 |
|
||||
| 27 | Axis1: 1/2 | Axis2: 1/3 | Axis3: 1/3 | Axis4: 1/2 |
|
||||
| 28 | Axis1: 0/2 | Axis2: 2/3 | Axis3: 1/3 | Axis4: 1/2 |
|
||||
| 29 | Axis1: 1/2 | Axis2: 2/3 | Axis3: 1/3 | Axis4: 1/2 |
|
||||
| 30 | Axis1: 0/2 | Axis2: 0/3 | Axis3: 2/3 | Axis4: 1/2 |
|
||||
| 31 | Axis1: 1/2 | Axis2: 0/3 | Axis3: 2/3 | Axis4: 1/2 |
|
||||
| 32 | Axis1: 0/2 | Axis2: 1/3 | Axis3: 2/3 | Axis4: 1/2 |
|
||||
| 33 | Axis1: 1/2 | Axis2: 1/3 | Axis3: 2/3 | Axis4: 1/2 |
|
||||
| 34 | Axis1: 0/2 | Axis2: 2/3 | Axis3: 2/3 | Axis4: 1/2 |
|
||||
| 35 | Axis1: 1/2 | Axis2: 2/3 | Axis3: 2/3 | Axis4: 1/2 |
|
||||
)expected";
|
||||
|
||||
const std::string test = fmt::to_string(buffer);
|
||||
ASSERT_MSG(test == ref,
|
||||
fmt::format("Expected:\n\"{}\"\n\nActual:\n\"{}\"", ref, test));
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
test_empty();
|
||||
test_single_state();
|
||||
test_basic();
|
||||
}
|
||||
@@ -14,7 +14,7 @@
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
#define ASSERT_MSG(cond, msg) \
|
||||
#define ASSERT_MSG(cond, fmtstr, ...) \
|
||||
do \
|
||||
{ \
|
||||
if (cond) \
|
||||
@@ -25,7 +25,7 @@
|
||||
__FILE__, \
|
||||
__LINE__, \
|
||||
#cond, \
|
||||
msg); \
|
||||
fmt::format(fmtstr, __VA_ARGS__)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
Reference in New Issue
Block a user