Add attn sink (#2892)

* enable attn sink

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* update attn_sink script

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* fix some error

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* clang-format

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* update fmha_bwd mask

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* update fmha_bwd_kernel'mask

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* update block_fmha_pipeline_qr_ks_vs.hpp

Signed-off-by: JL-underdog <Jun.Lin@amd.com>

* fix ci error

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* fix format error

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* Update block_fmha_bwd_pipeline_default_policy.hpp

* Update fmha_fwd_runner.hpp

* Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp

* Update fmha_fwd_runner.hpp

* Update fmha_fwd_runner.hpp

* Update fmha_fwd_runner.hpp

* update splitkv_pipline

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* update splitkv&pagedkv pipeline

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* add sink test

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* update attn_sink result log

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* update smoke_test_fwd_sink.sh

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* update test file

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* update test script

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* Update block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp

* use constexpr kHasSink for sink in fmha pipeline

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* update by pre-commit

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update fmha_fwd.py

* Update example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Remove causal mask setting logic from mask.hpp

Removed the mask setting logic for causal masks.

* fix ci error that some usage of lamada not support in c++17

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* Update remod.py

* add smoke sink test

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* Update fmha_pagedkv_prefill.py

* Update FmhaFwdPipeline parameters in fmha_fwd.py

* update block_fmha_pipeline_qr_ks_vs_async_trload.hpp

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* fix c++17 unsupprot error

Signed-off-by: LJ-underdog <Jun.Lin@amd.com>

* Update block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp

* Fix formatting of sink_seq_end assignment

* Fix indentation for sink_seq_end assignment

* Update block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp

---------

Signed-off-by: JL-underdog <Jun.Lin@amd.com>
Signed-off-by: LJ-underdog <Jun.Lin@amd.com>
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Linjun-AMD
2025-11-20 19:24:05 +08:00
committed by GitHub
parent 84540edff3
commit 9fa4e8d5ab
25 changed files with 940 additions and 195 deletions

View File

@@ -25,6 +25,7 @@ struct mask_info
ck_tile::index_t seqlen_k;
ck_tile::index_t y, x;
ck_tile::index_t left, right; // FA style SWA left/right
ck_tile::index_t sink;
void serialize(std::ostream& os) const
{
@@ -58,13 +59,14 @@ struct mask_info
ck_tile::index_t window_size = std::stoi(v);
ck_tile::index_t left_size = -1;
ck_tile::index_t right_size = 0;
ck_tile::index_t sink_size = 0;
if(window_size > 0)
{
left_size = window_size / 2;
right_size = window_size - 1 - left_size;
}
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
left_size, right_size, y_total, x_total, t == "xt");
left_size, right_size, sink_size, y_total, x_total, t == "xt");
tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right;
tmp.y = r.at(ck_tile::number<0>{});
@@ -79,27 +81,54 @@ struct mask_info
{
throw std::invalid_argument("invalid mask value: " + str);
}
ck_tile::index_t v0 = std::stoi(v.substr(0, found_1));
ck_tile::index_t v1 = std::stoi(v.substr(found_1 + 1));
tmp.type = mask_enum::window_generic;
ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str());
auto found_2 = v.find(',', found_1 + 1);
ck_tile::index_t v1 = 0;
ck_tile::index_t sink = 0;
// ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str());
// TODO: some validation
if(t == "t")
{
if(found_2 != std::string::npos)
{
v1 = atoi(v.substr(found_1 + 1, found_2 - found_1 - 1).c_str());
sink = atoi(v.substr(found_2 + 1).c_str());
}
else
{
v1 = atoi(v.substr(found_1 + 1).c_str());
sink = 0;
}
tmp.type = mask_enum::mask_top_left;
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
v0, v1, y_total, x_total, true);
v0, v1, sink, y_total, x_total, true);
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = v0;
tmp.right = v1;
tmp.sink = sink;
}
else if(t == "b")
{
if(found_2 != std::string::npos)
{
v1 = atoi(v.substr(found_1 + 1, found_2 - found_1 - 1).c_str());
sink = atoi(v.substr(found_2 + 1).c_str());
}
else
{
v1 = atoi(v.substr(found_1 + 1).c_str());
sink = 0;
}
tmp.type = mask_enum::mask_bottom_right;
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
v0, v1, y_total, x_total, false);
v0, v1, sink, y_total, x_total, false);
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = v0;
tmp.right = v1;
tmp.sink = sink;
}
else if(t == "g")
{
@@ -108,6 +137,7 @@ struct mask_info
tmp.x = v1;
tmp.left = v0; // TODO: don't use this?
tmp.right = v1;
tmp.sink = 0;
}
}
else
@@ -126,6 +156,7 @@ struct mask_info
tmp.x = 1;
tmp.left = -1;
tmp.right = 0;
tmp.sink = 0;
}
else if(str == "2" || str == "b")
{
@@ -134,6 +165,7 @@ struct mask_info
tmp.x = seqlen_k - seqlen_q + 1;
tmp.left = -1;
tmp.right = 0;
tmp.sink = 0;
}
else
{