ck tile pagedkv prefill (#2405)

* add prefetching physical block id for pagedkv

* start add pagedkv prefill

* rename pipeline

* add kernel for pagedkv

* add an init version pagedkv prefill

* fix redefine issue

* add struct BlockFmhaFwdPagedKVPipelineProblem and fmha_fwd_pagedkv_args

* generate dispatch code

* add body generating code

* comipling pass

* remove dropout from pagedkv

* set lse to false in generating code

* start changing qr kernel to pagedkv

* init version of  kernerl with pagedkv

* change names of file that are generated

* chang host validation for pagedkv prefill

* using iglp to change blockgemm

* add kernel files to op head file

* show parameters

* rewrite print parameter fun

* add fwd

* remove default parameter of GridSize

* format

* fix nhead issue and add seqlen_k_ptr to batch mode

* format code

* remove no-longer used code

* format

* fix some comments

---------

Co-authored-by: ltqin <letaoqin@amd.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
ltqin
2025-07-07 16:16:54 +08:00
committed by GitHub
parent 0aecb5ab68
commit 9f4c5d7372
15 changed files with 3520 additions and 12 deletions

View File

@@ -51,6 +51,27 @@ struct TrivialPageBlockNavigator
return /*block_index=*/0;
}
template <typename TileWindow>
CK_TILE_HOST_DEVICE index_t
move_tile_window(index_t /*block_index*/,
TileWindow& tile_window,
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step,
index_t /*id*/) const
{
ck_tile::move_tile_window(tile_window, step);
return 0;
}
template <typename TileWindow>
CK_TILE_HOST_DEVICE index_t
prefetch_table_id(index_t /*block_index*/,
TileWindow /*tile_window*/,
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& /*step*/) const
{
return -1;
}
CK_TILE_HOST_DEVICE static constexpr WindowOrigin
to_local_window_origin(const WindowOrigin& global_window_origin)
{
@@ -153,6 +174,56 @@ struct PageBlockNavigator
return new_block_index;
}
template <typename TileWindow>
CK_TILE_HOST_DEVICE index_t
move_tile_window(index_t block_index,
TileWindow& tile_window,
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step,
index_t id) const
{
ck_tile::move_tile_window(tile_window, step);
const WindowOrigin global_window_origin =
to_global_window_origin(block_index, tile_window.get_window_origin());
const WindowOrigin local_window_origin = to_local_window_origin(global_window_origin);
const index_t new_block_index = get_block_index(global_window_origin);
/// TODO: only update necessary attributes
tile_window.bottom_tensor_view_.desc_ =
(is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor();
tile_window.set_window_origin(local_window_origin);
if(id >= 0)
tile_window.set_bottom_tensor_view_data_ptr(physical_blocks + id * block_stride +
fixed_offset);
else
tile_window.set_bottom_tensor_view_data_ptr(nullptr);
return new_block_index;
}
template <typename TileWindow>
CK_TILE_HOST_DEVICE index_t
prefetch_table_id(index_t block_index,
TileWindow& tile_window,
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step) const
{
auto local_tile_window = tile_window; // not affect origin window
ck_tile::move_tile_window(local_tile_window, step);
const WindowOrigin global_window_origin =
to_global_window_origin(block_index, local_tile_window.get_window_origin());
const index_t new_block_index = get_block_index(global_window_origin);
if(new_block_index < num_blocks)
{
return physical_block_indices[new_block_index];
}
else
{
return -1;
}
}
CK_TILE_HOST_DEVICE bool is_last_block(index_t block_index) const
{
return block_index == num_blocks - 1;