From 4be0c5bdcd627a609de80dd97a41f1c12adcdfcb Mon Sep 17 00:00:00 2001 From: Georgy Evtushenko Date: Thu, 11 Jan 2024 10:48:52 -0800 Subject: [PATCH] API convention --- examples/custom_criterion.cu | 25 +++++++++++++------------ nvbench/detail/entropy_criterion.cuh | 6 +++--- nvbench/detail/entropy_criterion.cxx | 4 ++-- nvbench/detail/stdrel_criterion.cuh | 5 ++--- nvbench/detail/stdrel_criterion.cxx | 4 ++-- nvbench/stopping_criterion.cuh | 22 +++++++++++++++++++--- testing/criterion_manager.cu | 5 ++--- 7 files changed, 43 insertions(+), 28 deletions(-) diff --git a/examples/custom_criterion.cu b/examples/custom_criterion.cu index 3257971..4dbee9e 100644 --- a/examples/custom_criterion.cu +++ b/examples/custom_criterion.cu @@ -34,24 +34,25 @@ public: : nvbench::stopping_criterion_base{"fixed", {{"max-samples", nvbench::int64_t{42}}}} {} - // Process new measurements in the `add_measurement()` method: - virtual void add_measurement(nvbench::float64_t /* measurement */) override - { - m_num_samples++; - } - - // Check if the stopping criterion is met in the `is_finished()` method: - virtual bool is_finished() override - { - return m_num_samples >= m_params.get_int64("max-samples"); - } - protected: // Setup the criterion in the `do_initialize()` method: virtual void do_initialize() override { m_num_samples = 0; } + + // Process new measurements in the `add_measurement()` method: + virtual void do_add_measurement(nvbench::float64_t /* measurement */) override + { + m_num_samples++; + } + + // Check if the stopping criterion is met in the `is_finished()` method: + virtual bool do_is_finished() override + { + return m_num_samples >= m_params.get_int64("max-samples"); + } + }; // Register the criterion with NVBench: diff --git a/nvbench/detail/entropy_criterion.cuh b/nvbench/detail/entropy_criterion.cuh index a6478af..b0e4ebe 100644 --- a/nvbench/detail/entropy_criterion.cuh +++ b/nvbench/detail/entropy_criterion.cuh @@ -45,11 +45,11 @@ class entropy_criterion final : public stopping_criterion_base public: entropy_criterion(); - virtual void add_measurement(nvbench::float64_t measurement) override; - virtual bool is_finished() override; - protected: virtual void do_initialize() override; + virtual void do_add_measurement(nvbench::float64_t measurement) override; + virtual bool do_is_finished() override; + }; } // namespace nvbench::detail diff --git a/nvbench/detail/entropy_criterion.cxx b/nvbench/detail/entropy_criterion.cxx index 8b01ba1..6d9ba8c 100644 --- a/nvbench/detail/entropy_criterion.cxx +++ b/nvbench/detail/entropy_criterion.cxx @@ -64,7 +64,7 @@ nvbench::float64_t entropy_criterion::compute_entropy() return entropy; } -void entropy_criterion::add_measurement(nvbench::float64_t measurement) +void entropy_criterion::do_add_measurement(nvbench::float64_t measurement) { m_total_samples++; m_total_cuda_time += measurement; @@ -100,7 +100,7 @@ void entropy_criterion::add_measurement(nvbench::float64_t measurement) m_entropy_tracker.push_back(compute_entropy()); } -bool entropy_criterion::is_finished() +bool entropy_criterion::do_is_finished() { if (m_entropy_tracker.size() < 2) { diff --git a/nvbench/detail/stdrel_criterion.cuh b/nvbench/detail/stdrel_criterion.cuh index 4edb74e..5f87e84 100644 --- a/nvbench/detail/stdrel_criterion.cuh +++ b/nvbench/detail/stdrel_criterion.cuh @@ -38,11 +38,10 @@ class stdrel_criterion final : public stopping_criterion_base public: stdrel_criterion(); - virtual void add_measurement(nvbench::float64_t measurement) override; - virtual bool is_finished() override; - protected: virtual void do_initialize() override; + virtual void do_add_measurement(nvbench::float64_t measurement) override; + virtual bool do_is_finished() override; }; } // namespace nvbench::detail diff --git a/nvbench/detail/stdrel_criterion.cxx b/nvbench/detail/stdrel_criterion.cxx index 3eb3bfc..a6c5ea8 100644 --- a/nvbench/detail/stdrel_criterion.cxx +++ b/nvbench/detail/stdrel_criterion.cxx @@ -35,7 +35,7 @@ void stdrel_criterion::do_initialize() m_noise_tracker.clear(); } -void stdrel_criterion::add_measurement(nvbench::float64_t measurement) +void stdrel_criterion::do_add_measurement(nvbench::float64_t measurement) { m_total_samples++; m_total_cuda_time += measurement; @@ -53,7 +53,7 @@ void stdrel_criterion::add_measurement(nvbench::float64_t measurement) } } -bool stdrel_criterion::is_finished() +bool stdrel_criterion::do_is_finished() { if (m_total_cuda_time <= m_params.get_float64("min-time")) { diff --git a/nvbench/stopping_criterion.cuh b/nvbench/stopping_criterion.cuh index 0938a8f..36fb6eb 100644 --- a/nvbench/stopping_criterion.cuh +++ b/nvbench/stopping_criterion.cuh @@ -97,24 +97,40 @@ public: void initialize(const criterion_params ¶ms) { m_params.set_from(params); - do_initialize(); + this->do_initialize(); } /** * Add the latest measurement to the criterion */ - virtual void add_measurement(nvbench::float64_t measurement) = 0; + void add_measurement(nvbench::float64_t measurement) + { + this->do_add_measurement(measurement); + } /** * Check if the criterion has been met for all measurements processed by `add_measurement` */ - virtual bool is_finished() = 0; + bool is_finished() + { + return this->do_is_finished(); + } protected: /** * Initialize the criterion after updaring the parameters */ virtual void do_initialize() = 0; + + /** + * Add the latest measurement to the criterion + */ + virtual void do_add_measurement(nvbench::float64_t measurement) = 0; + + /** + * Check if the criterion has been met for all measurements processed by `add_measurement` + */ + virtual bool do_is_finished() = 0; }; } // namespace nvbench diff --git a/testing/criterion_manager.cu b/testing/criterion_manager.cu index 0cf9204..841cd8c 100644 --- a/testing/criterion_manager.cu +++ b/testing/criterion_manager.cu @@ -34,11 +34,10 @@ public: : nvbench::stopping_criterion_base("custom", nvbench::criterion_params{}) {} - virtual void add_measurement(nvbench::float64_t /* measurement */) override {} - virtual bool is_finished() override { return true; } - protected: virtual void do_initialize() override {} + virtual void do_add_measurement(nvbench::float64_t /* measurement */) override {} + virtual bool do_is_finished() override { return true; } }; void test_no_duplicates_are_allowed()