mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-04-20 06:48:53 +00:00
Add default axis names.
Also cleaned up the annoying quirk where `set_type_axes_names` *had*
to be called on all benchmarks with type axes.
Default names are {"T", "U", "V", "W"} for up-to four type axes. For
five or more, {"T0", "T1", ...} is used instead.
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
#include <nvbench/axes_metadata.cuh>
|
||||
|
||||
#include <nvbench/detail/throw.cuh>
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <fmt/ranges.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
@@ -28,6 +31,38 @@ axes_metadata &axes_metadata::operator=(const axes_metadata &other)
|
||||
return *this;
|
||||
}
|
||||
|
||||
void axes_metadata::set_type_axes_names(std::vector<std::string> names)
|
||||
try
|
||||
{
|
||||
if (names.size() < m_axes.size())
|
||||
{
|
||||
NVBENCH_THROW(std::runtime_error,
|
||||
"Number of names exceeds number of axes ({}).",
|
||||
m_axes.size());
|
||||
}
|
||||
|
||||
for (std::size_t i = 0; i < names.size(); ++i)
|
||||
{
|
||||
auto &axis = *m_axes[i];
|
||||
if (axis.get_type() != nvbench::axis_type::type)
|
||||
{
|
||||
NVBENCH_THROW(std::runtime_error,
|
||||
"Number of names exceeds number of type axes ({})",
|
||||
i);
|
||||
}
|
||||
|
||||
axis.set_name(std::move(names[i]));
|
||||
}
|
||||
}
|
||||
catch (std::exception &e)
|
||||
{
|
||||
NVBENCH_THROW(std::runtime_error,
|
||||
"Error in set_type_axes_names:\n{}\n"
|
||||
"TypeAxesNames: {}",
|
||||
e.what(),
|
||||
names);
|
||||
}
|
||||
|
||||
void axes_metadata::add_float64_axis(std::string name,
|
||||
std::vector<nvbench::float64_t> data)
|
||||
{
|
||||
@@ -205,4 +240,32 @@ axis_base &axes_metadata::get_axis(std::string_view name,
|
||||
return axis;
|
||||
}
|
||||
|
||||
std::vector<std::string>
|
||||
axes_metadata::generate_default_type_axis_names(std::size_t num_type_axes)
|
||||
{
|
||||
switch (num_type_axes)
|
||||
{
|
||||
case 0:
|
||||
return {};
|
||||
case 1:
|
||||
return {"T"};
|
||||
case 2:
|
||||
return {"T", "U"};
|
||||
case 3:
|
||||
return {"T", "U", "V"};
|
||||
case 4:
|
||||
return {"T", "U", "V", "W"};
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
std::vector<std::string> result;
|
||||
result.reserve(num_type_axes);
|
||||
for (std::size_t i = 0; i < num_type_axes; ++i)
|
||||
{
|
||||
result.emplace_back(fmt::format("T{}", i));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace nvbench
|
||||
|
||||
@@ -20,6 +20,9 @@ struct axes_metadata
|
||||
{
|
||||
using axes_type = std::vector<std::unique_ptr<nvbench::axis_base>>;
|
||||
|
||||
template <typename... TypeAxes>
|
||||
explicit axes_metadata(nvbench::type_list<TypeAxes...>);
|
||||
|
||||
axes_metadata() = default;
|
||||
axes_metadata(axes_metadata &&) = default;
|
||||
axes_metadata &operator=(axes_metadata &&) = default;
|
||||
@@ -27,7 +30,6 @@ struct axes_metadata
|
||||
axes_metadata(const axes_metadata &);
|
||||
axes_metadata &operator=(const axes_metadata &);
|
||||
|
||||
template <typename type_axes>
|
||||
void set_type_axes_names(std::vector<std::string> names);
|
||||
|
||||
void add_int64_axis(std::string name,
|
||||
@@ -68,32 +70,37 @@ struct axes_metadata
|
||||
[[nodiscard]] nvbench::axis_base &get_axis(std::string_view name,
|
||||
nvbench::axis_type type);
|
||||
|
||||
[[nodiscard]] static std::vector<std::string>
|
||||
generate_default_type_axis_names(std::size_t num_type_axes);
|
||||
|
||||
private:
|
||||
axes_type m_axes;
|
||||
};
|
||||
|
||||
template <typename type_axes>
|
||||
void axes_metadata::set_type_axes_names(std::vector<std::string> names)
|
||||
template <typename ...TypeAxes>
|
||||
axes_metadata::axes_metadata(nvbench::type_list<TypeAxes...>)
|
||||
: axes_metadata{}
|
||||
{
|
||||
if (names.size() != nvbench::tl::size<type_axes>::value)
|
||||
{ // TODO Find a way to get a better error message w/o bringing fmt
|
||||
// into this header.
|
||||
throw std::runtime_error("set_type_axes_names(): len(names) != "
|
||||
"len(type_axes)");
|
||||
}
|
||||
std::size_t axis_index = 0;
|
||||
auto names_iter = names.begin(); // contents will be moved from
|
||||
nvbench::tl::foreach<type_axes>([&axes = m_axes, &names_iter, &axis_index](
|
||||
[[maybe_unused]] auto wrapped_type) {
|
||||
// Note:
|
||||
// The word "type" appears 6 times in the next line.
|
||||
// Every. Single. Token.
|
||||
typedef typename decltype(wrapped_type)::type type_list;
|
||||
auto axis = std::make_unique<nvbench::type_axis>(std::move(*names_iter++),
|
||||
axis_index++);
|
||||
axis->set_inputs<type_list>();
|
||||
axes.push_back(std::move(axis));
|
||||
});
|
||||
using type_axes = nvbench::type_list<TypeAxes...>;
|
||||
constexpr auto num_type_axes = nvbench::tl::size<type_axes>::value;
|
||||
auto names = axes_metadata::generate_default_type_axis_names(num_type_axes);
|
||||
|
||||
auto names_iter = names.begin(); // contents will be moved from
|
||||
nvbench::tl::foreach<type_axes>(
|
||||
[&axes = m_axes, &names_iter]([[maybe_unused]] auto wrapped_type) {
|
||||
// This is always called before other axes are added, so the length of the
|
||||
// axes vector will be the type axis index:
|
||||
const std::size_t type_axis_index = axes.size();
|
||||
|
||||
// Note:
|
||||
// The word "type" appears 6 times in the next line.
|
||||
// Every. Single. Token.
|
||||
typedef typename decltype(wrapped_type)::type type_list;
|
||||
auto axis = std::make_unique<nvbench::type_axis>(std::move(*names_iter++),
|
||||
type_axis_index);
|
||||
axis->set_inputs<type_list>();
|
||||
axes.push_back(std::move(axis));
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace nvbench
|
||||
|
||||
@@ -25,6 +25,7 @@ struct axis_base
|
||||
[[nodiscard]] std::unique_ptr<axis_base> clone() const;
|
||||
|
||||
[[nodiscard]] const std::string &get_name() const { return m_name; }
|
||||
void set_name(std::string name) { m_name = std::move(name); }
|
||||
|
||||
[[nodiscard]] axis_type get_type() const { return m_type; }
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
#include <nvbench/benchmark_base.cuh>
|
||||
|
||||
#include <nvbench/axes_metadata.cuh>
|
||||
#include <nvbench/type_list.cuh>
|
||||
#include <nvbench/runner.cuh>
|
||||
#include <nvbench/type_list.cuh>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
@@ -42,7 +42,9 @@ struct benchmark final : public benchmark_base
|
||||
static constexpr std::size_t num_type_configs =
|
||||
nvbench::tl::size<type_configs>{};
|
||||
|
||||
using benchmark_base::benchmark_base;
|
||||
benchmark::benchmark()
|
||||
: benchmark_base(type_axes{})
|
||||
{}
|
||||
|
||||
// Note that this inline virtual dtor may cause vtable issues if linking
|
||||
// benchmark TUs together. That's not a likely scenario, so we'll deal with
|
||||
@@ -57,7 +59,7 @@ private:
|
||||
|
||||
void do_set_type_axes_names(std::vector<std::string> names) final
|
||||
{
|
||||
m_axes.template set_type_axes_names<type_axes>(std::move(names));
|
||||
m_axes.set_type_axes_names(std::move(names));
|
||||
}
|
||||
|
||||
void do_run() final
|
||||
|
||||
@@ -1,14 +1,8 @@
|
||||
#include <nvbench/benchmark_base.cuh>
|
||||
|
||||
#include <nvbench/device_manager.cuh>
|
||||
|
||||
namespace nvbench
|
||||
{
|
||||
|
||||
benchmark_base::benchmark_base()
|
||||
: m_devices(nvbench::device_manager::get().get_devices())
|
||||
{}
|
||||
|
||||
benchmark_base::~benchmark_base() = default;
|
||||
|
||||
std::unique_ptr<benchmark_base> benchmark_base::clone() const
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
#include <nvbench/axes_metadata.cuh>
|
||||
#include <nvbench/device_info.cuh>
|
||||
#include <nvbench/device_manager.cuh>
|
||||
#include <nvbench/state.cuh>
|
||||
|
||||
#include <memory>
|
||||
@@ -24,7 +25,12 @@ struct runner;
|
||||
*/
|
||||
struct benchmark_base
|
||||
{
|
||||
benchmark_base();
|
||||
template <typename TypeAxes>
|
||||
explicit benchmark_base(TypeAxes type_axes)
|
||||
: m_axes(type_axes)
|
||||
, m_devices(nvbench::device_manager::get().get_devices())
|
||||
{}
|
||||
|
||||
virtual ~benchmark_base();
|
||||
|
||||
/**
|
||||
|
||||
@@ -24,10 +24,81 @@ using three_type_axes = nvbench::type_list<int_list, float_list, misc_list>;
|
||||
|
||||
using no_types = nvbench::type_list<>;
|
||||
|
||||
void test_default_type_axes_names()
|
||||
{
|
||||
using TL = nvbench::type_list<int>;
|
||||
|
||||
{
|
||||
nvbench::axes_metadata axes{};
|
||||
ASSERT(axes.get_axes().size() == 0);
|
||||
}
|
||||
|
||||
{
|
||||
nvbench::axes_metadata axes{nvbench::type_list<>{}};
|
||||
ASSERT(axes.get_axes().size() == 0);
|
||||
}
|
||||
|
||||
{
|
||||
nvbench::axes_metadata axes{nvbench::type_list<TL>{}};
|
||||
ASSERT(axes.get_axes().size() == 1);
|
||||
ASSERT(axes.get_type_axis(0).get_name() == "T");
|
||||
ASSERT(axes.get_type_axis(0).get_axis_index() == 0);
|
||||
}
|
||||
|
||||
{
|
||||
nvbench::axes_metadata axes{nvbench::type_list<TL, TL>{}};
|
||||
ASSERT(axes.get_axes().size() == 2);
|
||||
ASSERT(axes.get_type_axis(0).get_name() == "T");
|
||||
ASSERT(axes.get_type_axis(0).get_axis_index() == 0);
|
||||
ASSERT(axes.get_type_axis(1).get_name() == "U");
|
||||
ASSERT(axes.get_type_axis(1).get_axis_index() == 1);
|
||||
}
|
||||
|
||||
{
|
||||
nvbench::axes_metadata axes{nvbench::type_list<TL, TL, TL>{}};
|
||||
ASSERT(axes.get_axes().size() == 3);
|
||||
ASSERT(axes.get_type_axis(0).get_name() == "T");
|
||||
ASSERT(axes.get_type_axis(0).get_axis_index() == 0);
|
||||
ASSERT(axes.get_type_axis(1).get_name() == "U");
|
||||
ASSERT(axes.get_type_axis(1).get_axis_index() == 1);
|
||||
ASSERT(axes.get_type_axis(2).get_name() == "V");
|
||||
ASSERT(axes.get_type_axis(2).get_axis_index() == 2);
|
||||
}
|
||||
|
||||
{
|
||||
nvbench::axes_metadata axes{nvbench::type_list<TL, TL, TL, TL>{}};
|
||||
ASSERT(axes.get_axes().size() == 4);
|
||||
ASSERT(axes.get_type_axis(0).get_name() == "T");
|
||||
ASSERT(axes.get_type_axis(0).get_axis_index() == 0);
|
||||
ASSERT(axes.get_type_axis(1).get_name() == "U");
|
||||
ASSERT(axes.get_type_axis(1).get_axis_index() == 1);
|
||||
ASSERT(axes.get_type_axis(2).get_name() == "V");
|
||||
ASSERT(axes.get_type_axis(2).get_axis_index() == 2);
|
||||
ASSERT(axes.get_type_axis(3).get_name() == "W");
|
||||
ASSERT(axes.get_type_axis(3).get_axis_index() == 3);
|
||||
}
|
||||
|
||||
{
|
||||
nvbench::axes_metadata axes{nvbench::type_list<TL, TL, TL, TL, TL>{}};
|
||||
ASSERT(axes.get_axes().size() == 5);
|
||||
ASSERT(axes.get_type_axis(0).get_name() == "T0");
|
||||
ASSERT(axes.get_type_axis(0).get_axis_index() == 0);
|
||||
ASSERT(axes.get_type_axis(1).get_name() == "T1");
|
||||
ASSERT(axes.get_type_axis(1).get_axis_index() == 1);
|
||||
ASSERT(axes.get_type_axis(2).get_name() == "T2");
|
||||
ASSERT(axes.get_type_axis(2).get_axis_index() == 2);
|
||||
ASSERT(axes.get_type_axis(3).get_name() == "T3");
|
||||
ASSERT(axes.get_type_axis(3).get_axis_index() == 3);
|
||||
ASSERT(axes.get_type_axis(4).get_name() == "T4");
|
||||
ASSERT(axes.get_type_axis(4).get_axis_index() == 4);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void test_type_axes()
|
||||
{
|
||||
nvbench::axes_metadata axes;
|
||||
axes.set_type_axes_names<three_type_axes>({"Integer", "Float", "Other"});
|
||||
nvbench::axes_metadata axes{three_type_axes{}};
|
||||
axes.set_type_axes_names({"Integer", "Float", "Other"});
|
||||
|
||||
ASSERT(axes.get_type_axis("Integer").get_name() == "Integer");
|
||||
ASSERT(axes.get_type_axis("Float").get_name() == "Float");
|
||||
@@ -143,6 +214,7 @@ void test_string_axes()
|
||||
|
||||
int main()
|
||||
{
|
||||
test_default_type_axes_names();
|
||||
test_type_axes();
|
||||
test_float64_axes();
|
||||
test_int64_axes();
|
||||
|
||||
Reference in New Issue
Block a user