From 82ee798809be5e09f2185c8a2b97b835db0c6b71 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Tue, 2 Jun 2026 14:58:12 -0500 Subject: [PATCH] Fix nsplits to min(8, nc) for chunk-level dyn_naive load balancing under NCCL overlap Co-Authored-By: Claude Sonnet 4.6 --- .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 176 ++++++------------ 1 file changed, 61 insertions(+), 115 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 5be57af61b..b88b44e4bb 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -218,17 +218,17 @@ struct FmhaBwdWorkspaceManager const index_t target_w = integer_divide_ceil(prefix_batch[batch_size], num_cus); // Step 2: compute nsplits[b] and fill batch_states[b] (sq, nc, nsplits per batch) + // Chunk-level dyn_naive: fix nsplits=8 for load balancing under NCCL overlap. + // Each CU claims one (ibatch, ihead, isplit) chunk dynamically; exclusive ownership + // guarantees deterministic dq_acc accumulation without cross-CU races. for(index_t b = 0; b < batch_size; ++b) { const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b]; const index_t sq_w = sq_work(sq); const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0); - const index_t rest_workload = (nc > 0) ? (nc - 1) * sq_w : 0; - const index_t ns = 1 + (rest_workload > 0 && target_w > 0 - ? integer_divide_ceil(rest_workload, target_w) - : 0); + const index_t ns = (nc > 0) ? min(index_t(8), nc) : 1; nsplits[b] = ns; - batch_states[b].sq = sq_w; // GPU uses sq_w for w_chunk tracking + batch_states[b].sq = sq_w; batch_states[b].nc = nc; batch_states[b].nsplits = ns; } @@ -245,66 +245,10 @@ struct FmhaBwdWorkspaceManager offsets[i] + static_cast(nhead_q) * nsplits[i] * (seqstart_qs[i + 1] - seqstart_qs[i]) * hdim_q; - // Step 4: fill cu_states via two-pointer scan. - // w_lo = global position of the first K-chunk: pb + head_start*hw + c_start*sq. - // This makes w_chunk track true global K-chunk positions on GPU, so w_chunk < w_hi - // correctly identifies boundaries without off-by-one overlap between adjacent CUs. - // w_hi is set in a post-pass to cu_states[c+1].w_lo. - index_t cu_lo = 0; - for(index_t b = 0; b < batch_size; ++b) - { - const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b]; - const index_t sq_w = sq_work(sq); // kM0-aligned sq for work distribution - const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0); - const index_t hw = nc * sq_w; // use sq_w so sq=0 batches get work - const index_t pb = prefix_batch[b]; - const index_t cu_hi = - min(num_cus, integer_divide_ceil(prefix_batch[b + 1], target_w)); - for(index_t c = cu_lo; c < cu_hi; ++c) - { - const index_t w_lo = c * target_w; - cu_states[c].ibatch = b; - if(hw > 0) - { - const index_t head_start = - max(static_cast((w_lo - pb) / hw), index_t(0)); - const index_t w_head = pb + head_start * hw; - const index_t wc_start = max(w_lo - w_head, index_t(0)); - const index_t c_start = - wc_start > 0 ? integer_divide_ceil(wc_start, sq_w) : 0; - cu_states[c].isplit = - wc_start > 0 ? integer_divide_ceil(wc_start, target_w) : 0; - cu_states[c].head_start = head_start; - cu_states[c].c_start = c_start; - // w_lo = true global start of first K-chunk for this CU - cu_states[c].w_lo = pb + head_start * hw + c_start * sq_w; - } - else - { - cu_states[c].isplit = 0; - cu_states[c].head_start = 0; - cu_states[c].c_start = 0; - cu_states[c].w_lo = pb; // hw==0: degenerate, w_lo=batch start - } - } - cu_lo = cu_hi; - } - // Inactive CUs: use total_w as w_lo sentinel so the post-pass sets - // the last active CU's w_hi = total_w correctly. - const index_t total_w = prefix_batch[batch_size]; - for(index_t c = cu_lo; c < num_cus; ++c) - { - cu_states[c].w_lo = total_w; - cu_states[c].w_hi = total_w; - cu_states[c].ibatch = batch_size; // sentinel → early return on GPU - cu_states[c].isplit = 0; - cu_states[c].head_start = 0; - cu_states[c].c_start = 0; - } - // Post-pass: set w_hi[c] = w_lo[c+1] (global start of next CU's first K-chunk). - for(index_t c = 0; c < num_cus - 1; ++c) - cu_states[c].w_hi = cu_states[c + 1].w_lo; - cu_states[num_cus - 1].w_hi = total_w; + // Step 4: zero cu_states. The first index_t word is repurposed as the + // global tile counter (init=0) for the dynamic naive GPU scheduler; + // the remainder of the array is unused. + memset(cu_states, 0, GetCuStateSize(num_cus)); return sizeof(AccDataType) * dq_acc_elems; } @@ -1162,61 +1106,63 @@ struct FmhaBwdDQDKDVKernel } else { - // Group mode persistent: variable seqlen per batch, dispatch via gist algo. - // Per-CU state (w_lo, w_hi, ibatch, isplit, head_start, c_start) is packed - // in a single struct array to minimise kargs pointer count (5 → 1). - // Remap block→CU: interleave SEs so consecutive blocks hit different SEs, - // spreading dq_acc writes across HBM channels. - const index_t cu_id = blockIdx.x / 8 + (blockIdx.x % 8) * 32; + // Group mode persistent: chunk-level dyn_naive scheduling. + // Each CU atomically claims one (ibatch, ihead, isplit) chunk and processes + // all nc/nsplits KV tiles in that chunk sequentially. Exclusive isplit + // ownership guarantees deterministic dq_acc accumulation: atomic_adds from + // this CU are ordered by the __syncthreads inside run_(), and no other CU + // writes to the same dq_acc[ibatch][ihead][isplit] slice. + // cu_states[0] is repurposed as the chunk counter (zeroed by + // PrepareWorkspaceHost before the workspace is copied to the GPU). + index_t* chunk_ctr = + const_cast( + reinterpret_cast(kargs.cu_state_ptr)); - // Load all per-CU fields through a single pointer; pointer dies after loads. - const FmhaBwdGroupPersistentCuState* cs = kargs.cu_state_ptr + cu_id; - const index_t w_hi = amd_wave_read_first_lane(cs->w_hi); - index_t ibatch = amd_wave_read_first_lane(cs->ibatch); - index_t isplit = amd_wave_read_first_lane(cs->isplit); - index_t head_start = amd_wave_read_first_lane(cs->head_start); - index_t c_start_0 = amd_wave_read_first_lane(cs->c_start); - index_t w_chunk = amd_wave_read_first_lane(cs->w_lo); - if(ibatch >= kargs.batch) - return; // this CU has no work (sentinel: ibatch == batch_size) + // Compute total chunks = sum_b(nhead * nsplits[b]) once per CU. + index_t total_chunks = 0; + for(index_t b = 0; b < kargs.batch; ++b) + total_chunks += + kargs.nhead_q * + amd_wave_read_first_lane(kargs.batch_state_ptr[b].nsplits); - // w_chunk tracks the global K-chunk position; the loop exits when w_chunk - // reaches w_hi (this CU's exclusive upper bound). ibatch < batch is guaranteed - // on entry; the check inside guards against the rare case where head_start - // reaches nhead_q and ibatch is incremented past batch before w_chunk catches - // up. - do + index_t chunk_id; + while((chunk_id = atomicAdd(chunk_ctr, 1)) < total_chunks) { - if(ibatch >= kargs.batch) - return; - // sq/nc/nsplits are read inline (not pre-hoisted) to shorten their live - // range across the inlined run_() body, reducing SGPR pressure. - const FmhaBwdBatchState* bs = kargs.batch_state_ptr + ibatch; - - while(head_start < kargs.nhead_q) + // Decode chunk_id → (ibatch, ihead, isplit) via linear scan over batches. + index_t prefix = 0; + index_t ibatch = 0; + for(; ibatch < kargs.batch - 1; ++ibatch) { - while(c_start_0 < amd_wave_read_first_lane(bs->nc) && w_chunk < w_hi) - { - run_(kargs, - dim3(tile_n_interleave(c_start_0, - amd_wave_read_first_lane(bs->nc)), - head_start, - ibatch), - isplit, - amd_wave_read_first_lane(bs->nsplits)); - w_chunk += amd_wave_read_first_lane(bs->sq); - ++c_start_0; - } - if(w_chunk >= w_hi) - return; - // w_chunk is now at the start of the next head - c_start_0 = 0; - isplit = 0; - ++head_start; + const index_t batch_chunks = + kargs.nhead_q * + amd_wave_read_first_lane(kargs.batch_state_ptr[ibatch].nsplits); + if(prefix + batch_chunks > chunk_id) + break; + prefix += batch_chunks; } - head_start = 0; - ++ibatch; - } while(w_chunk < w_hi); + const FmhaBwdBatchState* bs = kargs.batch_state_ptr + ibatch; + const index_t nc = amd_wave_read_first_lane(bs->nc); + const index_t nsplits = amd_wave_read_first_lane(bs->nsplits); + const index_t rel = chunk_id - prefix; + const index_t ihead = rel / nsplits; + const index_t isplit = rel % nsplits; + + // KV tile range owned by this chunk: [c_start, c_end). + // Integer division distributes nc tiles evenly across nsplits chunks; + // the last chunk absorbs any remainder. + const index_t c_start = isplit * nc / nsplits; + const index_t c_end = (isplit + 1) * nc / nsplits; + + // Process all KV tiles in this chunk. The __syncthreads inside run_() + // orders the atomic_adds to dq_acc[ibatch][ihead][isplit] sequentially. + for(index_t c = c_start; c < c_end; ++c) + { + run_(kargs, + dim3(tile_n_interleave(c, nc), ihead, ibatch), + isplit, + nsplits); + } + } } } }