mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 02:54:21 +00:00
Fix int32 overflow in CK-UA via pointer rebasing
When using large KV caches (>131K blocks for d64/GQA-8), the tensor coordinate offset calculation overflows int32: offset = row_index * stride where stride = num_kv_heads * head_dim = 512 With 150K blocks at block_size=32: max_row = 4,799,968 max_offset = 4,799,968 × 512 = 2,457,583,616 > 2^31 This caused 77.7% of output elements to be incorrect. Solution: Pointer rebasing - Add k_row_stride and v_row_stride parameters to pipeline - Calculate int64 offset and rebase buffer pointer: base_ptr + (int64)offset - Set window origin to 0 (small int32 relative to new base) - Call init_raw() to update AMD buffer resource descriptor - Enabled only for hdim <= 64 (hdim=128 has different buffer layout) - Falls back to original set_window_origin when strides not provided Test results: - 150K blocks (overflow): CK vs Triton max diff 4.9e-4 (PASS) - 1K blocks (no overflow): CK vs Triton max diff 4.9e-4 (PASS) - 131K blocks (large): CK vs Triton max diff 1.2e-4 (PASS) Made-with: Claude Code
This commit is contained in:
@@ -468,7 +468,9 @@ struct UnifiedAttentionKernel
|
||||
kv_page_size_in_blocks,
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
smem_ptr);
|
||||
smem_ptr,
|
||||
kargs.stride_k_cache_1,
|
||||
kargs.stride_v_cache_1);
|
||||
}();
|
||||
|
||||
// O DRAM and O DRAM window
|
||||
|
||||
@@ -187,7 +187,9 @@ struct UnifiedAttentionPipeline
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
FmhaMask mask,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
long_index_t k_row_stride = 0,
|
||||
long_index_t v_row_stride = 0) const
|
||||
{
|
||||
using namespace ck_tile;
|
||||
static_assert(
|
||||
@@ -361,17 +363,48 @@ struct UnifiedAttentionPipeline
|
||||
block_table_offset += num_blocks_start;
|
||||
index_t kv_blk_idx_initial = block_tables_ptr_[block_table_offset + k_block_idx];
|
||||
|
||||
// Use pointer rebasing to avoid int32 overflow in tensor_coordinate::get_offset()
|
||||
// for large KV pools (>131K blocks for d64/GQA-8).
|
||||
// Only enabled when row strides are provided (from kernel) and for hdim <= 64 configs.
|
||||
const bool use_ptr_rebase = (k_row_stride > 0 && v_row_stride > 0 && kHeadDim <= 64);
|
||||
|
||||
// Get views and save original base pointers
|
||||
auto k_view = k_dram_block_window_tmp.get_bottom_tensor_view();
|
||||
auto v_view = v_dram_block_window_tmp.get_bottom_tensor_view();
|
||||
auto* k_base_ptr = k_view.buf_.p_data_;
|
||||
auto* v_base_ptr = v_view.buf_.p_data_;
|
||||
const auto k_buf_size_orig = k_view.buf_.buffer_size_;
|
||||
const auto v_buf_size_orig = v_view.buf_.buffer_size_;
|
||||
|
||||
if(use_ptr_rebase)
|
||||
{
|
||||
// Rebase pointers to avoid int32 overflow in window origin coordinates
|
||||
long_index_t k_off =
|
||||
static_cast<long_index_t>(kv_blk_idx_initial) * PageSize * k_row_stride;
|
||||
k_view.buf_.p_data_ = k_base_ptr + k_off;
|
||||
auto new_k = k_buf_size_orig - k_off;
|
||||
k_view.buf_.buffer_size_ = new_k > 0 ? new_k : kPageBlockSize * kHeadDim;
|
||||
|
||||
long_index_t v_off =
|
||||
static_cast<long_index_t>(kv_blk_idx_initial) * PageSize * v_row_stride;
|
||||
v_view.buf_.p_data_ = v_base_ptr + v_off;
|
||||
auto new_v = v_buf_size_orig - v_off;
|
||||
v_view.buf_.buffer_size_ = new_v > 0 ? new_v : kPageBlockSize * kHeadDim;
|
||||
}
|
||||
|
||||
const index_t init_origin = use_ptr_rebase ? 0 : kv_blk_idx_initial * PageSize;
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tile_window(k_view,
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
{kv_blk_idx_initial * PageSize, 0},
|
||||
{init_origin, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
k_dram_window.init_raw();
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tile_window(v_view,
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{kv_blk_idx_initial * PageSize, 0},
|
||||
{init_origin, 0},
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
v_dram_window.init_raw();
|
||||
|
||||
@@ -463,6 +496,16 @@ struct UnifiedAttentionPipeline
|
||||
constexpr int K_mem_su_ld_insts = k_dram_window.get_num_of_access();
|
||||
constexpr int V_mem_su_ld_insts = v_dram_window.get_num_of_access();
|
||||
|
||||
// Helper lambda to rebase window pointer (avoids int32 overflow)
|
||||
auto rebase_window = [](auto& window, auto* base_ptr, long_index_t elem_offset,
|
||||
auto buf_size_orig) {
|
||||
window.bottom_tensor_view_.buf_.p_data_ = base_ptr + elem_offset;
|
||||
auto new_size = buf_size_orig - elem_offset;
|
||||
window.bottom_tensor_view_.buf_.buffer_size_ = new_size > 0 ? new_size : kPageBlockSize * kHeadDim;
|
||||
window.init_raw();
|
||||
window.set_window_origin({0, 0});
|
||||
};
|
||||
|
||||
// Page block index tracking
|
||||
// const index_t kv_page_size_in_blocks =
|
||||
// PageSize / kPageBlockSize;
|
||||
@@ -475,10 +518,20 @@ struct UnifiedAttentionPipeline
|
||||
|
||||
index_t k_page_blk_idx =
|
||||
block_tables_ptr_[block_table_offset + (k_block_idx / kv_page_size_in_blocks)];
|
||||
k_dram_window.set_window_origin(
|
||||
{k_page_blk_idx * PageSize +
|
||||
(k_block_idx % kv_page_size_in_blocks) * kPageBlockSize,
|
||||
0});
|
||||
if(use_ptr_rebase)
|
||||
{
|
||||
long_index_t k_row =
|
||||
static_cast<long_index_t>(k_page_blk_idx) * PageSize +
|
||||
(k_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
|
||||
rebase_window(k_dram_window, k_base_ptr, k_row * k_row_stride, k_buf_size_orig);
|
||||
}
|
||||
else
|
||||
{
|
||||
k_dram_window.set_window_origin(
|
||||
{k_page_blk_idx * PageSize +
|
||||
(k_block_idx % kv_page_size_in_blocks) * kPageBlockSize,
|
||||
0});
|
||||
}
|
||||
};
|
||||
|
||||
auto V_mem_load = [&](auto v_lds_write_idx) {
|
||||
@@ -487,10 +540,20 @@ struct UnifiedAttentionPipeline
|
||||
|
||||
index_t v_page_blk_idx =
|
||||
block_tables_ptr_[block_table_offset + (v_block_idx / kv_page_size_in_blocks)];
|
||||
v_dram_window.set_window_origin(
|
||||
{v_page_blk_idx * PageSize +
|
||||
(v_block_idx % kv_page_size_in_blocks) * kPageBlockSize,
|
||||
0});
|
||||
if(use_ptr_rebase)
|
||||
{
|
||||
long_index_t v_row =
|
||||
static_cast<long_index_t>(v_page_blk_idx) * PageSize +
|
||||
(v_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
|
||||
rebase_window(v_dram_window, v_base_ptr, v_row * v_row_stride, v_buf_size_orig);
|
||||
}
|
||||
else
|
||||
{
|
||||
v_dram_window.set_window_origin(
|
||||
{v_page_blk_idx * PageSize +
|
||||
(v_block_idx % kv_page_size_in_blocks) * kPageBlockSize,
|
||||
0});
|
||||
}
|
||||
};
|
||||
|
||||
auto K_lds_load = [&](auto k_lds_read_idx) {
|
||||
@@ -1123,7 +1186,9 @@ struct UnifiedAttentionPipeline
|
||||
const index_t kv_page_size_in_blocks,
|
||||
FmhaMask mask,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
long_index_t k_row_stride = 0,
|
||||
long_index_t v_row_stride = 0) const
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
@@ -1143,7 +1208,9 @@ struct UnifiedAttentionPipeline
|
||||
identity{},
|
||||
mask,
|
||||
scale_s,
|
||||
smem_ptr);
|
||||
smem_ptr,
|
||||
k_row_stride,
|
||||
v_row_stride);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user