[rocm-libraries] ROCm/rocm-libraries#5174 (commit a358a21)

[CK_TILE] FMHA BWD Use Persistent Kernels in Deterministic
 Mode (#5174)

## Motivation
This PR enables a persistent-kernel execution path for FMHA backward
(dQ/dK/dV) in deterministic mode, adjusting how dQ accumulation is
split, stored, and converted back to final gradients.

## Technical Details
- Introduces a persistent-kernel grid mapping in deterministic mode and
updates split-count calculation accordingly.
- Extends kernel kargs to carry batch-related info needed for persistent
scheduling and dQ conversion.
- Refactors dQ store conditions and adds mask-type traits/utilities and
runner logging updates.

## Test Plan
- Jenkins
[base](http://micimaster.amd.com/blue/organizations/jenkins/rocm-libraries-folder%2FComposable%20Kernel/detail/PR-5174/10/pipeline)
- Jenkins
[AITER](http://micimaster.amd.com/blue/organizations/jenkins/rocm-libraries-folder%2FComposable%20Kernel/detail/PR-5174/12/pipeline)
- Jenkins
[FMHA](http://micimaster.amd.com/blue/organizations/jenkins/rocm-libraries-folder%2FComposable%20Kernel/detail/PR-5174/11/pipeline)
- local FA tests

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Yi DING
2026-03-13 06:14:31 +00:00
committed by assistant-librarian[bot]
parent e2f5ab8000
commit 574c1c121a
7 changed files with 175 additions and 49 deletions

View File

@@ -169,10 +169,17 @@ int fmha_bwd_dq_dk_dv_maxq_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>()
}}
template <>
int fmha_bwd_dq_dk_dv_dq_acc_splits_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(ck_tile::index_t seqlen_k)
int fmha_bwd_dq_dk_dv_dq_acc_splits_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(const fmha_bwd_traits& t)
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
return k_::GetDqAccSplits(seqlen_k);
return k_::GetDqAccSplits(t.batch, t.nhead_q, t.max_seqlen_k);
}}
template <>
bool fmha_bwd_dq_dk_dv_needs_zero_dq_acc_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>()
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
return k_::NeedsZeroDqAcc();
}}
template <>
@@ -192,6 +199,7 @@ fmha_bwd_launcher::fmha_bwd_launcher(const fmha_bwd_traits& t){{
{F_launcher}
run = [](fmha_bwd_args, const ck_tile::stream_config&) {{ return -1.0f; }};
dq_acc_splits = 1;
needs_zero_dq_acc = false;
}}
@@ -231,7 +239,8 @@ FMHA_BWD_API_INNER_DISPATCH_LAUNCHER = """
run = [](fmha_bwd_args a, const ck_tile::stream_config& s) {{
return fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>, {F_arch.tag}>(s, a);
}};
dq_acc_splits = fmha_bwd_dq_dk_dv_dq_acc_splits_<dq_dk_dv_trait_, {F_arch.tag}>(t.max_seqlen_k);
dq_acc_splits = fmha_bwd_dq_dk_dv_dq_acc_splits_<dq_dk_dv_trait_, {F_arch.tag}>(t);
needs_zero_dq_acc = fmha_bwd_dq_dk_dv_needs_zero_dq_acc_<dq_dk_dv_trait_, {F_arch.tag}>();
return;
}}
"""
@@ -447,7 +456,7 @@ class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9):
results = KernelComponentFactoryGfx9.get_dq_dk_dv_tiles(dtype, tr_load)
if dtype in ["fp16", "bf16"] and tr_load == "t":
results.extend([
FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
FmhaBwdDQDKDVTileSize( 32, 256, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
FmhaBwdDQDKDVTileSize( 16, 192, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
# FmhaBwdDQDKDVTileSize( 32, 32, 64, 32, 64, 32, 32, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, 1, 32),
@@ -823,7 +832,7 @@ class FmhaBwdApiTrait:
@property
def extra_cond(self) -> str:
if self.tr_load == "t" and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128:
if self.tr_load == "t" and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128 and self.tile.F_bhdq == 128:
return " && (t.seqlen_k <= 256)"
else:
return ""