mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
The K_mem_load / V_mem_load lambdas unconditionally call refresh_*_offsets
after each load to prepare per-element page_offsets for the next tile.
On the *last* tile of the (split-KV per-split) loop the next load is
never issued, but the refresh still reads
block_tables[block_table_offset + (last_relative_tile + 1)] — one past
the seq's last valid logical_page on the final split. When block_tables
happens to be the last allocation in a memory page that read faults.
The PyTorch caching allocator hides the bug for small workloads (the
4-byte OOB lands in adjacent live memory and just returns garbage), but
it reproduces reliably once a workload deep-copies >~30 distinct
block_tables tensors and the allocator scatters them across unmapped
page boundaries. The fault is not split-KV specific — the single-launch
path (num_splits == 1) hits the same OOB on the final tile of the only
"split". Verified on MI355 with a 200-config decode FP8 sweep (b ∈
{1,4,8,16,32}, sk ∈ {512,1024,2048}, d ∈ {64,128}, GQA-{2,8}, bs ∈
{16,32,64}, bf16+fp16, ±FP8): 200/200 pass against the reference; same
configs were "memory access fault by GPU node" at iter ~27 before the
fix.
Note on the gate: k_block_idx / v_block_idx are 0-based *relative to
this split*, while num_total_loop is the absolute end index, so the
correct bound is `num_total_loop - num_blocks_start` (= per-split iter
count). Skipping the refresh leaves k_page_offsets / v_page_offsets
stale on the final iter, which is harmless because no subsequent load
consumes them.
Co-authored-by: Cursor <cursoragent@cursor.com>