Merge multinode branch

This commit is contained in:
Qinghua Zhou
2026-03-25 02:51:24 +00:00
9 changed files with 857 additions and 125 deletions

View File

@@ -62,6 +62,7 @@ void register_algorithm(nb::module_& m) {
.def_prop_ro("buffer_mode", &Algorithm::bufferMode)
.def_prop_ro("constraint", &Algorithm::constraint)
.def_prop_ro("type", &Algorithm::type)
.def("reset", &Algorithm::reset)
.def(
"execute",
[](Algorithm& self, std::shared_ptr<Communicator> comm, uintptr_t input, uintptr_t output,

View File

@@ -248,6 +248,10 @@ class MscclppAlltoAllV:
# Fast path: skip GPU copies + bootstrap exchange if split sizes unchanged
splits_key = (tuple(send_counts_bytes), tuple(recv_counts_bytes))
if splits_key != self._cached_splits_key:
# Clear cached contexts to free RegisteredMemory for old (possibly freed) tensors.
# Without this, stale CUDA IPC handles accumulate and eventually SIGSEGV.
if hasattr(self._algo, 'reset'):
self._algo.reset()
# Copy counts/displacements to GPU
self._d_send_counts.copy_(torch.tensor(send_counts_bytes, dtype=torch.int64))
self._d_send_displs.copy_(torch.tensor(send_displs_bytes, dtype=torch.int64))
@@ -268,13 +272,16 @@ class MscclppAlltoAllV:
stream = torch.cuda.current_stream()
cuda_stream = stream.cuda_stream
# Use full buffer sizes (not actual data sizes) so the C++ context
# key (input_ptr, output_ptr, inputSize, outputSize) is always the
# same when using persistent buffers. This ensures only ONE context
# is ever created, avoiding bootstrap TCP on every unique size combo.
# The kernel uses per-peer sendCounts/recvCounts for actual data bounds.
input_size = input.numel() * elem_size
output_size = output.numel() * elem_size
# Use the full underlying storage size (not just the view's active data)
# for the context key, so that reusing views of the same tensor with
# different split sizes doesn't create new contexts (which leak
# RegisteredMemory for stale buffers).
try:
input_alloc_size = input.untyped_storage().size()
output_alloc_size = output.untyped_storage().size()
except Exception:
input_alloc_size = input.nelement() * input.element_size()
output_alloc_size = output.nelement() * output.element_size()
self._a2av_call_count += 1
_cid = self._a2av_call_count
@@ -297,8 +304,8 @@ class MscclppAlltoAllV:
self._comm,
input.data_ptr(),
output.data_ptr(),
input_size,
output_size,
input_alloc_size,
output_alloc_size,
_torch_dtype_to_mscclpp(dtype),
ReduceOp.NOP,
cuda_stream,

View File

@@ -12,30 +12,113 @@ Usage:
import torch
import torch.distributed as dist
import os
import sys
import time
import random
import socket
import struct
import pickle
from typing import Callable, List, Tuple
# Must init torch.distributed before importing mscclpp modules
# to set rank/world_size environment variables
def _get_routable_ip() -> str:
"""Get a routable IP address for this host (not 127.0.0.1)."""
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80)) # doesn't actually send data
ip = s.getsockname()[0]
s.close()
return ip
except Exception:
return socket.gethostbyname(socket.gethostname())
def _tcp_broadcast_unique_id(unique_id_bytes: bytes, rank: int, world_size: int,
master_addr: str, port: int = 18515) -> bytes:
"""Broadcast UniqueId bytes from rank 0 to all other ranks via TCP."""
if rank == 0:
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server.bind(("", port))
server.listen(world_size - 1)
for _ in range(world_size - 1):
conn, _ = server.accept()
length = len(unique_id_bytes)
conn.sendall(struct.pack("!I", length) + unique_id_bytes)
conn.close()
server.close()
return unique_id_bytes
else:
# Retry connection to rank 0
for attempt in range(120):
try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect((master_addr, port))
break
except (ConnectionRefusedError, OSError):
time.sleep(0.5)
if attempt == 119:
raise RuntimeError(f"Rank {rank}: failed to connect to {master_addr}:{port}")
raw_len = b""
while len(raw_len) < 4:
raw_len += s.recv(4 - len(raw_len))
length = struct.unpack("!I", raw_len)[0]
data = b""
while len(data) < length:
data += s.recv(length - len(data))
s.close()
return data
def main():
# Get rank/world from MPI environment
rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", os.environ.get("PMI_RANK", 0)))
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", os.environ.get("PMI_SIZE", 1)))
# Set CUDA device
local_rank = int(os.environ.get("LOCAL_RANK", rank % torch.cuda.device_count()))
# Set CUDA device — prefer MPI-provided local rank to handle any rank mapping
local_rank = int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK",
os.environ.get("MPI_LOCALRANKID",
os.environ.get("LOCAL_RANK", rank % torch.cuda.device_count()))))
torch.cuda.set_device(local_rank)
# Initialize torch.distributed with NCCL (need MASTER_ADDR/PORT)
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
# Disable UCX in OpenMPI to avoid version mismatch crashes
os.environ.setdefault("OMPI_MCA_pml", "ob1")
os.environ.setdefault("OMPI_MCA_btl", "tcp,vader,self")
# Initialize torch.distributed — use NCCL when torch_fn benchmarks are needed,
# otherwise gloo avoids IB configuration issues on some clusters.
# Set ALLTOALLV_BACKEND=nccl to enable torch baseline comparison.
backend = os.environ.get("ALLTOALLV_BACKEND", "gloo")
# For multi-node: MASTER_ADDR must be set to rank 0's routable IP.
# Single-node auto-detects; multi-node requires it from the launcher.
if "MASTER_ADDR" not in os.environ:
if rank == 0:
os.environ["MASTER_ADDR"] = _get_routable_ip()
else:
# Check if we're single-node (all ranks on same host)
n_gpus = torch.cuda.device_count()
if world_size <= n_gpus:
# Likely single-node 127.0.0.1 works
os.environ["MASTER_ADDR"] = "127.0.0.1"
else:
raise RuntimeError(
f"Rank {rank}: MASTER_ADDR not set for multi-node run "
f"(world_size={world_size} > local GPUs={n_gpus}). "
f"Set it in your launcher, e.g.:\n"
f" mpirun -x MASTER_ADDR=<node0_ip> -x MASTER_PORT=29500 ..."
)
os.environ.setdefault("MASTER_PORT", "29500")
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size,
device_id=torch.device(f"cuda:{local_rank}"))
if backend == "nccl":
# Don't use device_id= eager init — it triggers an immediate NCCL allreduce
# that fails on some platforms (e.g. GB200 with NCCL 2.28.9).
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
else:
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
if rank == 0:
print(f"Testing MscclppAlltoAllV with {world_size} ranks")
@@ -48,33 +131,51 @@ def main():
UniqueId,
)
from mscclpp.ext.alltoallv_single import MscclppAlltoAllV
import pickle
# Create mscclpp communicator with TcpBootstrap
# Use torch.distributed to share the unique ID via pickle
# Broadcast UniqueId via TCP sockets (avoids mpi4py/UCX issues)
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
uid_port = int(os.environ.get("MSCCLPP_UID_PORT", "18515"))
bootstrap = TcpBootstrap(rank, world_size)
if rank == 0:
unique_id = bootstrap.create_unique_id()
# Serialize UniqueId via pickle and broadcast
pickled = pickle.dumps(unique_id)
id_tensor = torch.zeros(256, dtype=torch.uint8, device='cuda')
id_tensor[:len(pickled)] = torch.tensor(list(pickled), dtype=torch.uint8)
# Also send length
len_tensor = torch.tensor([len(pickled)], dtype=torch.int64, device='cuda')
uid_bytes = pickle.dumps(unique_id)
else:
id_tensor = torch.zeros(256, dtype=torch.uint8, device='cuda')
len_tensor = torch.zeros(1, dtype=torch.int64, device='cuda')
dist.broadcast(len_tensor, src=0)
dist.broadcast(id_tensor, src=0)
uid_bytes = b""
uid_bytes = _tcp_broadcast_unique_id(uid_bytes, rank, world_size, master_addr, uid_port)
if rank != 0:
pickled_len = int(len_tensor.item())
pickled = bytes(id_tensor[:pickled_len].cpu().tolist())
unique_id = pickle.loads(pickled)
unique_id = pickle.loads(uid_bytes)
bootstrap.initialize(unique_id)
# ── Multi-node diagnostics ─────────────────────────────────────────
import subprocess, platform
hostname = platform.node()
n_ranks_per_node = bootstrap.get_n_ranks_per_node()
is_multi_node = (world_size > n_ranks_per_node)
# Check IB device availability
try:
ib_out = subprocess.check_output(["ibv_devinfo", "-l"], stderr=subprocess.DEVNULL, timeout=5).decode().strip()
ib_devices = [l.strip() for l in ib_out.splitlines() if l.strip() and "device" not in l.lower()]
except Exception:
ib_devices = []
if rank == 0:
print(f" Hostname: {hostname}")
print(f" nRanksPerNode: {n_ranks_per_node}, isMultiNode: {is_multi_node}")
print(f" IB devices: {ib_devices if ib_devices else 'NONE FOUND'}")
print(f" MSCCLPP_SOCKET_IFNAME: {os.environ.get('MSCCLPP_SOCKET_IFNAME', '<not set>')}")
if is_multi_node and not ib_devices:
print(f" WARNING: Multi-node detected but no IB devices! Cross-node will fail.")
# Also print from rank n_ranks_per_node (first rank on node 1) for comparison
if is_multi_node and rank == n_ranks_per_node:
print(f" [Node 1] Hostname: {hostname}, rank={rank}")
print(f" [Node 1] IB devices: {ib_devices if ib_devices else 'NONE FOUND'}")
# ── End diagnostics ────────────────────────────────────────────────
comm = Communicator(bootstrap)
# Create MscclppAlltoAllV with existing communicator
@@ -188,6 +289,21 @@ def main():
input_split_sizes=in_splits,
output_split_sizes=out_splits)
# Detect whether torch comparison is possible:
# - Need NCCL backend (gloo doesn't support all_to_all_single)
# - Need native NCCL (mscclpp shim doesn't implement all collectives)
use_torch_baseline = (backend == "nccl")
if use_torch_baseline:
try:
tiny_in = torch.zeros(world_size, dtype=torch.float32, device='cuda')
tiny_out = torch.zeros(world_size, dtype=torch.float32, device='cuda')
dist.all_to_all_single(tiny_out, tiny_in)
torch.cuda.synchronize()
except Exception:
use_torch_baseline = False
if rank == 0:
print(" [INFO] torch all_to_all_single unavailable, skipping torch baseline")
def torch_fn(inp, out, in_splits, out_splits):
dist.all_to_all_single(out, inp,
output_split_sizes=out_splits,
@@ -202,22 +318,32 @@ def main():
def print_header():
if rank == 0:
print(f" {'Avg Size':>10s} "
f"{'mscclpp Lat':>12s} {'mscclpp BW':>11s} "
f"{'torch Lat':>10s} {'torch BW':>9s} "
f"{'Speedup':>7s}")
print(f" {'-'*10} "
f"{'-'*12} {'-'*11} "
f"{'-'*10} {'-'*9} "
f"{'-'*7}")
if use_torch_baseline:
print(f" {'Avg Size':>10s} "
f"{'mscclpp Lat':>12s} {'mscclpp BW':>11s} "
f"{'torch Lat':>10s} {'torch BW':>9s} "
f"{'Speedup':>7s}")
print(f" {'-'*10} "
f"{'-'*12} {'-'*11} "
f"{'-'*10} {'-'*9} "
f"{'-'*7}")
else:
print(f" {'Avg Size':>10s} "
f"{'mscclpp Lat':>12s} {'mscclpp BW':>11s}")
print(f" {'-'*10} "
f"{'-'*12} {'-'*11}")
def print_row(size_str, m_lat, m_bw, t_lat, t_bw):
def print_row(size_str, m_lat, m_bw, t_lat=None, t_bw=None):
if rank == 0:
speedup = m_bw / t_bw if t_bw > 0 else float('inf')
print(f" {size_str:>10s} "
f"{m_lat:>10.1f}us {m_bw:>9.2f}GB "
f"{t_lat:>8.1f}us {t_bw:>7.2f}GB "
f"{speedup:>6.2f}x")
if t_bw is not None and t_bw > 0:
speedup = m_bw / t_bw
print(f" {size_str:>10s} "
f"{m_lat:>10.1f}us {m_bw:>9.2f}GB "
f"{t_lat:>8.1f}us {t_bw:>7.2f}GB "
f"{speedup:>6.2f}x")
else:
print(f" {size_str:>10s} "
f"{m_lat:>10.1f}us {m_bw:>9.2f}GB")
# ── Test 3: Synthetic variable-size sweep ─────────────────────────────
if rank == 0:
@@ -227,6 +353,13 @@ def main():
msg_sizes = [1 << s for s in range(10, 28) if s % 2 == 0]
msg_sizes.append(128 * 1024 * 1024)
# Pre-compute max split sizes across all sweep iterations to allocate
# fixed-size tensors. Reusing the same tensors keeps the NativeAlgorithm
# context key stable (same ptrs + sizes) and avoids the context cache
# leak that causes SIGSEGV when stale RegisteredMemory accumulates.
max_in_elems = 0
max_out_elems = 0
sweep_params = [] # (avg_msg_size, in_splits, out_splits)
for avg_msg_size in msg_sizes:
random.seed(12345)
avg_elems = avg_msg_size // 4
@@ -234,19 +367,41 @@ def main():
for i in range(world_size):
row = [max(1, int(avg_elems * (0.5 + random.random()))) for _ in range(world_size)]
send_matrix.append(row)
in_splits = send_matrix[rank]
out_splits = [send_matrix[j][rank] for j in range(world_size)]
max_in_elems = max(max_in_elems, sum(in_splits))
max_out_elems = max(max_out_elems, sum(out_splits))
sweep_params.append((avg_msg_size, in_splits, out_splits))
inp = torch.randn(sum(in_splits), dtype=torch.float32, device='cuda')
out = torch.empty(sum(out_splits), dtype=torch.float32, device='cuda')
# Allocate once at max size
inp = torch.randn(max_in_elems, dtype=torch.float32, device='cuda')
out = torch.empty(max_out_elems, dtype=torch.float32, device='cuda')
for avg_msg_size, in_splits, out_splits in sweep_params:
n_warmup = 3 if avg_msg_size >= 16 * 1024 * 1024 else 5
n_iters = 5 if avg_msg_size >= 64 * 1024 * 1024 else (10 if avg_msg_size >= 4 * 1024 * 1024 else 20)
m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
print_row(fmt_size(avg_msg_size), m_lat, m_bw, t_lat, t_bw)
# Use views into the fixed buffers (same data_ptr → same context key)
inp_view = inp[:sum(in_splits)]
out_view = out[:sum(out_splits)]
m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp_view, out_view, in_splits, out_splits, n_warmup, n_iters)
if use_torch_baseline:
try:
t_lat, t_bw = bench_alltoallv(torch_fn, inp_view, out_view, in_splits, out_splits, n_warmup, n_iters)
print_row(fmt_size(avg_msg_size), m_lat, m_bw, t_lat, t_bw)
except Exception as e:
if rank == 0:
print(f" [WARN] torch baseline failed: {e}")
print(f" [INFO] Disabling torch baseline for remaining sizes")
use_torch_baseline = False
try:
torch.cuda.synchronize()
except Exception:
pass
print_row(fmt_size(avg_msg_size), m_lat, m_bw)
else:
print_row(fmt_size(avg_msg_size), m_lat, m_bw)
# ── Test 4: Real MoE workloads ───────────────────────────────────────
# Token counts from real MoE training runs (rank 0's view, 8 GPUs).
@@ -257,19 +412,30 @@ def main():
# per rank so every rank has the same total send and each NVLink
# carries a realistically imbalanced load.
# 10 workloads picked from 3M dispatch records in a real MoE training run,
# covering the full imbalance spectrum from nearly uniform (1.05×) to
# extremely skewed (10×). Each has 32768 total tokens → 167.8MB.
MOE_WORKLOADS = [
{
"name": "MoE-A",
# input_splits=[3976,3916,4497,4838,2888,3839,4355,4459]
# total_send=167,772,160 total_recv=148,316,160
"input_tokens": [3976, 3916, 4497, 4838, 2888, 3839, 4355, 4459],
},
{
"name": "MoE-B",
# input_splits=[3009,7161,2719,2766,3428,3010,6290,4385]
# total_send=167,772,160 total_recv=163,722,240
"input_tokens": [3009, 7161, 2719, 2766, 3428, 3010, 6290, 4385],
},
{"name": "MoE-A", # imbalance ≈ 1.05× (near-uniform)
"input_tokens": [4122, 4115, 4000, 4200, 4126, 4046, 4035, 4124]},
{"name": "MoE-B", # imbalance ≈ 1.20×
"input_tokens": [3770, 4236, 3966, 4046, 4524, 4132, 3825, 4269]},
{"name": "MoE-C", # imbalance ≈ 1.35×
"input_tokens": [4142, 4489, 4563, 3380, 3957, 4133, 3958, 4146]},
{"name": "MoE-D", # imbalance ≈ 1.50× (median)
"input_tokens": [4232, 3697, 4619, 4788, 4420, 3192, 3971, 3849]},
{"name": "MoE-E", # imbalance ≈ 1.75×
"input_tokens": [4178, 3209, 4678, 5085, 3108, 3365, 5439, 3706]},
{"name": "MoE-F", # imbalance ≈ 2.00×
"input_tokens": [4582, 3903, 3949, 3727, 4823, 5106, 2553, 4125]},
{"name": "MoE-G", # imbalance ≈ 2.50×
"input_tokens": [4036, 4438, 4804, 6180, 2913, 2472, 4105, 3820]},
{"name": "MoE-H", # imbalance ≈ 3.50×
"input_tokens": [3152, 1722, 4406, 4027, 5365, 6027, 4895, 3174]},
{"name": "MoE-I", # imbalance ≈ 5.00×
"input_tokens": [4384, 4194, 7840, 3079, 3460, 3506, 1568, 4737]},
{"name": "MoE-J", # imbalance ≈ 10.00× (extreme skew)
"input_tokens": [2710, 7661, 3354, 4457, 4609, 766, 3423, 5788]},
]
ELEMS_PER_TOKEN = 2560 # 5120 bytes / 2 bytes-per-bfloat16
@@ -304,10 +470,23 @@ def main():
n_warmup, n_iters = 5, 20
m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
avg_bytes = total_bytes // world_size
print_row(fmt_size(avg_bytes), m_lat, m_bw, t_lat, t_bw)
if use_torch_baseline:
try:
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
print_row(fmt_size(avg_bytes), m_lat, m_bw, t_lat, t_bw)
except Exception as e:
if rank == 0:
print(f" [WARN] torch baseline failed: {e}")
print(f" [INFO] Disabling torch baseline for remaining workloads")
use_torch_baseline = False
try:
torch.cuda.synchronize()
except Exception:
pass
print_row(fmt_size(avg_bytes), m_lat, m_bw)
else:
print_row(fmt_size(avg_bytes), m_lat, m_bw)
else:
if rank == 0:
print("\n[Test 4] Skipped (real MoE workloads require exactly 8 ranks)")