mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
Add attn sink (#2892)
* enable attn sink Signed-off-by: JL-underdog <Jun.Lin@amd.com> * update attn_sink script Signed-off-by: JL-underdog <Jun.Lin@amd.com> * fix some error Signed-off-by: JL-underdog <Jun.Lin@amd.com> * clang-format Signed-off-by: JL-underdog <Jun.Lin@amd.com> * update fmha_bwd mask Signed-off-by: JL-underdog <Jun.Lin@amd.com> * update fmha_bwd_kernel'mask Signed-off-by: JL-underdog <Jun.Lin@amd.com> * update block_fmha_pipeline_qr_ks_vs.hpp Signed-off-by: JL-underdog <Jun.Lin@amd.com> * fix ci error Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * fix format error Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * Update block_fmha_bwd_pipeline_default_policy.hpp * Update fmha_fwd_runner.hpp * Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp * Update fmha_fwd_runner.hpp * Update fmha_fwd_runner.hpp * Update fmha_fwd_runner.hpp * update splitkv_pipline Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * update splitkv&pagedkv pipeline Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * add sink test Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * update attn_sink result log Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * update smoke_test_fwd_sink.sh Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * update test file Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * update test script Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * Update block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp * use constexpr kHasSink for sink in fmha pipeline Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * update by pre-commit Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fmha_fwd.py * Update example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Remove causal mask setting logic from mask.hpp Removed the mask setting logic for causal masks. * fix ci error that some usage of lamada not support in c++17 Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * Update remod.py * add smoke sink test Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * Update fmha_pagedkv_prefill.py * Update FmhaFwdPipeline parameters in fmha_fwd.py * update block_fmha_pipeline_qr_ks_vs_async_trload.hpp Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * fix c++17 unsupprot error Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * Update block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp * Fix formatting of sink_seq_end assignment * Fix indentation for sink_seq_end assignment * Update block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp --------- Signed-off-by: JL-underdog <Jun.Lin@amd.com> Signed-off-by: LJ-underdog <Jun.Lin@amd.com> Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -25,6 +25,7 @@ struct mask_info
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t y, x;
|
||||
ck_tile::index_t left, right; // FA style SWA left/right
|
||||
ck_tile::index_t sink;
|
||||
|
||||
void serialize(std::ostream& os) const
|
||||
{
|
||||
@@ -58,13 +59,14 @@ struct mask_info
|
||||
ck_tile::index_t window_size = std::stoi(v);
|
||||
ck_tile::index_t left_size = -1;
|
||||
ck_tile::index_t right_size = 0;
|
||||
ck_tile::index_t sink_size = 0;
|
||||
if(window_size > 0)
|
||||
{
|
||||
left_size = window_size / 2;
|
||||
right_size = window_size - 1 - left_size;
|
||||
}
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
left_size, right_size, y_total, x_total, t == "xt");
|
||||
left_size, right_size, sink_size, y_total, x_total, t == "xt");
|
||||
|
||||
tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right;
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
@@ -79,27 +81,54 @@ struct mask_info
|
||||
{
|
||||
throw std::invalid_argument("invalid mask value: " + str);
|
||||
}
|
||||
ck_tile::index_t v0 = std::stoi(v.substr(0, found_1));
|
||||
ck_tile::index_t v1 = std::stoi(v.substr(found_1 + 1));
|
||||
tmp.type = mask_enum::window_generic;
|
||||
ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str());
|
||||
auto found_2 = v.find(',', found_1 + 1);
|
||||
ck_tile::index_t v1 = 0;
|
||||
ck_tile::index_t sink = 0;
|
||||
// ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str());
|
||||
// TODO: some validation
|
||||
if(t == "t")
|
||||
{
|
||||
if(found_2 != std::string::npos)
|
||||
{
|
||||
v1 = atoi(v.substr(found_1 + 1, found_2 - found_1 - 1).c_str());
|
||||
sink = atoi(v.substr(found_2 + 1).c_str());
|
||||
}
|
||||
else
|
||||
{
|
||||
v1 = atoi(v.substr(found_1 + 1).c_str());
|
||||
sink = 0;
|
||||
}
|
||||
tmp.type = mask_enum::mask_top_left;
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
v0, v1, y_total, x_total, true);
|
||||
v0, v1, sink, y_total, x_total, true);
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.left = v0;
|
||||
tmp.right = v1;
|
||||
tmp.sink = sink;
|
||||
}
|
||||
else if(t == "b")
|
||||
{
|
||||
if(found_2 != std::string::npos)
|
||||
{
|
||||
v1 = atoi(v.substr(found_1 + 1, found_2 - found_1 - 1).c_str());
|
||||
sink = atoi(v.substr(found_2 + 1).c_str());
|
||||
}
|
||||
else
|
||||
{
|
||||
v1 = atoi(v.substr(found_1 + 1).c_str());
|
||||
sink = 0;
|
||||
}
|
||||
tmp.type = mask_enum::mask_bottom_right;
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
v0, v1, y_total, x_total, false);
|
||||
v0, v1, sink, y_total, x_total, false);
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.left = v0;
|
||||
tmp.right = v1;
|
||||
tmp.sink = sink;
|
||||
}
|
||||
else if(t == "g")
|
||||
{
|
||||
@@ -108,6 +137,7 @@ struct mask_info
|
||||
tmp.x = v1;
|
||||
tmp.left = v0; // TODO: don't use this?
|
||||
tmp.right = v1;
|
||||
tmp.sink = 0;
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -126,6 +156,7 @@ struct mask_info
|
||||
tmp.x = 1;
|
||||
tmp.left = -1;
|
||||
tmp.right = 0;
|
||||
tmp.sink = 0;
|
||||
}
|
||||
else if(str == "2" || str == "b")
|
||||
{
|
||||
@@ -134,6 +165,7 @@ struct mask_info
|
||||
tmp.x = seqlen_k - seqlen_q + 1;
|
||||
tmp.left = -1;
|
||||
tmp.right = 0;
|
||||
tmp.sink = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user