mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
[CK_TILE] Fix dq_acc per-nhead stride in FMHA BWD group mode
In group mode the dq_acc workspace layout uses physical (padded) seqlen_q for the per-nhead stride (see FmhaBwdWorkspaceManager doc; also matches FmhaBwdConvertQGradKernel reads). The unified-workspace refactor inlined this stride as kargs.seqlen_q, which is the LOGICAL length when seqlen_q_ptr is provided. The result: main kernel writes batch i nhead>0 dq_acc at offsets that the convert kernel never reads, so dQ ends up zero for those positions. Hoist physical_seqlen_q to the outer scope and use it for both the non-deterministic and deterministic stride computations in the dq_dram_window lambda. Batch mode is unaffected since kargs.seqlen_q already equals the physical length there. Fixes 135 padding-related failures in test_ck_tile_fmha_bwd_fp16 (BasicQPadding / MultiBatchPadding / PaddingWithMask / QKVPadding / VariedPaddingRatios / ZeroLengthPadding / Deterministic / ElementwiseBias). Verified locally: full suite 672 PASSED / 0 FAILED. SGPR usage drops by 1; VGPR/AGPR/spill/occupancy unchanged.
This commit is contained in:
@@ -983,6 +983,9 @@ struct FmhaBwdDQDKDVKernel
|
||||
long_index_t batch_offset_dk = 0;
|
||||
long_index_t batch_offset_dv = 0;
|
||||
long_index_t batch_offset_dbias = 0;
|
||||
// dq_acc per-nhead stride uses padded seqlen_q in group mode; equals kargs.seqlen_q
|
||||
// in batch mode. See FmhaBwdWorkspaceManager doc.
|
||||
index_t physical_seqlen_q = kargs.seqlen_q;
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
@@ -990,6 +993,9 @@ struct FmhaBwdDQDKDVKernel
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
|
||||
|
||||
physical_seqlen_q =
|
||||
static_cast<index_t>(kargs.seqstart_q_ptr[i_batch + 1] - query_start);
|
||||
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
@@ -1030,10 +1036,6 @@ struct FmhaBwdDQDKDVKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
const ck_tile::index_t physical_seqlen_q =
|
||||
adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
kargs.seqlen_q =
|
||||
kargs.seqlen_q_ptr ? kargs.seqlen_q_ptr[i_batch] : physical_seqlen_q;
|
||||
}
|
||||
@@ -1212,12 +1214,13 @@ struct FmhaBwdDQDKDVKernel
|
||||
else if constexpr(!kIsDeterministic)
|
||||
{
|
||||
return batch_offset_dq_acc +
|
||||
static_cast<long_index_t>(i_nhead_) * kargs.seqlen_q * kargs.hdim_q;
|
||||
static_cast<long_index_t>(i_nhead_) * physical_seqlen_q * kargs.hdim_q;
|
||||
}
|
||||
else
|
||||
{
|
||||
const long_index_t split_stride = kargs.seqlen_q * kargs.hdim_q;
|
||||
const auto nsplits = [&]() {
|
||||
const long_index_t split_stride =
|
||||
static_cast<long_index_t>(physical_seqlen_q) * kargs.hdim_q;
|
||||
const auto nsplits = [&]() {
|
||||
if constexpr(!kIsGroupMode)
|
||||
return n_splits;
|
||||
else
|
||||
|
||||
Reference in New Issue
Block a user