API convention

This commit is contained in:
Georgy Evtushenko
2024-01-11 10:48:52 -08:00
parent 34c57965a2
commit 4be0c5bdcd
7 changed files with 43 additions and 28 deletions

View File

@@ -34,24 +34,25 @@ public:
: nvbench::stopping_criterion_base{"fixed", {{"max-samples", nvbench::int64_t{42}}}}
{}
// Process new measurements in the `add_measurement()` method:
virtual void add_measurement(nvbench::float64_t /* measurement */) override
{
m_num_samples++;
}
// Check if the stopping criterion is met in the `is_finished()` method:
virtual bool is_finished() override
{
return m_num_samples >= m_params.get_int64("max-samples");
}
protected:
// Setup the criterion in the `do_initialize()` method:
virtual void do_initialize() override
{
m_num_samples = 0;
}
// Process new measurements in the `add_measurement()` method:
virtual void do_add_measurement(nvbench::float64_t /* measurement */) override
{
m_num_samples++;
}
// Check if the stopping criterion is met in the `is_finished()` method:
virtual bool do_is_finished() override
{
return m_num_samples >= m_params.get_int64("max-samples");
}
};
// Register the criterion with NVBench:

View File

@@ -45,11 +45,11 @@ class entropy_criterion final : public stopping_criterion_base
public:
entropy_criterion();
virtual void add_measurement(nvbench::float64_t measurement) override;
virtual bool is_finished() override;
protected:
virtual void do_initialize() override;
virtual void do_add_measurement(nvbench::float64_t measurement) override;
virtual bool do_is_finished() override;
};
} // namespace nvbench::detail

View File

@@ -64,7 +64,7 @@ nvbench::float64_t entropy_criterion::compute_entropy()
return entropy;
}
void entropy_criterion::add_measurement(nvbench::float64_t measurement)
void entropy_criterion::do_add_measurement(nvbench::float64_t measurement)
{
m_total_samples++;
m_total_cuda_time += measurement;
@@ -100,7 +100,7 @@ void entropy_criterion::add_measurement(nvbench::float64_t measurement)
m_entropy_tracker.push_back(compute_entropy());
}
bool entropy_criterion::is_finished()
bool entropy_criterion::do_is_finished()
{
if (m_entropy_tracker.size() < 2)
{

View File

@@ -38,11 +38,10 @@ class stdrel_criterion final : public stopping_criterion_base
public:
stdrel_criterion();
virtual void add_measurement(nvbench::float64_t measurement) override;
virtual bool is_finished() override;
protected:
virtual void do_initialize() override;
virtual void do_add_measurement(nvbench::float64_t measurement) override;
virtual bool do_is_finished() override;
};
} // namespace nvbench::detail

View File

@@ -35,7 +35,7 @@ void stdrel_criterion::do_initialize()
m_noise_tracker.clear();
}
void stdrel_criterion::add_measurement(nvbench::float64_t measurement)
void stdrel_criterion::do_add_measurement(nvbench::float64_t measurement)
{
m_total_samples++;
m_total_cuda_time += measurement;
@@ -53,7 +53,7 @@ void stdrel_criterion::add_measurement(nvbench::float64_t measurement)
}
}
bool stdrel_criterion::is_finished()
bool stdrel_criterion::do_is_finished()
{
if (m_total_cuda_time <= m_params.get_float64("min-time"))
{

View File

@@ -97,24 +97,40 @@ public:
void initialize(const criterion_params &params)
{
m_params.set_from(params);
do_initialize();
this->do_initialize();
}
/**
* Add the latest measurement to the criterion
*/
virtual void add_measurement(nvbench::float64_t measurement) = 0;
void add_measurement(nvbench::float64_t measurement)
{
this->do_add_measurement(measurement);
}
/**
* Check if the criterion has been met for all measurements processed by `add_measurement`
*/
virtual bool is_finished() = 0;
bool is_finished()
{
return this->do_is_finished();
}
protected:
/**
* Initialize the criterion after updaring the parameters
*/
virtual void do_initialize() = 0;
/**
* Add the latest measurement to the criterion
*/
virtual void do_add_measurement(nvbench::float64_t measurement) = 0;
/**
* Check if the criterion has been met for all measurements processed by `add_measurement`
*/
virtual bool do_is_finished() = 0;
};
} // namespace nvbench

View File

@@ -34,11 +34,10 @@ public:
: nvbench::stopping_criterion_base("custom", nvbench::criterion_params{})
{}
virtual void add_measurement(nvbench::float64_t /* measurement */) override {}
virtual bool is_finished() override { return true; }
protected:
virtual void do_initialize() override {}
virtual void do_add_measurement(nvbench::float64_t /* measurement */) override {}
virtual bool do_is_finished() override { return true; }
};
void test_no_duplicates_are_allowed()