diff --git a/python/test/test_alltoallv_mscclpp.py b/python/test/test_alltoallv_mscclpp.py index eb9e2fab..fcd8738a 100644 --- a/python/test/test_alltoallv_mscclpp.py +++ b/python/test/test_alltoallv_mscclpp.py @@ -12,14 +12,67 @@ Usage: import torch import torch.distributed as dist import os +import sys import time import random +import socket +import struct +import pickle from typing import Callable, List, Tuple # Must init torch.distributed before importing mscclpp modules # to set rank/world_size environment variables +def _get_routable_ip() -> str: + """Get a routable IP address for this host (not 127.0.0.1).""" + try: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(("8.8.8.8", 80)) # doesn't actually send data + ip = s.getsockname()[0] + s.close() + return ip + except Exception: + return socket.gethostbyname(socket.gethostname()) + + +def _tcp_broadcast_unique_id(unique_id_bytes: bytes, rank: int, world_size: int, + master_addr: str, port: int = 18515) -> bytes: + """Broadcast UniqueId bytes from rank 0 to all other ranks via TCP.""" + if rank == 0: + server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + server.bind(("", port)) + server.listen(world_size - 1) + for _ in range(world_size - 1): + conn, _ = server.accept() + length = len(unique_id_bytes) + conn.sendall(struct.pack("!I", length) + unique_id_bytes) + conn.close() + server.close() + return unique_id_bytes + else: + # Retry connection to rank 0 + for attempt in range(120): + try: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.connect((master_addr, port)) + break + except (ConnectionRefusedError, OSError): + time.sleep(0.5) + if attempt == 119: + raise RuntimeError(f"Rank {rank}: failed to connect to {master_addr}:{port}") + raw_len = b"" + while len(raw_len) < 4: + raw_len += s.recv(4 - len(raw_len)) + length = struct.unpack("!I", raw_len)[0] + data = b"" + while len(data) < length: + data += s.recv(length - len(data)) + s.close() + return data + + def main(): # Get rank/world from MPI environment rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", os.environ.get("PMI_RANK", 0))) @@ -31,11 +84,21 @@ def main(): os.environ.get("LOCAL_RANK", rank % torch.cuda.device_count())))) torch.cuda.set_device(local_rank) + # Disable UCX in OpenMPI to avoid version mismatch crashes + os.environ.setdefault("OMPI_MCA_pml", "ob1") + os.environ.setdefault("OMPI_MCA_btl", "tcp,vader,self") + # Initialize torch.distributed — use NCCL when torch_fn benchmarks are needed, # otherwise gloo avoids IB configuration issues on some clusters. # Set ALLTOALLV_BACKEND=nccl to enable torch baseline comparison. backend = os.environ.get("ALLTOALLV_BACKEND", "gloo") - os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + # For multi-node: detect a routable IP instead of 127.0.0.1 + if "MASTER_ADDR" not in os.environ: + if rank == 0: + os.environ["MASTER_ADDR"] = _get_routable_ip() + else: + # Non-zero ranks: MASTER_ADDR must be set externally for multi-node + os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ.setdefault("MASTER_PORT", "29500") os.environ["RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) @@ -56,20 +119,22 @@ def main(): UniqueId, ) from mscclpp.ext.alltoallv_single import MscclppAlltoAllV - from mpi4py import MPI - mpi_comm = MPI.COMM_WORLD # Create mscclpp communicator with TcpBootstrap - # Broadcast UniqueId raw bytes (128 bytes) via MPI to avoid NCCL interception issues + # Broadcast UniqueId via TCP sockets (avoids mpi4py/UCX issues) + master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") + uid_port = int(os.environ.get("MSCCLPP_UID_PORT", "18515")) bootstrap = TcpBootstrap(rank, world_size) if rank == 0: unique_id = bootstrap.create_unique_id() + uid_bytes = pickle.dumps(unique_id) else: - unique_id = UniqueId() + uid_bytes = b"" - # UniqueId supports pickle (__getstate__/__setstate__), MPI bcast uses pickle - unique_id = mpi_comm.bcast(unique_id, root=0) + uid_bytes = _tcp_broadcast_unique_id(uid_bytes, rank, world_size, master_addr, uid_port) + if rank != 0: + unique_id = pickle.loads(uid_bytes) bootstrap.initialize(unique_id) comm = Communicator(bootstrap) @@ -185,6 +250,21 @@ def main(): input_split_sizes=in_splits, output_split_sizes=out_splits) + # Detect whether torch comparison is possible: + # - Need NCCL backend (gloo doesn't support all_to_all_single) + # - Need native NCCL (mscclpp shim doesn't implement all collectives) + use_torch_baseline = (backend == "nccl") + if use_torch_baseline: + try: + # Quick test: if the NCCL shim is active it may not support all_to_all_single + tiny_in = torch.zeros(world_size, dtype=torch.float32, device='cuda') + tiny_out = torch.zeros(world_size, dtype=torch.float32, device='cuda') + dist.all_to_all_single(tiny_out, tiny_in) + except Exception: + use_torch_baseline = False + if rank == 0: + print(" [INFO] torch all_to_all_single unavailable, skipping torch baseline") + def torch_fn(inp, out, in_splits, out_splits): dist.all_to_all_single(out, inp, output_split_sizes=out_splits, @@ -199,22 +279,32 @@ def main(): def print_header(): if rank == 0: - print(f" {'Avg Size':>10s} " - f"{'mscclpp Lat':>12s} {'mscclpp BW':>11s} " - f"{'torch Lat':>10s} {'torch BW':>9s} " - f"{'Speedup':>7s}") - print(f" {'-'*10} " - f"{'-'*12} {'-'*11} " - f"{'-'*10} {'-'*9} " - f"{'-'*7}") + if use_torch_baseline: + print(f" {'Avg Size':>10s} " + f"{'mscclpp Lat':>12s} {'mscclpp BW':>11s} " + f"{'torch Lat':>10s} {'torch BW':>9s} " + f"{'Speedup':>7s}") + print(f" {'-'*10} " + f"{'-'*12} {'-'*11} " + f"{'-'*10} {'-'*9} " + f"{'-'*7}") + else: + print(f" {'Avg Size':>10s} " + f"{'mscclpp Lat':>12s} {'mscclpp BW':>11s}") + print(f" {'-'*10} " + f"{'-'*12} {'-'*11}") - def print_row(size_str, m_lat, m_bw, t_lat, t_bw): + def print_row(size_str, m_lat, m_bw, t_lat=None, t_bw=None): if rank == 0: - speedup = m_bw / t_bw if t_bw > 0 else float('inf') - print(f" {size_str:>10s} " - f"{m_lat:>10.1f}us {m_bw:>9.2f}GB " - f"{t_lat:>8.1f}us {t_bw:>7.2f}GB " - f"{speedup:>6.2f}x") + if t_bw is not None and t_bw > 0: + speedup = m_bw / t_bw + print(f" {size_str:>10s} " + f"{m_lat:>10.1f}us {m_bw:>9.2f}GB " + f"{t_lat:>8.1f}us {t_bw:>7.2f}GB " + f"{speedup:>6.2f}x") + else: + print(f" {size_str:>10s} " + f"{m_lat:>10.1f}us {m_bw:>9.2f}GB") # ── Test 3: Synthetic variable-size sweep ───────────────────────────── if rank == 0: @@ -242,8 +332,11 @@ def main(): n_iters = 5 if avg_msg_size >= 64 * 1024 * 1024 else (10 if avg_msg_size >= 4 * 1024 * 1024 else 20) m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp, out, in_splits, out_splits, n_warmup, n_iters) - t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters) - print_row(fmt_size(avg_msg_size), m_lat, m_bw, t_lat, t_bw) + if use_torch_baseline: + t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters) + print_row(fmt_size(avg_msg_size), m_lat, m_bw, t_lat, t_bw) + else: + print_row(fmt_size(avg_msg_size), m_lat, m_bw) # ── Test 4: Real MoE workloads ─────────────────────────────────────── # Token counts from real MoE training runs (rank 0's view, 8 GPUs). @@ -312,10 +405,13 @@ def main(): n_warmup, n_iters = 5, 20 m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp, out, in_splits, out_splits, n_warmup, n_iters) - t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters) - - avg_bytes = total_bytes // world_size - print_row(fmt_size(avg_bytes), m_lat, m_bw, t_lat, t_bw) + if use_torch_baseline: + t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters) + avg_bytes = total_bytes // world_size + print_row(fmt_size(avg_bytes), m_lat, m_bw, t_lat, t_bw) + else: + avg_bytes = total_bytes // world_size + print_row(fmt_size(avg_bytes), m_lat, m_bw) else: if rank == 0: print("\n[Test 4] Skipped (real MoE workloads require exactly 8 ranks)")