Update new test to support device-init changes.

This commit is contained in:
Allison Piper
2025-05-01 16:40:09 +00:00
parent edefcd0f6a
commit 250d755bd6

View File

@@ -17,23 +17,24 @@
*/
#include <nvbench/benchmark.cuh>
#include <nvbench/callable.cuh>
#include <nvbench/device_manager.cuh>
#include <nvbench/named_values.cuh>
#include <nvbench/state.cuh>
#include <nvbench/type_list.cuh>
#include <nvbench/type_strings.cuh>
#include <nvbench/types.cuh>
#include "test_asserts.cuh"
#include <fmt/format.h>
#include <algorithm>
#include <iterator>
#include <utility>
#include <variant>
#include <vector>
#include "test_asserts.cuh"
template <typename T>
std::vector<T> sort(std::vector<T> &&vec)
{
@@ -114,12 +115,18 @@ void test_zip_axes()
{
using benchmark_type = nvbench::benchmark<no_op_callable>;
benchmark_type bench;
bench.set_devices(nvbench::device_manager::get().get_devices());
bench.add_zip_axes(nvbench::float64_axis("F64 Axis", {0., .1, .25, .5, 1.}),
nvbench::int64_axis("I64 Axis", {1, 3, 2, 4, 5}));
ASSERT_MSG(bench.get_config_count() == 5 * bench.get_devices().size(),
"Got {}",
bench.get_config_count());
const auto num_devices = std::max(std::size_t(1), bench.get_devices().size());
ASSERT_MSG(bench.get_config_count() == 5 * num_devices,
"Got {}, expected {}",
bench.get_config_count(),
5 * bench.get_devices().size());
bench.set_devices(std::vector<int>{});
ASSERT_MSG(bench.get_config_count() == 5, "Got {}, expected {}", bench.get_config_count(), 5);
}
void test_zip_unequal_length()
@@ -241,6 +248,7 @@ void test_user_axes()
{
using benchmark_type = rezippable_benchmark<no_op_callable>;
benchmark_type bench;
bench.set_devices(nvbench::device_manager::get().get_devices());
bench.add_user_iteration_axes(
[](auto... args) -> std::unique_ptr<nvbench::iteration_space_base> {
return std::make_unique<under_diag>(args...);
@@ -248,9 +256,11 @@ void test_user_axes()
nvbench::float64_axis("F64 Axis", {0., .1, .25, .5, 1.}),
nvbench::int64_axis("I64 Axis", {1, 3, 2, 4, 5}));
ASSERT_MSG(bench.get_config_count() == 15 * bench.get_devices().size(),
"Got {}",
bench.get_config_count());
const auto num_devices = std::max(std::size_t(1), bench.get_devices().size());
ASSERT_MSG(bench.get_config_count() == 15 * num_devices, "Got {}", bench.get_config_count());
bench.set_devices(std::vector<int>{});
ASSERT_MSG(bench.get_config_count() == 15, "Got {}", bench.get_config_count());
}
int main()