[rocm-libraries] ROCm/rocm-libraries#5028 (commit 5131491)

[CK_TILE] Optimize ck_tile::sequence to reduce template
 instantiation depth [2A] (#5028)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Summary

### Rationale
`ck_tile::sequence` is the most fundamental metaprogramming type in
ck_tile — it underpins
tensor dimensions, strides, loop bounds, and index calculations. Six of
its metafunctions
use recursive template instantiation, producing O(N) to O(N²)
intermediate types that
the compiler must process. When these are used inside deeply nested GEMM
pipelines with
large dimension counts, the cumulative instantiation overhead becomes a
significant
contributor to frontend compile time.

Measurements on `test_gemm_pipeline_compv6` show 84,288
`InstantiateFunction` calls in
the frontend alone. Reducing template instantiation depth in these core
utilities has a
multiplicative effect because they are called from hundreds of sites.

### What changed

| Metafunction | Before | After |
|---|---|---|
| `sequence::modify` | O(N) recursive split/merge | O(1) pack expansion
|
| `sequence_gen` | O(log N) recursive binary split | O(1) via
`__make_integer_seq` |
| `uniform_sequence_gen` | Delegates to `sequence_gen` | O(1) via
`__make_integer_seq` |
| `sequence_reverse_inclusive_scan` | O(N) recursive | O(1) constexpr
for-loop + pack expansion |
| `sequence_inclusive_scan` | Computed via reverse + flip | O(1)
constexpr for-loop (unified impl) |
| `sequence_exclusive_scan` | O(N) recursive merge chain | O(1)
constexpr for-loop + pack expansion |
| `sequence_map_inverse` | O(N²) recursive modify calls | O(1) constexpr
for-loop + pack expansion |

Supporting changes:
- Portable `__type_pack_element` fallback with `__has_builtin` guard
(hipRTC-safe, no `<tuple>` dependency)
- Renamed reserved `__integer_sequence` to `integer_sequence_wrapper`
- Adopted `static_array` from develop (PR #4355) for constexpr
computation
- Unified forward and reverse inclusive scan into a single
`sequence_inclusive_scan_impl` with `bool Reverse` template parameter
- Added `sequence_inclusive_scan` struct (new public API for forward
scan direction)
- Replaced recursive `sequence_exclusive_scan` (3 template
specializations) with `sequence_exclusive_scan_impl` using the same
constexpr for-loop pattern as inclusive scan
- Rewired `exclusive_scan_sequence` and `prefix_sum_sequence` to use new
impl
- Added `CK_TILE_HOST_DEVICE` to `exclusive_scan_sequence` and
`prefix_sum_sequence` to match sibling scan function annotations

### Technical debt and housekeeping
- Unified all `namespace impl` to `namespace detail` across sequence.hpp
for consistency
- Removed dead comment block (orphaned `integer_sequence` alternative)
- Added defensive `static_assert(sizeof...(Is) > 0)` in
`sequence_map_inverse::build_inverse`
- Converted all multi-line Doxygen blocks from `///` to `/** */` per
style guide
- Corrected `constexpr static` to `static constexpr` keyword ordering in
`static_array`
- Added blank line between `#pragma once` and first `#include` in
`static_array.hpp`
- Trimmed redundant 4-line comment on `sequence_gen_helper` to a
one-liner
- Moved `sequence_gen` Doxygen comment below `namespace detail` block so
it directly precedes the struct it documents
- Added Doxygen `@brief`/`@tparam`/`@pre` documentation for
`sequence_gen` and `sequence_map_inverse` public APIs
- Added `@brief` documentation to `static_array` explaining relationship
to `ck_tile::array`
- Added scope comment at `namespace detail` openings

**Note:** `private:`/`public:` access modifier indentation is enforced
at 4 spaces by
`.clang-format`. The style guide calls for left-alignment, but the
formatter overrides
this. Requires a `.clang-format` config change to resolve — not
addressable in code.

### `static_array` hardening (from develop's PR #4355)
- Added zero-length array guard (`T elems[N > 0 ? N : 1]`)
- Added `CK_TILE_HOST_DEVICE` annotations to `operator[]` and `size()`
- Added `#include "ck_tile/core/config.hpp"` (IWYU for
`CK_TILE_HOST_DEVICE`)

### Value
Combined with the `static_ford` changes, measured impact on
`test_gemm_pipeline_compv6`:
- **Frontend: -28.9%** (InstantiateFunction: 84,288 → 69,439)
- **Backend: -13.1%** (CodeGen Functions: 3,170 → 2,203)
- **Wall-clock: -16.3%** (611.6s → 512.2s)

### Files changed (4)
- `sequence.hpp`: Metafunction optimizations, namespace unification,
documentation, style fixes
- `static_array.hpp`: Zero-length guard, `CK_TILE_HOST_DEVICE`,
documentation, style fixes
- `test_sequence.cpp`: 50 unit tests with runtime `EXPECT_EQ` assertions
(new file)
- `CMakeLists.txt`: Register new test target

## Test plan
- [x] 50 runtime unit tests covering all optimized and pre-existing
sequence APIs
- [x] Edge cases: empty sequences, single-element, larger sizes (N=8),
negative values, non-trivial init values
- [x] Both functor signatures tested (`operator()(index_t)` and
`operator()(number<I>)`)
- [x] Both scan reducers (`plus`, `multiplies`) with forward, reverse,
inclusive, and exclusive directions
- [x] Exclusive scan: sum, product, single, empty, non-zero init
- [x] Prefix sum: N+1 output verification, single, empty
- [x] Permutation round-trip verification for `sequence_map_inverse`
- [x] Full sequence public API coverage: modify, gen, uniform_gen, scans
(inclusive, exclusive, prefix sum), map_inverse, make_index_sequence,
size/sum/product, push/pop, reverse, extract, merge, arithmetic
operators, equality, transform
- [x] Portable `__type_pack_element` fallback tested implicitly (same
`at_index_t` interface)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Christopher Millette
2026-03-11 20:26:11 +00:00
committed by assistant-librarian[bot]
parent d8ee107a47
commit 56e1d5da08
4 changed files with 858 additions and 145 deletions

View File

@@ -35,16 +35,46 @@ CK_TILE_HOST_DEVICE constexpr auto sequence_pop_front(sequence<I, Is...>);
template <typename Seq>
CK_TILE_HOST_DEVICE constexpr auto sequence_pop_back(Seq);
namespace impl {
// Implementation details for sequence element access and index generation.
namespace detail {
// static_assert(__has_builtin(__type_pack_element), "can't find __type_pack_element");
// O(1) type pack indexing via compiler builtin when available,
// with an O(N) recursive fallback for compilers that lack it (e.g., older MSVC).
// Does not depend on <tuple> so it is safe for hipRTC / GPU codegen.
#if defined(__has_builtin) && __has_builtin(__type_pack_element)
template <index_t I, typename... Ts>
using at_index_t = __type_pack_element<I, Ts...>;
} // namespace impl
#else
template <index_t I, typename T, typename... Ts>
struct type_pack_element_impl
{
using type = typename type_pack_element_impl<I - 1, Ts...>::type;
};
// we could implement as below, similiar to std. But let's reduce the symbol name...
// template< class T, T... Ints >
// class integer_sequence;
template <typename T, typename... Ts>
struct type_pack_element_impl<0, T, Ts...>
{
using type = T;
};
template <index_t I, typename... Ts>
using at_index_t = typename type_pack_element_impl<I, Ts...>::type;
#endif
// Bridge type for __make_integer_seq: converts integer pack to ck_tile::sequence
template <typename T, T... Ints>
struct integer_sequence_wrapper;
template <index_t... Ints>
struct integer_sequence_wrapper<index_t, Ints...>
{
using seq_type = sequence<Ints...>;
};
} // namespace detail
template <index_t N>
using make_index_sequence =
typename __make_integer_seq<detail::integer_sequence_wrapper, index_t, N>::seq_type;
template <index_t... Is>
struct sequence
@@ -59,7 +89,7 @@ struct sequence
CK_TILE_HOST_DEVICE static constexpr auto get()
{
static_assert(I < size(), "wrong! I too large");
return number<impl::at_index_t<I, constant<Is>...>{}>{};
return number<detail::at_index_t<I, constant<Is>...>{}>{};
}
template <index_t I>
@@ -80,7 +110,7 @@ struct sequence
CK_TILE_HOST_DEVICE static constexpr auto at()
{
static_assert(I < size(), "wrong! I too large");
return number<impl::at_index_t<I, constant<Is>...>{}>{};
return number<detail::at_index_t<I, constant<Is>...>{}>{};
}
template <index_t I>
@@ -184,15 +214,19 @@ struct sequence
template <index_t I, index_t X>
CK_TILE_HOST_DEVICE static constexpr auto modify(number<I>, number<X>)
{
static_assert(I < size(), "wrong!");
using seq_split = sequence_split<type, I>;
constexpr auto seq_left = typename seq_split::left_type{};
constexpr auto seq_right = typename seq_split::right_type{}.pop_front();
return seq_left.push_back(number<X>{}).push_back(seq_right);
static_assert(I >= 0 && I < size(), "Index I is out of bounds");
return modify_impl(make_index_sequence<size()>{}, number<I>{}, number<X>{});
}
private:
template <index_t... Idxs, index_t ModifyIdx, index_t NewVal>
CK_TILE_HOST_DEVICE static constexpr auto
modify_impl(sequence<Idxs...>, number<ModifyIdx>, number<NewVal>)
{
return sequence<(Idxs == ModifyIdx ? NewVal : get<Idxs>())...>{};
}
public:
template <typename F>
CK_TILE_HOST_DEVICE static constexpr auto transform(F f)
{
@@ -227,22 +261,6 @@ struct is_sequence<sequence<Is...>> : std::true_type
template <typename T>
inline constexpr bool is_sequence_v = is_sequence<T>::value;
namespace impl {
template <typename T, T... Ints>
struct __integer_sequence;
template <index_t... Ints>
struct __integer_sequence<index_t, Ints...>
{
using seq_type = sequence<Ints...>;
};
} // namespace impl
// similiar
template <index_t N>
using make_index_sequence =
typename __make_integer_seq<impl::__integer_sequence, index_t, N>::seq_type;
// merge sequence
template <typename Seq, typename... Seqs>
struct sequence_merge
@@ -262,36 +280,36 @@ struct sequence_merge<Seq>
using type = Seq;
};
// generate sequence
namespace detail {
// Bridge: converts __make_integer_seq index pack into a sequence via functor application.
template <typename T, T... Ids>
struct sequence_gen_helper
{
template <typename F>
using apply = sequence<F{}(number<Ids>{})...>;
};
} // namespace detail
/**
* @brief Generate a compile-time sequence by applying a functor to indices 0..N-1.
* @tparam NSize Number of elements in the generated sequence.
* @tparam F Functor type; must be default-constructible with a constexpr call operator
* accepting number<I> (or index_t via implicit conversion) and returning index_t.
* Lambdas with captures cannot be used; use a template struct functor instead.
*/
template <index_t NSize, typename F>
struct sequence_gen
{
template <index_t IBegin, index_t NRemain, typename G>
struct sequence_gen_impl
{
static constexpr index_t NRemainLeft = NRemain / 2;
static constexpr index_t NRemainRight = NRemain - NRemainLeft;
static constexpr index_t IMiddle = IBegin + NRemainLeft;
using type =
typename __make_integer_seq<detail::sequence_gen_helper, index_t, NSize>::template apply<F>;
};
using type = typename sequence_merge<
typename sequence_gen_impl<IBegin, NRemainLeft, G>::type,
typename sequence_gen_impl<IMiddle, NRemainRight, G>::type>::type;
};
template <index_t I, typename G>
struct sequence_gen_impl<I, 1, G>
{
static constexpr index_t Is = G{}(number<I>{});
using type = sequence<Is>;
};
template <index_t I, typename G>
struct sequence_gen_impl<I, 0, G>
{
using type = sequence<>;
};
using type = typename sequence_gen_impl<0, NSize, F>::type;
template <typename F>
struct sequence_gen<0, F>
{
using type = sequence<>;
};
// arithmetic sequence
@@ -321,20 +339,34 @@ struct arithmetic_sequence_gen<0, IEnd, 1>
using type = make_index_sequence<IEnd>;
};
// uniform sequence
// uniform sequence - optimized using __make_integer_seq
namespace detail {
template <typename T, T... Ids>
struct uniform_sequence_helper
{
// Comma operator: discard Ids, produce Value for each element
template <index_t Value>
using apply = sequence<((void)Ids, Value)...>;
};
} // namespace detail
template <index_t NSize, index_t I>
struct uniform_sequence_gen
{
struct F
{
CK_TILE_HOST_DEVICE constexpr index_t operator()(index_t) const { return I; }
};
using type = typename sequence_gen<NSize, F>::type;
using type = typename __make_integer_seq<detail::uniform_sequence_helper, index_t, NSize>::
template apply<I>;
};
// inclusive scan (with init) sequence
namespace impl {
template <index_t I>
struct uniform_sequence_gen<0, I>
{
using type = sequence<>;
};
// inclusive scan (with init) sequence - optimized using constexpr for-loop with static_array
namespace detail {
template <typename Seq, typename Reduce, index_t Init, bool Reverse>
struct sequence_inclusive_scan_impl;
@@ -352,6 +384,9 @@ struct sequence_inclusive_scan_impl<sequence<Is...>, Reduce, Init, Reverse>
}
else
{
// Compute all scan values in a single constexpr evaluation using
// static_array, then unpack via index expansion. Avoids O(N) recursive
// template instantiation.
constexpr auto arr = []() {
static_array<index_t, size> values = {Is...};
static_array<index_t, size> result = {0};
@@ -381,18 +416,53 @@ struct sequence_inclusive_scan_impl<sequence<Is...>, Reduce, Init, Reverse>
using type = decltype(compute(make_index_sequence<sizeof...(Is)>{}));
};
} // namespace impl
// Exclusive scan: result[0] = Init, result[i] = Reduce(values[i-1], result[i-1]) for i > 0.
template <typename Seq, typename Reduce, index_t Init>
struct sequence_exclusive_scan_impl;
template <index_t... Is, typename Reduce, index_t Init>
struct sequence_exclusive_scan_impl<sequence<Is...>, Reduce, Init>
{
template <index_t... Indices>
static constexpr auto compute(sequence<Indices...>)
{
constexpr index_t size = sizeof...(Is);
if constexpr(size == 0)
{
return sequence<>{};
}
else
{
constexpr auto arr = []() {
static_array<index_t, size> values = {Is...};
static_array<index_t, size> result = {0};
result[0] = Init;
for(index_t i = 1; i < size; ++i)
{
result[i] = Reduce{}(values[i - 1], result[i - 1]);
}
return result;
}();
return sequence<arr[Indices]...>{};
}
}
using type = decltype(compute(make_index_sequence<sizeof...(Is)>{}));
};
} // namespace detail
template <typename Seq, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan
{
using type = typename impl::sequence_inclusive_scan_impl<Seq, Reduce, Init, true>::type;
using type = typename detail::sequence_inclusive_scan_impl<Seq, Reduce, Init, true>::type;
};
template <typename Seq, typename Reduce, index_t Init>
struct sequence_inclusive_scan
{
using type = typename impl::sequence_inclusive_scan_impl<Seq, Reduce, Init, false>::type;
using type = typename detail::sequence_inclusive_scan_impl<Seq, Reduce, Init, false>::type;
};
// split sequence
@@ -434,7 +504,7 @@ struct sequence_reverse<sequence<I0, I1>>
};
#endif
namespace impl {
namespace detail {
template <typename Id, index_t... Ns>
struct seq_reverse;
@@ -442,14 +512,14 @@ template <index_t... Ids, index_t... Ns>
struct seq_reverse<sequence<Ids...>, Ns...>
{
template <index_t I>
using element = impl::at_index_t<I, constant<Ns>...>;
using element = detail::at_index_t<I, constant<Ns>...>;
using type = sequence<element<(sizeof...(Ns) - 1 - Ids)>::value...>;
};
} // namespace impl
} // namespace detail
template <index_t... Ns>
struct sequence_reverse<sequence<Ns...>>
: impl::seq_reverse<make_index_sequence<sizeof...(Ns)>, Ns...>
: detail::seq_reverse<make_index_sequence<sizeof...(Ns)>, Ns...>
{
};
@@ -719,31 +789,48 @@ struct is_valid_sequence_map
{
};
template <typename SeqMap>
struct sequence_map_inverse
/**
* @brief Compute the inverse permutation of a sequence map.
* @tparam Is A valid permutation of {0, 1, ..., N-1}.
* @pre Input must satisfy is_valid_sequence_map (enforced by static_assert).
*
* Optimized using constexpr for-loop: O(1) template instantiation depth instead of O(N).
*/
template <index_t... Is>
struct sequence_map_inverse<sequence<Is...>>
{
template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
struct sequence_map_inverse_impl
static_assert(is_valid_sequence_map<sequence<Is...>>::value,
"sequence_map_inverse requires a valid permutation sequence map");
private:
static constexpr auto build_inverse()
{
static constexpr auto new_y2x =
WorkingY2X::modify(X2Y::get(number<XBegin>{}), number<XBegin>{});
static_assert(sizeof...(Is) > 0, "build_inverse requires non-empty sequence");
static_array<index_t, sizeof...(Is)> result = {0};
constexpr index_t input[] = {Is...};
for(index_t pos = 0; pos < static_cast<index_t>(sizeof...(Is)); ++pos)
{
result[input[pos]] = pos;
}
return result;
}
using type =
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::
type;
};
static constexpr auto inverse = build_inverse();
template <typename X2Y, typename WorkingY2X, index_t XBegin>
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
template <index_t... Positions>
static constexpr auto compute(sequence<Positions...>)
{
using type = WorkingY2X;
};
return sequence<inverse[Positions]...>{};
}
using type =
typename sequence_map_inverse_impl<SeqMap,
typename uniform_sequence_gen<SeqMap::size(), 0>::type,
0,
SeqMap::size()>::type;
public:
using type = decltype(compute(make_index_sequence<sizeof...(Is)>{}));
};
template <>
struct sequence_map_inverse<sequence<>>
{
using type = sequence<>;
};
template <index_t... Xs, index_t... Ys>
@@ -922,43 +1009,18 @@ CK_TILE_HOST_DEVICE constexpr auto inclusive_scan_sequence(Seq, Reduce, number<I
}
// e.g. Seq<2, 3, 4> --> Seq<0, 2, 5>, Init=0, Reduce=Add
// ResultSeq TargetSeq Reduce
template <typename, typename, typename>
struct sequence_exclusive_scan;
template <index_t... Xs, index_t Y, index_t... Ys, typename Reduce>
struct sequence_exclusive_scan<sequence<Xs...>, sequence<Y, Ys...>, Reduce>
{
using old_scan = typename sequence_merge<sequence<Xs...>,
sequence<Reduce{}(Y, sequence<Xs...>{}.back())>>::type;
using type = typename sequence_exclusive_scan<old_scan, sequence<Ys...>, Reduce>::type;
};
template <index_t... Xs, index_t Y, typename Reduce>
struct sequence_exclusive_scan<sequence<Xs...>, sequence<Y>, Reduce>
{
using type = sequence<Xs...>;
};
template <index_t... Xs, typename Reduce>
struct sequence_exclusive_scan<sequence<Xs...>, sequence<>, Reduce>
{
using type = sequence<Xs...>;
};
template <typename Seq, typename Reduce, index_t Init>
constexpr auto exclusive_scan_sequence(Seq, Reduce, number<Init>)
CK_TILE_HOST_DEVICE constexpr auto exclusive_scan_sequence(Seq, Reduce, number<Init>)
{
// TODO: c++20 and later can pass in Reduce with a lambda expression
return typename sequence_exclusive_scan<sequence<Init>, Seq, Reduce>::type{};
return typename detail::sequence_exclusive_scan_impl<Seq, Reduce, Init>::type{};
}
// e.g. Seq<2, 3, 4> --> Seq<0, 2, 5, 9> (N+1 elements: prefix sums including both endpoints)
template <typename Seq>
constexpr auto prefix_sum_sequence(Seq)
CK_TILE_HOST_DEVICE constexpr auto prefix_sum_sequence(Seq)
{
return typename sequence_exclusive_scan<sequence<0>,
typename sequence_merge<Seq, sequence<0>>::type,
plus<index_t>>::type{};
using extended = typename sequence_merge<Seq, sequence<0>>::type;
return typename detail::sequence_exclusive_scan_impl<extended, plus<index_t>, 0>::type{};
}
template <typename Seq, index_t... Is>
@@ -1177,7 +1239,7 @@ CK_TILE_HOST_DEVICE constexpr auto generate_array(F&& f, number<N>)
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
namespace impl {
namespace detail {
template <typename, typename, typename, index_t>
struct reverse_slice_sequence_impl;
@@ -1239,7 +1301,7 @@ struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, Slice
static constexpr index_t split_idx =
std::conditional_t<split_flag, number<id>, number<0>>::value;
};
} // namespace impl
} // namespace detail
// clang-format off
// input a sequence(with optional mask), and the SliceSize : size per slice
@@ -1288,11 +1350,11 @@ constexpr auto reverse_slice_sequence(Seq,
SliceSize ==
0,
"slice size can't evenly divide input sizes");
using sliced_type =
impl::reverse_slice_sequence_impl<Seq,
Mask,
typename arithmetic_sequence_gen<0, Seq::size(), 1>::type,
SliceSize>;
using sliced_type = detail::reverse_slice_sequence_impl<
Seq,
Mask,
typename arithmetic_sequence_gen<0, Seq::size(), 1>::type,
SliceSize>;
static_assert(sliced_type::remaining_slice_sizes::front().value == 1,
"can not evenly divide this sequence, please check");
return make_tuple(typename sliced_type::dim_lengths{},

View File

@@ -2,29 +2,31 @@
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
namespace ck_tile {
// Fixed-size array with aggregate initialization
//
// This is a minimal array type designed for:
// - Constexpr/compile-time computation
// - GPU kernel code (trivially copyable)
// - Template metaprogramming
//
// Unlike ck_tile::array, this has no custom constructors,
// making it a literal type suitable for constexpr contexts.
// Use aggregate initialization: static_array<int, 3> arr{1, 2, 3};
/**
* @brief Fixed-size array with aggregate initialization for constexpr contexts.
*
* Unlike ck_tile::array, this has no custom constructors, making it a literal type
* suitable for constexpr evaluation and GPU kernel code. Use ck_tile::array when
* constructors or non-trivial initialization are needed.
* Use aggregate initialization: static_array<int, 3> arr{1, 2, 3};
*/
template <typename T, index_t N>
struct static_array
{
// Public aggregate initialization makes this a literal type
T elems[N];
// Public aggregate initialization makes this a literal type.
// N == 0 uses size 1 to avoid zero-length arrays (non-standard).
T elems[N > 0 ? N : 1];
// Basic constexpr accessors
constexpr const T& operator[](index_t i) const { return elems[i]; }
constexpr T& operator[](index_t i) { return elems[i]; }
CK_TILE_HOST_DEVICE constexpr const T& operator[](index_t i) const { return elems[i]; }
CK_TILE_HOST_DEVICE constexpr T& operator[](index_t i) { return elems[i]; }
constexpr static index_t size() { return N; }
CK_TILE_HOST_DEVICE static constexpr index_t size() { return N; }
};
} // namespace ck_tile