Fix get_config_count for CPU-only benchmarks. (#218)

This commit is contained in:
Allison Piper
2025-05-01 12:34:35 -04:00
committed by GitHub
parent 433376fd83
commit 9d189280de
2 changed files with 13 additions and 2 deletions

View File

@@ -20,6 +20,9 @@
#include <nvbench/criterion_manager.cuh>
#include <nvbench/detail/transform_reduce.cuh>
#include <algorithm>
#include <cstdint>
namespace nvbench
{
@@ -86,7 +89,8 @@ std::size_t benchmark_base::get_config_count() const
return axis_ptr->get_size();
});
return per_device_count * m_devices.size();
// Devices will be empty for cpu-only benchmarks.
return per_device_count * std::max(std::size_t(1), m_devices.size());
}
benchmark_base &benchmark_base::set_stopping_criterion(std::string criterion)

View File

@@ -18,6 +18,7 @@
#include <nvbench/benchmark.cuh>
#include <nvbench/callable.cuh>
#include <nvbench/device_manager.cuh>
#include <nvbench/named_values.cuh>
#include <nvbench/state.cuh>
#include <nvbench/type_list.cuh>
@@ -27,6 +28,7 @@
#include <fmt/format.h>
#include <algorithm>
#include <cstdint>
#include <utility>
#include <variant>
#include <vector>
@@ -279,6 +281,7 @@ void test_clone()
void test_get_config_count()
{
lots_of_types_bench bench;
bench.set_devices(nvbench::device_manager::get().get_devices());
bench.set_type_axes_names({"Integer", "Float", "Other"});
bench.get_axes().get_type_axis(0).set_active_inputs({"I16", "I32"}); // 2, 2
bench.get_axes().get_type_axis(1).set_active_inputs({"F32", "F64"}); // 2, 4
@@ -288,9 +291,13 @@ void test_get_config_count()
bench.add_string_axis("baz", {"str", "ing"}); // 2, 72
bench.add_string_axis("baz", {"single"}); // 1, 72
auto const num_devices = bench.get_devices().size();
auto const num_devices = std::max(std::size_t(1), bench.get_devices().size());
ASSERT_MSG(bench.get_config_count() == 72 * num_devices, "Got {}", bench.get_config_count());
// Check that zero devices (e.g. CPU-only) is the same as a single device:
bench.set_devices(std::vector<int>{});
ASSERT_MSG(bench.get_config_count() == 72, "Got {}", bench.get_config_count());
}
int main()