Lift kPageBlockSize <= page_size constraint in CK-UA pipeline

Refactor the K/V DRAM access in the unified-attention pipeline to use
tile_scatter_gather with a unified per-(thread, Y0-iter) page-offset
formula:

    logical_token = tile_idx * kPageBlockSize + thread_N_pos + i * Y0_step_N
    logical_page  = logical_token / page_size
    within_page   = logical_token % page_size
    phys_page     = block_tables[block_table_offset + logical_page]
    page_offsets[i] = (phys_page * page_size + within_page) * row_stride

The page indirection now lives entirely in page_offsets, refreshed via
update_page_idx() between iters. The per-iter SRD rebase
(set_bottom_tensor_view_data_ptr + init_raw) and the use_ptr_rebase
overflow heuristic are gone.

Effects:
 - The assertion kv_page_size_in_blocks >= 1 (i.e. kPageBlockSize <=
   page_size) in the kernel is dropped. Tiles may now span multiple
   cache pages, as long as Y0_step_N (= N1*N2 from the K/V tile dist)
   divides page_size so that a wave-wide load never straddles a page.
 - Pipeline arg renamed kv_page_size_in_blocks -> page_size (PageSize
   in tokens). Kernel passes kargs.page_size through directly.
 - Validated correctness vs Triton on bf16 / d=64 / decode_s with
   block_size in {16, 32, 64}; max abs diff 1.22e-04 in all cases.
   Perf is on par with the prior pass-1 scaffolding (~3.6 ms on the
   131072-context shape).

TODO(overflow): page_offsets are index_t; caches whose
num_blocks * page_size * row_stride exceeds INT32_MAX will wrap.
A future change should plumb long_index_t through the scatter-gather
load path or compute a per-batch min-page shift in a pre-pass.

TODO(unsupported regime): page_size < Y0_step_N (a wave crosses a page
mid-iter) needs per-lane VGPR SRDs and is not implemented.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
juuso-oskari
2026-05-11 10:04:01 +00:00
parent 8506db8761
commit 473869aba5
2 changed files with 108 additions and 96 deletions

View File

@@ -307,7 +307,7 @@ struct UnifiedAttentionKernel
const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len);
index_t _max_seq_prefix_len = amd_wave_read_first_lane(
(context_len + q_block_local_idx * kBlockQ + (kBlockM - 1) + 1));
(context_len + q_block_local_idx * kBlockQ + (kBlockQ - 1) + 1));
if(seq_len < _max_seq_prefix_len)
{
@@ -454,8 +454,12 @@ struct UnifiedAttentionKernel
return FmhaMask{cur_batch_query_len, seq_len};
}();
const index_t kv_page_size_in_blocks = kargs.page_size / kPageBlockSize;
assert(kv_page_size_in_blocks >= 1); // kPageBlockSize <= page_size
// Pass-2: the pipeline now uses a unified per-(thread, Y0-iter) page
// offset formula and accepts page_size in tokens directly. The earlier
// `kPageBlockSize <= page_size` constraint (which required at least one
// kernel tile to fit in a cache page) is gone — tiles may span multiple
// pages as long as the inner-N step (Y0_step_N from the K/V tile dist)
// divides page_size cleanly.
auto o_acc_tile = [&]() {
return UnifiedAttentionPipeline{}(q_dram_window,
@@ -465,7 +469,7 @@ struct UnifiedAttentionKernel
num_blocks_start,
kargs.block_tables_ptr,
block_table_offset,
kv_page_size_in_blocks,
kargs.page_size,
mask,
kargs.scale_s,
smem_ptr,

View File

@@ -181,7 +181,7 @@ struct UnifiedAttentionPipeline
const index_t num_blocks_start,
const void* block_tables_ptr,
index_t block_table_offset,
const index_t kv_page_size_in_blocks,
const index_t page_size, // PageSize in tokens (cache rows per page)
[[maybe_unused]] const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
@@ -355,62 +355,109 @@ struct UnifiedAttentionPipeline
}
}
index_t i_total_loops = num_blocks_start;
const index_t PageSize = kv_page_size_in_blocks * kPageBlockSize;
index_t i_total_loops = num_blocks_start;
const ck_tile::index_t* block_tables_ptr_ =
reinterpret_cast<const ck_tile::index_t*>(block_tables_ptr);
assert(k_block_idx == v_block_idx); // because of the following line
block_table_offset += num_blocks_start;
index_t kv_blk_idx_initial = block_tables_ptr_[block_table_offset + k_block_idx];
// When row strides are provided, use pointer rebasing to avoid int32 overflow
// in tensor_coordinate::get_offset() for large KV pools (>131K blocks for d64/GQA-8).
// When strides are 0 (legacy callers), use the original set_window_origin approach.
// Use pointer rebasing to avoid int32 overflow in tensor_coordinate for large KV pools.
// Only enabled when row strides are provided (from kernel) and for hdim <= 64 configs.
// hdim=128 configs have different buffer_view internals that cause issues with rebasing.
const bool use_ptr_rebase = (k_row_stride > 0 && v_row_stride > 0 && kHeadDim <= 64);
// Pass-2: unified page-offset formula. The kPageBlockSize <= page_size
// constraint is gone. For every (thread, Y0-iter) pair we compute:
//
// logical_token = tile_idx * kPageBlockSize
// + thread_N_pos // lane/warp partition
// + i * Y0_step_N // per-Y0-iter advance
// logical_page = logical_token / page_size // index into block_tables
// within_page = logical_token % page_size // row inside the page
// phys_page = block_tables[block_table_offset + logical_page]
// page_offsets[i] = (phys_page * page_size + within_page) * row_stride
//
// The page indirection moves entirely into page_offsets, so the per-iter
// SRD rebase (set_bottom_tensor_view_data_ptr + init_raw) is dropped —
// we just call update_page_idx() to refresh offsets between tiles. This
// works for any (kPageBlockSize, page_size) pair where Y0_step_N (= the
// inner N stride from the dist encoding, N1 * N2) divides page_size, so
// a single wave-wide load instruction never straddles a page boundary.
// If page_size < Y0_step_N, per-lane VGPR SRDs would be required and we
// don't currently support that.
//
// TODO(overflow): page_offsets are index_t (int32). For caches whose
// num_blocks * page_size * row_stride exceeds INT32_MAX, the offsets
// wrap and reads return wrong data. The previous pass had a one-shot
// base-pointer shift heuristic for this case (`use_ptr_rebase`); it has
// been removed here because it does not interact well with the unified
// formula when block_tables are non-monotonic (a far-away page produces
// a large negative relative offset that the HW OOB check clamps to 0).
// A robust fix would either plumb long_index_t through the gather load
// path or compute a per-batch min-page shift in a pre-pass.
const auto k_dist = Policy::template MakeKDramTileDistribution<Problem>();
const auto v_dist = Policy::template MakeVDramTileDistribution<Problem>();
using KDstrType = decltype(k_dist);
using VDstrType = decltype(v_dist);
constexpr index_t KNRepeat =
KDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<0>{}];
constexpr index_t VNRepeat =
VDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<0>{}];
constexpr index_t KY0_step_N =
KDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<1>{}] *
KDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<2>{}];
constexpr index_t VY0_step_N =
VDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<1>{}] *
VDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<2>{}];
const auto k_thread_coord = k_dist.calculate_index();
const auto v_thread_coord = v_dist.calculate_index();
const index_t k_thread_n_pos = k_thread_coord[number<0>{}];
const index_t v_thread_n_pos = v_thread_coord[number<0>{}];
statically_indexed_array<index_t, KNRepeat> k_page_offsets;
statically_indexed_array<index_t, VNRepeat> v_page_offsets;
auto refresh_k_offsets = [&](index_t k_tile_idx) {
static_for<0, KNRepeat, 1>{}([&](auto i) {
const index_t logical_token = k_tile_idx * kPageBlockSize + k_thread_n_pos +
static_cast<index_t>(i.value) * KY0_step_N;
const index_t logical_page = logical_token / page_size;
const index_t within_page = logical_token - logical_page * page_size;
const index_t phys_page =
block_tables_ptr_[block_table_offset + logical_page];
k_page_offsets(i) =
(phys_page * page_size + within_page) * k_row_stride;
});
};
auto refresh_v_offsets = [&](index_t v_tile_idx) {
static_for<0, VNRepeat, 1>{}([&](auto i) {
const index_t logical_token = v_tile_idx * kPageBlockSize + v_thread_n_pos +
static_cast<index_t>(i.value) * VY0_step_N;
const index_t logical_page = logical_token / page_size;
const index_t within_page = logical_token - logical_page * page_size;
const index_t phys_page =
block_tables_ptr_[block_table_offset + logical_page];
v_page_offsets(i) =
(phys_page * page_size + within_page) * v_row_stride;
});
};
refresh_k_offsets(k_block_idx);
refresh_v_offsets(v_block_idx);
auto k_view = k_dram_block_window_tmp.get_bottom_tensor_view();
auto v_view = v_dram_block_window_tmp.get_bottom_tensor_view();
auto* k_base_ptr = k_view.buf_.p_data_;
auto* v_base_ptr = v_view.buf_.p_data_;
if(use_ptr_rebase)
{
long_index_t k_off =
static_cast<long_index_t>(kv_blk_idx_initial) * PageSize * k_row_stride;
k_view.buf_.p_data_ = k_base_ptr + k_off;
auto new_k = k_view.buf_.buffer_size_ - k_off;
k_view.buf_.buffer_size_ = new_k > 0 ? new_k : kPageBlockSize * kHeadDim;
long_index_t v_off =
static_cast<long_index_t>(kv_blk_idx_initial) * PageSize * v_row_stride;
v_view.buf_.p_data_ = v_base_ptr + v_off;
auto new_v = v_view.buf_.buffer_size_ - v_off;
v_view.buf_.buffer_size_ = new_v > 0 ? new_v : kPageBlockSize * kHeadDim;
}
else
{
// Legacy path: use original view with absolute window origin
k_view = k_dram_block_window_tmp.get_bottom_tensor_view();
v_view = v_dram_block_window_tmp.get_bottom_tensor_view();
}
const index_t init_origin = use_ptr_rebase ? 0 : kv_blk_idx_initial * PageSize;
auto k_dram_window =
make_tile_window(k_view,
k_dram_block_window_tmp.get_window_lengths(),
{init_origin, 0},
Policy::template MakeKDramTileDistribution<Problem>());
make_tile_scatter_gather(k_view,
k_dram_block_window_tmp.get_window_lengths(),
{0, 0},
k_dist,
k_page_offsets);
k_dram_window.init_raw();
auto v_dram_window =
make_tile_window(v_view,
v_dram_block_window_tmp.get_window_lengths(),
{init_origin, 0},
Policy::template MakeVDramTileDistribution<Problem>());
make_tile_scatter_gather(v_view,
v_dram_block_window_tmp.get_window_lengths(),
{0, 0},
v_dist,
v_page_offsets);
v_dram_window.init_raw();
// prefetch K tile
@@ -508,60 +555,21 @@ struct UnifiedAttentionPipeline
// only for block 0 and thread
if(blockIdx.x == 0 && threadIdx.x == 0) {}
auto rebase_window = [](auto& window, auto* base_ptr, long_index_t elem_offset,
auto buf_size_orig) {
window.bottom_tensor_view_.buf_.p_data_ = base_ptr + elem_offset;
auto new_size = buf_size_orig - elem_offset;
window.bottom_tensor_view_.buf_.buffer_size_ = new_size > 0 ? new_size : kPageBlockSize * kHeadDim;
window.init_raw();
window.set_window_origin({0, 0});
};
const auto k_buf_size_orig = k_dram_window.bottom_tensor_view_.buf_.buffer_size_;
const auto v_buf_size_orig = v_dram_window.bottom_tensor_view_.buf_.buffer_size_;
// Pass-2: page indirection lives in page_offsets, not in the SRD. We
// refresh the per-iter offsets table and push it to the window via
// update_page_idx(); the SRD itself stays put (no init_raw per iter).
auto K_mem_load = [&](auto k_lds_write_idx) {
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
k_block_idx++;
index_t k_page_blk_idx =
block_tables_ptr_[block_table_offset + (k_block_idx / kv_page_size_in_blocks)];
if(use_ptr_rebase)
{
long_index_t k_row =
static_cast<long_index_t>(k_page_blk_idx) * PageSize +
(k_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
rebase_window(k_dram_window, k_base_ptr, k_row * k_row_stride, k_buf_size_orig);
}
else
{
k_dram_window.set_window_origin(
{k_page_blk_idx * PageSize +
(k_block_idx % kv_page_size_in_blocks) * kPageBlockSize,
0});
}
refresh_k_offsets(k_block_idx);
k_dram_window.update_page_idx(k_page_offsets);
};
auto V_mem_load = [&](auto v_lds_write_idx) {
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
v_block_idx++;
index_t v_page_blk_idx =
block_tables_ptr_[block_table_offset + (v_block_idx / kv_page_size_in_blocks)];
if(use_ptr_rebase)
{
long_index_t v_row =
static_cast<long_index_t>(v_page_blk_idx) * PageSize +
(v_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
rebase_window(v_dram_window, v_base_ptr, v_row * v_row_stride, v_buf_size_orig);
}
else
{
v_dram_window.set_window_origin(
{v_page_blk_idx * PageSize +
(v_block_idx % kv_page_size_in_blocks) * kPageBlockSize,
0});
}
refresh_v_offsets(v_block_idx);
v_dram_window.update_page_idx(v_page_offsets);
};
auto K_lds_load = [&](auto k_lds_read_idx) {
@@ -1191,7 +1199,7 @@ struct UnifiedAttentionPipeline
const index_t num_blocks_start,
const void* block_tables_ptr,
index_t block_table_offset,
const index_t kv_page_size_in_blocks,
const index_t page_size, // PageSize in tokens (cache rows per page)
FmhaMask mask,
float scale_s,
void* smem_ptr,
@@ -1210,7 +1218,7 @@ struct UnifiedAttentionPipeline
num_blocks_start,
block_tables_ptr,
block_table_offset,
kv_page_size_in_blocks,
page_size,
identity{},
identity{},
identity{},