mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
Split the cooperative K/V cache loads across the two FA4 warp groups so each group owns exactly one tile's DRAM load and address arithmetic: WG0 loads V, WG1 loads K, and both read from the shared LDS buffers. - kFA4WG0LoadsV / kFA4WG1LoadsK policy flags + GetVLoadNumWarps / GetKLoadNumWarps: the owning group's 4 waves alone fill the tile via 4-warp descriptors; the partner skips the load and reads from LDS. - High-warp-group support for the raw async path: the raw store bakes the absolute warp id into the LDS M0, so WG1 (waves 4-7) needs a base shift (GetKStoreWarpShift / WarpIdShift in MakeKLdsStoreBlockDescriptor) to map back to the 4-warp layout, plus WG-relative (warp % NumWarps) page offsets so the gather token positions are correct. - Stage B: move each tile's V LDS read into the PRECEDING softmax phase so the read latency hides under softmax VALU. Safe because V is now single- group-owned; uses drain-before-barrier (vmcnt<0> then s_barrier) so all 4 cooperating writer waves' slices are published before the read. - Gate per-tile offset refresh per warp-group (WG0 refreshes V, WG1 K), so each wave fetches a block-table page index for one tile instead of both; loop counters stay uniform. Validated 0% mismatch vs GPU reference, causal + non-causal, sq 256..8192. Net latency vs the cooperative baseline: causal ~-3-4.6%, non-causal ~-2-4.7% across sq 2048..16384 (d128 fp8). Co-authored-by: Cursor <cursoragent@cursor.com>