From 0d13ef7329e0d4e9f3104fe3edd3b968a4dfd47b Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Tue, 13 Jan 2026 13:52:26 +0800 Subject: [PATCH] [CK Tile] Fix FMHA LSE calculation and potential division by zero (#3326) This commit addresses numerical stability issues in the BlockFmhaPipelineQRKSVS pipeline when bias has -inf masking values: 1. Explicitly handle the case where the accumulated exponential sum (l) is zero. In this case, the LSE is now correctly set to negative infinity, preventing log(0) errors. 2. Extend the zero-check protection in the normalization step to cover the ELEMENTWISE_BIAS case, preventing potential division by zero. [ROCm/composable_kernel commit: 141f77aa122a453184919e00fb8239b26a873a50] --- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 30 +++++++++++++------ 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index fe825a370a..d54ade9f7b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -714,26 +714,35 @@ struct BlockFmhaPipelineQRKSVS constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { constexpr auto i_idx = make_tuple(idx0); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) + // In the masked biased case, the entire row can be suppressed and the accumulated + // softmax denominator becomes zero; treat it as log(0) = -inf to avoid NaNs. + if(l_[i_idx] == 0.0f) { - lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); + lse(i_idx) = -numeric::infinity(); } else { - if constexpr(kHasLogitsSoftCap) +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); } else { - lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + if constexpr(kHasLogitsSoftCap) + { + lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); + } + else + { + lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + } } - } #else - lse(i_idx) = m_[i_idx] + log(l_[i_idx]); + lse(i_idx) = m_[i_idx] + log(l_[i_idx]); #endif + } }); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); @@ -745,7 +754,10 @@ struct BlockFmhaPipelineQRKSVS sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); const auto tmp = [&]() { - if constexpr(FmhaMask::IsMasking) + // When bias carries -inf masks the denominator can be zero; guard the normalization + // so we do not divide by zero after a fully masked row. + if constexpr(FmhaMask::IsMasking || + BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; }