mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
[fmha-bwd] Fix dK/dV left uninitialized for zero-length Q batches in group persistent
This commit is contained in:
@@ -156,7 +156,7 @@ struct FmhaBwdWorkspaceManager
|
||||
}
|
||||
|
||||
// Fill CPU prepared workspace and return size of non CPU prepared workspace size
|
||||
template <bool kUseQrQtrDorPipeline, index_t kN0>
|
||||
template <bool kUseQrQtrDorPipeline, index_t kN0, index_t kM0>
|
||||
CK_TILE_HOST static size_t
|
||||
PrepareWorkspaceHost(void* cpu_ws,
|
||||
index_t batch_size,
|
||||
@@ -202,26 +202,33 @@ struct FmhaBwdWorkspaceManager
|
||||
auto* batch_states = reinterpret_cast<FmhaBwdBatchState*>(
|
||||
reinterpret_cast<char*>(cpu_ws) + GetBatchStateOffset(batch_size));
|
||||
|
||||
// sq_work: sq aligned to kM0 for work-distribution purposes.
|
||||
// If sq==0, use kM0 so CUs are still dispatched and write dK/dV=0.
|
||||
const auto sq_work = [](index_t sq) -> index_t {
|
||||
return sq == 0 ? kM0 : integer_least_multiple(sq, kM0);
|
||||
};
|
||||
|
||||
prefix_batch[0] = 0;
|
||||
for(index_t b = 0; b < batch_size; ++b)
|
||||
{
|
||||
const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b];
|
||||
const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0);
|
||||
prefix_batch[b + 1] = prefix_batch[b] + nhead_q * nc * sq;
|
||||
prefix_batch[b + 1] = prefix_batch[b] + nhead_q * nc * sq_work(sq);
|
||||
}
|
||||
const index_t target_w = integer_divide_ceil(prefix_batch[batch_size], num_cus);
|
||||
|
||||
// Step 2: compute nsplits[b] and fill batch_states[b] (sq, nc, nsplits per batch)
|
||||
for(index_t b = 0; b < batch_size; ++b)
|
||||
{
|
||||
const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b];
|
||||
const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0);
|
||||
const index_t rest_workload = (nc > 0) ? (nc - 1) * sq : 0;
|
||||
const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b];
|
||||
const index_t sq_w = sq_work(sq);
|
||||
const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0);
|
||||
const index_t rest_workload = (nc > 0) ? (nc - 1) * sq_w : 0;
|
||||
const index_t ns = 1 + (rest_workload > 0 && target_w > 0
|
||||
? integer_divide_ceil(rest_workload, target_w)
|
||||
: 0);
|
||||
nsplits[b] = ns;
|
||||
batch_states[b].sq = sq;
|
||||
batch_states[b].sq = sq_w; // GPU uses sq_w for w_chunk tracking
|
||||
batch_states[b].nc = nc;
|
||||
batch_states[b].nsplits = ns;
|
||||
}
|
||||
@@ -246,10 +253,11 @@ struct FmhaBwdWorkspaceManager
|
||||
index_t cu_lo = 0;
|
||||
for(index_t b = 0; b < batch_size; ++b)
|
||||
{
|
||||
const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b];
|
||||
const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0);
|
||||
const index_t hw = nc * sq;
|
||||
const index_t pb = prefix_batch[b];
|
||||
const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b];
|
||||
const index_t sq_w = sq_work(sq); // kM0-aligned sq for work distribution
|
||||
const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0);
|
||||
const index_t hw = nc * sq_w; // use sq_w so sq=0 batches get work
|
||||
const index_t pb = prefix_batch[b];
|
||||
const index_t cu_hi =
|
||||
min(num_cus, integer_divide_ceil(prefix_batch[b + 1], target_w));
|
||||
for(index_t c = cu_lo; c < cu_hi; ++c)
|
||||
@@ -263,13 +271,13 @@ struct FmhaBwdWorkspaceManager
|
||||
const index_t w_head = pb + head_start * hw;
|
||||
const index_t wc_start = max(w_lo - w_head, index_t(0));
|
||||
const index_t c_start =
|
||||
wc_start > 0 ? integer_divide_ceil(wc_start, sq) : 0;
|
||||
wc_start > 0 ? integer_divide_ceil(wc_start, sq_w) : 0;
|
||||
cu_states[c].isplit =
|
||||
wc_start > 0 ? integer_divide_ceil(wc_start, target_w) : 0;
|
||||
cu_states[c].head_start = head_start;
|
||||
cu_states[c].c_start = c_start;
|
||||
// w_lo = true global start of first K-chunk for this CU
|
||||
cu_states[c].w_lo = pb + head_start * hw + c_start * sq;
|
||||
cu_states[c].w_lo = pb + head_start * hw + c_start * sq_w;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -451,7 +459,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
CK_TILE_HOST static constexpr auto PrepareWorkspaceHost(Args&&... args)
|
||||
{
|
||||
return WorkspaceManager::template PrepareWorkspaceHost<kUseQrQtrDorPipeline,
|
||||
FmhaPipeline::BlockFmhaShape::kN0>(
|
||||
FmhaPipeline::BlockFmhaShape::kN0,
|
||||
FmhaPipeline::BlockFmhaShape::kM0>(
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
template <typename... Args>
|
||||
|
||||
Reference in New Issue
Block a user