diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index aa29633edc..c9a6013c40 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1254,6 +1254,7 @@ def get_product(receipt: int) -> Product: cond &= kernel_ctx.pipeline.F_bias in ["no", "alibi"] cond &= kernel_ctx.pipeline.F_qscale == "no" cond &= kernel_ctx.pipeline.F_skip == "f" + cond &= kernel_ctx.pipeline.F_sink == "f" return cond return Product(name="Flash attention integration", rule=fit) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 6672940576..1bc84836d3 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -540,16 +540,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync { auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - - if(__builtin_isinf_sign(sink_v) >= 0) - { - set_tile(lse, SMPLComputeDataType{sink_v * scale_s}); - } - else - { - set_tile(lse, -numeric::infinity()); - } - + set_tile(lse, SMPLComputeDataType{sink_v * scale_s}); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); } buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp index d7696f0f76..3f6b9bc44f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp @@ -275,16 +275,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS { auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - - if(__builtin_isinf_sign(sink_v) >= 0) - { - set_tile(lse, SMPLComputeDataType{sink_v * scale_s}); - } - else - { - set_tile(lse, -numeric::infinity()); - } - + set_tile(lse, SMPLComputeDataType{sink_v * scale_s}); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp index adc8ea5a90..1af244751b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp @@ -293,16 +293,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS { auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - - if(__builtin_isinf_sign(sink_v) >= 0 && i_split == 0) - { - set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); - } - else - { - set_tile(lse_acc, -numeric::infinity()); - } - + set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); if(get_thread_local_1d_id() < kM0) { store_tile(lse_acc_dram_window_tmp, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index ef6ed8b4e8..842b48013a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -277,16 +277,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS { auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - - if(__builtin_isinf_sign(sink_v) >= 0 && i_split == 0) - { - set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); - } - else - { - set_tile(lse_acc, -numeric::infinity()); - } - + set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); store_tile(lse_acc_dram_window_tmp, tile_elementwise_in(lse_acc_element_func, lse_acc)); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 81bd8d5ab5..911f059932 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -335,16 +335,7 @@ struct BlockFmhaPipelineQRKSVSAsync { auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - - if(__builtin_isinf_sign(sink_v) >= 0) - { - set_tile(lse, SMPLComputeDataType{sink_v * scale_s}); - } - else - { - set_tile(lse, -numeric::infinity()); - } - + set_tile(lse, SMPLComputeDataType{sink_v * scale_s}); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index aab79c52ae..e9ed9ac072 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -228,16 +228,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload { auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - - if(__builtin_isinf_sign(sink_v) >= 0) - { - set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); - } - else - { - set_tile(lse_acc, -numeric::infinity()); - } - + set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); store_tile(lse_acc_dram_window_tmp, lse_acc); } @@ -757,16 +748,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload { auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - - if(__builtin_isinf_sign(sink_v) >= 0) - { - set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); - } - else - { - set_tile(lse_acc, -numeric::infinity()); - } - + set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); store_tile(lse_acc_dram_window_tmp, lse_acc); }