ep tests: report dispatch/combine min, avg, max time and use avg for BW

Aligns with NCCL-EP's ep_bench convention (BW computed from average time
across ranks). Previously we reported only the max time and computed BW
per-rank, which made our numbers more pessimistic than NCCL-EP's.
This commit is contained in:
Qinghua Zhou
2026-04-29 16:50:33 +00:00
parent afbdcd6a3d
commit 9213587ffe
3 changed files with 72 additions and 53 deletions

View File

@@ -342,17 +342,25 @@ def main():
# the input footprint, not by the recv-side fan-out. We use the same
# convention here so `per_rank_bw` is directly comparable across stacks.
bytes_one_way = bench_tokens * bench_hidden * x_b.element_size()
disp_bw = bytes_one_way / (disp_us * 1e-6) / 1e9
comb_bw = bytes_one_way / (comb_us * 1e-6) / 1e9
disp_us_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda")
comb_us_t = torch.tensor([comb_us], dtype=torch.float64, device="cuda")
disp_bw_t = torch.tensor([disp_bw], dtype=torch.float64, device="cuda")
comb_bw_t = torch.tensor([comb_bw], dtype=torch.float64, device="cuda")
dist.all_reduce(disp_us_t, op=dist.ReduceOp.MAX, group=group)
dist.all_reduce(comb_us_t, op=dist.ReduceOp.MAX, group=group)
dist.all_reduce(disp_bw_t, op=dist.ReduceOp.SUM, group=group)
dist.all_reduce(comb_bw_t, op=dist.ReduceOp.SUM, group=group)
# Reduce timings: report min/avg/max and base BW on AVG to match NCCL-EP's
# `ep_bench.cu` convention.
disp_min_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda")
disp_avg_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda")
disp_max_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda")
comb_min_t = torch.tensor([comb_us], dtype=torch.float64, device="cuda")
comb_avg_t = torch.tensor([comb_us], dtype=torch.float64, device="cuda")
comb_max_t = torch.tensor([comb_us], dtype=torch.float64, device="cuda")
dist.all_reduce(disp_min_t, op=dist.ReduceOp.MIN, group=group)
dist.all_reduce(disp_avg_t, op=dist.ReduceOp.SUM, group=group)
dist.all_reduce(disp_max_t, op=dist.ReduceOp.MAX, group=group)
dist.all_reduce(comb_min_t, op=dist.ReduceOp.MIN, group=group)
dist.all_reduce(comb_avg_t, op=dist.ReduceOp.SUM, group=group)
dist.all_reduce(comb_max_t, op=dist.ReduceOp.MAX, group=group)
disp_avg_us = disp_avg_t.item() / num_ranks
comb_avg_us = comb_avg_t.item() / num_ranks
disp_bw_per_rank = bytes_one_way / (disp_avg_us * 1e-6) / 1e9
comb_bw_per_rank = bytes_one_way / (comb_avg_us * 1e-6) / 1e9
if rank == 0:
print(
f"[bench internode HT] nodes={num_nodes} num_ranks={num_ranks} "
@@ -362,15 +370,15 @@ def main():
flush=True,
)
print(
f" dispatch: {disp_us_t.item():.1f}us (max) "
f"per_rank_bw={disp_bw_t.item() / num_ranks:.2f} GB/s "
f"agg_bw={disp_bw_t.item():.2f} GB/s",
f" dispatch: avg={disp_avg_us:.1f}us min={disp_min_t.item():.1f}us max={disp_max_t.item():.1f}us "
f"per_rank_bw={disp_bw_per_rank:.2f} GB/s "
f"agg_bw={disp_bw_per_rank * num_ranks:.2f} GB/s (BW @ avg time)",
flush=True,
)
print(
f" combine : {comb_us_t.item():.1f}us (max) "
f"per_rank_bw={comb_bw_t.item() / num_ranks:.2f} GB/s "
f"agg_bw={comb_bw_t.item():.2f} GB/s",
f" combine : avg={comb_avg_us:.1f}us min={comb_min_t.item():.1f}us max={comb_max_t.item():.1f}us "
f"per_rank_bw={comb_bw_per_rank:.2f} GB/s "
f"agg_bw={comb_bw_per_rank * num_ranks:.2f} GB/s (BW @ avg time)",
flush=True,
)

View File

@@ -288,17 +288,25 @@ def main():
# the input footprint, not by the recv-side fan-out. We use the same
# convention here so `per_rank_bw` is directly comparable across stacks.
bytes_one_way = bench_tokens * bench_hidden * x_b.element_size()
disp_bw = bytes_one_way / (disp_us * 1e-6) / 1e9
comb_bw = bytes_one_way / (comb_us * 1e-6) / 1e9
disp_us_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda")
comb_us_t = torch.tensor([comb_us], dtype=torch.float64, device="cuda")
disp_bw_t = torch.tensor([disp_bw], dtype=torch.float64, device="cuda")
comb_bw_t = torch.tensor([comb_bw], dtype=torch.float64, device="cuda")
dist.all_reduce(disp_us_t, op=dist.ReduceOp.MAX, group=group)
dist.all_reduce(comb_us_t, op=dist.ReduceOp.MAX, group=group)
dist.all_reduce(disp_bw_t, op=dist.ReduceOp.SUM, group=group)
dist.all_reduce(comb_bw_t, op=dist.ReduceOp.SUM, group=group)
# Reduce timings: report min/avg/max and base BW on AVG to match NCCL-EP's
# `ep_bench.cu` convention.
disp_min_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda")
disp_avg_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda")
disp_max_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda")
comb_min_t = torch.tensor([comb_us], dtype=torch.float64, device="cuda")
comb_avg_t = torch.tensor([comb_us], dtype=torch.float64, device="cuda")
comb_max_t = torch.tensor([comb_us], dtype=torch.float64, device="cuda")
dist.all_reduce(disp_min_t, op=dist.ReduceOp.MIN, group=group)
dist.all_reduce(disp_avg_t, op=dist.ReduceOp.SUM, group=group)
dist.all_reduce(disp_max_t, op=dist.ReduceOp.MAX, group=group)
dist.all_reduce(comb_min_t, op=dist.ReduceOp.MIN, group=group)
dist.all_reduce(comb_avg_t, op=dist.ReduceOp.SUM, group=group)
dist.all_reduce(comb_max_t, op=dist.ReduceOp.MAX, group=group)
disp_avg_us = disp_avg_t.item() / num_ranks
comb_avg_us = comb_avg_t.item() / num_ranks
disp_bw_per_rank = bytes_one_way / (disp_avg_us * 1e-6) / 1e9
comb_bw_per_rank = bytes_one_way / (comb_avg_us * 1e-6) / 1e9
if rank == 0:
print(
f"[bench intranode HT] tokens={bench_tokens} hidden={bench_hidden} "
@@ -307,15 +315,15 @@ def main():
flush=True,
)
print(
f" dispatch: {disp_us_t.item():.1f}us (max) "
f"per_rank_bw={disp_bw_t.item() / num_ranks:.2f} GB/s "
f"agg_bw={disp_bw_t.item():.2f} GB/s",
f" dispatch: avg={disp_avg_us:.1f}us min={disp_min_t.item():.1f}us max={disp_max_t.item():.1f}us "
f"per_rank_bw={disp_bw_per_rank:.2f} GB/s "
f"agg_bw={disp_bw_per_rank * num_ranks:.2f} GB/s (BW @ avg time)",
flush=True,
)
print(
f" combine : {comb_us_t.item():.1f}us (max) "
f"per_rank_bw={comb_bw_t.item() / num_ranks:.2f} GB/s "
f"agg_bw={comb_bw_t.item():.2f} GB/s",
f" combine : avg={comb_avg_us:.1f}us min={comb_min_t.item():.1f}us max={comb_max_t.item():.1f}us "
f"per_rank_bw={comb_bw_per_rank:.2f} GB/s "
f"agg_bw={comb_bw_per_rank * num_ranks:.2f} GB/s (BW @ avg time)",
flush=True,
)

View File

@@ -268,22 +268,25 @@ def main():
# the actual send payload by ~num_topk×.
disp_bytes = recv_tokens * hidden * 2
comb_bytes = recv_tokens * hidden * 2
disp_bw = disp_bytes / (disp_us * 1e-6) / 1e9
comb_bw = comb_bytes / (comb_us * 1e-6) / 1e9
disp_us_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda")
comb_us_t = torch.tensor([comb_us], dtype=torch.float64, device="cuda")
disp_bw_t = torch.tensor([disp_bw], dtype=torch.float64, device="cuda")
comb_bw_t = torch.tensor([comb_bw], dtype=torch.float64, device="cuda")
dist.all_reduce(disp_us_t, op=dist.ReduceOp.MAX, group=group)
dist.all_reduce(comb_us_t, op=dist.ReduceOp.MAX, group=group)
# Aggregate = sum across ranks; per-rank avg = sum / num_ranks. Also report
# per-rank numbers to line up with NCCL-EP's `ep_bench.cu`, which prints the
# rank's own bytes / its own elapsed time.
disp_bw_agg = disp_bw_t.clone()
comb_bw_agg = comb_bw_t.clone()
dist.all_reduce(disp_bw_agg, op=dist.ReduceOp.SUM, group=group)
dist.all_reduce(comb_bw_agg, op=dist.ReduceOp.SUM, group=group)
# Reduce timings: report min/avg/max and base BW on AVG to match NCCL-EP's
# `ep_bench.cu` convention.
disp_min_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda")
disp_avg_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda")
disp_max_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda")
comb_min_t = torch.tensor([comb_us], dtype=torch.float64, device="cuda")
comb_avg_t = torch.tensor([comb_us], dtype=torch.float64, device="cuda")
comb_max_t = torch.tensor([comb_us], dtype=torch.float64, device="cuda")
dist.all_reduce(disp_min_t, op=dist.ReduceOp.MIN, group=group)
dist.all_reduce(disp_avg_t, op=dist.ReduceOp.SUM, group=group)
dist.all_reduce(disp_max_t, op=dist.ReduceOp.MAX, group=group)
dist.all_reduce(comb_min_t, op=dist.ReduceOp.MIN, group=group)
dist.all_reduce(comb_avg_t, op=dist.ReduceOp.SUM, group=group)
dist.all_reduce(comb_max_t, op=dist.ReduceOp.MAX, group=group)
disp_avg_us = disp_avg_t.item() / num_ranks
comb_avg_us = comb_avg_t.item() / num_ranks
disp_bw_per_rank = disp_bytes / (disp_avg_us * 1e-6) / 1e9
comb_bw_per_rank = comb_bytes / (comb_avg_us * 1e-6) / 1e9
if rank == 0:
print(
f"[bench LL] num_ranks={num_ranks} tokens={num_tokens} hidden={hidden} "
@@ -291,15 +294,15 @@ def main():
flush=True,
)
print(
f" dispatch: {disp_us_t.item():.1f}us (max) "
f"per_rank_bw={disp_bw_agg.item() / num_ranks:.2f} GB/s "
f"agg_bw={disp_bw_agg.item():.2f} GB/s",
f" dispatch: avg={disp_avg_us:.1f}us min={disp_min_t.item():.1f}us max={disp_max_t.item():.1f}us "
f"per_rank_bw={disp_bw_per_rank:.2f} GB/s "
f"agg_bw={disp_bw_per_rank * num_ranks:.2f} GB/s (BW @ avg time)",
flush=True,
)
print(
f" combine : {comb_us_t.item():.1f}us (max) "
f"per_rank_bw={comb_bw_agg.item() / num_ranks:.2f} GB/s "
f"agg_bw={comb_bw_agg.item():.2f} GB/s",
f" combine : avg={comb_avg_us:.1f}us min={comb_min_t.item():.1f}us max={comb_max_t.item():.1f}us "
f"per_rank_bw={comb_bw_per_rank:.2f} GB/s "
f"agg_bw={comb_bw_per_rank * num_ranks:.2f} GB/s (BW @ avg time)",
flush=True,
)