mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
[rocm-libraries] ROCm/rocm-libraries#5504 (commit 47f86c7)
[CK Tile] Add sink token gradient support in FMHA backward pass (#5504) ## Motivation Adds sink token support to the FMHA backward kernel (dot_do_o pipeline): ## Technical Details - Extend BlockFmhaBwdOGradDotOPipelineProblem with LSEDataType - Add sink_ptr/d_sink_ptr/lse_ptr/nhead to FmhaBwdOGradDotOCommonKargs - Compute per-head sink gradient via atomic accumulation in the pipeline - Update example runner with reference validation for sink gradient ## Test Plan Add new test case ## Test Result WIP ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
c1127a36f5
commit
08792e0b31
@@ -1324,6 +1324,7 @@ struct FmhaBwdOGradDotOKernel
|
||||
using DDataType = ck_tile::remove_cvref_t<typename FmhaBwdOGradDotO::DDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename FmhaBwdOGradDotO::ODataType>;
|
||||
using OGradDataType = ck_tile::remove_cvref_t<typename FmhaBwdOGradDotO::OGradDataType>;
|
||||
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaBwdOGradDotO::LSEDataType>;
|
||||
|
||||
static constexpr bool kIsGroupMode = FmhaBwdOGradDotO::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = FmhaBwdOGradDotO::kPadSeqLenQ;
|
||||
@@ -1365,25 +1366,31 @@ struct FmhaBwdOGradDotOKernel
|
||||
const void* o_ptr;
|
||||
const void* do_ptr;
|
||||
void* d_ptr;
|
||||
const void* lse_ptr; // log-sum-exp from forward pass, shape [batch, nhead, seqlen_q]
|
||||
const LSEDataType* sink_ptr; // sink scores, shape [batch, nhead]; nullptr disables sink
|
||||
LSEDataType* d_sink_ptr; // sink gradient output, shape [nhead]; nullptr disables sink grad
|
||||
|
||||
float p_undrop;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead; // used to index sink_ptr / d_sink_ptr
|
||||
|
||||
ck_tile::index_t stride_do;
|
||||
ck_tile::index_t stride_o;
|
||||
|
||||
ck_tile::index_t nhead_stride_do;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t nhead_stride_d;
|
||||
// LSE and D always share the same layout; this stride covers both.
|
||||
ck_tile::index_t nhead_stride_lsed;
|
||||
};
|
||||
|
||||
struct FmhaBwdOGradDotOBatchModeKargs : FmhaBwdOGradDotOCommonKargs
|
||||
{
|
||||
ck_tile::index_t batch_stride_do;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t batch_stride_d;
|
||||
// LSE and D always share the same layout; this stride covers both.
|
||||
ck_tile::index_t batch_stride_lsed;
|
||||
};
|
||||
|
||||
struct FmhaBwdOGradDotOGroupModeKargs : FmhaBwdOGradDotOCommonKargs
|
||||
@@ -1401,32 +1408,40 @@ struct FmhaBwdOGradDotOKernel
|
||||
MakeKargs(const void* o_ptr,
|
||||
const void* do_ptr,
|
||||
void* d_ptr,
|
||||
const void* lse_ptr,
|
||||
const void* sink_ptr,
|
||||
void* d_sink_ptr,
|
||||
float p_undrop,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t stride_do,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_do,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_d,
|
||||
ck_tile::index_t nhead_stride_lsed,
|
||||
ck_tile::index_t batch_stride_do,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t batch_stride_d)
|
||||
ck_tile::index_t batch_stride_lsed)
|
||||
{
|
||||
Kargs kargs{{o_ptr,
|
||||
do_ptr,
|
||||
d_ptr,
|
||||
lse_ptr,
|
||||
reinterpret_cast<const LSEDataType*>(sink_ptr),
|
||||
reinterpret_cast<LSEDataType*>(d_sink_ptr),
|
||||
p_undrop,
|
||||
seqlen_q,
|
||||
hdim_v,
|
||||
nhead,
|
||||
stride_do,
|
||||
stride_o,
|
||||
nhead_stride_do,
|
||||
nhead_stride_o,
|
||||
nhead_stride_d},
|
||||
nhead_stride_lsed},
|
||||
batch_stride_do,
|
||||
batch_stride_o,
|
||||
batch_stride_d};
|
||||
batch_stride_lsed};
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -1436,28 +1451,36 @@ struct FmhaBwdOGradDotOKernel
|
||||
MakeKargs(const void* o_ptr,
|
||||
const void* do_ptr,
|
||||
void* d_ptr,
|
||||
const void* lse_ptr,
|
||||
const void* sink_ptr,
|
||||
void* d_sink_ptr,
|
||||
float p_undrop,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqlen_q_ptr,
|
||||
const void* cu_seqlen_q_ptr,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t stride_do,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_do,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_d)
|
||||
ck_tile::index_t nhead_stride_lsed)
|
||||
{
|
||||
Kargs kargs{{o_ptr,
|
||||
do_ptr,
|
||||
d_ptr,
|
||||
lse_ptr,
|
||||
reinterpret_cast<const LSEDataType*>(sink_ptr),
|
||||
reinterpret_cast<LSEDataType*>(d_sink_ptr),
|
||||
p_undrop,
|
||||
-1, // seqlen will be updated by another pointer
|
||||
hdim_v,
|
||||
nhead,
|
||||
stride_do,
|
||||
stride_o,
|
||||
nhead_stride_do,
|
||||
nhead_stride_o,
|
||||
nhead_stride_d},
|
||||
nhead_stride_lsed},
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr)};
|
||||
@@ -1491,18 +1514,18 @@ struct FmhaBwdOGradDotOKernel
|
||||
|
||||
const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * kM0);
|
||||
|
||||
long_index_t batch_offset_o = 0;
|
||||
long_index_t batch_offset_do = 0;
|
||||
long_index_t batch_offset_d = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
long_index_t batch_offset_do = 0;
|
||||
long_index_t batch_offset_lsed = 0;
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
|
||||
batch_offset_o = query_start * kargs.stride_o;
|
||||
batch_offset_do = query_start * kargs.stride_do;
|
||||
batch_offset_d = query_start;
|
||||
batch_offset_o = query_start * kargs.stride_o;
|
||||
batch_offset_do = query_start * kargs.stride_do;
|
||||
batch_offset_lsed = query_start;
|
||||
|
||||
// Priority: cu_seqlen_q_ptr > seqlen_q_ptr > physical_seqlen_q
|
||||
if(kargs.cu_seqlen_q_ptr != nullptr)
|
||||
@@ -1530,11 +1553,20 @@ struct FmhaBwdOGradDotOKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
batch_offset_do = static_cast<long_index_t>(i_batch) * kargs.batch_stride_do;
|
||||
batch_offset_d = static_cast<long_index_t>(i_batch) * kargs.batch_stride_d;
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
batch_offset_do = static_cast<long_index_t>(i_batch) * kargs.batch_stride_do;
|
||||
batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed;
|
||||
}
|
||||
|
||||
// Read per-head sink score and convert to log2 domain so the pipeline can use exp2.
|
||||
// Pre-multiply by log2e so that exp2(sink_value - log2e*lse) == exp(raw_sink - lse).
|
||||
// -inf is left unchanged (log2e * -inf == -inf) to keep P_sink -> 0 when sink is disabled.
|
||||
const LSEDataType sink_value =
|
||||
kargs.sink_ptr != nullptr
|
||||
? log2e_v<LSEDataType> *
|
||||
kargs.sink_ptr[static_cast<long_index_t>(i_batch) * kargs.nhead + i_nhead]
|
||||
: -numeric<LSEDataType>::infinity();
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
const ODataType* o_ptr = reinterpret_cast<const ODataType*>(kargs.o_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
|
||||
@@ -1542,9 +1574,13 @@ struct FmhaBwdOGradDotOKernel
|
||||
const OGradDataType* do_ptr = reinterpret_cast<const OGradDataType*>(kargs.do_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_do +
|
||||
batch_offset_do;
|
||||
const LSEDataType* lse_ptr = reinterpret_cast<const LSEDataType*>(kargs.lse_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_lsed +
|
||||
batch_offset_lsed;
|
||||
|
||||
DDataType* d_ptr = reinterpret_cast<DDataType*>(kargs.d_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_d +
|
||||
batch_offset_d;
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_lsed +
|
||||
batch_offset_lsed;
|
||||
|
||||
// O/dO/D DRAM and DRAM window
|
||||
const auto o_dram = [&]() {
|
||||
@@ -1578,13 +1614,31 @@ struct FmhaBwdOGradDotOKernel
|
||||
|
||||
auto o_dram_window =
|
||||
make_tile_window(o_dram, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {i_m0, 0});
|
||||
|
||||
auto do_dram_window =
|
||||
make_tile_window(do_dram, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {i_m0, 0});
|
||||
|
||||
auto d_dram_window = make_tile_window(d_dram, make_tuple(number<kM0>{}), {i_m0});
|
||||
|
||||
FmhaBwdOGradDotO{}(o_dram_window, do_dram_window, d_dram_window, kargs.p_undrop);
|
||||
// nullptr when sink grad is disabled; the pipeline checks this to skip the sink path
|
||||
LSEDataType* atomic_sink_grad_ptr =
|
||||
kargs.d_sink_ptr == nullptr ? nullptr : kargs.d_sink_ptr + i_nhead;
|
||||
|
||||
// lse_ptr is always valid (also needed by the main bwd kernel).
|
||||
// The actual load happens inside the pipeline only when atomic_sink_grad_ptr != nullptr.
|
||||
auto lse_dram = [&]() {
|
||||
const auto lse_dram_naive = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
lse_ptr, make_tuple(kargs.seqlen_q), number<1>{});
|
||||
return pad_tensor_view(
|
||||
lse_dram_naive, make_tuple(number<kM0>{}), sequence<kPadSeqLenQ>{});
|
||||
}();
|
||||
auto lse_dram_window = make_tile_window(lse_dram, make_tuple(number<kM0>{}), {i_m0});
|
||||
|
||||
FmhaBwdOGradDotO{}(o_dram_window,
|
||||
do_dram_window,
|
||||
lse_dram_window,
|
||||
d_dram_window,
|
||||
sink_value,
|
||||
kargs.p_undrop,
|
||||
atomic_sink_grad_ptr);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user