Refactoring / renaming.

This commit is contained in:
Allison Piper
2025-05-02 20:30:23 +00:00
parent a2bf266e16
commit c8909c7d1b
13 changed files with 231 additions and 199 deletions

View File

@@ -96,13 +96,13 @@ struct under_diag final : nvbench::user_axis_space
mutable std::size_t y_pos = 0;
mutable std::size_t x_start = 0;
nvbench::detail::axis_space_iterator do_get_iterator(axes_info info) const
nvbench::detail::axis_space_iterator do_get_iterator(axis_value_indices info) const
{
// generate our increment function
auto adv_func = [&, info](std::size_t &inc_index, std::size_t /*len*/) -> bool {
inc_index++;
x_pos++;
if (x_pos == info[0].size)
if (x_pos == info[0].axis_size)
{
x_pos = ++x_start;
y_pos = x_start;
@@ -112,25 +112,24 @@ struct under_diag final : nvbench::user_axis_space
};
// our update function
auto diag_under = [&, info](std::size_t,
std::vector<nvbench::detail::axis_index>::iterator start,
std::vector<nvbench::detail::axis_index>::iterator end) {
start->index = x_pos;
end->index = y_pos;
};
auto diag_under =
[&, info](std::size_t, axis_value_indices::iterator start, axis_value_indices::iterator end) {
start->value_index = x_pos;
end->value_index = y_pos;
};
const size_t iteration_length = ((info[0].size * (info[1].size + 1)) / 2);
const size_t iteration_length = ((info[0].axis_size * (info[1].axis_size + 1)) / 2);
return nvbench::detail::axis_space_iterator(info, iteration_length, adv_func, diag_under);
}
std::size_t do_get_size(const axes_info &info) const
std::size_t do_get_size(const axis_value_indices &info) const
{
return ((info[0].size * (info[1].size + 1)) / 2);
return ((info[0].axis_size * (info[1].axis_size + 1)) / 2);
}
std::size_t do_get_active_count(const axes_info &info) const
std::size_t do_get_active_count(const axis_value_indices &info) const
{
return ((info[0].size * (info[1].size + 1)) / 2);
return ((info[0].axis_size * (info[1].axis_size + 1)) / 2);
}
std::unique_ptr<nvbench::iteration_space_base> do_clone() const
@@ -160,36 +159,38 @@ struct gauss final : nvbench::user_axis_space
: nvbench::user_axis_space(std::move(input_indices))
{}
nvbench::detail::axis_space_iterator do_get_iterator(axes_info info) const
nvbench::detail::axis_space_iterator do_get_iterator(axis_value_indices info) const
{
const double mid_point = static_cast<double>((info[0].size / 2));
const double mid_point = static_cast<double>((info[0].axis_size / 2));
std::random_device rd{};
std::mt19937 gen{rd()};
std::normal_distribution<> d{mid_point, 2};
const size_t iteration_length = info[0].size;
const size_t iteration_length = info[0].axis_size;
std::vector<std::size_t> gauss_indices(iteration_length);
for (auto &g : gauss_indices)
{
auto v = std::min(static_cast<double>(info[0].size), d(gen));
auto v = std::min(static_cast<double>(info[0].axis_size), d(gen));
v = std::max(0.0, v);
g = static_cast<std::size_t>(v);
}
// our update function
auto gauss_func = [=](std::size_t index,
std::vector<nvbench::detail::axis_index>::iterator start,
std::vector<nvbench::detail::axis_index>::iterator) {
start->index = gauss_indices[index];
};
auto gauss_func =
[=](std::size_t index, axis_value_indices::iterator start, axis_value_indices::iterator) {
start->value_index = gauss_indices[index];
};
return nvbench::detail::axis_space_iterator(info, iteration_length, gauss_func);
}
std::size_t do_get_size(const axes_info &info) const { return info[0].size; }
std::size_t do_get_size(const axis_value_indices &info) const { return info[0].axis_size; }
std::size_t do_get_active_count(const axes_info &info) const { return info[0].size; }
std::size_t do_get_active_count(const axis_value_indices &info) const
{
return info[0].axis_size;
}
std::unique_ptr<iteration_space_base> do_clone() const { return std::make_unique<gauss>(*this); }
};

View File

@@ -30,76 +30,86 @@ namespace nvbench
namespace detail
{
struct axis_index
// Tracks current value and axis information used while iterating through axes.
struct axis_value_index
{
axis_index() = default;
axis_value_index() = default;
explicit axis_index(const axis_base *axis)
: index(0)
, name(axis->get_name())
, type(axis->get_type())
, size(axis->get_size())
, active_size(axis->get_size())
{
if (type == nvbench::axis_type::type)
{
active_size = static_cast<const nvbench::type_axis *>(axis)->get_active_count();
}
}
std::size_t index;
std::string name;
nvbench::axis_type type;
std::size_t size;
std::size_t active_size;
explicit axis_value_index(const axis_base *axis)
: value_index(0)
, axis_name(axis->get_name())
, axis_type(axis->get_type())
, axis_size(axis->get_size())
, axis_active_size(axis_type == nvbench::axis_type::type
? static_cast<const nvbench::type_axis *>(axis)->get_active_count()
: axis->get_size())
{}
std::size_t value_index;
std::string axis_name;
nvbench::axis_type axis_type;
std::size_t axis_size;
std::size_t axis_active_size;
};
struct axis_space_iterator
{
using axes_info = std::vector<detail::axis_index>;
using AdvanceSignature = bool(std::size_t &current_index, std::size_t length);
using UpdateSignature = void(std::size_t index,
axes_info::iterator start,
axes_info::iterator end);
using axis_value_indices = std::vector<detail::axis_value_index>;
using advance_signature = bool(std::size_t &current_iteration, std::size_t iteration_size);
using update_signature = void(std::size_t current_iteration,
axis_value_indices::iterator start_axis_value_info,
axis_value_indices::iterator end_axis_value_info);
axis_space_iterator(std::vector<detail::axis_index> info,
std::size_t iter_count,
std::function<axis_space_iterator::AdvanceSignature> &&advance,
std::function<axis_space_iterator::UpdateSignature> &&update)
: m_info(info)
, m_iteration_size(iter_count)
axis_space_iterator(axis_value_indices info,
std::size_t iteration_size,
std::function<axis_space_iterator::advance_signature> &&advance,
std::function<axis_space_iterator::update_signature> &&update)
: m_iteration_size(iteration_size)
, m_axis_value_indices(std::move(info))
, m_advance(std::move(advance))
, m_update(std::move(update))
{}
axis_space_iterator(std::vector<detail::axis_index> info,
axis_space_iterator(axis_value_indices info,
std::size_t iter_count,
std::function<axis_space_iterator::UpdateSignature> &&update)
: m_info(info)
, m_iteration_size(iter_count)
std::function<axis_space_iterator::update_signature> &&update)
: m_iteration_size(iter_count)
, m_axis_value_indices(std::move(info))
, m_update(std::move(update))
{}
[[nodiscard]] bool next() { return this->m_advance(m_current_index, m_iteration_size); }
[[nodiscard]] bool next() { return m_advance(m_current_iteration, m_iteration_size); }
void update_indices(std::vector<axis_index> &indices) const
void update_axis_value_indices(axis_value_indices &info) const
{
using diff_t = typename axes_info::difference_type;
indices.insert(indices.end(), m_info.begin(), m_info.end());
axes_info::iterator end = indices.end();
axes_info::iterator start = end - static_cast<diff_t>(m_info.size());
this->m_update(m_current_index, start, end);
using diff_t = typename axis_value_indices::difference_type;
info.insert(info.end(), m_axis_value_indices.begin(), m_axis_value_indices.end());
axis_value_indices::iterator end = info.end();
axis_value_indices::iterator start = end - static_cast<diff_t>(m_axis_value_indices.size());
m_update(m_current_iteration, start, end);
}
axes_info m_info;
std::size_t m_iteration_size = 1;
std::function<AdvanceSignature> m_advance = [](std::size_t &current_index, std::size_t length) {
(current_index + 1 == length) ? current_index = 0 : current_index++;
return (current_index == 0); // we rolled over
};
std::function<UpdateSignature> m_update = nullptr;
[[nodiscard]] const axis_value_indices &get_axis_value_indices() const
{
return m_axis_value_indices;
}
[[nodiscard]] axis_value_indices &get_axis_value_indices() { return m_axis_value_indices; }
[[nodiscard]] std::size_t get_iteration_size() const { return m_iteration_size; }
private:
std::size_t m_current_index = 0;
std::size_t m_current_iteration = 0;
std::size_t m_iteration_size = 1;
axis_value_indices m_axis_value_indices;
std::function<advance_signature> m_advance = [](std::size_t &current_iteration,
std::size_t iteration_size) {
(current_iteration + 1 == iteration_size) ? current_iteration = 0 : current_iteration++;
return (current_iteration == 0); // we rolled over
};
std::function<update_signature> m_update = nullptr;
};
} // namespace detail

View File

@@ -78,13 +78,12 @@ struct state_iterator
[[nodiscard]] std::size_t get_number_of_states() const;
void init();
[[nodiscard]] std::vector<axis_index> get_current_indices() const;
[[nodiscard]] std::vector<axis_value_index> get_current_axis_value_indices() const;
[[nodiscard]] bool iter_valid() const;
void next();
std::vector<axis_space_iterator> m_space;
std::vector<axis_space_iterator> m_axis_space_iterators;
std::size_t m_axes_count = 0;
std::size_t m_current_space = 0;
std::size_t m_current_iteration = 0;
std::size_t m_max_iteration = 1;
};

View File

@@ -18,6 +18,7 @@
#include <nvbench/benchmark_base.cuh>
#include <nvbench/detail/state_generator.cuh>
#include <nvbench/detail/throw.cuh>
#include <nvbench/detail/transform_reduce.cuh>
#include <nvbench/device_info.cuh>
#include <nvbench/named_values.cuh>
@@ -25,6 +26,7 @@
#include <algorithm>
#include <cassert>
#include <exception>
#include <functional>
#include <numeric>
@@ -34,10 +36,10 @@ namespace nvbench::detail
void state_iterator::add_iteration_space(const nvbench::detail::axis_space_iterator &iter)
{
m_axes_count += iter.m_info.size();
m_max_iteration *= iter.m_iteration_size;
m_axes_count += iter.get_axis_value_indices().size();
m_max_iteration *= iter.get_iteration_size();
m_space.push_back(std::move(iter));
m_axis_space_iterators.push_back(std::move(iter));
}
[[nodiscard]] std::size_t state_iterator::get_number_of_states() const
@@ -45,22 +47,26 @@ void state_iterator::add_iteration_space(const nvbench::detail::axis_space_itera
return this->m_max_iteration;
}
void state_iterator::init()
{
m_current_space = 0;
m_current_iteration = 0;
}
void state_iterator::init() { m_current_iteration = 0; }
[[nodiscard]] std::vector<axis_index> state_iterator::get_current_indices() const
[[nodiscard]] std::vector<axis_value_index> state_iterator::get_current_axis_value_indices() const
{
std::vector<axis_index> indices;
indices.reserve(m_axes_count);
for (auto &m : m_space)
std::vector<axis_value_index> info;
info.reserve(m_axes_count);
for (auto &iter : m_axis_space_iterators)
{
m.update_indices(indices);
iter.update_axis_value_indices(info);
}
// verify length
return indices;
if (info.size() != m_axes_count)
{
NVBENCH_THROW(std::runtime_error,
"Internal error: State iterator has {} axes, but only {} were updated.",
m_axes_count,
info.size());
}
return info;
}
[[nodiscard]] bool state_iterator::iter_valid() const
@@ -72,9 +78,9 @@ void state_iterator::next()
{
m_current_iteration++;
for (auto &&space : this->m_space)
for (auto &iter : this->m_axis_space_iterators)
{
auto rolled_over = space.next();
const auto rolled_over = iter.next();
if (rolled_over)
{
continue;
@@ -128,13 +134,13 @@ void state_generator::build_axis_configs()
auto &[config, active_mask] =
m_type_axis_configs.emplace_back(std::make_pair(nvbench::named_values{}, true));
for (const auto &axis_info : ti.get_current_indices())
for (const auto &info : ti.get_current_axis_value_indices())
{
const auto &axis = axes.get_type_axis(axis_info.name);
const auto &axis = axes.get_type_axis(info.axis_name);
active_mask &= axis.get_is_active(axis_info.index);
active_mask &= axis.get_is_active(info.value_index);
config.set_string(axis.get_name(), axis.get_input_string(axis_info.index));
config.set_string(axis.get_name(), axis.get_input_string(info.value_index));
}
}
@@ -143,30 +149,33 @@ void state_generator::build_axis_configs()
auto &config = m_non_type_axis_configs.emplace_back();
// Add non-type parameters to state:
for (const auto &axis_info : vi.get_current_indices())
for (const auto &axis_value : vi.get_current_axis_value_indices())
{
switch (axis_info.type)
switch (axis_value.axis_type)
{
default:
case axis_type::type:
assert("unreachable." && false);
break;
case axis_type::int64:
config.set_int64(axis_info.name,
axes.get_int64_axis(axis_info.name).get_value(axis_info.index));
config.set_int64(
axis_value.axis_name,
axes.get_int64_axis(axis_value.axis_name).get_value(axis_value.value_index));
break;
case axis_type::float64:
config.set_float64(axis_info.name,
axes.get_float64_axis(axis_info.name).get_value(axis_info.index));
config.set_float64(
axis_value.axis_name,
axes.get_float64_axis(axis_value.axis_name).get_value(axis_value.value_index));
break;
case axis_type::string:
config.set_string(axis_info.name,
axes.get_string_axis(axis_info.name).get_value(axis_info.index));
config.set_string(
axis_value.axis_name,
axes.get_string_axis(axis_value.axis_name).get_value(axis_value.value_index));
break;
} // switch (type)
} // for (axis_info : current_indices)
} // for (axis_values)
}
if (m_type_axis_configs.empty())

View File

@@ -51,30 +51,24 @@ namespace nvbench
*/
struct iteration_space_base
{
using axes_type = std::vector<std::unique_ptr<nvbench::axis_base>>;
using axes_info = std::vector<detail::axis_index>;
using axes_type = std::vector<std::unique_ptr<nvbench::axis_base>>;
using axis_value_indices = std::vector<detail::axis_value_index>;
using AdvanceSignature = nvbench::detail::axis_space_iterator::AdvanceSignature;
using UpdateSignature = nvbench::detail::axis_space_iterator::UpdateSignature;
using advance_signature = nvbench::detail::axis_space_iterator::advance_signature;
using update_signature = nvbench::detail::axis_space_iterator::update_signature;
/*!
* Construct a new derived iteration_space
*
* The input_indices and output_indices combine together to allow the iteration space to know
* what axes they should query from axes_metadata and where each of those map to in the output
* iteration space.
* @param[input_indices] recorded indices of each axis from the axes metadata value space
* @param[input_axis_indices] Index of each associated axis in axes_metadata.
*/
iteration_space_base(std::vector<std::size_t> input_indices);
iteration_space_base(std::vector<std::size_t> input_axis_indices);
virtual ~iteration_space_base();
[[nodiscard]] std::unique_ptr<iteration_space_base> clone() const;
/*!
* Returns the iterator over the @a axis provided
*
* @param[axes]
*
* Returns the iterator over the @a axes provided
*/
[[nodiscard]] detail::axis_space_iterator get_iterator(const axes_type &axes) const;
@@ -97,12 +91,12 @@ struct iteration_space_base
[[nodiscard]] std::size_t get_active_count(const axes_type &axes) const;
protected:
std::vector<std::size_t> m_input_indices;
std::vector<std::size_t> m_axis_indices;
virtual std::unique_ptr<iteration_space_base> do_clone() const = 0;
virtual detail::axis_space_iterator do_get_iterator(axes_info info) const = 0;
virtual std::size_t do_get_size(const axes_info &info) const = 0;
virtual std::size_t do_get_active_count(const axes_info &info) const = 0;
virtual std::unique_ptr<iteration_space_base> do_clone() const = 0;
virtual detail::axis_space_iterator do_get_iterator(axis_value_indices info) const = 0;
virtual std::size_t do_get_size(const axis_value_indices &info) const = 0;
virtual std::size_t do_get_active_count(const axis_value_indices &info) const = 0;
};
} // namespace nvbench

View File

@@ -23,8 +23,8 @@
namespace nvbench
{
iteration_space_base::iteration_space_base(std::vector<std::size_t> input_indices)
: m_input_indices(std::move(input_indices))
iteration_space_base::iteration_space_base(std::vector<std::size_t> input_axis_indices)
: m_axis_indices(std::move(input_axis_indices))
{}
iteration_space_base::~iteration_space_base() = default;
@@ -37,15 +37,15 @@ std::unique_ptr<iteration_space_base> iteration_space_base::clone() const
namespace
{
nvbench::iteration_space_base::axes_info
get_axes_info(const nvbench::iteration_space_base::axes_type &axes,
const std::vector<std::size_t> &indices)
nvbench::iteration_space_base::axis_value_indices
get_axis_value_indices(const nvbench::iteration_space_base::axes_type &axes,
const std::vector<std::size_t> &indices)
{
nvbench::iteration_space_base::axes_info info;
nvbench::iteration_space_base::axis_value_indices info;
info.reserve(indices.size());
for (auto &n : indices)
for (auto &idx : indices)
{
info.emplace_back(axes[n].get());
info.emplace_back(axes[idx].get());
}
return info;
}
@@ -53,16 +53,16 @@ get_axes_info(const nvbench::iteration_space_base::axes_type &axes,
detail::axis_space_iterator iteration_space_base::get_iterator(const axes_type &axes) const
{
return this->do_get_iterator(get_axes_info(axes, m_input_indices));
return this->do_get_iterator(get_axis_value_indices(axes, m_axis_indices));
}
std::size_t iteration_space_base::get_size(const axes_type &axes) const
{
return this->do_get_size(get_axes_info(axes, m_input_indices));
return this->do_get_size(get_axis_value_indices(axes, m_axis_indices));
}
std::size_t iteration_space_base::get_active_count(const axes_type &axes) const
{
return this->do_get_active_count(get_axes_info(axes, m_input_indices));
return this->do_get_active_count(get_axis_value_indices(axes, m_axis_indices));
}
} // namespace nvbench

View File

@@ -27,17 +27,16 @@ namespace nvbench
* Provides linear forward iteration over a single axis.
*
* The default for all axes added to a benchmark
*
*/
struct linear_axis_space final : iteration_space_base
{
linear_axis_space(std::size_t in);
linear_axis_space(std::size_t axis_index);
~linear_axis_space();
std::unique_ptr<iteration_space_base> do_clone() const override;
detail::axis_space_iterator do_get_iterator(axes_info info) const override;
std::size_t do_get_size(const axes_info &info) const override;
std::size_t do_get_active_count(const axes_info &info) const override;
detail::axis_space_iterator do_get_iterator(axis_value_indices info) const override;
std::size_t do_get_size(const axis_value_indices &info) const override;
std::size_t do_get_active_count(const axis_value_indices &info) const override;
};
} // namespace nvbench

View File

@@ -23,26 +23,30 @@
namespace nvbench
{
linear_axis_space::linear_axis_space(std::size_t in_index)
: iteration_space_base({in_index})
linear_axis_space::linear_axis_space(std::size_t axis_index)
: iteration_space_base({axis_index})
{}
linear_axis_space::~linear_axis_space() = default;
detail::axis_space_iterator linear_axis_space::do_get_iterator(axes_info info) const
detail::axis_space_iterator linear_axis_space::do_get_iterator(axis_value_indices info) const
{
auto update_func = [=](std::size_t inc_index, axes_info::iterator start, axes_info::iterator) {
start->index = inc_index;
};
auto update_func = [](std::size_t current_iteration,
axis_value_indices::iterator start,
axis_value_indices::iterator) { start->value_index = current_iteration; };
return detail::axis_space_iterator(info, info[0].size, update_func);
const auto axis_size = info[0].axis_size;
return detail::axis_space_iterator(std::move(info), axis_size, update_func);
}
std::size_t linear_axis_space::do_get_size(const axes_info &info) const { return info[0].size; }
std::size_t linear_axis_space::do_get_active_count(const axes_info &info) const
std::size_t linear_axis_space::do_get_size(const axis_value_indices &info) const
{
return info[0].active_size;
return info[0].axis_size;
}
std::size_t linear_axis_space::do_get_active_count(const axis_value_indices &info) const
{
return info[0].axis_active_size;
}
std::unique_ptr<iteration_space_base> linear_axis_space::do_clone() const

View File

@@ -35,7 +35,7 @@ namespace nvbench
* : nvbench::user_axis_space(std::move(input_indices))
* {}
*
* nvbench::detail::axis_space_iterator do_get_iterator(axes_info info) const
* nvbench::detail::axis_space_iterator do_get_iterator(axis_value_indices info) const
* {
* // our increment function
* auto adv_func = [](std::size_t &inc_index,
@@ -46,19 +46,19 @@ namespace nvbench
*
* // our update function
* auto update_func = [](std::size_t inc_index,
* axes_info::iterator start,
* axes_info::iterator end) {
* axis_value_indices::iterator start,
* axis_value_indices::iterator end) {
* for (; start != end; ++start) {
* start->index = inc_index;
* }
* };
* return detail::axis_space_iterator(info, (info[0].size/3),
* return detail::axis_space_iterator(info, (info[0].axis_size/3),
* adv_func, update_func);
* }
*
* std::size_t do_get_size(const axes_info &info) const
* std::size_t do_get_size(const axis_value_indices &info) const
* {
* return (info[0].size/3);
* return (info[0].axis_size/3);
* }
* ...
* };

View File

@@ -37,13 +37,13 @@ namespace nvbench
*/
struct zip_axis_space final : iteration_space_base
{
zip_axis_space(std::vector<std::size_t> input_indices);
zip_axis_space(std::vector<std::size_t> input_axis_indices);
~zip_axis_space();
std::unique_ptr<iteration_space_base> do_clone() const override;
detail::axis_space_iterator do_get_iterator(axes_info info) const override;
std::size_t do_get_size(const axes_info &info) const override;
std::size_t do_get_active_count(const axes_info &info) const override;
detail::axis_space_iterator do_get_iterator(axis_value_indices info) const override;
std::size_t do_get_size(const axis_value_indices &info) const override;
std::size_t do_get_active_count(const axis_value_indices &info) const override;
};
} // namespace nvbench

View File

@@ -18,35 +18,51 @@
#include "zip_axis_space.cuh"
#include <nvbench/detail/throw.cuh>
#include <nvbench/type_axis.cuh>
#include <exception>
namespace nvbench
{
zip_axis_space::zip_axis_space(std::vector<std::size_t> input_indices)
: iteration_space_base(std::move(input_indices))
zip_axis_space::zip_axis_space(std::vector<std::size_t> input_axis_indices)
: iteration_space_base(std::move(input_axis_indices))
{}
zip_axis_space::~zip_axis_space() = default;
detail::axis_space_iterator zip_axis_space::do_get_iterator(axes_info info) const
detail::axis_space_iterator zip_axis_space::do_get_iterator(axis_value_indices info) const
{
auto update_func =
[=](std::size_t inc_index, axes_info::iterator start, axes_info::iterator end) {
for (; start != end; ++start)
{
start->index = inc_index;
}
};
const auto axis_size = info[0].axis_size;
for (const auto &axis : info)
{
if (axis.axis_active_size != axis_size)
{
NVBENCH_THROW(std::runtime_error, "%s", "All zipped axes must have the same size.");
}
}
return detail::axis_space_iterator(info, info[0].size, update_func);
auto update_func = [](std::size_t current_iteration,
axis_value_indices::iterator start_axis_value_info,
axis_value_indices::iterator end_axis_value_info) {
for (; start_axis_value_info != end_axis_value_info; ++start_axis_value_info)
{
start_axis_value_info->value_index = current_iteration;
}
};
return detail::axis_space_iterator(std::move(info), axis_size, update_func);
}
std::size_t zip_axis_space::do_get_size(const axes_info &info) const { return info[0].size; }
std::size_t zip_axis_space::do_get_active_count(const axes_info &info) const
std::size_t zip_axis_space::do_get_size(const axis_value_indices &info) const
{
return info[0].active_size;
return info[0].axis_size;
}
std::size_t zip_axis_space::do_get_active_count(const axis_value_indices &info) const
{
return info[0].axis_active_size;
}
std::unique_ptr<iteration_space_base> zip_axis_space::do_clone() const

View File

@@ -188,13 +188,13 @@ struct under_diag final : nvbench::user_axis_space
mutable std::size_t y_pos = 0;
mutable std::size_t x_start = 0;
nvbench::detail::axis_space_iterator do_get_iterator(axes_info info) const
nvbench::detail::axis_space_iterator do_get_iterator(axis_value_indices info) const
{
// generate our increment function
auto adv_func = [&, info](std::size_t &inc_index, std::size_t /*len*/) -> bool {
inc_index++;
x_pos++;
if (x_pos == info[0].size)
if (x_pos == info[0].axis_size)
{
x_pos = ++x_start;
y_pos = x_start;
@@ -204,25 +204,24 @@ struct under_diag final : nvbench::user_axis_space
};
// our update function
auto diag_under = [&, info](std::size_t,
std::vector<nvbench::detail::axis_index>::iterator start,
std::vector<nvbench::detail::axis_index>::iterator end) {
start->index = x_pos;
end->index = y_pos;
};
auto diag_under =
[&, info](std::size_t, axis_value_indices::iterator start, axis_value_indices::iterator end) {
start->value_index = x_pos;
end->value_index = y_pos;
};
const size_t iteration_length = ((info[0].size * (info[1].size + 1)) / 2);
const size_t iteration_length = ((info[0].axis_size * (info[1].axis_size + 1)) / 2);
return nvbench::detail::axis_space_iterator(info, iteration_length, adv_func, diag_under);
}
std::size_t do_get_size(const axes_info &info) const
std::size_t do_get_size(const axis_value_indices &info) const
{
return ((info[0].size * (info[1].size + 1)) / 2);
return ((info[0].axis_size * (info[1].axis_size + 1)) / 2);
}
std::size_t do_get_active_count(const axes_info &info) const
std::size_t do_get_active_count(const axis_value_indices &info) const
{
return ((info[0].size * (info[1].size + 1)) / 2);
return ((info[0].axis_size * (info[1].axis_size + 1)) / 2);
}
std::unique_ptr<nvbench::iteration_space_base> do_clone() const

View File

@@ -65,11 +65,12 @@ void test_single_state()
ASSERT(sg.get_number_of_states() == 1);
sg.init();
ASSERT(sg.iter_valid());
ASSERT(sg.get_current_indices().size() == 1);
ASSERT(sg.get_current_indices()[0].name == "OnlyAxis");
ASSERT(sg.get_current_indices()[0].index == 0);
ASSERT(sg.get_current_indices()[0].size == 1);
ASSERT(sg.get_current_indices()[0].type == nvbench::axis_type::string);
ASSERT(sg.get_current_axis_value_indices().size() == 1);
ASSERT(sg.get_current_axis_value_indices()[0].axis_name == "OnlyAxis");
ASSERT(sg.get_current_axis_value_indices()[0].axis_size == 1);
ASSERT(sg.get_current_axis_value_indices()[0].axis_active_size == 1);
ASSERT(sg.get_current_axis_value_indices()[0].axis_type == nvbench::axis_type::string);
ASSERT(sg.get_current_axis_value_indices()[0].value_index == 0);
sg.next();
ASSERT(!sg.iter_valid());
@@ -112,14 +113,14 @@ void test_basic()
{
line.clear();
fmt::format_to(std::back_inserter(line), "| {:^2}", line_num++);
for (auto &axis_index : sg.get_current_indices())
for (auto &axis_value : sg.get_current_axis_value_indices())
{
ASSERT(axis_index.type == nvbench::axis_type::string);
ASSERT(axis_value.axis_type == nvbench::axis_type::string);
fmt::format_to(std::back_inserter(line),
" | {}: {}/{}",
axis_index.name,
axis_index.index,
axis_index.size);
axis_value.axis_name,
axis_value.value_index,
axis_value.axis_size);
}
fmt::format_to(std::back_inserter(buffer), "{} |\n", fmt::to_string(line));
}