mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
Replace sequence_merge O(log N) recursion with O(1) fold expression
Use operator| with fold expression (Seqs{} | ...) to merge sequences
in O(1) template depth instead of O(log N) binary tree recursion.
- Reduces sequence_merge instantiations from 449 to 167 (63% reduction)
- Total template instantiations: 47,186 → 46,974 (-212)
- ADL finds operator| since Sequence is in ck namespace
This commit is contained in:
@@ -199,57 +199,20 @@ template <index_t N>
|
||||
using make_index_sequence =
|
||||
typename __make_integer_seq<impl::__integer_sequence, index_t, N>::seq_type;
|
||||
|
||||
// merge sequence - optimized to avoid recursive instantiation
|
||||
namespace detail {
|
||||
|
||||
// Helper to concatenate multiple sequences in one step using fold expression
|
||||
template <typename... Seqs>
|
||||
struct sequence_merge_impl;
|
||||
|
||||
// Base case: single sequence
|
||||
template <index_t... Is>
|
||||
struct sequence_merge_impl<Sequence<Is...>>
|
||||
// merge sequence - O(1) template depth using fold expression
|
||||
// Binary merge operator for fold expression - enables O(1) depth via (S1 | S2 | S3 | ...)
|
||||
// Must be in ck namespace for ADL to find it when used with Sequence types
|
||||
template <index_t... As, index_t... Bs>
|
||||
constexpr Sequence<As..., Bs...> operator|(Sequence<As...>, Sequence<Bs...>)
|
||||
{
|
||||
using type = Sequence<Is...>;
|
||||
};
|
||||
|
||||
// Two sequences: direct concatenation
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
struct sequence_merge_impl<Sequence<Xs...>, Sequence<Ys...>>
|
||||
{
|
||||
using type = Sequence<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
// Three sequences: direct concatenation (avoids one level of recursion)
|
||||
template <index_t... Xs, index_t... Ys, index_t... Zs>
|
||||
struct sequence_merge_impl<Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>>
|
||||
{
|
||||
using type = Sequence<Xs..., Ys..., Zs...>;
|
||||
};
|
||||
|
||||
// Four sequences: direct concatenation
|
||||
template <index_t... As, index_t... Bs, index_t... Cs, index_t... Ds>
|
||||
struct sequence_merge_impl<Sequence<As...>, Sequence<Bs...>, Sequence<Cs...>, Sequence<Ds...>>
|
||||
{
|
||||
using type = Sequence<As..., Bs..., Cs..., Ds...>;
|
||||
};
|
||||
|
||||
// General case: binary tree reduction (O(log N) depth instead of O(N))
|
||||
template <typename S1, typename S2, typename S3, typename S4, typename... Rest>
|
||||
struct sequence_merge_impl<S1, S2, S3, S4, Rest...>
|
||||
{
|
||||
// Merge pairs first, then recurse
|
||||
using left = typename sequence_merge_impl<S1, S2>::type;
|
||||
using right = typename sequence_merge_impl<S3, S4, Rest...>::type;
|
||||
using type = typename sequence_merge_impl<left, right>::type;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
return {};
|
||||
}
|
||||
|
||||
template <typename... Seqs>
|
||||
struct sequence_merge
|
||||
{
|
||||
using type = typename detail::sequence_merge_impl<Seqs...>::type;
|
||||
// Left fold: ((S1 | S2) | S3) | ... - O(1) template depth
|
||||
using type = decltype((Seqs{} | ...));
|
||||
};
|
||||
|
||||
template <>
|
||||
|
||||
Reference in New Issue
Block a user