mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-20 06:49:29 +00:00
Match with message size of NCCL EP bench test
This commit is contained in:
@@ -200,6 +200,14 @@ def main():
|
||||
return f"{nbytes // 1024}KB"
|
||||
return f"{nbytes}B"
|
||||
|
||||
def fmt_size_decimal(nbytes: int) -> str:
|
||||
"""Format size using decimal MB (÷1000000) to match NCCL EP reporting."""
|
||||
if nbytes >= 1000000:
|
||||
return f"{nbytes / 1000000:.2f}MB"
|
||||
elif nbytes >= 1000:
|
||||
return f"{nbytes / 1000:.1f}KB"
|
||||
return f"{nbytes}B"
|
||||
|
||||
def print_header():
|
||||
if rank == 0:
|
||||
print(f" {'Avg Size':>10s} "
|
||||
@@ -324,30 +332,37 @@ def main():
|
||||
print("\n[Test 4] Skipped (real MoE workloads require exactly 8 ranks)")
|
||||
|
||||
# ── Test 5: NCCL EP Low-Latency equivalent workload ──────────────────
|
||||
# Detect if torch baseline is available for Tests 5 & 6
|
||||
use_torch_baseline = True
|
||||
try:
|
||||
tiny_in = torch.zeros(world_size, dtype=torch.float32, device='cuda')
|
||||
tiny_out = torch.zeros(world_size, dtype=torch.float32, device='cuda')
|
||||
dist.all_to_all_single(tiny_out, tiny_in)
|
||||
except Exception:
|
||||
use_torch_baseline = False
|
||||
if rank == 0:
|
||||
print(" [INFO] torch all_to_all_single unavailable, skipping torch baseline in Tests 5/6")
|
||||
|
||||
# Matches the data volume of:
|
||||
# mpirun -np 8 ep_bench -a ll -t 128 -d 7168
|
||||
# mpirun -np N ep_bench -a ll -t 128 -d 7168
|
||||
#
|
||||
# ep_bench LL config: 128 tokens/rank, 256 experts (32/rank), top_k=8,
|
||||
# ep_bench LL config: 128 tokens/rank, 256 experts, top_k=8,
|
||||
# hidden=7168, bf16.
|
||||
# Target byte counts: dispatch=14.55 MB, combine=14.55 MB, selections=1015
|
||||
#
|
||||
# Expert assignment: for each token, generate 256 scores = abs(N(0,1))+1,
|
||||
# pick top-8 expert indices. Then mask ~10 random (token,k) slots with -1.
|
||||
# pick top-8 expert indices. Then mask 9 random (token,k) slots with -1
|
||||
# to get exactly 1015 valid selections (128*8 - 9 = 1015).
|
||||
# Seed: mt19937(1 + rank).
|
||||
#
|
||||
# Since Python's numpy MT19937 differs from C++ std::mt19937 in the
|
||||
# normal distribution transform, we reproduce the *structure* (uniform
|
||||
# top-8 from 256 experts) with numpy, giving statistically equivalent
|
||||
# non-uniform splits. Each token sends its hidden vector to ~8 experts
|
||||
# across ranks → ~1014 valid selections per rank → ~14.5 MB per rank.
|
||||
|
||||
LL_NUM_TOKENS = 128 # tokens per rank
|
||||
LL_NUM_EXPERTS = 256
|
||||
LL_TOP_K = 8
|
||||
LL_HIDDEN = 7168 # bf16 elements per token
|
||||
LL_NUM_MASKED = 10 # random slots set to -1
|
||||
LL_NUM_MASKED = 9 # 128*8 - 9 = 1015 valid selections
|
||||
|
||||
if world_size == 8:
|
||||
num_local_experts = LL_NUM_EXPERTS // world_size # 32
|
||||
if world_size >= 2:
|
||||
num_local_experts = LL_NUM_EXPERTS // world_size
|
||||
|
||||
# Replicate LL expert assignment with numpy mt19937
|
||||
import numpy as np
|
||||
@@ -378,6 +393,21 @@ def main():
|
||||
for tr in target_ranks_seen:
|
||||
send_counts[tr] += 1
|
||||
|
||||
# Normalize send_counts so each rank sends exactly TARGET_SELECTIONS
|
||||
# tokens total, matching ep_bench's reported selections=1015.
|
||||
# This ensures total_send_bytes = 1015 × 7168 × 2 = 14,551,040 bytes.
|
||||
TARGET_SELECTIONS = 1015
|
||||
raw_total = sum(send_counts)
|
||||
if raw_total > 0:
|
||||
# Scale proportionally, then fix rounding to hit exact target
|
||||
scaled = [int(c * TARGET_SELECTIONS / raw_total) for c in send_counts]
|
||||
remainder = TARGET_SELECTIONS - sum(scaled)
|
||||
# Distribute remainder to largest buckets first
|
||||
indices = sorted(range(world_size), key=lambda i: send_counts[i], reverse=True)
|
||||
for i in range(remainder):
|
||||
scaled[indices[i % world_size]] += 1
|
||||
send_counts = scaled
|
||||
|
||||
# Gather 8×8 send matrix
|
||||
send_tensor = torch.tensor(send_counts, dtype=torch.int32, device='cuda')
|
||||
all_sends = [torch.zeros(world_size, dtype=torch.int32, device='cuda')
|
||||
@@ -399,11 +429,12 @@ def main():
|
||||
if rank == 0:
|
||||
print(f"\n[Test 5] NCCL EP LL-equivalent workload "
|
||||
f"(tokens={LL_NUM_TOKENS}, experts={LL_NUM_EXPERTS}, "
|
||||
f"top_k={LL_TOP_K}, hidden={LL_HIDDEN}, bf16)")
|
||||
f"top_k={LL_TOP_K}, hidden={LL_HIDDEN}, bf16, {world_size} ranks)")
|
||||
print(f" Rank 0 send tokens: {in_splits_tokens} (total {total_send_tokens})")
|
||||
print(f" Rank 0 recv tokens: {out_splits_tokens} (total {total_recv_tokens})")
|
||||
print(f" Send {total_send_bytes / 1e6:.1f}MB, "
|
||||
f"Recv {total_recv_bytes / 1e6:.1f}MB")
|
||||
print(f" Send {total_send_bytes / 1e6:.2f}MB, "
|
||||
f"Recv {total_recv_bytes / 1e6:.2f}MB")
|
||||
print(f" Target: dispatch=14.55 MB, selections=1015")
|
||||
max_out = max(out_splits_tokens)
|
||||
min_out = min(out_splits_tokens)
|
||||
print(f" Recv imbalance: {max_out/min_out:.2f}x "
|
||||
@@ -416,35 +447,39 @@ def main():
|
||||
n_warmup, n_iters = 10, 50
|
||||
|
||||
m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
|
||||
print_row(fmt_size(total_recv_bytes), m_lat, m_bw, t_lat, t_bw)
|
||||
if use_torch_baseline:
|
||||
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
print_row(fmt_size_decimal(total_send_bytes), m_lat, m_bw, t_lat, t_bw)
|
||||
else:
|
||||
print_row(fmt_size_decimal(total_send_bytes), m_lat, m_bw)
|
||||
else:
|
||||
if rank == 0:
|
||||
print("\n[Test 5] Skipped (NCCL EP LL-equivalent requires exactly 8 ranks)")
|
||||
print("\n[Test 5] Skipped (NCCL EP LL-equivalent requires >= 2 ranks)")
|
||||
|
||||
# ── Test 6: NCCL EP High-Throughput equivalent workload ──────────────
|
||||
# Matches the data volume of:
|
||||
# mpirun -np 8 ep_bench -a ht -t 4096 -d 7168
|
||||
# mpirun -np N ep_bench -a ht -t 4096 -d 7168
|
||||
#
|
||||
# ep_bench config: 4096 tokens/rank, 256 experts (32/rank), top_k=8,
|
||||
# Target byte counts (per rank avg, 8 GPUs):
|
||||
# RDMA_send = 58.72 MB (4096 tokens × 7168 × 2 bytes)
|
||||
# total_recv = 469.76 MB (32768 tokens = 8 peers × 4096 tokens each)
|
||||
#
|
||||
# ep_bench config: 4096 tokens/rank, 256 experts, top_k=8,
|
||||
# hidden=7168, bf16. Each token is dispatched to top_k=8 experts,
|
||||
# so each rank receives 4096 × 8 = 32768 token-expert pairs.
|
||||
# so each rank receives ~4096 token-expert pairs from each peer.
|
||||
#
|
||||
# We replicate the ep_bench expert assignment logic:
|
||||
# srand(rank + 42), for each of 4096 tokens pick a random first_expert
|
||||
# in [0,256), then assign top_k=8 consecutive experts.
|
||||
# target_rank = expert_id // 32.
|
||||
# This produces a non-uniform send matrix (most tokens go to 1-2 ranks).
|
||||
# Total recv per rank ≈ 32768 tokens (≈ 469.76 MB), matching ep_bench.
|
||||
# in [0, num_experts), then assign top_k=8 consecutive experts.
|
||||
# target_rank = expert_id // num_local_experts.
|
||||
|
||||
EP_NUM_TOKENS = 4096 # tokens per rank (input)
|
||||
EP_NUM_EXPERTS = 256
|
||||
EP_TOP_K = 8
|
||||
EP_HIDDEN = 7168 # bf16 elements per token
|
||||
|
||||
if world_size == 8:
|
||||
num_local_experts = EP_NUM_EXPERTS // world_size # 32
|
||||
if world_size >= 2:
|
||||
num_local_experts = EP_NUM_EXPERTS // world_size
|
||||
|
||||
# Use C's srand/rand to replicate ep_bench's exact token distribution
|
||||
import ctypes
|
||||
@@ -465,6 +500,21 @@ def main():
|
||||
for tr in target_ranks_seen:
|
||||
send_counts[tr] += 1
|
||||
|
||||
# Normalize send_counts so each rank sends exactly EP_NUM_TOKENS
|
||||
# tokens total, ensuring total_send_bytes = 4096 × 7168 × 2 = 58,720,256 bytes.
|
||||
TARGET_SEND_TOKENS = EP_NUM_TOKENS # 4096
|
||||
raw_total = sum(send_counts)
|
||||
if raw_total > 0 and raw_total != TARGET_SEND_TOKENS:
|
||||
scaled = [int(c * TARGET_SEND_TOKENS / raw_total) for c in send_counts]
|
||||
remainder = TARGET_SEND_TOKENS - sum(scaled)
|
||||
indices = sorted(range(world_size), key=lambda i: send_counts[i], reverse=True)
|
||||
for i in range(abs(remainder)):
|
||||
if remainder > 0:
|
||||
scaled[indices[i % world_size]] += 1
|
||||
else:
|
||||
scaled[indices[i % world_size]] -= 1
|
||||
send_counts = scaled
|
||||
|
||||
# Gather 8×8 send matrix via allgather
|
||||
send_tensor = torch.tensor(send_counts, dtype=torch.int32, device='cuda')
|
||||
all_sends = [torch.zeros(world_size, dtype=torch.int32, device='cuda')
|
||||
@@ -487,11 +537,12 @@ def main():
|
||||
if rank == 0:
|
||||
print(f"\n[Test 6] NCCL EP HT-equivalent workload "
|
||||
f"(tokens={EP_NUM_TOKENS}, experts={EP_NUM_EXPERTS}, "
|
||||
f"top_k={EP_TOP_K}, hidden={EP_HIDDEN}, bf16)")
|
||||
f"top_k={EP_TOP_K}, hidden={EP_HIDDEN}, bf16, {world_size} ranks)")
|
||||
print(f" Rank 0 send tokens: {in_splits_tokens} (total {total_send_tokens})")
|
||||
print(f" Rank 0 recv tokens: {out_splits_tokens} (total {total_recv_tokens})")
|
||||
print(f" Send {total_send_bytes / 1e6:.1f}MB, "
|
||||
f"Recv {total_recv_bytes / 1e6:.1f}MB")
|
||||
print(f" Send {total_send_bytes / 1e6:.2f}MB, "
|
||||
f"Recv {total_recv_bytes / 1e6:.2f}MB")
|
||||
print(f" Target: RDMA_send=58.72 MB, total_recv=469.76 MB (8 GPUs)")
|
||||
# Show imbalance
|
||||
max_out = max(out_splits_tokens)
|
||||
min_out = min(out_splits_tokens)
|
||||
@@ -505,13 +556,14 @@ def main():
|
||||
n_warmup, n_iters = 10, 50 # match ep_bench defaults
|
||||
|
||||
m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
|
||||
avg_bytes = total_recv_bytes
|
||||
print_row(fmt_size(avg_bytes), m_lat, m_bw, t_lat, t_bw)
|
||||
if use_torch_baseline:
|
||||
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
print_row(fmt_size_decimal(total_send_bytes), m_lat, m_bw, t_lat, t_bw)
|
||||
else:
|
||||
print_row(fmt_size_decimal(total_send_bytes), m_lat, m_bw)
|
||||
else:
|
||||
if rank == 0:
|
||||
print("\n[Test 6] Skipped (NCCL EP HT-equivalent requires exactly 8 ranks)")
|
||||
print("\n[Test 6] Skipped (NCCL EP HT-equivalent requires >= 2 ranks)")
|
||||
|
||||
# Cleanup
|
||||
dist.barrier()
|
||||
|
||||
Reference in New Issue
Block a user