tests/ep: align intranode HT bench with NCCL-EP accounting

- Add MSCCLPP_EP_BENCH_EXPERTS / _TOPK env knobs so the bench phase can
  match NCCL-EP's `ep_bench -a ht` defaults (256 experts, top-8). The
  functional check above continues to use the smaller (num_ranks*4
  experts, topk=4) configuration.

- Switch BW accounting from recv_tokens*hidden to bench_tokens*hidden,
  matching NCCL-EP's `RDMA_send` per-rank byte count. The previous
  formula counted DeepEP's expanded recv layout (one row per
  (token,src_rank) pair), inflating reported GB/s ~5x and making
  cross-stack comparisons misleading.
This commit is contained in:
Qinghua Zhou
2026-04-27 17:14:42 +00:00
parent 9c129b8b5a
commit 4ed6f229f2

View File

@@ -189,6 +189,24 @@ def main():
iters = int(os.environ.get("MSCCLPP_EP_BENCH_ITERS", "20"))
bench_tokens = int(os.environ.get("MSCCLPP_EP_BENCH_TOKENS", "4096"))
bench_hidden = int(os.environ.get("MSCCLPP_EP_BENCH_HIDDEN", "7168"))
# Allow overriding num_experts / num_topk for the bench phase to match
# NCCL-EP's `ep_bench -a ht` defaults (256 experts, top-8). The functional
# check above still uses the smaller (num_experts=num_ranks*4, topk=4)
# configuration.
bench_num_experts = int(os.environ.get(
"MSCCLPP_EP_BENCH_EXPERTS", str(num_experts)))
bench_num_topk = int(os.environ.get(
"MSCCLPP_EP_BENCH_TOPK", str(num_topk)))
if bench_num_experts % num_ranks != 0:
if rank == 0:
print(f"[bench] skip: num_experts={bench_num_experts} not divisible "
f"by num_ranks={num_ranks}", flush=True)
return
if bench_num_topk > bench_num_experts:
if rank == 0:
print(f"[bench] skip: topk={bench_num_topk} > experts={bench_num_experts}",
flush=True)
return
# Rebuild inputs at bench size. Keep same layout recipe as above but at
# larger (num_tokens, hidden); Buffer is sized off the original cfg+hidden,
@@ -203,14 +221,14 @@ def main():
)
return
scores_b = torch.randn((bench_tokens, num_experts), device="cuda", dtype=torch.float32).abs() + 1
topk_idx_b = torch.topk(scores_b, num_topk, dim=-1, sorted=False).indices
topk_weights_b = torch.ones((bench_tokens, num_topk), dtype=torch.float32, device="cuda")
rank_idx_b = topk_idx_b // (num_experts // num_ranks)
scores_b = torch.randn((bench_tokens, bench_num_experts), device="cuda", dtype=torch.float32).abs() + 1
topk_idx_b = torch.topk(scores_b, bench_num_topk, dim=-1, sorted=False).indices
topk_weights_b = torch.ones((bench_tokens, bench_num_topk), dtype=torch.float32, device="cuda")
rank_idx_b = topk_idx_b // (bench_num_experts // num_ranks)
rank_idx_b.masked_fill_(topk_idx_b == -1, -1)
inplace_unique(rank_idx_b, num_ranks)
num_tokens_per_expert_b = torch.zeros((num_experts,), dtype=torch.int, device="cuda")
for i in range(num_experts):
num_tokens_per_expert_b = torch.zeros((bench_num_experts,), dtype=torch.int, device="cuda")
for i in range(bench_num_experts):
num_tokens_per_expert_b[i] = (topk_idx_b == i).sum()
num_tokens_per_rank_b = torch.empty((num_ranks,), dtype=torch.int, device="cuda")
token_idx_in_rank_b = torch.full((num_ranks, bench_tokens), -1, dtype=torch.long, device="cuda")
@@ -254,7 +272,6 @@ def main():
end_ev.record()
torch.cuda.synchronize()
disp_us = start_ev.elapsed_time(end_ev) * 1e3 / iters
recv_tokens = dout[0].size(0)
# Time combine alone (reusing the same dispatch output each iter).
dist.barrier(group=group)
@@ -265,8 +282,12 @@ def main():
torch.cuda.synchronize()
comb_us = start_ev.elapsed_time(end_ev) * 1e3 / iters
# One-way payload bytes (per phase) per rank.
bytes_one_way = recv_tokens * bench_hidden * x_b.element_size()
# Per-rank "send bytes" matches NCCL-EP's `ep_bench` accounting (`RDMA_send`):
# bench_tokens * hidden * sizeof(bf16). Each rank ships its `bench_tokens`
# input rows out (some replicated to multiple peers); NCCL-EP normalizes by
# the input footprint, not by the recv-side fan-out. We use the same
# convention here so `per_rank_bw` is directly comparable across stacks.
bytes_one_way = bench_tokens * bench_hidden * x_b.element_size()
disp_bw = bytes_one_way / (disp_us * 1e-6) / 1e9
comb_bw = bytes_one_way / (comb_us * 1e-6) / 1e9
@@ -281,6 +302,7 @@ def main():
if rank == 0:
print(
f"[bench intranode HT] tokens={bench_tokens} hidden={bench_hidden} "
f"experts={bench_num_experts} topk={bench_num_topk} "
f"warmup={warmup} iters={iters}",
flush=True,
)