mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
tile_scatter_gather::async_load_raw / async_load_raw_long recompute
get_warp_id() (threadIdx.x/warp_size + a convergent v_readfirstlane) at
every K/V load issue to form the m0 / LDS-wave base. The value is
wave-uniform and constant for the window's lifetime, but LLVM cannot
hoist or CSE it across the load loop: v_readfirstlane is convergent and
the m0 set is an asm-volatile with a memory clobber, which together pin
the recompute to each issue.
Materialize the warp id once at window construction (cached_warp_id_,
set only for the global-memory gather windows that issue these loads)
and read the cached SGPR in both async paths. ISA: the per-issue
s_lshr ÷64 and v_readfirstlane drop out of the loop (warp-id readfirstlane
sites 36 -> 11, the ÷64 shift down to 2 static sites).
Matched sweeps (line-tables-identical codegen, GQA-6 d128 page64):
bf16 prefill: -7.25% aggregate, 12/12 shapes improved, 0 regressions
(CK/Triton at sq>=5000 moves ~0.83x -> ~0.90x)
fp8 all: -0.4..-1.3% aggregate (fp8 prefill is gated by the
ds_bpermute repack, which masks the addressing savings)
Correctness vs torch reference: PASS (fp8 + bf16).
Co-authored-by: Cursor <cursoragent@cursor.com>