mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
Three coupled changes to the Tier-2 LDS-resident page-table cache:
1. Drop the `kBlockSize >= 8 * warp_size` gate from both the runtime
kScalarPromote{K,V}PageIdx predicates and the static
GetPageTableLdsBytes() allocator. The original conservative gate
excluded TinyDecode (kBlockSize == 64); the trade-off has since
flipped now that bf16 m16 doubles the per-tile iter count (and
thus the per-tile page-table refresh count) via the halved kBlockN
change. Enabling Tier-2 on TinyDecode eliminates the per-iter
`s_waitcnt vmcnt(0)` drains that the per-lane block_tables_ptr_
vector loads were forcing.
2. Fix a silent corruption bug in GetPageTableLdsBytes(): the LDS-
allocation gate carried the old hedge while operator()'s runtime
gate had already dropped it, so on TinyDecode the bulk-load
path wrote into LDS regions belonging to the K/V double buffers
above it. Both gates now share the same constexpr predicate.
3. Split-KV bulk-load correction. refresh_*_offsets indexes
block_tables_lds by absolute page index (= block_table_offset-
relative), so on splits 1+ where the CTA's split_token_offset > 0
the original bulk load only covered pages
[0, num_pages_for_split) and read OOB. We now load
[block_table_offset, block_table_offset + split_end_page) to
cover every absolute page the CTA can index.
Also: add explicit `s_waitcnt_lgkmcnt<0>()` after the bulk-load. On
multi-warp tiers the s_barrier carries the LDS-write drain
implicitly; on single-warp TinyDecode LLVM elides s_barrier entirely
and the refresh path reads stale LDS without the explicit drain.
Validated: correctness sweep across bf16/fp16/fp8 × {decode, prefill}
× b in {1,32,128,256}, sk up to 128k. Decode perf: ~1.18x geomean vs
Triton on long-context d=128 GQA8 (was 1.5x+ pre-fix).
Co-authored-by: Cursor <cursoragent@cursor.com>