mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
Broadcast UniqueId via TCP; Detect whether torch comparison is possible
This commit is contained in:
@@ -12,14 +12,67 @@ Usage:
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import random
|
||||
import socket
|
||||
import struct
|
||||
import pickle
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
# Must init torch.distributed before importing mscclpp modules
|
||||
# to set rank/world_size environment variables
|
||||
|
||||
|
||||
def _get_routable_ip() -> str:
|
||||
"""Get a routable IP address for this host (not 127.0.0.1)."""
|
||||
try:
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
s.connect(("8.8.8.8", 80)) # doesn't actually send data
|
||||
ip = s.getsockname()[0]
|
||||
s.close()
|
||||
return ip
|
||||
except Exception:
|
||||
return socket.gethostbyname(socket.gethostname())
|
||||
|
||||
|
||||
def _tcp_broadcast_unique_id(unique_id_bytes: bytes, rank: int, world_size: int,
|
||||
master_addr: str, port: int = 18515) -> bytes:
|
||||
"""Broadcast UniqueId bytes from rank 0 to all other ranks via TCP."""
|
||||
if rank == 0:
|
||||
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
server.bind(("", port))
|
||||
server.listen(world_size - 1)
|
||||
for _ in range(world_size - 1):
|
||||
conn, _ = server.accept()
|
||||
length = len(unique_id_bytes)
|
||||
conn.sendall(struct.pack("!I", length) + unique_id_bytes)
|
||||
conn.close()
|
||||
server.close()
|
||||
return unique_id_bytes
|
||||
else:
|
||||
# Retry connection to rank 0
|
||||
for attempt in range(120):
|
||||
try:
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.connect((master_addr, port))
|
||||
break
|
||||
except (ConnectionRefusedError, OSError):
|
||||
time.sleep(0.5)
|
||||
if attempt == 119:
|
||||
raise RuntimeError(f"Rank {rank}: failed to connect to {master_addr}:{port}")
|
||||
raw_len = b""
|
||||
while len(raw_len) < 4:
|
||||
raw_len += s.recv(4 - len(raw_len))
|
||||
length = struct.unpack("!I", raw_len)[0]
|
||||
data = b""
|
||||
while len(data) < length:
|
||||
data += s.recv(length - len(data))
|
||||
s.close()
|
||||
return data
|
||||
|
||||
|
||||
def main():
|
||||
# Get rank/world from MPI environment
|
||||
rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", os.environ.get("PMI_RANK", 0)))
|
||||
@@ -31,11 +84,21 @@ def main():
|
||||
os.environ.get("LOCAL_RANK", rank % torch.cuda.device_count()))))
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
# Disable UCX in OpenMPI to avoid version mismatch crashes
|
||||
os.environ.setdefault("OMPI_MCA_pml", "ob1")
|
||||
os.environ.setdefault("OMPI_MCA_btl", "tcp,vader,self")
|
||||
|
||||
# Initialize torch.distributed — use NCCL when torch_fn benchmarks are needed,
|
||||
# otherwise gloo avoids IB configuration issues on some clusters.
|
||||
# Set ALLTOALLV_BACKEND=nccl to enable torch baseline comparison.
|
||||
backend = os.environ.get("ALLTOALLV_BACKEND", "gloo")
|
||||
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
|
||||
# For multi-node: detect a routable IP instead of 127.0.0.1
|
||||
if "MASTER_ADDR" not in os.environ:
|
||||
if rank == 0:
|
||||
os.environ["MASTER_ADDR"] = _get_routable_ip()
|
||||
else:
|
||||
# Non-zero ranks: MASTER_ADDR must be set externally for multi-node
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
os.environ.setdefault("MASTER_PORT", "29500")
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
@@ -56,20 +119,22 @@ def main():
|
||||
UniqueId,
|
||||
)
|
||||
from mscclpp.ext.alltoallv_single import MscclppAlltoAllV
|
||||
from mpi4py import MPI
|
||||
mpi_comm = MPI.COMM_WORLD
|
||||
|
||||
# Create mscclpp communicator with TcpBootstrap
|
||||
# Broadcast UniqueId raw bytes (128 bytes) via MPI to avoid NCCL interception issues
|
||||
# Broadcast UniqueId via TCP sockets (avoids mpi4py/UCX issues)
|
||||
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
|
||||
uid_port = int(os.environ.get("MSCCLPP_UID_PORT", "18515"))
|
||||
bootstrap = TcpBootstrap(rank, world_size)
|
||||
|
||||
if rank == 0:
|
||||
unique_id = bootstrap.create_unique_id()
|
||||
uid_bytes = pickle.dumps(unique_id)
|
||||
else:
|
||||
unique_id = UniqueId()
|
||||
uid_bytes = b""
|
||||
|
||||
# UniqueId supports pickle (__getstate__/__setstate__), MPI bcast uses pickle
|
||||
unique_id = mpi_comm.bcast(unique_id, root=0)
|
||||
uid_bytes = _tcp_broadcast_unique_id(uid_bytes, rank, world_size, master_addr, uid_port)
|
||||
if rank != 0:
|
||||
unique_id = pickle.loads(uid_bytes)
|
||||
|
||||
bootstrap.initialize(unique_id)
|
||||
comm = Communicator(bootstrap)
|
||||
@@ -185,6 +250,21 @@ def main():
|
||||
input_split_sizes=in_splits,
|
||||
output_split_sizes=out_splits)
|
||||
|
||||
# Detect whether torch comparison is possible:
|
||||
# - Need NCCL backend (gloo doesn't support all_to_all_single)
|
||||
# - Need native NCCL (mscclpp shim doesn't implement all collectives)
|
||||
use_torch_baseline = (backend == "nccl")
|
||||
if use_torch_baseline:
|
||||
try:
|
||||
# Quick test: if the NCCL shim is active it may not support all_to_all_single
|
||||
tiny_in = torch.zeros(world_size, dtype=torch.float32, device='cuda')
|
||||
tiny_out = torch.zeros(world_size, dtype=torch.float32, device='cuda')
|
||||
dist.all_to_all_single(tiny_out, tiny_in)
|
||||
except Exception:
|
||||
use_torch_baseline = False
|
||||
if rank == 0:
|
||||
print(" [INFO] torch all_to_all_single unavailable, skipping torch baseline")
|
||||
|
||||
def torch_fn(inp, out, in_splits, out_splits):
|
||||
dist.all_to_all_single(out, inp,
|
||||
output_split_sizes=out_splits,
|
||||
@@ -199,22 +279,32 @@ def main():
|
||||
|
||||
def print_header():
|
||||
if rank == 0:
|
||||
print(f" {'Avg Size':>10s} "
|
||||
f"{'mscclpp Lat':>12s} {'mscclpp BW':>11s} "
|
||||
f"{'torch Lat':>10s} {'torch BW':>9s} "
|
||||
f"{'Speedup':>7s}")
|
||||
print(f" {'-'*10} "
|
||||
f"{'-'*12} {'-'*11} "
|
||||
f"{'-'*10} {'-'*9} "
|
||||
f"{'-'*7}")
|
||||
if use_torch_baseline:
|
||||
print(f" {'Avg Size':>10s} "
|
||||
f"{'mscclpp Lat':>12s} {'mscclpp BW':>11s} "
|
||||
f"{'torch Lat':>10s} {'torch BW':>9s} "
|
||||
f"{'Speedup':>7s}")
|
||||
print(f" {'-'*10} "
|
||||
f"{'-'*12} {'-'*11} "
|
||||
f"{'-'*10} {'-'*9} "
|
||||
f"{'-'*7}")
|
||||
else:
|
||||
print(f" {'Avg Size':>10s} "
|
||||
f"{'mscclpp Lat':>12s} {'mscclpp BW':>11s}")
|
||||
print(f" {'-'*10} "
|
||||
f"{'-'*12} {'-'*11}")
|
||||
|
||||
def print_row(size_str, m_lat, m_bw, t_lat, t_bw):
|
||||
def print_row(size_str, m_lat, m_bw, t_lat=None, t_bw=None):
|
||||
if rank == 0:
|
||||
speedup = m_bw / t_bw if t_bw > 0 else float('inf')
|
||||
print(f" {size_str:>10s} "
|
||||
f"{m_lat:>10.1f}us {m_bw:>9.2f}GB "
|
||||
f"{t_lat:>8.1f}us {t_bw:>7.2f}GB "
|
||||
f"{speedup:>6.2f}x")
|
||||
if t_bw is not None and t_bw > 0:
|
||||
speedup = m_bw / t_bw
|
||||
print(f" {size_str:>10s} "
|
||||
f"{m_lat:>10.1f}us {m_bw:>9.2f}GB "
|
||||
f"{t_lat:>8.1f}us {t_bw:>7.2f}GB "
|
||||
f"{speedup:>6.2f}x")
|
||||
else:
|
||||
print(f" {size_str:>10s} "
|
||||
f"{m_lat:>10.1f}us {m_bw:>9.2f}GB")
|
||||
|
||||
# ── Test 3: Synthetic variable-size sweep ─────────────────────────────
|
||||
if rank == 0:
|
||||
@@ -242,8 +332,11 @@ def main():
|
||||
n_iters = 5 if avg_msg_size >= 64 * 1024 * 1024 else (10 if avg_msg_size >= 4 * 1024 * 1024 else 20)
|
||||
|
||||
m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
print_row(fmt_size(avg_msg_size), m_lat, m_bw, t_lat, t_bw)
|
||||
if use_torch_baseline:
|
||||
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
print_row(fmt_size(avg_msg_size), m_lat, m_bw, t_lat, t_bw)
|
||||
else:
|
||||
print_row(fmt_size(avg_msg_size), m_lat, m_bw)
|
||||
|
||||
# ── Test 4: Real MoE workloads ───────────────────────────────────────
|
||||
# Token counts from real MoE training runs (rank 0's view, 8 GPUs).
|
||||
@@ -312,10 +405,13 @@ def main():
|
||||
n_warmup, n_iters = 5, 20
|
||||
|
||||
m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
|
||||
avg_bytes = total_bytes // world_size
|
||||
print_row(fmt_size(avg_bytes), m_lat, m_bw, t_lat, t_bw)
|
||||
if use_torch_baseline:
|
||||
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
avg_bytes = total_bytes // world_size
|
||||
print_row(fmt_size(avg_bytes), m_lat, m_bw, t_lat, t_bw)
|
||||
else:
|
||||
avg_bytes = total_bytes // world_size
|
||||
print_row(fmt_size(avg_bytes), m_lat, m_bw)
|
||||
else:
|
||||
if rank == 0:
|
||||
print("\n[Test 4] Skipped (real MoE workloads require exactly 8 ranks)")
|
||||
|
||||
Reference in New Issue
Block a user