From 62e8f73545c49a3882d25e709d4a63fcdf5fd41e Mon Sep 17 00:00:00 2001 From: juuso-oskari Date: Wed, 6 May 2026 12:16:30 +0000 Subject: [PATCH] Fix int32 overflow in CK-UA via pointer rebasing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When using large KV caches (>131K blocks for d64/GQA-8), the tensor coordinate offset calculation overflows int32: offset = row_index * stride where stride = num_kv_heads * head_dim = 512 With 150K blocks at block_size=32: max_row = 4,799,968 max_offset = 4,799,968 × 512 = 2,457,583,616 > 2^31 This caused 77.7% of output elements to be incorrect. Solution: Pointer rebasing - Add k_row_stride and v_row_stride parameters to pipeline - Calculate int64 offset and rebase buffer pointer: base_ptr + (int64)offset - Set window origin to 0 (small int32 relative to new base) - Call init_raw() to update AMD buffer resource descriptor - Enabled only for hdim <= 64 (hdim=128 has different buffer layout) - Falls back to original set_window_origin when strides not provided Test results: - 150K blocks (overflow): CK vs Triton max diff 4.9e-4 (PASS) - 1K blocks (no overflow): CK vs Triton max diff 4.9e-4 (PASS) - 131K blocks (large): CK vs Triton max diff 1.2e-4 (PASS) Made-with: Claude Code --- .../kernel/unified_attention_kernel.hpp | 4 +- .../pipeline/unified_attention_pipeline.hpp | 97 ++++++++++++++++--- 2 files changed, 85 insertions(+), 16 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 7c1facc545..dd190669e7 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -468,7 +468,9 @@ struct UnifiedAttentionKernel kv_page_size_in_blocks, mask, kargs.scale_s, - smem_ptr); + smem_ptr, + kargs.stride_k_cache_1, + kargs.stride_v_cache_1); }(); // O DRAM and O DRAM window diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 3f60dff312..a99db4cc94 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -187,7 +187,9 @@ struct UnifiedAttentionPipeline const OAccElementFunction& o_acc_element_func, FmhaMask mask, float scale_s, - void* smem_ptr) const + void* smem_ptr, + long_index_t k_row_stride = 0, + long_index_t v_row_stride = 0) const { using namespace ck_tile; static_assert( @@ -361,17 +363,48 @@ struct UnifiedAttentionPipeline block_table_offset += num_blocks_start; index_t kv_blk_idx_initial = block_tables_ptr_[block_table_offset + k_block_idx]; + // Use pointer rebasing to avoid int32 overflow in tensor_coordinate::get_offset() + // for large KV pools (>131K blocks for d64/GQA-8). + // Only enabled when row strides are provided (from kernel) and for hdim <= 64 configs. + const bool use_ptr_rebase = (k_row_stride > 0 && v_row_stride > 0 && kHeadDim <= 64); + + // Get views and save original base pointers + auto k_view = k_dram_block_window_tmp.get_bottom_tensor_view(); + auto v_view = v_dram_block_window_tmp.get_bottom_tensor_view(); + auto* k_base_ptr = k_view.buf_.p_data_; + auto* v_base_ptr = v_view.buf_.p_data_; + const auto k_buf_size_orig = k_view.buf_.buffer_size_; + const auto v_buf_size_orig = v_view.buf_.buffer_size_; + + if(use_ptr_rebase) + { + // Rebase pointers to avoid int32 overflow in window origin coordinates + long_index_t k_off = + static_cast(kv_blk_idx_initial) * PageSize * k_row_stride; + k_view.buf_.p_data_ = k_base_ptr + k_off; + auto new_k = k_buf_size_orig - k_off; + k_view.buf_.buffer_size_ = new_k > 0 ? new_k : kPageBlockSize * kHeadDim; + + long_index_t v_off = + static_cast(kv_blk_idx_initial) * PageSize * v_row_stride; + v_view.buf_.p_data_ = v_base_ptr + v_off; + auto new_v = v_buf_size_orig - v_off; + v_view.buf_.buffer_size_ = new_v > 0 ? new_v : kPageBlockSize * kHeadDim; + } + + const index_t init_origin = use_ptr_rebase ? 0 : kv_blk_idx_initial * PageSize; + auto k_dram_window = - make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + make_tile_window(k_view, k_dram_block_window_tmp.get_window_lengths(), - {kv_blk_idx_initial * PageSize, 0}, + {init_origin, 0}, Policy::template MakeKDramTileDistribution()); k_dram_window.init_raw(); auto v_dram_window = - make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + make_tile_window(v_view, v_dram_block_window_tmp.get_window_lengths(), - {kv_blk_idx_initial * PageSize, 0}, + {init_origin, 0}, Policy::template MakeVDramTileDistribution()); v_dram_window.init_raw(); @@ -463,6 +496,16 @@ struct UnifiedAttentionPipeline constexpr int K_mem_su_ld_insts = k_dram_window.get_num_of_access(); constexpr int V_mem_su_ld_insts = v_dram_window.get_num_of_access(); + // Helper lambda to rebase window pointer (avoids int32 overflow) + auto rebase_window = [](auto& window, auto* base_ptr, long_index_t elem_offset, + auto buf_size_orig) { + window.bottom_tensor_view_.buf_.p_data_ = base_ptr + elem_offset; + auto new_size = buf_size_orig - elem_offset; + window.bottom_tensor_view_.buf_.buffer_size_ = new_size > 0 ? new_size : kPageBlockSize * kHeadDim; + window.init_raw(); + window.set_window_origin({0, 0}); + }; + // Page block index tracking // const index_t kv_page_size_in_blocks = // PageSize / kPageBlockSize; @@ -475,10 +518,20 @@ struct UnifiedAttentionPipeline index_t k_page_blk_idx = block_tables_ptr_[block_table_offset + (k_block_idx / kv_page_size_in_blocks)]; - k_dram_window.set_window_origin( - {k_page_blk_idx * PageSize + - (k_block_idx % kv_page_size_in_blocks) * kPageBlockSize, - 0}); + if(use_ptr_rebase) + { + long_index_t k_row = + static_cast(k_page_blk_idx) * PageSize + + (k_block_idx % kv_page_size_in_blocks) * kPageBlockSize; + rebase_window(k_dram_window, k_base_ptr, k_row * k_row_stride, k_buf_size_orig); + } + else + { + k_dram_window.set_window_origin( + {k_page_blk_idx * PageSize + + (k_block_idx % kv_page_size_in_blocks) * kPageBlockSize, + 0}); + } }; auto V_mem_load = [&](auto v_lds_write_idx) { @@ -487,10 +540,20 @@ struct UnifiedAttentionPipeline index_t v_page_blk_idx = block_tables_ptr_[block_table_offset + (v_block_idx / kv_page_size_in_blocks)]; - v_dram_window.set_window_origin( - {v_page_blk_idx * PageSize + - (v_block_idx % kv_page_size_in_blocks) * kPageBlockSize, - 0}); + if(use_ptr_rebase) + { + long_index_t v_row = + static_cast(v_page_blk_idx) * PageSize + + (v_block_idx % kv_page_size_in_blocks) * kPageBlockSize; + rebase_window(v_dram_window, v_base_ptr, v_row * v_row_stride, v_buf_size_orig); + } + else + { + v_dram_window.set_window_origin( + {v_page_blk_idx * PageSize + + (v_block_idx % kv_page_size_in_blocks) * kPageBlockSize, + 0}); + } }; auto K_lds_load = [&](auto k_lds_read_idx) { @@ -1123,7 +1186,9 @@ struct UnifiedAttentionPipeline const index_t kv_page_size_in_blocks, FmhaMask mask, float scale_s, - void* smem_ptr) const + void* smem_ptr, + long_index_t k_row_stride = 0, + long_index_t v_row_stride = 0) const { using namespace ck_tile; @@ -1143,7 +1208,9 @@ struct UnifiedAttentionPipeline identity{}, mask, scale_s, - smem_ptr); + smem_ptr, + k_row_stride, + v_row_stride); } };