From cf11d1796b9364b1f77deb010563d60c69b46bd8 Mon Sep 17 00:00:00 2001 From: juuso-oskari Date: Sun, 10 May 2026 08:59:34 +0000 Subject: [PATCH] Fix CK-UA int32 overflow: use saved original pointers and row strides for rebasing When kCachePtrInt32OverflowPossible=true, we now: 1. Save original K/V buffer pointers at pipeline start 2. Always rebase by computing offset from original base pointer 3. Use k_row_stride/v_row_stride passed from kernel args This fixes the bug where successive rebases would compound, since each rebase modified buf.p_data_ without tracking the original base. Key insight: separate long_index_t variables for block_offset and elem_offset avoid compiler type promotion issues that caused assembly errors when computing the total offset in a single expression. Co-Authored-By: Claude Sonnet 4 --- .../pipeline/unified_attention_pipeline.hpp | 77 +++++++++---------- 1 file changed, 38 insertions(+), 39 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 3cda0352ef..702ec25b73 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -384,6 +384,10 @@ struct UnifiedAttentionPipeline Policy::template MakeVDramTileDistribution()); v_dram_window.init_raw(); + // Store original buffer pointers for rebasing (needed when kCachePtrInt32OverflowPossible) + auto* k_original_ptr = k_dram_window.bottom_tensor_view_.buf_.p_data_; + auto* v_original_ptr = v_dram_window.bottom_tensor_view_.buf_.p_data_; + // prefetch K tile constexpr index_t k0_loops = 1; constexpr index_t k1_loops = 1; @@ -489,9 +493,8 @@ struct UnifiedAttentionPipeline // Page block index tracking // const index_t kv_page_size_in_blocks = // PageSize / kPageBlockSize; - // index_t kv_block_idx = 0; - // only for block 0 and thread - if(blockIdx.x == 0 && threadIdx.x == 0) {} + + auto K_mem_load = [&](auto k_lds_write_idx) { async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); k_block_idx++; @@ -499,27 +502,25 @@ struct UnifiedAttentionPipeline index_t k_page_blk_idx = block_tables_ptr_[block_table_offset + (k_block_idx / kv_page_size_in_blocks)]; - // Calculate offset for this block - index_t offset = k_page_blk_idx * PageSize + - (k_block_idx % kv_page_size_in_blocks) * kPageBlockSize; - - // For large cache, check if we'd overflow int32 in set_window_origin + // For large cache, always rebase to avoid int32 overflow in set_window_origin if constexpr(kCachePtrInt32OverflowPossible) { - if(offset > kInt32Max) - { - // Rebase: advance pointer by offset, then use origin {0, 0} - auto& buf = k_dram_window.bottom_tensor_view_.buf_; - auto stride_0 = k_dram_window.bottom_tensor_view_.desc_.calculate_offset(make_tuple(1, 0)); - buf.p_data_ = buf.p_data_ + (static_cast(offset) * stride_0); - k_dram_window.init_raw(); - k_dram_window.set_window_origin({0, 0}); - return; - } - } + // these need to be cast to long_index_t to avoid int32 overflow + long_index_t block_offset = k_page_blk_idx * PageSize; + long_index_t elem_offset = (k_block_idx % kv_page_size_in_blocks) * kPageBlockSize; + long_index_t total_row_offset = block_offset + elem_offset; + // Rebase: set pointer to original base + (row_offset * row_stride), then use origin {0, 0} + auto& buf = k_dram_window.bottom_tensor_view_.buf_; + buf.p_data_ = k_original_ptr + (total_row_offset * k_row_stride); + k_dram_window.init_raw(); + k_dram_window.set_window_origin({0, 0}); + return; - // Fast path: no overflow, just set window origin - k_dram_window.set_window_origin({offset, 0}); + }else{ + index_t offset = k_page_blk_idx * PageSize + + (k_block_idx % kv_page_size_in_blocks) * kPageBlockSize; + k_dram_window.set_window_origin({offset, 0}); + } }; auto V_mem_load = [&](auto v_lds_write_idx) { @@ -529,27 +530,25 @@ struct UnifiedAttentionPipeline index_t v_page_blk_idx = block_tables_ptr_[block_table_offset + (v_block_idx / kv_page_size_in_blocks)]; - // Calculate offset for this block - index_t offset = v_page_blk_idx * PageSize + - (v_block_idx % kv_page_size_in_blocks) * kPageBlockSize; - - // For large cache, check if we'd overflow int32 in set_window_origin + // For large cache, always rebase to avoid int32 overflow in set_window_origin if constexpr(kCachePtrInt32OverflowPossible) { - if(offset > kInt32Max) - { - // Rebase: advance pointer by offset, then use origin {0, 0} - auto& buf = v_dram_window.bottom_tensor_view_.buf_; - auto stride_0 = v_dram_window.bottom_tensor_view_.desc_.calculate_offset(make_tuple(1, 0)); - buf.p_data_ = buf.p_data_ + (static_cast(offset) * stride_0); - v_dram_window.init_raw(); - v_dram_window.set_window_origin({0, 0}); - return; - } - } + // these need to be cast to long_index_t to avoid int32 overflow + long_index_t block_offset = v_page_blk_idx * PageSize; + long_index_t elem_offset = (v_block_idx % kv_page_size_in_blocks) * kPageBlockSize; + long_index_t total_row_offset = block_offset + elem_offset; + // Rebase: set pointer to original base + (row_offset * row_stride), then use origin {0, 0} + auto& buf = v_dram_window.bottom_tensor_view_.buf_; + buf.p_data_ = v_original_ptr + (total_row_offset * v_row_stride); + v_dram_window.init_raw(); + v_dram_window.set_window_origin({0, 0}); + return; - // Fast path: no overflow, just set window origin - v_dram_window.set_window_origin({offset, 0}); + }else{ + index_t offset = v_page_blk_idx * PageSize + + (v_block_idx % kv_page_size_in_blocks) * kPageBlockSize; + v_dram_window.set_window_origin({offset, 0}); + } }; auto K_lds_load = [&](auto k_lds_read_idx) {