mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
ext/ep: fix kRDMASender epilogue tail-write race (unblocks chunk_send>16)
The kRDMASender warp loop processes tokens in a warp-stride pattern under a per-channel sequential lock. In-loop tail writes are monotonic by construction (each warp owns a strictly-increasing slot range while it holds the lock) so a plain st_release_cta is correct and avoids the L2- serialized atomicMax codepath that the compiler emits for global-space atomics. The epilogue tail write, however, sits outside that monotonicity contract: when multiple sender warps reach the epilogue out of order, a later-exiting warp owning a smaller last_rdma_tail_idx can clobber the tail with a smaller value, leaving the kRDMASenderCoordinator wedged waiting on a slot that is already produced. This is invisible at deep receive windows (chunk_send=16, num_chunked_recv=128 -> depth 8) but deterministically hangs at shallow ones (chunk_send=32 -> depth 4). Diagnostic instrumentation in kRDMASenderCoordinator caught peer=1 last_issued_tail=384 processed_tail=408 to_send=25, i.e. exactly one missing token at the boundary. Fix: replace the epilogue st_release_cta with atomicMax. Confirmed chunk_send=32 PASS (previously deterministic hang) and chunk_send=16 unaffected (no perf regression vs the previous lazy-head fix).
This commit is contained in:
@@ -892,7 +892,13 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
|
||||
if (lane_id < kNumRDMARanks and not kCachedMode)
|
||||
send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx;
|
||||
|
||||
// Update last token tail
|
||||
// Update last token tail. In-loop writes are sequenced by the
|
||||
// per-channel sequential lock and the warp-stride property of the
|
||||
// token loop, so monotonicity is guaranteed and a plain
|
||||
// st_release_cta is correct AND faster than atomicMax (which
|
||||
// would serialize through L2 if the compiler can't infer shared
|
||||
// address space). The epilogue (out of the seq-lock contract for
|
||||
// the highest in-rank slot) needs atomicMax separately.
|
||||
if (last_rdma_tail_idx >= 0)
|
||||
st_release_cta(const_cast<const int*>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);
|
||||
last_rdma_tail_idx = rdma_tail_idx;
|
||||
@@ -962,9 +968,9 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
|
||||
;
|
||||
__syncwarp();
|
||||
|
||||
// Update last token tail
|
||||
// Update last token tail (epilogue). See in-loop note on atomicMax.
|
||||
if (last_rdma_tail_idx >= 0)
|
||||
st_release_cta(const_cast<const int*>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);
|
||||
atomicMax(const_cast<int*>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);
|
||||
|
||||
// Release sequential lock
|
||||
lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0;
|
||||
|
||||
Reference in New Issue
Block a user