mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
The K/V async-load width selector (GetKVAlignmentBytes via GetAlignmentK/V) computed its dwordx4 budget from the full block (kBlockSize=512 threads), so the 4 KB FP8 prefill tile never tiled cleanly and fell back to dword. With the FA4 per-warp-group decoupling a single 4-warp group (256 thr) fills the tile by itself -> 4 KB / 256 = exactly 16 B/thr = dwordx4. Thread the load-thread count into GetAlignmentK/V as a NumWarps template param (default = shape NumWarps, so all sizing/paged/decode instantiations are byte-identical). Only the load-path callers (MakeK/VDramTileDistribution, the LDS store/load descriptors, and the kAlignmentK/V DRAM-view vector) pass the decoupled GetK/VLoadNumWarps count to unlock the wide load. Effect: global_load_lds_dword 36->9 dwordx4 and buffer_load_dword 36->9 dwordx4 (both runtime branches); VGPR 181->173; LDS/SGPR unchanged. Accuracy PASS 0% (non-causal + causal). Latency-neutral on sq8192 (kernel is memory-latency bound, not load-issue bound) but a strictly-better instruction/VGPR footprint. Co-authored-by: Cursor <cursoragent@cursor.com>