Files
composable_kernel/include
juuso-oskari 473869aba5 Lift kPageBlockSize <= page_size constraint in CK-UA pipeline
Refactor the K/V DRAM access in the unified-attention pipeline to use
tile_scatter_gather with a unified per-(thread, Y0-iter) page-offset
formula:

    logical_token = tile_idx * kPageBlockSize + thread_N_pos + i * Y0_step_N
    logical_page  = logical_token / page_size
    within_page   = logical_token % page_size
    phys_page     = block_tables[block_table_offset + logical_page]
    page_offsets[i] = (phys_page * page_size + within_page) * row_stride

The page indirection now lives entirely in page_offsets, refreshed via
update_page_idx() between iters. The per-iter SRD rebase
(set_bottom_tensor_view_data_ptr + init_raw) and the use_ptr_rebase
overflow heuristic are gone.

Effects:
 - The assertion kv_page_size_in_blocks >= 1 (i.e. kPageBlockSize <=
   page_size) in the kernel is dropped. Tiles may now span multiple
   cache pages, as long as Y0_step_N (= N1*N2 from the K/V tile dist)
   divides page_size so that a wave-wide load never straddles a page.
 - Pipeline arg renamed kv_page_size_in_blocks -> page_size (PageSize
   in tokens). Kernel passes kargs.page_size through directly.
 - Validated correctness vs Triton on bf16 / d=64 / decode_s with
   block_size in {16, 32, 64}; max abs diff 1.22e-04 in all cases.
   Perf is on par with the prior pass-1 scaffolding (~3.6 ms on the
   131072-context shape).

TODO(overflow): page_offsets are index_t; caches whose
num_blocks * page_size * row_stride exceeds INT32_MAX will wrap.
A future change should plumb long_index_t through the scatter-gather
load path or compute a per-batch min-page shift in a pre-pass.

TODO(unsupported regime): page_size < Y0_step_N (a wave crosses a page
mid-iter) needs per-lane VGPR SRDs and is not implemented.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-11 10:04:01 +00:00
..