Add nvbench::tl::foreach.

This commit is contained in:
Allison Vacanti
2020-12-22 15:17:08 -05:00
parent fb10be7e72
commit 95e2eaf607
5 changed files with 93 additions and 6 deletions

View File

@@ -7,7 +7,14 @@ namespace nvbench
{
template <typename... Ts>
struct type_list;
struct type_list
{};
template <typename T>
struct wrapped_type
{
using type = T;
};
namespace tl
{
@@ -83,6 +90,23 @@ struct cartesian_product<
using type = decltype(detail::concat(cur{}, next{}));
};
//------------------------------------------------------------------------------
template <typename TypeList, typename Functor, std::size_t... Is>
void foreach (std::index_sequence<Is...>, Functor && f)
{
// Garmonbozia...
((f(wrapped_type<decltype(detail::get<Is>(TypeList{}))>{})), ...);
}
template <typename TypeList, typename Functor>
void foreach (Functor &&f)
{
constexpr std::size_t list_size = decltype(detail::size(TypeList{}))::value;
using indices = std::make_index_sequence<list_size>;
detail::foreach<TypeList>(indices{}, std::forward<Functor>(f));
}
} // namespace detail
} // namespace tl
} // namespace nvbench

View File

@@ -9,8 +9,11 @@ namespace nvbench
{
template <typename... Ts>
struct type_list
{};
struct type_list;
// Wraps a type for use with nvbench::tl::foreach.
template <typename T>
struct wrapped_type;
namespace tl
{
@@ -99,6 +102,26 @@ using prepend_each = typename detail::prepend_each<T, TypeLists>::type;
template <typename TypeLists>
using cartesian_product = typename detail::cartesian_product<TypeLists>::type;
/**
* Invoke the Functor once for each type in TypeList. The type will be passed to
* `f` as a `nvbench::wrapped_type<T>` argument.
*
* ```c++
* using TL = nvbench::type_list<int8_t, int16_t, int32_t, int64_t>;
* std::vector<std::size_t> sizes;
* nvbench::tl::foreach<TL>([&sizes](auto wrapped_type) {
* using T = typename decltype(wrapped_type)::type;
* sizes.push_back(sizeof(T));
* });
* static_assert(sizes == {1, 2, 3, 4});
* ```
*/
template <typename TypeList, typename Functor>
void foreach (Functor &&f)
{
detail::foreach<TypeList>(std::forward<Functor>(f));
}
} // namespace tl
} // namespace nvbench

View File

@@ -7,6 +7,7 @@ foreach(test_src IN LISTS test_srcs)
get_filename_component(test_name "${test_src}" NAME_WLE)
string(PREPEND test_name "nvbench.test.")
add_executable(${test_name} "${test_src}")
target_include_directories(${test_name} PRIVATE "${CMAKE_CURRENT_LIST_DIR}")
target_link_libraries(${test_name} PRIVATE nvbench fmt)
set_target_properties(${test_name} PROPERTIES COMPILE_FEATURES cuda_std_17)
add_test(NAME ${test_name} COMMAND "$<TARGET_FILE:${test_target}>")

View File

@@ -1,6 +1,6 @@
#include <nvbench/int64_axis.cuh>
#include "testing/test_asserts.cuh"
#include "test_asserts.cuh"
#include <fmt/format.h>

View File

@@ -2,8 +2,15 @@
#include <nvbench/type_strings.cuh>
#include "test_asserts.cuh"
#include <fmt/format.h>
#include <fmt/ranges.h>
#include <cstdint>
#include <string>
#include <type_traits>
#include <vector>
// Unique, numbered types for testing type_list functionality.
using T0 = std::integral_constant<std::size_t, 0>;
@@ -103,5 +110,37 @@ struct test_cartesian_product
static_assert(std::is_same_v<nvbench::tl::cartesian_product<TLs>, CartProd>);
};
// This test only has static asserts.
int main() {}
struct test_foreach
{
using TL0 = nvbench::type_list<>;
using TL1 = nvbench::type_list<T0>;
using TL2 = nvbench::type_list<T0, T1>;
using TL3 = nvbench::type_list<T0, T1, T2>;
template <typename TypeList>
static void test(std::vector<std::string> ref_vals)
{
std::vector<std::string> test_vals;
nvbench::tl::foreach<TypeList>([&test_vals](auto wrapped_type) {
using T = typename decltype(wrapped_type)::type;
test_vals.push_back(nvbench::type_strings<T>::input_string());
});
ASSERT_MSG(test_vals == ref_vals,
fmt::format("{} != {}", test_vals, ref_vals));
}
static void run()
{
test<TL0>({});
test<TL1>({"T0"});
test<TL2>({"T0", "T1"});
test<TL3>({"T0", "T1", "T2"});
}
};
int main()
{
// Note that most tests in this file are just static asserts. Only those with
// runtime components are listed here.
test_foreach::run();
}