mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
fix bugs
This commit is contained in:
@@ -287,7 +287,8 @@ struct FmhaBatchDecodeWithPagedKVCacheKernel
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
ck_tile::index_t num_splits,
|
||||
int32_t num_page_blocks,
|
||||
int32_t num_total_pages,
|
||||
int32_t max_num_blocks_per_seq,
|
||||
const void* block_table_ptr,
|
||||
ck_tile::index_t page_block_size,
|
||||
float scale_s,
|
||||
@@ -380,7 +381,8 @@ struct FmhaBatchDecodeWithPagedKVCacheKernel
|
||||
}
|
||||
if constexpr(kKVCacheEnum == KVCacheEnum::VLLM)
|
||||
{
|
||||
kargs.num_page_blocks = num_page_blocks;
|
||||
kargs.num_page_blocks = num_total_pages;
|
||||
kargs.max_num_blocks_per_seq = max_num_blocks_per_seq;
|
||||
kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
|
||||
kargs.page_block_size = page_block_size;
|
||||
}
|
||||
@@ -410,6 +412,7 @@ template <bool Cond = kIsGroupMode>
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
ck_tile::index_t num_splits,
|
||||
int32_t num_total_pages,
|
||||
int32_t max_num_blocks_per_seq,
|
||||
const void* block_table_ptr,
|
||||
ck_tile::index_t page_block_size,
|
||||
@@ -497,6 +500,7 @@ template <bool Cond = kIsGroupMode>
|
||||
if constexpr(kKVCacheEnum == KVCacheEnum::VLLM)
|
||||
{
|
||||
kargs.max_num_blocks_per_seq = max_num_blocks_per_seq;
|
||||
kargs.num_page_blocks = num_total_pages;
|
||||
kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
|
||||
kargs.page_block_size = page_block_size;
|
||||
|
||||
@@ -528,7 +532,7 @@ template <bool Cond = kIsGroupMode>
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
ck_tile::index_t num_splits,
|
||||
int32_t max_num_blocks_per_seq,
|
||||
int32_t num_total_pages,
|
||||
const void* kv_indptr,
|
||||
const void* kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
@@ -823,7 +827,6 @@ template <bool Cond = kIsGroupMode>
|
||||
long_index_t batch_offset_lse_acc = 0;
|
||||
long_index_t batch_offset_o_acc = 0;
|
||||
int32_t num_page_blocks = 0;
|
||||
int32_t* block_table_seq = nullptr;
|
||||
int32_t* kv_page_indices = nullptr;
|
||||
if constexpr(kKVCacheEnum == KVCacheEnum::SGLANG){
|
||||
num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch];
|
||||
@@ -831,7 +834,7 @@ template <bool Cond = kIsGroupMode>
|
||||
const int32_t last_page_len = kargs.kv_last_page_lens[i_batch];
|
||||
#endif
|
||||
}else if constexpr(kKVCacheEnum == KVCacheEnum::VLLM){
|
||||
num_page_blocks = kargs.num_page_blocks;
|
||||
num_page_blocks = kargs.num_total_pages;
|
||||
}
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
@@ -839,9 +842,9 @@ template <bool Cond = kIsGroupMode>
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
|
||||
kargs.kv_page_indices += kargs.kv_indptr[i_batch];
|
||||
|
||||
if constexpr(kKVCacheEnum == KVCacheEnum::SGLANG){
|
||||
kargs.kv_page_indices += kargs.kv_indptr[i_batch];
|
||||
}
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
batch_offset_bias = query_start * kargs.stride_bias;
|
||||
|
||||
Reference in New Issue
Block a user