Add nvbench::state.

This class holds a single value for each runtime axis.
This commit is contained in:
Allison Vacanti
2020-12-27 10:44:22 -05:00
parent 7b14ceb3fe
commit 093077de5f
5 changed files with 137 additions and 0 deletions

View File

@@ -5,6 +5,7 @@ set(srcs
params.cu
string_axis.cu
type_axis.cu
state.cu
detail/state_generator.cu
)

41
nvbench/state.cu Normal file
View File

@@ -0,0 +1,41 @@
#include <nvbench/state.cuh>
#include <nvbench/types.cuh>
#include <string>
#include <unordered_map>
#include <variant>
namespace nvbench
{
nvbench::int64_t state::get_int64(const std::string &axis_name) const
{
return std::get<nvbench::int64_t>(m_params.at(axis_name));
}
nvbench::float64_t state::get_float64(const std::string &axis_name) const
{
return std::get<nvbench::float64_t>(m_params.at(axis_name));
}
const std::string &state::get_string(const std::string &axis_name) const
{
return std::get<std::string>(m_params.at(axis_name));
}
void state::set_param(std::string axis_name, nvbench::int64_t value)
{
m_params.insert(std::make_pair(std::move(axis_name), value));
}
void state::set_param(std::string axis_name, nvbench::float64_t value)
{
m_params.insert(std::make_pair(std::move(axis_name), value));
}
void state::set_param(std::string axis_name, std::string value)
{
m_params.insert(std::make_pair(std::move(axis_name), std::move(value)));
}
} // namespace nvbench

46
nvbench/state.cuh Normal file
View File

@@ -0,0 +1,46 @@
#pragma once
#include <nvbench/types.cuh>
#include <string>
#include <unordered_map>
#include <variant>
namespace nvbench
{
struct state
{
// move-only
state(const state &) = delete;
state(state &&) = default;
state &operator=(const state &) = delete;
state &operator=(state &&) = default;
[[nodiscard]] nvbench::int64_t get_int64(const std::string &axis_name) const;
[[nodiscard]] nvbench::float64_t
get_float64(const std::string &axis_name) const;
[[nodiscard]] const std::string &
get_string(const std::string &axis_name) const;
protected:
state() = default;
using param_type =
std::variant<nvbench::int64_t, nvbench::float64_t, std::string>;
using params_type = std::unordered_map<std::string, param_type>;
explicit state(params_type params)
: m_params{std::move(params)}
{}
void set_param(std::string axis_name, nvbench::int64_t value);
void set_param(std::string axis_name, nvbench::float64_t value);
void set_param(std::string axis_name, std::string value);
params_type m_params;
};
} // namespace nvbench

View File

@@ -3,6 +3,7 @@ set(test_srcs
int64_axis.cu
float64_axis.cu
params.cu
state.cu
state_generator.cu
string_axis.cu
type_axis.cu

48
testing/state.cu Normal file
View File

@@ -0,0 +1,48 @@
#include <nvbench/state.cuh>
#include <nvbench/types.cuh>
#include "test_asserts.cuh"
// Subclass to gain access to protected members for testing:
struct state_tester : public nvbench::state
{
using params_type = nvbench::state::params_type;
state_tester()
: nvbench::state()
{}
explicit state_tester(params_type params)
: nvbench::state{std::move(params)}
{}
template <typename... Args>
void set_param(Args &&...args)
{
this->state::set_param(std::forward<Args>(args)...);
}
const auto &get_params() const { return m_params; }
};
void test_params()
{
// Build a state param by param
state_tester state1;
state1.set_param("TestInt", nvbench::int64_t{22});
state1.set_param("TestFloat", nvbench::float64_t{3.14});
state1.set_param("TestString", "A String!");
ASSERT(state1.get_int64("TestInt") == nvbench::int64_t{22});
ASSERT(state1.get_float64("TestFloat") == nvbench::float64_t{3.14});
ASSERT(state1.get_string("TestString") == "A String!");
// Construct a state from the parameter map built above:
state_tester state2{state1.get_params()};
ASSERT(state2.get_int64("TestInt") == nvbench::int64_t{22});
ASSERT(state2.get_float64("TestFloat") == nvbench::float64_t{3.14});
ASSERT(state2.get_string("TestString") == "A String!");
}
int main() { test_params(); }