Match the message size for EP bench HT of 16 GPUs in test 6

This commit is contained in:
Qinghua Zhou
2026-03-30 03:40:05 +00:00
parent 62ab8883a6
commit 36940dbacf

View File

@@ -665,10 +665,6 @@ def main():
# Matches the data volume of:
# mpirun -np N ep_bench -a ht -t 4096 -d 7168
#
# Target byte counts (per rank avg, 8 GPUs):
# RDMA_send = 58.72 MB (4096 tokens × 7168 × 2 bytes)
# total_recv = 469.76 MB (32768 tokens = 8 peers × 4096 tokens each)
#
# ep_bench config: 4096 tokens/rank, 256 experts, top_k=8,
# hidden=7168, bf16. Each token is dispatched to top_k=8 experts,
# so each rank receives ~4096 token-expert pairs from each peer.
@@ -677,12 +673,21 @@ def main():
# srand(rank + 42), for each of 4096 tokens pick a random first_expert
# in [0, num_experts), then assign top_k=8 consecutive experts.
# target_rank = expert_id // num_local_experts.
#
# Target send bytes vary by GPU count (to match ep_bench reports):
# 8 GPUs: 4096 tokens/rank → 58.72 MB (no cross-boundary inflation)
# 16 GPUs: 4317 tokens/rank → 61.88 MB (matches ep_bench RDMA_send)
EP_NUM_TOKENS = 4096 # tokens per rank (input)
EP_NUM_EXPERTS = 256
EP_TOP_K = 8
EP_HIDDEN = 7168 # bf16 elements per token
# Target send tokens per rank, keyed by world_size.
# 8 GPUs: top_k=8 = num_local_experts=32, so no boundary-crossing → 4096
# 16 GPUs: num_local_experts=16, boundary crossing inflates to ~4317
EP_TARGET_TOKENS = {8: 4096, 16: 4317}
if world_size >= 2:
num_local_experts = EP_NUM_EXPERTS // world_size
@@ -705,9 +710,9 @@ def main():
for tr in target_ranks_seen:
send_counts[tr] += 1
# Normalize send_counts so each rank sends exactly EP_NUM_TOKENS
# tokens total, ensuring total_send_bytes = 4096 × 7168 × 2 = 58,720,256 bytes.
TARGET_SEND_TOKENS = EP_NUM_TOKENS # 4096
# Normalize send_counts to the target for this world_size.
# For unknown world_size, keep raw counts.
TARGET_SEND_TOKENS = EP_TARGET_TOKENS.get(world_size, sum(send_counts))
raw_total = sum(send_counts)
if raw_total > 0 and raw_total != TARGET_SEND_TOKENS:
scaled = [int(c * TARGET_SEND_TOKENS / raw_total) for c in send_counts]
@@ -720,7 +725,7 @@ def main():
scaled[indices[i % world_size]] -= 1
send_counts = scaled
# Gather 8×8 send matrix via allgather
# Gather send matrix via allgather
send_tensor = torch.tensor(send_counts, dtype=torch.int32, device='cuda')
all_sends = [torch.zeros(world_size, dtype=torch.int32, device='cuda')
for _ in range(world_size)]
@@ -739,6 +744,10 @@ def main():
total_send_bytes = sum(in_splits) * 2
total_recv_bytes = sum(out_splits) * 2
target_send_mb = TARGET_SEND_TOKENS * EP_HIDDEN * 2 / 1e6
target_recv_tokens = world_size * EP_NUM_TOKENS
target_recv_mb = target_recv_tokens * EP_HIDDEN * 2 / 1e6
if rank == 0:
print(f"\n[Test 6] NCCL EP HT-equivalent workload "
f"(tokens={EP_NUM_TOKENS}, experts={EP_NUM_EXPERTS}, "
@@ -747,8 +756,10 @@ def main():
print(f" Rank 0 recv tokens: {out_splits_tokens} (total {total_recv_tokens})")
print(f" Send {total_send_bytes / 1e6:.2f}MB, "
f"Recv {total_recv_bytes / 1e6:.2f}MB")
print(f" Target: RDMA_send=58.72 MB, total_recv=469.76 MB (8 GPUs)")
# Show imbalance
print(f" Target: RDMA_send={target_send_mb:.2f} MB "
f"({TARGET_SEND_TOKENS} tokens), "
f"total_recv={target_recv_mb:.2f} MB "
f"({target_recv_tokens} tokens)")
max_out = max(out_splits_tokens)
min_out = min(out_splits_tokens)
print(f" Recv imbalance: {max_out/min_out:.2f}x "