From ee843d445f6e84e3547538f4813afbabe13b350a Mon Sep 17 00:00:00 2001 From: Qinghua Zhou Date: Wed, 25 Feb 2026 12:39:48 +0000 Subject: [PATCH] Add test of real MoE workloads --- python/test/test_alltoallv_mscclpp.py | 179 +++++++++++++++++--------- 1 file changed, 117 insertions(+), 62 deletions(-) diff --git a/python/test/test_alltoallv_mscclpp.py b/python/test/test_alltoallv_mscclpp.py index 611726bb..275df78e 100644 --- a/python/test/test_alltoallv_mscclpp.py +++ b/python/test/test_alltoallv_mscclpp.py @@ -14,7 +14,7 @@ import torch.distributed as dist import os import time import random -from typing import Callable, List, Optional +from typing import Callable, List, Tuple # Must init torch.distributed before importing mscclpp modules # to set rank/world_size environment variables @@ -156,21 +156,7 @@ def main(): print(f" Local copy verified: {local_ok}") print(f" {'PASS' if local_ok else 'FAIL'}") - # ── Unified benchmark helper ────────────────────────────────────────── - def build_variable_send_matrix(avg_msg_size: int, world_size: int): - """Build a deterministic variable-size send matrix (0.5×–1.5× of avg).""" - random.seed(12345) - avg_elems = avg_msg_size // 4 # float32 - send_matrix = [] - for i in range(world_size): - row = [] - for j in range(world_size): - factor = 0.5 + random.random() - elems = max(1, int(avg_elems * factor)) - row.append(elems) - send_matrix.append(row) - return send_matrix - + # ── Shared benchmark helpers ────────────────────────────────────────── def bench_alltoallv( fn: Callable, input_tensor: torch.Tensor, @@ -179,8 +165,8 @@ def main(): output_split_sizes: List[int], n_warmup: int, n_iters: int, - ) -> tuple: - """Benchmark an all_to_all_single implementation. Returns (latency_us, algbw_gbps).""" + ) -> Tuple[float, float]: + """Benchmark an all_to_all_single impl. Returns (latency_us, algbw_gbps).""" for _ in range(n_warmup): fn(input_tensor, output_tensor, input_split_sizes, output_split_sizes) torch.cuda.synchronize() @@ -191,12 +177,12 @@ def main(): torch.cuda.synchronize() elapsed = time.perf_counter() - start - total_recv_bytes = sum(output_split_sizes) * 4 # float32 + elem_size = input_tensor.element_size() + total_recv_bytes = sum(output_split_sizes) * elem_size algbw = total_recv_bytes * n_iters / elapsed / 1e9 lat = elapsed / n_iters * 1e6 return lat, algbw - # Wrap mscclpp and torch.dist into the same calling convention def mscclpp_fn(inp, out, in_splits, out_splits): alltoallv.all_to_all_single(inp, output=out, input_split_sizes=in_splits, @@ -207,56 +193,125 @@ def main(): output_split_sizes=out_splits, input_split_sizes=in_splits) - # ── Test 3: Side-by-side comparison ─────────────────────────────────── - if rank == 0: - print("\n[Test 3] Variable-size benchmark: mscclpp vs torch.dist (1KB–128MB avg/peer)") - print(f" {'Avg Size':>10s} " - f"{'mscclpp Lat':>12s} {'mscclpp BW':>11s} " - f"{'torch Lat':>10s} {'torch BW':>9s} " - f"{'Speedup':>7s}") - print(f" {'-'*10} " - f"{'-'*12} {'-'*11} " - f"{'-'*10} {'-'*9} " - f"{'-'*7}") - - msg_sizes = [1 << s for s in range(10, 28) if s % 2 == 0] - msg_sizes.append(128 * 1024 * 1024) - - for avg_msg_size in msg_sizes: - send_matrix = build_variable_send_matrix(avg_msg_size, world_size) - - input_split_sizes = send_matrix[rank] - output_split_sizes = [send_matrix[j][rank] for j in range(world_size)] - - total_send = sum(input_split_sizes) - total_recv = sum(output_split_sizes) - - input_tensor = torch.randn(total_send, dtype=torch.float32, device='cuda') - output_tensor = torch.empty(total_recv, dtype=torch.float32, device='cuda') - - n_warmup = 3 if avg_msg_size >= 16 * 1024 * 1024 else 5 - n_iters = 5 if avg_msg_size >= 64 * 1024 * 1024 else (10 if avg_msg_size >= 4 * 1024 * 1024 else 20) - - m_lat, m_bw = bench_alltoallv(mscclpp_fn, input_tensor, output_tensor, - input_split_sizes, output_split_sizes, - n_warmup, n_iters) - t_lat, t_bw = bench_alltoallv(torch_fn, input_tensor, output_tensor, - input_split_sizes, output_split_sizes, - n_warmup, n_iters) + def fmt_size(nbytes: int) -> str: + if nbytes >= 1024 * 1024: + return f"{nbytes // (1024*1024)}MB" + elif nbytes >= 1024: + return f"{nbytes // 1024}KB" + return f"{nbytes}B" + def print_header(): + if rank == 0: + print(f" {'Avg Size':>10s} " + f"{'mscclpp Lat':>12s} {'mscclpp BW':>11s} " + f"{'torch Lat':>10s} {'torch BW':>9s} " + f"{'Speedup':>7s}") + print(f" {'-'*10} " + f"{'-'*12} {'-'*11} " + f"{'-'*10} {'-'*9} " + f"{'-'*7}") + + def print_row(size_str, m_lat, m_bw, t_lat, t_bw): if rank == 0: - if avg_msg_size >= 1024 * 1024: - size_str = f"{avg_msg_size // (1024*1024)}MB" - elif avg_msg_size >= 1024: - size_str = f"{avg_msg_size // 1024}KB" - else: - size_str = f"{avg_msg_size}B" speedup = m_bw / t_bw if t_bw > 0 else float('inf') print(f" {size_str:>10s} " f"{m_lat:>10.1f}us {m_bw:>9.2f}GB " f"{t_lat:>8.1f}us {t_bw:>7.2f}GB " f"{speedup:>6.2f}x") + # ── Test 3: Synthetic variable-size sweep ───────────────────────────── + if rank == 0: + print("\n[Test 3] Synthetic variable-size benchmark: mscclpp vs torch.dist") + print_header() + + msg_sizes = [1 << s for s in range(10, 28) if s % 2 == 0] + msg_sizes.append(128 * 1024 * 1024) + + for avg_msg_size in msg_sizes: + random.seed(12345) + avg_elems = avg_msg_size // 4 + send_matrix = [] + for i in range(world_size): + row = [max(1, int(avg_elems * (0.5 + random.random()))) for _ in range(world_size)] + send_matrix.append(row) + + in_splits = send_matrix[rank] + out_splits = [send_matrix[j][rank] for j in range(world_size)] + + inp = torch.randn(sum(in_splits), dtype=torch.float32, device='cuda') + out = torch.empty(sum(out_splits), dtype=torch.float32, device='cuda') + + n_warmup = 3 if avg_msg_size >= 16 * 1024 * 1024 else 5 + n_iters = 5 if avg_msg_size >= 64 * 1024 * 1024 else (10 if avg_msg_size >= 4 * 1024 * 1024 else 20) + + m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp, out, in_splits, out_splits, n_warmup, n_iters) + t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters) + print_row(fmt_size(avg_msg_size), m_lat, m_bw, t_lat, t_bw) + + # ── Test 4: Real MoE workloads ─────────────────────────────────────── + # Token counts from real MoE training runs (rank 0's view, 8 GPUs). + # Each token = 5120 bytes (hidden_dim=2560, bf16). + # We use bf16 with 2560 elements/token to match real workload dtype. + # + # To build a consistent 8×8 send matrix, we rotate the input_tokens + # per rank so every rank has the same total send and each NVLink + # carries a realistically imbalanced load. + + MOE_WORKLOADS = [ + { + "name": "MoE-A", + # input_splits=[3976,3916,4497,4838,2888,3839,4355,4459] + # total_send=167,772,160 total_recv=148,316,160 + "input_tokens": [3976, 3916, 4497, 4838, 2888, 3839, 4355, 4459], + }, + { + "name": "MoE-B", + # input_splits=[3009,7161,2719,2766,3428,3010,6290,4385] + # total_send=167,772,160 total_recv=163,722,240 + "input_tokens": [3009, 7161, 2719, 2766, 3428, 3010, 6290, 4385], + }, + ] + ELEMS_PER_TOKEN = 2560 # 5120 bytes / 2 bytes-per-bfloat16 + + if world_size == 8: + if rank == 0: + print(f"\n[Test 4] Real MoE workloads (hidden=2560, bf16, 8 GPUs)") + + for wl_idx, wl in enumerate(MOE_WORKLOADS): + tokens = wl["input_tokens"] + min_tok, max_tok = min(tokens), max(tokens) + imbalance = max_tok / min_tok + total_bytes = sum(tokens) * 5120 + + if rank == 0: + print(f"\n {wl['name']}: {sum(tokens)} tokens/rank, " + f"{total_bytes / 1e6:.1f}MB, imbalance={imbalance:.1f}x") + print(f" Token distribution: {tokens}") + print_header() + + # Build consistent send_matrix: rotate token list per rank + moe_send_matrix = [] + for i in range(world_size): + row = tokens[i:] + tokens[:i] + moe_send_matrix.append(row) + + in_splits = [moe_send_matrix[rank][j] * ELEMS_PER_TOKEN for j in range(world_size)] + out_splits = [moe_send_matrix[j][rank] * ELEMS_PER_TOKEN for j in range(world_size)] + + inp = torch.randn(sum(in_splits), dtype=torch.bfloat16, device='cuda') + out = torch.empty(sum(out_splits), dtype=torch.bfloat16, device='cuda') + + n_warmup, n_iters = 5, 20 + + m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp, out, in_splits, out_splits, n_warmup, n_iters) + t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters) + + avg_bytes = total_bytes // world_size + print_row(fmt_size(avg_bytes), m_lat, m_bw, t_lat, t_bw) + else: + if rank == 0: + print("\n[Test 4] Skipped (real MoE workloads require exactly 8 ranks)") + # Cleanup dist.barrier() if rank == 0: