Files
mscclpp/python/test/test_alltoallv_mscclpp.py

334 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Test script for MscclppAlltoAllV with optimized C++ kernels.
Uses MPI bootstrap for mscclpp and NCCL backend for torch.distributed.
Usage:
mpirun -np N python test_alltoallv_mscclpp.py
"""
import torch
import torch.distributed as dist
import os
import time
import random
from typing import Callable, List, Tuple
# Must init torch.distributed before importing mscclpp modules
# to set rank/world_size environment variables
def main():
# Get rank/world from MPI environment
rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", os.environ.get("PMI_RANK", 0)))
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", os.environ.get("PMI_SIZE", 1)))
# Set CUDA device — prefer MPI-provided local rank to handle any rank mapping
local_rank = int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK",
os.environ.get("MPI_LOCALRANKID",
os.environ.get("LOCAL_RANK", rank % torch.cuda.device_count()))))
torch.cuda.set_device(local_rank)
# Initialize torch.distributed — use NCCL when torch_fn benchmarks are needed,
# otherwise gloo avoids IB configuration issues on some clusters.
# Set ALLTOALLV_BACKEND=nccl to enable torch baseline comparison.
backend = os.environ.get("ALLTOALLV_BACKEND", "gloo")
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
os.environ.setdefault("MASTER_PORT", "29500")
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
if backend == "nccl":
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size,
device_id=torch.device(f"cuda:{local_rank}"))
else:
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
if rank == 0:
print(f"Testing MscclppAlltoAllV with {world_size} ranks")
print("=" * 60)
# Import after torch.distributed init
from mscclpp._mscclpp import (
Communicator,
TcpBootstrap,
UniqueId,
)
from mscclpp.ext.alltoallv_single import MscclppAlltoAllV
from mpi4py import MPI
mpi_comm = MPI.COMM_WORLD
# Create mscclpp communicator with TcpBootstrap
# Broadcast UniqueId raw bytes (128 bytes) via MPI to avoid NCCL interception issues
bootstrap = TcpBootstrap(rank, world_size)
if rank == 0:
unique_id = bootstrap.create_unique_id()
else:
unique_id = UniqueId()
# UniqueId supports pickle (__getstate__/__setstate__), MPI bcast uses pickle
unique_id = mpi_comm.bcast(unique_id, root=0)
bootstrap.initialize(unique_id)
comm = Communicator(bootstrap)
# Create MscclppAlltoAllV with existing communicator
alltoallv = MscclppAlltoAllV(communicator=comm)
if rank == 0:
print(f"MscclppAlltoAllV initialized")
print(f"Algorithm: {alltoallv._algo.name}")
# Test 1: Uniform all-to-all (equal splits)
if rank == 0:
print("\n[Test 1] Uniform all-to-all (1024 elements per rank)")
chunk_size = 1024
input_data = torch.arange(
rank * world_size * chunk_size,
(rank + 1) * world_size * chunk_size,
dtype=torch.float32,
device='cuda'
)
output = alltoallv.all_to_all_single(input_data)
# Verify: each chunk should come from different ranks
torch.cuda.synchronize()
expected_total = sum(r * world_size * chunk_size for r in range(world_size))
actual_total = output[:chunk_size].sum().item() # Just check first chunk is from rank 0
expected = 0 * world_size * chunk_size + sum(range(chunk_size))
if rank == 0:
print(f" First chunk sum: {actual_total}, expected ~{expected}")
print(f" PASS" if abs(actual_total - expected) < 1 else f" FAIL")
# Test 2: Variable-size all-to-all (simulating MoE)
if rank == 0:
print("\n[Test 2] Variable-size all-to-all (MoE-like)")
# Simulate MoE token distribution with imbalanced routing.
# Build a full send matrix so each rank has different per-peer sizes.
# send_matrix[i][j] = number of elements rank i sends to rank j.
# For consistency: rank i's output_split[j] = send_matrix[j][i].
import random
random.seed(42)
send_matrix = []
for i in range(world_size):
row = [random.randint(128, 2048) for _ in range(world_size)]
send_matrix.append(row)
input_split_sizes = send_matrix[rank] # what this rank sends to each peer
output_split_sizes = [send_matrix[j][rank] for j in range(world_size)] # what this rank receives from each peer
total_input = sum(input_split_sizes)
total_output = sum(output_split_sizes)
# Fill input with rank-specific pattern for verification
input_tensor = torch.arange(total_input, dtype=torch.float32, device='cuda') + rank * 100000
output_tensor = torch.empty(total_output, dtype=torch.float32, device='cuda')
output = alltoallv.all_to_all_single(
input_tensor,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
output=output_tensor
)
torch.cuda.synchronize()
# Verify: the local-to-local segment should match exactly
local_send_offset = sum(input_split_sizes[:rank])
local_recv_offset = sum(output_split_sizes[:rank])
local_size = input_split_sizes[rank] # == output_split_sizes[rank]
expected_local = input_tensor[local_send_offset:local_send_offset + local_size]
actual_local = output_tensor[local_recv_offset:local_recv_offset + local_size]
local_ok = torch.allclose(expected_local, actual_local)
if rank == 0:
print(f" Send matrix row (rank 0 sends): {input_split_sizes}")
print(f" Recv sizes (rank 0 receives): {output_split_sizes}")
print(f" Input total: {total_input}, Output total: {total_output}")
print(f" Local copy verified: {local_ok}")
print(f" {'PASS' if local_ok else 'FAIL'}")
# ── Shared benchmark helpers ──────────────────────────────────────────
def bench_alltoallv(
fn: Callable,
input_tensor: torch.Tensor,
output_tensor: torch.Tensor,
input_split_sizes: List[int],
output_split_sizes: List[int],
n_warmup: int,
n_iters: int,
) -> Tuple[float, float]:
"""Benchmark an all_to_all_single impl. Returns (latency_us, algbw_gbps)."""
for _ in range(n_warmup):
fn(input_tensor, output_tensor, input_split_sizes, output_split_sizes)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(n_iters):
fn(input_tensor, output_tensor, input_split_sizes, output_split_sizes)
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
elem_size = input_tensor.element_size()
total_recv_bytes = sum(output_split_sizes) * elem_size
algbw = total_recv_bytes * n_iters / elapsed / 1e9
lat = elapsed / n_iters * 1e6
return lat, algbw
def mscclpp_fn(inp, out, in_splits, out_splits):
alltoallv.all_to_all_single(inp, output=out,
input_split_sizes=in_splits,
output_split_sizes=out_splits)
def torch_fn(inp, out, in_splits, out_splits):
dist.all_to_all_single(out, inp,
output_split_sizes=out_splits,
input_split_sizes=in_splits)
def fmt_size(nbytes: int) -> str:
if nbytes >= 1024 * 1024:
return f"{nbytes // (1024*1024)}MB"
elif nbytes >= 1024:
return f"{nbytes // 1024}KB"
return f"{nbytes}B"
def print_header():
if rank == 0:
print(f" {'Avg Size':>10s} "
f"{'mscclpp Lat':>12s} {'mscclpp BW':>11s} "
f"{'torch Lat':>10s} {'torch BW':>9s} "
f"{'Speedup':>7s}")
print(f" {'-'*10} "
f"{'-'*12} {'-'*11} "
f"{'-'*10} {'-'*9} "
f"{'-'*7}")
def print_row(size_str, m_lat, m_bw, t_lat, t_bw):
if rank == 0:
speedup = m_bw / t_bw if t_bw > 0 else float('inf')
print(f" {size_str:>10s} "
f"{m_lat:>10.1f}us {m_bw:>9.2f}GB "
f"{t_lat:>8.1f}us {t_bw:>7.2f}GB "
f"{speedup:>6.2f}x")
# ── Test 3: Synthetic variable-size sweep ─────────────────────────────
if rank == 0:
print("\n[Test 3] Synthetic variable-size benchmark: mscclpp vs torch.dist")
print_header()
msg_sizes = [1 << s for s in range(10, 28) if s % 2 == 0]
msg_sizes.append(128 * 1024 * 1024)
for avg_msg_size in msg_sizes:
random.seed(12345)
avg_elems = avg_msg_size // 4
send_matrix = []
for i in range(world_size):
row = [max(1, int(avg_elems * (0.5 + random.random()))) for _ in range(world_size)]
send_matrix.append(row)
in_splits = send_matrix[rank]
out_splits = [send_matrix[j][rank] for j in range(world_size)]
inp = torch.randn(sum(in_splits), dtype=torch.float32, device='cuda')
out = torch.empty(sum(out_splits), dtype=torch.float32, device='cuda')
n_warmup = 3 if avg_msg_size >= 16 * 1024 * 1024 else 5
n_iters = 5 if avg_msg_size >= 64 * 1024 * 1024 else (10 if avg_msg_size >= 4 * 1024 * 1024 else 20)
m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
print_row(fmt_size(avg_msg_size), m_lat, m_bw, t_lat, t_bw)
# ── Test 4: Real MoE workloads ───────────────────────────────────────
# Token counts from real MoE training runs (rank 0's view, 8 GPUs).
# Each token = 5120 bytes (hidden_dim=2560, bf16).
# We use bf16 with 2560 elements/token to match real workload dtype.
#
# To build a consistent 8×8 send matrix, we rotate the input_tokens
# per rank so every rank has the same total send and each NVLink
# carries a realistically imbalanced load.
# 10 workloads picked from 3M dispatch records in a real MoE training run,
# covering the full imbalance spectrum from nearly uniform (1.05×) to
# extremely skewed (10×). Each has 32768 total tokens → 167.8MB.
MOE_WORKLOADS = [
{"name": "MoE-A", # imbalance ≈ 1.05× (near-uniform)
"input_tokens": [4122, 4115, 4000, 4200, 4126, 4046, 4035, 4124]},
{"name": "MoE-B", # imbalance ≈ 1.20×
"input_tokens": [3770, 4236, 3966, 4046, 4524, 4132, 3825, 4269]},
{"name": "MoE-C", # imbalance ≈ 1.35×
"input_tokens": [4142, 4489, 4563, 3380, 3957, 4133, 3958, 4146]},
{"name": "MoE-D", # imbalance ≈ 1.50× (median)
"input_tokens": [4232, 3697, 4619, 4788, 4420, 3192, 3971, 3849]},
{"name": "MoE-E", # imbalance ≈ 1.75×
"input_tokens": [4178, 3209, 4678, 5085, 3108, 3365, 5439, 3706]},
{"name": "MoE-F", # imbalance ≈ 2.00×
"input_tokens": [4582, 3903, 3949, 3727, 4823, 5106, 2553, 4125]},
{"name": "MoE-G", # imbalance ≈ 2.50×
"input_tokens": [4036, 4438, 4804, 6180, 2913, 2472, 4105, 3820]},
{"name": "MoE-H", # imbalance ≈ 3.50×
"input_tokens": [3152, 1722, 4406, 4027, 5365, 6027, 4895, 3174]},
{"name": "MoE-I", # imbalance ≈ 5.00×
"input_tokens": [4384, 4194, 7840, 3079, 3460, 3506, 1568, 4737]},
{"name": "MoE-J", # imbalance ≈ 10.00× (extreme skew)
"input_tokens": [2710, 7661, 3354, 4457, 4609, 766, 3423, 5788]},
]
ELEMS_PER_TOKEN = 2560 # 5120 bytes / 2 bytes-per-bfloat16
if world_size == 8:
if rank == 0:
print(f"\n[Test 4] Real MoE workloads (hidden=2560, bf16, 8 GPUs)")
for wl_idx, wl in enumerate(MOE_WORKLOADS):
tokens = wl["input_tokens"]
min_tok, max_tok = min(tokens), max(tokens)
imbalance = max_tok / min_tok
total_bytes = sum(tokens) * 5120
if rank == 0:
print(f"\n {wl['name']}: {sum(tokens)} tokens/rank, "
f"{total_bytes / 1e6:.1f}MB, imbalance={imbalance:.1f}x")
print(f" Token distribution: {tokens}")
print_header()
# Build consistent send_matrix: rotate token list per rank
moe_send_matrix = []
for i in range(world_size):
row = tokens[i:] + tokens[:i]
moe_send_matrix.append(row)
in_splits = [moe_send_matrix[rank][j] * ELEMS_PER_TOKEN for j in range(world_size)]
out_splits = [moe_send_matrix[j][rank] * ELEMS_PER_TOKEN for j in range(world_size)]
inp = torch.randn(sum(in_splits), dtype=torch.bfloat16, device='cuda')
out = torch.empty(sum(out_splits), dtype=torch.bfloat16, device='cuda')
n_warmup, n_iters = 5, 20
m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
avg_bytes = total_bytes // world_size
print_row(fmt_size(avg_bytes), m_lat, m_bw, t_lat, t_bw)
else:
if rank == 0:
print("\n[Test 4] Skipped (real MoE workloads require exactly 8 ranks)")
# Cleanup
dist.barrier()
if rank == 0:
print("\n" + "=" * 60)
print("All tests passed!")
dist.destroy_process_group()
if __name__ == "__main__":
main()