Add type_axis::get_index.

This commit is contained in:
Allison Vacanti
2020-12-22 16:37:19 -05:00
parent 07e4cc36c2
commit 65eadda1c1
4 changed files with 41 additions and 0 deletions

View File

@@ -1,8 +1,29 @@
#include <nvbench/type_axis.cuh>
#include <fmt/format.h>
#include <fmt/ranges.h>
#include <algorithm>
#include <stdexcept>
namespace nvbench
{
type_axis::~type_axis() = default;
std::size_t type_axis::get_index(const std::string &input_string) const
{
auto it =
std::find(m_input_strings.cbegin(), m_input_strings.cend(), input_string);
if (it == m_input_strings.end())
{
throw std::runtime_error(
fmt::format("{}:{}: Invalid input string '{}' for type_axis `{}`.\n"
"Valid input strings: {}",
__FILE__, __LINE__, input_string, this->get_name(), m_input_strings));
}
return it - m_input_strings.cbegin();
}
} // namespace nvbench

View File

@@ -24,6 +24,8 @@ struct type_axis final : public axis_base
template <typename TypeList>
void set_inputs();
std::size_t get_index(const std::string& input_string) const;
private:
std::size_t do_get_size() const final { return m_input_strings.size(); }
std::string do_get_input_string(std::size_t i) const final

4
testing/float64_axis.cu Normal file
View File

@@ -0,0 +1,4 @@
//
// Created by allie on 12/22/2020.
//

View File

@@ -47,9 +47,23 @@ void test_several()
ASSERT(axis.get_description(2) == "");
}
void test_get_index()
{
nvbench::type_axis axis("GetIndexTest");
axis.set_inputs<
nvbench::
type_list<nvbench::int8_t, nvbench::uint16_t, nvbench::float32_t, bool>>();
ASSERT(axis.get_index("I8") == 0);
ASSERT(axis.get_index("U16") == 1);
ASSERT(axis.get_index("F32") == 2);
ASSERT(axis.get_index("bool") == 3);
}
int main()
{
test_empty();
test_single();
test_several();
test_get_index();
}