mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
[CK_TILE] Fix FMHA BWD workspace upper-bound undersizing in group mode
GetWorkspaceDeviceSizeUpperBound was computing
max_batch * nhead_q * max_seqlen_q * hdim_q
in non-deterministic group mode, but PrepareWorkspaceHost actually returns
nhead_q * seqstart_q[batch] * hdim_q
i.e. it scales with the sum of *padded* per-batch seqlen_q, not max_batch
times the *logical* max. When per-batch padding makes seqstart_q[batch]
exceed max_batch * max_seqlen_q the launcher under-allocates dq_acc, the
kernel writes past the buffer, and tests see either ~42% wrong QGrad
values or a GPU page fault (e.g. test_ck_tile_fmha_bwd_bf16
QKVPadding/23,24,26 corrupt; /27 page-faults).
Fix: replace the (max_batch, max_seqlen_q) pair with a single
total_seqlen_q_padded parameter holding the true total padded q tokens.
Launcher derives it from the trait (group: t.seqlen_q already is the
padded total; batch: t.batch * t.seqlen_q). The four mode formulas
collapse to one:
size = nhead_q * nsplits_factor * total_seqlen_q_padded * hdim_q
where nsplits_factor is 1 for non-deterministic, ceil(max_seqlen_k, kN0)
for deterministic group, and the persistent worker computation for
deterministic non-group (the only branch that still needs max_batch).
No caller-side API change: FA, AITER and the CK runner already pass
q.shape[0] (the padded total) as traits.seqlen_q in group mode.
Verified on gfx1201: full test_ck_tile_fmha_bwd_{bf16,fp16} 672/672 PASS,
0 fail, 0 crash (was 27/28 QKVPadding fails + 1 GPU illegal access).
This commit is contained in:
@@ -178,11 +178,11 @@ size_t fmha_bwd_dq_dk_dv_dq_ws_host_size_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(
|
||||
template <>
|
||||
size_t fmha_bwd_dq_dk_dv_dq_ws_device_upper_bound_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(
|
||||
ck_tile::index_t max_batch, ck_tile::index_t hdim_q, ck_tile::index_t nhead_q,
|
||||
ck_tile::index_t max_seqlen_q, ck_tile::index_t max_seqlen_k)
|
||||
ck_tile::index_t total_seqlen_q_padded, ck_tile::index_t max_seqlen_k)
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
return k_::GetWorkspaceDeviceSizeUpperBound(
|
||||
max_batch, hdim_q, nhead_q, max_seqlen_q, max_seqlen_k);
|
||||
max_batch, hdim_q, nhead_q, total_seqlen_q_padded, max_seqlen_k);
|
||||
}}
|
||||
|
||||
template <>
|
||||
|
||||
@@ -469,11 +469,14 @@ int fmha_bwd_dq_dk_dv_maxq_();
|
||||
struct fmha_bwd_traits;
|
||||
template <typename Traits_, typename Arch = void>
|
||||
size_t fmha_bwd_dq_dk_dv_dq_ws_host_size_(int batch_size);
|
||||
// `total_seqlen_q_padded` is total q tokens across all batches (incl. per-batch padding):
|
||||
// - batch mode: max_batch * seqlen_q
|
||||
// - group mode: seqstart_q[batch] (== varlen q tensor's first dim)
|
||||
template <typename Traits_, typename Arch = void>
|
||||
size_t fmha_bwd_dq_dk_dv_dq_ws_device_upper_bound_(ck_tile::index_t max_batch,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t nhead_q,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t total_seqlen_q_padded,
|
||||
ck_tile::index_t max_seqlen_k);
|
||||
template <typename Traits_, typename Arch = void>
|
||||
size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_(void* cpu_ws,
|
||||
@@ -730,8 +733,12 @@ struct fmha_bwd_launcher
|
||||
size_t device_ws_size = 0;
|
||||
if(host_ws_size_ > 0)
|
||||
{
|
||||
// In group mode t.seqlen_q is already the padded total (== seqstart_q[batch]);
|
||||
// in batch mode it's per-batch and the total is batch * seqlen_q.
|
||||
const ck_tile::index_t total_seqlen_q_padded =
|
||||
t.is_group_mode ? t.seqlen_q : t.batch * t.seqlen_q;
|
||||
device_ws_size = fmha_bwd_dq_dk_dv_dq_ws_device_upper_bound_<T1, Arch>(
|
||||
t.batch, t.hdim_q, t.nhead_q, t.max_seqlen_q, t.max_seqlen_k);
|
||||
t.batch, t.hdim_q, t.nhead_q, total_seqlen_q_padded, t.max_seqlen_k);
|
||||
pack_workspace_host_ = [batch = t.batch,
|
||||
hdim_q = t.hdim_q,
|
||||
nhead_q = t.nhead_q,
|
||||
|
||||
@@ -167,45 +167,47 @@ struct FmhaBwdWorkspaceManager
|
||||
return kHasMask;
|
||||
}
|
||||
|
||||
// Mirrors PrepareWorkspaceHost's return value but uses worst-case totals so
|
||||
// device workspace can be pre-allocated before host has the seqstart values.
|
||||
// Upper bound on PrepareWorkspaceHost's size, computable without seqstart so
|
||||
// the device workspace can be allocated before any D2H.
|
||||
//
|
||||
// total_seqlen_q_padded: total q tokens incl. per-batch padding.
|
||||
// Batch: max_batch * seqlen_q. Group: seqstart_q[batch].
|
||||
// max_seqlen_k: deterministic-only; pass per-batch padded max if the caller
|
||||
// does internal k padding, otherwise the logical max is fine.
|
||||
template <bool kUseQrQtrDorPipeline, index_t kN0>
|
||||
CK_TILE_HOST static size_t GetWorkspaceDeviceSizeUpperBound(index_t max_batch,
|
||||
index_t hdim_q,
|
||||
index_t nhead_q,
|
||||
index_t max_seqlen_q,
|
||||
index_t total_seqlen_q_padded,
|
||||
index_t max_seqlen_k)
|
||||
{
|
||||
if constexpr(kUseQrQtrDorPipeline)
|
||||
return 0;
|
||||
|
||||
if constexpr(!kIsDeterministic)
|
||||
index_t nsplits_factor = 1;
|
||||
if constexpr(kIsDeterministic)
|
||||
{
|
||||
return sizeof(AccDataType) * static_cast<long_index_t>(max_batch) * nhead_q *
|
||||
max_seqlen_q * hdim_q;
|
||||
}
|
||||
else if constexpr(kIsGroupMode)
|
||||
{
|
||||
const index_t nsplits_max = integer_divide_ceil(max_seqlen_k, kN0);
|
||||
return sizeof(AccDataType) * static_cast<long_index_t>(max_batch) * nhead_q *
|
||||
nsplits_max * max_seqlen_q * hdim_q;
|
||||
}
|
||||
else // deterministic non-group mode (kUsePersistent)
|
||||
{
|
||||
const index_t dqdqkdv_workers = get_num_cus();
|
||||
const index_t jobs_per_head = integer_divide_ceil(max_seqlen_k, kN0);
|
||||
const index_t total_jobs = max_batch * nhead_q * jobs_per_head;
|
||||
const index_t jobs_per_worker = integer_divide_ceil(total_jobs, dqdqkdv_workers);
|
||||
index_t nsplits;
|
||||
if(jobs_per_head % jobs_per_worker == 0)
|
||||
nsplits = jobs_per_head / jobs_per_worker;
|
||||
else if(jobs_per_worker % jobs_per_head == 0)
|
||||
nsplits = 1;
|
||||
else
|
||||
nsplits = 1 + integer_divide_ceil(jobs_per_head - 1, jobs_per_worker);
|
||||
return sizeof(AccDataType) * static_cast<long_index_t>(max_batch) * nhead_q * nsplits *
|
||||
max_seqlen_q * hdim_q;
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
nsplits_factor = integer_divide_ceil(max_seqlen_k, kN0);
|
||||
}
|
||||
else // persistent
|
||||
{
|
||||
const index_t dqdqkdv_workers = get_num_cus();
|
||||
const index_t jobs_per_head = integer_divide_ceil(max_seqlen_k, kN0);
|
||||
const index_t total_jobs = max_batch * nhead_q * jobs_per_head;
|
||||
const index_t jobs_per_worker = integer_divide_ceil(total_jobs, dqdqkdv_workers);
|
||||
if(jobs_per_head % jobs_per_worker == 0)
|
||||
nsplits_factor = jobs_per_head / jobs_per_worker;
|
||||
else if(jobs_per_worker % jobs_per_head == 0)
|
||||
nsplits_factor = 1;
|
||||
else
|
||||
nsplits_factor = 1 + integer_divide_ceil(jobs_per_head - 1, jobs_per_worker);
|
||||
}
|
||||
}
|
||||
|
||||
return sizeof(AccDataType) * static_cast<long_index_t>(nhead_q) * nsplits_factor *
|
||||
total_seqlen_q_padded * hdim_q;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user