From ec29289bb18df4ea918700ff28cc891832747395 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Tue, 14 Oct 2025 12:04:11 +0000 Subject: [PATCH] kv paging --- .../pipeline/unified_attention_pipeline.hpp | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 7bc3dc1d7d..15a5e339ea 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -404,7 +404,6 @@ struct UnifiedAttentionPipeline { using namespace ck_tile; - index_t block_idx_prev = 0; static_assert( std::is_same_v> && @@ -577,22 +576,26 @@ struct UnifiedAttentionPipeline } } + index_t i_total_loops = 0; + index_t kv_blk_idx = block_tables_ptr[block_table_offset + i_total_loops]; + index_t kv_blk_idx_prev = 0; + + auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), k_dram_block_window_tmp.get_window_lengths(), - {seqlen_k_start, 0}, + {(kv_blk_idx - kv_blk_idx_prev) * BLOCK_SIZE, 0}, Policy::template MakeKDramTileDistribution()); k_dram_window.init_raw(); auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), - {seqlen_k_start, 0}, // TODO: hdim split? + {(kv_blk_idx - kv_blk_idx_prev) * BLOCK_SIZE, 0}, // TODO: hdim split? Policy::template MakeVDramTileDistribution()); v_dram_window.init_raw(); // prefetch K tile - index_t i_total_loops = 0; constexpr index_t k0_loops = kQKHeaddim / kK0; constexpr index_t k1_loops = kN0 / kK1; static_assert(1 == k0_loops); @@ -685,7 +688,10 @@ struct UnifiedAttentionPipeline /// FIXME: use the future-predicting method to move the window // move K tile windows - move_tile_window(k_dram_window, {kN0, 0}); + auto k_dram_window = make_tile_window(k_dram_window.get_bottom_tensor_view(), + k_dram_window.get_window_lengths(), + {(block_tables_ptr[block_table_offset + i_total_loops]) * BLOCK_SIZE, 0}, + Policy::template MakeVDramTileDistribution()); }; auto K_lds_load = [&](auto k_lds_read_idx) { @@ -696,7 +702,10 @@ struct UnifiedAttentionPipeline async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); /// FIXME: use the future-predicting method to move the window - move_tile_window(v_dram_window, {kK1, 0}); + auto v_dram_window = make_tile_window(v_dram_window.get_bottom_tensor_view(), + v_dram_window.get_window_lengths(), + {(block_tables_ptr[block_table_offset + i_total_loops]) * BLOCK_SIZE, 0}, + Policy::template MakeVDramTileDistribution()); }; auto V_lds_load = [&](auto v_lds_read_idx) {