diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_decode_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_decode_kernel.hpp index 03b457d9db..23b4fc2574 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_decode_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_decode_kernel.hpp @@ -287,7 +287,8 @@ struct FmhaBatchDecodeWithPagedKVCacheKernel ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, ck_tile::index_t num_splits, - int32_t num_page_blocks, + int32_t num_total_pages, + int32_t max_num_blocks_per_seq, const void* block_table_ptr, ck_tile::index_t page_block_size, float scale_s, @@ -380,7 +381,8 @@ struct FmhaBatchDecodeWithPagedKVCacheKernel } if constexpr(kKVCacheEnum == KVCacheEnum::VLLM) { - kargs.num_page_blocks = num_page_blocks; + kargs.num_page_blocks = num_total_pages; + kargs.max_num_blocks_per_seq = max_num_blocks_per_seq; kargs.block_table_ptr = reinterpret_cast(block_table_ptr); kargs.page_block_size = page_block_size; } @@ -410,6 +412,7 @@ template ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, ck_tile::index_t num_splits, + int32_t num_total_pages, int32_t max_num_blocks_per_seq, const void* block_table_ptr, ck_tile::index_t page_block_size, @@ -497,6 +500,7 @@ template if constexpr(kKVCacheEnum == KVCacheEnum::VLLM) { kargs.max_num_blocks_per_seq = max_num_blocks_per_seq; + kargs.num_page_blocks = num_total_pages; kargs.block_table_ptr = reinterpret_cast(block_table_ptr); kargs.page_block_size = page_block_size; @@ -528,7 +532,7 @@ template ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, ck_tile::index_t num_splits, - int32_t max_num_blocks_per_seq, + int32_t num_total_pages, const void* kv_indptr, const void* kv_page_indices, #if 0 // we assume page_block_size=1 for now @@ -823,7 +827,6 @@ template long_index_t batch_offset_lse_acc = 0; long_index_t batch_offset_o_acc = 0; int32_t num_page_blocks = 0; - int32_t* block_table_seq = nullptr; int32_t* kv_page_indices = nullptr; if constexpr(kKVCacheEnum == KVCacheEnum::SGLANG){ num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch]; @@ -831,7 +834,7 @@ template const int32_t last_page_len = kargs.kv_last_page_lens[i_batch]; #endif }else if constexpr(kKVCacheEnum == KVCacheEnum::VLLM){ - num_page_blocks = kargs.num_page_blocks; + num_page_blocks = kargs.num_total_pages; } if constexpr(kIsGroupMode) { @@ -839,9 +842,9 @@ template const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; batch_offset_q = query_start * kargs.stride_q; - - kargs.kv_page_indices += kargs.kv_indptr[i_batch]; - + if constexpr(kKVCacheEnum == KVCacheEnum::SGLANG){ + kargs.kv_page_indices += kargs.kv_indptr[i_batch]; + } if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { batch_offset_bias = query_start * kargs.stride_bias;