diff --git a/nvbench/CMakeLists.txt b/nvbench/CMakeLists.txt index ca687f3..28d9a84 100644 --- a/nvbench/CMakeLists.txt +++ b/nvbench/CMakeLists.txt @@ -5,6 +5,7 @@ set(srcs params.cu string_axis.cu type_axis.cu + state.cu detail/state_generator.cu ) diff --git a/nvbench/state.cu b/nvbench/state.cu new file mode 100644 index 0000000..94719ef --- /dev/null +++ b/nvbench/state.cu @@ -0,0 +1,41 @@ +#include + +#include + +#include +#include +#include + +namespace nvbench +{ + +nvbench::int64_t state::get_int64(const std::string &axis_name) const +{ + return std::get(m_params.at(axis_name)); +} + +nvbench::float64_t state::get_float64(const std::string &axis_name) const +{ + return std::get(m_params.at(axis_name)); +} + +const std::string &state::get_string(const std::string &axis_name) const +{ + return std::get(m_params.at(axis_name)); +} +void state::set_param(std::string axis_name, nvbench::int64_t value) +{ + m_params.insert(std::make_pair(std::move(axis_name), value)); +} + +void state::set_param(std::string axis_name, nvbench::float64_t value) +{ + m_params.insert(std::make_pair(std::move(axis_name), value)); +} + +void state::set_param(std::string axis_name, std::string value) +{ + m_params.insert(std::make_pair(std::move(axis_name), std::move(value))); +} + +} // namespace nvbench diff --git a/nvbench/state.cuh b/nvbench/state.cuh new file mode 100644 index 0000000..d896bf0 --- /dev/null +++ b/nvbench/state.cuh @@ -0,0 +1,46 @@ +#pragma once + +#include + +#include +#include +#include + +namespace nvbench +{ + +struct state +{ + // move-only + state(const state &) = delete; + state(state &&) = default; + state &operator=(const state &) = delete; + state &operator=(state &&) = default; + + [[nodiscard]] nvbench::int64_t get_int64(const std::string &axis_name) const; + + [[nodiscard]] nvbench::float64_t + get_float64(const std::string &axis_name) const; + + [[nodiscard]] const std::string & + get_string(const std::string &axis_name) const; + +protected: + state() = default; + + using param_type = + std::variant; + using params_type = std::unordered_map; + + explicit state(params_type params) + : m_params{std::move(params)} + {} + + void set_param(std::string axis_name, nvbench::int64_t value); + void set_param(std::string axis_name, nvbench::float64_t value); + void set_param(std::string axis_name, std::string value); + + params_type m_params; +}; + +} // namespace nvbench diff --git a/testing/CMakeLists.txt b/testing/CMakeLists.txt index fe3350c..2a7c16d 100644 --- a/testing/CMakeLists.txt +++ b/testing/CMakeLists.txt @@ -3,6 +3,7 @@ set(test_srcs int64_axis.cu float64_axis.cu params.cu + state.cu state_generator.cu string_axis.cu type_axis.cu diff --git a/testing/state.cu b/testing/state.cu new file mode 100644 index 0000000..9ab4172 --- /dev/null +++ b/testing/state.cu @@ -0,0 +1,48 @@ +#include + +#include + +#include "test_asserts.cuh" + +// Subclass to gain access to protected members for testing: +struct state_tester : public nvbench::state +{ + using params_type = nvbench::state::params_type; + + state_tester() + : nvbench::state() + {} + explicit state_tester(params_type params) + : nvbench::state{std::move(params)} + {} + + template + void set_param(Args &&...args) + { + this->state::set_param(std::forward(args)...); + } + + const auto &get_params() const { return m_params; } +}; + +void test_params() +{ + // Build a state param by param + state_tester state1; + state1.set_param("TestInt", nvbench::int64_t{22}); + state1.set_param("TestFloat", nvbench::float64_t{3.14}); + state1.set_param("TestString", "A String!"); + + ASSERT(state1.get_int64("TestInt") == nvbench::int64_t{22}); + ASSERT(state1.get_float64("TestFloat") == nvbench::float64_t{3.14}); + ASSERT(state1.get_string("TestString") == "A String!"); + + // Construct a state from the parameter map built above: + state_tester state2{state1.get_params()}; + + ASSERT(state2.get_int64("TestInt") == nvbench::int64_t{22}); + ASSERT(state2.get_float64("TestFloat") == nvbench::float64_t{3.14}); + ASSERT(state2.get_string("TestString") == "A String!"); +} + +int main() { test_params(); }