Add block_table kernel args for appendkv kernel

This commit is contained in:
PoYen, Chen
2024-08-07 04:27:15 +00:00
parent 15d0034a64
commit 443a528adc
2 changed files with 20 additions and 2 deletions

View File

@@ -502,6 +502,9 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
args.rotary_cos_ptr,
args.rotary_sin_ptr,
args.rotary_dim,
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.stride_q,
args.stride_k,
args.stride_knew,
@@ -532,6 +535,9 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
args.rotary_cos_ptr,
args.rotary_sin_ptr,
args.rotary_dim,
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.stride_q,
args.stride_k,
args.stride_knew,

View File

@@ -97,11 +97,11 @@ struct FmhaFwdAppendKVKernel
// for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
// if this param is larger than 1, indicate MQA/GQA case
ck_tile::index_t nhead_ratio_qk;
/*
const void* block_table_ptr;
ck_tile::index_t batch_stride_block_table;
ck_tile::index_t page_block_size;
*/
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_knew;
@@ -160,6 +160,9 @@ struct FmhaFwdAppendKVKernel
const void* rotary_cos_ptr,
const void* rotary_sin_ptr,
ck_tile::index_t rotary_dim,
const void* block_table_ptr,
ck_tile::index_t batch_stride_block_table,
ck_tile::index_t page_block_size,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_knew,
@@ -188,6 +191,9 @@ struct FmhaFwdAppendKVKernel
hdim_v,
num_head_q,
nhead_ratio_qk,
block_table_ptr,
batch_stride_block_table,
page_block_size,
stride_q,
stride_k,
stride_knew,
@@ -233,6 +239,9 @@ struct FmhaFwdAppendKVKernel
const void* rotary_cos_ptr,
const void* rotary_sin_ptr,
ck_tile::index_t rotary_dim,
const void* block_table_ptr,
ck_tile::index_t batch_stride_block_table,
ck_tile::index_t page_block_size,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_knew,
@@ -258,6 +267,9 @@ struct FmhaFwdAppendKVKernel
hdim_v,
num_head_q,
nhead_ratio_qk,
block_table_ptr,
batch_stride_block_table,
page_block_size,
stride_q,
stride_k,
stride_knew,