diff --git a/nvbench/type_axis.cu b/nvbench/type_axis.cu index 2832f5d..62d0c04 100644 --- a/nvbench/type_axis.cu +++ b/nvbench/type_axis.cu @@ -1,8 +1,29 @@ #include +#include +#include + +#include +#include + namespace nvbench { type_axis::~type_axis() = default; +std::size_t type_axis::get_index(const std::string &input_string) const +{ + auto it = + std::find(m_input_strings.cbegin(), m_input_strings.cend(), input_string); + if (it == m_input_strings.end()) + { + throw std::runtime_error( + fmt::format("{}:{}: Invalid input string '{}' for type_axis `{}`.\n" + "Valid input strings: {}", + __FILE__, __LINE__, input_string, this->get_name(), m_input_strings)); + } + + return it - m_input_strings.cbegin(); } + +} // namespace nvbench diff --git a/nvbench/type_axis.cuh b/nvbench/type_axis.cuh index b2133c0..d1df670 100644 --- a/nvbench/type_axis.cuh +++ b/nvbench/type_axis.cuh @@ -24,6 +24,8 @@ struct type_axis final : public axis_base template void set_inputs(); + std::size_t get_index(const std::string& input_string) const; + private: std::size_t do_get_size() const final { return m_input_strings.size(); } std::string do_get_input_string(std::size_t i) const final diff --git a/testing/float64_axis.cu b/testing/float64_axis.cu new file mode 100644 index 0000000..674afd0 --- /dev/null +++ b/testing/float64_axis.cu @@ -0,0 +1,4 @@ +// +// Created by allie on 12/22/2020. +// + diff --git a/testing/type_axis.cu b/testing/type_axis.cu index 5b05b42..82eaea7 100644 --- a/testing/type_axis.cu +++ b/testing/type_axis.cu @@ -47,9 +47,23 @@ void test_several() ASSERT(axis.get_description(2) == ""); } +void test_get_index() +{ + nvbench::type_axis axis("GetIndexTest"); + axis.set_inputs< + nvbench:: + type_list>(); + + ASSERT(axis.get_index("I8") == 0); + ASSERT(axis.get_index("U16") == 1); + ASSERT(axis.get_index("F32") == 2); + ASSERT(axis.get_index("bool") == 3); +} + int main() { test_empty(); test_single(); test_several(); + test_get_index(); }