Files
mscclpp/python/test/test_alltoallv_mscclpp.py
2026-02-23 14:22:30 +00:00

214 lines
8.1 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Test script for MscclppAlltoAllV with optimized C++ kernels.
Uses MPI bootstrap for mscclpp and NCCL backend for torch.distributed.
Usage:
mpirun -np N python test_alltoallv_mscclpp.py
"""
import torch
import torch.distributed as dist
import os
import time
# Must init torch.distributed before importing mscclpp modules
# to set rank/world_size environment variables
def main():
# Get rank/world from MPI environment
rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", os.environ.get("PMI_RANK", 0)))
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", os.environ.get("PMI_SIZE", 1)))
# Set CUDA device
local_rank = int(os.environ.get("LOCAL_RANK", rank % torch.cuda.device_count()))
torch.cuda.set_device(local_rank)
# Initialize torch.distributed with NCCL (need MASTER_ADDR/PORT)
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
os.environ.setdefault("MASTER_PORT", "29500")
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size,
device_id=torch.device(f"cuda:{local_rank}"))
if rank == 0:
print(f"Testing MscclppAlltoAllV with {world_size} ranks")
print("=" * 60)
# Import after torch.distributed init
from mscclpp._mscclpp import (
Communicator,
TcpBootstrap,
UniqueId,
)
from mscclpp.ext.alltoallv_single import MscclppAlltoAllV
import pickle
# Create mscclpp communicator with TcpBootstrap
# Use torch.distributed to share the unique ID via pickle
bootstrap = TcpBootstrap(rank, world_size)
if rank == 0:
unique_id = bootstrap.create_unique_id()
# Serialize UniqueId via pickle and broadcast
pickled = pickle.dumps(unique_id)
id_tensor = torch.zeros(256, dtype=torch.uint8, device='cuda')
id_tensor[:len(pickled)] = torch.tensor(list(pickled), dtype=torch.uint8)
# Also send length
len_tensor = torch.tensor([len(pickled)], dtype=torch.int64, device='cuda')
else:
id_tensor = torch.zeros(256, dtype=torch.uint8, device='cuda')
len_tensor = torch.zeros(1, dtype=torch.int64, device='cuda')
dist.broadcast(len_tensor, src=0)
dist.broadcast(id_tensor, src=0)
if rank != 0:
pickled_len = int(len_tensor.item())
pickled = bytes(id_tensor[:pickled_len].cpu().tolist())
unique_id = pickle.loads(pickled)
bootstrap.initialize(unique_id)
comm = Communicator(bootstrap)
# Create MscclppAlltoAllV with existing communicator
alltoallv = MscclppAlltoAllV(communicator=comm)
if rank == 0:
print(f"MscclppAlltoAllV initialized")
print(f"Algorithm: {alltoallv._algo.name}")
# Test 1: Uniform all-to-all (equal splits)
if rank == 0:
print("\n[Test 1] Uniform all-to-all (1024 elements per rank)")
chunk_size = 1024
input_data = torch.arange(
rank * world_size * chunk_size,
(rank + 1) * world_size * chunk_size,
dtype=torch.float32,
device='cuda'
)
output = alltoallv.all_to_all_single(input_data)
# Verify: each chunk should come from different ranks
torch.cuda.synchronize()
expected_total = sum(r * world_size * chunk_size for r in range(world_size))
actual_total = output[:chunk_size].sum().item() # Just check first chunk is from rank 0
expected = 0 * world_size * chunk_size + sum(range(chunk_size))
if rank == 0:
print(f" First chunk sum: {actual_total}, expected ~{expected}")
print(f" PASS" if abs(actual_total - expected) < 1 else f" FAIL")
# Test 2: Variable-size all-to-all (simulating MoE)
if rank == 0:
print("\n[Test 2] Variable-size all-to-all (MoE-like)")
# Simulate MoE token distribution with imbalanced routing.
# Build a full send matrix so each rank has different per-peer sizes.
# send_matrix[i][j] = number of elements rank i sends to rank j.
# For consistency: rank i's output_split[j] = send_matrix[j][i].
import random
random.seed(42)
send_matrix = []
for i in range(world_size):
row = [random.randint(128, 2048) for _ in range(world_size)]
send_matrix.append(row)
input_split_sizes = send_matrix[rank] # what this rank sends to each peer
output_split_sizes = [send_matrix[j][rank] for j in range(world_size)] # what this rank receives from each peer
total_input = sum(input_split_sizes)
total_output = sum(output_split_sizes)
# Fill input with rank-specific pattern for verification
input_tensor = torch.arange(total_input, dtype=torch.float32, device='cuda') + rank * 100000
output_tensor = torch.empty(total_output, dtype=torch.float32, device='cuda')
output = alltoallv.all_to_all_single(
input_tensor,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
output=output_tensor
)
torch.cuda.synchronize()
# Verify: the local-to-local segment should match exactly
local_send_offset = sum(input_split_sizes[:rank])
local_recv_offset = sum(output_split_sizes[:rank])
local_size = input_split_sizes[rank] # == output_split_sizes[rank]
expected_local = input_tensor[local_send_offset:local_send_offset + local_size]
actual_local = output_tensor[local_recv_offset:local_recv_offset + local_size]
local_ok = torch.allclose(expected_local, actual_local)
if rank == 0:
print(f" Send matrix row (rank 0 sends): {input_split_sizes}")
print(f" Recv sizes (rank 0 receives): {output_split_sizes}")
print(f" Input total: {total_input}, Output total: {total_output}")
print(f" Local copy verified: {local_ok}")
print(f" {'PASS' if local_ok else 'FAIL'}")
# Test 3: Performance benchmark across message sizes (1KB to 128MB)
if rank == 0:
print("\n[Test 3] Performance benchmark (1KB to 128MB per rank)")
print(f" {'Msg Size':>10s} {'Iters':>5s} {'Total (ms)':>10s} {'Lat (us)':>10s} {'BW (GB/s)':>10s}")
print(f" {'-'*10} {'-'*5} {'-'*10} {'-'*10} {'-'*10}")
# Message sizes: 1KB, 4KB, 16KB, 64KB, 256KB, 1MB, 4MB, 16MB, 64MB, 128MB
msg_sizes = [1 << s for s in range(10, 28) if s % 2 == 0] # powers of 4 from 1KB to 64MB
msg_sizes.append(128 * 1024 * 1024) # add 128MB
for msg_size in msg_sizes:
input_size = msg_size * world_size
n_elems = input_size // 4 # float32 = 4 bytes
input_tensor = torch.randn(n_elems, dtype=torch.float32, device='cuda')
output_tensor = torch.empty_like(input_tensor)
# Fewer warmup/iters for very large sizes
n_warmup = 3 if msg_size >= 16 * 1024 * 1024 else 5
n_iters = 5 if msg_size >= 64 * 1024 * 1024 else (10 if msg_size >= 4 * 1024 * 1024 else 20)
# Warmup
for _ in range(n_warmup):
alltoallv.all_to_all_single(input_tensor, output=output_tensor)
torch.cuda.synchronize()
# Benchmark
start = time.perf_counter()
for _ in range(n_iters):
alltoallv.all_to_all_single(input_tensor, output=output_tensor)
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
total_bytes = 2 * input_size * n_iters # read + write
bandwidth_gbps = total_bytes / elapsed / 1e9
latency_us = elapsed / n_iters * 1e6
if rank == 0:
if msg_size >= 1024 * 1024:
size_str = f"{msg_size // (1024*1024)}MB"
elif msg_size >= 1024:
size_str = f"{msg_size // 1024}KB"
else:
size_str = f"{msg_size}B"
print(f" {size_str:>10s} {n_iters:>5d} {elapsed*1000:>10.2f} {latency_us:>10.1f} {bandwidth_gbps:>10.2f}")
# Cleanup
dist.barrier()
if rank == 0:
print("\n" + "=" * 60)
print("All tests passed!")
dist.destroy_process_group()
if __name__ == "__main__":
main()