mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-04-20 14:58:54 +00:00
Update new test to support device-init changes.
This commit is contained in:
@@ -17,23 +17,24 @@
|
||||
*/
|
||||
|
||||
#include <nvbench/benchmark.cuh>
|
||||
|
||||
#include <nvbench/callable.cuh>
|
||||
#include <nvbench/device_manager.cuh>
|
||||
#include <nvbench/named_values.cuh>
|
||||
#include <nvbench/state.cuh>
|
||||
#include <nvbench/type_list.cuh>
|
||||
#include <nvbench/type_strings.cuh>
|
||||
#include <nvbench/types.cuh>
|
||||
|
||||
#include "test_asserts.cuh"
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
|
||||
#include "test_asserts.cuh"
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> sort(std::vector<T> &&vec)
|
||||
{
|
||||
@@ -114,12 +115,18 @@ void test_zip_axes()
|
||||
{
|
||||
using benchmark_type = nvbench::benchmark<no_op_callable>;
|
||||
benchmark_type bench;
|
||||
bench.set_devices(nvbench::device_manager::get().get_devices());
|
||||
bench.add_zip_axes(nvbench::float64_axis("F64 Axis", {0., .1, .25, .5, 1.}),
|
||||
nvbench::int64_axis("I64 Axis", {1, 3, 2, 4, 5}));
|
||||
|
||||
ASSERT_MSG(bench.get_config_count() == 5 * bench.get_devices().size(),
|
||||
"Got {}",
|
||||
bench.get_config_count());
|
||||
const auto num_devices = std::max(std::size_t(1), bench.get_devices().size());
|
||||
ASSERT_MSG(bench.get_config_count() == 5 * num_devices,
|
||||
"Got {}, expected {}",
|
||||
bench.get_config_count(),
|
||||
5 * bench.get_devices().size());
|
||||
|
||||
bench.set_devices(std::vector<int>{});
|
||||
ASSERT_MSG(bench.get_config_count() == 5, "Got {}, expected {}", bench.get_config_count(), 5);
|
||||
}
|
||||
|
||||
void test_zip_unequal_length()
|
||||
@@ -241,6 +248,7 @@ void test_user_axes()
|
||||
{
|
||||
using benchmark_type = rezippable_benchmark<no_op_callable>;
|
||||
benchmark_type bench;
|
||||
bench.set_devices(nvbench::device_manager::get().get_devices());
|
||||
bench.add_user_iteration_axes(
|
||||
[](auto... args) -> std::unique_ptr<nvbench::iteration_space_base> {
|
||||
return std::make_unique<under_diag>(args...);
|
||||
@@ -248,9 +256,11 @@ void test_user_axes()
|
||||
nvbench::float64_axis("F64 Axis", {0., .1, .25, .5, 1.}),
|
||||
nvbench::int64_axis("I64 Axis", {1, 3, 2, 4, 5}));
|
||||
|
||||
ASSERT_MSG(bench.get_config_count() == 15 * bench.get_devices().size(),
|
||||
"Got {}",
|
||||
bench.get_config_count());
|
||||
const auto num_devices = std::max(std::size_t(1), bench.get_devices().size());
|
||||
ASSERT_MSG(bench.get_config_count() == 15 * num_devices, "Got {}", bench.get_config_count());
|
||||
|
||||
bench.set_devices(std::vector<int>{});
|
||||
ASSERT_MSG(bench.get_config_count() == 15, "Got {}", bench.get_config_count());
|
||||
}
|
||||
|
||||
int main()
|
||||
|
||||
Reference in New Issue
Block a user