[rocm-libraries] ROCm/rocm-libraries#5938 (commit 73f3650)

[CK_TILE] Optimize static_ford and sequence compile-time
 infrastructure (#5938)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Problem

Each `static_for<0, N, 1>` instantiates its lambda N times (one per
`number<I>` type). When nested, intermediate lambdas capture the outer
loop variable (a different type per iteration), creating unique closure
types. For a 3-level nest with M=4, N=4, K=2, this produces 4 + 16 + 32
= 52 IR functions, of which 20 are intermediate closures that get
inlined away but still cost frontend compile time.

ck_tile's `static_ford` was supposed to eliminate these intermediates
(as old CK's PR #5031 did successfully), but it used a **recursive**
`static_ford_impl` that recreated the same closure pattern plus added
`reorder_old_to_new`/`reorder_new_to_old` overhead.

Additionally, the sequence utility layer (`sequence_sort`,
`is_valid_sequence_map`) used recursive template metaprogramming that
generated O(N log N) intermediate types for every permutation validation
— called on every `reorder_new_to_old`/`reorder_old_to_new` invocation.

## Changes

### 1. Replace `sequence_sort` with constexpr insertion sort
Replace recursive merge sort (`sequence_sort_impl` +
`sorted_sequence_merge_impl`, O(N log N) intermediate type
instantiations) with constexpr insertion sort using `static_array`. O(1)
template depth, same `::type` and `::sorted2unsorted_map` API.

### 2. Replace `is_valid_sequence_map` with constexpr check
Replace sort-based permutation validation (which instantiated the full
`sequence_sort` chain) with a constexpr "seen array" loop. O(N)
constexpr steps instead of O(N log N) template instantiations.

### 3. Replace recursive `static_ford` with flat-loop `index_decomposer`
Replace `static_ford_impl` (recursive `static_for` nesting +
`pop_front`/`push_back` + `reorder_old_to_new` per iteration) with flat
`index_decomposer` using pre-computed strides. Add `decompose_reordered`
alias that folds reordering into decomposition, and `inverse_perm`
helper that avoids the `sequence_map_inverse` → `is_valid_sequence_map`
→ `sequence_sort` chain.

### 4. Eliminate internal lambda via `ford_applier`
The flat-loop approach still used `static_for` with a lambda, creating
M×N internal lambda instantiations per call site. Replace with
`ford_applier` struct that calls `f(decompose<I>{})` directly via fold
expression — zero intermediate closures:

```cpp
// Before: 2×M×N function instantiations
static_for<0, M*N, 1>{}([&](auto i) { f(decompose<i>{}); });

// After: M×N function instantiations (50% reduction)
ford_applier<Decomposer, make_index_sequence<M*N>>{}(f);
```

Also unified identity and non-identity order paths into a single
template with `constexpr if`.

### 5. Fix const-qualified sequence handling
Fix `is_valid_sequence_map` to handle const-qualified sequence types via
`remove_cvref_t` in callers (`tensor_adaptor.hpp`,
`tile_distribution_encoding.hpp`).

## Results (this PR only, without flattening)

### Build Time (Wilcoxon signed-rank, 7 paired trials, gfx942, load ~5)

| Target | Base (s) | Treat (s) | Delta | % | Wins | Significant? |
|--------|----------|-----------|-------|---|------|-------------|
| **flatmm** | 160.1 | 152.7 | **-7.4s** | **-4.6%** | 6/7 | **YES**
(W+=1, p<0.05) |
| universal_gemm | 228.4 | 224.7 | -3.7s | -1.6% | 6/7 | Trending (W+=4)
|

Per-trial diffs (flatmm): [-6, -20, -9, -8, -8, 4, -5]
Per-trial diffs (universal_gemm): [-2, -6, 4, -3, -2, -11, -6]

### IR Function Counts (device trace, gfx942)

| Target | Metric | Before | After | Delta | % |
|--------|--------|--------|-------|-------|---|
| **universal_gemm** | InstantiateFunction | 117,715 | 109,165 |
**-8,550** | **-7.3%** |
| **universal_gemm** | CodeGen Function | 47,912 | 45,044 | **-2,868** |
**-6.0%** |
| **flatmm** | InstantiateFunction | 100,939 | 95,127 | **-5,812** |
**-5.8%** |
| **flatmm** | CodeGen Function | 42,651 | 40,367 | **-2,284** |
**-5.4%** |

Note: The `ford_applier` (commit 3) has minimal additional effect in
this PR since ck_tile code does not yet use `static_ford` extensively.
Its impact compounds when the follow-up flattening PR #5939 converts 124
`static_for` nests to `static_ford`. Combined results with #5939: flatmm
**-7.5%** wall time (p<0.01), CodeGen **-10.5%**.

### ASM Equivalence
7/7 PASS — 979,943 lines of device assembly verified identical (gfx942 +
gfx1100). TUs: universal_gemm, flatmm_basic, fmha_bwd, reduce, bscale.

## Test plan
- [x] `test_ck_tile_static_ford`: 13 behavioral tests
(identity/non-identity orders, 1D-4D, unit dimensions, edge cases)
- [x] `ck_tile_unit_sequence`: 88 tests (11 new for sorted2unsorted_map,
is_valid_sequence_map edge cases, sequence_unique_sort map round-trip)
- [x] ASM equivalence verified (980K lines)
- [x] Wilcoxon timing verified (7 trials, flatmm p<0.05)
- [ ] CI

🤖 Generated with [Claude Code](https://claude.com/claude-code)
This commit is contained in:
Christopher Millette
2026-04-02 21:25:56 +00:00
committed by assistant-librarian[bot]
parent 7cc9bae9d2
commit 144854dba1
7 changed files with 610 additions and 189 deletions

View File

@@ -355,6 +355,102 @@ TEST(SequenceSort, SortSingleElement)
EXPECT_TRUE((std::is_same<Result, Expected>::value));
}
// Test sequence_sort sorted2unsorted_map (index tracking)
TEST(SequenceSort, SortedMapUnsorted)
{
using Seq = sequence<5, 2, 8, 1, 9>;
using Sort = sequence_sort<Seq, less<index_t>>;
using Map = typename Sort::sorted2unsorted_map;
// sorted = <1,2,5,8,9>, original indices = <3,1,0,2,4>
using Expected = sequence<3, 1, 0, 2, 4>;
EXPECT_TRUE((std::is_same<Map, Expected>::value));
}
TEST(SequenceSort, SortedMapAlreadySorted)
{
using Seq = sequence<1, 2, 3, 4, 5>;
using Sort = sequence_sort<Seq, less<index_t>>;
using Map = typename Sort::sorted2unsorted_map;
// Already sorted: map should be identity
using Expected = sequence<0, 1, 2, 3, 4>;
EXPECT_TRUE((std::is_same<Map, Expected>::value));
}
TEST(SequenceSort, SortedMapRoundTrip)
{
// Verify: sorted_values[i] == original[sorted2unsorted_map[i]]
using Seq = sequence<5, 2, 8, 1, 9>;
using Sort = sequence_sort<Seq, less<index_t>>;
// sorted = <1,2,5,8,9>, map = <3,1,0,2,4>
EXPECT_EQ(Seq::at(Sort::sorted2unsorted_map::at(0)), Sort::type::at(0));
EXPECT_EQ(Seq::at(Sort::sorted2unsorted_map::at(1)), Sort::type::at(1));
EXPECT_EQ(Seq::at(Sort::sorted2unsorted_map::at(2)), Sort::type::at(2));
EXPECT_EQ(Seq::at(Sort::sorted2unsorted_map::at(3)), Sort::type::at(3));
EXPECT_EQ(Seq::at(Sort::sorted2unsorted_map::at(4)), Sort::type::at(4));
}
TEST(SequenceSort, SortedMapWithDuplicates)
{
using Seq = sequence<3, 1, 3, 1>;
using Sort = sequence_sort<Seq, less<index_t>>;
using Sorted = typename Sort::type;
using Map = typename Sort::sorted2unsorted_map;
// sorted = <1,1,3,3>
using ExpectedSorted = sequence<1, 1, 3, 3>;
EXPECT_TRUE((std::is_same<Sorted, ExpectedSorted>::value));
// Verify round-trip: original[map[i]] == sorted[i] for all i
// (don't assert specific index order for duplicates — sort stability may vary)
EXPECT_EQ(Seq::at(Map::at(0)), Sorted::at(0));
EXPECT_EQ(Seq::at(Map::at(1)), Sorted::at(1));
EXPECT_EQ(Seq::at(Map::at(2)), Sorted::at(2));
EXPECT_EQ(Seq::at(Map::at(3)), Sorted::at(3));
}
TEST(SequenceSort, SortedMapReverseSorted)
{
using Seq = sequence<5, 4, 3, 2, 1>;
using Sort = sequence_sort<Seq, less<index_t>>;
using Sorted = typename Sort::type;
using Map = typename Sort::sorted2unsorted_map;
using ExpSorted = sequence<1, 2, 3, 4, 5>;
using ExpMap = sequence<4, 3, 2, 1, 0>;
EXPECT_TRUE((std::is_same<Sorted, ExpSorted>::value));
EXPECT_TRUE((std::is_same<Map, ExpMap>::value));
}
TEST(SequenceSort, SortedMapEmpty)
{
using Sort = sequence_sort<sequence<>, less<index_t>>;
using Map = typename Sort::sorted2unsorted_map;
EXPECT_TRUE((std::is_same<Map, sequence<>>::value));
}
TEST(SequenceSort, SortedMapSingleElement)
{
using Sort = sequence_sort<sequence<42>, less<index_t>>;
using Map = typename Sort::sorted2unsorted_map;
EXPECT_TRUE((std::is_same<Map, sequence<0>>::value));
}
// Test sequence_unique_sort sorted2unsorted_map
TEST(SequenceUniqueSort, UniqueSortMap)
{
using Seq = sequence<3, 1, 4, 1, 5, 9, 2, 6, 5>;
using Result = sequence_unique_sort<Seq, less<index_t>, equal<index_t>>;
using Map = typename Result::sorted2unsorted_map;
// sorted unique = <1,2,3,4,5,6,9>
// The map should reference the first occurrence of each unique value in the original
// Verify round-trip: for each i, original[map[i]] == sorted_unique[i]
using Values = typename Result::type;
EXPECT_EQ(Seq::at(Map::at(0)), Values::at(0)); // 1
EXPECT_EQ(Seq::at(Map::at(1)), Values::at(1)); // 2
EXPECT_EQ(Seq::at(Map::at(2)), Values::at(2)); // 3
EXPECT_EQ(Seq::at(Map::at(3)), Values::at(3)); // 4
EXPECT_EQ(Seq::at(Map::at(4)), Values::at(4)); // 5
EXPECT_EQ(Seq::at(Map::at(5)), Values::at(5)); // 6
EXPECT_EQ(Seq::at(Map::at(6)), Values::at(6)); // 9
}
// Test sequence_unique_sort
TEST(SequenceUniqueSort, UniqueSort)
{
@@ -405,6 +501,24 @@ TEST(SequenceMap, InvalidMapMissing)
EXPECT_FALSE((is_valid_sequence_map<Map>::value));
}
TEST(SequenceMap, InvalidMapNegative)
{
using Map = sequence<0, -1, 2>;
EXPECT_FALSE((is_valid_sequence_map<Map>::value));
}
TEST(SequenceMap, ValidMapSingleElement)
{
EXPECT_TRUE((is_valid_sequence_map<sequence<0>>::value));
}
TEST(SequenceMap, InvalidMapSingleElement)
{
EXPECT_FALSE((is_valid_sequence_map<sequence<1>>::value));
}
TEST(SequenceMap, ValidMapEmpty) { EXPECT_TRUE((is_valid_sequence_map<sequence<>>::value)); }
// Test sequence_map_inverse
// Note: sequence_map_inverse inverts a mapping where Map[i] = j means old position i maps to new
// position j The inverse gives us new position i came from old position inverse[i]