diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp index af22256883..014b320b51 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -271,8 +271,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline } else { - lse_logsum(distributed_indices) = - ck_tile::log(lse_sum(distributed_indices)) + lse_max(distributed_indices); + lse_logsum(distributed_indices) = ck_tile::log(lse_sum(distributed_indices)) + + get_validated_m(lse_max(distributed_indices)); } #if defined(PRINT_LSE_LOGSUM) @@ -328,6 +328,18 @@ struct BlockFmhaFwdSplitKVCombinePipeline if constexpr(kStoreLSE) { + constexpr auto out_spans = static_distributed_tensor< + LSEDataType, + decltype(lse_logsum.get_tile_distribution())>::get_distributed_spans(); + sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) { + constexpr auto distributed_indices = make_tuple(idx0); + + if(lse_logsum(distributed_indices) == numeric::infinity()) + { + lse_logsum(distributed_indices) = -numeric::infinity(); + } + }); + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum)); }