mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
refactor
This commit is contained in:
@@ -1,41 +1,33 @@
|
||||
#pragma once
|
||||
#include "Sequence.hip.hpp"
|
||||
|
||||
template <index_t RemainDim>
|
||||
// RemainLengths: Sequence<...>
|
||||
template <class RemainLengths>
|
||||
struct static_ford_impl
|
||||
{
|
||||
// F signature: F(Sequence<...> multi_id)
|
||||
// CurrentMultiIndex: Sequence<...>
|
||||
// RemainLengths: Sequence<...>
|
||||
template <class F, class CurrentMultiIndex, class RemainLengths>
|
||||
__host__ __device__ void operator()(F f, CurrentMultiIndex, RemainLengths) const
|
||||
template <class F, class CurrentMultiIndex>
|
||||
__host__ __device__ void operator()(F f, CurrentMultiIndex) const
|
||||
{
|
||||
static_assert(RemainLengths::GetSize() == RemainDim, "wrong!");
|
||||
static_assert(RemainDim > 1, "wrong!");
|
||||
static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here");
|
||||
|
||||
constexpr auto next_length = RemainLengths{}.Front();
|
||||
|
||||
static_for<0, next_length, 1>{}([=](auto I) {
|
||||
static_ford_impl<RemainDim - 1>{}(
|
||||
f, CurrentMultiIndex{}.PushBack(I), RemainLengths{}.PopFront());
|
||||
static_for<0, RemainLengths::Front(), 1>{}([=](auto I) {
|
||||
static_ford_impl<decltype(RemainLengths::PopFront())>{}(f,
|
||||
CurrentMultiIndex::PushBack(I));
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct static_ford_impl<1>
|
||||
struct static_ford_impl<Sequence<>>
|
||||
{
|
||||
// F signature: F(Sequence<Is...> multi_id)
|
||||
// F signature: F(Sequence<...> multi_id)
|
||||
// CurrentMultiIndex: Sequence<...>
|
||||
// RemainLengths: Sequence<...>
|
||||
template <class F, class CurrentMultiIndex, class RemainLengths>
|
||||
__host__ __device__ void operator()(F f, CurrentMultiIndex, RemainLengths) const
|
||||
template <class F, class CurrentMultiIndex>
|
||||
__host__ __device__ void operator()(F f, CurrentMultiIndex) const
|
||||
{
|
||||
static_assert(RemainLengths::GetSize() == 1, "wrong!");
|
||||
|
||||
constexpr index_t last_length = RemainLengths{}.Front();
|
||||
|
||||
static_for<0, last_length, 1>{}([=](auto I) { f(CurrentMultiIndex{}.PushBack(I)); });
|
||||
f(CurrentMultiIndex{});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -43,16 +35,13 @@ struct static_ford_impl<1>
|
||||
template <class Lengths>
|
||||
struct static_ford
|
||||
{
|
||||
// F signature: F(Sequence<Is...> multi_id)
|
||||
// F signature: F(Sequence<...> multi_id)
|
||||
template <class F>
|
||||
__host__ __device__ void operator()(F f) const
|
||||
{
|
||||
constexpr index_t first_length = Lengths{}.Front();
|
||||
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
|
||||
|
||||
static_for<0, first_length, 1>{}([=](auto I) {
|
||||
static_ford_impl<Lengths::GetSize() - 1>{}(
|
||||
f, Sequence<I.Get()>{}, Lengths{}.PopFront());
|
||||
});
|
||||
static_ford_impl<Lengths>{}(f, Sequence<>{});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user