mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-03-14 20:27:24 +00:00
Add nvbench::state.
This class holds a single value for each runtime axis.
This commit is contained in:
@@ -5,6 +5,7 @@ set(srcs
|
||||
params.cu
|
||||
string_axis.cu
|
||||
type_axis.cu
|
||||
state.cu
|
||||
detail/state_generator.cu
|
||||
)
|
||||
|
||||
|
||||
41
nvbench/state.cu
Normal file
41
nvbench/state.cu
Normal file
@@ -0,0 +1,41 @@
|
||||
#include <nvbench/state.cuh>
|
||||
|
||||
#include <nvbench/types.cuh>
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <variant>
|
||||
|
||||
namespace nvbench
|
||||
{
|
||||
|
||||
nvbench::int64_t state::get_int64(const std::string &axis_name) const
|
||||
{
|
||||
return std::get<nvbench::int64_t>(m_params.at(axis_name));
|
||||
}
|
||||
|
||||
nvbench::float64_t state::get_float64(const std::string &axis_name) const
|
||||
{
|
||||
return std::get<nvbench::float64_t>(m_params.at(axis_name));
|
||||
}
|
||||
|
||||
const std::string &state::get_string(const std::string &axis_name) const
|
||||
{
|
||||
return std::get<std::string>(m_params.at(axis_name));
|
||||
}
|
||||
void state::set_param(std::string axis_name, nvbench::int64_t value)
|
||||
{
|
||||
m_params.insert(std::make_pair(std::move(axis_name), value));
|
||||
}
|
||||
|
||||
void state::set_param(std::string axis_name, nvbench::float64_t value)
|
||||
{
|
||||
m_params.insert(std::make_pair(std::move(axis_name), value));
|
||||
}
|
||||
|
||||
void state::set_param(std::string axis_name, std::string value)
|
||||
{
|
||||
m_params.insert(std::make_pair(std::move(axis_name), std::move(value)));
|
||||
}
|
||||
|
||||
} // namespace nvbench
|
||||
46
nvbench/state.cuh
Normal file
46
nvbench/state.cuh
Normal file
@@ -0,0 +1,46 @@
|
||||
#pragma once
|
||||
|
||||
#include <nvbench/types.cuh>
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <variant>
|
||||
|
||||
namespace nvbench
|
||||
{
|
||||
|
||||
struct state
|
||||
{
|
||||
// move-only
|
||||
state(const state &) = delete;
|
||||
state(state &&) = default;
|
||||
state &operator=(const state &) = delete;
|
||||
state &operator=(state &&) = default;
|
||||
|
||||
[[nodiscard]] nvbench::int64_t get_int64(const std::string &axis_name) const;
|
||||
|
||||
[[nodiscard]] nvbench::float64_t
|
||||
get_float64(const std::string &axis_name) const;
|
||||
|
||||
[[nodiscard]] const std::string &
|
||||
get_string(const std::string &axis_name) const;
|
||||
|
||||
protected:
|
||||
state() = default;
|
||||
|
||||
using param_type =
|
||||
std::variant<nvbench::int64_t, nvbench::float64_t, std::string>;
|
||||
using params_type = std::unordered_map<std::string, param_type>;
|
||||
|
||||
explicit state(params_type params)
|
||||
: m_params{std::move(params)}
|
||||
{}
|
||||
|
||||
void set_param(std::string axis_name, nvbench::int64_t value);
|
||||
void set_param(std::string axis_name, nvbench::float64_t value);
|
||||
void set_param(std::string axis_name, std::string value);
|
||||
|
||||
params_type m_params;
|
||||
};
|
||||
|
||||
} // namespace nvbench
|
||||
@@ -3,6 +3,7 @@ set(test_srcs
|
||||
int64_axis.cu
|
||||
float64_axis.cu
|
||||
params.cu
|
||||
state.cu
|
||||
state_generator.cu
|
||||
string_axis.cu
|
||||
type_axis.cu
|
||||
|
||||
48
testing/state.cu
Normal file
48
testing/state.cu
Normal file
@@ -0,0 +1,48 @@
|
||||
#include <nvbench/state.cuh>
|
||||
|
||||
#include <nvbench/types.cuh>
|
||||
|
||||
#include "test_asserts.cuh"
|
||||
|
||||
// Subclass to gain access to protected members for testing:
|
||||
struct state_tester : public nvbench::state
|
||||
{
|
||||
using params_type = nvbench::state::params_type;
|
||||
|
||||
state_tester()
|
||||
: nvbench::state()
|
||||
{}
|
||||
explicit state_tester(params_type params)
|
||||
: nvbench::state{std::move(params)}
|
||||
{}
|
||||
|
||||
template <typename... Args>
|
||||
void set_param(Args &&...args)
|
||||
{
|
||||
this->state::set_param(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
const auto &get_params() const { return m_params; }
|
||||
};
|
||||
|
||||
void test_params()
|
||||
{
|
||||
// Build a state param by param
|
||||
state_tester state1;
|
||||
state1.set_param("TestInt", nvbench::int64_t{22});
|
||||
state1.set_param("TestFloat", nvbench::float64_t{3.14});
|
||||
state1.set_param("TestString", "A String!");
|
||||
|
||||
ASSERT(state1.get_int64("TestInt") == nvbench::int64_t{22});
|
||||
ASSERT(state1.get_float64("TestFloat") == nvbench::float64_t{3.14});
|
||||
ASSERT(state1.get_string("TestString") == "A String!");
|
||||
|
||||
// Construct a state from the parameter map built above:
|
||||
state_tester state2{state1.get_params()};
|
||||
|
||||
ASSERT(state2.get_int64("TestInt") == nvbench::int64_t{22});
|
||||
ASSERT(state2.get_float64("TestFloat") == nvbench::float64_t{3.14});
|
||||
ASSERT(state2.get_string("TestString") == "A String!");
|
||||
}
|
||||
|
||||
int main() { test_params(); }
|
||||
Reference in New Issue
Block a user