From 422b6d6c164622586862825091eda75f755a603f Mon Sep 17 00:00:00 2001 From: "Ding, Yi" Date: Wed, 13 May 2026 02:20:09 -0400 Subject: [PATCH] [CK_TILE] Fix FMHA BWD workspace upper-bound undersizing in group mode GetWorkspaceDeviceSizeUpperBound was computing max_batch * nhead_q * max_seqlen_q * hdim_q in non-deterministic group mode, but PrepareWorkspaceHost actually returns nhead_q * seqstart_q[batch] * hdim_q i.e. it scales with the sum of *padded* per-batch seqlen_q, not max_batch times the *logical* max. When per-batch padding makes seqstart_q[batch] exceed max_batch * max_seqlen_q the launcher under-allocates dq_acc, the kernel writes past the buffer, and tests see either ~42% wrong QGrad values or a GPU page fault (e.g. test_ck_tile_fmha_bwd_bf16 QKVPadding/23,24,26 corrupt; /27 page-faults). Fix: replace the (max_batch, max_seqlen_q) pair with a single total_seqlen_q_padded parameter holding the true total padded q tokens. Launcher derives it from the trait (group: t.seqlen_q already is the padded total; batch: t.batch * t.seqlen_q). The four mode formulas collapse to one: size = nhead_q * nsplits_factor * total_seqlen_q_padded * hdim_q where nsplits_factor is 1 for non-deterministic, ceil(max_seqlen_k, kN0) for deterministic group, and the persistent worker computation for deterministic non-group (the only branch that still needs max_batch). No caller-side API change: FA, AITER and the CK runner already pass q.shape[0] (the padded total) as traits.seqlen_q in group mode. Verified on gfx1201: full test_ck_tile_fmha_bwd_{bf16,fp16} 672/672 PASS, 0 fail, 0 crash (was 27/28 QKVPadding fails + 1 GPU illegal access). --- .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 4 +- example/ck_tile/01_fmha/fmha_bwd.hpp | 11 +++- .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 58 ++++++++++--------- 3 files changed, 41 insertions(+), 32 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index d6493eb533..5ec6de2d4a 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -178,11 +178,11 @@ size_t fmha_bwd_dq_dk_dv_dq_ws_host_size_( template <> size_t fmha_bwd_dq_dk_dv_dq_ws_device_upper_bound_( ck_tile::index_t max_batch, ck_tile::index_t hdim_q, ck_tile::index_t nhead_q, - ck_tile::index_t max_seqlen_q, ck_tile::index_t max_seqlen_k) + ck_tile::index_t total_seqlen_q_padded, ck_tile::index_t max_seqlen_k) {{ using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; return k_::GetWorkspaceDeviceSizeUpperBound( - max_batch, hdim_q, nhead_q, max_seqlen_q, max_seqlen_k); + max_batch, hdim_q, nhead_q, total_seqlen_q_padded, max_seqlen_k); }} template <> diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 23b5bec8d4..76c72fc159 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -469,11 +469,14 @@ int fmha_bwd_dq_dk_dv_maxq_(); struct fmha_bwd_traits; template size_t fmha_bwd_dq_dk_dv_dq_ws_host_size_(int batch_size); +// `total_seqlen_q_padded` is total q tokens across all batches (incl. per-batch padding): +// - batch mode: max_batch * seqlen_q +// - group mode: seqstart_q[batch] (== varlen q tensor's first dim) template size_t fmha_bwd_dq_dk_dv_dq_ws_device_upper_bound_(ck_tile::index_t max_batch, ck_tile::index_t hdim_q, ck_tile::index_t nhead_q, - ck_tile::index_t max_seqlen_q, + ck_tile::index_t total_seqlen_q_padded, ck_tile::index_t max_seqlen_k); template size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_(void* cpu_ws, @@ -730,8 +733,12 @@ struct fmha_bwd_launcher size_t device_ws_size = 0; if(host_ws_size_ > 0) { + // In group mode t.seqlen_q is already the padded total (== seqstart_q[batch]); + // in batch mode it's per-batch and the total is batch * seqlen_q. + const ck_tile::index_t total_seqlen_q_padded = + t.is_group_mode ? t.seqlen_q : t.batch * t.seqlen_q; device_ws_size = fmha_bwd_dq_dk_dv_dq_ws_device_upper_bound_( - t.batch, t.hdim_q, t.nhead_q, t.max_seqlen_q, t.max_seqlen_k); + t.batch, t.hdim_q, t.nhead_q, total_seqlen_q_padded, t.max_seqlen_k); pack_workspace_host_ = [batch = t.batch, hdim_q = t.hdim_q, nhead_q = t.nhead_q, diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 8391a14832..7aff21530d 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -167,45 +167,47 @@ struct FmhaBwdWorkspaceManager return kHasMask; } - // Mirrors PrepareWorkspaceHost's return value but uses worst-case totals so - // device workspace can be pre-allocated before host has the seqstart values. + // Upper bound on PrepareWorkspaceHost's size, computable without seqstart so + // the device workspace can be allocated before any D2H. + // + // total_seqlen_q_padded: total q tokens incl. per-batch padding. + // Batch: max_batch * seqlen_q. Group: seqstart_q[batch]. + // max_seqlen_k: deterministic-only; pass per-batch padded max if the caller + // does internal k padding, otherwise the logical max is fine. template CK_TILE_HOST static size_t GetWorkspaceDeviceSizeUpperBound(index_t max_batch, index_t hdim_q, index_t nhead_q, - index_t max_seqlen_q, + index_t total_seqlen_q_padded, index_t max_seqlen_k) { if constexpr(kUseQrQtrDorPipeline) return 0; - if constexpr(!kIsDeterministic) + index_t nsplits_factor = 1; + if constexpr(kIsDeterministic) { - return sizeof(AccDataType) * static_cast(max_batch) * nhead_q * - max_seqlen_q * hdim_q; - } - else if constexpr(kIsGroupMode) - { - const index_t nsplits_max = integer_divide_ceil(max_seqlen_k, kN0); - return sizeof(AccDataType) * static_cast(max_batch) * nhead_q * - nsplits_max * max_seqlen_q * hdim_q; - } - else // deterministic non-group mode (kUsePersistent) - { - const index_t dqdqkdv_workers = get_num_cus(); - const index_t jobs_per_head = integer_divide_ceil(max_seqlen_k, kN0); - const index_t total_jobs = max_batch * nhead_q * jobs_per_head; - const index_t jobs_per_worker = integer_divide_ceil(total_jobs, dqdqkdv_workers); - index_t nsplits; - if(jobs_per_head % jobs_per_worker == 0) - nsplits = jobs_per_head / jobs_per_worker; - else if(jobs_per_worker % jobs_per_head == 0) - nsplits = 1; - else - nsplits = 1 + integer_divide_ceil(jobs_per_head - 1, jobs_per_worker); - return sizeof(AccDataType) * static_cast(max_batch) * nhead_q * nsplits * - max_seqlen_q * hdim_q; + if constexpr(kIsGroupMode) + { + nsplits_factor = integer_divide_ceil(max_seqlen_k, kN0); + } + else // persistent + { + const index_t dqdqkdv_workers = get_num_cus(); + const index_t jobs_per_head = integer_divide_ceil(max_seqlen_k, kN0); + const index_t total_jobs = max_batch * nhead_q * jobs_per_head; + const index_t jobs_per_worker = integer_divide_ceil(total_jobs, dqdqkdv_workers); + if(jobs_per_head % jobs_per_worker == 0) + nsplits_factor = jobs_per_head / jobs_per_worker; + else if(jobs_per_worker % jobs_per_head == 0) + nsplits_factor = 1; + else + nsplits_factor = 1 + integer_divide_ceil(jobs_per_head - 1, jobs_per_worker); + } } + + return sizeof(AccDataType) * static_cast(nhead_q) * nsplits_factor * + total_seqlen_q_padded * hdim_q; } };