mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Refactor the K/V DRAM access in the unified-attention pipeline to use
tile_scatter_gather with a unified per-(thread, Y0-iter) page-offset
formula:
logical_token = tile_idx * kPageBlockSize + thread_N_pos + i * Y0_step_N
logical_page = logical_token / page_size
within_page = logical_token % page_size
phys_page = block_tables[block_table_offset + logical_page]
page_offsets[i] = (phys_page * page_size + within_page) * row_stride
The page indirection now lives entirely in page_offsets, refreshed via
update_page_idx() between iters. The per-iter SRD rebase
(set_bottom_tensor_view_data_ptr + init_raw) and the use_ptr_rebase
overflow heuristic are gone.
Effects:
- The assertion kv_page_size_in_blocks >= 1 (i.e. kPageBlockSize <=
page_size) in the kernel is dropped. Tiles may now span multiple
cache pages, as long as Y0_step_N (= N1*N2 from the K/V tile dist)
divides page_size so that a wave-wide load never straddles a page.
- Pipeline arg renamed kv_page_size_in_blocks -> page_size (PageSize
in tokens). Kernel passes kargs.page_size through directly.
- Validated correctness vs Triton on bf16 / d=64 / decode_s with
block_size in {16, 32, 64}; max abs diff 1.22e-04 in all cases.
Perf is on par with the prior pass-1 scaffolding (~3.6 ms on the
131072-context shape).
TODO(overflow): page_offsets are index_t; caches whose
num_blocks * page_size * row_stride exceeds INT32_MAX will wrap.
A future change should plumb long_index_t through the scatter-gather
load path or compute a per-batch min-page shift in a pre-pass.
TODO(unsupported regime): page_size < Y0_step_N (a wave crosses a page
mid-iter) needs per-lane VGPR SRDs and is not implemented.
Co-authored-by: Cursor <cursoragent@cursor.com>