mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
Rewrite sequence_map_inverse using O(1) depth pack expansion
Replace O(N) recursive template sequence_map_inverse_impl with constexpr function and pack expansion for O(1) template depth. Results: - sequence_map_inverse: 45 instances, 187ms → 7 instances, 10ms (95% reduction)
This commit is contained in:
@@ -576,31 +576,35 @@ struct is_valid_sequence_map : is_same<typename arithmetic_sequence_gen<0, SeqMa
|
||||
{
|
||||
};
|
||||
|
||||
// O(1) template depth helper to find source index in permutation inversion
|
||||
// For a permutation X2Y, finds i such that X2Y[i] == Target
|
||||
namespace detail {
|
||||
template <index_t Target, index_t... Is>
|
||||
__host__ __device__ constexpr index_t find_source_index(Sequence<Is...>)
|
||||
{
|
||||
constexpr index_t values[] = {Is...};
|
||||
for(index_t i = 0; i < static_cast<index_t>(sizeof...(Is)); ++i)
|
||||
{
|
||||
if(values[i] == Target)
|
||||
return i;
|
||||
}
|
||||
return 0; // should not reach for valid permutation
|
||||
}
|
||||
|
||||
template <typename SeqMap, index_t... Positions>
|
||||
__host__ __device__ constexpr auto invert_permutation_impl(Sequence<Positions...>)
|
||||
{
|
||||
return Sequence<find_source_index<Positions>(SeqMap{})...>{};
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
// Invert a permutation sequence using O(1) template depth pack expansion
|
||||
// For X2Y = {a, b, c, ...}, computes Y2X where Y2X[X2Y[i]] = i
|
||||
template <typename SeqMap>
|
||||
struct sequence_map_inverse
|
||||
{
|
||||
template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
|
||||
struct sequence_map_inverse_impl
|
||||
{
|
||||
static constexpr auto new_y2x =
|
||||
WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});
|
||||
|
||||
using type =
|
||||
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::
|
||||
type;
|
||||
};
|
||||
|
||||
template <typename X2Y, typename WorkingY2X, index_t XBegin>
|
||||
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
|
||||
{
|
||||
using type = WorkingY2X;
|
||||
};
|
||||
|
||||
using type =
|
||||
typename sequence_map_inverse_impl<SeqMap,
|
||||
typename uniform_sequence_gen<SeqMap::Size(), 0>::type,
|
||||
0,
|
||||
SeqMap::Size()>::type;
|
||||
using type = decltype(detail::invert_permutation_impl<SeqMap>(
|
||||
typename arithmetic_sequence_gen<0, SeqMap::Size(), 1>::type{}));
|
||||
};
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
|
||||
Reference in New Issue
Block a user