Add termination criteria API.

- min_samples
- min_time
- max_noise
- skip_time (not yet implemented)
- timeout

Refactored s/(trials)|(iters)/samples/s.
This commit is contained in:
Allison Vacanti
2021-02-15 11:56:10 -05:00
parent e5914ff620
commit d323f569b8
9 changed files with 258 additions and 84 deletions

View File

@@ -717,16 +717,10 @@ void test_devices()
ASSERT(states.size() == 12);
fmt::memory_buffer buffer;
const std::string table_format =
"| {:^5} | {:^6} | {:^5} | {:^3} |\n";
const std::string table_format = "| {:^5} | {:^6} | {:^5} | {:^3} |\n";
fmt::format_to(buffer, "\n");
fmt::format_to(buffer,
table_format,
"State",
"Device",
"S",
"I");
fmt::format_to(buffer, table_format, "State", "Device", "S", "I");
std::size_t config = 0;
for (const auto &state : states)
@@ -760,6 +754,36 @@ void test_devices()
ASSERT_MSG(test == ref, "Expected:\n\"{}\"\n\nActual:\n\"{}\"", ref, test);
}
void test_termination_criteria()
{
const nvbench::int64_t min_samples = 1000;
const nvbench::float64_t min_time = 2000;
const nvbench::float64_t max_noise = 3000;
const nvbench::float64_t skip_time = 4000;
const nvbench::float64_t timeout = 5000;
// for comparing floats
auto within_one = [](auto a, auto b) { return std::abs(a - b) < 1.; };
dummy_bench bench;
bench.set_devices(std::vector<int>{});
bench.set_min_samples(min_samples);
bench.set_min_time(min_time);
bench.set_max_noise(max_noise);
bench.set_skip_time(skip_time);
bench.set_timeout(timeout);
const std::vector<nvbench::state> states =
nvbench::detail::state_generator::create(bench);
ASSERT(states.size() == 1);
ASSERT(min_samples == states[0].get_min_samples());
ASSERT(within_one(min_time, states[0].get_min_time()));
ASSERT(within_one(max_noise, states[0].get_max_noise()));
ASSERT(within_one(skip_time, states[0].get_skip_time()));
ASSERT(within_one(timeout, states[0].get_timeout()));
}
int main()
try
{
@@ -770,9 +794,11 @@ try
test_create_with_types();
test_create_with_masked_types();
test_devices();
test_termination_criteria();
return 0;
}
catch (std::exception& e)
catch (std::exception &e)
{
fmt::print("{}\n", e.what());
return 1;