[CK Tile] Add sink token gradient support in FMHA backward pass (#5504)

## Motivation

Adds sink token support to the FMHA backward kernel (dot_do_o pipeline):

## Technical Details

- Extend BlockFmhaBwdOGradDotOPipelineProblem with LSEDataType
- Add sink_ptr/d_sink_ptr/lse_ptr/nhead to FmhaBwdOGradDotOCommonKargs
- Compute per-head sink gradient via atomic accumulation in the pipeline
- Update example runner with reference validation for sink gradient

## Test Plan

Add new test case

## Test Result

WIP

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Linjun-AMD
2026-04-02 11:17:01 +08:00
committed by GitHub
parent 91dbdfa476
commit d06e2bfa2f
12 changed files with 380 additions and 130 deletions

View File

@@ -91,7 +91,8 @@ void fmha_bwd_test(const FmhaBwdTestParam& param)
drop_offset,
drop_prefs,
mask_str,
det, // deterministic
false, // sink_grad
det, // deterministic
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
@@ -333,7 +334,8 @@ TEST_P(BasicQPadding, DataTypeConfig)
drop_offset,
drop_prefs,
mask_str,
det,
false, // sink_grad
det, // deterministic
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
@@ -419,7 +421,8 @@ TEST_P(BasicKVPadding, DataTypeConfig)
drop_offset,
drop_prefs,
mask_str,
det,
false, // sink_grad
det, // deterministic
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
@@ -513,7 +516,8 @@ TEST_P(QKVPadding, DataTypeConfig)
drop_offset,
drop_prefs,
mask_str,
det,
false, // sink_grad
det, // deterministic
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
@@ -620,7 +624,8 @@ TEST_P(ZeroLengthPadding, DataTypeConfig)
drop_offset,
drop_prefs,
mask_str,
det,
false, // sink_grad
det, // deterministic
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
@@ -741,7 +746,8 @@ TEST_P(VariedPaddingRatios, DataTypeConfig)
drop_offset,
drop_prefs,
mask_str,
det,
false, // sink_grad
det, // deterministic
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
@@ -843,7 +849,8 @@ TEST_P(PaddingWithMask, DataTypeConfig)
drop_offset,
drop_prefs,
mask_str,
det,
false, // sink_grad
det, // deterministic
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,
@@ -977,7 +984,8 @@ TEST_P(MultiBatchPadding, DataTypeConfig)
drop_offset,
drop_prefs,
mask_str,
det,
false, // sink_grad
det, // deterministic
init_method,
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1,