Fix CK-UA int32 overflow: use saved original pointers and row strides for rebasing

When kCachePtrInt32OverflowPossible=true, we now:
1. Save original K/V buffer pointers at pipeline start
2. Always rebase by computing offset from original base pointer
3. Use k_row_stride/v_row_stride passed from kernel args

This fixes the bug where successive rebases would compound, since each
rebase modified buf.p_data_ without tracking the original base.

Key insight: separate long_index_t variables for block_offset and
elem_offset avoid compiler type promotion issues that caused assembly
errors when computing the total offset in a single expression.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
This commit is contained in:
juuso-oskari
2026-05-10 08:59:34 +00:00
parent 397febf42c
commit cf11d1796b

View File

@@ -384,6 +384,10 @@ struct UnifiedAttentionPipeline
Policy::template MakeVDramTileDistribution<Problem>());
v_dram_window.init_raw();
// Store original buffer pointers for rebasing (needed when kCachePtrInt32OverflowPossible)
auto* k_original_ptr = k_dram_window.bottom_tensor_view_.buf_.p_data_;
auto* v_original_ptr = v_dram_window.bottom_tensor_view_.buf_.p_data_;
// prefetch K tile
constexpr index_t k0_loops = 1;
constexpr index_t k1_loops = 1;
@@ -489,9 +493,8 @@ struct UnifiedAttentionPipeline
// Page block index tracking
// const index_t kv_page_size_in_blocks =
// PageSize / kPageBlockSize;
// index_t kv_block_idx = 0;
// only for block 0 and thread
if(blockIdx.x == 0 && threadIdx.x == 0) {}
auto K_mem_load = [&](auto k_lds_write_idx) {
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
k_block_idx++;
@@ -499,27 +502,25 @@ struct UnifiedAttentionPipeline
index_t k_page_blk_idx =
block_tables_ptr_[block_table_offset + (k_block_idx / kv_page_size_in_blocks)];
// Calculate offset for this block
index_t offset = k_page_blk_idx * PageSize +
(k_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
// For large cache, check if we'd overflow int32 in set_window_origin
// For large cache, always rebase to avoid int32 overflow in set_window_origin
if constexpr(kCachePtrInt32OverflowPossible)
{
if(offset > kInt32Max)
{
// Rebase: advance pointer by offset, then use origin {0, 0}
auto& buf = k_dram_window.bottom_tensor_view_.buf_;
auto stride_0 = k_dram_window.bottom_tensor_view_.desc_.calculate_offset(make_tuple(1, 0));
buf.p_data_ = buf.p_data_ + (static_cast<long_index_t>(offset) * stride_0);
k_dram_window.init_raw();
k_dram_window.set_window_origin({0, 0});
return;
}
}
// these need to be cast to long_index_t to avoid int32 overflow
long_index_t block_offset = k_page_blk_idx * PageSize;
long_index_t elem_offset = (k_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
long_index_t total_row_offset = block_offset + elem_offset;
// Rebase: set pointer to original base + (row_offset * row_stride), then use origin {0, 0}
auto& buf = k_dram_window.bottom_tensor_view_.buf_;
buf.p_data_ = k_original_ptr + (total_row_offset * k_row_stride);
k_dram_window.init_raw();
k_dram_window.set_window_origin({0, 0});
return;
// Fast path: no overflow, just set window origin
k_dram_window.set_window_origin({offset, 0});
}else{
index_t offset = k_page_blk_idx * PageSize +
(k_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
k_dram_window.set_window_origin({offset, 0});
}
};
auto V_mem_load = [&](auto v_lds_write_idx) {
@@ -529,27 +530,25 @@ struct UnifiedAttentionPipeline
index_t v_page_blk_idx =
block_tables_ptr_[block_table_offset + (v_block_idx / kv_page_size_in_blocks)];
// Calculate offset for this block
index_t offset = v_page_blk_idx * PageSize +
(v_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
// For large cache, check if we'd overflow int32 in set_window_origin
// For large cache, always rebase to avoid int32 overflow in set_window_origin
if constexpr(kCachePtrInt32OverflowPossible)
{
if(offset > kInt32Max)
{
// Rebase: advance pointer by offset, then use origin {0, 0}
auto& buf = v_dram_window.bottom_tensor_view_.buf_;
auto stride_0 = v_dram_window.bottom_tensor_view_.desc_.calculate_offset(make_tuple(1, 0));
buf.p_data_ = buf.p_data_ + (static_cast<long_index_t>(offset) * stride_0);
v_dram_window.init_raw();
v_dram_window.set_window_origin({0, 0});
return;
}
}
// these need to be cast to long_index_t to avoid int32 overflow
long_index_t block_offset = v_page_blk_idx * PageSize;
long_index_t elem_offset = (v_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
long_index_t total_row_offset = block_offset + elem_offset;
// Rebase: set pointer to original base + (row_offset * row_stride), then use origin {0, 0}
auto& buf = v_dram_window.bottom_tensor_view_.buf_;
buf.p_data_ = v_original_ptr + (total_row_offset * v_row_stride);
v_dram_window.init_raw();
v_dram_window.set_window_origin({0, 0});
return;
// Fast path: no overflow, just set window origin
v_dram_window.set_window_origin({offset, 0});
}else{
index_t offset = v_page_blk_idx * PageSize +
(v_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
v_dram_window.set_window_origin({offset, 0});
}
};
auto K_lds_load = [&](auto k_lds_read_idx) {