[VLLM V1] Add chunked prefill for FA to pass seq with small seqlen_q (#2221)

* fix splitkv compiler issue since lse is used to select kernel instances

* bypass seqlen == 1

* add chunked prefill into mha varlen

This reverts commit aa9847e42d.

* skip compile when receipt 2-4 and add comments

* fix

---------

Co-authored-by: fsx950223 <fsx950223@outlook.com>
This commit is contained in:
Zzz9990
2025-05-26 19:17:18 +08:00
committed by GitHub
parent 8146e471f1
commit ece38b9d7a
7 changed files with 72 additions and 29 deletions

View File

@@ -53,6 +53,8 @@ struct FmhaFwdKernel
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
@@ -257,6 +259,11 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_randval = 0;
};
struct FmhaFwdSkipMinSeqlenQKargs
{
ck_tile::index_t min_seqlen_q = 0;
};
struct FmhaFwdBatchModeKargs
: FmhaFwdCommonKargs,
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
@@ -287,7 +294,8 @@ struct FmhaFwdKernel
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
@@ -664,6 +672,7 @@ struct FmhaFwdKernel
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q,
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
@@ -698,6 +707,7 @@ struct FmhaFwdKernel
{}, // placeholder for fp8_static_quant args
{}, // placeholder for dropout
{}, // placeholder for logits_soft_cap
{}, // placeholder for min_seqlen_q
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
@@ -753,6 +763,10 @@ struct FmhaFwdKernel
{
kargs.init_logits_soft_cap(logits_soft_cap);
}
if constexpr(kSkipMinSeqlenQ)
{
kargs.min_seqlen_q = min_seqlen_q;
}
return kargs;
}
@@ -1053,6 +1067,14 @@ struct FmhaFwdKernel
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
if constexpr(kSkipMinSeqlenQ)
{
if(kargs.seqlen_q <= kargs.min_seqlen_q)
{
return;
}
}
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if(kargs.seqlen_q <= i_m0)