[rocm-libraries] ROCm/rocm-libraries#5504 (commit 47f86c7)

[CK Tile] Add sink token gradient support in FMHA backward
 pass (#5504)

## Motivation

Adds sink token support to the FMHA backward kernel (dot_do_o pipeline):

## Technical Details

- Extend BlockFmhaBwdOGradDotOPipelineProblem with LSEDataType
- Add sink_ptr/d_sink_ptr/lse_ptr/nhead to FmhaBwdOGradDotOCommonKargs
- Compute per-head sink gradient via atomic accumulation in the pipeline
- Update example runner with reference validation for sink gradient

## Test Plan

Add new test case

## Test Result

WIP

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Linjun-AMD
2026-04-02 03:17:45 +00:00
committed by assistant-librarian[bot]
parent c1127a36f5
commit 08792e0b31
12 changed files with 380 additions and 130 deletions

View File

@@ -87,6 +87,7 @@ auto create_args(int argc, char* argv[])
"0",
"if set to 1 will use multi-buffer reduction strategy for dq, atomic operation "
"will not be used")
.insert("sink_grad", "0", "if set to 1, compute and validate sink token gradient")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "fmha_bwd.json", "json file name to dump results");
@@ -122,6 +123,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
bool deterministic = arg_parser.get_bool("deterministic");
std::string init_method = arg_parser.get_str("init");
uint32_t seed = arg_parser.get_uint32("seed");
bool sink_grad = arg_parser.get_bool("sink_grad");
ck_tile::stream_config stream_config{nullptr,
true,
@@ -154,6 +156,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
drop_offset,
drop_prefs,
mask_str,
sink_grad,
deterministic,
init_method,
seed,