mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
The Tier-2 LDS page-table cache bulk-loaded the whole per-split window [split_start_page, split_end_page) up front and trapped on `split_window_pages <= kPageTableLdsEntries` (4096). That capped max KV length at 4096*page_size per split -- e.g. decode sk=131072 page=16 (no/few splits) or long-context prefill (num_splits==1) device-asserted instead of running. Replace the fixed window with a sliding one: - `lds_window_base` is the absolute page index of LDS entry 0; every refresh reads block_tables_lds[abs_page - lds_window_base] (was - split_start_page). - The initial load fills min(split_window_pages, 4096) entries. - slide_page_table() runs at the existing CTA convergence barriers (single-WG decode loop + FA4 prefill slot-A barrier, both branches symmetric). The slide predicate is a pure function of the wave-uniform tile bookkeeping (k_block_idx/v_block_idx) + page geometry, so every wave evaluates it identically and the two internal s_barriers stay matched across both warp groups -- no divergence, no deadlock. It slides the window to the lowest-needed page when the in-flight + next-prefetched tiles would leave the resident range, keeping the lagging consumer (V trails K) covered. - Loop-invariant early-out (split_window_pages <= 4096): the hot, fits-in-LDS path skips the slide entirely (one predicted branch, no per-tile divides), so steady-state codegen is bit-identical to before. Validated on gfx950: previously-asserting decode sk=131072 page=16 (8192 pages, ~2 slides) and chunked prefill sq=2048 sk=80000 page=16 (5000 pages, FA4 slide) now PASS vs torch ref. Full correctness matrix 263/263 PASS. Canonical perf unchanged: bf16 paged ~10.0 ms (1.18x), fp8 paged 4.66 ms (1.65x); production shape sweep within run-to-run noise. Co-authored-by: Cursor <cursoragent@cursor.com>