tests/ep: align internode HT bench with NCCL-EP accounting

Same change as the intra-node bench (commit 4ed6f229), applied to the
cross-node test:

- Add MSCCLPP_EP_BENCH_EXPERTS / _TOPK env knobs so the bench phase can
  match NCCL-EP's `ep_bench -a ht` defaults (256 experts, top-8).
- Switch BW accounting from recv_tokens*hidden to bench_tokens*hidden,
  matching NCCL-EP's `RDMA_send` per-rank byte count.
This commit is contained in:
Qinghua Zhou
2026-04-27 17:39:17 +00:00
parent 4ed6f229f2
commit 48540bc11e

View File

@@ -222,6 +222,24 @@ def main():
iters = int(os.environ.get("MSCCLPP_EP_BENCH_ITERS", "20"))
bench_tokens = int(os.environ.get("MSCCLPP_EP_BENCH_TOKENS", "4096"))
bench_hidden = int(os.environ.get("MSCCLPP_EP_BENCH_HIDDEN", "7168"))
# Allow overriding num_experts / num_topk for the bench phase to match
# NCCL-EP's `ep_bench -a ht` defaults (256 experts, top-8). The functional
# check above still uses the smaller (num_experts=num_ranks*4, topk=4)
# configuration.
bench_num_experts = int(os.environ.get(
"MSCCLPP_EP_BENCH_EXPERTS", str(num_experts)))
bench_num_topk = int(os.environ.get(
"MSCCLPP_EP_BENCH_TOPK", str(num_topk)))
if bench_num_experts % num_ranks != 0:
if rank == 0:
print(f"[bench] skip: num_experts={bench_num_experts} not divisible "
f"by num_ranks={num_ranks}", flush=True)
return
if bench_num_topk > bench_num_experts:
if rank == 0:
print(f"[bench] skip: topk={bench_num_topk} > experts={bench_num_experts}",
flush=True)
return
# Respect the Buffer's pre-sized num_nvl_bytes / num_rdma_bytes budget.
per_peer_nvl = num_nvl_bytes // max(1, num_ranks)
@@ -236,18 +254,18 @@ def main():
)
return
scores_b = torch.randn((bench_tokens, num_experts), device="cuda", dtype=torch.float32).abs() + 1
topk_idx_b = torch.topk(scores_b, num_topk, dim=-1, sorted=False).indices
topk_weights_b = torch.ones((bench_tokens, num_topk), dtype=torch.float32, device="cuda")
rank_idx_b = topk_idx_b // (num_experts // num_ranks)
scores_b = torch.randn((bench_tokens, bench_num_experts), device="cuda", dtype=torch.float32).abs() + 1
topk_idx_b = torch.topk(scores_b, bench_num_topk, dim=-1, sorted=False).indices
topk_weights_b = torch.ones((bench_tokens, bench_num_topk), dtype=torch.float32, device="cuda")
rank_idx_b = topk_idx_b // (bench_num_experts // num_ranks)
rank_idx_b.masked_fill_(topk_idx_b == -1, -1)
inplace_unique(rank_idx_b, num_ranks)
rdma_rank_idx_b = rank_idx_b // num_local_ranks
rdma_rank_idx_b.masked_fill_(rank_idx_b == -1, -1)
inplace_unique(rdma_rank_idx_b, num_nodes)
num_tokens_per_expert_b = torch.zeros((num_experts,), dtype=torch.int, device="cuda")
for i in range(num_experts):
num_tokens_per_expert_b = torch.zeros((bench_num_experts,), dtype=torch.int, device="cuda")
for i in range(bench_num_experts):
num_tokens_per_expert_b[i] = (topk_idx_b == i).sum()
num_tokens_per_rank_b = torch.empty((num_ranks,), dtype=torch.int, device="cuda")
num_tokens_per_rdma_rank_b = torch.empty((num_nodes,), dtype=torch.int, device="cuda")
@@ -303,7 +321,6 @@ def main():
end_ev.record()
torch.cuda.synchronize()
disp_us = start_ev.elapsed_time(end_ev) * 1e3 / iters
recv_tokens = dout[0].size(0)
# Required guard before combine sees the dispatch outputs (see correctness
# path's XXX note). Not included in either phase's timing.
@@ -318,7 +335,12 @@ def main():
torch.cuda.synchronize()
comb_us = start_ev.elapsed_time(end_ev) * 1e3 / iters
bytes_one_way = recv_tokens * bench_hidden * x_b.element_size()
# Per-rank "send bytes" matches NCCL-EP's `ep_bench` accounting (`RDMA_send`):
# bench_tokens * hidden * sizeof(bf16). Each rank ships its `bench_tokens`
# input rows out (some replicated to multiple peers); NCCL-EP normalizes by
# the input footprint, not by the recv-side fan-out. We use the same
# convention here so `per_rank_bw` is directly comparable across stacks.
bytes_one_way = bench_tokens * bench_hidden * x_b.element_size()
disp_bw = bytes_one_way / (disp_us * 1e-6) / 1e9
comb_bw = bytes_one_way / (comb_us * 1e-6) / 1e9
@@ -334,6 +356,7 @@ def main():
print(
f"[bench internode HT] nodes={num_nodes} num_ranks={num_ranks} "
f"tokens={bench_tokens} hidden={bench_hidden} "
f"experts={bench_num_experts} topk={bench_num_topk} "
f"warmup={warmup} iters={iters}",
flush=True,
)