diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 7c1facc545..087a8872b9 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -468,7 +468,9 @@ struct UnifiedAttentionKernel kv_page_size_in_blocks, mask, kargs.scale_s, - smem_ptr); + smem_ptr, + static_cast(kargs.stride_k_cache_1), + static_cast(kargs.stride_v_cache_1)); }(); // O DRAM and O DRAM window diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 3f60dff312..29617948df 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -187,7 +187,9 @@ struct UnifiedAttentionPipeline const OAccElementFunction& o_acc_element_func, FmhaMask mask, float scale_s, - void* smem_ptr) const + void* smem_ptr, + long_index_t k_row_stride = 0, + long_index_t v_row_stride = 0) const { using namespace ck_tile; static_assert( @@ -361,17 +363,53 @@ struct UnifiedAttentionPipeline block_table_offset += num_blocks_start; index_t kv_blk_idx_initial = block_tables_ptr_[block_table_offset + k_block_idx]; + // When row strides are provided, use pointer rebasing to avoid int32 overflow + // in tensor_coordinate::get_offset() for large KV pools (>131K blocks for d64/GQA-8). + // When strides are 0 (legacy callers), use the original set_window_origin approach. + // Use pointer rebasing to avoid int32 overflow in tensor_coordinate for large KV pools. + // Only enabled when row strides are provided (from kernel) and for hdim <= 64 configs. + // hdim=128 configs have different buffer_view internals that cause issues with rebasing. + const bool use_ptr_rebase = (k_row_stride > 0 && v_row_stride > 0 && kHeadDim <= 64); + + auto k_view = k_dram_block_window_tmp.get_bottom_tensor_view(); + auto v_view = v_dram_block_window_tmp.get_bottom_tensor_view(); + auto* k_base_ptr = k_view.buf_.p_data_; + auto* v_base_ptr = v_view.buf_.p_data_; + + if(use_ptr_rebase) + { + long_index_t k_off = + static_cast(kv_blk_idx_initial) * PageSize * k_row_stride; + k_view.buf_.p_data_ = k_base_ptr + k_off; + auto new_k = k_view.buf_.buffer_size_ - k_off; + k_view.buf_.buffer_size_ = new_k > 0 ? new_k : kPageBlockSize * kHeadDim; + + long_index_t v_off = + static_cast(kv_blk_idx_initial) * PageSize * v_row_stride; + v_view.buf_.p_data_ = v_base_ptr + v_off; + auto new_v = v_view.buf_.buffer_size_ - v_off; + v_view.buf_.buffer_size_ = new_v > 0 ? new_v : kPageBlockSize * kHeadDim; + } + else + { + // Legacy path: use original view with absolute window origin + k_view = k_dram_block_window_tmp.get_bottom_tensor_view(); + v_view = v_dram_block_window_tmp.get_bottom_tensor_view(); + } + + const index_t init_origin = use_ptr_rebase ? 0 : kv_blk_idx_initial * PageSize; + auto k_dram_window = - make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + make_tile_window(k_view, k_dram_block_window_tmp.get_window_lengths(), - {kv_blk_idx_initial * PageSize, 0}, + {init_origin, 0}, Policy::template MakeKDramTileDistribution()); k_dram_window.init_raw(); auto v_dram_window = - make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + make_tile_window(v_view, v_dram_block_window_tmp.get_window_lengths(), - {kv_blk_idx_initial * PageSize, 0}, + {init_origin, 0}, Policy::template MakeVDramTileDistribution()); v_dram_window.init_raw(); @@ -469,16 +507,39 @@ struct UnifiedAttentionPipeline // index_t kv_block_idx = 0; // only for block 0 and thread if(blockIdx.x == 0 && threadIdx.x == 0) {} + + auto rebase_window = [](auto& window, auto* base_ptr, long_index_t elem_offset, + auto buf_size_orig) { + window.bottom_tensor_view_.buf_.p_data_ = base_ptr + elem_offset; + auto new_size = buf_size_orig - elem_offset; + window.bottom_tensor_view_.buf_.buffer_size_ = new_size > 0 ? new_size : kPageBlockSize * kHeadDim; + window.init_raw(); + window.set_window_origin({0, 0}); + }; + + const auto k_buf_size_orig = k_dram_window.bottom_tensor_view_.buf_.buffer_size_; + const auto v_buf_size_orig = v_dram_window.bottom_tensor_view_.buf_.buffer_size_; + auto K_mem_load = [&](auto k_lds_write_idx) { async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); k_block_idx++; index_t k_page_blk_idx = block_tables_ptr_[block_table_offset + (k_block_idx / kv_page_size_in_blocks)]; - k_dram_window.set_window_origin( - {k_page_blk_idx * PageSize + - (k_block_idx % kv_page_size_in_blocks) * kPageBlockSize, - 0}); + if(use_ptr_rebase) + { + long_index_t k_row = + static_cast(k_page_blk_idx) * PageSize + + (k_block_idx % kv_page_size_in_blocks) * kPageBlockSize; + rebase_window(k_dram_window, k_base_ptr, k_row * k_row_stride, k_buf_size_orig); + } + else + { + k_dram_window.set_window_origin( + {k_page_blk_idx * PageSize + + (k_block_idx % kv_page_size_in_blocks) * kPageBlockSize, + 0}); + } }; auto V_mem_load = [&](auto v_lds_write_idx) { @@ -487,10 +548,20 @@ struct UnifiedAttentionPipeline index_t v_page_blk_idx = block_tables_ptr_[block_table_offset + (v_block_idx / kv_page_size_in_blocks)]; - v_dram_window.set_window_origin( - {v_page_blk_idx * PageSize + - (v_block_idx % kv_page_size_in_blocks) * kPageBlockSize, - 0}); + if(use_ptr_rebase) + { + long_index_t v_row = + static_cast(v_page_blk_idx) * PageSize + + (v_block_idx % kv_page_size_in_blocks) * kPageBlockSize; + rebase_window(v_dram_window, v_base_ptr, v_row * v_row_stride, v_buf_size_orig); + } + else + { + v_dram_window.set_window_origin( + {v_page_blk_idx * PageSize + + (v_block_idx % kv_page_size_in_blocks) * kPageBlockSize, + 0}); + } }; auto K_lds_load = [&](auto k_lds_read_idx) { @@ -1123,7 +1194,9 @@ struct UnifiedAttentionPipeline const index_t kv_page_size_in_blocks, FmhaMask mask, float scale_s, - void* smem_ptr) const + void* smem_ptr, + long_index_t k_row_stride = 0, + long_index_t v_row_stride = 0) const { using namespace ck_tile; @@ -1143,7 +1216,9 @@ struct UnifiedAttentionPipeline identity{}, mask, scale_s, - smem_ptr); + smem_ptr, + k_row_stride, + v_row_stride); } };