[rocm-libraries] ROCm/rocm-libraries#5504 (commit 47f86c7)

[CK Tile] Add sink token gradient support in FMHA backward
 pass (#5504)

## Motivation

Adds sink token support to the FMHA backward kernel (dot_do_o pipeline):

## Technical Details

- Extend BlockFmhaBwdOGradDotOPipelineProblem with LSEDataType
- Add sink_ptr/d_sink_ptr/lse_ptr/nhead to FmhaBwdOGradDotOCommonKargs
- Compute per-head sink gradient via atomic accumulation in the pipeline
- Update example runner with reference validation for sink gradient

## Test Plan

Add new test case

## Test Result

WIP

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Linjun-AMD
2026-04-02 03:17:45 +00:00
committed by assistant-librarian[bot]
parent c1127a36f5
commit 08792e0b31
12 changed files with 380 additions and 130 deletions

View File

@@ -77,6 +77,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
uint64_t drop_offset,
bool drop_prefs,
std::string mask_str,
bool sink_grad, // if true, compute and validate sink gradient
bool deterministic,
std::string init_method,
uint32_t seed,
@@ -284,6 +285,16 @@ bwd_result fmha_bwd_run(mode_enum mode,
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
ck_tile::HostTensor<LSEDataType> lse_host(
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
ck_tile::HostTensor<LSEDataType> sink_host(
sink_grad ? std::array<ck_tile::index_t, 2>{shape_batch, nhead}
: std::array<ck_tile::index_t, 2>{1, 1} /* dummy when sink is disabled */);
if(sink_grad)
{
std::uniform_real_distribution<float> sink_dist(30.0f, 60.0f);
sink_host.ForEach([&](auto& self, auto i) {
self(i) = static_cast<LSEDataType>(sink_dist(random_engine));
});
}
ck_tile::HostTensor<DDataType> d_host(
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
ck_tile::HostTensor<RandValOutputDataType> randval_host(
@@ -301,6 +312,12 @@ bwd_result fmha_bwd_run(mode_enum mode,
use_dbias
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<LSEDataType> d_sink_host(sink_grad ? std::array<ck_tile::index_t, 1>{nhead}
: std::array<ck_tile::index_t, 1>{0});
if(sink_grad)
{
d_sink_host.ForEach([&](auto& self, auto i) { self(i) = 0; });
}
ck_tile::HostTensor<AccDataType> dq_acc_host(
std::array<ck_tile::index_t, 5>{shape_batch, nhead, nsplits, shape_seqlen_q, hdim_q});
@@ -361,11 +378,13 @@ bwd_result fmha_bwd_run(mode_enum mode,
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sink_buf(sink_grad ? sink_host.get_element_space_size_in_bytes() : 0);
ck_tile::DeviceMem d_buf(d_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dq_buf(dq_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dk_buf(dk_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dv_buf(dv_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_sink_buf(sink_grad ? d_sink_host.get_element_space_size_in_bytes() : 0);
ck_tile::DeviceMem do_buf(do_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
@@ -396,6 +415,11 @@ bwd_result fmha_bwd_run(mode_enum mode,
drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr);
drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr);
alibi_slope_buf.ToDevice(alibi_slope_host.data());
if(sink_grad)
{
sink_buf.ToDevice(sink_host.data());
d_sink_buf.ToDevice(d_sink_host.data());
}
// clang-format off
auto layout_str = [&](bool permute){
@@ -415,7 +439,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
<< "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0]
<< "/" << seqlen_ks[0] << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale
<< ", bias:" << bias << ", dbias:" << use_dbias << ", p_drop:" << p_drop
<< ", s_randval:" << s_randval << ", deterministic:" << deterministic
<< (sink_grad ? ", sink:(rand[30,60], grad)" : "") << ", s_randval:" << s_randval
<< ", deterministic:" << deterministic
<< (deterministic
? std::string(", workspace:") + std::to_string(workspace_size_in_megabytes) +
"MiB|" + std::to_string(nsplits) + "splits"
@@ -479,7 +504,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
const void* seqlen_q_ptr_dev = use_qpadding ? seqlen_q_dev.GetDeviceBuffer() : nullptr;
const void* seqlen_k_ptr_dev = use_kpadding ? seqlen_k_dev.GetDeviceBuffer() : nullptr;
return fmha_bwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
@@ -495,6 +519,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
dv_buf.GetDeviceBuffer(),
dbias_buf.GetDeviceBuffer(),
dq_acc_buf.GetDeviceBuffer(),
sink_buf.GetDeviceBuffer(),
d_sink_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(),
seqlen_q_ptr_dev,
@@ -589,6 +615,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
std::vector<ck_tile::HostTensor<RandValOutputDataType>> randval_host_refs;
std::vector<ck_tile::HostTensor<AccDataType>> p_hp_host_refs;
std::vector<ck_tile::HostTensor<GemmDataType>> p_lp_host_refs;
std::vector<ck_tile::HostTensor<AccDataType>> p_sink_host_refs;
randval_buf.FromDevice(randval_host.data());
@@ -765,6 +792,46 @@ bwd_result fmha_bwd_run(mode_enum mode,
ck_tile::reference_batched_softmax<AccDataType, LSEDataType, AccDataType>(
s_host_ref, p_hp_host_ref, ck_tile::identity{}, lse_host_ref);
// Incorporate sink token into the softmax distribution (reference computation).
// The sink acts as an extra key whose score is sink_host(wb, i_h) (in log-space),
// which is a per-head random value in [30, 60].
// lse_new = log(exp(lse_old) + exp(sink))
// P_new = P_old * exp(lse_old - lse_new) (rescaled token attention)
// P_sink = exp(sink - lse_new) (sink attention weight)
ck_tile::HostTensor<AccDataType> p_sink_host_ref(
sink_grad ? std::array<ck_tile::index_t, 2>{nhead, real_seqlen_q}
: std::array<ck_tile::index_t, 2>{0, 0});
if(sink_grad)
{
for(int i_h = 0; i_h < nhead; ++i_h)
{
AccDataType sink_val = sink_host(wb, i_h);
for(int i_q = 0; i_q < real_seqlen_q; ++i_q)
{
// Use numerically stable log-sum-exp: lse_new = log(exp(lse_old)+exp(sink))
// = max(lse_old, sink) + log(1 + exp(min - max))
// This handles lse_old = -inf (fully-masked rows) without producing NaN:
// if lse_old=-inf: max=sink, min=-inf, exp(-inf-sink)=0, lse_new=sink
// It also avoids exp(lse_old) overflow when lse_old is large.
// p_scale = exp(lse_old - lse_new) [fraction kept by regular tokens]
// p_sink = exp(sink - lse_new) [sink attention weight]
AccDataType lse_old = lse_host_ref(i_h, i_q);
AccDataType hi = lse_old > sink_val ? lse_old : sink_val;
AccDataType lo = lse_old > sink_val ? sink_val : lse_old;
AccDataType lse_new =
hi + ck_tile::log(AccDataType(1) + ck_tile::exp(lo - hi));
AccDataType p_scale = ck_tile::exp(lse_old - lse_new);
lse_host_ref(i_h, i_q) = lse_new;
for(int i_k = 0; i_k < real_seqlen_k; ++i_k)
p_hp_host_ref(i_h, i_q, i_k) *= p_scale;
p_sink_host_ref(i_h, i_q) = ck_tile::exp(sink_val - lse_new);
}
}
}
if(p_drop > 0)
{
p_dropped_hp_host_ref = p_hp_host_ref;
@@ -823,6 +890,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
o_host_refs.push_back(o_host_ref);
p_hp_host_refs.push_back(p_hp_host_ref);
p_lp_host_refs.push_back(p_lp_host_ref);
p_sink_host_refs.push_back(p_sink_host_ref);
if(p_drop > 0)
{
randval_host_refs.push_back(randval_host_ref);
@@ -842,6 +910,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
o_buf.ToDevice(o_host.data());
lse_buf.ToDevice(lse_host.data());
dbias_buf.SetZero();
if(sink_grad)
d_sink_buf.SetZero();
if(launcher.needs_zero_dq_acc)
dq_acc_buf.SetZero();
@@ -853,10 +923,19 @@ bwd_result fmha_bwd_run(mode_enum mode,
dk_buf.FromDevice(dk_host.data());
dv_buf.FromDevice(dv_host.data());
dbias_buf.FromDevice(dbias_host.data());
if(sink_grad)
d_sink_buf.FromDevice(d_sink_host.data());
// Track the index into reference vectors (may differ from wb if batches were skipped)
ck_tile::index_t ref_idx = 0;
// validation sink accumulator: global over batch, shape [nhead]
ck_tile::HostTensor<AccDataType> d_sink_host_ref(
sink_grad ? std::array<ck_tile::index_t, 1>{nhead}
: std::array<ck_tile::index_t, 1>{0});
if(sink_grad)
d_sink_host_ref.ForEach([&](auto& self, auto i) { self(i) = 0; });
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{
// When padding is enabled, use logical lengths instead of computing from padded
@@ -932,6 +1011,30 @@ bwd_result fmha_bwd_run(mode_enum mode,
ds_hp_host_ref.mDesc.get_lengths()[1],
ds_hp_host_ref.mDesc.get_lengths()[2])(std::thread::hardware_concurrency());
if(sink_grad)
{
// Reference: dSink[h] = -sum_q( P_sink[h,q] * D[h,q] )
// where D[h,q] = sum_j(dO[h,q,j] * O[h,q,j]) * p_undrop
for(int i_h = 0; i_h < nhead; ++i_h)
{
AccDataType d_sink_head_acc = 0;
for(int i_q = 0; i_q < real_seqlen_q; ++i_q)
{
AccDataType do_dot_o = 0;
for(int o = 0; o < hdim_v; o++)
{
do_dot_o +=
ck_tile::type_convert<AccDataType>(do_host_ref(i_h, i_q, o)) *
ck_tile::type_convert<AccDataType>(
o_host_refs[ref_idx](i_h, i_q, o)) *
p_undrop;
}
d_sink_head_acc += -p_sink_host_refs[ref_idx](i_h, i_q) * do_dot_o;
}
d_sink_host_ref(i_h) += d_sink_head_acc;
}
}
if(use_dbias)
{
dbias_host_ref = ds_hp_host_ref.template CopyAsType<BiasGradDataType>();
@@ -1044,6 +1147,17 @@ bwd_result fmha_bwd_run(mode_enum mode,
ref_idx++;
}
if(pass && sink_grad)
{
auto [rtol, atol] = get_elimit<DataTypeConfig>(hdim_q, hdim_v);
bool dsink_pass = ck_tile::check_err(d_sink_host,
d_sink_host_ref,
std::string("Error: SinkGrad Incorrect results!"),
rtol,
atol);
pass &= dsink_pass;
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}