mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 04:49:54 +00:00
more utility code
This commit is contained in:
@@ -8,20 +8,7 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class>
|
||||
struct is_static : integral_constant<bool, false>
|
||||
{
|
||||
};
|
||||
|
||||
template <class T, T X>
|
||||
struct is_static<integral_constant<T, X>> : integral_constant<bool, true>
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t... Is>
|
||||
struct is_static<Sequence<Is...>> : integral_constant<bool, true>
|
||||
{
|
||||
};
|
||||
namespace detail {
|
||||
|
||||
// RemainLengths: Sequence<...>
|
||||
// Orders: Sequence<...>
|
||||
@@ -58,29 +45,6 @@ struct static_ford_impl<Sequence<>, Orders>
|
||||
}
|
||||
};
|
||||
|
||||
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
|
||||
// Orders is Sequence<...>, it is the order of dimension in which static_ford will loop over each
|
||||
// dimension
|
||||
template <class Lengths,
|
||||
class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
|
||||
struct static_ford
|
||||
{
|
||||
__host__ __device__ constexpr static_ford()
|
||||
{
|
||||
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
|
||||
static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
|
||||
}
|
||||
|
||||
// F signature: F(Sequence<...> multi_id)
|
||||
// multi_id is the unordered multi-index
|
||||
template <class F>
|
||||
__host__ __device__ constexpr void operator()(F f) const
|
||||
{
|
||||
constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});
|
||||
static_ford_impl<decltype(ordered_lengths), Orders>{}(f, Sequence<>{});
|
||||
}
|
||||
};
|
||||
|
||||
// RemainLengths: Sequence<...>
|
||||
// Orders: Sequence<...>
|
||||
template <class RemainLengths, class Orders>
|
||||
@@ -117,6 +81,31 @@ struct ford_impl<Sequence<>, Orders>
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
|
||||
// Orders is Sequence<...>, it is the order of dimension in which static_ford will loop over each
|
||||
// dimension
|
||||
template <class Lengths,
|
||||
class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
|
||||
struct static_ford
|
||||
{
|
||||
__host__ __device__ constexpr static_ford()
|
||||
{
|
||||
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
|
||||
static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
|
||||
}
|
||||
|
||||
// F signature: F(Sequence<...> multi_id)
|
||||
// multi_id is the unordered multi-index
|
||||
template <class F>
|
||||
__host__ __device__ constexpr void operator()(F f) const
|
||||
{
|
||||
constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});
|
||||
detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, Sequence<>{});
|
||||
}
|
||||
};
|
||||
|
||||
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
|
||||
// Orders is Sequence<...>, it is the order of dimension in which ford will loop over each
|
||||
// dimension
|
||||
@@ -139,7 +128,8 @@ struct ford
|
||||
|
||||
for(index_t i = 0; i < ordered_lengths.Front(); ++i)
|
||||
{
|
||||
ford_impl<decltype(ordered_lengths.PopFront()), Orders>{}(f, Array<index_t, 1>{i});
|
||||
detail::ford_impl<decltype(ordered_lengths.PopFront()), Orders>{}(f,
|
||||
Array<index_t, 1>{i});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user