Files
composable_kernel/include
juuso-oskari c11722bf3e CK-UA: decouple K/V DRAM loads per warp-group + early V read (FA4 fp8)
Split the cooperative K/V cache loads across the two FA4 warp groups so
each group owns exactly one tile's DRAM load and address arithmetic:
WG0 loads V, WG1 loads K, and both read from the shared LDS buffers.

- kFA4WG0LoadsV / kFA4WG1LoadsK policy flags + GetVLoadNumWarps /
  GetKLoadNumWarps: the owning group's 4 waves alone fill the tile via
  4-warp descriptors; the partner skips the load and reads from LDS.
- High-warp-group support for the raw async path: the raw store bakes the
  absolute warp id into the LDS M0, so WG1 (waves 4-7) needs a base shift
  (GetKStoreWarpShift / WarpIdShift in MakeKLdsStoreBlockDescriptor) to
  map back to the 4-warp layout, plus WG-relative (warp % NumWarps) page
  offsets so the gather token positions are correct.
- Stage B: move each tile's V LDS read into the PRECEDING softmax phase so
  the read latency hides under softmax VALU. Safe because V is now single-
  group-owned; uses drain-before-barrier (vmcnt<0> then s_barrier) so all
  4 cooperating writer waves' slices are published before the read.
- Gate per-tile offset refresh per warp-group (WG0 refreshes V, WG1 K), so
  each wave fetches a block-table page index for one tile instead of both;
  loop counters stay uniform.

Validated 0% mismatch vs GPU reference, causal + non-causal, sq 256..8192.
Net latency vs the cooperative baseline: causal ~-3-4.6%, non-causal
~-2-4.7% across sq 2048..16384 (d128 fp8).

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-09 14:35:00 +00:00
..