[rocm-libraries] ROCm/rocm-libraries#5939 (commit 6fb1791)

[CK_TILE] Flatten nested static_for loops into static_ford
 (#5939)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Summary
Mechanical conversion of 129 nested `static_for`/`static_ford` patterns
to flat `static_ford` across 29 ck_tile header files.

Each conversion eliminates intermediate lambda closure instantiations by
replacing nested compile-time loops with a single flat iteration using
index decomposition.

### What `static_ford` eliminates

When `static_for` loops are nested, each level creates unique closure
types:
```cpp
// BEFORE: M + M×N = 20 IR functions (for M=4, N=4)
static_for<0, 4, 1>{}([&](auto m) {        // 4 closure instantiations
    static_for<0, 4, 1>{}([&](auto n) {     // 4×4 = 16 closure instantiations
        body(m, n);
    });
});

// AFTER: M×N = 16 IR functions (with ford_applier, no intermediates)
static_ford<sequence<4, 4>>{}([&](auto mn) {
    constexpr auto m = number<mn[number<0>{}]>{};
    constexpr auto n = number<mn[number<1>{}]>{};
    body(m, n);
});
```

### Pattern categories converted

| Category | Count | Description |
|----------|-------|-------------|
| C (2-level `static_for` chains) | 112 | Nested `static_for` →
`static_ford` |
| C3 (3-level `static_for` chains) | 9 | Three consecutive nests →
`static_ford` |
| Partial rescue | 3 | Outer 2 levels of blocked 4-level nests |
| B (nested `static_ford` merge) | 5 | Two nested `static_ford` → single
higher-dim `static_ford` |
| **Total** | **129** | Across 29 files |

6 false positives were detected and reverted (in `tensor_adaptor.hpp`,
`tile_distribution.hpp`, `tile_distribution_encoding.hpp`) where the
inner loop bound depended on the outer variable.

### Files changed by family

| Family | Files | Sites |
|--------|-------|-------|
| Block GEMM | 12 | ~20 |
| FlatMM pipelines | 4 | ~69 (including 5 ford-ford merges) |
| GEMM quant | 7 | ~22 |
| FlatMM kernel | 1 | 2 |
| FMHA | 1 | 2 |
| Reduce/norm | 2 | 2 |
| Epilogue | 1 | 1 |

### Blocked locations from review comments

- **block_gemm_areg_breg_creg_v1.hpp:356** — BLOCKED: runtime scale
loads (`scale_a_slice`, `scale_b_slice`, A warp tensor load) between
every nesting level
- **block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp:228** — BLOCKED:
`zero_accumulators()` before inner loop; `sched_barrier` + conditional
`block_sync_lds()` after inner loop
- **block_universal_gemm_as_aquant_bs_bquant_cr.hpp:298** — BLOCKED:
runtime `CWarpTensor` construction before inner loop; quantization scale
application code after inner loop
- **block_universal_gemm_as_aquant_bs_cr.hpp:277** — BLOCKED: same
pattern as above
- **block_universal_gemm_as_bs_bquant_cr.hpp:367** — BLOCKED: same
pattern as above

## Depends on
- #5938 ([CK_TILE] Optimize static_ford and sequence compile-time
infrastructure) — provides the `ford_applier` that makes these
conversions beneficial. Without it, `static_ford` uses a recursive
implementation that provides no IR function savings.

## Results (combined with #5938)

### Build Time (Wilcoxon signed-rank, 7 paired trials, gfx942)

| Target | Base (s) | Treat (s) | Delta | % | Significant? |
|--------|----------|-----------|-------|---|-------------|
| **flatmm** | 161.1 | 149.0 | **-12.1s** | **-7.5%** | **YES** (p<0.01,
7/7 wins) |
| **universal_gemm** | 225.4 | 220.3 | **-5.1s** | **-2.3%** | **YES**
(p<0.01, 7/7 wins) |

### IR Function Counts (device trace, gfx942)

| Target | InstFunc | CodeGen |
|--------|----------|---------|
| universal_gemm | **-8.5%** | **-9.2%** |
| flatmm | **-7.6%** | **-10.5%** |

### ASM Equivalence
5/5 PASS — 650,151 lines verified identical (gfx942). TUs:
universal_gemm, flatmm_basic, fmha_bwd, reduce, bscale.

## Test plan
- [x] ASM equivalence verified (650K lines, gfx942)
- [x] Wilcoxon timing verified (7 trials, p<0.01)
- [x] IR function counts verified (-7.6% to -10.5% CodeGen reduction)
- [ ] CI

🤖 Generated with [Claude Code](https://claude.com/claude-code)
This commit is contained in:
Christopher Millette
2026-04-07 14:38:07 +00:00
committed by assistant-librarian[bot]
parent c2ac7aa7b0
commit a170e2bd9d
29 changed files with 2160 additions and 2219 deletions

View File

@@ -288,22 +288,22 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
MIterPerWarp>
a_warp_windows_pong;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
static_ford<sequence<MIterPerWarp, KIterPerWarp>>{}([&](auto mk) {
constexpr auto mIter = number<mk[number<0>{}]>{};
constexpr auto kIter = number<mk[number<1>{}]>{};
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
move_tile_window(a_warp_windows_ping(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
move_tile_window(a_warp_windows_ping(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
static_ford<sequence<MIterPerWarp, KIterPerWarp>>{}([&](auto mk) {
constexpr auto mIter = number<mk[number<0>{}]>{};
constexpr auto kIter = number<mk[number<1>{}]>{};
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
move_tile_window(a_warp_windows_pong(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
move_tile_window(a_warp_windows_pong(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
// Block GEMM
@@ -366,16 +366,16 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// prefetch B
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
static_ford<sequence<NIterPerWarp, KIterPerWarp>>{}([&](auto nk) {
constexpr auto nIter = number<nk[number<0>{}]>{};
constexpr auto kIter = number<nk[number<1>{}]>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
@@ -448,15 +448,15 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
bq_block_tile,
a_warp_windows_ping);
// prefetch B(2i+1)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
constexpr auto kIter = number<kn[number<0>{}]>{};
constexpr auto nIter = number<kn[number<1>{}]>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
aq_block_tile_2 = load_tile(aq_copy_dram_window);
@@ -473,15 +473,15 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
// Next K
// prefetch B(2i+2)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
constexpr auto kIter = number<kn[number<0>{}]>{};
constexpr auto nIter = number<kn[number<1>{}]>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
aq_block_tile = load_tile(aq_copy_dram_window);
@@ -520,16 +520,16 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
if constexpr(TailNum == TailNumber::Even)
{
// prefetch B(loopK)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
constexpr auto kIter = number<kn[number<0>{}]>{};
constexpr auto nIter = number<kn[number<1>{}]>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
aq_block_tile_2 = load_tile(aq_copy_dram_window);
bq_block_tile_2 = load_tile(bq_copy_dram_window);

View File

@@ -275,22 +275,22 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
MIterPerWarp>
a_warp_windows_pong;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
static_ford<sequence<MIterPerWarp, KIterPerWarp>>{}([&](auto mk) {
constexpr auto mIter = number<mk[number<0>{}]>{};
constexpr auto kIter = number<mk[number<1>{}]>{};
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
move_tile_window(a_warp_windows_ping(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
move_tile_window(a_warp_windows_ping(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
static_ford<sequence<MIterPerWarp, KIterPerWarp>>{}([&](auto mk) {
constexpr auto mIter = number<mk[number<0>{}]>{};
constexpr auto kIter = number<mk[number<1>{}]>{};
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
move_tile_window(a_warp_windows_pong(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
move_tile_window(a_warp_windows_pong(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
// Block GEMM
@@ -337,16 +337,16 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// prefetch B
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
static_ford<sequence<NIterPerWarp, KIterPerWarp>>{}([&](auto nk) {
constexpr auto nIter = number<nk[number<0>{}]>{};
constexpr auto kIter = number<nk[number<1>{}]>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
@@ -424,15 +424,15 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
bq_block_tile,
a_warp_windows_ping);
// prefetch B(2i+1)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
constexpr auto kIter = number<kn[number<0>{}]>{};
constexpr auto nIter = number<kn[number<1>{}]>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
@@ -461,15 +461,15 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
// Next K
// prefetch B(2i+2)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
constexpr auto kIter = number<kn[number<0>{}]>{};
constexpr auto nIter = number<kn[number<1>{}]>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
@@ -518,16 +518,16 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
if constexpr(TailNum == TailNumber::Even)
{
// prefetch B(loopK)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
constexpr auto kIter = number<kn[number<0>{}]>{};
constexpr auto nIter = number<kn[number<1>{}]>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
bq_block_tile_2 = load_tile(bq_copy_dram_window);