mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Replace sentinel value before storing
This commit is contained in:
@@ -271,8 +271,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
}
|
||||
else
|
||||
{
|
||||
lse_logsum(distributed_indices) =
|
||||
ck_tile::log(lse_sum(distributed_indices)) + lse_max(distributed_indices);
|
||||
lse_logsum(distributed_indices) = ck_tile::log(lse_sum(distributed_indices)) +
|
||||
get_validated_m(lse_max(distributed_indices));
|
||||
}
|
||||
|
||||
#if defined(PRINT_LSE_LOGSUM)
|
||||
@@ -328,6 +328,18 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
constexpr auto out_spans = static_distributed_tensor<
|
||||
LSEDataType,
|
||||
decltype(lse_logsum.get_tile_distribution())>::get_distributed_spans();
|
||||
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto distributed_indices = make_tuple(idx0);
|
||||
|
||||
if(lse_logsum(distributed_indices) == numeric<LSEDataType>::infinity())
|
||||
{
|
||||
lse_logsum(distributed_indices) = -numeric<LSEDataType>::infinity();
|
||||
}
|
||||
});
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum));
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user