mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
[rocm-libraries] ROCm/rocm-libraries#5504 (commit 47f86c7)
[CK Tile] Add sink token gradient support in FMHA backward pass (#5504) ## Motivation Adds sink token support to the FMHA backward kernel (dot_do_o pipeline): ## Technical Details - Extend BlockFmhaBwdOGradDotOPipelineProblem with LSEDataType - Add sink_ptr/d_sink_ptr/lse_ptr/nhead to FmhaBwdOGradDotOCommonKargs - Compute per-head sink gradient via atomic accumulation in the pipeline - Update example runner with reference validation for sink gradient ## Test Plan Add new test case ## Test Result WIP ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
c1127a36f5
commit
08792e0b31
@@ -87,6 +87,7 @@ auto create_args(int argc, char* argv[])
|
||||
"0",
|
||||
"if set to 1 will use multi-buffer reduction strategy for dq, atomic operation "
|
||||
"will not be used")
|
||||
.insert("sink_grad", "0", "if set to 1, compute and validate sink token gradient")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "fmha_bwd.json", "json file name to dump results");
|
||||
|
||||
@@ -122,6 +123,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
bool deterministic = arg_parser.get_bool("deterministic");
|
||||
std::string init_method = arg_parser.get_str("init");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
bool sink_grad = arg_parser.get_bool("sink_grad");
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
@@ -154,6 +156,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
sink_grad,
|
||||
deterministic,
|
||||
init_method,
|
||||
seed,
|
||||
|
||||
Reference in New Issue
Block a user