Files
composable_kernel/include
juuso-oskari badc807025 CK-UA: enable Tier-2 LDS page-table cache on decode + fix split-KV bulk-load OOB
Three coupled changes to the Tier-2 LDS-resident page-table cache:

1. Drop the `kBlockSize >= 8 * warp_size` gate from both the runtime
   kScalarPromote{K,V}PageIdx predicates and the static
   GetPageTableLdsBytes() allocator. The original conservative gate
   excluded TinyDecode (kBlockSize == 64); the trade-off has since
   flipped now that bf16 m16 doubles the per-tile iter count (and
   thus the per-tile page-table refresh count) via the halved kBlockN
   change. Enabling Tier-2 on TinyDecode eliminates the per-iter
   `s_waitcnt vmcnt(0)` drains that the per-lane block_tables_ptr_
   vector loads were forcing.

2. Fix a silent corruption bug in GetPageTableLdsBytes(): the LDS-
   allocation gate carried the old hedge while operator()'s runtime
   gate had already dropped it, so on TinyDecode the bulk-load
   path wrote into LDS regions belonging to the K/V double buffers
   above it. Both gates now share the same constexpr predicate.

3. Split-KV bulk-load correction. refresh_*_offsets indexes
   block_tables_lds by absolute page index (= block_table_offset-
   relative), so on splits 1+ where the CTA's split_token_offset > 0
   the original bulk load only covered pages
   [0, num_pages_for_split) and read OOB. We now load
   [block_table_offset, block_table_offset + split_end_page) to
   cover every absolute page the CTA can index.

Also: add explicit `s_waitcnt_lgkmcnt<0>()` after the bulk-load. On
multi-warp tiers the s_barrier carries the LDS-write drain
implicitly; on single-warp TinyDecode LLVM elides s_barrier entirely
and the refresh path reads stale LDS without the explicit drain.

Validated: correctness sweep across bf16/fp16/fp8 × {decode, prefill}
× b in {1,32,128,256}, sk up to 128k. Decode perf: ~1.18x geomean vs
Triton on long-context d=128 GQA8 (was 1.5x+ pre-fix).

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-26 08:21:10 +00:00
..