mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
[rocm-libraries] ROCm/rocm-libraries#5504 (commit 47f86c7)
[CK Tile] Add sink token gradient support in FMHA backward pass (#5504) ## Motivation Adds sink token support to the FMHA backward kernel (dot_do_o pipeline): ## Technical Details - Extend BlockFmhaBwdOGradDotOPipelineProblem with LSEDataType - Add sink_ptr/d_sink_ptr/lse_ptr/nhead to FmhaBwdOGradDotOCommonKargs - Compute per-head sink gradient via atomic accumulation in the pipeline - Update example runner with reference validation for sink gradient ## Test Plan Add new test case ## Test Result WIP ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
c1127a36f5
commit
08792e0b31
@@ -77,6 +77,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
uint64_t drop_offset,
|
||||
bool drop_prefs,
|
||||
std::string mask_str,
|
||||
bool sink_grad, // if true, compute and validate sink gradient
|
||||
bool deterministic,
|
||||
std::string init_method,
|
||||
uint32_t seed,
|
||||
@@ -284,6 +285,16 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
|
||||
ck_tile::HostTensor<LSEDataType> lse_host(
|
||||
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
|
||||
ck_tile::HostTensor<LSEDataType> sink_host(
|
||||
sink_grad ? std::array<ck_tile::index_t, 2>{shape_batch, nhead}
|
||||
: std::array<ck_tile::index_t, 2>{1, 1} /* dummy when sink is disabled */);
|
||||
if(sink_grad)
|
||||
{
|
||||
std::uniform_real_distribution<float> sink_dist(30.0f, 60.0f);
|
||||
sink_host.ForEach([&](auto& self, auto i) {
|
||||
self(i) = static_cast<LSEDataType>(sink_dist(random_engine));
|
||||
});
|
||||
}
|
||||
ck_tile::HostTensor<DDataType> d_host(
|
||||
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
|
||||
ck_tile::HostTensor<RandValOutputDataType> randval_host(
|
||||
@@ -301,6 +312,12 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
use_dbias
|
||||
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
ck_tile::HostTensor<LSEDataType> d_sink_host(sink_grad ? std::array<ck_tile::index_t, 1>{nhead}
|
||||
: std::array<ck_tile::index_t, 1>{0});
|
||||
if(sink_grad)
|
||||
{
|
||||
d_sink_host.ForEach([&](auto& self, auto i) { self(i) = 0; });
|
||||
}
|
||||
ck_tile::HostTensor<AccDataType> dq_acc_host(
|
||||
std::array<ck_tile::index_t, 5>{shape_batch, nhead, nsplits, shape_seqlen_q, hdim_q});
|
||||
|
||||
@@ -361,11 +378,13 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sink_buf(sink_grad ? sink_host.get_element_space_size_in_bytes() : 0);
|
||||
ck_tile::DeviceMem d_buf(d_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem dq_buf(dq_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem dk_buf(dk_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem dv_buf(dv_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_sink_buf(sink_grad ? d_sink_host.get_element_space_size_in_bytes() : 0);
|
||||
ck_tile::DeviceMem do_buf(do_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
@@ -396,6 +415,11 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr);
|
||||
drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr);
|
||||
alibi_slope_buf.ToDevice(alibi_slope_host.data());
|
||||
if(sink_grad)
|
||||
{
|
||||
sink_buf.ToDevice(sink_host.data());
|
||||
d_sink_buf.ToDevice(d_sink_host.data());
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
auto layout_str = [&](bool permute){
|
||||
@@ -415,7 +439,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
<< "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0]
|
||||
<< "/" << seqlen_ks[0] << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale
|
||||
<< ", bias:" << bias << ", dbias:" << use_dbias << ", p_drop:" << p_drop
|
||||
<< ", s_randval:" << s_randval << ", deterministic:" << deterministic
|
||||
<< (sink_grad ? ", sink:(rand[30,60], grad)" : "") << ", s_randval:" << s_randval
|
||||
<< ", deterministic:" << deterministic
|
||||
<< (deterministic
|
||||
? std::string(", workspace:") + std::to_string(workspace_size_in_megabytes) +
|
||||
"MiB|" + std::to_string(nsplits) + "splits"
|
||||
@@ -479,7 +504,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
|
||||
const void* seqlen_q_ptr_dev = use_qpadding ? seqlen_q_dev.GetDeviceBuffer() : nullptr;
|
||||
const void* seqlen_k_ptr_dev = use_kpadding ? seqlen_k_dev.GetDeviceBuffer() : nullptr;
|
||||
|
||||
return fmha_bwd_args{q_buf.GetDeviceBuffer(),
|
||||
k_buf.GetDeviceBuffer(),
|
||||
v_buf.GetDeviceBuffer(),
|
||||
@@ -495,6 +519,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
dv_buf.GetDeviceBuffer(),
|
||||
dbias_buf.GetDeviceBuffer(),
|
||||
dq_acc_buf.GetDeviceBuffer(),
|
||||
sink_buf.GetDeviceBuffer(),
|
||||
d_sink_buf.GetDeviceBuffer(),
|
||||
seqstart_q.GetDeviceBuffer(),
|
||||
seqstart_k.GetDeviceBuffer(),
|
||||
seqlen_q_ptr_dev,
|
||||
@@ -589,6 +615,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
std::vector<ck_tile::HostTensor<RandValOutputDataType>> randval_host_refs;
|
||||
std::vector<ck_tile::HostTensor<AccDataType>> p_hp_host_refs;
|
||||
std::vector<ck_tile::HostTensor<GemmDataType>> p_lp_host_refs;
|
||||
std::vector<ck_tile::HostTensor<AccDataType>> p_sink_host_refs;
|
||||
|
||||
randval_buf.FromDevice(randval_host.data());
|
||||
|
||||
@@ -765,6 +792,46 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
ck_tile::reference_batched_softmax<AccDataType, LSEDataType, AccDataType>(
|
||||
s_host_ref, p_hp_host_ref, ck_tile::identity{}, lse_host_ref);
|
||||
|
||||
// Incorporate sink token into the softmax distribution (reference computation).
|
||||
// The sink acts as an extra key whose score is sink_host(wb, i_h) (in log-space),
|
||||
// which is a per-head random value in [30, 60].
|
||||
// lse_new = log(exp(lse_old) + exp(sink))
|
||||
// P_new = P_old * exp(lse_old - lse_new) (rescaled token attention)
|
||||
// P_sink = exp(sink - lse_new) (sink attention weight)
|
||||
ck_tile::HostTensor<AccDataType> p_sink_host_ref(
|
||||
sink_grad ? std::array<ck_tile::index_t, 2>{nhead, real_seqlen_q}
|
||||
: std::array<ck_tile::index_t, 2>{0, 0});
|
||||
if(sink_grad)
|
||||
{
|
||||
for(int i_h = 0; i_h < nhead; ++i_h)
|
||||
{
|
||||
AccDataType sink_val = sink_host(wb, i_h);
|
||||
for(int i_q = 0; i_q < real_seqlen_q; ++i_q)
|
||||
{
|
||||
// Use numerically stable log-sum-exp: lse_new = log(exp(lse_old)+exp(sink))
|
||||
// = max(lse_old, sink) + log(1 + exp(min - max))
|
||||
// This handles lse_old = -inf (fully-masked rows) without producing NaN:
|
||||
// if lse_old=-inf: max=sink, min=-inf, exp(-inf-sink)=0, lse_new=sink
|
||||
// It also avoids exp(lse_old) overflow when lse_old is large.
|
||||
// p_scale = exp(lse_old - lse_new) [fraction kept by regular tokens]
|
||||
// p_sink = exp(sink - lse_new) [sink attention weight]
|
||||
AccDataType lse_old = lse_host_ref(i_h, i_q);
|
||||
AccDataType hi = lse_old > sink_val ? lse_old : sink_val;
|
||||
AccDataType lo = lse_old > sink_val ? sink_val : lse_old;
|
||||
AccDataType lse_new =
|
||||
hi + ck_tile::log(AccDataType(1) + ck_tile::exp(lo - hi));
|
||||
AccDataType p_scale = ck_tile::exp(lse_old - lse_new);
|
||||
|
||||
lse_host_ref(i_h, i_q) = lse_new;
|
||||
|
||||
for(int i_k = 0; i_k < real_seqlen_k; ++i_k)
|
||||
p_hp_host_ref(i_h, i_q, i_k) *= p_scale;
|
||||
|
||||
p_sink_host_ref(i_h, i_q) = ck_tile::exp(sink_val - lse_new);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(p_drop > 0)
|
||||
{
|
||||
p_dropped_hp_host_ref = p_hp_host_ref;
|
||||
@@ -823,6 +890,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
o_host_refs.push_back(o_host_ref);
|
||||
p_hp_host_refs.push_back(p_hp_host_ref);
|
||||
p_lp_host_refs.push_back(p_lp_host_ref);
|
||||
p_sink_host_refs.push_back(p_sink_host_ref);
|
||||
if(p_drop > 0)
|
||||
{
|
||||
randval_host_refs.push_back(randval_host_ref);
|
||||
@@ -842,6 +910,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
o_buf.ToDevice(o_host.data());
|
||||
lse_buf.ToDevice(lse_host.data());
|
||||
dbias_buf.SetZero();
|
||||
if(sink_grad)
|
||||
d_sink_buf.SetZero();
|
||||
|
||||
if(launcher.needs_zero_dq_acc)
|
||||
dq_acc_buf.SetZero();
|
||||
@@ -853,10 +923,19 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
dk_buf.FromDevice(dk_host.data());
|
||||
dv_buf.FromDevice(dv_host.data());
|
||||
dbias_buf.FromDevice(dbias_host.data());
|
||||
if(sink_grad)
|
||||
d_sink_buf.FromDevice(d_sink_host.data());
|
||||
|
||||
// Track the index into reference vectors (may differ from wb if batches were skipped)
|
||||
ck_tile::index_t ref_idx = 0;
|
||||
|
||||
// validation sink accumulator: global over batch, shape [nhead]
|
||||
ck_tile::HostTensor<AccDataType> d_sink_host_ref(
|
||||
sink_grad ? std::array<ck_tile::index_t, 1>{nhead}
|
||||
: std::array<ck_tile::index_t, 1>{0});
|
||||
if(sink_grad)
|
||||
d_sink_host_ref.ForEach([&](auto& self, auto i) { self(i) = 0; });
|
||||
|
||||
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
||||
{
|
||||
// When padding is enabled, use logical lengths instead of computing from padded
|
||||
@@ -932,6 +1011,30 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
ds_hp_host_ref.mDesc.get_lengths()[1],
|
||||
ds_hp_host_ref.mDesc.get_lengths()[2])(std::thread::hardware_concurrency());
|
||||
|
||||
if(sink_grad)
|
||||
{
|
||||
// Reference: dSink[h] = -sum_q( P_sink[h,q] * D[h,q] )
|
||||
// where D[h,q] = sum_j(dO[h,q,j] * O[h,q,j]) * p_undrop
|
||||
for(int i_h = 0; i_h < nhead; ++i_h)
|
||||
{
|
||||
AccDataType d_sink_head_acc = 0;
|
||||
for(int i_q = 0; i_q < real_seqlen_q; ++i_q)
|
||||
{
|
||||
AccDataType do_dot_o = 0;
|
||||
for(int o = 0; o < hdim_v; o++)
|
||||
{
|
||||
do_dot_o +=
|
||||
ck_tile::type_convert<AccDataType>(do_host_ref(i_h, i_q, o)) *
|
||||
ck_tile::type_convert<AccDataType>(
|
||||
o_host_refs[ref_idx](i_h, i_q, o)) *
|
||||
p_undrop;
|
||||
}
|
||||
d_sink_head_acc += -p_sink_host_refs[ref_idx](i_h, i_q) * do_dot_o;
|
||||
}
|
||||
d_sink_host_ref(i_h) += d_sink_head_acc;
|
||||
}
|
||||
}
|
||||
|
||||
if(use_dbias)
|
||||
{
|
||||
dbias_host_ref = ds_hp_host_ref.template CopyAsType<BiasGradDataType>();
|
||||
@@ -1044,6 +1147,17 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
ref_idx++;
|
||||
}
|
||||
|
||||
if(pass && sink_grad)
|
||||
{
|
||||
auto [rtol, atol] = get_elimit<DataTypeConfig>(hdim_q, hdim_v);
|
||||
bool dsink_pass = ck_tile::check_err(d_sink_host,
|
||||
d_sink_host_ref,
|
||||
std::string("Error: SinkGrad Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
pass &= dsink_pass;
|
||||
}
|
||||
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user