[CK_TILE] Fix dq_acc per-nhead stride in FMHA BWD group mode

In group mode the dq_acc workspace layout uses physical (padded)
seqlen_q for the per-nhead stride (see FmhaBwdWorkspaceManager doc;
also matches FmhaBwdConvertQGradKernel reads). The unified-workspace
refactor inlined this stride as kargs.seqlen_q, which is the LOGICAL
length when seqlen_q_ptr is provided. The result: main kernel writes
batch i nhead>0 dq_acc at offsets that the convert kernel never reads,
so dQ ends up zero for those positions.

Hoist physical_seqlen_q to the outer scope and use it for both the
non-deterministic and deterministic stride computations in the
dq_dram_window lambda. Batch mode is unaffected since kargs.seqlen_q
already equals the physical length there.

Fixes 135 padding-related failures in test_ck_tile_fmha_bwd_fp16
(BasicQPadding / MultiBatchPadding / PaddingWithMask / QKVPadding /
VariedPaddingRatios / ZeroLengthPadding / Deterministic /
ElementwiseBias). Verified locally: full suite 672 PASSED / 0 FAILED.
SGPR usage drops by 1; VGPR/AGPR/spill/occupancy unchanged.
This commit is contained in:
Ding, Yi
2026-04-27 01:54:53 -05:00
parent f122fc731f
commit b3a5e7ff64

View File

@@ -983,6 +983,9 @@ struct FmhaBwdDQDKDVKernel
long_index_t batch_offset_dk = 0;
long_index_t batch_offset_dv = 0;
long_index_t batch_offset_dbias = 0;
// dq_acc per-nhead stride uses padded seqlen_q in group mode; equals kargs.seqlen_q
// in batch mode. See FmhaBwdWorkspaceManager doc.
index_t physical_seqlen_q = kargs.seqlen_q;
if constexpr(kIsGroupMode)
{
@@ -990,6 +993,9 @@ struct FmhaBwdDQDKDVKernel
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
physical_seqlen_q =
static_cast<index_t>(kargs.seqstart_q_ptr[i_batch + 1] - query_start);
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
batch_offset_v = key_start * kargs.stride_v;
@@ -1030,10 +1036,6 @@ struct FmhaBwdDQDKDVKernel
}
else
{
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
const ck_tile::index_t physical_seqlen_q =
adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
kargs.seqlen_q =
kargs.seqlen_q_ptr ? kargs.seqlen_q_ptr[i_batch] : physical_seqlen_q;
}
@@ -1212,12 +1214,13 @@ struct FmhaBwdDQDKDVKernel
else if constexpr(!kIsDeterministic)
{
return batch_offset_dq_acc +
static_cast<long_index_t>(i_nhead_) * kargs.seqlen_q * kargs.hdim_q;
static_cast<long_index_t>(i_nhead_) * physical_seqlen_q * kargs.hdim_q;
}
else
{
const long_index_t split_stride = kargs.seqlen_q * kargs.hdim_q;
const auto nsplits = [&]() {
const long_index_t split_stride =
static_cast<long_index_t>(physical_seqlen_q) * kargs.hdim_q;
const auto nsplits = [&]() {
if constexpr(!kIsGroupMode)
return n_splits;
else