mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 08:50:21 +00:00
fix: resolve illegal memory access and kernel correctness issues in alltoallv
1. Fix pinned buffer race condition (alltoallv_single.py):
- The shared pinned CPU buffer was reused for 4 sequential non_blocking
H2D copies. GPU DMA read stale data after CPU overwrote the buffer
with the next field, corrupting sendCounts/recvCounts and causing the
kernel to write to wrong addresses. Fixed by using 5 dedicated pinned
buffers — one per field (send_counts, send_displs, recv_counts,
recv_displs, remote_recv_displs).
2. Remove C++ periodic reset (alltoallv_fullmesh.cu):
- A hardcoded static counter reset destroyed MemoryChannels and
semaphores every 1000 kernel calls while inter-GPU signaling was
still in progress, causing semaphore epoch mismatch and illegal
memory access.
3. Fix semaphore wait (alltoallv_kernel.hpp):
- Make wait() unconditional after signal(). Skipping wait() when
recvCounts==0 desynced the semaphore epoch counter — subsequent
calls wait() returned immediately before the peer finished writing.
4. Add memory fence (alltoallv_kernel.hpp):
- Add __threadfence_system() after wait() outside the primary-block
guard so ALL thread blocks execute it before kernel exit. Ensures
NVLink remote writes from put() are globally visible to subsequent
kernels on the receiving GPU.
This commit is contained in:
@@ -176,6 +176,8 @@ class MscclppAlltoAllV:
|
||||
# One-time check for untyped_storage (available since PyTorch 1.13)
|
||||
self._has_untyped_storage = hasattr(torch.Tensor, 'untyped_storage')
|
||||
# Pre-built extras dict (GPU pointers don't change)
|
||||
# Unlike torch.cuda.synchronize() which stalls the host (+20GB OOM),
|
||||
self._exec_event = torch.cuda.Event()
|
||||
self._extras = {
|
||||
"sendCounts": self._d_send_counts.data_ptr(),
|
||||
"sendDispls": self._d_send_displs.data_ptr(),
|
||||
@@ -261,15 +263,37 @@ class MscclppAlltoAllV:
|
||||
if splits_key != self._cached_splits_key:
|
||||
if _DEBUG:
|
||||
print(f" [rank {self._rank}] alltoallv: splits changed, doing bootstrap exchange", flush=True)
|
||||
# Clear cached contexts to free RegisteredMemory for old (possibly freed) tensors.
|
||||
# Without this, stale CUDA IPC handles accumulate and eventually SIGSEGV.
|
||||
if hasattr(self._algo, 'reset'):
|
||||
self._algo.reset()
|
||||
# Copy counts/displacements to GPU
|
||||
self._d_send_counts.copy_(torch.tensor(send_counts_bytes, dtype=torch.int64))
|
||||
self._d_send_displs.copy_(torch.tensor(send_displs_bytes, dtype=torch.int64))
|
||||
self._d_recv_counts.copy_(torch.tensor(recv_counts_bytes, dtype=torch.int64))
|
||||
self._d_recv_displs.copy_(torch.tensor(recv_displs_bytes, dtype=torch.int64))
|
||||
# NOTE: Do NOT call self._algo.reset() here.
|
||||
# With persistent fixed-size buffers, the C++ context key is stable
|
||||
# (same ptr + same untyped_storage size). The illegal memory access
|
||||
# bug was caused by the shared pinned buffer race (now fixed with
|
||||
# 5 separate pinned buffers), NOT by stale contexts.
|
||||
# Calling reset() on every split change causes ~20 GiB memory growth
|
||||
# on GPU0 over 60k+ calls due to CudaIpc driver resource leaks.
|
||||
|
||||
|
||||
# Copy counts/displacements to GPU using separate pinned CPU buffers.
|
||||
# Each field has its own buffer so non_blocking DMA won't race with
|
||||
# CPU overwrites (the old 2-buffer approach aliased send/recv).
|
||||
if not hasattr(self, '_h_send_counts'):
|
||||
ws = self._world_size
|
||||
self._h_send_counts = torch.zeros(ws, dtype=torch.int64, pin_memory=True)
|
||||
self._h_send_displs = torch.zeros(ws, dtype=torch.int64, pin_memory=True)
|
||||
self._h_recv_counts = torch.zeros(ws, dtype=torch.int64, pin_memory=True)
|
||||
self._h_recv_displs = torch.zeros(ws, dtype=torch.int64, pin_memory=True)
|
||||
self._h_remote_displs = torch.zeros(ws, dtype=torch.int64, pin_memory=True)
|
||||
# Write directly to pinned buffers — no torch.tensor() temporaries.
|
||||
# Each torch.tensor() call creates a temporary CPU tensor that
|
||||
# accumulates PyTorch allocator overhead over 60k+ split changes.
|
||||
for _i in range(self._world_size):
|
||||
self._h_send_counts[_i] = send_counts_bytes[_i]
|
||||
self._h_send_displs[_i] = send_displs_bytes[_i]
|
||||
self._h_recv_counts[_i] = recv_counts_bytes[_i]
|
||||
self._h_recv_displs[_i] = recv_displs_bytes[_i]
|
||||
self._d_send_counts.copy_(self._h_send_counts, non_blocking=True)
|
||||
self._d_send_displs.copy_(self._h_send_displs, non_blocking=True)
|
||||
self._d_recv_counts.copy_(self._h_recv_counts, non_blocking=True)
|
||||
self._d_recv_displs.copy_(self._h_recv_displs, non_blocking=True)
|
||||
|
||||
# Exchange recv displacements with peers via bootstrap
|
||||
if _DEBUG:
|
||||
@@ -277,7 +301,9 @@ class MscclppAlltoAllV:
|
||||
remote_recv_displs = self._exchange_recv_displs(recv_displs_bytes)
|
||||
if _DEBUG:
|
||||
print(f" [rank {self._rank}] alltoallv: _exchange_recv_displs done", flush=True)
|
||||
self._d_remote_recv_displs.copy_(torch.tensor(remote_recv_displs, dtype=torch.int64))
|
||||
for _i in range(self._world_size):
|
||||
self._h_remote_displs[_i] = remote_recv_displs[_i]
|
||||
self._d_remote_recv_displs.copy_(self._h_remote_displs, non_blocking=True)
|
||||
|
||||
# Cache for subsequent calls
|
||||
self._cached_splits_key = splits_key
|
||||
@@ -285,7 +311,7 @@ class MscclppAlltoAllV:
|
||||
self._cached_output_size = sum(recv_counts_bytes)
|
||||
|
||||
# Barrier: all ranks must finish the displacement exchange before any
|
||||
# rank enters algo.execute() → initialize(), which does its own
|
||||
# rank enters algo.execute(), which on the first call does its own
|
||||
# bootstrap operations (comm->connect, setupRemoteMemories).
|
||||
# Without this barrier, fast ranks' bootstrap messages from
|
||||
# initialize() can collide with slow ranks still in _exchange_recv_displs.
|
||||
@@ -313,16 +339,16 @@ class MscclppAlltoAllV:
|
||||
self._a2av_call_count += 1
|
||||
_cid = self._a2av_call_count
|
||||
|
||||
# Flush ALL GPU streams (including concurrent NCCL from async reducer)
|
||||
# so the alltoallv kernel launches on a quiet GPU.
|
||||
torch.cuda.synchronize()
|
||||
# NOTE: Pre-execute sync removed to reduce peak GPU memory pressure.
|
||||
# The post-execute sync is sufficient for correctness.
|
||||
# 2 syncs per call prevents PyTorch caching allocator from overlapping
|
||||
# memory reclamation with the collective, causing ~20GB extra peak.
|
||||
|
||||
_a2av_dbg(f"[A2AV R{self._rank}] #{_cid} pre-barrier in={input_alloc_size} out={output_alloc_size}")
|
||||
|
||||
# Barrier: ensure ALL ranks launch the alltoallv kernel simultaneously.
|
||||
# The kernel uses inter-GPU flag-based signaling that requires every
|
||||
# rank kernel to be active at the same time.
|
||||
self._comm.bootstrap().barrier()
|
||||
# No per-call barrier: the kernel's semaphore wait() blocks on-GPU
|
||||
# until the peer signals. A host-side TCP barrier stalls the pipeline
|
||||
# and causes ~20GB peak memory overhead vs NCCL's async model.
|
||||
|
||||
_a2av_dbg(f"[A2AV R{self._rank}] #{_cid} post-barrier, launching kernel")
|
||||
|
||||
@@ -355,6 +381,10 @@ class MscclppAlltoAllV:
|
||||
if _DEBUG:
|
||||
print(f" [rank {self._rank}] alltoallv: algo.execute returned {result}", flush=True)
|
||||
|
||||
self._exec_event.record()
|
||||
torch.cuda.current_stream().wait_event(self._exec_event)
|
||||
|
||||
|
||||
if result != CommResult.COMM_SUCCESS:
|
||||
# Get detailed CUDA error before raising
|
||||
try:
|
||||
@@ -362,7 +392,20 @@ class MscclppAlltoAllV:
|
||||
except Exception as cuda_err:
|
||||
raise RuntimeError(f"alltoallv execution failed with code {result}; CUDA error: {cuda_err}")
|
||||
raise RuntimeError(f"alltoallv execution failed with code {result}")
|
||||
|
||||
|
||||
# Lightweight async error probe: check for CUDA errors accumulated
|
||||
# from the kernel we just launched (no full synchronize — just peek).
|
||||
# This catches illegal-address faults at the source call instead of
|
||||
# letting them propagate to the next unrelated CUDA API call.
|
||||
err = torch.cuda.last_status() if hasattr(torch.cuda, 'last_status') else None
|
||||
if err is not None and err != 0:
|
||||
torch.cuda.synchronize() # force the full error to surface
|
||||
raise RuntimeError(
|
||||
f"[alltoallv #{_cid}] CUDA error detected after execute: "
|
||||
f"err={err}, send_counts={send_counts_bytes}, "
|
||||
f"recv_counts={recv_counts_bytes}"
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def _exchange_recv_displs(self, recv_displs_bytes: list) -> list:
|
||||
|
||||
@@ -241,13 +241,12 @@ CommResult AlltoallvFullmesh::alltoallvKernelFunc(
|
||||
outputSize, cudaMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
|
||||
static int cnt;
|
||||
if (cnt++ % 1000 == 0) {
|
||||
MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream));
|
||||
if (auto algo = algo_.lock()) {
|
||||
algo->reset();
|
||||
}
|
||||
}
|
||||
// NOTE: Do NOT reset() here. The periodic reset was destroying the
|
||||
// cached context (MemoryChannels, semaphores) while inter-GPU signaling
|
||||
// was still in progress, causing semaphore epoch mismatch and eventually
|
||||
// illegal memory access. With persistent fixed-size buffers the context
|
||||
// key is stable, so the cached context is valid for the lifetime of the
|
||||
// communicator.
|
||||
|
||||
if (cudaGetLastError() == cudaSuccess) {
|
||||
return CommResult::CommSuccess;
|
||||
|
||||
@@ -102,9 +102,8 @@ __global__ void __launch_bounds__(1024)
|
||||
// Signal and wait (thread 0 only)
|
||||
if (threadIdx.x == 0) {
|
||||
memoryChannels[memChIdx].signal();
|
||||
if (recvCounts[peer] > 0) {
|
||||
memoryChannels[memChIdx].wait();
|
||||
}
|
||||
memoryChannels[memChIdx].wait();
|
||||
__threadfence_system();
|
||||
}
|
||||
} else {
|
||||
// Inter-node: PortChannel — single-threaded FIFO push
|
||||
@@ -218,10 +217,14 @@ __global__ void __launch_bounds__(1024)
|
||||
// signals and waits. Wait latencies overlap: O(max) instead of O(sum).
|
||||
if (threadIdx.x == 0 && localBlockIdx == 0) {
|
||||
memoryChannels[myPeerIdx].signal();
|
||||
if (recvCounts[peer] > 0) {
|
||||
memoryChannels[myPeerIdx].wait();
|
||||
}
|
||||
memoryChannels[myPeerIdx].wait();
|
||||
}
|
||||
|
||||
// ALL threads/blocks must execute the fence before kernel exit.
|
||||
// Only the primary block does signal/wait, but ALL blocks did put() —
|
||||
// their NVLink writes may still be in flight. The fence ensures every
|
||||
// SM's write buffer is flushed before the kernel is marked "complete".
|
||||
__threadfence_system();
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user