mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
Fix CK-UA int32 overflow: use saved original pointers and row strides for rebasing
When kCachePtrInt32OverflowPossible=true, we now: 1. Save original K/V buffer pointers at pipeline start 2. Always rebase by computing offset from original base pointer 3. Use k_row_stride/v_row_stride passed from kernel args This fixes the bug where successive rebases would compound, since each rebase modified buf.p_data_ without tracking the original base. Key insight: separate long_index_t variables for block_offset and elem_offset avoid compiler type promotion issues that caused assembly errors when computing the total offset in a single expression. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
This commit is contained in:
@@ -384,6 +384,10 @@ struct UnifiedAttentionPipeline
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
v_dram_window.init_raw();
|
||||
|
||||
// Store original buffer pointers for rebasing (needed when kCachePtrInt32OverflowPossible)
|
||||
auto* k_original_ptr = k_dram_window.bottom_tensor_view_.buf_.p_data_;
|
||||
auto* v_original_ptr = v_dram_window.bottom_tensor_view_.buf_.p_data_;
|
||||
|
||||
// prefetch K tile
|
||||
constexpr index_t k0_loops = 1;
|
||||
constexpr index_t k1_loops = 1;
|
||||
@@ -489,9 +493,8 @@ struct UnifiedAttentionPipeline
|
||||
// Page block index tracking
|
||||
// const index_t kv_page_size_in_blocks =
|
||||
// PageSize / kPageBlockSize;
|
||||
// index_t kv_block_idx = 0;
|
||||
// only for block 0 and thread
|
||||
if(blockIdx.x == 0 && threadIdx.x == 0) {}
|
||||
|
||||
|
||||
auto K_mem_load = [&](auto k_lds_write_idx) {
|
||||
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
|
||||
k_block_idx++;
|
||||
@@ -499,27 +502,25 @@ struct UnifiedAttentionPipeline
|
||||
index_t k_page_blk_idx =
|
||||
block_tables_ptr_[block_table_offset + (k_block_idx / kv_page_size_in_blocks)];
|
||||
|
||||
// Calculate offset for this block
|
||||
index_t offset = k_page_blk_idx * PageSize +
|
||||
(k_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
|
||||
|
||||
// For large cache, check if we'd overflow int32 in set_window_origin
|
||||
// For large cache, always rebase to avoid int32 overflow in set_window_origin
|
||||
if constexpr(kCachePtrInt32OverflowPossible)
|
||||
{
|
||||
if(offset > kInt32Max)
|
||||
{
|
||||
// Rebase: advance pointer by offset, then use origin {0, 0}
|
||||
auto& buf = k_dram_window.bottom_tensor_view_.buf_;
|
||||
auto stride_0 = k_dram_window.bottom_tensor_view_.desc_.calculate_offset(make_tuple(1, 0));
|
||||
buf.p_data_ = buf.p_data_ + (static_cast<long_index_t>(offset) * stride_0);
|
||||
k_dram_window.init_raw();
|
||||
k_dram_window.set_window_origin({0, 0});
|
||||
return;
|
||||
}
|
||||
}
|
||||
// these need to be cast to long_index_t to avoid int32 overflow
|
||||
long_index_t block_offset = k_page_blk_idx * PageSize;
|
||||
long_index_t elem_offset = (k_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
|
||||
long_index_t total_row_offset = block_offset + elem_offset;
|
||||
// Rebase: set pointer to original base + (row_offset * row_stride), then use origin {0, 0}
|
||||
auto& buf = k_dram_window.bottom_tensor_view_.buf_;
|
||||
buf.p_data_ = k_original_ptr + (total_row_offset * k_row_stride);
|
||||
k_dram_window.init_raw();
|
||||
k_dram_window.set_window_origin({0, 0});
|
||||
return;
|
||||
|
||||
// Fast path: no overflow, just set window origin
|
||||
k_dram_window.set_window_origin({offset, 0});
|
||||
}else{
|
||||
index_t offset = k_page_blk_idx * PageSize +
|
||||
(k_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
|
||||
k_dram_window.set_window_origin({offset, 0});
|
||||
}
|
||||
};
|
||||
|
||||
auto V_mem_load = [&](auto v_lds_write_idx) {
|
||||
@@ -529,27 +530,25 @@ struct UnifiedAttentionPipeline
|
||||
index_t v_page_blk_idx =
|
||||
block_tables_ptr_[block_table_offset + (v_block_idx / kv_page_size_in_blocks)];
|
||||
|
||||
// Calculate offset for this block
|
||||
index_t offset = v_page_blk_idx * PageSize +
|
||||
(v_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
|
||||
|
||||
// For large cache, check if we'd overflow int32 in set_window_origin
|
||||
// For large cache, always rebase to avoid int32 overflow in set_window_origin
|
||||
if constexpr(kCachePtrInt32OverflowPossible)
|
||||
{
|
||||
if(offset > kInt32Max)
|
||||
{
|
||||
// Rebase: advance pointer by offset, then use origin {0, 0}
|
||||
auto& buf = v_dram_window.bottom_tensor_view_.buf_;
|
||||
auto stride_0 = v_dram_window.bottom_tensor_view_.desc_.calculate_offset(make_tuple(1, 0));
|
||||
buf.p_data_ = buf.p_data_ + (static_cast<long_index_t>(offset) * stride_0);
|
||||
v_dram_window.init_raw();
|
||||
v_dram_window.set_window_origin({0, 0});
|
||||
return;
|
||||
}
|
||||
}
|
||||
// these need to be cast to long_index_t to avoid int32 overflow
|
||||
long_index_t block_offset = v_page_blk_idx * PageSize;
|
||||
long_index_t elem_offset = (v_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
|
||||
long_index_t total_row_offset = block_offset + elem_offset;
|
||||
// Rebase: set pointer to original base + (row_offset * row_stride), then use origin {0, 0}
|
||||
auto& buf = v_dram_window.bottom_tensor_view_.buf_;
|
||||
buf.p_data_ = v_original_ptr + (total_row_offset * v_row_stride);
|
||||
v_dram_window.init_raw();
|
||||
v_dram_window.set_window_origin({0, 0});
|
||||
return;
|
||||
|
||||
// Fast path: no overflow, just set window origin
|
||||
v_dram_window.set_window_origin({offset, 0});
|
||||
}else{
|
||||
index_t offset = v_page_blk_idx * PageSize +
|
||||
(v_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
|
||||
v_dram_window.set_window_origin({offset, 0});
|
||||
}
|
||||
};
|
||||
|
||||
auto K_lds_load = [&](auto k_lds_read_idx) {
|
||||
|
||||
Reference in New Issue
Block a user