Replace sentinel value before storing

This commit is contained in:
PoYen, Chen
2024-06-04 09:59:51 +00:00
parent 5a6b8d8606
commit 064afc69d9

View File

@@ -271,8 +271,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline
}
else
{
lse_logsum(distributed_indices) =
ck_tile::log(lse_sum(distributed_indices)) + lse_max(distributed_indices);
lse_logsum(distributed_indices) = ck_tile::log(lse_sum(distributed_indices)) +
get_validated_m(lse_max(distributed_indices));
}
#if defined(PRINT_LSE_LOGSUM)
@@ -328,6 +328,18 @@ struct BlockFmhaFwdSplitKVCombinePipeline
if constexpr(kStoreLSE)
{
constexpr auto out_spans = static_distributed_tensor<
LSEDataType,
decltype(lse_logsum.get_tile_distribution())>::get_distributed_spans();
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
constexpr auto distributed_indices = make_tuple(idx0);
if(lse_logsum(distributed_indices) == numeric<LSEDataType>::infinity())
{
lse_logsum(distributed_indices) = -numeric<LSEDataType>::infinity();
}
});
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum));
}