Broadcast UniqueId via TCP; Detect whether torch comparison is possible

This commit is contained in:
Qinghua Zhou
2026-03-16 10:01:35 +00:00
parent f47e97659d
commit bdb30b56a5

View File

@@ -12,14 +12,67 @@ Usage:
import torch
import torch.distributed as dist
import os
import sys
import time
import random
import socket
import struct
import pickle
from typing import Callable, List, Tuple
# Must init torch.distributed before importing mscclpp modules
# to set rank/world_size environment variables
def _get_routable_ip() -> str:
"""Get a routable IP address for this host (not 127.0.0.1)."""
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80)) # doesn't actually send data
ip = s.getsockname()[0]
s.close()
return ip
except Exception:
return socket.gethostbyname(socket.gethostname())
def _tcp_broadcast_unique_id(unique_id_bytes: bytes, rank: int, world_size: int,
master_addr: str, port: int = 18515) -> bytes:
"""Broadcast UniqueId bytes from rank 0 to all other ranks via TCP."""
if rank == 0:
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server.bind(("", port))
server.listen(world_size - 1)
for _ in range(world_size - 1):
conn, _ = server.accept()
length = len(unique_id_bytes)
conn.sendall(struct.pack("!I", length) + unique_id_bytes)
conn.close()
server.close()
return unique_id_bytes
else:
# Retry connection to rank 0
for attempt in range(120):
try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect((master_addr, port))
break
except (ConnectionRefusedError, OSError):
time.sleep(0.5)
if attempt == 119:
raise RuntimeError(f"Rank {rank}: failed to connect to {master_addr}:{port}")
raw_len = b""
while len(raw_len) < 4:
raw_len += s.recv(4 - len(raw_len))
length = struct.unpack("!I", raw_len)[0]
data = b""
while len(data) < length:
data += s.recv(length - len(data))
s.close()
return data
def main():
# Get rank/world from MPI environment
rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", os.environ.get("PMI_RANK", 0)))
@@ -31,11 +84,21 @@ def main():
os.environ.get("LOCAL_RANK", rank % torch.cuda.device_count()))))
torch.cuda.set_device(local_rank)
# Disable UCX in OpenMPI to avoid version mismatch crashes
os.environ.setdefault("OMPI_MCA_pml", "ob1")
os.environ.setdefault("OMPI_MCA_btl", "tcp,vader,self")
# Initialize torch.distributed — use NCCL when torch_fn benchmarks are needed,
# otherwise gloo avoids IB configuration issues on some clusters.
# Set ALLTOALLV_BACKEND=nccl to enable torch baseline comparison.
backend = os.environ.get("ALLTOALLV_BACKEND", "gloo")
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
# For multi-node: detect a routable IP instead of 127.0.0.1
if "MASTER_ADDR" not in os.environ:
if rank == 0:
os.environ["MASTER_ADDR"] = _get_routable_ip()
else:
# Non-zero ranks: MASTER_ADDR must be set externally for multi-node
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ.setdefault("MASTER_PORT", "29500")
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
@@ -56,20 +119,22 @@ def main():
UniqueId,
)
from mscclpp.ext.alltoallv_single import MscclppAlltoAllV
from mpi4py import MPI
mpi_comm = MPI.COMM_WORLD
# Create mscclpp communicator with TcpBootstrap
# Broadcast UniqueId raw bytes (128 bytes) via MPI to avoid NCCL interception issues
# Broadcast UniqueId via TCP sockets (avoids mpi4py/UCX issues)
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
uid_port = int(os.environ.get("MSCCLPP_UID_PORT", "18515"))
bootstrap = TcpBootstrap(rank, world_size)
if rank == 0:
unique_id = bootstrap.create_unique_id()
uid_bytes = pickle.dumps(unique_id)
else:
unique_id = UniqueId()
uid_bytes = b""
# UniqueId supports pickle (__getstate__/__setstate__), MPI bcast uses pickle
unique_id = mpi_comm.bcast(unique_id, root=0)
uid_bytes = _tcp_broadcast_unique_id(uid_bytes, rank, world_size, master_addr, uid_port)
if rank != 0:
unique_id = pickle.loads(uid_bytes)
bootstrap.initialize(unique_id)
comm = Communicator(bootstrap)
@@ -185,6 +250,21 @@ def main():
input_split_sizes=in_splits,
output_split_sizes=out_splits)
# Detect whether torch comparison is possible:
# - Need NCCL backend (gloo doesn't support all_to_all_single)
# - Need native NCCL (mscclpp shim doesn't implement all collectives)
use_torch_baseline = (backend == "nccl")
if use_torch_baseline:
try:
# Quick test: if the NCCL shim is active it may not support all_to_all_single
tiny_in = torch.zeros(world_size, dtype=torch.float32, device='cuda')
tiny_out = torch.zeros(world_size, dtype=torch.float32, device='cuda')
dist.all_to_all_single(tiny_out, tiny_in)
except Exception:
use_torch_baseline = False
if rank == 0:
print(" [INFO] torch all_to_all_single unavailable, skipping torch baseline")
def torch_fn(inp, out, in_splits, out_splits):
dist.all_to_all_single(out, inp,
output_split_sizes=out_splits,
@@ -199,22 +279,32 @@ def main():
def print_header():
if rank == 0:
print(f" {'Avg Size':>10s} "
f"{'mscclpp Lat':>12s} {'mscclpp BW':>11s} "
f"{'torch Lat':>10s} {'torch BW':>9s} "
f"{'Speedup':>7s}")
print(f" {'-'*10} "
f"{'-'*12} {'-'*11} "
f"{'-'*10} {'-'*9} "
f"{'-'*7}")
if use_torch_baseline:
print(f" {'Avg Size':>10s} "
f"{'mscclpp Lat':>12s} {'mscclpp BW':>11s} "
f"{'torch Lat':>10s} {'torch BW':>9s} "
f"{'Speedup':>7s}")
print(f" {'-'*10} "
f"{'-'*12} {'-'*11} "
f"{'-'*10} {'-'*9} "
f"{'-'*7}")
else:
print(f" {'Avg Size':>10s} "
f"{'mscclpp Lat':>12s} {'mscclpp BW':>11s}")
print(f" {'-'*10} "
f"{'-'*12} {'-'*11}")
def print_row(size_str, m_lat, m_bw, t_lat, t_bw):
def print_row(size_str, m_lat, m_bw, t_lat=None, t_bw=None):
if rank == 0:
speedup = m_bw / t_bw if t_bw > 0 else float('inf')
print(f" {size_str:>10s} "
f"{m_lat:>10.1f}us {m_bw:>9.2f}GB "
f"{t_lat:>8.1f}us {t_bw:>7.2f}GB "
f"{speedup:>6.2f}x")
if t_bw is not None and t_bw > 0:
speedup = m_bw / t_bw
print(f" {size_str:>10s} "
f"{m_lat:>10.1f}us {m_bw:>9.2f}GB "
f"{t_lat:>8.1f}us {t_bw:>7.2f}GB "
f"{speedup:>6.2f}x")
else:
print(f" {size_str:>10s} "
f"{m_lat:>10.1f}us {m_bw:>9.2f}GB")
# ── Test 3: Synthetic variable-size sweep ─────────────────────────────
if rank == 0:
@@ -242,8 +332,11 @@ def main():
n_iters = 5 if avg_msg_size >= 64 * 1024 * 1024 else (10 if avg_msg_size >= 4 * 1024 * 1024 else 20)
m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
print_row(fmt_size(avg_msg_size), m_lat, m_bw, t_lat, t_bw)
if use_torch_baseline:
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
print_row(fmt_size(avg_msg_size), m_lat, m_bw, t_lat, t_bw)
else:
print_row(fmt_size(avg_msg_size), m_lat, m_bw)
# ── Test 4: Real MoE workloads ───────────────────────────────────────
# Token counts from real MoE training runs (rank 0's view, 8 GPUs).
@@ -312,10 +405,13 @@ def main():
n_warmup, n_iters = 5, 20
m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
avg_bytes = total_bytes // world_size
print_row(fmt_size(avg_bytes), m_lat, m_bw, t_lat, t_bw)
if use_torch_baseline:
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
avg_bytes = total_bytes // world_size
print_row(fmt_size(avg_bytes), m_lat, m_bw, t_lat, t_bw)
else:
avg_bytes = total_bytes // world_size
print_row(fmt_size(avg_bytes), m_lat, m_bw)
else:
if rank == 0:
print("\n[Test 4] Skipped (real MoE workloads require exactly 8 ranks)")