ext/ep tests: time dispatch and combine separately in MSCCLPP_EP_BENCH

Previously the optional benchmark measured full round-trip latency. Split
it to time dispatch alone (N iters) and combine alone (N iters reusing
one dispatch output), reporting per-phase latency (max across ranks) and
aggregate effective bandwidth (sum across ranks).

Applies to intranode HT, internode HT, and the (currently unreachable on
intra-node 8-GPU) LL test. Internode HT keeps the sync+barrier guard
between dispatch and combine but excludes it from either phase's timing.
This commit is contained in:
Qinghua Zhou
2026-04-22 23:11:04 +00:00
parent 2391ce1de7
commit c51a8a5305
3 changed files with 184 additions and 51 deletions

View File

@@ -17,8 +17,9 @@ Launch on each node with (example: 2 nodes x 8 GPUs = 16 ranks):
Round-trip dispatch + combine using internode HT kernels across nodes.
Set ``MSCCLPP_EP_BENCH=1`` to also run a post-correctness benchmark pass
(CUDA-event timed, reports per-iter latency and aggregate effective
bandwidth from rank 0). Override iteration counts with
that times dispatch and combine **separately** with CUDA events. Reports
per-phase latency (max across ranks) plus aggregate effective bandwidth
(sum across ranks). Override iteration counts with
``MSCCLPP_EP_BENCH_WARMUP`` / ``MSCCLPP_EP_BENCH_ITERS`` and the bench
problem size with ``MSCCLPP_EP_BENCH_TOKENS`` / ``_HIDDEN``.
"""
@@ -261,55 +262,84 @@ def main():
is_token_in_rank_b = token_idx_in_rank_b >= 0
x_b = torch.ones((bench_tokens, bench_hidden), dtype=torch.bfloat16, device="cuda") * float(rank)
def _one_iter():
(rx, rxs, rti, rtw, _lst,
rpm, gpm, rrcpm, rrps, rgpm, rgps,
rsm, sh_rdma, sh_nvl, _ev) = buf.runtime.internode_dispatch(
def _dispatch():
return buf.runtime.internode_dispatch(
x_b, None, topk_idx_b, topk_weights_b,
num_tokens_per_rank_b, num_tokens_per_rdma_rank_b, is_token_in_rank_b, num_tokens_per_expert_b,
0, 0, None, None, None, None,
1, cfg, None, False, False,
)
torch.cuda.synchronize()
dist.barrier(group=group)
def _combine(dout):
(rx, _rxs, _rti, rtw, _lst,
_rpm, _gpm, rrcpm, rrps, rgpm, _rgps,
rsm, sh_rdma, sh_nvl, _ev) = dout
buf.runtime.internode_combine(
rx, rtw, rsm, is_token_in_rank_b,
rrcpm, rrps, rgpm,
sh_rdma, sh_nvl, cfg, None, False, False,
)
return rx.size(0)
# Warmup (full round-trip with the sync/barrier guard between phases,
# matching the correctness-path invariant).
for _ in range(warmup):
_one_iter()
dout = _dispatch()
torch.cuda.synchronize()
dist.barrier(group=group)
_combine(dout)
torch.cuda.synchronize()
dist.barrier(group=group)
# Time dispatch alone.
start_ev = torch.cuda.Event(enable_timing=True)
end_ev = torch.cuda.Event(enable_timing=True)
start_ev.record()
recv_tokens_total = 0
dout = None
for _ in range(iters):
recv_tokens_total += _one_iter()
dout = _dispatch()
end_ev.record()
torch.cuda.synchronize()
elapsed_ms = start_ev.elapsed_time(end_ev)
us_per_iter = elapsed_ms * 1e3 / iters
disp_us = start_ev.elapsed_time(end_ev) * 1e3 / iters
recv_tokens = dout[0].size(0)
avg_recv = recv_tokens_total / iters
bytes_per_iter = 2 * avg_recv * bench_hidden * x_b.element_size()
bw_gbps = bytes_per_iter / (us_per_iter * 1e-6) / 1e9
# Required guard before combine sees the dispatch outputs (see correctness
# path's XXX note). Not included in either phase's timing.
torch.cuda.synchronize()
dist.barrier(group=group)
bw_t = torch.tensor([bw_gbps], dtype=torch.float64, device="cuda")
us_t = torch.tensor([us_per_iter], dtype=torch.float64, device="cuda")
dist.all_reduce(bw_t, op=dist.ReduceOp.SUM, group=group)
dist.all_reduce(us_t, op=dist.ReduceOp.MAX, group=group)
# Time combine alone (reusing the same dispatch output each iter).
start_ev.record()
for _ in range(iters):
_combine(dout)
end_ev.record()
torch.cuda.synchronize()
comb_us = start_ev.elapsed_time(end_ev) * 1e3 / iters
bytes_one_way = recv_tokens * bench_hidden * x_b.element_size()
disp_bw = bytes_one_way / (disp_us * 1e-6) / 1e9
comb_bw = bytes_one_way / (comb_us * 1e-6) / 1e9
disp_us_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda")
comb_us_t = torch.tensor([comb_us], dtype=torch.float64, device="cuda")
disp_bw_t = torch.tensor([disp_bw], dtype=torch.float64, device="cuda")
comb_bw_t = torch.tensor([comb_bw], dtype=torch.float64, device="cuda")
dist.all_reduce(disp_us_t, op=dist.ReduceOp.MAX, group=group)
dist.all_reduce(comb_us_t, op=dist.ReduceOp.MAX, group=group)
dist.all_reduce(disp_bw_t, op=dist.ReduceOp.SUM, group=group)
dist.all_reduce(comb_bw_t, op=dist.ReduceOp.SUM, group=group)
if rank == 0:
print(
f"[bench internode HT] nodes={num_nodes} num_ranks={num_ranks} "
f"tokens={bench_tokens} hidden={bench_hidden} "
f"warmup={warmup} iters={iters} "
f"per-iter={us_t.item():.1f}us (max across ranks) "
f"agg_bw={bw_t.item():.2f} GB/s (sum dispatch+combine)",
f"warmup={warmup} iters={iters}",
flush=True,
)
print(
f" dispatch: {disp_us_t.item():.1f}us (max) agg_bw={disp_bw_t.item():.2f} GB/s",
flush=True,
)
print(
f" combine : {comb_us_t.item():.1f}us (max) agg_bw={comb_bw_t.item():.2f} GB/s",
flush=True,
)