mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4313 (commit 080ac66)
[CK] Fix gptoss sink ## Motivation This PR removes conditional logic for handling infinity values in the sink mechanism across multiple FMHA pipeline implementations, defaulting sink_size to 0 and adding a constraint in the kernel selection logic. ## Technical Details Changes: Removed __builtin_isinf_sign(sink_v) checks and conditional initialization of LSE accumulators across 7 pipeline files Added default initialization (= 0) for sink_size in 4 argument structs Added F_sink == "f" constraint to kernel compatibility checking ## Test Plan Local test ## Test Result passed ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
d32d515f64
commit
78ae3835a6
@@ -540,16 +540,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
{
|
||||
auto lse =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
set_tile(lse, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
set_tile(lse, SMPLComputeDataType{sink_v * scale_s});
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
|
||||
|
||||
@@ -275,16 +275,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
{
|
||||
auto lse =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
set_tile(lse, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
set_tile(lse, SMPLComputeDataType{sink_v * scale_s});
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
|
||||
|
||||
@@ -293,16 +293,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
{
|
||||
auto lse_acc =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
if(__builtin_isinf_sign(sink_v) >= 0 && i_split == 0)
|
||||
{
|
||||
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
|
||||
if(get_thread_local_1d_id() < kM0)
|
||||
{
|
||||
store_tile(lse_acc_dram_window_tmp,
|
||||
|
||||
@@ -277,16 +277,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
{
|
||||
auto lse_acc =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
if(__builtin_isinf_sign(sink_v) >= 0 && i_split == 0)
|
||||
{
|
||||
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
|
||||
store_tile(lse_acc_dram_window_tmp,
|
||||
tile_elementwise_in(lse_acc_element_func, lse_acc));
|
||||
}
|
||||
|
||||
@@ -335,16 +335,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
{
|
||||
auto lse =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
set_tile(lse, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
set_tile(lse, SMPLComputeDataType{sink_v * scale_s});
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
|
||||
|
||||
@@ -228,16 +228,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
{
|
||||
auto lse_acc =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
|
||||
store_tile(lse_acc_dram_window_tmp, lse_acc);
|
||||
}
|
||||
|
||||
@@ -757,16 +748,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
{
|
||||
auto lse_acc =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
|
||||
store_tile(lse_acc_dram_window_tmp, lse_acc);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user