mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
kv paging
This commit is contained in:
@@ -404,7 +404,6 @@ struct UnifiedAttentionPipeline
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
index_t block_idx_prev = 0;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -577,22 +576,26 @@ struct UnifiedAttentionPipeline
|
||||
}
|
||||
}
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
index_t kv_blk_idx = block_tables_ptr[block_table_offset + i_total_loops];
|
||||
index_t kv_blk_idx_prev = 0;
|
||||
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_k_start, 0},
|
||||
{(kv_blk_idx - kv_blk_idx_prev) * BLOCK_SIZE, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
k_dram_window.init_raw();
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_k_start, 0}, // TODO: hdim split?
|
||||
{(kv_blk_idx - kv_blk_idx_prev) * BLOCK_SIZE, 0}, // TODO: hdim split?
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
v_dram_window.init_raw();
|
||||
|
||||
// prefetch K tile
|
||||
index_t i_total_loops = 0;
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
static_assert(1 == k0_loops);
|
||||
@@ -685,7 +688,10 @@ struct UnifiedAttentionPipeline
|
||||
|
||||
/// FIXME: use the future-predicting method to move the window
|
||||
// move K tile windows
|
||||
move_tile_window(k_dram_window, {kN0, 0});
|
||||
auto k_dram_window = make_tile_window(k_dram_window.get_bottom_tensor_view(),
|
||||
k_dram_window.get_window_lengths(),
|
||||
{(block_tables_ptr[block_table_offset + i_total_loops]) * BLOCK_SIZE, 0},
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
};
|
||||
|
||||
auto K_lds_load = [&](auto k_lds_read_idx) {
|
||||
@@ -696,7 +702,10 @@ struct UnifiedAttentionPipeline
|
||||
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
|
||||
|
||||
/// FIXME: use the future-predicting method to move the window
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
auto v_dram_window = make_tile_window(v_dram_window.get_bottom_tensor_view(),
|
||||
v_dram_window.get_window_lengths(),
|
||||
{(block_tables_ptr[block_table_offset + i_total_loops]) * BLOCK_SIZE, 0},
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
};
|
||||
|
||||
auto V_lds_load = [&](auto v_lds_read_idx) {
|
||||
|
||||
Reference in New Issue
Block a user