This commit is contained in:
fsx950223
2025-06-03 07:17:56 +00:00
parent 8e1dd4e7f9
commit 7ce4f50da6

View File

@@ -287,7 +287,8 @@ struct FmhaBatchDecodeWithPagedKVCacheKernel
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
ck_tile::index_t num_splits,
int32_t num_page_blocks,
int32_t num_total_pages,
int32_t max_num_blocks_per_seq,
const void* block_table_ptr,
ck_tile::index_t page_block_size,
float scale_s,
@@ -380,7 +381,8 @@ struct FmhaBatchDecodeWithPagedKVCacheKernel
}
if constexpr(kKVCacheEnum == KVCacheEnum::VLLM)
{
kargs.num_page_blocks = num_page_blocks;
kargs.num_page_blocks = num_total_pages;
kargs.max_num_blocks_per_seq = max_num_blocks_per_seq;
kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
kargs.page_block_size = page_block_size;
}
@@ -410,6 +412,7 @@ template <bool Cond = kIsGroupMode>
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
ck_tile::index_t num_splits,
int32_t num_total_pages,
int32_t max_num_blocks_per_seq,
const void* block_table_ptr,
ck_tile::index_t page_block_size,
@@ -497,6 +500,7 @@ template <bool Cond = kIsGroupMode>
if constexpr(kKVCacheEnum == KVCacheEnum::VLLM)
{
kargs.max_num_blocks_per_seq = max_num_blocks_per_seq;
kargs.num_page_blocks = num_total_pages;
kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
kargs.page_block_size = page_block_size;
@@ -528,7 +532,7 @@ template <bool Cond = kIsGroupMode>
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
ck_tile::index_t num_splits,
int32_t max_num_blocks_per_seq,
int32_t num_total_pages,
const void* kv_indptr,
const void* kv_page_indices,
#if 0 // we assume page_block_size=1 for now
@@ -823,7 +827,6 @@ template <bool Cond = kIsGroupMode>
long_index_t batch_offset_lse_acc = 0;
long_index_t batch_offset_o_acc = 0;
int32_t num_page_blocks = 0;
int32_t* block_table_seq = nullptr;
int32_t* kv_page_indices = nullptr;
if constexpr(kKVCacheEnum == KVCacheEnum::SGLANG){
num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch];
@@ -831,7 +834,7 @@ template <bool Cond = kIsGroupMode>
const int32_t last_page_len = kargs.kv_last_page_lens[i_batch];
#endif
}else if constexpr(kKVCacheEnum == KVCacheEnum::VLLM){
num_page_blocks = kargs.num_page_blocks;
num_page_blocks = kargs.num_total_pages;
}
if constexpr(kIsGroupMode)
{
@@ -839,9 +842,9 @@ template <bool Cond = kIsGroupMode>
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q;
kargs.kv_page_indices += kargs.kv_indptr[i_batch];
if constexpr(kKVCacheEnum == KVCacheEnum::SGLANG){
kargs.kv_page_indices += kargs.kv_indptr[i_batch];
}
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = query_start * kargs.stride_bias;