Add attn sink (#2892)

* enable attn sink

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* update attn_sink script

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* fix some error

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* clang-format

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* update fmha_bwd mask

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* update fmha_bwd_kernel'mask

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* update block_fmha_pipeline_qr_ks_vs.hpp

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* fix ci error

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* fix format error

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* Update block_fmha_bwd_pipeline_default_policy.hpp

* Update fmha_fwd_runner.hpp

* Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp

* Update fmha_fwd_runner.hpp

* Update fmha_fwd_runner.hpp

* Update fmha_fwd_runner.hpp

* update splitkv_pipline

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* update splitkv&pagedkv pipeline

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* add sink test

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* update attn_sink result log

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* update smoke_test_fwd_sink.sh

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* update test file

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* update test script

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* Update block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp

* use constexpr kHasSink for sink in fmha pipeline

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* update by pre-commit

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update fmha_fwd.py

* Update example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Remove causal mask setting logic from mask.hpp

Removed the mask setting logic for causal masks.

* fix ci error that some usage of lamada not support in c++17

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* Update remod.py

* add smoke sink test

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* Update fmha_pagedkv_prefill.py

* Update FmhaFwdPipeline parameters in fmha_fwd.py

* update block_fmha_pipeline_qr_ks_vs_async_trload.hpp

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* fix c++17 unsupprot error

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* Update block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp

* Fix formatting of sink_seq_end assignment

* Fix indentation for sink_seq_end assignment

* Update block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp

---------

Signed-off-by: JL-underdog <Jun.Lin@amd.com>
Signed-off-by: LJ-underdog <Jun.Lin@amd.com>
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Linjun-AMD
2025-11-20 19:24:05 +08:00
committed by GitHub
parent 84540edff3
commit 9fa4e8d5ab
25 changed files with 940 additions and 195 deletions

View File

@@ -198,7 +198,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
struct FmhaFwdMaskKargs
{
// ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right, sink_size;
ck_tile::GenericAttentionMaskEnum mask_type;
};
@@ -362,6 +362,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
@@ -425,6 +426,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
@@ -509,6 +511,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
ck_tile::index_t batch_stride_v,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
@@ -570,6 +573,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
@@ -1026,6 +1030,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
kargs.window_size_left,
kargs.window_size_right,
kargs.sink_size,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);

View File

@@ -56,6 +56,7 @@ struct FmhaFwdKernel
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
static constexpr bool kHasSink = FmhaPipeline::kHasSink;
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
@@ -112,7 +113,7 @@ struct FmhaFwdKernel
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kUseTrLoad ? "_trload" : "_ntrload");
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kUseTrLoad ? "_trload" : "_ntrload") + (kHasSink ? "_sink" : "_nsink");
#undef _SS_
#undef _TS_
// clang-format on
@@ -200,7 +201,7 @@ struct FmhaFwdKernel
struct FmhaFwdMaskKargs
{
// ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right, sink_size;
ck_tile::GenericAttentionMaskEnum mask_type;
};
@@ -374,6 +375,7 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
@@ -432,6 +434,7 @@ struct FmhaFwdKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
@@ -518,6 +521,7 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
@@ -565,6 +569,7 @@ struct FmhaFwdKernel
batch_stride_o,
window_size_left,
window_size_right,
sink_size,
mask_type,
p_drop,
s_randval,
@@ -615,6 +620,7 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
@@ -662,6 +668,7 @@ struct FmhaFwdKernel
batch_stride_o,
window_size_left,
window_size_right,
sink_size,
mask_type,
p_drop,
s_randval,
@@ -706,6 +713,7 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q,
float p_drop,
@@ -765,6 +773,7 @@ struct FmhaFwdKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
@@ -848,6 +857,7 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q,
float p_drop,
@@ -891,6 +901,7 @@ struct FmhaFwdKernel
nhead_stride_o,
window_size_left,
window_size_right,
sink_size,
mask_type,
min_seqlen_q,
p_drop,
@@ -937,6 +948,7 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q,
float p_drop,
@@ -980,6 +992,7 @@ struct FmhaFwdKernel
nhead_stride_o,
window_size_left,
window_size_right,
sink_size,
mask_type,
min_seqlen_q,
p_drop,
@@ -1471,6 +1484,7 @@ struct FmhaFwdKernel
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
kargs.window_size_left,
kargs.window_size_right,
kargs.sink_size,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
@@ -2200,6 +2214,7 @@ struct FmhaFwdKernel
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
kargs.window_size_left,
kargs.window_size_right,
kargs.sink_size,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);

View File

@@ -55,6 +55,7 @@ struct FmhaFwdPagedKVKernel
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
static constexpr bool kHasSink = FmhaPipeline::kHasSink;
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
@@ -101,7 +102,7 @@ struct FmhaFwdPagedKVKernel
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" );
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" );
#undef _SS_
#undef _TS_
// clang-format on
@@ -189,7 +190,7 @@ struct FmhaFwdPagedKVKernel
struct FmhaFwdMaskKargs
{
// ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right, sink_size;
ck_tile::GenericAttentionMaskEnum mask_type;
};
@@ -326,6 +327,7 @@ struct FmhaFwdPagedKVKernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type)
{
Kargs kargs{{q_ptr,
@@ -379,6 +381,7 @@ struct FmhaFwdPagedKVKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
@@ -453,6 +456,7 @@ struct FmhaFwdPagedKVKernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type)
{
return MakeKargsImpl(q_ptr,
@@ -495,6 +499,7 @@ struct FmhaFwdPagedKVKernel
batch_stride_o,
window_size_left,
window_size_right,
sink_size,
mask_type);
}
@@ -536,6 +541,7 @@ struct FmhaFwdPagedKVKernel
ck_tile::index_t batch_stride_v, // only used for paged-kvcache
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q)
{
@@ -590,6 +596,7 @@ struct FmhaFwdPagedKVKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
@@ -660,6 +667,7 @@ struct FmhaFwdPagedKVKernel
ck_tile::index_t batch_stride_v, // only used for paged-kvcache
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q)
{
@@ -699,6 +707,7 @@ struct FmhaFwdPagedKVKernel
batch_stride_v,
window_size_left,
window_size_right,
sink_size,
mask_type,
min_seqlen_q);
}
@@ -1257,6 +1266,7 @@ struct FmhaFwdPagedKVKernel
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
kargs.window_size_left,
kargs.window_size_right,
kargs.sink_size,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);

View File

@@ -51,6 +51,7 @@ struct FmhaFwdSplitKVKernel
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
static constexpr bool kHasSink = FmhaPipeline::Problem::kHasSink;
static constexpr bool kMergeNumHeadGroupsSeqLenQ =
FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ;
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
@@ -101,7 +102,7 @@ struct FmhaFwdSplitKVKernel
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) +
(kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" );
(kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" );
#undef _SS_
#undef _TS_
// clang-format on
@@ -198,7 +199,7 @@ struct FmhaFwdSplitKVKernel
struct MaskKargs
{
// ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right, sink_size;
ck_tile::GenericAttentionMaskEnum mask_type;
};
@@ -325,6 +326,7 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t split_stride_o_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type)
{
Kargs kargs{{q_ptr,
@@ -384,6 +386,7 @@ struct FmhaFwdSplitKVKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kDoFp8StaticQuant)
@@ -451,6 +454,7 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t split_stride_o_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type)
{
Kargs kargs{{q_ptr,
@@ -508,6 +512,7 @@ struct FmhaFwdSplitKVKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kDoFp8StaticQuant)
@@ -994,6 +999,7 @@ struct FmhaFwdSplitKVKernel
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
kargs.window_size_left,
kargs.window_size_right,
kargs.sink_size,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);