From 78ae3835a66d725dcf19b4e3877203e4e1aaf4a4 Mon Sep 17 00:00:00 2001 From: Linjun-AMD <105184542+LJ-underdog@users.noreply.github.com> Date: Mon, 2 Mar 2026 01:54:46 +0000 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#4313 (commit 080ac66) [CK] Fix gptoss sink ## Motivation This PR removes conditional logic for handling infinity values in the sink mechanism across multiple FMHA pipeline implementations, defaulting sink_size to 0 and adding a constraint in the kernel selection logic. ## Technical Details Changes: Removed __builtin_isinf_sign(sink_v) checks and conditional initialization of LSE accumulators across 7 pipeline files Added default initialization (= 0) for sink_size in 4 argument structs Added F_sink == "f" constraint to kernel compatibility checking ## Test Plan Local test ## Test Result passed ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 1 + ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 11 +--------- ...ock_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp | 11 +--------- ...litkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp | 11 +--------- ...ock_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp | 11 +--------- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 11 +--------- ...ck_fmha_pipeline_qr_ks_vs_async_trload.hpp | 22 ++----------------- 7 files changed, 8 insertions(+), 70 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index aa29633edc..c9a6013c40 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1254,6 +1254,7 @@ def get_product(receipt: int) -> Product: cond &= kernel_ctx.pipeline.F_bias in ["no", "alibi"] cond &= kernel_ctx.pipeline.F_qscale == "no" cond &= kernel_ctx.pipeline.F_skip == "f" + cond &= kernel_ctx.pipeline.F_sink == "f" return cond return Product(name="Flash attention integration", rule=fit) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 6672940576..1bc84836d3 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -540,16 +540,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync { auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - - if(__builtin_isinf_sign(sink_v) >= 0) - { - set_tile(lse, SMPLComputeDataType{sink_v * scale_s}); - } - else - { - set_tile(lse, -numeric::infinity()); - } - + set_tile(lse, SMPLComputeDataType{sink_v * scale_s}); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); } buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp index d7696f0f76..3f6b9bc44f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp @@ -275,16 +275,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS { auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - - if(__builtin_isinf_sign(sink_v) >= 0) - { - set_tile(lse, SMPLComputeDataType{sink_v * scale_s}); - } - else - { - set_tile(lse, -numeric::infinity()); - } - + set_tile(lse, SMPLComputeDataType{sink_v * scale_s}); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp index adc8ea5a90..1af244751b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp @@ -293,16 +293,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS { auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - - if(__builtin_isinf_sign(sink_v) >= 0 && i_split == 0) - { - set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); - } - else - { - set_tile(lse_acc, -numeric::infinity()); - } - + set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); if(get_thread_local_1d_id() < kM0) { store_tile(lse_acc_dram_window_tmp, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index ef6ed8b4e8..842b48013a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -277,16 +277,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS { auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - - if(__builtin_isinf_sign(sink_v) >= 0 && i_split == 0) - { - set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); - } - else - { - set_tile(lse_acc, -numeric::infinity()); - } - + set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); store_tile(lse_acc_dram_window_tmp, tile_elementwise_in(lse_acc_element_func, lse_acc)); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 81bd8d5ab5..911f059932 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -335,16 +335,7 @@ struct BlockFmhaPipelineQRKSVSAsync { auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - - if(__builtin_isinf_sign(sink_v) >= 0) - { - set_tile(lse, SMPLComputeDataType{sink_v * scale_s}); - } - else - { - set_tile(lse, -numeric::infinity()); - } - + set_tile(lse, SMPLComputeDataType{sink_v * scale_s}); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index aab79c52ae..e9ed9ac072 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -228,16 +228,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload { auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - - if(__builtin_isinf_sign(sink_v) >= 0) - { - set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); - } - else - { - set_tile(lse_acc, -numeric::infinity()); - } - + set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); store_tile(lse_acc_dram_window_tmp, lse_acc); } @@ -757,16 +748,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload { auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - - if(__builtin_isinf_sign(sink_v) >= 0) - { - set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); - } - else - { - set_tile(lse_acc, -numeric::infinity()); - } - + set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); store_tile(lse_acc_dram_window_tmp, lse_acc); }