kv paging

This commit is contained in:
Tianxing Wu
2025-10-14 12:04:11 +00:00
parent c87f2e3ca9
commit ec29289bb1

View File

@@ -404,7 +404,6 @@ struct UnifiedAttentionPipeline
{
using namespace ck_tile;
index_t block_idx_prev = 0;
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
@@ -577,22 +576,26 @@ struct UnifiedAttentionPipeline
}
}
index_t i_total_loops = 0;
index_t kv_blk_idx = block_tables_ptr[block_table_offset + i_total_loops];
index_t kv_blk_idx_prev = 0;
auto k_dram_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0},
{(kv_blk_idx - kv_blk_idx_prev) * BLOCK_SIZE, 0},
Policy::template MakeKDramTileDistribution<Problem>());
k_dram_window.init_raw();
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0}, // TODO: hdim split?
{(kv_blk_idx - kv_blk_idx_prev) * BLOCK_SIZE, 0}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
v_dram_window.init_raw();
// prefetch K tile
index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
static_assert(1 == k0_loops);
@@ -685,7 +688,10 @@ struct UnifiedAttentionPipeline
/// FIXME: use the future-predicting method to move the window
// move K tile windows
move_tile_window(k_dram_window, {kN0, 0});
auto k_dram_window = make_tile_window(k_dram_window.get_bottom_tensor_view(),
k_dram_window.get_window_lengths(),
{(block_tables_ptr[block_table_offset + i_total_loops]) * BLOCK_SIZE, 0},
Policy::template MakeVDramTileDistribution<Problem>());
};
auto K_lds_load = [&](auto k_lds_read_idx) {
@@ -696,7 +702,10 @@ struct UnifiedAttentionPipeline
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
/// FIXME: use the future-predicting method to move the window
move_tile_window(v_dram_window, {kK1, 0});
auto v_dram_window = make_tile_window(v_dram_window.get_bottom_tensor_view(),
v_dram_window.get_window_lengths(),
{(block_tables_ptr[block_table_offset + i_total_loops]) * BLOCK_SIZE, 0},
Policy::template MakeVDramTileDistribution<Problem>());
};
auto V_lds_load = [&](auto v_lds_read_idx) {