mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
[rocm-libraries] ROCm/rocm-libraries#5504 (commit 47f86c7)
[CK Tile] Add sink token gradient support in FMHA backward pass (#5504) ## Motivation Adds sink token support to the FMHA backward kernel (dot_do_o pipeline): ## Technical Details - Extend BlockFmhaBwdOGradDotOPipelineProblem with LSEDataType - Add sink_ptr/d_sink_ptr/lse_ptr/nhead to FmhaBwdOGradDotOCommonKargs - Compute per-head sink gradient via atomic accumulation in the pipeline - Update example runner with reference validation for sink gradient ## Test Plan Add new test case ## Test Result WIP ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
c1127a36f5
commit
08792e0b31
@@ -533,6 +533,7 @@ using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDot
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
||||
/* BlockSize = M0 = */ {F_bm0},
|
||||
{F_hdim},
|
||||
{F_mode},
|
||||
|
||||
Reference in New Issue
Block a user