[fmha-bwd] Fix dK/dV left uninitialized for zero-length Q batches in group persistent

This commit is contained in:
Ding, Yi
2026-04-17 02:54:52 -05:00
parent 92f2ed758e
commit 3e3cb36c7a

View File

@@ -156,7 +156,7 @@ struct FmhaBwdWorkspaceManager
}
// Fill CPU prepared workspace and return size of non CPU prepared workspace size
template <bool kUseQrQtrDorPipeline, index_t kN0>
template <bool kUseQrQtrDorPipeline, index_t kN0, index_t kM0>
CK_TILE_HOST static size_t
PrepareWorkspaceHost(void* cpu_ws,
index_t batch_size,
@@ -202,26 +202,33 @@ struct FmhaBwdWorkspaceManager
auto* batch_states = reinterpret_cast<FmhaBwdBatchState*>(
reinterpret_cast<char*>(cpu_ws) + GetBatchStateOffset(batch_size));
// sq_work: sq aligned to kM0 for work-distribution purposes.
// If sq==0, use kM0 so CUs are still dispatched and write dK/dV=0.
const auto sq_work = [](index_t sq) -> index_t {
return sq == 0 ? kM0 : integer_least_multiple(sq, kM0);
};
prefix_batch[0] = 0;
for(index_t b = 0; b < batch_size; ++b)
{
const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b];
const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0);
prefix_batch[b + 1] = prefix_batch[b] + nhead_q * nc * sq;
prefix_batch[b + 1] = prefix_batch[b] + nhead_q * nc * sq_work(sq);
}
const index_t target_w = integer_divide_ceil(prefix_batch[batch_size], num_cus);
// Step 2: compute nsplits[b] and fill batch_states[b] (sq, nc, nsplits per batch)
for(index_t b = 0; b < batch_size; ++b)
{
const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b];
const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0);
const index_t rest_workload = (nc > 0) ? (nc - 1) * sq : 0;
const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b];
const index_t sq_w = sq_work(sq);
const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0);
const index_t rest_workload = (nc > 0) ? (nc - 1) * sq_w : 0;
const index_t ns = 1 + (rest_workload > 0 && target_w > 0
? integer_divide_ceil(rest_workload, target_w)
: 0);
nsplits[b] = ns;
batch_states[b].sq = sq;
batch_states[b].sq = sq_w; // GPU uses sq_w for w_chunk tracking
batch_states[b].nc = nc;
batch_states[b].nsplits = ns;
}
@@ -246,10 +253,11 @@ struct FmhaBwdWorkspaceManager
index_t cu_lo = 0;
for(index_t b = 0; b < batch_size; ++b)
{
const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b];
const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0);
const index_t hw = nc * sq;
const index_t pb = prefix_batch[b];
const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b];
const index_t sq_w = sq_work(sq); // kM0-aligned sq for work distribution
const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0);
const index_t hw = nc * sq_w; // use sq_w so sq=0 batches get work
const index_t pb = prefix_batch[b];
const index_t cu_hi =
min(num_cus, integer_divide_ceil(prefix_batch[b + 1], target_w));
for(index_t c = cu_lo; c < cu_hi; ++c)
@@ -263,13 +271,13 @@ struct FmhaBwdWorkspaceManager
const index_t w_head = pb + head_start * hw;
const index_t wc_start = max(w_lo - w_head, index_t(0));
const index_t c_start =
wc_start > 0 ? integer_divide_ceil(wc_start, sq) : 0;
wc_start > 0 ? integer_divide_ceil(wc_start, sq_w) : 0;
cu_states[c].isplit =
wc_start > 0 ? integer_divide_ceil(wc_start, target_w) : 0;
cu_states[c].head_start = head_start;
cu_states[c].c_start = c_start;
// w_lo = true global start of first K-chunk for this CU
cu_states[c].w_lo = pb + head_start * hw + c_start * sq;
cu_states[c].w_lo = pb + head_start * hw + c_start * sq_w;
}
else
{
@@ -451,7 +459,8 @@ struct FmhaBwdDQDKDVKernel
CK_TILE_HOST static constexpr auto PrepareWorkspaceHost(Args&&... args)
{
return WorkspaceManager::template PrepareWorkspaceHost<kUseQrQtrDorPipeline,
FmhaPipeline::BlockFmhaShape::kN0>(
FmhaPipeline::BlockFmhaShape::kN0,
FmhaPipeline::BlockFmhaShape::kM0>(
std::forward<Args>(args)...);
}
template <typename... Args>