Files
composable_kernel/include/ck_tile/ops/epilogue/chainer
Christopher Millette a170e2bd9d [rocm-libraries] ROCm/rocm-libraries#5939 (commit 6fb1791)
[CK_TILE] Flatten nested static_for loops into static_ford
 (#5939)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Summary
Mechanical conversion of 129 nested `static_for`/`static_ford` patterns
to flat `static_ford` across 29 ck_tile header files.

Each conversion eliminates intermediate lambda closure instantiations by
replacing nested compile-time loops with a single flat iteration using
index decomposition.

### What `static_ford` eliminates

When `static_for` loops are nested, each level creates unique closure
types:
```cpp
// BEFORE: M + M×N = 20 IR functions (for M=4, N=4)
static_for<0, 4, 1>{}([&](auto m) {        // 4 closure instantiations
    static_for<0, 4, 1>{}([&](auto n) {     // 4×4 = 16 closure instantiations
        body(m, n);
    });
});

// AFTER: M×N = 16 IR functions (with ford_applier, no intermediates)
static_ford<sequence<4, 4>>{}([&](auto mn) {
    constexpr auto m = number<mn[number<0>{}]>{};
    constexpr auto n = number<mn[number<1>{}]>{};
    body(m, n);
});
```

### Pattern categories converted

| Category | Count | Description |
|----------|-------|-------------|
| C (2-level `static_for` chains) | 112 | Nested `static_for` →
`static_ford` |
| C3 (3-level `static_for` chains) | 9 | Three consecutive nests →
`static_ford` |
| Partial rescue | 3 | Outer 2 levels of blocked 4-level nests |
| B (nested `static_ford` merge) | 5 | Two nested `static_ford` → single
higher-dim `static_ford` |
| **Total** | **129** | Across 29 files |

6 false positives were detected and reverted (in `tensor_adaptor.hpp`,
`tile_distribution.hpp`, `tile_distribution_encoding.hpp`) where the
inner loop bound depended on the outer variable.

### Files changed by family

| Family | Files | Sites |
|--------|-------|-------|
| Block GEMM | 12 | ~20 |
| FlatMM pipelines | 4 | ~69 (including 5 ford-ford merges) |
| GEMM quant | 7 | ~22 |
| FlatMM kernel | 1 | 2 |
| FMHA | 1 | 2 |
| Reduce/norm | 2 | 2 |
| Epilogue | 1 | 1 |

### Blocked locations from review comments

- **block_gemm_areg_breg_creg_v1.hpp:356** — BLOCKED: runtime scale
loads (`scale_a_slice`, `scale_b_slice`, A warp tensor load) between
every nesting level
- **block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp:228** — BLOCKED:
`zero_accumulators()` before inner loop; `sched_barrier` + conditional
`block_sync_lds()` after inner loop
- **block_universal_gemm_as_aquant_bs_bquant_cr.hpp:298** — BLOCKED:
runtime `CWarpTensor` construction before inner loop; quantization scale
application code after inner loop
- **block_universal_gemm_as_aquant_bs_cr.hpp:277** — BLOCKED: same
pattern as above
- **block_universal_gemm_as_bs_bquant_cr.hpp:367** — BLOCKED: same
pattern as above

## Depends on
- #5938 ([CK_TILE] Optimize static_ford and sequence compile-time
infrastructure) — provides the `ford_applier` that makes these
conversions beneficial. Without it, `static_ford` uses a recursive
implementation that provides no IR function savings.

## Results (combined with #5938)

### Build Time (Wilcoxon signed-rank, 7 paired trials, gfx942)

| Target | Base (s) | Treat (s) | Delta | % | Significant? |
|--------|----------|-----------|-------|---|-------------|
| **flatmm** | 161.1 | 149.0 | **-12.1s** | **-7.5%** | **YES** (p<0.01,
7/7 wins) |
| **universal_gemm** | 225.4 | 220.3 | **-5.1s** | **-2.3%** | **YES**
(p<0.01, 7/7 wins) |

### IR Function Counts (device trace, gfx942)

| Target | InstFunc | CodeGen |
|--------|----------|---------|
| universal_gemm | **-8.5%** | **-9.2%** |
| flatmm | **-7.6%** | **-10.5%** |

### ASM Equivalence
5/5 PASS — 650,151 lines verified identical (gfx942). TUs:
universal_gemm, flatmm_basic, fmha_bwd, reduce, bscale.

## Test plan
- [x] ASM equivalence verified (650K lines, gfx942)
- [x] Wilcoxon timing verified (7 trials, p<0.01)
- [x] IR function counts verified (-7.6% to -10.5% CodeGen reduction)
- [ ] CI

🤖 Generated with [Claude Code](https://claude.com/claude-code)
2026-04-07 14:38:07 +00:00
..

CK Tile Epilogue Chainer

Overview

The Epilogue Chainer provides a modular epilogue processing framework through scheduler-defined operation graphs.

Architecture

Core Design Principle

The chainer follows a Scheduler-Graph-Node architecture with shared context:

  • Scheduler: Defines operation graphs and creates a shared context
  • Graph: Composes multiple operations into sequential processing units
  • Node: Wraps individual epilogue operations with their arguments

EpilogueChainer

The EpilogueChainer struct serves as the modular epilogue processing facilitator. It delegates to schedulers for context creation and schedule generation, then processes the resulting operation graphs.

EpilogueNode

Individual epilogue operations are wrapped in EpilogueNode structures that capture required arguments at construction time and automatically forward them during processing. Supports both parameterized and parameter-free operations.

EpilogueGraph

The EpilogueGraph composes multiple nodes into sequential processing units that iterate over multiple accesses if needed, running all operations in order for each iteration.

Files

Core Infrastructure

  • epilogue_chainer.hpp - General chainer, node, and graph infrastructure
  • common_epilogue_ops.hpp - Epilogue operations usable with any epilogue type

CShuffle Implementation

  • cshuffle_epilogue_chainer_ops.hpp - CShuffle-specific problem, context, and slice operations
  • cshuffle_epilogue_schedule.hpp - CShuffle scheduler with pre-built schedules

Usage

Common Operations (common_epilogue_ops.hpp)

These operations work with any context that provides the standardized interface:

  • ScaleScalarOp - Scale working-tile by scalar values
  • CastAndStoreToLdsOp<DstType> - Cast working-tile and store to LDS
  • LoadFromLdsOp<Pattern> - Load output tile from LDS with sync
  • ElementwiseOp<Func, NumAux> - Apply elementwise operation with auxiliary tensors
  • StoreOp<MemOp> - Store output tile to global memory
  • MoveWindowsOp<SFC, NumAux> - Advance windows to next position

CShuffle-Specific Operations (cshuffle_epilogue_chainer_ops.hpp)

These operations are specific to CShuffle epilogue:

  • CShuffleSliceOp - Slice accumulator tile based on distribution
  • CShuffleScaleWindowOp - Scale using tensor windows with shuffle distribution

Context Interface

Operations communicate through a shared context with standardized members:

  • working_tile: Tile for intermediate computations
  • out_tile: Output tile
  • aux_windows: Tuple of auxiliary tensor windows
  • lds_write_window: Window for writing to LDS
  • lds_read_window: Window for reading from LDS

Schedule Tags

  • DefaultScheduleTag - Standard: Slice → CastStore → Load → ApplyD → Store → Move
  • RowColQuantScheduleTag - With window scaling
  • TensorQuantScheduleTag - With scalar scaling