mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
[VLLM V1] Add chunked prefill for FA to pass seq with small seqlen_q (#2221)
* fix splitkv compiler issue since lse is used to select kernel instances
* bypass seqlen == 1
* add chunked prefill into mha varlen
This reverts commit aa9847e42d.
* skip compile when receipt 2-4 and add comments
* fix
---------
Co-authored-by: fsx950223 <fsx950223@outlook.com>
This commit is contained in:
@@ -53,6 +53,8 @@ struct FmhaFwdKernel
|
||||
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
|
||||
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
|
||||
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
|
||||
static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
|
||||
|
||||
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
static constexpr bool kHasMask = FmhaMask::IsMasking;
|
||||
@@ -257,6 +259,11 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_randval = 0;
|
||||
};
|
||||
|
||||
struct FmhaFwdSkipMinSeqlenQKargs
|
||||
{
|
||||
ck_tile::index_t min_seqlen_q = 0;
|
||||
};
|
||||
|
||||
struct FmhaFwdBatchModeKargs
|
||||
: FmhaFwdCommonKargs,
|
||||
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
|
||||
@@ -287,7 +294,8 @@ struct FmhaFwdKernel
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
|
||||
std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
@@ -664,6 +672,7 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t min_seqlen_q,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
@@ -698,6 +707,7 @@ struct FmhaFwdKernel
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
{}, // placeholder for dropout
|
||||
{}, // placeholder for logits_soft_cap
|
||||
{}, // placeholder for min_seqlen_q
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
|
||||
@@ -753,6 +763,10 @@ struct FmhaFwdKernel
|
||||
{
|
||||
kargs.init_logits_soft_cap(logits_soft_cap);
|
||||
}
|
||||
if constexpr(kSkipMinSeqlenQ)
|
||||
{
|
||||
kargs.min_seqlen_q = min_seqlen_q;
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -1053,6 +1067,14 @@ struct FmhaFwdKernel
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
|
||||
if constexpr(kSkipMinSeqlenQ)
|
||||
{
|
||||
if(kargs.seqlen_q <= kargs.min_seqlen_q)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
if(kargs.seqlen_q <= i_m0)
|
||||
|
||||
Reference in New Issue
Block a user