mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK Tile] Add sink token gradient support in FMHA backward pass (#5504)
## Motivation Adds sink token support to the FMHA backward kernel (dot_do_o pipeline): ## Technical Details - Extend BlockFmhaBwdOGradDotOPipelineProblem with LSEDataType - Add sink_ptr/d_sink_ptr/lse_ptr/nhead to FmhaBwdOGradDotOCommonKargs - Compute per-head sink gradient via atomic accumulation in the pipeline - Update example runner with reference validation for sink gradient ## Test Plan Add new test case ## Test Result WIP ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
@@ -91,7 +91,8 @@ void fmha_bwd_test(const FmhaBwdTestParam& param)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
det, // deterministic
|
||||
false, // sink_grad
|
||||
det, // deterministic
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
1,
|
||||
@@ -333,7 +334,8 @@ TEST_P(BasicQPadding, DataTypeConfig)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
det,
|
||||
false, // sink_grad
|
||||
det, // deterministic
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
1,
|
||||
@@ -419,7 +421,8 @@ TEST_P(BasicKVPadding, DataTypeConfig)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
det,
|
||||
false, // sink_grad
|
||||
det, // deterministic
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
1,
|
||||
@@ -513,7 +516,8 @@ TEST_P(QKVPadding, DataTypeConfig)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
det,
|
||||
false, // sink_grad
|
||||
det, // deterministic
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
1,
|
||||
@@ -620,7 +624,8 @@ TEST_P(ZeroLengthPadding, DataTypeConfig)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
det,
|
||||
false, // sink_grad
|
||||
det, // deterministic
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
1,
|
||||
@@ -741,7 +746,8 @@ TEST_P(VariedPaddingRatios, DataTypeConfig)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
det,
|
||||
false, // sink_grad
|
||||
det, // deterministic
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
1,
|
||||
@@ -843,7 +849,8 @@ TEST_P(PaddingWithMask, DataTypeConfig)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
det,
|
||||
false, // sink_grad
|
||||
det, // deterministic
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
1,
|
||||
@@ -977,7 +984,8 @@ TEST_P(MultiBatchPadding, DataTypeConfig)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
det,
|
||||
false, // sink_grad
|
||||
det, // deterministic
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
1,
|
||||
|
||||
Reference in New Issue
Block a user