Different singleton convention

This commit is contained in:
Georgy Evtushenko
2024-01-08 14:08:12 -08:00
parent 5b6378e918
commit fade52fa2e
10 changed files with 68 additions and 39 deletions

View File

@@ -27,15 +27,19 @@
// Inherit from the stopping_criterion class:
class fixed_criterion final : public nvbench::stopping_criterion
{
nvbench::criterion_params m_params{};
nvbench::int64_t m_max_samples{};
nvbench::int64_t m_num_samples{};
public:
fixed_criterion() : nvbench::stopping_criterion{"fixed"} {}
// Setup the criterion in the `initialize()` method:
virtual void initialize(const nvbench::criterion_params &params) override
{
m_params = params;
m_num_samples = 0;
m_max_samples = params.has_value("max-samples") ? params.get_int64("max-samples") : 42;
m_max_samples = m_params.has_value("max-samples") ? m_params.get_int64("max-samples") : 42;
}
// Process new measurements in the `add_measurement()` method:
@@ -61,9 +65,8 @@ public:
};
// Register the criterion with NVBench:
static bool registered = //
nvbench::criterion_manager::register_criterion("fixed",
std::make_unique<fixed_criterion>());
static nvbench::stopping_criterion& criterion_ref = //
nvbench::criterion_manager::get().add(std::make_unique<fixed_criterion>());
void throughput_bench(nvbench::state &state)
{

View File

@@ -37,12 +37,16 @@ class criterion_manager
criterion_manager();
public:
static criterion_manager &instance();
/**
* @return The singleton criterion_manager instance.
*/
static criterion_manager& get();
static nvbench::stopping_criterion* get(const std::string& name);
static bool register_criterion(std::string name,
std::unique_ptr<nvbench::stopping_criterion> criterion);
/**
* Register a new stopping criterion.
*/
nvbench::stopping_criterion& add(std::unique_ptr<nvbench::stopping_criterion> criterion);
nvbench::stopping_criterion& get_criterion(const std::string& name);
static nvbench::stopping_criterion::params_description get_params_description();
};

View File

@@ -28,43 +28,45 @@ criterion_manager::criterion_manager()
m_map.emplace("entropy", std::make_unique<nvbench::detail::entropy_criterion>());
}
criterion_manager &criterion_manager::instance()
criterion_manager &criterion_manager::get()
{
static criterion_manager registry;
return registry;
}
stopping_criterion* criterion_manager::get(const std::string& name)
stopping_criterion& criterion_manager::get_criterion(const std::string& name)
{
criterion_manager& registry = instance();
criterion_manager& registry = criterion_manager::get();
auto iter = registry.m_map.find(name);
if (iter == registry.m_map.end())
{
NVBENCH_THROW(std::runtime_error, "No stopping criterion named \"{}\".", name);
}
return iter->second.get();
return *iter->second.get();
}
bool criterion_manager::register_criterion(std::string name,
std::unique_ptr<stopping_criterion> criterion)
stopping_criterion &criterion_manager::add(std::unique_ptr<stopping_criterion> criterion)
{
criterion_manager& manager = instance();
criterion_manager& manager = criterion_manager::get();
const std::string name = criterion->get_name();
if (manager.m_map.find(name) != manager.m_map.end())
auto [it, success] = manager.m_map.emplace(name, std::move(criterion));
if (!success)
{
NVBENCH_THROW(std::runtime_error,
"Stopping criterion \"{}\" is already registered.", name);
}
return manager.m_map.emplace(std::move(name), std::move(criterion)).second;
return *it->second.get();
}
nvbench::stopping_criterion::params_description criterion_manager::get_params_description()
{
nvbench::stopping_criterion::params_description desc;
criterion_manager &manager = instance();
criterion_manager& manager = criterion_manager::get();
for (auto &[criterion_name, criterion] : manager.m_map)
{
for (auto param : criterion->get_params_description())

View File

@@ -26,6 +26,7 @@ namespace nvbench::detail
{
entropy_criterion::entropy_criterion()
: stopping_criterion{"entropy"}
{
m_freq_tracker.reserve(m_entropy_tracker.capacity() * 2);
m_ps.reserve(m_entropy_tracker.capacity() * 2);
@@ -33,19 +34,20 @@ entropy_criterion::entropy_criterion()
void entropy_criterion::initialize(const criterion_params &params)
{
m_params = params;
m_total_samples = 0;
m_total_cuda_time = 0.0;
m_entropy_tracker.clear();
m_freq_tracker.clear();
if (params.has_value("max-angle"))
if (m_params.has_value("max-angle"))
{
m_max_angle = params.get_float64("max-angle");
m_max_angle = m_params.get_float64("max-angle");
}
if (params.has_value("min-r2"))
if (m_params.has_value("min-r2"))
{
m_min_r2 = params.get_float64("min-r2");
m_min_r2 = m_params.get_float64("min-r2");
}
}

View File

@@ -34,7 +34,7 @@ measure_cold_base::measure_cold_base(state &exec_state)
: m_state{exec_state}
, m_launch{m_state.get_cuda_stream()}
, m_criterion_params{exec_state.get_criterion_params()}
, m_stopping_criterion{nvbench::criterion_manager::get(exec_state.get_stopping_criterion())}
, m_stopping_criterion{nvbench::criterion_manager::get().get_criterion(exec_state.get_stopping_criterion())}
, m_run_once{exec_state.get_run_once()}
, m_no_block{exec_state.get_disable_blocking_kernel()}
, m_min_samples{exec_state.get_min_samples()}
@@ -71,7 +71,7 @@ void measure_cold_base::initialize()
m_cpu_times.clear();
m_max_time_exceeded = false;
m_stopping_criterion->initialize(m_criterion_params);
m_stopping_criterion.initialize(m_criterion_params);
}
void measure_cold_base::run_trials_prologue() { m_walltime_timer.start(); }
@@ -87,7 +87,7 @@ void measure_cold_base::record_measurements()
m_total_cpu_time += cur_cpu_time;
++m_total_samples;
m_stopping_criterion->add_measurement(cur_cuda_time);
m_stopping_criterion.add_measurement(cur_cuda_time);
}
bool measure_cold_base::is_finished()
@@ -100,7 +100,7 @@ bool measure_cold_base::is_finished()
// Check that we've gathered enough samples:
if (m_total_samples > m_min_samples)
{
if (m_stopping_criterion->is_finished())
if (m_stopping_criterion.is_finished())
{
return true;
}

View File

@@ -87,7 +87,7 @@ protected:
nvbench::blocking_kernel m_blocker;
nvbench::criterion_params m_criterion_params;
nvbench::stopping_criterion* m_stopping_criterion{};
nvbench::stopping_criterion& m_stopping_criterion;
bool m_run_once{false};
bool m_no_block{false};

View File

@@ -40,6 +40,8 @@ class stdrel_criterion final : public stopping_criterion
nvbench::detail::ring_buffer<nvbench::float64_t> m_noise_tracker{512};
public:
stdrel_criterion();
virtual void initialize(const criterion_params &params) override;
virtual void add_measurement(nvbench::float64_t measurement) override;
virtual bool is_finished() override;

View File

@@ -21,21 +21,26 @@
namespace nvbench::detail
{
stdrel_criterion::stdrel_criterion()
: stopping_criterion{"stdrel"}
{}
void stdrel_criterion::initialize(const criterion_params &params)
{
m_params = params;
m_total_samples = 0;
m_total_cuda_time = 0.0;
m_cuda_times.clear();
m_noise_tracker.clear();
if (params.has_value("max-noise"))
if (m_params.has_value("max-noise"))
{
m_max_noise = params.get_float64("max-noise");
m_max_noise = m_params.get_float64("max-noise");
}
if (params.has_value("min-time"))
if (m_params.has_value("min-time"))
{
m_min_time = params.get_float64("min-time");
m_min_time = m_params.get_float64("min-time");
}
}

View File

@@ -57,7 +57,15 @@ public:
*/
class stopping_criterion
{
protected:
std::string m_name;
criterion_params m_params;
public:
stopping_criterion(std::string name) : m_name(std::move(name)) { }
[[nodiscard]] const std::string &get_name() const { return m_name; }
/**
* Initialize the criterion with the given parameters
*

View File

@@ -23,13 +23,15 @@
void test_standard_criteria_exist()
{
ASSERT(nvbench::criterion_manager::get("stdrel") != nullptr);
ASSERT(nvbench::criterion_manager::get("entropy") != nullptr);
ASSERT(nvbench::criterion_manager::get().get_criterion("stdrel").get_name() == "stdrel");
ASSERT(nvbench::criterion_manager::get().get_criterion("entropy").get_name() == "entropy");
}
class custom_criterion : public nvbench::stopping_criterion
{
public:
custom_criterion() : nvbench::stopping_criterion("custom") {}
virtual void initialize(const nvbench::criterion_params &) override {}
virtual void add_measurement(nvbench::float64_t /* measurement */) override {}
virtual bool is_finished() override { return true; }
@@ -42,10 +44,11 @@ public:
void test_no_duplicates_are_allowed()
{
nvbench::criterion_manager& manager = nvbench::criterion_manager::get();
bool exception_triggered = false;
try {
nvbench::stopping_criterion* custom = nvbench::criterion_manager::get("custom");
nvbench::stopping_criterion& custom = manager.get_criterion("custom");
} catch(...) {
exception_triggered = true;
}
@@ -53,14 +56,14 @@ void test_no_duplicates_are_allowed()
std::unique_ptr<custom_criterion> custom_ptr = std::make_unique<custom_criterion>();
custom_criterion* custom_raw = custom_ptr.get();
ASSERT(nvbench::criterion_manager::register_criterion("custom", std::move(custom_ptr)));
ASSERT(&manager.add(std::move(custom_ptr)) == custom_raw);
nvbench::stopping_criterion* custom = nvbench::criterion_manager::get("custom");
ASSERT(custom_raw == custom);
nvbench::stopping_criterion& custom = nvbench::criterion_manager::get().get_criterion("custom");
ASSERT(custom_raw == &custom);
exception_triggered = false;
try {
nvbench::criterion_manager::register_criterion("custom", std::make_unique<custom_criterion>());
manager.add(std::make_unique<custom_criterion>());
} catch(...) {
exception_triggered = true;
}