mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
Add test of real MoE workloads
This commit is contained in:
@@ -14,7 +14,7 @@ import torch.distributed as dist
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
# Must init torch.distributed before importing mscclpp modules
|
||||
# to set rank/world_size environment variables
|
||||
@@ -156,21 +156,7 @@ def main():
|
||||
print(f" Local copy verified: {local_ok}")
|
||||
print(f" {'PASS' if local_ok else 'FAIL'}")
|
||||
|
||||
# ── Unified benchmark helper ──────────────────────────────────────────
|
||||
def build_variable_send_matrix(avg_msg_size: int, world_size: int):
|
||||
"""Build a deterministic variable-size send matrix (0.5×–1.5× of avg)."""
|
||||
random.seed(12345)
|
||||
avg_elems = avg_msg_size // 4 # float32
|
||||
send_matrix = []
|
||||
for i in range(world_size):
|
||||
row = []
|
||||
for j in range(world_size):
|
||||
factor = 0.5 + random.random()
|
||||
elems = max(1, int(avg_elems * factor))
|
||||
row.append(elems)
|
||||
send_matrix.append(row)
|
||||
return send_matrix
|
||||
|
||||
# ── Shared benchmark helpers ──────────────────────────────────────────
|
||||
def bench_alltoallv(
|
||||
fn: Callable,
|
||||
input_tensor: torch.Tensor,
|
||||
@@ -179,8 +165,8 @@ def main():
|
||||
output_split_sizes: List[int],
|
||||
n_warmup: int,
|
||||
n_iters: int,
|
||||
) -> tuple:
|
||||
"""Benchmark an all_to_all_single implementation. Returns (latency_us, algbw_gbps)."""
|
||||
) -> Tuple[float, float]:
|
||||
"""Benchmark an all_to_all_single impl. Returns (latency_us, algbw_gbps)."""
|
||||
for _ in range(n_warmup):
|
||||
fn(input_tensor, output_tensor, input_split_sizes, output_split_sizes)
|
||||
torch.cuda.synchronize()
|
||||
@@ -191,12 +177,12 @@ def main():
|
||||
torch.cuda.synchronize()
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
total_recv_bytes = sum(output_split_sizes) * 4 # float32
|
||||
elem_size = input_tensor.element_size()
|
||||
total_recv_bytes = sum(output_split_sizes) * elem_size
|
||||
algbw = total_recv_bytes * n_iters / elapsed / 1e9
|
||||
lat = elapsed / n_iters * 1e6
|
||||
return lat, algbw
|
||||
|
||||
# Wrap mscclpp and torch.dist into the same calling convention
|
||||
def mscclpp_fn(inp, out, in_splits, out_splits):
|
||||
alltoallv.all_to_all_single(inp, output=out,
|
||||
input_split_sizes=in_splits,
|
||||
@@ -207,56 +193,125 @@ def main():
|
||||
output_split_sizes=out_splits,
|
||||
input_split_sizes=in_splits)
|
||||
|
||||
# ── Test 3: Side-by-side comparison ───────────────────────────────────
|
||||
if rank == 0:
|
||||
print("\n[Test 3] Variable-size benchmark: mscclpp vs torch.dist (1KB–128MB avg/peer)")
|
||||
print(f" {'Avg Size':>10s} "
|
||||
f"{'mscclpp Lat':>12s} {'mscclpp BW':>11s} "
|
||||
f"{'torch Lat':>10s} {'torch BW':>9s} "
|
||||
f"{'Speedup':>7s}")
|
||||
print(f" {'-'*10} "
|
||||
f"{'-'*12} {'-'*11} "
|
||||
f"{'-'*10} {'-'*9} "
|
||||
f"{'-'*7}")
|
||||
|
||||
msg_sizes = [1 << s for s in range(10, 28) if s % 2 == 0]
|
||||
msg_sizes.append(128 * 1024 * 1024)
|
||||
|
||||
for avg_msg_size in msg_sizes:
|
||||
send_matrix = build_variable_send_matrix(avg_msg_size, world_size)
|
||||
|
||||
input_split_sizes = send_matrix[rank]
|
||||
output_split_sizes = [send_matrix[j][rank] for j in range(world_size)]
|
||||
|
||||
total_send = sum(input_split_sizes)
|
||||
total_recv = sum(output_split_sizes)
|
||||
|
||||
input_tensor = torch.randn(total_send, dtype=torch.float32, device='cuda')
|
||||
output_tensor = torch.empty(total_recv, dtype=torch.float32, device='cuda')
|
||||
|
||||
n_warmup = 3 if avg_msg_size >= 16 * 1024 * 1024 else 5
|
||||
n_iters = 5 if avg_msg_size >= 64 * 1024 * 1024 else (10 if avg_msg_size >= 4 * 1024 * 1024 else 20)
|
||||
|
||||
m_lat, m_bw = bench_alltoallv(mscclpp_fn, input_tensor, output_tensor,
|
||||
input_split_sizes, output_split_sizes,
|
||||
n_warmup, n_iters)
|
||||
t_lat, t_bw = bench_alltoallv(torch_fn, input_tensor, output_tensor,
|
||||
input_split_sizes, output_split_sizes,
|
||||
n_warmup, n_iters)
|
||||
def fmt_size(nbytes: int) -> str:
|
||||
if nbytes >= 1024 * 1024:
|
||||
return f"{nbytes // (1024*1024)}MB"
|
||||
elif nbytes >= 1024:
|
||||
return f"{nbytes // 1024}KB"
|
||||
return f"{nbytes}B"
|
||||
|
||||
def print_header():
|
||||
if rank == 0:
|
||||
print(f" {'Avg Size':>10s} "
|
||||
f"{'mscclpp Lat':>12s} {'mscclpp BW':>11s} "
|
||||
f"{'torch Lat':>10s} {'torch BW':>9s} "
|
||||
f"{'Speedup':>7s}")
|
||||
print(f" {'-'*10} "
|
||||
f"{'-'*12} {'-'*11} "
|
||||
f"{'-'*10} {'-'*9} "
|
||||
f"{'-'*7}")
|
||||
|
||||
def print_row(size_str, m_lat, m_bw, t_lat, t_bw):
|
||||
if rank == 0:
|
||||
if avg_msg_size >= 1024 * 1024:
|
||||
size_str = f"{avg_msg_size // (1024*1024)}MB"
|
||||
elif avg_msg_size >= 1024:
|
||||
size_str = f"{avg_msg_size // 1024}KB"
|
||||
else:
|
||||
size_str = f"{avg_msg_size}B"
|
||||
speedup = m_bw / t_bw if t_bw > 0 else float('inf')
|
||||
print(f" {size_str:>10s} "
|
||||
f"{m_lat:>10.1f}us {m_bw:>9.2f}GB "
|
||||
f"{t_lat:>8.1f}us {t_bw:>7.2f}GB "
|
||||
f"{speedup:>6.2f}x")
|
||||
|
||||
# ── Test 3: Synthetic variable-size sweep ─────────────────────────────
|
||||
if rank == 0:
|
||||
print("\n[Test 3] Synthetic variable-size benchmark: mscclpp vs torch.dist")
|
||||
print_header()
|
||||
|
||||
msg_sizes = [1 << s for s in range(10, 28) if s % 2 == 0]
|
||||
msg_sizes.append(128 * 1024 * 1024)
|
||||
|
||||
for avg_msg_size in msg_sizes:
|
||||
random.seed(12345)
|
||||
avg_elems = avg_msg_size // 4
|
||||
send_matrix = []
|
||||
for i in range(world_size):
|
||||
row = [max(1, int(avg_elems * (0.5 + random.random()))) for _ in range(world_size)]
|
||||
send_matrix.append(row)
|
||||
|
||||
in_splits = send_matrix[rank]
|
||||
out_splits = [send_matrix[j][rank] for j in range(world_size)]
|
||||
|
||||
inp = torch.randn(sum(in_splits), dtype=torch.float32, device='cuda')
|
||||
out = torch.empty(sum(out_splits), dtype=torch.float32, device='cuda')
|
||||
|
||||
n_warmup = 3 if avg_msg_size >= 16 * 1024 * 1024 else 5
|
||||
n_iters = 5 if avg_msg_size >= 64 * 1024 * 1024 else (10 if avg_msg_size >= 4 * 1024 * 1024 else 20)
|
||||
|
||||
m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
print_row(fmt_size(avg_msg_size), m_lat, m_bw, t_lat, t_bw)
|
||||
|
||||
# ── Test 4: Real MoE workloads ───────────────────────────────────────
|
||||
# Token counts from real MoE training runs (rank 0's view, 8 GPUs).
|
||||
# Each token = 5120 bytes (hidden_dim=2560, bf16).
|
||||
# We use bf16 with 2560 elements/token to match real workload dtype.
|
||||
#
|
||||
# To build a consistent 8×8 send matrix, we rotate the input_tokens
|
||||
# per rank so every rank has the same total send and each NVLink
|
||||
# carries a realistically imbalanced load.
|
||||
|
||||
MOE_WORKLOADS = [
|
||||
{
|
||||
"name": "MoE-A",
|
||||
# input_splits=[3976,3916,4497,4838,2888,3839,4355,4459]
|
||||
# total_send=167,772,160 total_recv=148,316,160
|
||||
"input_tokens": [3976, 3916, 4497, 4838, 2888, 3839, 4355, 4459],
|
||||
},
|
||||
{
|
||||
"name": "MoE-B",
|
||||
# input_splits=[3009,7161,2719,2766,3428,3010,6290,4385]
|
||||
# total_send=167,772,160 total_recv=163,722,240
|
||||
"input_tokens": [3009, 7161, 2719, 2766, 3428, 3010, 6290, 4385],
|
||||
},
|
||||
]
|
||||
ELEMS_PER_TOKEN = 2560 # 5120 bytes / 2 bytes-per-bfloat16
|
||||
|
||||
if world_size == 8:
|
||||
if rank == 0:
|
||||
print(f"\n[Test 4] Real MoE workloads (hidden=2560, bf16, 8 GPUs)")
|
||||
|
||||
for wl_idx, wl in enumerate(MOE_WORKLOADS):
|
||||
tokens = wl["input_tokens"]
|
||||
min_tok, max_tok = min(tokens), max(tokens)
|
||||
imbalance = max_tok / min_tok
|
||||
total_bytes = sum(tokens) * 5120
|
||||
|
||||
if rank == 0:
|
||||
print(f"\n {wl['name']}: {sum(tokens)} tokens/rank, "
|
||||
f"{total_bytes / 1e6:.1f}MB, imbalance={imbalance:.1f}x")
|
||||
print(f" Token distribution: {tokens}")
|
||||
print_header()
|
||||
|
||||
# Build consistent send_matrix: rotate token list per rank
|
||||
moe_send_matrix = []
|
||||
for i in range(world_size):
|
||||
row = tokens[i:] + tokens[:i]
|
||||
moe_send_matrix.append(row)
|
||||
|
||||
in_splits = [moe_send_matrix[rank][j] * ELEMS_PER_TOKEN for j in range(world_size)]
|
||||
out_splits = [moe_send_matrix[j][rank] * ELEMS_PER_TOKEN for j in range(world_size)]
|
||||
|
||||
inp = torch.randn(sum(in_splits), dtype=torch.bfloat16, device='cuda')
|
||||
out = torch.empty(sum(out_splits), dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
n_warmup, n_iters = 5, 20
|
||||
|
||||
m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
|
||||
avg_bytes = total_bytes // world_size
|
||||
print_row(fmt_size(avg_bytes), m_lat, m_bw, t_lat, t_bw)
|
||||
else:
|
||||
if rank == 0:
|
||||
print("\n[Test 4] Skipped (real MoE workloads require exactly 8 ranks)")
|
||||
|
||||
# Cleanup
|
||||
dist.barrier()
|
||||
if rank == 0:
|
||||
|
||||
Reference in New Issue
Block a user