mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
Fix nsplits to min(8, nc) for chunk-level dyn_naive load balancing under NCCL overlap
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -218,17 +218,17 @@ struct FmhaBwdWorkspaceManager
|
||||
const index_t target_w = integer_divide_ceil(prefix_batch[batch_size], num_cus);
|
||||
|
||||
// Step 2: compute nsplits[b] and fill batch_states[b] (sq, nc, nsplits per batch)
|
||||
// Chunk-level dyn_naive: fix nsplits=8 for load balancing under NCCL overlap.
|
||||
// Each CU claims one (ibatch, ihead, isplit) chunk dynamically; exclusive ownership
|
||||
// guarantees deterministic dq_acc accumulation without cross-CU races.
|
||||
for(index_t b = 0; b < batch_size; ++b)
|
||||
{
|
||||
const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b];
|
||||
const index_t sq_w = sq_work(sq);
|
||||
const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0);
|
||||
const index_t rest_workload = (nc > 0) ? (nc - 1) * sq_w : 0;
|
||||
const index_t ns = 1 + (rest_workload > 0 && target_w > 0
|
||||
? integer_divide_ceil(rest_workload, target_w)
|
||||
: 0);
|
||||
const index_t ns = (nc > 0) ? min(index_t(8), nc) : 1;
|
||||
nsplits[b] = ns;
|
||||
batch_states[b].sq = sq_w; // GPU uses sq_w for w_chunk tracking
|
||||
batch_states[b].sq = sq_w;
|
||||
batch_states[b].nc = nc;
|
||||
batch_states[b].nsplits = ns;
|
||||
}
|
||||
@@ -245,66 +245,10 @@ struct FmhaBwdWorkspaceManager
|
||||
offsets[i] + static_cast<long_index_t>(nhead_q) * nsplits[i] *
|
||||
(seqstart_qs[i + 1] - seqstart_qs[i]) * hdim_q;
|
||||
|
||||
// Step 4: fill cu_states via two-pointer scan.
|
||||
// w_lo = global position of the first K-chunk: pb + head_start*hw + c_start*sq.
|
||||
// This makes w_chunk track true global K-chunk positions on GPU, so w_chunk < w_hi
|
||||
// correctly identifies boundaries without off-by-one overlap between adjacent CUs.
|
||||
// w_hi is set in a post-pass to cu_states[c+1].w_lo.
|
||||
index_t cu_lo = 0;
|
||||
for(index_t b = 0; b < batch_size; ++b)
|
||||
{
|
||||
const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b];
|
||||
const index_t sq_w = sq_work(sq); // kM0-aligned sq for work distribution
|
||||
const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0);
|
||||
const index_t hw = nc * sq_w; // use sq_w so sq=0 batches get work
|
||||
const index_t pb = prefix_batch[b];
|
||||
const index_t cu_hi =
|
||||
min(num_cus, integer_divide_ceil(prefix_batch[b + 1], target_w));
|
||||
for(index_t c = cu_lo; c < cu_hi; ++c)
|
||||
{
|
||||
const index_t w_lo = c * target_w;
|
||||
cu_states[c].ibatch = b;
|
||||
if(hw > 0)
|
||||
{
|
||||
const index_t head_start =
|
||||
max(static_cast<index_t>((w_lo - pb) / hw), index_t(0));
|
||||
const index_t w_head = pb + head_start * hw;
|
||||
const index_t wc_start = max(w_lo - w_head, index_t(0));
|
||||
const index_t c_start =
|
||||
wc_start > 0 ? integer_divide_ceil(wc_start, sq_w) : 0;
|
||||
cu_states[c].isplit =
|
||||
wc_start > 0 ? integer_divide_ceil(wc_start, target_w) : 0;
|
||||
cu_states[c].head_start = head_start;
|
||||
cu_states[c].c_start = c_start;
|
||||
// w_lo = true global start of first K-chunk for this CU
|
||||
cu_states[c].w_lo = pb + head_start * hw + c_start * sq_w;
|
||||
}
|
||||
else
|
||||
{
|
||||
cu_states[c].isplit = 0;
|
||||
cu_states[c].head_start = 0;
|
||||
cu_states[c].c_start = 0;
|
||||
cu_states[c].w_lo = pb; // hw==0: degenerate, w_lo=batch start
|
||||
}
|
||||
}
|
||||
cu_lo = cu_hi;
|
||||
}
|
||||
// Inactive CUs: use total_w as w_lo sentinel so the post-pass sets
|
||||
// the last active CU's w_hi = total_w correctly.
|
||||
const index_t total_w = prefix_batch[batch_size];
|
||||
for(index_t c = cu_lo; c < num_cus; ++c)
|
||||
{
|
||||
cu_states[c].w_lo = total_w;
|
||||
cu_states[c].w_hi = total_w;
|
||||
cu_states[c].ibatch = batch_size; // sentinel → early return on GPU
|
||||
cu_states[c].isplit = 0;
|
||||
cu_states[c].head_start = 0;
|
||||
cu_states[c].c_start = 0;
|
||||
}
|
||||
// Post-pass: set w_hi[c] = w_lo[c+1] (global start of next CU's first K-chunk).
|
||||
for(index_t c = 0; c < num_cus - 1; ++c)
|
||||
cu_states[c].w_hi = cu_states[c + 1].w_lo;
|
||||
cu_states[num_cus - 1].w_hi = total_w;
|
||||
// Step 4: zero cu_states. The first index_t word is repurposed as the
|
||||
// global tile counter (init=0) for the dynamic naive GPU scheduler;
|
||||
// the remainder of the array is unused.
|
||||
memset(cu_states, 0, GetCuStateSize(num_cus));
|
||||
|
||||
return sizeof(AccDataType) * dq_acc_elems;
|
||||
}
|
||||
@@ -1162,61 +1106,63 @@ struct FmhaBwdDQDKDVKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
// Group mode persistent: variable seqlen per batch, dispatch via gist algo.
|
||||
// Per-CU state (w_lo, w_hi, ibatch, isplit, head_start, c_start) is packed
|
||||
// in a single struct array to minimise kargs pointer count (5 → 1).
|
||||
// Remap block→CU: interleave SEs so consecutive blocks hit different SEs,
|
||||
// spreading dq_acc writes across HBM channels.
|
||||
const index_t cu_id = blockIdx.x / 8 + (blockIdx.x % 8) * 32;
|
||||
// Group mode persistent: chunk-level dyn_naive scheduling.
|
||||
// Each CU atomically claims one (ibatch, ihead, isplit) chunk and processes
|
||||
// all nc/nsplits KV tiles in that chunk sequentially. Exclusive isplit
|
||||
// ownership guarantees deterministic dq_acc accumulation: atomic_adds from
|
||||
// this CU are ordered by the __syncthreads inside run_(), and no other CU
|
||||
// writes to the same dq_acc[ibatch][ihead][isplit] slice.
|
||||
// cu_states[0] is repurposed as the chunk counter (zeroed by
|
||||
// PrepareWorkspaceHost before the workspace is copied to the GPU).
|
||||
index_t* chunk_ctr =
|
||||
const_cast<index_t*>(
|
||||
reinterpret_cast<const index_t*>(kargs.cu_state_ptr));
|
||||
|
||||
// Load all per-CU fields through a single pointer; pointer dies after loads.
|
||||
const FmhaBwdGroupPersistentCuState* cs = kargs.cu_state_ptr + cu_id;
|
||||
const index_t w_hi = amd_wave_read_first_lane(cs->w_hi);
|
||||
index_t ibatch = amd_wave_read_first_lane(cs->ibatch);
|
||||
index_t isplit = amd_wave_read_first_lane(cs->isplit);
|
||||
index_t head_start = amd_wave_read_first_lane(cs->head_start);
|
||||
index_t c_start_0 = amd_wave_read_first_lane(cs->c_start);
|
||||
index_t w_chunk = amd_wave_read_first_lane(cs->w_lo);
|
||||
if(ibatch >= kargs.batch)
|
||||
return; // this CU has no work (sentinel: ibatch == batch_size)
|
||||
// Compute total chunks = sum_b(nhead * nsplits[b]) once per CU.
|
||||
index_t total_chunks = 0;
|
||||
for(index_t b = 0; b < kargs.batch; ++b)
|
||||
total_chunks +=
|
||||
kargs.nhead_q *
|
||||
amd_wave_read_first_lane(kargs.batch_state_ptr[b].nsplits);
|
||||
|
||||
// w_chunk tracks the global K-chunk position; the loop exits when w_chunk
|
||||
// reaches w_hi (this CU's exclusive upper bound). ibatch < batch is guaranteed
|
||||
// on entry; the check inside guards against the rare case where head_start
|
||||
// reaches nhead_q and ibatch is incremented past batch before w_chunk catches
|
||||
// up.
|
||||
do
|
||||
index_t chunk_id;
|
||||
while((chunk_id = atomicAdd(chunk_ctr, 1)) < total_chunks)
|
||||
{
|
||||
if(ibatch >= kargs.batch)
|
||||
return;
|
||||
// sq/nc/nsplits are read inline (not pre-hoisted) to shorten their live
|
||||
// range across the inlined run_() body, reducing SGPR pressure.
|
||||
const FmhaBwdBatchState* bs = kargs.batch_state_ptr + ibatch;
|
||||
|
||||
while(head_start < kargs.nhead_q)
|
||||
// Decode chunk_id → (ibatch, ihead, isplit) via linear scan over batches.
|
||||
index_t prefix = 0;
|
||||
index_t ibatch = 0;
|
||||
for(; ibatch < kargs.batch - 1; ++ibatch)
|
||||
{
|
||||
while(c_start_0 < amd_wave_read_first_lane(bs->nc) && w_chunk < w_hi)
|
||||
{
|
||||
run_(kargs,
|
||||
dim3(tile_n_interleave(c_start_0,
|
||||
amd_wave_read_first_lane(bs->nc)),
|
||||
head_start,
|
||||
ibatch),
|
||||
isplit,
|
||||
amd_wave_read_first_lane(bs->nsplits));
|
||||
w_chunk += amd_wave_read_first_lane(bs->sq);
|
||||
++c_start_0;
|
||||
}
|
||||
if(w_chunk >= w_hi)
|
||||
return;
|
||||
// w_chunk is now at the start of the next head
|
||||
c_start_0 = 0;
|
||||
isplit = 0;
|
||||
++head_start;
|
||||
const index_t batch_chunks =
|
||||
kargs.nhead_q *
|
||||
amd_wave_read_first_lane(kargs.batch_state_ptr[ibatch].nsplits);
|
||||
if(prefix + batch_chunks > chunk_id)
|
||||
break;
|
||||
prefix += batch_chunks;
|
||||
}
|
||||
head_start = 0;
|
||||
++ibatch;
|
||||
} while(w_chunk < w_hi);
|
||||
const FmhaBwdBatchState* bs = kargs.batch_state_ptr + ibatch;
|
||||
const index_t nc = amd_wave_read_first_lane(bs->nc);
|
||||
const index_t nsplits = amd_wave_read_first_lane(bs->nsplits);
|
||||
const index_t rel = chunk_id - prefix;
|
||||
const index_t ihead = rel / nsplits;
|
||||
const index_t isplit = rel % nsplits;
|
||||
|
||||
// KV tile range owned by this chunk: [c_start, c_end).
|
||||
// Integer division distributes nc tiles evenly across nsplits chunks;
|
||||
// the last chunk absorbs any remainder.
|
||||
const index_t c_start = isplit * nc / nsplits;
|
||||
const index_t c_end = (isplit + 1) * nc / nsplits;
|
||||
|
||||
// Process all KV tiles in this chunk. The __syncthreads inside run_()
|
||||
// orders the atomic_adds to dq_acc[ibatch][ihead][isplit] sequentially.
|
||||
for(index_t c = c_start; c < c_end; ++c)
|
||||
{
|
||||
run_(kargs,
|
||||
dim3(tile_n_interleave(c, nc), ihead, ibatch),
|
||||
isplit,
|
||||
nsplits);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user