[rocm-libraries] ROCm/rocm-libraries#5938 (commit 73f3650)

[CK_TILE] Optimize static_ford and sequence compile-time
 infrastructure (#5938)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Problem

Each `static_for<0, N, 1>` instantiates its lambda N times (one per
`number<I>` type). When nested, intermediate lambdas capture the outer
loop variable (a different type per iteration), creating unique closure
types. For a 3-level nest with M=4, N=4, K=2, this produces 4 + 16 + 32
= 52 IR functions, of which 20 are intermediate closures that get
inlined away but still cost frontend compile time.

ck_tile's `static_ford` was supposed to eliminate these intermediates
(as old CK's PR #5031 did successfully), but it used a **recursive**
`static_ford_impl` that recreated the same closure pattern plus added
`reorder_old_to_new`/`reorder_new_to_old` overhead.

Additionally, the sequence utility layer (`sequence_sort`,
`is_valid_sequence_map`) used recursive template metaprogramming that
generated O(N log N) intermediate types for every permutation validation
— called on every `reorder_new_to_old`/`reorder_old_to_new` invocation.

## Changes

### 1. Replace `sequence_sort` with constexpr insertion sort
Replace recursive merge sort (`sequence_sort_impl` +
`sorted_sequence_merge_impl`, O(N log N) intermediate type
instantiations) with constexpr insertion sort using `static_array`. O(1)
template depth, same `::type` and `::sorted2unsorted_map` API.

### 2. Replace `is_valid_sequence_map` with constexpr check
Replace sort-based permutation validation (which instantiated the full
`sequence_sort` chain) with a constexpr "seen array" loop. O(N)
constexpr steps instead of O(N log N) template instantiations.

### 3. Replace recursive `static_ford` with flat-loop `index_decomposer`
Replace `static_ford_impl` (recursive `static_for` nesting +
`pop_front`/`push_back` + `reorder_old_to_new` per iteration) with flat
`index_decomposer` using pre-computed strides. Add `decompose_reordered`
alias that folds reordering into decomposition, and `inverse_perm`
helper that avoids the `sequence_map_inverse` → `is_valid_sequence_map`
→ `sequence_sort` chain.

### 4. Eliminate internal lambda via `ford_applier`
The flat-loop approach still used `static_for` with a lambda, creating
M×N internal lambda instantiations per call site. Replace with
`ford_applier` struct that calls `f(decompose<I>{})` directly via fold
expression — zero intermediate closures:

```cpp
// Before: 2×M×N function instantiations
static_for<0, M*N, 1>{}([&](auto i) { f(decompose<i>{}); });

// After: M×N function instantiations (50% reduction)
ford_applier<Decomposer, make_index_sequence<M*N>>{}(f);
```

Also unified identity and non-identity order paths into a single
template with `constexpr if`.

### 5. Fix const-qualified sequence handling
Fix `is_valid_sequence_map` to handle const-qualified sequence types via
`remove_cvref_t` in callers (`tensor_adaptor.hpp`,
`tile_distribution_encoding.hpp`).

## Results (this PR only, without flattening)

### Build Time (Wilcoxon signed-rank, 7 paired trials, gfx942, load ~5)

| Target | Base (s) | Treat (s) | Delta | % | Wins | Significant? |
|--------|----------|-----------|-------|---|------|-------------|
| **flatmm** | 160.1 | 152.7 | **-7.4s** | **-4.6%** | 6/7 | **YES**
(W+=1, p<0.05) |
| universal_gemm | 228.4 | 224.7 | -3.7s | -1.6% | 6/7 | Trending (W+=4)
|

Per-trial diffs (flatmm): [-6, -20, -9, -8, -8, 4, -5]
Per-trial diffs (universal_gemm): [-2, -6, 4, -3, -2, -11, -6]

### IR Function Counts (device trace, gfx942)

| Target | Metric | Before | After | Delta | % |
|--------|--------|--------|-------|-------|---|
| **universal_gemm** | InstantiateFunction | 117,715 | 109,165 |
**-8,550** | **-7.3%** |
| **universal_gemm** | CodeGen Function | 47,912 | 45,044 | **-2,868** |
**-6.0%** |
| **flatmm** | InstantiateFunction | 100,939 | 95,127 | **-5,812** |
**-5.8%** |
| **flatmm** | CodeGen Function | 42,651 | 40,367 | **-2,284** |
**-5.4%** |

Note: The `ford_applier` (commit 3) has minimal additional effect in
this PR since ck_tile code does not yet use `static_ford` extensively.
Its impact compounds when the follow-up flattening PR #5939 converts 124
`static_for` nests to `static_ford`. Combined results with #5939: flatmm
**-7.5%** wall time (p<0.01), CodeGen **-10.5%**.

### ASM Equivalence
7/7 PASS — 979,943 lines of device assembly verified identical (gfx942 +
gfx1100). TUs: universal_gemm, flatmm_basic, fmha_bwd, reduce, bscale.

## Test plan
- [x] `test_ck_tile_static_ford`: 13 behavioral tests
(identity/non-identity orders, 1D-4D, unit dimensions, edge cases)
- [x] `ck_tile_unit_sequence`: 88 tests (11 new for sorted2unsorted_map,
is_valid_sequence_map edge cases, sequence_unique_sort map round-trip)
- [x] ASM equivalence verified (980K lines)
- [x] Wilcoxon timing verified (7 trials, flatmm p<0.05)
- [ ] CI

🤖 Generated with [Claude Code](https://claude.com/claude-code)
This commit is contained in:
Christopher Millette
2026-04-02 21:25:56 +00:00
committed by assistant-librarian[bot]
parent 7cc9bae9d2
commit 144854dba1
7 changed files with 610 additions and 189 deletions

View File

@@ -39,7 +39,7 @@ CK_TILE_HOST_DEVICE constexpr auto
container_reorder_given_new2old(const array<TData, NSize>& old_array, sequence<IRs...> /*new2old*/)
{
static_assert(NSize == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
static_assert(is_valid_sequence_map<sequence<IRs...>>::value, "wrong! invalid reorder map");
return make_array<remove_cvref_t<TData>>(old_array[IRs]...);
}
@@ -89,7 +89,7 @@ CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(const tuple<T
{
static_assert(sizeof...(Ts) == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
static_assert(is_valid_sequence_map<sequence<IRs...>>::value, "wrong! invalid reorder map");
return make_tuple(old_tuple[number<IRs>{}]...);
}
@@ -109,7 +109,7 @@ CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(sequence<Is..
{
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
static_assert(is_valid_sequence_map<sequence<IRs...>>::value, "wrong! invalid reorder map");
return sequence<sequence<Is...>::at(number<IRs>{})...>{};
}
@@ -120,7 +120,7 @@ CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(sequence<Is..
{
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
static_assert(is_valid_sequence_map<sequence<IRs...>>::value, "wrong! invalid reorder map");
constexpr auto new2old = typename sequence_map_inverse<sequence<IRs...>>::type{};

View File

@@ -144,9 +144,11 @@ struct sequence
static_assert(MapOld2New::size() == size(),
"wrong! reorder map should have the same size as sequence to be rerodered");
static_assert(is_valid_sequence_map<MapOld2New>::value, "wrong! invalid reorder map");
static_assert(is_valid_sequence_map<remove_cvref_t<MapOld2New>>::value,
"wrong! invalid reorder map");
return reorder_new_to_old(typename sequence_map_inverse<MapOld2New>::type{});
return reorder_new_to_old(
typename sequence_map_inverse<remove_cvref_t<MapOld2New>>::type{});
}
CK_TILE_HOST_DEVICE static constexpr auto reverse()
@@ -548,163 +550,59 @@ struct sequence_reduce<Reduce, Seq>
};
#endif
template <typename Values, typename Ids, typename Compare>
struct sequence_sort_impl
// Sorts a sequence using constexpr insertion sort. O(1) template instantiation
// depth, replacing the recursive merge sort that created O(N log N) intermediate types.
namespace detail {
template <typename Values, typename Compare, typename IndexSeq>
struct sequence_sort_helper;
template <index_t... Vs, typename Compare, index_t... Idx>
struct sequence_sort_helper<sequence<Vs...>, Compare, sequence<Idx...>>
{
template <typename LeftValues,
typename LeftIds,
typename RightValues,
typename RightIds,
typename MergedValues,
typename MergedIds,
typename Comp>
struct sorted_sequence_merge_impl
struct sort_result
{
static constexpr bool choose_left = LeftValues::front() < RightValues::front();
static constexpr index_t chosen_value =
choose_left ? LeftValues::front() : RightValues::front();
static constexpr index_t chosen_id = choose_left ? LeftIds::front() : RightIds::front();
using new_merged_values = decltype(MergedValues::push_back(number<chosen_value>{}));
using new_merged_ids = decltype(MergedIds::push_back(number<chosen_id>{}));
using new_left_values = typename std::
conditional<choose_left, decltype(LeftValues::pop_front()), LeftValues>::type;
using new_left_ids =
typename std::conditional<choose_left, decltype(LeftIds::pop_front()), LeftIds>::type;
using new_right_values = typename std::
conditional<choose_left, RightValues, decltype(RightValues::pop_front())>::type;
using new_right_ids =
typename std::conditional<choose_left, RightIds, decltype(RightIds::pop_front())>::type;
using merge = sorted_sequence_merge_impl<new_left_values,
new_left_ids,
new_right_values,
new_right_ids,
new_merged_values,
new_merged_ids,
Comp>;
// this is output
using merged_values = typename merge::merged_values;
using merged_ids = typename merge::merged_ids;
static_array<index_t, sizeof...(Vs)> values;
static_array<index_t, sizeof...(Vs)> ids;
};
template <typename LeftValues,
typename LeftIds,
typename MergedValues,
typename MergedIds,
typename Comp>
struct sorted_sequence_merge_impl<LeftValues,
LeftIds,
sequence<>,
sequence<>,
MergedValues,
MergedIds,
Comp>
static constexpr sort_result compute()
{
using merged_values = typename sequence_merge<MergedValues, LeftValues>::type;
using merged_ids = typename sequence_merge<MergedIds, LeftIds>::type;
};
constexpr index_t n = sizeof...(Vs);
sort_result r{{{Vs...}}, {{Idx...}}};
// insertion sort — O(N^2) constexpr steps, O(1) template depth
for(index_t i = 1; i < n; ++i)
{
for(index_t j = i; j > 0 && Compare{}(r.values[j], r.values[j - 1]); --j)
{
auto tv = r.values[j];
r.values[j] = r.values[j - 1];
r.values[j - 1] = tv;
auto ti = r.ids[j];
r.ids[j] = r.ids[j - 1];
r.ids[j - 1] = ti;
}
}
return r;
}
template <typename RightValues,
typename RightIds,
typename MergedValues,
typename MergedIds,
typename Comp>
struct sorted_sequence_merge_impl<sequence<>,
sequence<>,
RightValues,
RightIds,
MergedValues,
MergedIds,
Comp>
{
using merged_values = typename sequence_merge<MergedValues, RightValues>::type;
using merged_ids = typename sequence_merge<MergedIds, RightIds>::type;
};
template <typename LeftValues,
typename LeftIds,
typename RightValues,
typename RightIds,
typename Comp>
struct sorted_sequence_merge
{
using merge = sorted_sequence_merge_impl<LeftValues,
LeftIds,
RightValues,
RightIds,
sequence<>,
sequence<>,
Comp>;
using merged_values = typename merge::merged_values;
using merged_ids = typename merge::merged_ids;
};
static constexpr index_t nsize = Values::size();
using split_unsorted_values = sequence_split<Values, nsize / 2>;
using split_unsorted_ids = sequence_split<Ids, nsize / 2>;
using left_unsorted_values = typename split_unsorted_values::left_type;
using left_unsorted_ids = typename split_unsorted_ids::left_type;
using left_sort = sequence_sort_impl<left_unsorted_values, left_unsorted_ids, Compare>;
using left_sorted_values = typename left_sort::sorted_values;
using left_sorted_ids = typename left_sort::sorted_ids;
using right_unsorted_values = typename split_unsorted_values::right_type;
using right_unsorted_ids = typename split_unsorted_ids::right_type;
using right_sort = sequence_sort_impl<right_unsorted_values, right_unsorted_ids, Compare>;
using right_sorted_values = typename right_sort::sorted_values;
using right_sorted_ids = typename right_sort::sorted_ids;
using merged_sorted = sorted_sequence_merge<left_sorted_values,
left_sorted_ids,
right_sorted_values,
right_sorted_ids,
Compare>;
using sorted_values = typename merged_sorted::merged_values;
using sorted_ids = typename merged_sorted::merged_ids;
static constexpr sort_result sorted = compute();
using sorted_values = sequence<sorted.values[Idx]...>;
using sorted_ids = sequence<sorted.ids[Idx]...>;
};
template <index_t ValueX, index_t ValueY, index_t IdX, index_t IdY, typename Compare>
struct sequence_sort_impl<sequence<ValueX, ValueY>, sequence<IdX, IdY>, Compare>
{
static constexpr bool choose_x = Compare{}(ValueX, ValueY);
using sorted_values = typename std::
conditional<choose_x, sequence<ValueX, ValueY>, sequence<ValueY, ValueX>>::type;
using sorted_ids =
typename std::conditional<choose_x, sequence<IdX, IdY>, sequence<IdY, IdX>>::type;
};
template <index_t Value, index_t Id, typename Compare>
struct sequence_sort_impl<sequence<Value>, sequence<Id>, Compare>
{
using sorted_values = sequence<Value>;
using sorted_ids = sequence<Id>;
};
template <typename Compare>
struct sequence_sort_impl<sequence<>, sequence<>, Compare>
{
using sorted_values = sequence<>;
using sorted_ids = sequence<>;
};
} // namespace detail
template <typename Values, typename Compare>
struct sequence_sort
{
using unsorted_ids = typename arithmetic_sequence_gen<0, Values::size(), 1>::type;
using sort = sequence_sort_impl<Values, unsorted_ids, Compare>;
static constexpr index_t n = Values::size();
using idx_seq = make_index_sequence<n>;
// this is output
using type = typename sort::sorted_values;
using sorted2unsorted_map = typename sort::sorted_ids;
using helper = detail::sequence_sort_helper<remove_cvref_t<Values>, Compare, idx_seq>;
using type = typename helper::sorted_values;
using sorted2unsorted_map = typename helper::sorted_ids;
};
template <typename Values, typename Less, typename Equal>
@@ -782,10 +680,42 @@ struct sequence_unique_sort
using sorted2unsorted_map = typename uniquify::uniquified_ids;
};
// Validates that a sequence is a permutation of {0, 1, ..., N-1}.
// Uses a constexpr loop instead of instantiating sequence_sort.
namespace detail {
template <index_t... Is>
constexpr bool check_valid_sequence_map()
{
constexpr index_t n = sizeof...(Is);
if constexpr(n == 0)
{
return true;
}
else
{
constexpr index_t vals[] = {Is...};
static_array<bool, n> seen{};
for(index_t i = 0; i < n; ++i)
{
if(vals[i] < 0 || vals[i] >= n || seen[vals[i]])
return false;
seen[vals[i]] = true;
}
return true;
}
}
} // namespace detail
template <typename SeqMap>
struct is_valid_sequence_map
: std::is_same<typename arithmetic_sequence_gen<0, SeqMap::size(), 1>::type,
typename sequence_sort<SeqMap, less<index_t>>::type>
struct is_valid_sequence_map : std::false_type
{
};
template <index_t... Is>
struct is_valid_sequence_map<sequence<Is...>>
: std::integral_constant<bool, detail::check_valid_sequence_map<Is...>()>
{
};