mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Optimize sequence_merge using direct concatenation for small cases
Replace linear recursive instantiation with direct pack expansion for 1-4 sequences, and binary tree reduction for larger cases. Before: O(N) depth for merging N sequences After: O(log N) depth with O(1) for up to 4 sequences This further reduces maximum nesting depth from 26 to 22 levels when combined with the previous sequence_gen optimization. Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -199,30 +199,71 @@ template <index_t N>
|
||||
using make_index_sequence =
|
||||
typename __make_integer_seq<impl::__integer_sequence, index_t, N>::seq_type;
|
||||
|
||||
// merge sequence
|
||||
template <typename Seq, typename... Seqs>
|
||||
struct sequence_merge
|
||||
// merge sequence - optimized to avoid recursive instantiation
|
||||
namespace detail {
|
||||
|
||||
// Helper to concatenate multiple sequences in one step using fold expression
|
||||
template <typename... Seqs>
|
||||
struct sequence_merge_impl;
|
||||
|
||||
// Base case: single sequence
|
||||
template <index_t... Is>
|
||||
struct sequence_merge_impl<Sequence<Is...>>
|
||||
{
|
||||
using type = typename sequence_merge<Seq, typename sequence_merge<Seqs...>::type>::type;
|
||||
using type = Sequence<Is...>;
|
||||
};
|
||||
|
||||
// Two sequences: direct concatenation
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
|
||||
struct sequence_merge_impl<Sequence<Xs...>, Sequence<Ys...>>
|
||||
{
|
||||
using type = Sequence<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
template <typename Seq>
|
||||
struct sequence_merge<Seq>
|
||||
// Three sequences: direct concatenation (avoids one level of recursion)
|
||||
template <index_t... Xs, index_t... Ys, index_t... Zs>
|
||||
struct sequence_merge_impl<Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>>
|
||||
{
|
||||
using type = Seq;
|
||||
using type = Sequence<Xs..., Ys..., Zs...>;
|
||||
};
|
||||
|
||||
// Four sequences: direct concatenation
|
||||
template <index_t... As, index_t... Bs, index_t... Cs, index_t... Ds>
|
||||
struct sequence_merge_impl<Sequence<As...>, Sequence<Bs...>, Sequence<Cs...>, Sequence<Ds...>>
|
||||
{
|
||||
using type = Sequence<As..., Bs..., Cs..., Ds...>;
|
||||
};
|
||||
|
||||
// General case: binary tree reduction (O(log N) depth instead of O(N))
|
||||
template <typename S1, typename S2, typename S3, typename S4, typename... Rest>
|
||||
struct sequence_merge_impl<S1, S2, S3, S4, Rest...>
|
||||
{
|
||||
// Merge pairs first, then recurse
|
||||
using left = typename sequence_merge_impl<S1, S2>::type;
|
||||
using right = typename sequence_merge_impl<S3, S4, Rest...>::type;
|
||||
using type = typename sequence_merge_impl<left, right>::type;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename... Seqs>
|
||||
struct sequence_merge
|
||||
{
|
||||
using type = typename detail::sequence_merge_impl<Seqs...>::type;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct sequence_merge<>
|
||||
{
|
||||
using type = Sequence<>;
|
||||
};
|
||||
|
||||
// generate sequence - optimized using __make_integer_seq to avoid recursive instantiation
|
||||
namespace detail {
|
||||
|
||||
// Helper that applies functor F to indices and produces a Sequence
|
||||
// __make_integer_seq<sequence_gen_helper, index_t, N> produces sequence_gen_helper<index_t, 0, 1, ..., N-1>
|
||||
// __make_integer_seq<sequence_gen_helper, index_t, N> produces sequence_gen_helper<index_t, 0, 1,
|
||||
// ..., N-1>
|
||||
template <typename T, T... Is>
|
||||
struct sequence_gen_helper
|
||||
{
|
||||
@@ -236,8 +277,8 @@ struct sequence_gen_helper
|
||||
template <index_t NSize, typename F>
|
||||
struct sequence_gen
|
||||
{
|
||||
using type = typename __make_integer_seq<detail::sequence_gen_helper, index_t, NSize>::
|
||||
template apply<F>;
|
||||
using type =
|
||||
typename __make_integer_seq<detail::sequence_gen_helper, index_t, NSize>::template apply<F>;
|
||||
};
|
||||
|
||||
template <typename F>
|
||||
|
||||
@@ -20,6 +20,7 @@ struct tuple_concat<Tuple<Xs...>, Tuple<Ys...>>
|
||||
using type = Tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
// StaticallyIndexedArrayImpl uses binary split for O(log N) depth
|
||||
template <typename T, index_t N>
|
||||
struct StaticallyIndexedArrayImpl
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user