mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
Previously the 32x32x16 FP8 P-tile cvt and the QK-C -> PV-A cross-lane
swap ran in two separate static_for loops back-to-back inside fmha_alu1:
the whole tile was cvt'd into p.thread_buf_ first, then a second pass
issued one ds_bpermute_b32 per 8-fp8 K-chunk and read/wrote the same
buffer to swap the "bad" 4-byte halves between paired lanes.
The ds_bpermute has nontrivial LDS-DMA latency that the scheduler has
no way to hide when it lives alone in a tight serial loop with the
gather/scatter packs around it.
Fuse the two into one 8-fp8-per-iter loop:
1. cvt 8 fp32 -> 2 packed uint32 (lo_pack=slot[0..3], hi_pack=slot[4..7])
using the chained cvt_pk_fp8_f32 pattern matching cast_tile_pk_fp8_fp32.
2. Pick own_bad = (sub==0 ? hi_pack : lo_pack) and issue ds_bpermute on it.
3. Write back all 8 fp8 bytes; the "good" half lands first so its byte
stores can overlap with the in-flight ds_bpermute, and the next
iter's cvts can begin while the swap is still pending.
The 16x16x32 LDS-roundtrip branch keeps the original separated cvt
loop (no swap latency to hide there since the relayout goes through
LDS, not ds_bpermute).
Single-shape FP8 perf on gfx950 GPU 2 (CUDA graph, 50 iters):
decode d=128 b=4 sq=8 sk=4096: 0.2106 -> 0.1951 ms (-7.4%)
decode d=64 b=4 sq=8 sk=4096: 0.1464 -> 0.1208 ms (-17.5%)
prefill d=128 b=2 sq=512 sk=4k: 0.2558 -> 0.2220 ms (-13.2%)
BF16 unchanged (0.2046 -> 0.2039 ms, within noise).
Correctness: pytest UA correctness suite 405 passed / 80 skipped
(245 BF16/FP16 + 160 FP8), unchanged from before.
Co-authored-by: Cursor <cursoragent@cursor.com>