Add block_table kernel args for appendkv kernel

This commit is contained in:
PoYen, Chen
2024-08-07 04:27:15 +00:00
parent 15d0034a64
commit 443a528adc
2 changed files with 20 additions and 2 deletions

View File

@@ -97,11 +97,11 @@ struct FmhaFwdAppendKVKernel
// for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
// if this param is larger than 1, indicate MQA/GQA case
ck_tile::index_t nhead_ratio_qk;
/*
const void* block_table_ptr;
ck_tile::index_t batch_stride_block_table;
ck_tile::index_t page_block_size;
*/
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_knew;
@@ -160,6 +160,9 @@ struct FmhaFwdAppendKVKernel
const void* rotary_cos_ptr,
const void* rotary_sin_ptr,
ck_tile::index_t rotary_dim,
const void* block_table_ptr,
ck_tile::index_t batch_stride_block_table,
ck_tile::index_t page_block_size,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_knew,
@@ -188,6 +191,9 @@ struct FmhaFwdAppendKVKernel
hdim_v,
num_head_q,
nhead_ratio_qk,
block_table_ptr,
batch_stride_block_table,
page_block_size,
stride_q,
stride_k,
stride_knew,
@@ -233,6 +239,9 @@ struct FmhaFwdAppendKVKernel
const void* rotary_cos_ptr,
const void* rotary_sin_ptr,
ck_tile::index_t rotary_dim,
const void* block_table_ptr,
ck_tile::index_t batch_stride_block_table,
ck_tile::index_t page_block_size,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_knew,
@@ -258,6 +267,9 @@ struct FmhaFwdAppendKVKernel
hdim_v,
num_head_q,
nhead_ratio_qk,
block_table_ptr,
batch_stride_block_table,
page_block_size,
stride_q,
stride_k,
stride_knew,