diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index c69e52b0af..5be57af61b 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -156,7 +156,7 @@ struct FmhaBwdWorkspaceManager } // Fill CPU prepared workspace and return size of non CPU prepared workspace size - template + template CK_TILE_HOST static size_t PrepareWorkspaceHost(void* cpu_ws, index_t batch_size, @@ -202,26 +202,33 @@ struct FmhaBwdWorkspaceManager auto* batch_states = reinterpret_cast( reinterpret_cast(cpu_ws) + GetBatchStateOffset(batch_size)); + // sq_work: sq aligned to kM0 for work-distribution purposes. + // If sq==0, use kM0 so CUs are still dispatched and write dK/dV=0. + const auto sq_work = [](index_t sq) -> index_t { + return sq == 0 ? kM0 : integer_least_multiple(sq, kM0); + }; + prefix_batch[0] = 0; for(index_t b = 0; b < batch_size; ++b) { const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b]; const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0); - prefix_batch[b + 1] = prefix_batch[b] + nhead_q * nc * sq; + prefix_batch[b + 1] = prefix_batch[b] + nhead_q * nc * sq_work(sq); } const index_t target_w = integer_divide_ceil(prefix_batch[batch_size], num_cus); // Step 2: compute nsplits[b] and fill batch_states[b] (sq, nc, nsplits per batch) for(index_t b = 0; b < batch_size; ++b) { - const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b]; - const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0); - const index_t rest_workload = (nc > 0) ? (nc - 1) * sq : 0; + const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b]; + const index_t sq_w = sq_work(sq); + const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0); + const index_t rest_workload = (nc > 0) ? (nc - 1) * sq_w : 0; const index_t ns = 1 + (rest_workload > 0 && target_w > 0 ? integer_divide_ceil(rest_workload, target_w) : 0); nsplits[b] = ns; - batch_states[b].sq = sq; + batch_states[b].sq = sq_w; // GPU uses sq_w for w_chunk tracking batch_states[b].nc = nc; batch_states[b].nsplits = ns; } @@ -246,10 +253,11 @@ struct FmhaBwdWorkspaceManager index_t cu_lo = 0; for(index_t b = 0; b < batch_size; ++b) { - const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b]; - const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0); - const index_t hw = nc * sq; - const index_t pb = prefix_batch[b]; + const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b]; + const index_t sq_w = sq_work(sq); // kM0-aligned sq for work distribution + const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0); + const index_t hw = nc * sq_w; // use sq_w so sq=0 batches get work + const index_t pb = prefix_batch[b]; const index_t cu_hi = min(num_cus, integer_divide_ceil(prefix_batch[b + 1], target_w)); for(index_t c = cu_lo; c < cu_hi; ++c) @@ -263,13 +271,13 @@ struct FmhaBwdWorkspaceManager const index_t w_head = pb + head_start * hw; const index_t wc_start = max(w_lo - w_head, index_t(0)); const index_t c_start = - wc_start > 0 ? integer_divide_ceil(wc_start, sq) : 0; + wc_start > 0 ? integer_divide_ceil(wc_start, sq_w) : 0; cu_states[c].isplit = wc_start > 0 ? integer_divide_ceil(wc_start, target_w) : 0; cu_states[c].head_start = head_start; cu_states[c].c_start = c_start; // w_lo = true global start of first K-chunk for this CU - cu_states[c].w_lo = pb + head_start * hw + c_start * sq; + cu_states[c].w_lo = pb + head_start * hw + c_start * sq_w; } else { @@ -451,7 +459,8 @@ struct FmhaBwdDQDKDVKernel CK_TILE_HOST static constexpr auto PrepareWorkspaceHost(Args&&... args) { return WorkspaceManager::template PrepareWorkspaceHost( + FmhaPipeline::BlockFmhaShape::kN0, + FmhaPipeline::BlockFmhaShape::kM0>( std::forward(args)...); } template