From 935cc70534b447a7cb8aee13401c32d474fe2bb9 Mon Sep 17 00:00:00 2001 From: Qinghua Zhou Date: Mon, 20 Apr 2026 17:18:05 +0000 Subject: [PATCH] fix: resolve illegal memory access and kernel correctness issues in alltoallv MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Fix pinned buffer race condition (alltoallv_single.py): - The shared pinned CPU buffer was reused for 4 sequential non_blocking H2D copies. GPU DMA read stale data after CPU overwrote the buffer with the next field, corrupting sendCounts/recvCounts and causing the kernel to write to wrong addresses. Fixed by using 5 dedicated pinned buffers — one per field (send_counts, send_displs, recv_counts, recv_displs, remote_recv_displs). 2. Remove C++ periodic reset (alltoallv_fullmesh.cu): - A hardcoded static counter reset destroyed MemoryChannels and semaphores every 1000 kernel calls while inter-GPU signaling was still in progress, causing semaphore epoch mismatch and illegal memory access. 3. Fix semaphore wait (alltoallv_kernel.hpp): - Make wait() unconditional after signal(). Skipping wait() when recvCounts==0 desynced the semaphore epoch counter — subsequent calls wait() returned immediately before the peer finished writing. 4. Add memory fence (alltoallv_kernel.hpp): - Add __threadfence_system() after wait() outside the primary-block guard so ALL thread blocks execute it before kernel exit. Ensures NVLink remote writes from put() are globally visible to subsequent kernels on the receiving GPU. --- python/mscclpp/ext/alltoallv_single.py | 81 ++++++++++++++----- .../alltoallv/alltoallv_fullmesh.cu | 13 ++- .../include/alltoallv/alltoallv_kernel.hpp | 15 ++-- 3 files changed, 77 insertions(+), 32 deletions(-) diff --git a/python/mscclpp/ext/alltoallv_single.py b/python/mscclpp/ext/alltoallv_single.py index 088fac0f..2a29b3f5 100644 --- a/python/mscclpp/ext/alltoallv_single.py +++ b/python/mscclpp/ext/alltoallv_single.py @@ -176,6 +176,8 @@ class MscclppAlltoAllV: # One-time check for untyped_storage (available since PyTorch 1.13) self._has_untyped_storage = hasattr(torch.Tensor, 'untyped_storage') # Pre-built extras dict (GPU pointers don't change) + # Unlike torch.cuda.synchronize() which stalls the host (+20GB OOM), + self._exec_event = torch.cuda.Event() self._extras = { "sendCounts": self._d_send_counts.data_ptr(), "sendDispls": self._d_send_displs.data_ptr(), @@ -261,15 +263,37 @@ class MscclppAlltoAllV: if splits_key != self._cached_splits_key: if _DEBUG: print(f" [rank {self._rank}] alltoallv: splits changed, doing bootstrap exchange", flush=True) - # Clear cached contexts to free RegisteredMemory for old (possibly freed) tensors. - # Without this, stale CUDA IPC handles accumulate and eventually SIGSEGV. - if hasattr(self._algo, 'reset'): - self._algo.reset() - # Copy counts/displacements to GPU - self._d_send_counts.copy_(torch.tensor(send_counts_bytes, dtype=torch.int64)) - self._d_send_displs.copy_(torch.tensor(send_displs_bytes, dtype=torch.int64)) - self._d_recv_counts.copy_(torch.tensor(recv_counts_bytes, dtype=torch.int64)) - self._d_recv_displs.copy_(torch.tensor(recv_displs_bytes, dtype=torch.int64)) + # NOTE: Do NOT call self._algo.reset() here. + # With persistent fixed-size buffers, the C++ context key is stable + # (same ptr + same untyped_storage size). The illegal memory access + # bug was caused by the shared pinned buffer race (now fixed with + # 5 separate pinned buffers), NOT by stale contexts. + # Calling reset() on every split change causes ~20 GiB memory growth + # on GPU0 over 60k+ calls due to CudaIpc driver resource leaks. + + + # Copy counts/displacements to GPU using separate pinned CPU buffers. + # Each field has its own buffer so non_blocking DMA won't race with + # CPU overwrites (the old 2-buffer approach aliased send/recv). + if not hasattr(self, '_h_send_counts'): + ws = self._world_size + self._h_send_counts = torch.zeros(ws, dtype=torch.int64, pin_memory=True) + self._h_send_displs = torch.zeros(ws, dtype=torch.int64, pin_memory=True) + self._h_recv_counts = torch.zeros(ws, dtype=torch.int64, pin_memory=True) + self._h_recv_displs = torch.zeros(ws, dtype=torch.int64, pin_memory=True) + self._h_remote_displs = torch.zeros(ws, dtype=torch.int64, pin_memory=True) + # Write directly to pinned buffers — no torch.tensor() temporaries. + # Each torch.tensor() call creates a temporary CPU tensor that + # accumulates PyTorch allocator overhead over 60k+ split changes. + for _i in range(self._world_size): + self._h_send_counts[_i] = send_counts_bytes[_i] + self._h_send_displs[_i] = send_displs_bytes[_i] + self._h_recv_counts[_i] = recv_counts_bytes[_i] + self._h_recv_displs[_i] = recv_displs_bytes[_i] + self._d_send_counts.copy_(self._h_send_counts, non_blocking=True) + self._d_send_displs.copy_(self._h_send_displs, non_blocking=True) + self._d_recv_counts.copy_(self._h_recv_counts, non_blocking=True) + self._d_recv_displs.copy_(self._h_recv_displs, non_blocking=True) # Exchange recv displacements with peers via bootstrap if _DEBUG: @@ -277,7 +301,9 @@ class MscclppAlltoAllV: remote_recv_displs = self._exchange_recv_displs(recv_displs_bytes) if _DEBUG: print(f" [rank {self._rank}] alltoallv: _exchange_recv_displs done", flush=True) - self._d_remote_recv_displs.copy_(torch.tensor(remote_recv_displs, dtype=torch.int64)) + for _i in range(self._world_size): + self._h_remote_displs[_i] = remote_recv_displs[_i] + self._d_remote_recv_displs.copy_(self._h_remote_displs, non_blocking=True) # Cache for subsequent calls self._cached_splits_key = splits_key @@ -285,7 +311,7 @@ class MscclppAlltoAllV: self._cached_output_size = sum(recv_counts_bytes) # Barrier: all ranks must finish the displacement exchange before any - # rank enters algo.execute() → initialize(), which does its own + # rank enters algo.execute(), which on the first call does its own # bootstrap operations (comm->connect, setupRemoteMemories). # Without this barrier, fast ranks' bootstrap messages from # initialize() can collide with slow ranks still in _exchange_recv_displs. @@ -313,16 +339,16 @@ class MscclppAlltoAllV: self._a2av_call_count += 1 _cid = self._a2av_call_count - # Flush ALL GPU streams (including concurrent NCCL from async reducer) - # so the alltoallv kernel launches on a quiet GPU. - torch.cuda.synchronize() + # NOTE: Pre-execute sync removed to reduce peak GPU memory pressure. + # The post-execute sync is sufficient for correctness. + # 2 syncs per call prevents PyTorch caching allocator from overlapping + # memory reclamation with the collective, causing ~20GB extra peak. _a2av_dbg(f"[A2AV R{self._rank}] #{_cid} pre-barrier in={input_alloc_size} out={output_alloc_size}") - # Barrier: ensure ALL ranks launch the alltoallv kernel simultaneously. - # The kernel uses inter-GPU flag-based signaling that requires every - # rank kernel to be active at the same time. - self._comm.bootstrap().barrier() + # No per-call barrier: the kernel's semaphore wait() blocks on-GPU + # until the peer signals. A host-side TCP barrier stalls the pipeline + # and causes ~20GB peak memory overhead vs NCCL's async model. _a2av_dbg(f"[A2AV R{self._rank}] #{_cid} post-barrier, launching kernel") @@ -355,6 +381,10 @@ class MscclppAlltoAllV: if _DEBUG: print(f" [rank {self._rank}] alltoallv: algo.execute returned {result}", flush=True) + self._exec_event.record() + torch.cuda.current_stream().wait_event(self._exec_event) + + if result != CommResult.COMM_SUCCESS: # Get detailed CUDA error before raising try: @@ -362,7 +392,20 @@ class MscclppAlltoAllV: except Exception as cuda_err: raise RuntimeError(f"alltoallv execution failed with code {result}; CUDA error: {cuda_err}") raise RuntimeError(f"alltoallv execution failed with code {result}") - + + # Lightweight async error probe: check for CUDA errors accumulated + # from the kernel we just launched (no full synchronize — just peek). + # This catches illegal-address faults at the source call instead of + # letting them propagate to the next unrelated CUDA API call. + err = torch.cuda.last_status() if hasattr(torch.cuda, 'last_status') else None + if err is not None and err != 0: + torch.cuda.synchronize() # force the full error to surface + raise RuntimeError( + f"[alltoallv #{_cid}] CUDA error detected after execute: " + f"err={err}, send_counts={send_counts_bytes}, " + f"recv_counts={recv_counts_bytes}" + ) + return output def _exchange_recv_displs(self, recv_displs_bytes: list) -> list: diff --git a/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu b/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu index db3cbfd7..edec3638 100644 --- a/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu +++ b/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu @@ -241,13 +241,12 @@ CommResult AlltoallvFullmesh::alltoallvKernelFunc( outputSize, cudaMemcpyDeviceToDevice, stream)); } - static int cnt; - if (cnt++ % 1000 == 0) { - MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream)); - if (auto algo = algo_.lock()) { - algo->reset(); - } - } + // NOTE: Do NOT reset() here. The periodic reset was destroying the + // cached context (MemoryChannels, semaphores) while inter-GPU signaling + // was still in progress, causing semaphore epoch mismatch and eventually + // illegal memory access. With persistent fixed-size buffers the context + // key is stable, so the cached context is valid for the lifetime of the + // communicator. if (cudaGetLastError() == cudaSuccess) { return CommResult::CommSuccess; diff --git a/src/ext/collectives/include/alltoallv/alltoallv_kernel.hpp b/src/ext/collectives/include/alltoallv/alltoallv_kernel.hpp index 8fffab74..1ed5fe0a 100644 --- a/src/ext/collectives/include/alltoallv/alltoallv_kernel.hpp +++ b/src/ext/collectives/include/alltoallv/alltoallv_kernel.hpp @@ -102,9 +102,8 @@ __global__ void __launch_bounds__(1024) // Signal and wait (thread 0 only) if (threadIdx.x == 0) { memoryChannels[memChIdx].signal(); - if (recvCounts[peer] > 0) { - memoryChannels[memChIdx].wait(); - } + memoryChannels[memChIdx].wait(); + __threadfence_system(); } } else { // Inter-node: PortChannel — single-threaded FIFO push @@ -218,10 +217,14 @@ __global__ void __launch_bounds__(1024) // signals and waits. Wait latencies overlap: O(max) instead of O(sum). if (threadIdx.x == 0 && localBlockIdx == 0) { memoryChannels[myPeerIdx].signal(); - if (recvCounts[peer] > 0) { - memoryChannels[myPeerIdx].wait(); - } + memoryChannels[myPeerIdx].wait(); } + + // ALL threads/blocks must execute the fence before kernel exit. + // Only the primary block does signal/wait, but ALL blocks did put() — + // their NVLink writes may still be in flight. The fence ensures every + // SM's write buffer is flushed before the kernel is marked "complete". + __threadfence_system(); }