[CK] Fix gptoss sink (#4313)

## Motivation

This PR removes conditional logic for handling infinity values in the
sink mechanism across multiple FMHA pipeline implementations, defaulting
sink_size to 0 and adding a constraint in the kernel selection logic.

## Technical Details

Changes:

Removed __builtin_isinf_sign(sink_v) checks and conditional
initialization of LSE accumulators across 7 pipeline files
Added default initialization (= 0) for sink_size in 4 argument structs
Added F_sink == "f" constraint to kernel compatibility checking

## Test Plan

Local test

## Test Result

passed

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

---------

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
Co-authored-by: systems-assistant[bot] <systems-assistant[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
Linjun-AMD
2026-03-02 09:53:52 +08:00
committed by GitHub
parent f0d724135c
commit c67a80187b
7 changed files with 8 additions and 70 deletions

View File

@@ -1254,6 +1254,7 @@ def get_product(receipt: int) -> Product:
cond &= kernel_ctx.pipeline.F_bias in ["no", "alibi"]
cond &= kernel_ctx.pipeline.F_qscale == "no"
cond &= kernel_ctx.pipeline.F_skip == "f"
cond &= kernel_ctx.pipeline.F_sink == "f"
return cond
return Product(name="Flash attention integration", rule=fit)

View File

@@ -540,16 +540,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
{
auto lse =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
if(__builtin_isinf_sign(sink_v) >= 0)
{
set_tile(lse, SMPLComputeDataType{sink_v * scale_s});
}
else
{
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
}
set_tile(lse, SMPLComputeDataType{sink_v * scale_s});
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
}
buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0)

View File

@@ -275,16 +275,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
{
auto lse =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
if(__builtin_isinf_sign(sink_v) >= 0)
{
set_tile(lse, SMPLComputeDataType{sink_v * scale_s});
}
else
{
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
}
set_tile(lse, SMPLComputeDataType{sink_v * scale_s});
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
}

View File

@@ -293,16 +293,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
{
auto lse_acc =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
if(__builtin_isinf_sign(sink_v) >= 0 && i_split == 0)
{
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
}
else
{
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
}
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
if(get_thread_local_1d_id() < kM0)
{
store_tile(lse_acc_dram_window_tmp,

View File

@@ -277,16 +277,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
{
auto lse_acc =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
if(__builtin_isinf_sign(sink_v) >= 0 && i_split == 0)
{
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
}
else
{
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
}
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
store_tile(lse_acc_dram_window_tmp,
tile_elementwise_in(lse_acc_element_func, lse_acc));
}

View File

@@ -335,16 +335,7 @@ struct BlockFmhaPipelineQRKSVSAsync
{
auto lse =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
if(__builtin_isinf_sign(sink_v) >= 0)
{
set_tile(lse, SMPLComputeDataType{sink_v * scale_s});
}
else
{
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
}
set_tile(lse, SMPLComputeDataType{sink_v * scale_s});
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
}

View File

@@ -228,16 +228,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
{
auto lse_acc =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
if(__builtin_isinf_sign(sink_v) >= 0)
{
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
}
else
{
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
}
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
store_tile(lse_acc_dram_window_tmp, lse_acc);
}
@@ -757,16 +748,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
{
auto lse_acc =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
if(__builtin_isinf_sign(sink_v) >= 0)
{
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
}
else
{
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
}
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
store_tile(lse_acc_dram_window_tmp, lse_acc);
}