Got rid of the params description API

This commit is contained in:
Georgy Evtushenko
2024-01-10 12:16:55 -08:00
parent bcd5c7d885
commit 182c77e4f4
13 changed files with 157 additions and 123 deletions

View File

@@ -27,19 +27,12 @@
// Inherit from the stopping_criterion class:
class fixed_criterion final : public nvbench::stopping_criterion
{
nvbench::int64_t m_max_samples{};
nvbench::int64_t m_num_samples{};
public:
fixed_criterion() : nvbench::stopping_criterion{"fixed"} {}
// Setup the criterion in the `initialize()` method:
virtual void initialize(const nvbench::criterion_params &params) override
{
m_params = params;
m_num_samples = 0;
m_max_samples = m_params.has_value("max-samples") ? m_params.get_int64("max-samples") : 42;
}
fixed_criterion()
: nvbench::stopping_criterion{"fixed", {{"max-samples", nvbench::int64_t{42}}}}
{}
// Process new measurements in the `add_measurement()` method:
virtual void add_measurement(nvbench::float64_t /* measurement */) override
@@ -50,16 +43,14 @@ public:
// Check if the stopping criterion is met in the `is_finished()` method:
virtual bool is_finished() override
{
return m_num_samples >= m_max_samples;
return m_num_samples >= m_params.get_int64("max-samples");
}
// Describe criterion parameters in the `get_params_description()` method:
virtual const params_description &get_params_description() const override
protected:
// Setup the criterion in the `do_initialize()` method:
virtual void do_initialize() override
{
static const params_description desc{
{"max-samples", nvbench::named_values::type::int64}
};
return desc;
m_num_samples = 0;
}
};

View File

@@ -184,7 +184,10 @@ struct benchmark_base
/// Accumulate at least this many seconds of timing data per measurement.
/// Only applies to `stdrel` stopping criterion. @{
[[nodiscard]] nvbench::float64_t get_min_time() const { return m_criterion_params.get_float64("min-time"); }
[[nodiscard]] nvbench::float64_t get_min_time() const
{
return m_criterion_params.get_float64("min-time");
}
benchmark_base &set_min_time(nvbench::float64_t min_time)
{
m_criterion_params.set_float64("min-time", min_time);
@@ -196,7 +199,10 @@ struct benchmark_base
/// Noise is the relative standard deviation:
/// `noise = stdev / mean_time`.
/// Only applies to `stdrel` stopping criterion. @{
[[nodiscard]] nvbench::float64_t get_max_noise() const { return m_criterion_params.get_float64("max-noise"); }
[[nodiscard]] nvbench::float64_t get_max_noise() const
{
return m_criterion_params.get_float64("max-noise");
}
benchmark_base &set_max_noise(nvbench::float64_t max_noise)
{
m_criterion_params.set_float64("max-noise", max_noise);

View File

@@ -49,7 +49,8 @@ public:
nvbench::stopping_criterion& get_criterion(const std::string& name);
const nvbench::stopping_criterion& get_criterion(const std::string& name) const;
nvbench::stopping_criterion::params_description get_params_description() const;
using params_description = std::vector<std::pair<std::string, nvbench::named_values::type>>;
params_description get_params_description() const;
};
/**

View File

@@ -69,25 +69,28 @@ stopping_criterion &criterion_manager::add(std::unique_ptr<stopping_criterion> c
return *it->second.get();
}
nvbench::stopping_criterion::params_description criterion_manager::get_params_description() const
nvbench::criterion_manager::params_description criterion_manager::get_params_description() const
{
nvbench::stopping_criterion::params_description desc;
nvbench::criterion_manager::params_description desc;
for (auto &[criterion_name, criterion] : m_map)
{
for (auto param : criterion->get_params_description())
nvbench::criterion_params params = criterion->get_params();
for (auto param : params.get_names())
{
nvbench::named_values::type type = params.get_type(param);
if (std::find_if(desc.begin(), desc.end(), [&](auto d) {
return d.first == param.first && d.second != param.second;
return d.first == param && d.second != type;
}) != desc.end())
{
NVBENCH_THROW(std::runtime_error,
"Stopping criterion \"{}\" parameter \"{}\" is already used by another "
"criterion with a different type.",
criterion_name,
param.first);
param);
}
desc.push_back(param);
desc.emplace_back(param, type);
}
}

View File

@@ -29,10 +29,6 @@ namespace nvbench::detail
class entropy_criterion final : public stopping_criterion
{
// parameters
nvbench::float64_t m_max_angle{0.048};
nvbench::float64_t m_min_r2{0.36};
// state
nvbench::int64_t m_total_samples{};
nvbench::float64_t m_total_cuda_time{};
@@ -49,10 +45,11 @@ class entropy_criterion final : public stopping_criterion
public:
entropy_criterion();
virtual void initialize(const criterion_params &params) override;
virtual void add_measurement(nvbench::float64_t measurement) override;
virtual bool is_finished() override;
virtual const params_description &get_params_description() const override;
protected:
virtual void do_initialize() override;
};
} // namespace nvbench::detail

View File

@@ -26,29 +26,19 @@ namespace nvbench::detail
{
entropy_criterion::entropy_criterion()
: stopping_criterion{"entropy"}
: stopping_criterion{"entropy",
{{"max-angle", 0.048}, {"min-r2", 0.36}}}
{
m_freq_tracker.reserve(m_entropy_tracker.capacity() * 2);
m_probabilities.reserve(m_entropy_tracker.capacity() * 2);
}
void entropy_criterion::initialize(const criterion_params &params)
void entropy_criterion::do_initialize()
{
m_params = params;
m_total_samples = 0;
m_total_samples = 0;
m_total_cuda_time = 0.0;
m_entropy_tracker.clear();
m_freq_tracker.clear();
if (m_params.has_value("max-angle"))
{
m_max_angle = m_params.get_float64("max-angle");
}
if (m_params.has_value("min-r2"))
{
m_min_r2 = m_params.get_float64("min-r2");
}
}
nvbench::float64_t entropy_criterion::compute_entropy()
@@ -131,13 +121,13 @@ bool entropy_criterion::is_finished()
const auto [slope, intercept] = statistics::compute_linear_regression(begin, end, mean);
if (statistics::slope2deg(slope) > m_max_angle)
if (statistics::slope2deg(slope) > m_params.get_float64("max-angle"))
{
return false;
}
const auto r2 = statistics::compute_r2(begin, end, mean, slope, intercept);
if (r2 < m_min_r2)
if (r2 < m_params.get_float64("min-r2"))
{
return false;
}
@@ -145,13 +135,4 @@ bool entropy_criterion::is_finished()
return true;
}
const entropy_criterion::params_description &entropy_criterion::get_params_description() const
{
static const params_description desc{
{"max-angle", nvbench::named_values::type::float64},
{"min-r2", nvbench::named_values::type::float64},
};
return desc;
}
} // namespace nvbench::detail

View File

@@ -29,10 +29,6 @@ namespace nvbench::detail
class stdrel_criterion final : public stopping_criterion
{
// parameters
nvbench::float64_t m_min_time{nvbench::detail::compat_min_time()};
nvbench::float64_t m_max_noise{nvbench::detail::compat_max_noise()};
// state
nvbench::int64_t m_total_samples{};
nvbench::float64_t m_total_cuda_time{};
@@ -42,10 +38,11 @@ class stdrel_criterion final : public stopping_criterion
public:
stdrel_criterion();
virtual void initialize(const criterion_params &params) override;
virtual void add_measurement(nvbench::float64_t measurement) override;
virtual bool is_finished() override;
virtual const params_description &get_params_description() const override;
protected:
virtual void do_initialize() override;
};
} // namespace nvbench::detail

View File

@@ -22,26 +22,17 @@ namespace nvbench::detail
{
stdrel_criterion::stdrel_criterion()
: stopping_criterion{"stdrel"}
: stopping_criterion{"stdrel",
{{"max-noise", nvbench::detail::compat_max_noise()},
{"min-time", nvbench::detail::compat_min_time()}}}
{}
void stdrel_criterion::initialize(const criterion_params &params)
void stdrel_criterion::do_initialize()
{
m_params = params;
m_total_samples = 0;
m_total_cuda_time = 0.0;
m_cuda_times.clear();
m_noise_tracker.clear();
if (m_params.has_value("max-noise"))
{
m_max_noise = m_params.get_float64("max-noise");
}
if (m_params.has_value("min-time"))
{
m_min_time = m_params.get_float64("min-time");
}
}
void stdrel_criterion::add_measurement(nvbench::float64_t measurement)
@@ -64,13 +55,13 @@ void stdrel_criterion::add_measurement(nvbench::float64_t measurement)
bool stdrel_criterion::is_finished()
{
if (m_total_cuda_time <= m_min_time)
if (m_total_cuda_time <= m_params.get_float64("min-time"))
{
return false;
}
// Noise has dropped below threshold
if (m_noise_tracker.back() < m_max_noise)
if (m_noise_tracker.back() < m_params.get_float64("max-noise"))
{
return true;
}
@@ -104,13 +95,4 @@ bool stdrel_criterion::is_finished()
return false;
}
const stdrel_criterion::params_description &stdrel_criterion::get_params_description() const
{
static const params_description desc{
{"max-noise", nvbench::named_values::type::float64},
{"min-time", nvbench::named_values::type::float64},
};
return desc;
}
} // namespace nvbench::detail

View File

@@ -377,7 +377,7 @@ void option_parser::parse_range(option_parser::arg_iterator_t first,
}
};
const nvbench::stopping_criterion::params_description criterion_params =
const nvbench::criterion_manager::params_description criterion_params =
nvbench::criterion_manager::get().get_params_description();
while (first < last)

View File

@@ -149,16 +149,28 @@ struct state
/// Accumulate at least this many seconds of timing data per measurement.
/// Only applies to `stdrel` stopping criterion. @{
[[nodiscard]] nvbench::float64_t get_min_time() const { return m_criterion_params.get_float64("min-time"); }
void set_min_time(nvbench::float64_t min_time) { m_criterion_params.set_float64("min-time", min_time); }
[[nodiscard]] nvbench::float64_t get_min_time() const
{
return m_criterion_params.get_float64("min-time");
}
void set_min_time(nvbench::float64_t min_time)
{
m_criterion_params.set_float64("min-time", min_time);
}
/// @}
/// Specify the maximum amount of noise if a measurement supports noise.
/// Noise is the relative standard deviation:
/// `noise = stdev / mean_time`.
/// `noise = stdev / mean_time`.
/// Only applies to `stdrel` stopping criterion. @{
[[nodiscard]] nvbench::float64_t get_max_noise() const { return m_criterion_params.get_float64("max-noise"); }
void set_max_noise(nvbench::float64_t max_noise) { m_criterion_params.set_float64("max-noise", max_noise); }
[[nodiscard]] nvbench::float64_t get_max_noise() const
{
return m_criterion_params.get_float64("max-noise");
}
void set_max_noise(nvbench::float64_t max_noise)
{
m_criterion_params.set_float64("max-noise", max_noise);
}
/// @}
/// If a warmup run finishes in less than `skip_time`, the measurement will

View File

@@ -18,12 +18,14 @@
#pragma once
#include <nvbench/types.cuh>
#include <nvbench/named_values.cuh>
#include <nvbench/types.cuh>
#include <unordered_map>
#include <string>
#include <initializer_list>
#include <unordered_map>
namespace nvbench
{
@@ -42,14 +44,27 @@ class criterion_params
{
nvbench::named_values m_named_values;
public:
criterion_params();
criterion_params(std::initializer_list<std::pair<std::string, nvbench::named_values::value_type>>);
/**
* Set parameter values from another criterion_params object if they exist
*
* Parameters in `other` that do not correspond to parameters in `this` are ignored.
*/
void set_from(const criterion_params &other);
void set_int64(std::string name, nvbench::int64_t value);
void set_float64(std::string name, nvbench::float64_t value);
void set_string(std::string name, std::string value);
[[nodiscard]] std::vector<std::string> get_names() const;
[[nodiscard]] nvbench::named_values::type get_type(const std::string &name) const;
[[nodiscard]] bool has_value(const std::string &name) const;
[[nodiscard]] nvbench::int64_t get_int64(const std::string &name) const;
[[nodiscard]] nvbench::float64_t get_float64(const std::string &name) const;
[[nodiscard]] std::string get_string(const std::string &name) const;
};
/**
@@ -62,16 +77,28 @@ protected:
criterion_params m_params;
public:
explicit stopping_criterion(std::string name) : m_name(std::move(name)) { }
/**
* @param name Unique name of the criterion
* @param params Default values for all parameters of the criterion
*/
explicit stopping_criterion(std::string name, criterion_params params)
: m_name{std::move(name)}
, m_params{params}
{}
[[nodiscard]] const std::string &get_name() const { return m_name; }
[[nodiscard]] const criterion_params &get_params() const { return m_params; }
/**
* Initialize the criterion with the given parameters
*
* This method is called once per benchmark run, before any measurements are provided.
*/
virtual void initialize(const criterion_params &params) = 0;
void initialize(const criterion_params &params)
{
m_params.set_from(params);
do_initialize();
}
/**
* Add the latest measurement to the criterion
@@ -83,12 +110,11 @@ public:
*/
virtual bool is_finished() = 0;
using params_description = std::vector<std::pair<std::string, nvbench::named_values::type>>;
protected:
/**
* Return the parameter names and types for this criterion
* Initialize the criterion after updaring the parameters
*/
virtual const params_description &get_params_description() const = 0;
virtual void do_initialize() = 0;
};
} // namespace nvbench

View File

@@ -18,10 +18,48 @@
#include <nvbench/stopping_criterion.cuh>
#include <nvbench/detail/throw.cuh>
namespace nvbench
{
// Default constructor for compatibility with old code
criterion_params::criterion_params()
: criterion_params{{"max-noise", nvbench::detail::compat_max_noise()},
{"min-time", nvbench::detail::compat_min_time()}}
{}
criterion_params::criterion_params(
std::initializer_list<std::pair<std::string, nvbench::named_values::value_type>> list)
{
for (const auto &[name, value] : list)
{
m_named_values.set_value(name, value);
}
}
void criterion_params::set_from(const criterion_params &other)
{
for (const std::string &name : this->get_names())
{
if (other.has_value(name))
{
if (this->get_type(name) != other.get_type(name))
{
NVBENCH_THROW(std::runtime_error,
"Mismatched types for named value \"{}\". "
"Expected {}, got {}.",
name,
static_cast<int>(this->get_type(name)),
static_cast<int>(other.get_type(name)));
}
m_named_values.remove_value(name);
m_named_values.set_value(name, other.m_named_values.get_value(name));
}
}
}
void criterion_params::set_int64(std::string name, nvbench::int64_t value)
{
if (m_named_values.has_value(name))
@@ -54,10 +92,6 @@ void criterion_params::set_string(std::string name, std::string value)
bool criterion_params::has_value(const std::string &name) const
{
if (name == "max-noise" || name == "min-time")
{ // compat
return true;
}
return m_named_values.has_value(name);
}
@@ -68,18 +102,23 @@ nvbench::int64_t criterion_params::get_int64(const std::string &name) const
nvbench::float64_t criterion_params::get_float64(const std::string &name) const
{
if (!m_named_values.has_value(name))
{
if (name == "max-noise")
{ // compat
return nvbench::detail::compat_max_noise();
}
else if (name == "min-time")
{ // compat
return nvbench::detail::compat_min_time();
}
}
return m_named_values.get_float64(name);
}
std::string criterion_params::get_string(const std::string &name) const
{
return m_named_values.get_string(name);
}
std::vector<std::string> criterion_params::get_names() const
{
return m_named_values.get_names();
}
nvbench::named_values::type criterion_params::get_type(const std::string &name) const
{
return m_named_values.get_type(name);
}
} // namespace nvbench::detail

View File

@@ -30,16 +30,15 @@ void test_standard_criteria_exist()
class custom_criterion : public nvbench::stopping_criterion
{
public:
custom_criterion() : nvbench::stopping_criterion("custom") {}
custom_criterion()
: nvbench::stopping_criterion("custom", nvbench::criterion_params{})
{}
virtual void initialize(const nvbench::criterion_params &) override {}
virtual void add_measurement(nvbench::float64_t /* measurement */) override {}
virtual bool is_finished() override { return true; }
virtual const params_description &get_params_description() const override
{
static const params_description desc{};
return desc;
}
protected:
virtual void do_initialize() override {}
};
void test_no_duplicates_are_allowed()