Add attn sink (#2892)

* enable attn sink

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* update attn_sink script

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* fix some error

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* clang-format

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* update fmha_bwd mask

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* update fmha_bwd_kernel'mask

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* update block_fmha_pipeline_qr_ks_vs.hpp

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* fix ci error

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* fix format error

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* Update block_fmha_bwd_pipeline_default_policy.hpp

* Update fmha_fwd_runner.hpp

* Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp

* Update fmha_fwd_runner.hpp

* Update fmha_fwd_runner.hpp

* Update fmha_fwd_runner.hpp

* update splitkv_pipline

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* update splitkv&pagedkv pipeline

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* add sink test

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* update attn_sink result log

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* update smoke_test_fwd_sink.sh

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* update test file

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* update test script

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* Update block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp

* use constexpr kHasSink for sink in fmha pipeline

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* update by pre-commit

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update fmha_fwd.py

* Update example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Remove causal mask setting logic from mask.hpp

Removed the mask setting logic for causal masks.

* fix ci error that some usage of lamada not support in c++17

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* Update remod.py

* add smoke sink test

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* Update fmha_pagedkv_prefill.py

* Update FmhaFwdPipeline parameters in fmha_fwd.py

* update block_fmha_pipeline_qr_ks_vs_async_trload.hpp

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* fix c++17 unsupprot error

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* Update block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp

* Fix formatting of sink_seq_end assignment

* Fix indentation for sink_seq_end assignment

* Update block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp

---------

Signed-off-by: JL-underdog <Jun.Lin@amd.com>
Signed-off-by: LJ-underdog <Jun.Lin@amd.com>
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Linjun-AMD
2025-11-20 19:24:05 +08:00
committed by GitHub
parent 84540edff3
commit 9fa4e8d5ab
25 changed files with 940 additions and 195 deletions

View File

@@ -265,6 +265,7 @@ struct fmha_fwd_args
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t sink_size;
ck_tile::index_t mask_type;
ck_tile::index_t min_seqlen_q;
@@ -351,6 +352,7 @@ struct fmha_fwd_pagedkv_args
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t sink_size;
ck_tile::index_t mask_type;
ck_tile::index_t min_seqlen_q;
};
@@ -441,6 +443,7 @@ struct fmha_fwd_splitkv_args
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t sink_size;
ck_tile::index_t mask_type;
};
@@ -611,6 +614,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_o,
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type,
args.min_seqlen_q,
args.p_drop,
@@ -660,6 +664,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.batch_stride_o,
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type,
args.p_drop,
args.s_randval,
@@ -727,6 +732,7 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
args.batch_stride_v,
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type,
args.min_seqlen_q);
}
@@ -772,6 +778,7 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
args.batch_stride_o,
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type);
}
}();
@@ -838,6 +845,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.split_stride_o_acc,
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type);
}
else
@@ -885,6 +893,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.split_stride_o_acc,
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type);
}
}();
@@ -1131,7 +1140,8 @@ template <ck_tile::index_t HDim_,
bool kPadD_,
bool kPadDv_,
bool kUseTrLoad_,
bool kSkipMinSeqlenQ_ = false>
bool kSkipMinSeqlenQ_ = false,
bool kHasSink_ = false>
struct fmha_fwd_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
@@ -1157,6 +1167,7 @@ struct fmha_fwd_traits_
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kUseTrLoad = kUseTrLoad_;
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
static constexpr bool kHasSink = kHasSink_;
};
template <typename Traits_, typename Arch = void>
@@ -1183,7 +1194,8 @@ template <ck_tile::index_t HDim_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_,
bool kSkipMinSeqlenQ_ = false>
bool kSkipMinSeqlenQ_ = false,
bool kHasSink_ = false>
struct fmha_fwd_pagedkv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
@@ -1208,6 +1220,7 @@ struct fmha_fwd_pagedkv_traits_
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
static constexpr bool kHasSink = kHasSink_;
};
template <typename Traits_, typename Arch = void>
@@ -1230,6 +1243,7 @@ template <ck_tile::index_t HDim_,
bool kStoreLse_,
bool kDoFp8StaticQuant_,
bool kIsPagedKV_,
bool kHasSink_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
@@ -1257,6 +1271,7 @@ struct fmha_fwd_splitkv_traits_
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kIsPagedKV = kIsPagedKV_;
static constexpr bool kHasSink = kHasSink_;
};
template <typename Traits_, typename Arch = void>
@@ -1343,6 +1358,7 @@ struct fmha_fwd_traits
bool has_dropout;
bool do_fp8_static_quant;
bool skip_min_seqlen_q = false;
bool has_sink = false;
// TODO: padding check is inside this api
};
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
@@ -1361,6 +1377,7 @@ struct fmha_fwd_pagedkv_traits
bool use_pagedkv = true;
bool do_fp8_static_quant = false;
bool skip_min_seqlen_q = false;
bool has_sink = false;
// TODO: padding check is inside this api
};
@@ -1380,6 +1397,7 @@ struct fmha_fwd_splitkv_traits
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_lse;
bool do_fp8_static_quant;
bool has_sink = false;
// TODO: padding check is inside this api
};
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits,