Add default axis names.

Also cleaned up the annoying quirk where `set_type_axes_names` *had*
to be called on all benchmarks with type axes.

Default names are {"T", "U", "V", "W"} for up-to four type axes. For
five or more, {"T0", "T1", ...} is used instead.
This commit is contained in:
Allison Vacanti
2021-02-19 12:37:05 -05:00
parent 324b0d107e
commit 8d6d934dfe
7 changed files with 179 additions and 34 deletions

View File

@@ -1,6 +1,9 @@
#include <nvbench/axes_metadata.cuh>
#include <nvbench/detail/throw.cuh>
#include <fmt/format.h>
#include <fmt/ranges.h>
#include <algorithm>
#include <stdexcept>
@@ -28,6 +31,38 @@ axes_metadata &axes_metadata::operator=(const axes_metadata &other)
return *this;
}
void axes_metadata::set_type_axes_names(std::vector<std::string> names)
try
{
if (names.size() < m_axes.size())
{
NVBENCH_THROW(std::runtime_error,
"Number of names exceeds number of axes ({}).",
m_axes.size());
}
for (std::size_t i = 0; i < names.size(); ++i)
{
auto &axis = *m_axes[i];
if (axis.get_type() != nvbench::axis_type::type)
{
NVBENCH_THROW(std::runtime_error,
"Number of names exceeds number of type axes ({})",
i);
}
axis.set_name(std::move(names[i]));
}
}
catch (std::exception &e)
{
NVBENCH_THROW(std::runtime_error,
"Error in set_type_axes_names:\n{}\n"
"TypeAxesNames: {}",
e.what(),
names);
}
void axes_metadata::add_float64_axis(std::string name,
std::vector<nvbench::float64_t> data)
{
@@ -205,4 +240,32 @@ axis_base &axes_metadata::get_axis(std::string_view name,
return axis;
}
std::vector<std::string>
axes_metadata::generate_default_type_axis_names(std::size_t num_type_axes)
{
switch (num_type_axes)
{
case 0:
return {};
case 1:
return {"T"};
case 2:
return {"T", "U"};
case 3:
return {"T", "U", "V"};
case 4:
return {"T", "U", "V", "W"};
default:
break;
}
std::vector<std::string> result;
result.reserve(num_type_axes);
for (std::size_t i = 0; i < num_type_axes; ++i)
{
result.emplace_back(fmt::format("T{}", i));
}
return result;
}
} // namespace nvbench

View File

@@ -20,6 +20,9 @@ struct axes_metadata
{
using axes_type = std::vector<std::unique_ptr<nvbench::axis_base>>;
template <typename... TypeAxes>
explicit axes_metadata(nvbench::type_list<TypeAxes...>);
axes_metadata() = default;
axes_metadata(axes_metadata &&) = default;
axes_metadata &operator=(axes_metadata &&) = default;
@@ -27,7 +30,6 @@ struct axes_metadata
axes_metadata(const axes_metadata &);
axes_metadata &operator=(const axes_metadata &);
template <typename type_axes>
void set_type_axes_names(std::vector<std::string> names);
void add_int64_axis(std::string name,
@@ -68,32 +70,37 @@ struct axes_metadata
[[nodiscard]] nvbench::axis_base &get_axis(std::string_view name,
nvbench::axis_type type);
[[nodiscard]] static std::vector<std::string>
generate_default_type_axis_names(std::size_t num_type_axes);
private:
axes_type m_axes;
};
template <typename type_axes>
void axes_metadata::set_type_axes_names(std::vector<std::string> names)
template <typename ...TypeAxes>
axes_metadata::axes_metadata(nvbench::type_list<TypeAxes...>)
: axes_metadata{}
{
if (names.size() != nvbench::tl::size<type_axes>::value)
{ // TODO Find a way to get a better error message w/o bringing fmt
// into this header.
throw std::runtime_error("set_type_axes_names(): len(names) != "
"len(type_axes)");
}
std::size_t axis_index = 0;
auto names_iter = names.begin(); // contents will be moved from
nvbench::tl::foreach<type_axes>([&axes = m_axes, &names_iter, &axis_index](
[[maybe_unused]] auto wrapped_type) {
// Note:
// The word "type" appears 6 times in the next line.
// Every. Single. Token.
typedef typename decltype(wrapped_type)::type type_list;
auto axis = std::make_unique<nvbench::type_axis>(std::move(*names_iter++),
axis_index++);
axis->set_inputs<type_list>();
axes.push_back(std::move(axis));
});
using type_axes = nvbench::type_list<TypeAxes...>;
constexpr auto num_type_axes = nvbench::tl::size<type_axes>::value;
auto names = axes_metadata::generate_default_type_axis_names(num_type_axes);
auto names_iter = names.begin(); // contents will be moved from
nvbench::tl::foreach<type_axes>(
[&axes = m_axes, &names_iter]([[maybe_unused]] auto wrapped_type) {
// This is always called before other axes are added, so the length of the
// axes vector will be the type axis index:
const std::size_t type_axis_index = axes.size();
// Note:
// The word "type" appears 6 times in the next line.
// Every. Single. Token.
typedef typename decltype(wrapped_type)::type type_list;
auto axis = std::make_unique<nvbench::type_axis>(std::move(*names_iter++),
type_axis_index);
axis->set_inputs<type_list>();
axes.push_back(std::move(axis));
});
}
} // namespace nvbench

View File

@@ -25,6 +25,7 @@ struct axis_base
[[nodiscard]] std::unique_ptr<axis_base> clone() const;
[[nodiscard]] const std::string &get_name() const { return m_name; }
void set_name(std::string name) { m_name = std::move(name); }
[[nodiscard]] axis_type get_type() const { return m_type; }

View File

@@ -3,8 +3,8 @@
#include <nvbench/benchmark_base.cuh>
#include <nvbench/axes_metadata.cuh>
#include <nvbench/type_list.cuh>
#include <nvbench/runner.cuh>
#include <nvbench/type_list.cuh>
#include <string>
#include <vector>
@@ -42,7 +42,9 @@ struct benchmark final : public benchmark_base
static constexpr std::size_t num_type_configs =
nvbench::tl::size<type_configs>{};
using benchmark_base::benchmark_base;
benchmark::benchmark()
: benchmark_base(type_axes{})
{}
// Note that this inline virtual dtor may cause vtable issues if linking
// benchmark TUs together. That's not a likely scenario, so we'll deal with
@@ -57,7 +59,7 @@ private:
void do_set_type_axes_names(std::vector<std::string> names) final
{
m_axes.template set_type_axes_names<type_axes>(std::move(names));
m_axes.set_type_axes_names(std::move(names));
}
void do_run() final

View File

@@ -1,14 +1,8 @@
#include <nvbench/benchmark_base.cuh>
#include <nvbench/device_manager.cuh>
namespace nvbench
{
benchmark_base::benchmark_base()
: m_devices(nvbench::device_manager::get().get_devices())
{}
benchmark_base::~benchmark_base() = default;
std::unique_ptr<benchmark_base> benchmark_base::clone() const

View File

@@ -2,6 +2,7 @@
#include <nvbench/axes_metadata.cuh>
#include <nvbench/device_info.cuh>
#include <nvbench/device_manager.cuh>
#include <nvbench/state.cuh>
#include <memory>
@@ -24,7 +25,12 @@ struct runner;
*/
struct benchmark_base
{
benchmark_base();
template <typename TypeAxes>
explicit benchmark_base(TypeAxes type_axes)
: m_axes(type_axes)
, m_devices(nvbench::device_manager::get().get_devices())
{}
virtual ~benchmark_base();
/**

View File

@@ -24,10 +24,81 @@ using three_type_axes = nvbench::type_list<int_list, float_list, misc_list>;
using no_types = nvbench::type_list<>;
void test_default_type_axes_names()
{
using TL = nvbench::type_list<int>;
{
nvbench::axes_metadata axes{};
ASSERT(axes.get_axes().size() == 0);
}
{
nvbench::axes_metadata axes{nvbench::type_list<>{}};
ASSERT(axes.get_axes().size() == 0);
}
{
nvbench::axes_metadata axes{nvbench::type_list<TL>{}};
ASSERT(axes.get_axes().size() == 1);
ASSERT(axes.get_type_axis(0).get_name() == "T");
ASSERT(axes.get_type_axis(0).get_axis_index() == 0);
}
{
nvbench::axes_metadata axes{nvbench::type_list<TL, TL>{}};
ASSERT(axes.get_axes().size() == 2);
ASSERT(axes.get_type_axis(0).get_name() == "T");
ASSERT(axes.get_type_axis(0).get_axis_index() == 0);
ASSERT(axes.get_type_axis(1).get_name() == "U");
ASSERT(axes.get_type_axis(1).get_axis_index() == 1);
}
{
nvbench::axes_metadata axes{nvbench::type_list<TL, TL, TL>{}};
ASSERT(axes.get_axes().size() == 3);
ASSERT(axes.get_type_axis(0).get_name() == "T");
ASSERT(axes.get_type_axis(0).get_axis_index() == 0);
ASSERT(axes.get_type_axis(1).get_name() == "U");
ASSERT(axes.get_type_axis(1).get_axis_index() == 1);
ASSERT(axes.get_type_axis(2).get_name() == "V");
ASSERT(axes.get_type_axis(2).get_axis_index() == 2);
}
{
nvbench::axes_metadata axes{nvbench::type_list<TL, TL, TL, TL>{}};
ASSERT(axes.get_axes().size() == 4);
ASSERT(axes.get_type_axis(0).get_name() == "T");
ASSERT(axes.get_type_axis(0).get_axis_index() == 0);
ASSERT(axes.get_type_axis(1).get_name() == "U");
ASSERT(axes.get_type_axis(1).get_axis_index() == 1);
ASSERT(axes.get_type_axis(2).get_name() == "V");
ASSERT(axes.get_type_axis(2).get_axis_index() == 2);
ASSERT(axes.get_type_axis(3).get_name() == "W");
ASSERT(axes.get_type_axis(3).get_axis_index() == 3);
}
{
nvbench::axes_metadata axes{nvbench::type_list<TL, TL, TL, TL, TL>{}};
ASSERT(axes.get_axes().size() == 5);
ASSERT(axes.get_type_axis(0).get_name() == "T0");
ASSERT(axes.get_type_axis(0).get_axis_index() == 0);
ASSERT(axes.get_type_axis(1).get_name() == "T1");
ASSERT(axes.get_type_axis(1).get_axis_index() == 1);
ASSERT(axes.get_type_axis(2).get_name() == "T2");
ASSERT(axes.get_type_axis(2).get_axis_index() == 2);
ASSERT(axes.get_type_axis(3).get_name() == "T3");
ASSERT(axes.get_type_axis(3).get_axis_index() == 3);
ASSERT(axes.get_type_axis(4).get_name() == "T4");
ASSERT(axes.get_type_axis(4).get_axis_index() == 4);
}
}
void test_type_axes()
{
nvbench::axes_metadata axes;
axes.set_type_axes_names<three_type_axes>({"Integer", "Float", "Other"});
nvbench::axes_metadata axes{three_type_axes{}};
axes.set_type_axes_names({"Integer", "Float", "Other"});
ASSERT(axes.get_type_axis("Integer").get_name() == "Integer");
ASSERT(axes.get_type_axis("Float").get_name() == "Float");
@@ -143,6 +214,7 @@ void test_string_axes()
int main()
{
test_default_type_axes_names();
test_type_axes();
test_float64_axes();
test_int64_axes();