tests/ep: LL bench prints per_rank_bw and accepts size env vars

- Report both per-rank and aggregate BW to align with NCCL-EP's ep_bench
  (which reports per-rank GB/s).
- Accept MSCCLPP_EP_LL_TOKENS/HIDDEN/TOPK/EXPERTS_PER_RANK env overrides
  so we can match external benchmark problem sizes (NCCL-EP LL defaults
  are num_tokens=128, hidden=7168, top_k=8).
This commit is contained in:
Qinghua Zhou
2026-04-23 22:20:40 +00:00
parent 63afb25ab3
commit 10cd0012f1

View File

@@ -61,10 +61,11 @@ def main():
rank_offset = 128
assert num_ranks - rank_offset < 257, "too many ranks for bf16 precision anchor"
num_tokens = 64
hidden = 7168 # LL kernels are compiled for a fixed set; see SWITCH_HIDDEN
num_topk = 4
num_experts = num_ranks * 4
num_tokens = int(os.environ.get("MSCCLPP_EP_LL_TOKENS", "64"))
hidden = int(os.environ.get("MSCCLPP_EP_LL_HIDDEN", "7168")) # LL kernels are compiled for a fixed set; see SWITCH_HIDDEN
num_topk = int(os.environ.get("MSCCLPP_EP_LL_TOPK", "4"))
num_experts_per_rank = int(os.environ.get("MSCCLPP_EP_LL_EXPERTS_PER_RANK", "4"))
num_experts = num_ranks * num_experts_per_rank
assert num_experts % num_ranks == 0
num_local_experts = num_experts // num_ranks
@@ -271,20 +272,29 @@ def main():
comb_bw_t = torch.tensor([comb_bw], dtype=torch.float64, device="cuda")
dist.all_reduce(disp_us_t, op=dist.ReduceOp.MAX, group=group)
dist.all_reduce(comb_us_t, op=dist.ReduceOp.MAX, group=group)
dist.all_reduce(disp_bw_t, op=dist.ReduceOp.SUM, group=group)
dist.all_reduce(comb_bw_t, op=dist.ReduceOp.SUM, group=group)
# Aggregate = sum across ranks; per-rank avg = sum / num_ranks. Also report
# per-rank numbers to line up with NCCL-EP's `ep_bench.cu`, which prints the
# rank's own bytes / its own elapsed time.
disp_bw_agg = disp_bw_t.clone()
comb_bw_agg = comb_bw_t.clone()
dist.all_reduce(disp_bw_agg, op=dist.ReduceOp.SUM, group=group)
dist.all_reduce(comb_bw_agg, op=dist.ReduceOp.SUM, group=group)
if rank == 0:
print(
f"[bench LL] num_ranks={num_ranks} tokens={num_tokens} hidden={hidden} "
f"num_experts={num_experts} warmup={warmup} iters={iters}",
f"num_experts={num_experts} num_topk={num_topk} warmup={warmup} iters={iters}",
flush=True,
)
print(
f" dispatch: {disp_us_t.item():.1f}us (max) agg_bw={disp_bw_t.item():.2f} GB/s",
f" dispatch: {disp_us_t.item():.1f}us (max) "
f"per_rank_bw={disp_bw_agg.item() / num_ranks:.2f} GB/s "
f"agg_bw={disp_bw_agg.item():.2f} GB/s",
flush=True,
)
print(
f" combine : {comb_us_t.item():.1f}us (max) agg_bw={comb_bw_t.item():.2f} GB/s",
f" combine : {comb_us_t.item():.1f}us (max) "
f"per_rank_bw={comb_bw_agg.item() / num_ranks:.2f} GB/s "
f"agg_bw={comb_bw_agg.item():.2f} GB/s",
flush=True,
)