From 8554618d6a7b41ff2e56ebc7bdb26720fbdc4fcf Mon Sep 17 00:00:00 2001 From: Yi DING Date: Fri, 27 Mar 2026 15:54:01 +0800 Subject: [PATCH] [CK_TILE] Fix NaN for FMHA BWD When seq_q=0 (#5790) ## Motivation This PR addresses NaNs in the FMHA backward (dQ/dK/dV) path when the effective query sequence length for a tile is zero, by ensuring the per-tile pipelines exit early with zeroed accumulators and by avoiding an early kernel return that prevented writing out cleared gradients. ## Technical Details - Add unconditional early-exit in the dK/dV pipelines when `num_total_loop <= 0` (no work), returning zeroed accumulators. - Adjust group-mode kernel early-return logic to only return when **both** `seqlen_q` and `seqlen_k` are zero, allowing blocks to run and store cleared dK/dV when `seqlen_q == 0`. ## Test Plan ## Test Result ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp | 2 +- .../block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp | 12 ++++-------- ...ock_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp | 12 ++++-------- ...k_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp | 12 ++++-------- 4 files changed, 13 insertions(+), 25 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index d32d5a321d..5659162c97 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -872,7 +872,7 @@ struct FmhaBwdDQDKDVKernel } // skip if logical lengths are zero - if(kargs.seqlen_q == 0 || kargs.seqlen_k == 0) + if(kargs.seqlen_q == 0 && kargs.seqlen_k == 0) { return; } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp index e4332df930..d12310add3 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp @@ -161,15 +161,11 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); - // check early exit if masked and no work to do. - if constexpr(FmhaMask::IsMasking) + // check early exit if no work to do. + if(num_total_loop <= 0) { - if(num_total_loop <= 0) - { - // Note: here dk_acc&dv_acc are all cleard, return it - // Note: v loaded but no fence, ignore it. - return make_tuple(dk_acc, dv_acc); - } + // Note: here dk_acc&dv_acc are all cleared, return it + return make_tuple(dk_acc, dv_acc); } KDataType* k_lds_ptr = static_cast(static_cast(static_cast(smem_ptr))); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index 03ee1486da..79bf963cf7 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -161,15 +161,11 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); - // check early exit if masked and no work to do. - if constexpr(FmhaMask::IsMasking) + // check early exit if no work to do. + if(num_total_loop <= 0) { - if(num_total_loop <= 0) - { - // Note: here dk_acc&dv_acc are all cleard, return it - // Note: v loaded but no fence, ignore it. - return make_tuple(dk_acc, dv_acc); - } + // Note: here dk_acc&dv_acc are all cleared, return it + return make_tuple(dk_acc, dv_acc); } KDataType* k_lds_ptr = static_cast(static_cast(static_cast(smem_ptr))); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp index 7f893a93ba..966e2ddff4 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp @@ -247,15 +247,11 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); - // check early exit if masked and no work to do. - if constexpr(FmhaMask::IsMasking) + // check early exit if no work to do. + if(num_total_loop <= 0) { - if(num_total_loop <= 0) - { - // Note: here dk_acc&dv_acc are all cleard, return it - // Note: v loaded but no fence, ignore it. - return make_tuple(dk_acc, dv_acc); - } + // Note: here dk_acc&dv_acc are all cleared, return it + return make_tuple(dk_acc, dv_acc); } auto k_lds = make_tensor_view(