mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-04-20 06:48:53 +00:00
Different singleton convention
This commit is contained in:
@@ -27,15 +27,19 @@
|
||||
// Inherit from the stopping_criterion class:
|
||||
class fixed_criterion final : public nvbench::stopping_criterion
|
||||
{
|
||||
nvbench::criterion_params m_params{};
|
||||
nvbench::int64_t m_max_samples{};
|
||||
nvbench::int64_t m_num_samples{};
|
||||
|
||||
public:
|
||||
fixed_criterion() : nvbench::stopping_criterion{"fixed"} {}
|
||||
|
||||
// Setup the criterion in the `initialize()` method:
|
||||
virtual void initialize(const nvbench::criterion_params ¶ms) override
|
||||
{
|
||||
m_params = params;
|
||||
m_num_samples = 0;
|
||||
m_max_samples = params.has_value("max-samples") ? params.get_int64("max-samples") : 42;
|
||||
m_max_samples = m_params.has_value("max-samples") ? m_params.get_int64("max-samples") : 42;
|
||||
}
|
||||
|
||||
// Process new measurements in the `add_measurement()` method:
|
||||
@@ -61,9 +65,8 @@ public:
|
||||
};
|
||||
|
||||
// Register the criterion with NVBench:
|
||||
static bool registered = //
|
||||
nvbench::criterion_manager::register_criterion("fixed",
|
||||
std::make_unique<fixed_criterion>());
|
||||
static nvbench::stopping_criterion& criterion_ref = //
|
||||
nvbench::criterion_manager::get().add(std::make_unique<fixed_criterion>());
|
||||
|
||||
void throughput_bench(nvbench::state &state)
|
||||
{
|
||||
|
||||
@@ -37,12 +37,16 @@ class criterion_manager
|
||||
criterion_manager();
|
||||
|
||||
public:
|
||||
static criterion_manager &instance();
|
||||
/**
|
||||
* @return The singleton criterion_manager instance.
|
||||
*/
|
||||
static criterion_manager& get();
|
||||
|
||||
static nvbench::stopping_criterion* get(const std::string& name);
|
||||
|
||||
static bool register_criterion(std::string name,
|
||||
std::unique_ptr<nvbench::stopping_criterion> criterion);
|
||||
/**
|
||||
* Register a new stopping criterion.
|
||||
*/
|
||||
nvbench::stopping_criterion& add(std::unique_ptr<nvbench::stopping_criterion> criterion);
|
||||
nvbench::stopping_criterion& get_criterion(const std::string& name);
|
||||
|
||||
static nvbench::stopping_criterion::params_description get_params_description();
|
||||
};
|
||||
|
||||
@@ -28,43 +28,45 @@ criterion_manager::criterion_manager()
|
||||
m_map.emplace("entropy", std::make_unique<nvbench::detail::entropy_criterion>());
|
||||
}
|
||||
|
||||
criterion_manager &criterion_manager::instance()
|
||||
criterion_manager &criterion_manager::get()
|
||||
{
|
||||
static criterion_manager registry;
|
||||
return registry;
|
||||
}
|
||||
|
||||
stopping_criterion* criterion_manager::get(const std::string& name)
|
||||
stopping_criterion& criterion_manager::get_criterion(const std::string& name)
|
||||
{
|
||||
criterion_manager& registry = instance();
|
||||
criterion_manager& registry = criterion_manager::get();
|
||||
|
||||
auto iter = registry.m_map.find(name);
|
||||
if (iter == registry.m_map.end())
|
||||
{
|
||||
NVBENCH_THROW(std::runtime_error, "No stopping criterion named \"{}\".", name);
|
||||
}
|
||||
return iter->second.get();
|
||||
return *iter->second.get();
|
||||
}
|
||||
|
||||
bool criterion_manager::register_criterion(std::string name,
|
||||
std::unique_ptr<stopping_criterion> criterion)
|
||||
stopping_criterion &criterion_manager::add(std::unique_ptr<stopping_criterion> criterion)
|
||||
{
|
||||
criterion_manager& manager = instance();
|
||||
criterion_manager& manager = criterion_manager::get();
|
||||
const std::string name = criterion->get_name();
|
||||
|
||||
if (manager.m_map.find(name) != manager.m_map.end())
|
||||
auto [it, success] = manager.m_map.emplace(name, std::move(criterion));
|
||||
|
||||
if (!success)
|
||||
{
|
||||
NVBENCH_THROW(std::runtime_error,
|
||||
"Stopping criterion \"{}\" is already registered.", name);
|
||||
}
|
||||
|
||||
return manager.m_map.emplace(std::move(name), std::move(criterion)).second;
|
||||
return *it->second.get();
|
||||
}
|
||||
|
||||
nvbench::stopping_criterion::params_description criterion_manager::get_params_description()
|
||||
{
|
||||
nvbench::stopping_criterion::params_description desc;
|
||||
|
||||
criterion_manager &manager = instance();
|
||||
criterion_manager& manager = criterion_manager::get();
|
||||
for (auto &[criterion_name, criterion] : manager.m_map)
|
||||
{
|
||||
for (auto param : criterion->get_params_description())
|
||||
|
||||
@@ -26,6 +26,7 @@ namespace nvbench::detail
|
||||
{
|
||||
|
||||
entropy_criterion::entropy_criterion()
|
||||
: stopping_criterion{"entropy"}
|
||||
{
|
||||
m_freq_tracker.reserve(m_entropy_tracker.capacity() * 2);
|
||||
m_ps.reserve(m_entropy_tracker.capacity() * 2);
|
||||
@@ -33,19 +34,20 @@ entropy_criterion::entropy_criterion()
|
||||
|
||||
void entropy_criterion::initialize(const criterion_params ¶ms)
|
||||
{
|
||||
m_params = params;
|
||||
m_total_samples = 0;
|
||||
m_total_cuda_time = 0.0;
|
||||
m_entropy_tracker.clear();
|
||||
m_freq_tracker.clear();
|
||||
|
||||
if (params.has_value("max-angle"))
|
||||
if (m_params.has_value("max-angle"))
|
||||
{
|
||||
m_max_angle = params.get_float64("max-angle");
|
||||
m_max_angle = m_params.get_float64("max-angle");
|
||||
}
|
||||
|
||||
if (params.has_value("min-r2"))
|
||||
if (m_params.has_value("min-r2"))
|
||||
{
|
||||
m_min_r2 = params.get_float64("min-r2");
|
||||
m_min_r2 = m_params.get_float64("min-r2");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ measure_cold_base::measure_cold_base(state &exec_state)
|
||||
: m_state{exec_state}
|
||||
, m_launch{m_state.get_cuda_stream()}
|
||||
, m_criterion_params{exec_state.get_criterion_params()}
|
||||
, m_stopping_criterion{nvbench::criterion_manager::get(exec_state.get_stopping_criterion())}
|
||||
, m_stopping_criterion{nvbench::criterion_manager::get().get_criterion(exec_state.get_stopping_criterion())}
|
||||
, m_run_once{exec_state.get_run_once()}
|
||||
, m_no_block{exec_state.get_disable_blocking_kernel()}
|
||||
, m_min_samples{exec_state.get_min_samples()}
|
||||
@@ -71,7 +71,7 @@ void measure_cold_base::initialize()
|
||||
m_cpu_times.clear();
|
||||
m_max_time_exceeded = false;
|
||||
|
||||
m_stopping_criterion->initialize(m_criterion_params);
|
||||
m_stopping_criterion.initialize(m_criterion_params);
|
||||
}
|
||||
|
||||
void measure_cold_base::run_trials_prologue() { m_walltime_timer.start(); }
|
||||
@@ -87,7 +87,7 @@ void measure_cold_base::record_measurements()
|
||||
m_total_cpu_time += cur_cpu_time;
|
||||
++m_total_samples;
|
||||
|
||||
m_stopping_criterion->add_measurement(cur_cuda_time);
|
||||
m_stopping_criterion.add_measurement(cur_cuda_time);
|
||||
}
|
||||
|
||||
bool measure_cold_base::is_finished()
|
||||
@@ -100,7 +100,7 @@ bool measure_cold_base::is_finished()
|
||||
// Check that we've gathered enough samples:
|
||||
if (m_total_samples > m_min_samples)
|
||||
{
|
||||
if (m_stopping_criterion->is_finished())
|
||||
if (m_stopping_criterion.is_finished())
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ protected:
|
||||
nvbench::blocking_kernel m_blocker;
|
||||
|
||||
nvbench::criterion_params m_criterion_params;
|
||||
nvbench::stopping_criterion* m_stopping_criterion{};
|
||||
nvbench::stopping_criterion& m_stopping_criterion;
|
||||
|
||||
bool m_run_once{false};
|
||||
bool m_no_block{false};
|
||||
|
||||
@@ -40,6 +40,8 @@ class stdrel_criterion final : public stopping_criterion
|
||||
nvbench::detail::ring_buffer<nvbench::float64_t> m_noise_tracker{512};
|
||||
|
||||
public:
|
||||
stdrel_criterion();
|
||||
|
||||
virtual void initialize(const criterion_params ¶ms) override;
|
||||
virtual void add_measurement(nvbench::float64_t measurement) override;
|
||||
virtual bool is_finished() override;
|
||||
|
||||
@@ -21,21 +21,26 @@
|
||||
namespace nvbench::detail
|
||||
{
|
||||
|
||||
stdrel_criterion::stdrel_criterion()
|
||||
: stopping_criterion{"stdrel"}
|
||||
{}
|
||||
|
||||
void stdrel_criterion::initialize(const criterion_params ¶ms)
|
||||
{
|
||||
m_params = params;
|
||||
m_total_samples = 0;
|
||||
m_total_cuda_time = 0.0;
|
||||
m_cuda_times.clear();
|
||||
m_noise_tracker.clear();
|
||||
|
||||
if (params.has_value("max-noise"))
|
||||
if (m_params.has_value("max-noise"))
|
||||
{
|
||||
m_max_noise = params.get_float64("max-noise");
|
||||
m_max_noise = m_params.get_float64("max-noise");
|
||||
}
|
||||
|
||||
if (params.has_value("min-time"))
|
||||
if (m_params.has_value("min-time"))
|
||||
{
|
||||
m_min_time = params.get_float64("min-time");
|
||||
m_min_time = m_params.get_float64("min-time");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -57,7 +57,15 @@ public:
|
||||
*/
|
||||
class stopping_criterion
|
||||
{
|
||||
protected:
|
||||
std::string m_name;
|
||||
criterion_params m_params;
|
||||
|
||||
public:
|
||||
stopping_criterion(std::string name) : m_name(std::move(name)) { }
|
||||
|
||||
[[nodiscard]] const std::string &get_name() const { return m_name; }
|
||||
|
||||
/**
|
||||
* Initialize the criterion with the given parameters
|
||||
*
|
||||
|
||||
@@ -23,13 +23,15 @@
|
||||
|
||||
void test_standard_criteria_exist()
|
||||
{
|
||||
ASSERT(nvbench::criterion_manager::get("stdrel") != nullptr);
|
||||
ASSERT(nvbench::criterion_manager::get("entropy") != nullptr);
|
||||
ASSERT(nvbench::criterion_manager::get().get_criterion("stdrel").get_name() == "stdrel");
|
||||
ASSERT(nvbench::criterion_manager::get().get_criterion("entropy").get_name() == "entropy");
|
||||
}
|
||||
|
||||
class custom_criterion : public nvbench::stopping_criterion
|
||||
{
|
||||
public:
|
||||
custom_criterion() : nvbench::stopping_criterion("custom") {}
|
||||
|
||||
virtual void initialize(const nvbench::criterion_params &) override {}
|
||||
virtual void add_measurement(nvbench::float64_t /* measurement */) override {}
|
||||
virtual bool is_finished() override { return true; }
|
||||
@@ -42,10 +44,11 @@ public:
|
||||
|
||||
void test_no_duplicates_are_allowed()
|
||||
{
|
||||
nvbench::criterion_manager& manager = nvbench::criterion_manager::get();
|
||||
bool exception_triggered = false;
|
||||
|
||||
try {
|
||||
nvbench::stopping_criterion* custom = nvbench::criterion_manager::get("custom");
|
||||
nvbench::stopping_criterion& custom = manager.get_criterion("custom");
|
||||
} catch(...) {
|
||||
exception_triggered = true;
|
||||
}
|
||||
@@ -53,14 +56,14 @@ void test_no_duplicates_are_allowed()
|
||||
|
||||
std::unique_ptr<custom_criterion> custom_ptr = std::make_unique<custom_criterion>();
|
||||
custom_criterion* custom_raw = custom_ptr.get();
|
||||
ASSERT(nvbench::criterion_manager::register_criterion("custom", std::move(custom_ptr)));
|
||||
ASSERT(&manager.add(std::move(custom_ptr)) == custom_raw);
|
||||
|
||||
nvbench::stopping_criterion* custom = nvbench::criterion_manager::get("custom");
|
||||
ASSERT(custom_raw == custom);
|
||||
nvbench::stopping_criterion& custom = nvbench::criterion_manager::get().get_criterion("custom");
|
||||
ASSERT(custom_raw == &custom);
|
||||
|
||||
exception_triggered = false;
|
||||
try {
|
||||
nvbench::criterion_manager::register_criterion("custom", std::make_unique<custom_criterion>());
|
||||
manager.add(std::make_unique<custom_criterion>());
|
||||
} catch(...) {
|
||||
exception_triggered = true;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user