Files
composable_kernel/include
juuso-oskari 87658a9518 CK-UA: hoist wave-uniform warp id out of the async-load issue loop
tile_scatter_gather::async_load_raw / async_load_raw_long recompute
get_warp_id() (threadIdx.x/warp_size + a convergent v_readfirstlane) at
every K/V load issue to form the m0 / LDS-wave base. The value is
wave-uniform and constant for the window's lifetime, but LLVM cannot
hoist or CSE it across the load loop: v_readfirstlane is convergent and
the m0 set is an asm-volatile with a memory clobber, which together pin
the recompute to each issue.

Materialize the warp id once at window construction (cached_warp_id_,
set only for the global-memory gather windows that issue these loads)
and read the cached SGPR in both async paths. ISA: the per-issue
s_lshr ÷64 and v_readfirstlane drop out of the loop (warp-id readfirstlane
sites 36 -> 11, the ÷64 shift down to 2 static sites).

Matched sweeps (line-tables-identical codegen, GQA-6 d128 page64):
  bf16 prefill: -7.25% aggregate, 12/12 shapes improved, 0 regressions
                (CK/Triton at sq>=5000 moves ~0.83x -> ~0.90x)
  fp8  all:     -0.4..-1.3% aggregate (fp8 prefill is gated by the
                ds_bpermute repack, which masks the addressing savings)
Correctness vs torch reference: PASS (fp8 + bf16).

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-01 11:05:21 +00:00
..