Fix int32 overflow in CK-UA pipeline via pointer rebasing

tensor_coordinate::get_offset() returns index_t (int32), causing overflow
when page_idx * block_size * stride > 2^31 (~131K blocks for d64/GQA-8).

Fix: rebase K/V data pointer for each page using int64 arithmetic instead
of set_window_origin with large offsets. After rebasing p_data_ and
buffer_size_, call init_raw() to refresh the AMD buffer resource descriptor,
then set_window_origin({0,0}) to reset cached coordinates.

Tested: num_blocks up to 2M with nkh=1/8, blk=32/64. All pass.
Made-with: Cursor
This commit is contained in:
root
2026-04-02 09:30:00 +00:00
parent e8587b86c2
commit 8506db8761
2 changed files with 93 additions and 16 deletions

View File

@@ -468,7 +468,9 @@ struct UnifiedAttentionKernel
kv_page_size_in_blocks,
mask,
kargs.scale_s,
smem_ptr);
smem_ptr,
static_cast<long_index_t>(kargs.stride_k_cache_1),
static_cast<long_index_t>(kargs.stride_v_cache_1));
}();
// O DRAM and O DRAM window

View File

@@ -187,7 +187,9 @@ struct UnifiedAttentionPipeline
const OAccElementFunction& o_acc_element_func,
FmhaMask mask,
float scale_s,
void* smem_ptr) const
void* smem_ptr,
long_index_t k_row_stride = 0,
long_index_t v_row_stride = 0) const
{
using namespace ck_tile;
static_assert(
@@ -361,17 +363,53 @@ struct UnifiedAttentionPipeline
block_table_offset += num_blocks_start;
index_t kv_blk_idx_initial = block_tables_ptr_[block_table_offset + k_block_idx];
// When row strides are provided, use pointer rebasing to avoid int32 overflow
// in tensor_coordinate::get_offset() for large KV pools (>131K blocks for d64/GQA-8).
// When strides are 0 (legacy callers), use the original set_window_origin approach.
// Use pointer rebasing to avoid int32 overflow in tensor_coordinate for large KV pools.
// Only enabled when row strides are provided (from kernel) and for hdim <= 64 configs.
// hdim=128 configs have different buffer_view internals that cause issues with rebasing.
const bool use_ptr_rebase = (k_row_stride > 0 && v_row_stride > 0 && kHeadDim <= 64);
auto k_view = k_dram_block_window_tmp.get_bottom_tensor_view();
auto v_view = v_dram_block_window_tmp.get_bottom_tensor_view();
auto* k_base_ptr = k_view.buf_.p_data_;
auto* v_base_ptr = v_view.buf_.p_data_;
if(use_ptr_rebase)
{
long_index_t k_off =
static_cast<long_index_t>(kv_blk_idx_initial) * PageSize * k_row_stride;
k_view.buf_.p_data_ = k_base_ptr + k_off;
auto new_k = k_view.buf_.buffer_size_ - k_off;
k_view.buf_.buffer_size_ = new_k > 0 ? new_k : kPageBlockSize * kHeadDim;
long_index_t v_off =
static_cast<long_index_t>(kv_blk_idx_initial) * PageSize * v_row_stride;
v_view.buf_.p_data_ = v_base_ptr + v_off;
auto new_v = v_view.buf_.buffer_size_ - v_off;
v_view.buf_.buffer_size_ = new_v > 0 ? new_v : kPageBlockSize * kHeadDim;
}
else
{
// Legacy path: use original view with absolute window origin
k_view = k_dram_block_window_tmp.get_bottom_tensor_view();
v_view = v_dram_block_window_tmp.get_bottom_tensor_view();
}
const index_t init_origin = use_ptr_rebase ? 0 : kv_blk_idx_initial * PageSize;
auto k_dram_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
make_tile_window(k_view,
k_dram_block_window_tmp.get_window_lengths(),
{kv_blk_idx_initial * PageSize, 0},
{init_origin, 0},
Policy::template MakeKDramTileDistribution<Problem>());
k_dram_window.init_raw();
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
make_tile_window(v_view,
v_dram_block_window_tmp.get_window_lengths(),
{kv_blk_idx_initial * PageSize, 0},
{init_origin, 0},
Policy::template MakeVDramTileDistribution<Problem>());
v_dram_window.init_raw();
@@ -469,16 +507,39 @@ struct UnifiedAttentionPipeline
// index_t kv_block_idx = 0;
// only for block 0 and thread
if(blockIdx.x == 0 && threadIdx.x == 0) {}
auto rebase_window = [](auto& window, auto* base_ptr, long_index_t elem_offset,
auto buf_size_orig) {
window.bottom_tensor_view_.buf_.p_data_ = base_ptr + elem_offset;
auto new_size = buf_size_orig - elem_offset;
window.bottom_tensor_view_.buf_.buffer_size_ = new_size > 0 ? new_size : kPageBlockSize * kHeadDim;
window.init_raw();
window.set_window_origin({0, 0});
};
const auto k_buf_size_orig = k_dram_window.bottom_tensor_view_.buf_.buffer_size_;
const auto v_buf_size_orig = v_dram_window.bottom_tensor_view_.buf_.buffer_size_;
auto K_mem_load = [&](auto k_lds_write_idx) {
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
k_block_idx++;
index_t k_page_blk_idx =
block_tables_ptr_[block_table_offset + (k_block_idx / kv_page_size_in_blocks)];
k_dram_window.set_window_origin(
{k_page_blk_idx * PageSize +
(k_block_idx % kv_page_size_in_blocks) * kPageBlockSize,
0});
if(use_ptr_rebase)
{
long_index_t k_row =
static_cast<long_index_t>(k_page_blk_idx) * PageSize +
(k_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
rebase_window(k_dram_window, k_base_ptr, k_row * k_row_stride, k_buf_size_orig);
}
else
{
k_dram_window.set_window_origin(
{k_page_blk_idx * PageSize +
(k_block_idx % kv_page_size_in_blocks) * kPageBlockSize,
0});
}
};
auto V_mem_load = [&](auto v_lds_write_idx) {
@@ -487,10 +548,20 @@ struct UnifiedAttentionPipeline
index_t v_page_blk_idx =
block_tables_ptr_[block_table_offset + (v_block_idx / kv_page_size_in_blocks)];
v_dram_window.set_window_origin(
{v_page_blk_idx * PageSize +
(v_block_idx % kv_page_size_in_blocks) * kPageBlockSize,
0});
if(use_ptr_rebase)
{
long_index_t v_row =
static_cast<long_index_t>(v_page_blk_idx) * PageSize +
(v_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
rebase_window(v_dram_window, v_base_ptr, v_row * v_row_stride, v_buf_size_orig);
}
else
{
v_dram_window.set_window_origin(
{v_page_blk_idx * PageSize +
(v_block_idx % kv_page_size_in_blocks) * kPageBlockSize,
0});
}
};
auto K_lds_load = [&](auto k_lds_read_idx) {
@@ -1123,7 +1194,9 @@ struct UnifiedAttentionPipeline
const index_t kv_page_size_in_blocks,
FmhaMask mask,
float scale_s,
void* smem_ptr) const
void* smem_ptr,
long_index_t k_row_stride = 0,
long_index_t v_row_stride = 0) const
{
using namespace ck_tile;
@@ -1143,7 +1216,9 @@ struct UnifiedAttentionPipeline
identity{},
mask,
scale_s,
smem_ptr);
smem_ptr,
k_row_stride,
v_row_stride);
}
};