[rocm-libraries] ROCm/rocm-libraries#5790 (commit c132b5a)

[CK_TILE] Fix NaN for FMHA BWD When seq_q=0

## Motivation
This PR addresses NaNs in the FMHA backward (dQ/dK/dV) path when the
effective query sequence length for a tile is zero, by ensuring the
per-tile pipelines exit early with zeroed accumulators and by avoiding
an early kernel return that prevented writing out cleared gradients.

## Technical Details
- Add unconditional early-exit in the dK/dV pipelines when
`num_total_loop <= 0` (no work), returning zeroed accumulators.
- Adjust group-mode kernel early-return logic to only return when
**both** `seqlen_q` and `seqlen_k` are zero, allowing blocks to run and
store cleared dK/dV when `seqlen_q == 0`.

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Yi DING
2026-03-27 07:54:53 +00:00
committed by assistant-librarian[bot]
parent e2470e837a
commit 47a04fda08
4 changed files with 13 additions and 25 deletions

View File

@@ -872,7 +872,7 @@ struct FmhaBwdDQDKDVKernel
}
// skip if logical lengths are zero
if(kargs.seqlen_q == 0 || kargs.seqlen_k == 0)
if(kargs.seqlen_q == 0 && kargs.seqlen_k == 0)
{
return;
}

View File

@@ -161,15 +161,11 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
// check early exit if masked and no work to do.
if constexpr(FmhaMask::IsMasking)
// check early exit if no work to do.
if(num_total_loop <= 0)
{
if(num_total_loop <= 0)
{
// Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it.
return make_tuple(dk_acc, dv_acc);
}
// Note: here dk_acc&dv_acc are all cleared, return it
return make_tuple(dk_acc, dv_acc);
}
KDataType* k_lds_ptr =
static_cast<KDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));

View File

@@ -161,15 +161,11 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
// check early exit if masked and no work to do.
if constexpr(FmhaMask::IsMasking)
// check early exit if no work to do.
if(num_total_loop <= 0)
{
if(num_total_loop <= 0)
{
// Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it.
return make_tuple(dk_acc, dv_acc);
}
// Note: here dk_acc&dv_acc are all cleared, return it
return make_tuple(dk_acc, dv_acc);
}
KDataType* k_lds_ptr =
static_cast<KDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));

View File

@@ -247,15 +247,11 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
// check early exit if masked and no work to do.
if constexpr(FmhaMask::IsMasking)
// check early exit if no work to do.
if(num_total_loop <= 0)
{
if(num_total_loop <= 0)
{
// Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it.
return make_tuple(dk_acc, dv_acc);
}
// Note: here dk_acc&dv_acc are all cleared, return it
return make_tuple(dk_acc, dv_acc);
}
auto k_lds = make_tensor_view<address_space_enum::lds>(