Files
composable_kernel/include
juuso-oskari be398c224f CK-UA: sliding-window page-table cache to lift the max-KV-length ceiling
The Tier-2 LDS page-table cache bulk-loaded the whole per-split window
[split_start_page, split_end_page) up front and trapped on
`split_window_pages <= kPageTableLdsEntries` (4096). That capped max KV length
at 4096*page_size per split -- e.g. decode sk=131072 page=16 (no/few splits)
or long-context prefill (num_splits==1) device-asserted instead of running.

Replace the fixed window with a sliding one:
- `lds_window_base` is the absolute page index of LDS entry 0; every refresh
  reads block_tables_lds[abs_page - lds_window_base] (was - split_start_page).
- The initial load fills min(split_window_pages, 4096) entries.
- slide_page_table() runs at the existing CTA convergence barriers (single-WG
  decode loop + FA4 prefill slot-A barrier, both branches symmetric). The slide
  predicate is a pure function of the wave-uniform tile bookkeeping
  (k_block_idx/v_block_idx) + page geometry, so every wave evaluates it
  identically and the two internal s_barriers stay matched across both warp
  groups -- no divergence, no deadlock. It slides the window to the
  lowest-needed page when the in-flight + next-prefetched tiles would leave the
  resident range, keeping the lagging consumer (V trails K) covered.
- Loop-invariant early-out (split_window_pages <= 4096): the hot, fits-in-LDS
  path skips the slide entirely (one predicted branch, no per-tile divides),
  so steady-state codegen is bit-identical to before.

Validated on gfx950: previously-asserting decode sk=131072 page=16 (8192 pages,
~2 slides) and chunked prefill sq=2048 sk=80000 page=16 (5000 pages, FA4 slide)
now PASS vs torch ref. Full correctness matrix 263/263 PASS. Canonical perf
unchanged: bf16 paged ~10.0 ms (1.18x), fp8 paged 4.66 ms (1.65x); production
shape sweep within run-to-run noise.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-16 15:19:45 +00:00
..