mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-04-20 14:58:54 +00:00
Add termination criteria API.
- min_samples - min_time - max_noise - skip_time (not yet implemented) - timeout Refactored s/(trials)|(iters)/samples/s.
This commit is contained in:
@@ -717,16 +717,10 @@ void test_devices()
|
||||
ASSERT(states.size() == 12);
|
||||
|
||||
fmt::memory_buffer buffer;
|
||||
const std::string table_format =
|
||||
"| {:^5} | {:^6} | {:^5} | {:^3} |\n";
|
||||
const std::string table_format = "| {:^5} | {:^6} | {:^5} | {:^3} |\n";
|
||||
|
||||
fmt::format_to(buffer, "\n");
|
||||
fmt::format_to(buffer,
|
||||
table_format,
|
||||
"State",
|
||||
"Device",
|
||||
"S",
|
||||
"I");
|
||||
fmt::format_to(buffer, table_format, "State", "Device", "S", "I");
|
||||
|
||||
std::size_t config = 0;
|
||||
for (const auto &state : states)
|
||||
@@ -760,6 +754,36 @@ void test_devices()
|
||||
ASSERT_MSG(test == ref, "Expected:\n\"{}\"\n\nActual:\n\"{}\"", ref, test);
|
||||
}
|
||||
|
||||
void test_termination_criteria()
|
||||
{
|
||||
const nvbench::int64_t min_samples = 1000;
|
||||
const nvbench::float64_t min_time = 2000;
|
||||
const nvbench::float64_t max_noise = 3000;
|
||||
const nvbench::float64_t skip_time = 4000;
|
||||
const nvbench::float64_t timeout = 5000;
|
||||
|
||||
// for comparing floats
|
||||
auto within_one = [](auto a, auto b) { return std::abs(a - b) < 1.; };
|
||||
|
||||
dummy_bench bench;
|
||||
bench.set_devices(std::vector<int>{});
|
||||
bench.set_min_samples(min_samples);
|
||||
bench.set_min_time(min_time);
|
||||
bench.set_max_noise(max_noise);
|
||||
bench.set_skip_time(skip_time);
|
||||
bench.set_timeout(timeout);
|
||||
|
||||
const std::vector<nvbench::state> states =
|
||||
nvbench::detail::state_generator::create(bench);
|
||||
|
||||
ASSERT(states.size() == 1);
|
||||
ASSERT(min_samples == states[0].get_min_samples());
|
||||
ASSERT(within_one(min_time, states[0].get_min_time()));
|
||||
ASSERT(within_one(max_noise, states[0].get_max_noise()));
|
||||
ASSERT(within_one(skip_time, states[0].get_skip_time()));
|
||||
ASSERT(within_one(timeout, states[0].get_timeout()));
|
||||
}
|
||||
|
||||
int main()
|
||||
try
|
||||
{
|
||||
@@ -770,9 +794,11 @@ try
|
||||
test_create_with_types();
|
||||
test_create_with_masked_types();
|
||||
test_devices();
|
||||
test_termination_criteria();
|
||||
|
||||
return 0;
|
||||
}
|
||||
catch (std::exception& e)
|
||||
catch (std::exception &e)
|
||||
{
|
||||
fmt::print("{}\n", e.what());
|
||||
return 1;
|
||||
|
||||
Reference in New Issue
Block a user