mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
[rocm-libraries] ROCm/rocm-libraries#5790 (commit c132b5a)
[CK_TILE] Fix NaN for FMHA BWD When seq_q=0 ## Motivation This PR addresses NaNs in the FMHA backward (dQ/dK/dV) path when the effective query sequence length for a tile is zero, by ensuring the per-tile pipelines exit early with zeroed accumulators and by avoiding an early kernel return that prevented writing out cleared gradients. ## Technical Details - Add unconditional early-exit in the dK/dV pipelines when `num_total_loop <= 0` (no work), returning zeroed accumulators. - Adjust group-mode kernel early-return logic to only return when **both** `seqlen_q` and `seqlen_k` are zero, allowing blocks to run and store cleared dK/dV when `seqlen_q == 0`. ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
e2470e837a
commit
47a04fda08
@@ -872,7 +872,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
}
|
||||
|
||||
// skip if logical lengths are zero
|
||||
if(kargs.seqlen_q == 0 || kargs.seqlen_k == 0)
|
||||
if(kargs.seqlen_q == 0 && kargs.seqlen_k == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -161,15 +161,11 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
// check early exit if no work to do.
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
// Note: here dk_acc&dv_acc are all cleard, return it
|
||||
// Note: v loaded but no fence, ignore it.
|
||||
return make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
// Note: here dk_acc&dv_acc are all cleared, return it
|
||||
return make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
KDataType* k_lds_ptr =
|
||||
static_cast<KDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
|
||||
|
||||
@@ -161,15 +161,11 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
// check early exit if no work to do.
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
// Note: here dk_acc&dv_acc are all cleard, return it
|
||||
// Note: v loaded but no fence, ignore it.
|
||||
return make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
// Note: here dk_acc&dv_acc are all cleared, return it
|
||||
return make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
KDataType* k_lds_ptr =
|
||||
static_cast<KDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
|
||||
|
||||
@@ -247,15 +247,11 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
// check early exit if no work to do.
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
// Note: here dk_acc&dv_acc are all cleard, return it
|
||||
// Note: v loaded but no fence, ignore it.
|
||||
return make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
// Note: here dk_acc&dv_acc are all cleared, return it
|
||||
return make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
|
||||
Reference in New Issue
Block a user