diff --git a/examples/custom_criterion.cu b/examples/custom_criterion.cu index 939d6ae..6d1dfb3 100644 --- a/examples/custom_criterion.cu +++ b/examples/custom_criterion.cu @@ -27,19 +27,12 @@ // Inherit from the stopping_criterion class: class fixed_criterion final : public nvbench::stopping_criterion { - nvbench::int64_t m_max_samples{}; nvbench::int64_t m_num_samples{}; public: - fixed_criterion() : nvbench::stopping_criterion{"fixed"} {} - - // Setup the criterion in the `initialize()` method: - virtual void initialize(const nvbench::criterion_params ¶ms) override - { - m_params = params; - m_num_samples = 0; - m_max_samples = m_params.has_value("max-samples") ? m_params.get_int64("max-samples") : 42; - } + fixed_criterion() + : nvbench::stopping_criterion{"fixed", {{"max-samples", nvbench::int64_t{42}}}} + {} // Process new measurements in the `add_measurement()` method: virtual void add_measurement(nvbench::float64_t /* measurement */) override @@ -50,16 +43,14 @@ public: // Check if the stopping criterion is met in the `is_finished()` method: virtual bool is_finished() override { - return m_num_samples >= m_max_samples; + return m_num_samples >= m_params.get_int64("max-samples"); } - // Describe criterion parameters in the `get_params_description()` method: - virtual const params_description &get_params_description() const override +protected: + // Setup the criterion in the `do_initialize()` method: + virtual void do_initialize() override { - static const params_description desc{ - {"max-samples", nvbench::named_values::type::int64} - }; - return desc; + m_num_samples = 0; } }; diff --git a/nvbench/benchmark_base.cuh b/nvbench/benchmark_base.cuh index efe0c83..55673b0 100644 --- a/nvbench/benchmark_base.cuh +++ b/nvbench/benchmark_base.cuh @@ -184,7 +184,10 @@ struct benchmark_base /// Accumulate at least this many seconds of timing data per measurement. /// Only applies to `stdrel` stopping criterion. @{ - [[nodiscard]] nvbench::float64_t get_min_time() const { return m_criterion_params.get_float64("min-time"); } + [[nodiscard]] nvbench::float64_t get_min_time() const + { + return m_criterion_params.get_float64("min-time"); + } benchmark_base &set_min_time(nvbench::float64_t min_time) { m_criterion_params.set_float64("min-time", min_time); @@ -196,7 +199,10 @@ struct benchmark_base /// Noise is the relative standard deviation: /// `noise = stdev / mean_time`. /// Only applies to `stdrel` stopping criterion. @{ - [[nodiscard]] nvbench::float64_t get_max_noise() const { return m_criterion_params.get_float64("max-noise"); } + [[nodiscard]] nvbench::float64_t get_max_noise() const + { + return m_criterion_params.get_float64("max-noise"); + } benchmark_base &set_max_noise(nvbench::float64_t max_noise) { m_criterion_params.set_float64("max-noise", max_noise); diff --git a/nvbench/criterion_manager.cuh b/nvbench/criterion_manager.cuh index 53206f5..e5c86ad 100644 --- a/nvbench/criterion_manager.cuh +++ b/nvbench/criterion_manager.cuh @@ -49,7 +49,8 @@ public: nvbench::stopping_criterion& get_criterion(const std::string& name); const nvbench::stopping_criterion& get_criterion(const std::string& name) const; - nvbench::stopping_criterion::params_description get_params_description() const; + using params_description = std::vector>; + params_description get_params_description() const; }; /** diff --git a/nvbench/criterion_manager.cxx b/nvbench/criterion_manager.cxx index 35d97f7..abf682c 100644 --- a/nvbench/criterion_manager.cxx +++ b/nvbench/criterion_manager.cxx @@ -69,25 +69,28 @@ stopping_criterion &criterion_manager::add(std::unique_ptr c return *it->second.get(); } -nvbench::stopping_criterion::params_description criterion_manager::get_params_description() const +nvbench::criterion_manager::params_description criterion_manager::get_params_description() const { - nvbench::stopping_criterion::params_description desc; + nvbench::criterion_manager::params_description desc; for (auto &[criterion_name, criterion] : m_map) { - for (auto param : criterion->get_params_description()) + nvbench::criterion_params params = criterion->get_params(); + + for (auto param : params.get_names()) { + nvbench::named_values::type type = params.get_type(param); if (std::find_if(desc.begin(), desc.end(), [&](auto d) { - return d.first == param.first && d.second != param.second; + return d.first == param && d.second != type; }) != desc.end()) { NVBENCH_THROW(std::runtime_error, "Stopping criterion \"{}\" parameter \"{}\" is already used by another " "criterion with a different type.", criterion_name, - param.first); + param); } - desc.push_back(param); + desc.emplace_back(param, type); } } diff --git a/nvbench/detail/entropy_criterion.cuh b/nvbench/detail/entropy_criterion.cuh index 2c143ab..e0a3763 100644 --- a/nvbench/detail/entropy_criterion.cuh +++ b/nvbench/detail/entropy_criterion.cuh @@ -29,10 +29,6 @@ namespace nvbench::detail class entropy_criterion final : public stopping_criterion { - // parameters - nvbench::float64_t m_max_angle{0.048}; - nvbench::float64_t m_min_r2{0.36}; - // state nvbench::int64_t m_total_samples{}; nvbench::float64_t m_total_cuda_time{}; @@ -49,10 +45,11 @@ class entropy_criterion final : public stopping_criterion public: entropy_criterion(); - virtual void initialize(const criterion_params ¶ms) override; virtual void add_measurement(nvbench::float64_t measurement) override; virtual bool is_finished() override; - virtual const params_description &get_params_description() const override; + +protected: + virtual void do_initialize() override; }; } // namespace nvbench::detail diff --git a/nvbench/detail/entropy_criterion.cxx b/nvbench/detail/entropy_criterion.cxx index 102b63e..c8ad16e 100644 --- a/nvbench/detail/entropy_criterion.cxx +++ b/nvbench/detail/entropy_criterion.cxx @@ -26,29 +26,19 @@ namespace nvbench::detail { entropy_criterion::entropy_criterion() - : stopping_criterion{"entropy"} + : stopping_criterion{"entropy", + {{"max-angle", 0.048}, {"min-r2", 0.36}}} { m_freq_tracker.reserve(m_entropy_tracker.capacity() * 2); m_probabilities.reserve(m_entropy_tracker.capacity() * 2); } -void entropy_criterion::initialize(const criterion_params ¶ms) +void entropy_criterion::do_initialize() { - m_params = params; - m_total_samples = 0; + m_total_samples = 0; m_total_cuda_time = 0.0; m_entropy_tracker.clear(); m_freq_tracker.clear(); - - if (m_params.has_value("max-angle")) - { - m_max_angle = m_params.get_float64("max-angle"); - } - - if (m_params.has_value("min-r2")) - { - m_min_r2 = m_params.get_float64("min-r2"); - } } nvbench::float64_t entropy_criterion::compute_entropy() @@ -131,13 +121,13 @@ bool entropy_criterion::is_finished() const auto [slope, intercept] = statistics::compute_linear_regression(begin, end, mean); - if (statistics::slope2deg(slope) > m_max_angle) + if (statistics::slope2deg(slope) > m_params.get_float64("max-angle")) { return false; } const auto r2 = statistics::compute_r2(begin, end, mean, slope, intercept); - if (r2 < m_min_r2) + if (r2 < m_params.get_float64("min-r2")) { return false; } @@ -145,13 +135,4 @@ bool entropy_criterion::is_finished() return true; } -const entropy_criterion::params_description &entropy_criterion::get_params_description() const -{ - static const params_description desc{ - {"max-angle", nvbench::named_values::type::float64}, - {"min-r2", nvbench::named_values::type::float64}, - }; - return desc; -} - } // namespace nvbench::detail diff --git a/nvbench/detail/stdrel_criterion.cuh b/nvbench/detail/stdrel_criterion.cuh index c9bf6b6..d9d353c 100644 --- a/nvbench/detail/stdrel_criterion.cuh +++ b/nvbench/detail/stdrel_criterion.cuh @@ -29,10 +29,6 @@ namespace nvbench::detail class stdrel_criterion final : public stopping_criterion { - // parameters - nvbench::float64_t m_min_time{nvbench::detail::compat_min_time()}; - nvbench::float64_t m_max_noise{nvbench::detail::compat_max_noise()}; - // state nvbench::int64_t m_total_samples{}; nvbench::float64_t m_total_cuda_time{}; @@ -42,10 +38,11 @@ class stdrel_criterion final : public stopping_criterion public: stdrel_criterion(); - virtual void initialize(const criterion_params ¶ms) override; virtual void add_measurement(nvbench::float64_t measurement) override; virtual bool is_finished() override; - virtual const params_description &get_params_description() const override; + +protected: + virtual void do_initialize() override; }; } // namespace nvbench::detail diff --git a/nvbench/detail/stdrel_criterion.cxx b/nvbench/detail/stdrel_criterion.cxx index 0c19282..a7a5d22 100644 --- a/nvbench/detail/stdrel_criterion.cxx +++ b/nvbench/detail/stdrel_criterion.cxx @@ -22,26 +22,17 @@ namespace nvbench::detail { stdrel_criterion::stdrel_criterion() - : stopping_criterion{"stdrel"} + : stopping_criterion{"stdrel", + {{"max-noise", nvbench::detail::compat_max_noise()}, + {"min-time", nvbench::detail::compat_min_time()}}} {} -void stdrel_criterion::initialize(const criterion_params ¶ms) +void stdrel_criterion::do_initialize() { - m_params = params; m_total_samples = 0; m_total_cuda_time = 0.0; m_cuda_times.clear(); m_noise_tracker.clear(); - - if (m_params.has_value("max-noise")) - { - m_max_noise = m_params.get_float64("max-noise"); - } - - if (m_params.has_value("min-time")) - { - m_min_time = m_params.get_float64("min-time"); - } } void stdrel_criterion::add_measurement(nvbench::float64_t measurement) @@ -64,13 +55,13 @@ void stdrel_criterion::add_measurement(nvbench::float64_t measurement) bool stdrel_criterion::is_finished() { - if (m_total_cuda_time <= m_min_time) + if (m_total_cuda_time <= m_params.get_float64("min-time")) { return false; } // Noise has dropped below threshold - if (m_noise_tracker.back() < m_max_noise) + if (m_noise_tracker.back() < m_params.get_float64("max-noise")) { return true; } @@ -104,13 +95,4 @@ bool stdrel_criterion::is_finished() return false; } -const stdrel_criterion::params_description &stdrel_criterion::get_params_description() const -{ - static const params_description desc{ - {"max-noise", nvbench::named_values::type::float64}, - {"min-time", nvbench::named_values::type::float64}, - }; - return desc; -} - } // namespace nvbench::detail diff --git a/nvbench/option_parser.cu b/nvbench/option_parser.cu index 4380ccd..26f8e1f 100644 --- a/nvbench/option_parser.cu +++ b/nvbench/option_parser.cu @@ -377,7 +377,7 @@ void option_parser::parse_range(option_parser::arg_iterator_t first, } }; - const nvbench::stopping_criterion::params_description criterion_params = + const nvbench::criterion_manager::params_description criterion_params = nvbench::criterion_manager::get().get_params_description(); while (first < last) diff --git a/nvbench/state.cuh b/nvbench/state.cuh index afd6291..09795de 100644 --- a/nvbench/state.cuh +++ b/nvbench/state.cuh @@ -149,16 +149,28 @@ struct state /// Accumulate at least this many seconds of timing data per measurement. /// Only applies to `stdrel` stopping criterion. @{ - [[nodiscard]] nvbench::float64_t get_min_time() const { return m_criterion_params.get_float64("min-time"); } - void set_min_time(nvbench::float64_t min_time) { m_criterion_params.set_float64("min-time", min_time); } + [[nodiscard]] nvbench::float64_t get_min_time() const + { + return m_criterion_params.get_float64("min-time"); + } + void set_min_time(nvbench::float64_t min_time) + { + m_criterion_params.set_float64("min-time", min_time); + } /// @} /// Specify the maximum amount of noise if a measurement supports noise. /// Noise is the relative standard deviation: - /// `noise = stdev / mean_time`. + /// `noise = stdev / mean_time`. /// Only applies to `stdrel` stopping criterion. @{ - [[nodiscard]] nvbench::float64_t get_max_noise() const { return m_criterion_params.get_float64("max-noise"); } - void set_max_noise(nvbench::float64_t max_noise) { m_criterion_params.set_float64("max-noise", max_noise); } + [[nodiscard]] nvbench::float64_t get_max_noise() const + { + return m_criterion_params.get_float64("max-noise"); + } + void set_max_noise(nvbench::float64_t max_noise) + { + m_criterion_params.set_float64("max-noise", max_noise); + } /// @} /// If a warmup run finishes in less than `skip_time`, the measurement will diff --git a/nvbench/stopping_criterion.cuh b/nvbench/stopping_criterion.cuh index 3abf9be..6a6b523 100644 --- a/nvbench/stopping_criterion.cuh +++ b/nvbench/stopping_criterion.cuh @@ -18,12 +18,14 @@ #pragma once -#include #include +#include -#include #include +#include +#include + namespace nvbench { @@ -42,14 +44,27 @@ class criterion_params { nvbench::named_values m_named_values; public: + criterion_params(); + criterion_params(std::initializer_list>); + + /** + * Set parameter values from another criterion_params object if they exist + * + * Parameters in `other` that do not correspond to parameters in `this` are ignored. + */ + void set_from(const criterion_params &other); void set_int64(std::string name, nvbench::int64_t value); void set_float64(std::string name, nvbench::float64_t value); void set_string(std::string name, std::string value); + [[nodiscard]] std::vector get_names() const; + [[nodiscard]] nvbench::named_values::type get_type(const std::string &name) const; + [[nodiscard]] bool has_value(const std::string &name) const; [[nodiscard]] nvbench::int64_t get_int64(const std::string &name) const; [[nodiscard]] nvbench::float64_t get_float64(const std::string &name) const; + [[nodiscard]] std::string get_string(const std::string &name) const; }; /** @@ -62,16 +77,28 @@ protected: criterion_params m_params; public: - explicit stopping_criterion(std::string name) : m_name(std::move(name)) { } + /** + * @param name Unique name of the criterion + * @param params Default values for all parameters of the criterion + */ + explicit stopping_criterion(std::string name, criterion_params params) + : m_name{std::move(name)} + , m_params{params} + {} [[nodiscard]] const std::string &get_name() const { return m_name; } + [[nodiscard]] const criterion_params &get_params() const { return m_params; } /** * Initialize the criterion with the given parameters * * This method is called once per benchmark run, before any measurements are provided. */ - virtual void initialize(const criterion_params ¶ms) = 0; + void initialize(const criterion_params ¶ms) + { + m_params.set_from(params); + do_initialize(); + } /** * Add the latest measurement to the criterion @@ -83,12 +110,11 @@ public: */ virtual bool is_finished() = 0; - using params_description = std::vector>; - +protected: /** - * Return the parameter names and types for this criterion + * Initialize the criterion after updaring the parameters */ - virtual const params_description &get_params_description() const = 0; + virtual void do_initialize() = 0; }; } // namespace nvbench diff --git a/nvbench/stopping_criterion.cxx b/nvbench/stopping_criterion.cxx index 0659c7e..976a1a7 100644 --- a/nvbench/stopping_criterion.cxx +++ b/nvbench/stopping_criterion.cxx @@ -18,10 +18,48 @@ #include +#include + namespace nvbench { +// Default constructor for compatibility with old code +criterion_params::criterion_params() + : criterion_params{{"max-noise", nvbench::detail::compat_max_noise()}, + {"min-time", nvbench::detail::compat_min_time()}} +{} + +criterion_params::criterion_params( + std::initializer_list> list) +{ + for (const auto &[name, value] : list) + { + m_named_values.set_value(name, value); + } +} + +void criterion_params::set_from(const criterion_params &other) +{ + for (const std::string &name : this->get_names()) + { + if (other.has_value(name)) + { + if (this->get_type(name) != other.get_type(name)) + { + NVBENCH_THROW(std::runtime_error, + "Mismatched types for named value \"{}\". " + "Expected {}, got {}.", + name, + static_cast(this->get_type(name)), + static_cast(other.get_type(name))); + } + m_named_values.remove_value(name); + m_named_values.set_value(name, other.m_named_values.get_value(name)); + } + } +} + void criterion_params::set_int64(std::string name, nvbench::int64_t value) { if (m_named_values.has_value(name)) @@ -54,10 +92,6 @@ void criterion_params::set_string(std::string name, std::string value) bool criterion_params::has_value(const std::string &name) const { - if (name == "max-noise" || name == "min-time") - { // compat - return true; - } return m_named_values.has_value(name); } @@ -68,18 +102,23 @@ nvbench::int64_t criterion_params::get_int64(const std::string &name) const nvbench::float64_t criterion_params::get_float64(const std::string &name) const { - if (!m_named_values.has_value(name)) - { - if (name == "max-noise") - { // compat - return nvbench::detail::compat_max_noise(); - } - else if (name == "min-time") - { // compat - return nvbench::detail::compat_min_time(); - } - } return m_named_values.get_float64(name); } +std::string criterion_params::get_string(const std::string &name) const +{ + return m_named_values.get_string(name); +} + +std::vector criterion_params::get_names() const +{ + return m_named_values.get_names(); +} + +nvbench::named_values::type criterion_params::get_type(const std::string &name) const +{ + return m_named_values.get_type(name); +} + + } // namespace nvbench::detail diff --git a/testing/criterion_manager.cu b/testing/criterion_manager.cu index 6c21eb1..f17f2c0 100644 --- a/testing/criterion_manager.cu +++ b/testing/criterion_manager.cu @@ -30,16 +30,15 @@ void test_standard_criteria_exist() class custom_criterion : public nvbench::stopping_criterion { public: - custom_criterion() : nvbench::stopping_criterion("custom") {} + custom_criterion() + : nvbench::stopping_criterion("custom", nvbench::criterion_params{}) + {} - virtual void initialize(const nvbench::criterion_params &) override {} virtual void add_measurement(nvbench::float64_t /* measurement */) override {} virtual bool is_finished() override { return true; } - virtual const params_description &get_params_description() const override - { - static const params_description desc{}; - return desc; - } + +protected: + virtual void do_initialize() override {} }; void test_no_duplicates_are_allowed()