[rocm-libraries] ROCm/rocm-libraries#5504 (commit 47f86c7)

[CK Tile] Add sink token gradient support in FMHA backward
 pass (#5504)

## Motivation

Adds sink token support to the FMHA backward kernel (dot_do_o pipeline):

## Technical Details

- Extend BlockFmhaBwdOGradDotOPipelineProblem with LSEDataType
- Add sink_ptr/d_sink_ptr/lse_ptr/nhead to FmhaBwdOGradDotOCommonKargs
- Compute per-head sink gradient via atomic accumulation in the pipeline
- Update example runner with reference validation for sink gradient

## Test Plan

Add new test case

## Test Result

WIP

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Linjun-AMD
2026-04-02 03:17:45 +00:00
committed by assistant-librarian[bot]
parent c1127a36f5
commit 08792e0b31
12 changed files with 380 additions and 130 deletions

View File

@@ -1324,6 +1324,7 @@ struct FmhaBwdOGradDotOKernel
using DDataType = ck_tile::remove_cvref_t<typename FmhaBwdOGradDotO::DDataType>;
using ODataType = ck_tile::remove_cvref_t<typename FmhaBwdOGradDotO::ODataType>;
using OGradDataType = ck_tile::remove_cvref_t<typename FmhaBwdOGradDotO::OGradDataType>;
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaBwdOGradDotO::LSEDataType>;
static constexpr bool kIsGroupMode = FmhaBwdOGradDotO::kIsGroupMode;
static constexpr bool kPadSeqLenQ = FmhaBwdOGradDotO::kPadSeqLenQ;
@@ -1365,25 +1366,31 @@ struct FmhaBwdOGradDotOKernel
const void* o_ptr;
const void* do_ptr;
void* d_ptr;
const void* lse_ptr; // log-sum-exp from forward pass, shape [batch, nhead, seqlen_q]
const LSEDataType* sink_ptr; // sink scores, shape [batch, nhead]; nullptr disables sink
LSEDataType* d_sink_ptr; // sink gradient output, shape [nhead]; nullptr disables sink grad
float p_undrop;
ck_tile::index_t seqlen_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead; // used to index sink_ptr / d_sink_ptr
ck_tile::index_t stride_do;
ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t nhead_stride_d;
// LSE and D always share the same layout; this stride covers both.
ck_tile::index_t nhead_stride_lsed;
};
struct FmhaBwdOGradDotOBatchModeKargs : FmhaBwdOGradDotOCommonKargs
{
ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_d;
// LSE and D always share the same layout; this stride covers both.
ck_tile::index_t batch_stride_lsed;
};
struct FmhaBwdOGradDotOGroupModeKargs : FmhaBwdOGradDotOCommonKargs
@@ -1401,32 +1408,40 @@ struct FmhaBwdOGradDotOKernel
MakeKargs(const void* o_ptr,
const void* do_ptr,
void* d_ptr,
const void* lse_ptr,
const void* sink_ptr,
void* d_sink_ptr,
float p_undrop,
ck_tile::index_t seqlen_q,
ck_tile::index_t hdim_v,
ck_tile::index_t nhead,
ck_tile::index_t stride_do,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_d,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t batch_stride_do,
ck_tile::index_t batch_stride_o,
ck_tile::index_t batch_stride_d)
ck_tile::index_t batch_stride_lsed)
{
Kargs kargs{{o_ptr,
do_ptr,
d_ptr,
lse_ptr,
reinterpret_cast<const LSEDataType*>(sink_ptr),
reinterpret_cast<LSEDataType*>(d_sink_ptr),
p_undrop,
seqlen_q,
hdim_v,
nhead,
stride_do,
stride_o,
nhead_stride_do,
nhead_stride_o,
nhead_stride_d},
nhead_stride_lsed},
batch_stride_do,
batch_stride_o,
batch_stride_d};
batch_stride_lsed};
return kargs;
}
@@ -1436,28 +1451,36 @@ struct FmhaBwdOGradDotOKernel
MakeKargs(const void* o_ptr,
const void* do_ptr,
void* d_ptr,
const void* lse_ptr,
const void* sink_ptr,
void* d_sink_ptr,
float p_undrop,
const void* seqstart_q_ptr,
const void* seqlen_q_ptr,
const void* cu_seqlen_q_ptr,
ck_tile::index_t hdim_v,
ck_tile::index_t nhead,
ck_tile::index_t stride_do,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_d)
ck_tile::index_t nhead_stride_lsed)
{
Kargs kargs{{o_ptr,
do_ptr,
d_ptr,
lse_ptr,
reinterpret_cast<const LSEDataType*>(sink_ptr),
reinterpret_cast<LSEDataType*>(d_sink_ptr),
p_undrop,
-1, // seqlen will be updated by another pointer
hdim_v,
nhead,
stride_do,
stride_o,
nhead_stride_do,
nhead_stride_o,
nhead_stride_d},
nhead_stride_lsed},
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr)};
@@ -1491,18 +1514,18 @@ struct FmhaBwdOGradDotOKernel
const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * kM0);
long_index_t batch_offset_o = 0;
long_index_t batch_offset_do = 0;
long_index_t batch_offset_d = 0;
long_index_t batch_offset_o = 0;
long_index_t batch_offset_do = 0;
long_index_t batch_offset_lsed = 0;
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_o = query_start * kargs.stride_o;
batch_offset_do = query_start * kargs.stride_do;
batch_offset_d = query_start;
batch_offset_o = query_start * kargs.stride_o;
batch_offset_do = query_start * kargs.stride_do;
batch_offset_lsed = query_start;
// Priority: cu_seqlen_q_ptr > seqlen_q_ptr > physical_seqlen_q
if(kargs.cu_seqlen_q_ptr != nullptr)
@@ -1530,11 +1553,20 @@ struct FmhaBwdOGradDotOKernel
}
else
{
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
batch_offset_do = static_cast<long_index_t>(i_batch) * kargs.batch_stride_do;
batch_offset_d = static_cast<long_index_t>(i_batch) * kargs.batch_stride_d;
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
batch_offset_do = static_cast<long_index_t>(i_batch) * kargs.batch_stride_do;
batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed;
}
// Read per-head sink score and convert to log2 domain so the pipeline can use exp2.
// Pre-multiply by log2e so that exp2(sink_value - log2e*lse) == exp(raw_sink - lse).
// -inf is left unchanged (log2e * -inf == -inf) to keep P_sink -> 0 when sink is disabled.
const LSEDataType sink_value =
kargs.sink_ptr != nullptr
? log2e_v<LSEDataType> *
kargs.sink_ptr[static_cast<long_index_t>(i_batch) * kargs.nhead + i_nhead]
: -numeric<LSEDataType>::infinity();
// for simplicity, batch stride we just modify the pointer
const ODataType* o_ptr = reinterpret_cast<const ODataType*>(kargs.o_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
@@ -1542,9 +1574,13 @@ struct FmhaBwdOGradDotOKernel
const OGradDataType* do_ptr = reinterpret_cast<const OGradDataType*>(kargs.do_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_do +
batch_offset_do;
const LSEDataType* lse_ptr = reinterpret_cast<const LSEDataType*>(kargs.lse_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_lsed +
batch_offset_lsed;
DDataType* d_ptr = reinterpret_cast<DDataType*>(kargs.d_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_d +
batch_offset_d;
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_lsed +
batch_offset_lsed;
// O/dO/D DRAM and DRAM window
const auto o_dram = [&]() {
@@ -1578,13 +1614,31 @@ struct FmhaBwdOGradDotOKernel
auto o_dram_window =
make_tile_window(o_dram, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {i_m0, 0});
auto do_dram_window =
make_tile_window(do_dram, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {i_m0, 0});
auto d_dram_window = make_tile_window(d_dram, make_tuple(number<kM0>{}), {i_m0});
FmhaBwdOGradDotO{}(o_dram_window, do_dram_window, d_dram_window, kargs.p_undrop);
// nullptr when sink grad is disabled; the pipeline checks this to skip the sink path
LSEDataType* atomic_sink_grad_ptr =
kargs.d_sink_ptr == nullptr ? nullptr : kargs.d_sink_ptr + i_nhead;
// lse_ptr is always valid (also needed by the main bwd kernel).
// The actual load happens inside the pipeline only when atomic_sink_grad_ptr != nullptr.
auto lse_dram = [&]() {
const auto lse_dram_naive = make_naive_tensor_view_packed<address_space_enum::global>(
lse_ptr, make_tuple(kargs.seqlen_q), number<1>{});
return pad_tensor_view(
lse_dram_naive, make_tuple(number<kM0>{}), sequence<kPadSeqLenQ>{});
}();
auto lse_dram_window = make_tile_window(lse_dram, make_tuple(number<kM0>{}), {i_m0});
FmhaBwdOGradDotO{}(o_dram_window,
do_dram_window,
lse_dram_window,
d_dram_window,
sink_value,
kargs.p_undrop,
atomic_sink_grad_ptr);
}
};

View File

@@ -14,6 +14,7 @@ struct BlockFmhaBwdOGradDotO
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
using DDataType = remove_cvref_t<typename Problem::DDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>; // needed for sink gradient
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize;
@@ -32,11 +33,18 @@ struct BlockFmhaBwdOGradDotO
template <typename ODramBlockWindowTmp,
typename OGradDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename DDramBlockWindowTmp>
// Computes D = diag(dO * O) and optionally accumulates the sink token gradient.
// sink_value: log-space sink score; pass -inf and atomic_sink_grad_ptr=nullptr to skip sink.
// atomic_sink_grad_ptr: per-head accumulator in global memory; nullptr disables sink path.
CK_TILE_HOST_DEVICE void operator()(const ODramBlockWindowTmp& o_dram_block_window_tmp,
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
DDramBlockWindowTmp& d_dram_block_window_tmp,
float p_undrop) const
const LSEDataType sink_value,
float p_undrop,
LSEDataType* atomic_sink_grad_ptr = nullptr) const
{
static_assert(
std::is_same_v<ODataType, remove_cvref_t<typename ODramBlockWindowTmp::DataType>> &&
@@ -44,6 +52,10 @@ struct BlockFmhaBwdOGradDotO
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>>,
"wrong!");
// atomic_sink_grad_ptr is reinterpret_cast to float* in the sink path;
// ensure LSEDataType is float so the cast is well-defined.
static_assert(std::is_same_v<LSEDataType, float>,
"sink gradient atomicAdd requires LSEDataType == float");
static_assert(kBlockSize == ODramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kBlockSize ==
@@ -67,14 +79,13 @@ struct BlockFmhaBwdOGradDotO
auto do_ = load_tile(do_dram_window);
// declare d
// D[q] = sum_j(O[q,j] * dO[q,j]), used in softmax backward
constexpr auto d_dstr =
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
o.get_tile_distribution().get_static_tile_distribution_encoding(), sequence<1>{}));
auto d = make_static_distributed_tensor<DDataType>(d_dstr);
clear_tile(d); // Initialize D
clear_tile(d);
constexpr auto o_spans = decltype(o)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
@@ -86,9 +97,67 @@ struct BlockFmhaBwdOGradDotO
});
});
// Scale by p_undrop (=1 when dropout is disabled)
tile_elementwise_inout([&p_undrop](auto& x) { x = x * p_undrop; }, d);
store_tile(d_dram_block_window_tmp, d);
// Sink gradient path: skipped entirely when atomic_sink_grad_ptr is nullptr
if(atomic_sink_grad_ptr != nullptr)
{
// Load LSE only on the sink path to avoid unnecessary global memory reads
constexpr auto lse_dstr =
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
o.get_tile_distribution().get_static_tile_distribution_encoding(),
sequence<1>{}));
auto lse_dram_window =
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
lse_dram_block_window_tmp.get_window_origin(),
lse_dstr);
auto lse_ = load_tile(lse_dram_window);
// Compute per-query contribution: -P_sink[q] * D[q]
// where P_sink[q] = exp2(sink_value - log2e*lse[q])
// sink_value has already been pre-multiplied by log2e at the kernel call site,
// so exp2(sink_value - log2e*lse) == exp(raw_sink - lse).
// exp2 maps directly to the v_exp_f32 hardware instruction on AMD GPUs.
// Always accumulate in float regardless of DDataType to avoid precision loss
// and to ensure atomicAdd works correctly on all architectures.
auto sink_val_tensor = make_static_distributed_tensor<float>(d_dstr);
tile_elementwise_inout(
[&](auto& s_out, const auto& l_in, const auto& d_in) {
float p_sink = exp2(type_convert<float>(sink_value) -
log2e_v<float> * type_convert<float>(l_in));
s_out = -p_sink * type_convert<float>(d_in);
},
sink_val_tensor,
lse_,
d);
// Reduce contributions held by this thread
float thread_sum = 0.f;
constexpr auto s_spans = decltype(sink_val_tensor)::get_distributed_spans();
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
thread_sum += sink_val_tensor(i_idx);
});
// Warp-level reduction: fold thread_sum across lanes so only one
// atomicAdd per warp is issued instead of one per thread.
#if defined(__HIP_DEVICE_COMPILE__) || defined(__CUDA_ARCH__)
const index_t warp_sz = get_warp_size();
for(index_t offset = warp_sz >> 1; offset > 0; offset >>= 1)
thread_sum += warp_shuffle_down(thread_sum, offset);
// Only lane 0 of each warp writes to global memory.
// Note: this atomicAdd is non-deterministic across runs regardless of the
// -deterministic flag, because d_sink is a single scalar per head accumulated
// across all thread-blocks. The practical impact is negligible for this value.
if(get_lane_id() == 0)
atomicAdd(reinterpret_cast<float*>(atomic_sink_grad_ptr), thread_sum);
#endif
}
}
};

View File

@@ -67,6 +67,7 @@ struct BlockFmhaBwdPipelineProblem
template <typename ODataType_,
typename OGradDataType_,
typename DDataType_,
typename LSEDataType_,
index_t kBlockSize_,
index_t kVHeaddim_,
bool kIsGroupMode_,
@@ -76,6 +77,7 @@ struct BlockFmhaBwdOGradDotOPipelineProblem
using ODataType = remove_cvref_t<ODataType_>;
using OGradDataType = remove_cvref_t<OGradDataType_>;
using DDataType = remove_cvref_t<DDataType_>;
using LSEDataType = remove_cvref_t<LSEDataType_>;
using Traits = remove_cvref_t<Traits_>;
static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0,