Files
mscclpp/python/mscclpp/ext/alltoallv_single.py
Qinghua Zhou 935cc70534 fix: resolve illegal memory access and kernel correctness issues in alltoallv
1. Fix pinned buffer race condition (alltoallv_single.py):
   - The shared pinned CPU buffer was reused for 4 sequential non_blocking
     H2D copies. GPU DMA read stale data after CPU overwrote the buffer
     with the next field, corrupting sendCounts/recvCounts and causing the
     kernel to write to wrong addresses. Fixed by using 5 dedicated pinned
     buffers — one per field (send_counts, send_displs, recv_counts,
     recv_displs, remote_recv_displs).

2. Remove C++ periodic reset (alltoallv_fullmesh.cu):
   - A hardcoded static counter reset destroyed MemoryChannels and
     semaphores every 1000 kernel calls while inter-GPU signaling was
     still in progress, causing semaphore epoch mismatch and illegal
     memory access.

3. Fix semaphore wait (alltoallv_kernel.hpp):
   - Make wait() unconditional after signal(). Skipping wait() when
     recvCounts==0 desynced the semaphore epoch counter — subsequent
     calls wait() returned immediately before the peer finished writing.

4. Add memory fence (alltoallv_kernel.hpp):
   - Add __threadfence_system() after wait() outside the primary-block
     guard so ALL thread blocks execute it before kernel exit. Ensures
     NVLink remote writes from put() are globally visible to subsequent
     kernels on the receiving GPU.
2026-04-20 17:18:05 +00:00

511 lines
21 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
PyTorch-compatible all_to_all_single API using mscclpp optimized kernels.
This module provides:
- MscclppAlltoAllV: A class to manage mscclpp alltoallv state
- all_to_all_single: Drop-in replacement for torch.distributed.all_to_all_single
Uses optimized C++ kernels (alltoallvKernel, alltoallvRingKernel, alltoallvPipelinedKernel)
via the NativeAlgorithm framework with size-adaptive algorithm selection.
"""
from __future__ import annotations
import os
import sys
import torch
import torch.distributed as dist
from typing import Optional, List, Tuple
_DEBUG_A2AV = bool(int(os.environ.get("DEBUG_ALL2ALL_MSG_SIZE", "0")))
def _a2av_dbg(msg: str):
if _DEBUG_A2AV:
print(msg, file=sys.stderr, flush=True)
from mscclpp._mscclpp import (
Communicator,
TcpBootstrap,
DataType,
ReduceOp,
CommResult,
)
from mscclpp.ext.algorithm_collection_builder import AlgorithmCollectionBuilder
import ctypes as _ctypes
try:
_cudart = _ctypes.CDLL("libcudart.so")
except Exception:
_cudart = None
_DEBUG = os.environ.get("MSCCLPP_DEBUG_ALLTOALLV", "0") == "1"
__all__ = ["MscclppAlltoAllV", "all_to_all_single"]
def _torch_dtype_to_mscclpp(dtype: torch.dtype) -> DataType:
"""Convert PyTorch dtype to mscclpp DataType."""
if dtype == torch.float32:
return DataType.float32
elif dtype == torch.float16:
return DataType.float16
elif dtype == torch.bfloat16:
return DataType.bfloat16
elif dtype == torch.int32:
return DataType.int32
elif dtype == torch.int64:
return DataType.int64
elif dtype == torch.uint8:
return DataType.uint8
elif dtype == torch.float64:
return DataType.float64
else:
raise ValueError(f"Unsupported dtype: {dtype}")
def _dtype_size(dtype: torch.dtype) -> int:
"""Get byte size for dtype."""
return torch.tensor([], dtype=dtype).element_size()
class MscclppAlltoAllV:
"""
Manages mscclpp state for alltoallv operations.
Uses optimized C++ kernels from alltoallv_fullmesh.cu with size-adaptive selection:
- Small messages (<1MB): alltoallvKernel (lower latency)
- Large messages + small world (<=16): alltoallvPipelinedKernel
- Large messages + large world (>16): alltoallvRingKernel (avoids congestion)
Example:
mscclpp_alltoallv = MscclppAlltoAllV(
rank=rank, world_size=world_size,
ip_port="10.0.0.1:50000"
)
# or use existing communicator:
# mscclpp_alltoallv = MscclppAlltoAllV(communicator=comm)
# Later:
output = mscclpp_alltoallv.all_to_all_single(
input_tensor,
output_split_sizes=[1024, 2048, ...], # per-rank sizes
input_split_sizes=[1024, 2048, ...]
)
"""
def __init__(
self,
rank: Optional[int] = None,
world_size: Optional[int] = None,
ip_port: Optional[str] = None,
communicator: Optional[Communicator] = None,
scratch_buffer_size: int = 256 * 1024 * 1024, # 256MB default
):
"""
Initialize MscclppAlltoAllV.
Args:
rank: Local rank (required if communicator not provided)
world_size: Total number of ranks (required if communicator not provided)
ip_port: IP:port for bootstrap (required if communicator not provided)
communicator: Existing mscclpp Communicator (alternative to rank/world_size/ip_port)
scratch_buffer_size: Size of scratch buffer in bytes
"""
if communicator is not None:
self._comm = communicator
self._rank = self._comm.bootstrap().get_rank()
self._world_size = self._comm.bootstrap().get_n_ranks()
self._owns_comm = False
else:
if rank is None or world_size is None or ip_port is None:
raise ValueError("Must provide either communicator or (rank, world_size, ip_port)")
self._rank = rank
self._world_size = world_size
# Create bootstrap
bootstrap = TcpBootstrap(rank, world_size)
if rank == 0:
unique_id = bootstrap.create_unique_id()
# Broadcast unique_id to other ranks via torch.distributed
id_tensor = torch.tensor(list(unique_id.encode()), dtype=torch.uint8).cuda()
else:
id_tensor = torch.zeros(128, dtype=torch.uint8).cuda()
dist.broadcast(id_tensor, src=0)
unique_id = bytes(id_tensor.cpu().tolist()).decode().rstrip('\x00')
bootstrap.initialize(unique_id)
self._comm = Communicator(bootstrap)
self._owns_comm = True
# Allocate scratch buffer
self._scratch_buffer = torch.zeros(scratch_buffer_size, dtype=torch.uint8, device='cuda')
self._scratch_ptr = self._scratch_buffer.data_ptr()
self._scratch_size = scratch_buffer_size
# Build algorithm collection with default algorithms including alltoallv
builder = AlgorithmCollectionBuilder()
self._algo_collection = builder.build_default_algorithms(
self._scratch_ptr,
self._scratch_size,
self._rank
)
# Get the alltoallv algorithm
alltoallv_algos = self._algo_collection.get_by_collective("alltoallv")
if not alltoallv_algos:
raise RuntimeError("No alltoallv algorithm found. Make sure mscclpp is built correctly.")
self._algo = alltoallv_algos[0]
# Pre-allocate count/displacement buffers on GPU (reused across calls)
# Using int64 (8 bytes) instead of size_t for safety
self._d_send_counts = torch.zeros(self._world_size, dtype=torch.int64, device='cuda')
self._d_send_displs = torch.zeros(self._world_size, dtype=torch.int64, device='cuda')
self._d_recv_counts = torch.zeros(self._world_size, dtype=torch.int64, device='cuda')
self._d_recv_displs = torch.zeros(self._world_size, dtype=torch.int64, device='cuda')
self._d_remote_recv_displs = torch.zeros(self._world_size, dtype=torch.int64, device='cuda')
# Cache for split sizes to avoid redundant bootstrap exchanges and GPU copies.
# Key: (tuple(send_counts_bytes), tuple(recv_counts_bytes))
self._cached_splits_key = None
self._cached_input_size = 0
self._cached_output_size = 0
self._cached_total_output_elems = 0
self._cached_dtype = None
# One-time check for untyped_storage (available since PyTorch 1.13)
self._has_untyped_storage = hasattr(torch.Tensor, 'untyped_storage')
# Pre-built extras dict (GPU pointers don't change)
# Unlike torch.cuda.synchronize() which stalls the host (+20GB OOM),
self._exec_event = torch.cuda.Event()
self._extras = {
"sendCounts": self._d_send_counts.data_ptr(),
"sendDispls": self._d_send_displs.data_ptr(),
"recvCounts": self._d_recv_counts.data_ptr(),
"recvDispls": self._d_recv_displs.data_ptr(),
"remoteRecvDispls": self._d_remote_recv_displs.data_ptr(),
}
self._a2av_call_count = 0
@property
def rank(self) -> int:
return self._rank
@property
def world_size(self) -> int:
return self._world_size
def all_to_all_single(
self,
input: torch.Tensor,
output_split_sizes: Optional[List[int]] = None,
input_split_sizes: Optional[List[int]] = None,
output: Optional[torch.Tensor] = None,
stream: Optional[torch.cuda.Stream] = None,
) -> torch.Tensor:
"""
Perform all-to-all exchange with variable-sized chunks.
Compatible with torch.distributed.all_to_all_single signature.
Args:
input: Input tensor (contiguous, CUDA)
output_split_sizes: List of sizes to receive from each rank (in elements)
input_split_sizes: List of sizes to send to each rank (in elements)
output: Pre-allocated output tensor (optional)
stream: CUDA stream (optional, uses current stream if not specified)
Returns:
Output tensor with received data
"""
if not input.is_cuda or not input.is_contiguous():
raise ValueError("Input must be a contiguous CUDA tensor")
dtype = input.dtype
elem_size = _dtype_size(dtype)
world_size = self._world_size
# Handle split sizes
if input_split_sizes is None:
# Equal split
assert input.numel() % world_size == 0
chunk_size = input.numel() // world_size
input_split_sizes = [chunk_size] * world_size
if output_split_sizes is None:
# All-to-all uniform: send and recv same sizes
output_split_sizes = input_split_sizes.copy()
# Calculate total output size and allocate if needed
total_output = sum(output_split_sizes)
if output is None:
output = torch.empty(total_output, dtype=dtype, device='cuda')
elif output.numel() < total_output:
raise ValueError(f"Output tensor too small: {output.numel()} < {total_output}")
# Calculate displacements
send_displs = [0]
for i in range(world_size - 1):
send_displs.append(send_displs[-1] + input_split_sizes[i])
recv_displs = [0]
for i in range(world_size - 1):
recv_displs.append(recv_displs[-1] + output_split_sizes[i])
# Convert to byte sizes/offsets for the kernel
send_counts_bytes = [s * elem_size for s in input_split_sizes]
send_displs_bytes = [d * elem_size for d in send_displs]
recv_counts_bytes = [s * elem_size for s in output_split_sizes]
recv_displs_bytes = [d * elem_size for d in recv_displs]
# Fast path: skip GPU copies + bootstrap exchange if split sizes unchanged
splits_key = (tuple(send_counts_bytes), tuple(recv_counts_bytes))
if splits_key != self._cached_splits_key:
if _DEBUG:
print(f" [rank {self._rank}] alltoallv: splits changed, doing bootstrap exchange", flush=True)
# NOTE: Do NOT call self._algo.reset() here.
# With persistent fixed-size buffers, the C++ context key is stable
# (same ptr + same untyped_storage size). The illegal memory access
# bug was caused by the shared pinned buffer race (now fixed with
# 5 separate pinned buffers), NOT by stale contexts.
# Calling reset() on every split change causes ~20 GiB memory growth
# on GPU0 over 60k+ calls due to CudaIpc driver resource leaks.
# Copy counts/displacements to GPU using separate pinned CPU buffers.
# Each field has its own buffer so non_blocking DMA won't race with
# CPU overwrites (the old 2-buffer approach aliased send/recv).
if not hasattr(self, '_h_send_counts'):
ws = self._world_size
self._h_send_counts = torch.zeros(ws, dtype=torch.int64, pin_memory=True)
self._h_send_displs = torch.zeros(ws, dtype=torch.int64, pin_memory=True)
self._h_recv_counts = torch.zeros(ws, dtype=torch.int64, pin_memory=True)
self._h_recv_displs = torch.zeros(ws, dtype=torch.int64, pin_memory=True)
self._h_remote_displs = torch.zeros(ws, dtype=torch.int64, pin_memory=True)
# Write directly to pinned buffers — no torch.tensor() temporaries.
# Each torch.tensor() call creates a temporary CPU tensor that
# accumulates PyTorch allocator overhead over 60k+ split changes.
for _i in range(self._world_size):
self._h_send_counts[_i] = send_counts_bytes[_i]
self._h_send_displs[_i] = send_displs_bytes[_i]
self._h_recv_counts[_i] = recv_counts_bytes[_i]
self._h_recv_displs[_i] = recv_displs_bytes[_i]
self._d_send_counts.copy_(self._h_send_counts, non_blocking=True)
self._d_send_displs.copy_(self._h_send_displs, non_blocking=True)
self._d_recv_counts.copy_(self._h_recv_counts, non_blocking=True)
self._d_recv_displs.copy_(self._h_recv_displs, non_blocking=True)
# Exchange recv displacements with peers via bootstrap
if _DEBUG:
print(f" [rank {self._rank}] alltoallv: starting _exchange_recv_displs", flush=True)
remote_recv_displs = self._exchange_recv_displs(recv_displs_bytes)
if _DEBUG:
print(f" [rank {self._rank}] alltoallv: _exchange_recv_displs done", flush=True)
for _i in range(self._world_size):
self._h_remote_displs[_i] = remote_recv_displs[_i]
self._d_remote_recv_displs.copy_(self._h_remote_displs, non_blocking=True)
# Cache for subsequent calls
self._cached_splits_key = splits_key
self._cached_input_size = sum(send_counts_bytes)
self._cached_output_size = sum(recv_counts_bytes)
# Barrier: all ranks must finish the displacement exchange before any
# rank enters algo.execute(), which on the first call does its own
# bootstrap operations (comm->connect, setupRemoteMemories).
# Without this barrier, fast ranks' bootstrap messages from
# initialize() can collide with slow ranks still in _exchange_recv_displs.
if _DEBUG:
print(f" [rank {self._rank}] alltoallv: waiting on bootstrap barrier", flush=True)
self._comm.bootstrap().barrier()
if _DEBUG:
print(f" [rank {self._rank}] alltoallv: bootstrap barrier done", flush=True)
# Get stream
if stream is None:
stream = torch.cuda.current_stream()
cuda_stream = stream.cuda_stream
# Use the full underlying storage size for context key stability.
# When the test reuses the same large tensor with different split sizes,
# storage size stays constant → same context key → reuses channels.
if self._has_untyped_storage:
input_alloc_size = input.untyped_storage().size()
output_alloc_size = output.untyped_storage().size()
else:
input_alloc_size = input.nelement() * input.element_size()
output_alloc_size = output.nelement() * output.element_size()
self._a2av_call_count += 1
_cid = self._a2av_call_count
# NOTE: Pre-execute sync removed to reduce peak GPU memory pressure.
# The post-execute sync is sufficient for correctness.
# 2 syncs per call prevents PyTorch caching allocator from overlapping
# memory reclamation with the collective, causing ~20GB extra peak.
_a2av_dbg(f"[A2AV R{self._rank}] #{_cid} pre-barrier in={input_alloc_size} out={output_alloc_size}")
# No per-call barrier: the kernel's semaphore wait() blocks on-GPU
# until the peer signals. A host-side TCP barrier stalls the pipeline
# and causes ~20GB peak memory overhead vs NCCL's async model.
_a2av_dbg(f"[A2AV R{self._rank}] #{_cid} post-barrier, launching kernel")
# Execute the optimized kernel
if _DEBUG:
# Clear stale CUDA errors (the C++ code checks cudaGetLastError
# after the kernel and returns INTERNAL_ERROR if any was pending).
if _cudart is not None:
_last_err = _cudart.cudaGetLastError()
if _last_err != 0:
print(f" [rank {self._rank}] WARNING: cleared stale CUDA error code {_last_err} before execute", flush=True)
print(f" [rank {self._rank}] alltoallv: calling algo.execute(input_alloc={input_alloc_size}, output_alloc={output_alloc_size})", flush=True)
result = self._algo.execute(
self._comm,
input.data_ptr(),
output.data_ptr(),
input_alloc_size,
output_alloc_size,
_torch_dtype_to_mscclpp(dtype),
ReduceOp.NOP,
cuda_stream,
None, # executor (not needed for native algos)
0, # nblocks (auto)
0, # nthreads_per_block (auto)
self._extras,
)
if _DEBUG:
print(f" [rank {self._rank}] alltoallv: algo.execute returned {result}", flush=True)
self._exec_event.record()
torch.cuda.current_stream().wait_event(self._exec_event)
if result != CommResult.COMM_SUCCESS:
# Get detailed CUDA error before raising
try:
torch.cuda.synchronize()
except Exception as cuda_err:
raise RuntimeError(f"alltoallv execution failed with code {result}; CUDA error: {cuda_err}")
raise RuntimeError(f"alltoallv execution failed with code {result}")
# Lightweight async error probe: check for CUDA errors accumulated
# from the kernel we just launched (no full synchronize — just peek).
# This catches illegal-address faults at the source call instead of
# letting them propagate to the next unrelated CUDA API call.
err = torch.cuda.last_status() if hasattr(torch.cuda, 'last_status') else None
if err is not None and err != 0:
torch.cuda.synchronize() # force the full error to surface
raise RuntimeError(
f"[alltoallv #{_cid}] CUDA error detected after execute: "
f"err={err}, send_counts={send_counts_bytes}, "
f"recv_counts={recv_counts_bytes}"
)
return output
def _exchange_recv_displs(self, recv_displs_bytes: list) -> list:
"""
Exchange recv displacement arrays between all ranks via bootstrap allGather.
Each rank needs to know where to write in each peer's output buffer.
remoteRecvDispls[peer] = peer's recvDispls[rank] — the byte offset in
peer's output buffer where data from this rank should be placed.
Uses bootstrap.all_gather() (ring sockets, pre-established during
initialize()) instead of pairwise TCP send/recv to avoid deadlocks.
Args:
recv_displs_bytes: This rank's recv displacement array (in bytes)
Returns:
List of remote recv displacements (one per rank, in bytes).
remoteRecvDispls[rank] == recv_displs_bytes[rank] (self, unused by kernel)
"""
import numpy as np
rank = self._rank
world_size = self._world_size
bootstrap = self._comm.bootstrap()
# All-gather: each rank contributes world_size int64 values
all_data = np.zeros((world_size, world_size), dtype=np.int64)
all_data[rank, :] = recv_displs_bytes
per_rank_bytes = world_size * 8 # world_size x sizeof(int64)
bootstrap.all_gather(all_data.ctypes.data, per_rank_bytes)
# remoteRecvDispls[peer] = peer's recv_displs[rank]
remote_recv_displs = [int(all_data[peer, rank]) for peer in range(world_size)]
return remote_recv_displs
def __del__(self):
"""Cleanup resources."""
# Let CUDA handle tensor cleanup automatically
pass
# Module-level singleton for convenience
_default_instance: Optional[MscclppAlltoAllV] = None
def get_default_instance(**kwargs) -> MscclppAlltoAllV:
"""Get or create a default MscclppAlltoAllV instance."""
global _default_instance
if _default_instance is None:
_default_instance = MscclppAlltoAllV(**kwargs)
return _default_instance
def all_to_all_single(
output: torch.Tensor,
input: torch.Tensor,
output_split_sizes: Optional[List[int]] = None,
input_split_sizes: Optional[List[int]] = None,
group=None,
async_op: bool = False,
) -> Optional[torch.Tensor]:
"""
Drop-in replacement for torch.distributed.all_to_all_single.
Uses mscclpp optimized kernels internally for better performance,
especially with imbalanced message sizes (e.g., MoE workloads).
Note: This function requires prior initialization via get_default_instance()
or will fall back to PyTorch's native implementation.
Args:
output: Pre-allocated output tensor
input: Input tensor
output_split_sizes: Sizes to receive from each rank
input_split_sizes: Sizes to send to each rank
group: Process group (unused, for compatibility)
async_op: If True, return async handle (not supported, falls back)
Returns:
None (modifies output in-place) or async handle if async_op=True
"""
global _default_instance
# Fall back to PyTorch if not initialized or async requested
if _default_instance is None or async_op:
return dist.all_to_all_single(
output, input,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
async_op=async_op
)
# Use optimized mscclpp implementation
result = _default_instance.all_to_all_single(
input=input,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
output=output,
)
return None # Matches torch.distributed API (async_op=False returns None)