mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
tests/ep: align intranode HT bench with NCCL-EP accounting
- Add MSCCLPP_EP_BENCH_EXPERTS / _TOPK env knobs so the bench phase can match NCCL-EP's `ep_bench -a ht` defaults (256 experts, top-8). The functional check above continues to use the smaller (num_ranks*4 experts, topk=4) configuration. - Switch BW accounting from recv_tokens*hidden to bench_tokens*hidden, matching NCCL-EP's `RDMA_send` per-rank byte count. The previous formula counted DeepEP's expanded recv layout (one row per (token,src_rank) pair), inflating reported GB/s ~5x and making cross-stack comparisons misleading.
This commit is contained in:
@@ -189,6 +189,24 @@ def main():
|
||||
iters = int(os.environ.get("MSCCLPP_EP_BENCH_ITERS", "20"))
|
||||
bench_tokens = int(os.environ.get("MSCCLPP_EP_BENCH_TOKENS", "4096"))
|
||||
bench_hidden = int(os.environ.get("MSCCLPP_EP_BENCH_HIDDEN", "7168"))
|
||||
# Allow overriding num_experts / num_topk for the bench phase to match
|
||||
# NCCL-EP's `ep_bench -a ht` defaults (256 experts, top-8). The functional
|
||||
# check above still uses the smaller (num_experts=num_ranks*4, topk=4)
|
||||
# configuration.
|
||||
bench_num_experts = int(os.environ.get(
|
||||
"MSCCLPP_EP_BENCH_EXPERTS", str(num_experts)))
|
||||
bench_num_topk = int(os.environ.get(
|
||||
"MSCCLPP_EP_BENCH_TOPK", str(num_topk)))
|
||||
if bench_num_experts % num_ranks != 0:
|
||||
if rank == 0:
|
||||
print(f"[bench] skip: num_experts={bench_num_experts} not divisible "
|
||||
f"by num_ranks={num_ranks}", flush=True)
|
||||
return
|
||||
if bench_num_topk > bench_num_experts:
|
||||
if rank == 0:
|
||||
print(f"[bench] skip: topk={bench_num_topk} > experts={bench_num_experts}",
|
||||
flush=True)
|
||||
return
|
||||
|
||||
# Rebuild inputs at bench size. Keep same layout recipe as above but at
|
||||
# larger (num_tokens, hidden); Buffer is sized off the original cfg+hidden,
|
||||
@@ -203,14 +221,14 @@ def main():
|
||||
)
|
||||
return
|
||||
|
||||
scores_b = torch.randn((bench_tokens, num_experts), device="cuda", dtype=torch.float32).abs() + 1
|
||||
topk_idx_b = torch.topk(scores_b, num_topk, dim=-1, sorted=False).indices
|
||||
topk_weights_b = torch.ones((bench_tokens, num_topk), dtype=torch.float32, device="cuda")
|
||||
rank_idx_b = topk_idx_b // (num_experts // num_ranks)
|
||||
scores_b = torch.randn((bench_tokens, bench_num_experts), device="cuda", dtype=torch.float32).abs() + 1
|
||||
topk_idx_b = torch.topk(scores_b, bench_num_topk, dim=-1, sorted=False).indices
|
||||
topk_weights_b = torch.ones((bench_tokens, bench_num_topk), dtype=torch.float32, device="cuda")
|
||||
rank_idx_b = topk_idx_b // (bench_num_experts // num_ranks)
|
||||
rank_idx_b.masked_fill_(topk_idx_b == -1, -1)
|
||||
inplace_unique(rank_idx_b, num_ranks)
|
||||
num_tokens_per_expert_b = torch.zeros((num_experts,), dtype=torch.int, device="cuda")
|
||||
for i in range(num_experts):
|
||||
num_tokens_per_expert_b = torch.zeros((bench_num_experts,), dtype=torch.int, device="cuda")
|
||||
for i in range(bench_num_experts):
|
||||
num_tokens_per_expert_b[i] = (topk_idx_b == i).sum()
|
||||
num_tokens_per_rank_b = torch.empty((num_ranks,), dtype=torch.int, device="cuda")
|
||||
token_idx_in_rank_b = torch.full((num_ranks, bench_tokens), -1, dtype=torch.long, device="cuda")
|
||||
@@ -254,7 +272,6 @@ def main():
|
||||
end_ev.record()
|
||||
torch.cuda.synchronize()
|
||||
disp_us = start_ev.elapsed_time(end_ev) * 1e3 / iters
|
||||
recv_tokens = dout[0].size(0)
|
||||
|
||||
# Time combine alone (reusing the same dispatch output each iter).
|
||||
dist.barrier(group=group)
|
||||
@@ -265,8 +282,12 @@ def main():
|
||||
torch.cuda.synchronize()
|
||||
comb_us = start_ev.elapsed_time(end_ev) * 1e3 / iters
|
||||
|
||||
# One-way payload bytes (per phase) per rank.
|
||||
bytes_one_way = recv_tokens * bench_hidden * x_b.element_size()
|
||||
# Per-rank "send bytes" matches NCCL-EP's `ep_bench` accounting (`RDMA_send`):
|
||||
# bench_tokens * hidden * sizeof(bf16). Each rank ships its `bench_tokens`
|
||||
# input rows out (some replicated to multiple peers); NCCL-EP normalizes by
|
||||
# the input footprint, not by the recv-side fan-out. We use the same
|
||||
# convention here so `per_rank_bw` is directly comparable across stacks.
|
||||
bytes_one_way = bench_tokens * bench_hidden * x_b.element_size()
|
||||
disp_bw = bytes_one_way / (disp_us * 1e-6) / 1e9
|
||||
comb_bw = bytes_one_way / (comb_us * 1e-6) / 1e9
|
||||
|
||||
@@ -281,6 +302,7 @@ def main():
|
||||
if rank == 0:
|
||||
print(
|
||||
f"[bench intranode HT] tokens={bench_tokens} hidden={bench_hidden} "
|
||||
f"experts={bench_num_experts} topk={bench_num_topk} "
|
||||
f"warmup={warmup} iters={iters}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user