diff --git a/test/python/ext/ep/test_low_latency_multirank.py b/test/python/ext/ep/test_low_latency_multirank.py index d82ea309..0fb0eca9 100644 --- a/test/python/ext/ep/test_low_latency_multirank.py +++ b/test/python/ext/ep/test_low_latency_multirank.py @@ -61,10 +61,11 @@ def main(): rank_offset = 128 assert num_ranks - rank_offset < 257, "too many ranks for bf16 precision anchor" - num_tokens = 64 - hidden = 7168 # LL kernels are compiled for a fixed set; see SWITCH_HIDDEN - num_topk = 4 - num_experts = num_ranks * 4 + num_tokens = int(os.environ.get("MSCCLPP_EP_LL_TOKENS", "64")) + hidden = int(os.environ.get("MSCCLPP_EP_LL_HIDDEN", "7168")) # LL kernels are compiled for a fixed set; see SWITCH_HIDDEN + num_topk = int(os.environ.get("MSCCLPP_EP_LL_TOPK", "4")) + num_experts_per_rank = int(os.environ.get("MSCCLPP_EP_LL_EXPERTS_PER_RANK", "4")) + num_experts = num_ranks * num_experts_per_rank assert num_experts % num_ranks == 0 num_local_experts = num_experts // num_ranks @@ -271,20 +272,29 @@ def main(): comb_bw_t = torch.tensor([comb_bw], dtype=torch.float64, device="cuda") dist.all_reduce(disp_us_t, op=dist.ReduceOp.MAX, group=group) dist.all_reduce(comb_us_t, op=dist.ReduceOp.MAX, group=group) - dist.all_reduce(disp_bw_t, op=dist.ReduceOp.SUM, group=group) - dist.all_reduce(comb_bw_t, op=dist.ReduceOp.SUM, group=group) + # Aggregate = sum across ranks; per-rank avg = sum / num_ranks. Also report + # per-rank numbers to line up with NCCL-EP's `ep_bench.cu`, which prints the + # rank's own bytes / its own elapsed time. + disp_bw_agg = disp_bw_t.clone() + comb_bw_agg = comb_bw_t.clone() + dist.all_reduce(disp_bw_agg, op=dist.ReduceOp.SUM, group=group) + dist.all_reduce(comb_bw_agg, op=dist.ReduceOp.SUM, group=group) if rank == 0: print( f"[bench LL] num_ranks={num_ranks} tokens={num_tokens} hidden={hidden} " - f"num_experts={num_experts} warmup={warmup} iters={iters}", + f"num_experts={num_experts} num_topk={num_topk} warmup={warmup} iters={iters}", flush=True, ) print( - f" dispatch: {disp_us_t.item():.1f}us (max) agg_bw={disp_bw_t.item():.2f} GB/s", + f" dispatch: {disp_us_t.item():.1f}us (max) " + f"per_rank_bw={disp_bw_agg.item() / num_ranks:.2f} GB/s " + f"agg_bw={disp_bw_agg.item():.2f} GB/s", flush=True, ) print( - f" combine : {comb_us_t.item():.1f}us (max) agg_bw={comb_bw_t.item():.2f} GB/s", + f" combine : {comb_us_t.item():.1f}us (max) " + f"per_rank_bw={comb_bw_agg.item() / num_ranks:.2f} GB/s " + f"agg_bw={comb_bw_agg.item():.2f} GB/s", flush=True, )