Fix nsplits to min(8, nc) for chunk-level dyn_naive load balancing under NCCL overlap

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Ye Wang
2026-06-02 14:58:12 -05:00
parent 3e3cb36c7a
commit 82ee798809

View File

@@ -218,17 +218,17 @@ struct FmhaBwdWorkspaceManager
const index_t target_w = integer_divide_ceil(prefix_batch[batch_size], num_cus);
// Step 2: compute nsplits[b] and fill batch_states[b] (sq, nc, nsplits per batch)
// Chunk-level dyn_naive: fix nsplits=8 for load balancing under NCCL overlap.
// Each CU claims one (ibatch, ihead, isplit) chunk dynamically; exclusive ownership
// guarantees deterministic dq_acc accumulation without cross-CU races.
for(index_t b = 0; b < batch_size; ++b)
{
const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b];
const index_t sq_w = sq_work(sq);
const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0);
const index_t rest_workload = (nc > 0) ? (nc - 1) * sq_w : 0;
const index_t ns = 1 + (rest_workload > 0 && target_w > 0
? integer_divide_ceil(rest_workload, target_w)
: 0);
const index_t ns = (nc > 0) ? min(index_t(8), nc) : 1;
nsplits[b] = ns;
batch_states[b].sq = sq_w; // GPU uses sq_w for w_chunk tracking
batch_states[b].sq = sq_w;
batch_states[b].nc = nc;
batch_states[b].nsplits = ns;
}
@@ -245,66 +245,10 @@ struct FmhaBwdWorkspaceManager
offsets[i] + static_cast<long_index_t>(nhead_q) * nsplits[i] *
(seqstart_qs[i + 1] - seqstart_qs[i]) * hdim_q;
// Step 4: fill cu_states via two-pointer scan.
// w_lo = global position of the first K-chunk: pb + head_start*hw + c_start*sq.
// This makes w_chunk track true global K-chunk positions on GPU, so w_chunk < w_hi
// correctly identifies boundaries without off-by-one overlap between adjacent CUs.
// w_hi is set in a post-pass to cu_states[c+1].w_lo.
index_t cu_lo = 0;
for(index_t b = 0; b < batch_size; ++b)
{
const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b];
const index_t sq_w = sq_work(sq); // kM0-aligned sq for work distribution
const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0);
const index_t hw = nc * sq_w; // use sq_w so sq=0 batches get work
const index_t pb = prefix_batch[b];
const index_t cu_hi =
min(num_cus, integer_divide_ceil(prefix_batch[b + 1], target_w));
for(index_t c = cu_lo; c < cu_hi; ++c)
{
const index_t w_lo = c * target_w;
cu_states[c].ibatch = b;
if(hw > 0)
{
const index_t head_start =
max(static_cast<index_t>((w_lo - pb) / hw), index_t(0));
const index_t w_head = pb + head_start * hw;
const index_t wc_start = max(w_lo - w_head, index_t(0));
const index_t c_start =
wc_start > 0 ? integer_divide_ceil(wc_start, sq_w) : 0;
cu_states[c].isplit =
wc_start > 0 ? integer_divide_ceil(wc_start, target_w) : 0;
cu_states[c].head_start = head_start;
cu_states[c].c_start = c_start;
// w_lo = true global start of first K-chunk for this CU
cu_states[c].w_lo = pb + head_start * hw + c_start * sq_w;
}
else
{
cu_states[c].isplit = 0;
cu_states[c].head_start = 0;
cu_states[c].c_start = 0;
cu_states[c].w_lo = pb; // hw==0: degenerate, w_lo=batch start
}
}
cu_lo = cu_hi;
}
// Inactive CUs: use total_w as w_lo sentinel so the post-pass sets
// the last active CU's w_hi = total_w correctly.
const index_t total_w = prefix_batch[batch_size];
for(index_t c = cu_lo; c < num_cus; ++c)
{
cu_states[c].w_lo = total_w;
cu_states[c].w_hi = total_w;
cu_states[c].ibatch = batch_size; // sentinel → early return on GPU
cu_states[c].isplit = 0;
cu_states[c].head_start = 0;
cu_states[c].c_start = 0;
}
// Post-pass: set w_hi[c] = w_lo[c+1] (global start of next CU's first K-chunk).
for(index_t c = 0; c < num_cus - 1; ++c)
cu_states[c].w_hi = cu_states[c + 1].w_lo;
cu_states[num_cus - 1].w_hi = total_w;
// Step 4: zero cu_states. The first index_t word is repurposed as the
// global tile counter (init=0) for the dynamic naive GPU scheduler;
// the remainder of the array is unused.
memset(cu_states, 0, GetCuStateSize(num_cus));
return sizeof(AccDataType) * dq_acc_elems;
}
@@ -1162,61 +1106,63 @@ struct FmhaBwdDQDKDVKernel
}
else
{
// Group mode persistent: variable seqlen per batch, dispatch via gist algo.
// Per-CU state (w_lo, w_hi, ibatch, isplit, head_start, c_start) is packed
// in a single struct array to minimise kargs pointer count (5 → 1).
// Remap block→CU: interleave SEs so consecutive blocks hit different SEs,
// spreading dq_acc writes across HBM channels.
const index_t cu_id = blockIdx.x / 8 + (blockIdx.x % 8) * 32;
// Group mode persistent: chunk-level dyn_naive scheduling.
// Each CU atomically claims one (ibatch, ihead, isplit) chunk and processes
// all nc/nsplits KV tiles in that chunk sequentially. Exclusive isplit
// ownership guarantees deterministic dq_acc accumulation: atomic_adds from
// this CU are ordered by the __syncthreads inside run_(), and no other CU
// writes to the same dq_acc[ibatch][ihead][isplit] slice.
// cu_states[0] is repurposed as the chunk counter (zeroed by
// PrepareWorkspaceHost before the workspace is copied to the GPU).
index_t* chunk_ctr =
const_cast<index_t*>(
reinterpret_cast<const index_t*>(kargs.cu_state_ptr));
// Load all per-CU fields through a single pointer; pointer dies after loads.
const FmhaBwdGroupPersistentCuState* cs = kargs.cu_state_ptr + cu_id;
const index_t w_hi = amd_wave_read_first_lane(cs->w_hi);
index_t ibatch = amd_wave_read_first_lane(cs->ibatch);
index_t isplit = amd_wave_read_first_lane(cs->isplit);
index_t head_start = amd_wave_read_first_lane(cs->head_start);
index_t c_start_0 = amd_wave_read_first_lane(cs->c_start);
index_t w_chunk = amd_wave_read_first_lane(cs->w_lo);
if(ibatch >= kargs.batch)
return; // this CU has no work (sentinel: ibatch == batch_size)
// Compute total chunks = sum_b(nhead * nsplits[b]) once per CU.
index_t total_chunks = 0;
for(index_t b = 0; b < kargs.batch; ++b)
total_chunks +=
kargs.nhead_q *
amd_wave_read_first_lane(kargs.batch_state_ptr[b].nsplits);
// w_chunk tracks the global K-chunk position; the loop exits when w_chunk
// reaches w_hi (this CU's exclusive upper bound). ibatch < batch is guaranteed
// on entry; the check inside guards against the rare case where head_start
// reaches nhead_q and ibatch is incremented past batch before w_chunk catches
// up.
do
index_t chunk_id;
while((chunk_id = atomicAdd(chunk_ctr, 1)) < total_chunks)
{
if(ibatch >= kargs.batch)
return;
// sq/nc/nsplits are read inline (not pre-hoisted) to shorten their live
// range across the inlined run_() body, reducing SGPR pressure.
const FmhaBwdBatchState* bs = kargs.batch_state_ptr + ibatch;
while(head_start < kargs.nhead_q)
// Decode chunk_id → (ibatch, ihead, isplit) via linear scan over batches.
index_t prefix = 0;
index_t ibatch = 0;
for(; ibatch < kargs.batch - 1; ++ibatch)
{
while(c_start_0 < amd_wave_read_first_lane(bs->nc) && w_chunk < w_hi)
{
run_(kargs,
dim3(tile_n_interleave(c_start_0,
amd_wave_read_first_lane(bs->nc)),
head_start,
ibatch),
isplit,
amd_wave_read_first_lane(bs->nsplits));
w_chunk += amd_wave_read_first_lane(bs->sq);
++c_start_0;
}
if(w_chunk >= w_hi)
return;
// w_chunk is now at the start of the next head
c_start_0 = 0;
isplit = 0;
++head_start;
const index_t batch_chunks =
kargs.nhead_q *
amd_wave_read_first_lane(kargs.batch_state_ptr[ibatch].nsplits);
if(prefix + batch_chunks > chunk_id)
break;
prefix += batch_chunks;
}
head_start = 0;
++ibatch;
} while(w_chunk < w_hi);
const FmhaBwdBatchState* bs = kargs.batch_state_ptr + ibatch;
const index_t nc = amd_wave_read_first_lane(bs->nc);
const index_t nsplits = amd_wave_read_first_lane(bs->nsplits);
const index_t rel = chunk_id - prefix;
const index_t ihead = rel / nsplits;
const index_t isplit = rel % nsplits;
// KV tile range owned by this chunk: [c_start, c_end).
// Integer division distributes nc tiles evenly across nsplits chunks;
// the last chunk absorbs any remainder.
const index_t c_start = isplit * nc / nsplits;
const index_t c_end = (isplit + 1) * nc / nsplits;
// Process all KV tiles in this chunk. The __syncthreads inside run_()
// orders the atomic_adds to dq_acc[ibatch][ihead][isplit] sequentially.
for(index_t c = c_start; c < c_end; ++c)
{
run_(kargs,
dim3(tile_n_interleave(c, nc), ihead, ibatch),
isplit,
nsplits);
}
}
}
}
}