mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 09:17:06 +00:00
ext/ep tests: time dispatch and combine separately in MSCCLPP_EP_BENCH
Previously the optional benchmark measured full round-trip latency. Split it to time dispatch alone (N iters) and combine alone (N iters reusing one dispatch output), reporting per-phase latency (max across ranks) and aggregate effective bandwidth (sum across ranks). Applies to intranode HT, internode HT, and the (currently unreachable on intra-node 8-GPU) LL test. Internode HT keeps the sync+barrier guard between dispatch and combine but excludes it from either phase's timing.
This commit is contained in:
@@ -17,8 +17,9 @@ Launch on each node with (example: 2 nodes x 8 GPUs = 16 ranks):
|
||||
Round-trip dispatch + combine using internode HT kernels across nodes.
|
||||
|
||||
Set ``MSCCLPP_EP_BENCH=1`` to also run a post-correctness benchmark pass
|
||||
(CUDA-event timed, reports per-iter latency and aggregate effective
|
||||
bandwidth from rank 0). Override iteration counts with
|
||||
that times dispatch and combine **separately** with CUDA events. Reports
|
||||
per-phase latency (max across ranks) plus aggregate effective bandwidth
|
||||
(sum across ranks). Override iteration counts with
|
||||
``MSCCLPP_EP_BENCH_WARMUP`` / ``MSCCLPP_EP_BENCH_ITERS`` and the bench
|
||||
problem size with ``MSCCLPP_EP_BENCH_TOKENS`` / ``_HIDDEN``.
|
||||
"""
|
||||
@@ -261,55 +262,84 @@ def main():
|
||||
is_token_in_rank_b = token_idx_in_rank_b >= 0
|
||||
x_b = torch.ones((bench_tokens, bench_hidden), dtype=torch.bfloat16, device="cuda") * float(rank)
|
||||
|
||||
def _one_iter():
|
||||
(rx, rxs, rti, rtw, _lst,
|
||||
rpm, gpm, rrcpm, rrps, rgpm, rgps,
|
||||
rsm, sh_rdma, sh_nvl, _ev) = buf.runtime.internode_dispatch(
|
||||
def _dispatch():
|
||||
return buf.runtime.internode_dispatch(
|
||||
x_b, None, topk_idx_b, topk_weights_b,
|
||||
num_tokens_per_rank_b, num_tokens_per_rdma_rank_b, is_token_in_rank_b, num_tokens_per_expert_b,
|
||||
0, 0, None, None, None, None,
|
||||
1, cfg, None, False, False,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
dist.barrier(group=group)
|
||||
|
||||
def _combine(dout):
|
||||
(rx, _rxs, _rti, rtw, _lst,
|
||||
_rpm, _gpm, rrcpm, rrps, rgpm, _rgps,
|
||||
rsm, sh_rdma, sh_nvl, _ev) = dout
|
||||
buf.runtime.internode_combine(
|
||||
rx, rtw, rsm, is_token_in_rank_b,
|
||||
rrcpm, rrps, rgpm,
|
||||
sh_rdma, sh_nvl, cfg, None, False, False,
|
||||
)
|
||||
return rx.size(0)
|
||||
|
||||
# Warmup (full round-trip with the sync/barrier guard between phases,
|
||||
# matching the correctness-path invariant).
|
||||
for _ in range(warmup):
|
||||
_one_iter()
|
||||
dout = _dispatch()
|
||||
torch.cuda.synchronize()
|
||||
dist.barrier(group=group)
|
||||
_combine(dout)
|
||||
torch.cuda.synchronize()
|
||||
dist.barrier(group=group)
|
||||
|
||||
# Time dispatch alone.
|
||||
start_ev = torch.cuda.Event(enable_timing=True)
|
||||
end_ev = torch.cuda.Event(enable_timing=True)
|
||||
start_ev.record()
|
||||
recv_tokens_total = 0
|
||||
dout = None
|
||||
for _ in range(iters):
|
||||
recv_tokens_total += _one_iter()
|
||||
dout = _dispatch()
|
||||
end_ev.record()
|
||||
torch.cuda.synchronize()
|
||||
elapsed_ms = start_ev.elapsed_time(end_ev)
|
||||
us_per_iter = elapsed_ms * 1e3 / iters
|
||||
disp_us = start_ev.elapsed_time(end_ev) * 1e3 / iters
|
||||
recv_tokens = dout[0].size(0)
|
||||
|
||||
avg_recv = recv_tokens_total / iters
|
||||
bytes_per_iter = 2 * avg_recv * bench_hidden * x_b.element_size()
|
||||
bw_gbps = bytes_per_iter / (us_per_iter * 1e-6) / 1e9
|
||||
# Required guard before combine sees the dispatch outputs (see correctness
|
||||
# path's XXX note). Not included in either phase's timing.
|
||||
torch.cuda.synchronize()
|
||||
dist.barrier(group=group)
|
||||
|
||||
bw_t = torch.tensor([bw_gbps], dtype=torch.float64, device="cuda")
|
||||
us_t = torch.tensor([us_per_iter], dtype=torch.float64, device="cuda")
|
||||
dist.all_reduce(bw_t, op=dist.ReduceOp.SUM, group=group)
|
||||
dist.all_reduce(us_t, op=dist.ReduceOp.MAX, group=group)
|
||||
# Time combine alone (reusing the same dispatch output each iter).
|
||||
start_ev.record()
|
||||
for _ in range(iters):
|
||||
_combine(dout)
|
||||
end_ev.record()
|
||||
torch.cuda.synchronize()
|
||||
comb_us = start_ev.elapsed_time(end_ev) * 1e3 / iters
|
||||
|
||||
bytes_one_way = recv_tokens * bench_hidden * x_b.element_size()
|
||||
disp_bw = bytes_one_way / (disp_us * 1e-6) / 1e9
|
||||
comb_bw = bytes_one_way / (comb_us * 1e-6) / 1e9
|
||||
|
||||
disp_us_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda")
|
||||
comb_us_t = torch.tensor([comb_us], dtype=torch.float64, device="cuda")
|
||||
disp_bw_t = torch.tensor([disp_bw], dtype=torch.float64, device="cuda")
|
||||
comb_bw_t = torch.tensor([comb_bw], dtype=torch.float64, device="cuda")
|
||||
dist.all_reduce(disp_us_t, op=dist.ReduceOp.MAX, group=group)
|
||||
dist.all_reduce(comb_us_t, op=dist.ReduceOp.MAX, group=group)
|
||||
dist.all_reduce(disp_bw_t, op=dist.ReduceOp.SUM, group=group)
|
||||
dist.all_reduce(comb_bw_t, op=dist.ReduceOp.SUM, group=group)
|
||||
if rank == 0:
|
||||
print(
|
||||
f"[bench internode HT] nodes={num_nodes} num_ranks={num_ranks} "
|
||||
f"tokens={bench_tokens} hidden={bench_hidden} "
|
||||
f"warmup={warmup} iters={iters} "
|
||||
f"per-iter={us_t.item():.1f}us (max across ranks) "
|
||||
f"agg_bw={bw_t.item():.2f} GB/s (sum dispatch+combine)",
|
||||
f"warmup={warmup} iters={iters}",
|
||||
flush=True,
|
||||
)
|
||||
print(
|
||||
f" dispatch: {disp_us_t.item():.1f}us (max) agg_bw={disp_bw_t.item():.2f} GB/s",
|
||||
flush=True,
|
||||
)
|
||||
print(
|
||||
f" combine : {comb_us_t.item():.1f}us (max) agg_bw={comb_bw_t.item():.2f} GB/s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -7,9 +7,10 @@ Tests that Buffer::sync() succeeds across N GPUs on a single node and that
|
||||
a round-trip dispatch + combine preserves data (sum of top-k weighted copies).
|
||||
|
||||
Set ``MSCCLPP_EP_BENCH=1`` to also run a post-correctness benchmark pass
|
||||
(N iterations with CUDA events; reports per-iter latency and effective
|
||||
NVLink bandwidth from rank 0). Override iteration counts with
|
||||
``MSCCLPP_EP_BENCH_WARMUP`` / ``MSCCLPP_EP_BENCH_ITERS`` / the bench
|
||||
that times dispatch and combine **separately** with CUDA events and
|
||||
reports per-phase latency (max across ranks) plus aggregate effective
|
||||
NVLink bandwidth (sum across ranks). Override iteration counts with
|
||||
``MSCCLPP_EP_BENCH_WARMUP`` / ``MSCCLPP_EP_BENCH_ITERS`` and the bench
|
||||
problem size with ``MSCCLPP_EP_BENCH_TOKENS`` / ``_HIDDEN``.
|
||||
|
||||
This is a minimal adaptation of DeepEP's tests/test_intranode.py stripped
|
||||
@@ -221,52 +222,71 @@ def main():
|
||||
is_token_in_rank_b = token_idx_in_rank_b >= 0
|
||||
x_b = torch.ones((bench_tokens, bench_hidden), dtype=torch.bfloat16, device="cuda") * float(rank)
|
||||
|
||||
def _one_iter():
|
||||
(rx, rxs, rti, rtw, _lst, rpm, cpm, rcpm, rsi, sh, _ev) = buf.runtime.intranode_dispatch(
|
||||
def _dispatch():
|
||||
return buf.runtime.intranode_dispatch(
|
||||
x_b, None, topk_idx_b, topk_weights_b,
|
||||
num_tokens_per_rank_b, is_token_in_rank_b, num_tokens_per_expert_b,
|
||||
0, None, None, 1, cfg, None, False, False,
|
||||
)
|
||||
|
||||
def _combine(dout):
|
||||
(rx, _rxs, _rti, rtw, _lst, rpm, _cpm, rcpm, rsi, sh, _ev) = dout
|
||||
buf.runtime.intranode_combine(
|
||||
rx, rtw, rsi, rpm, rcpm, sh, cfg, None, False, False,
|
||||
)
|
||||
return rx.size(0)
|
||||
|
||||
# Warmup
|
||||
# Warmup (full round-trip).
|
||||
for _ in range(warmup):
|
||||
_one_iter()
|
||||
_combine(_dispatch())
|
||||
torch.cuda.synchronize()
|
||||
dist.barrier(group=group)
|
||||
|
||||
# Time dispatch alone.
|
||||
start_ev = torch.cuda.Event(enable_timing=True)
|
||||
end_ev = torch.cuda.Event(enable_timing=True)
|
||||
start_ev.record()
|
||||
recv_tokens_total = 0
|
||||
dout = None
|
||||
for _ in range(iters):
|
||||
recv_tokens_total += _one_iter()
|
||||
dout = _dispatch()
|
||||
end_ev.record()
|
||||
torch.cuda.synchronize()
|
||||
elapsed_ms = start_ev.elapsed_time(end_ev)
|
||||
us_per_iter = elapsed_ms * 1e3 / iters
|
||||
disp_us = start_ev.elapsed_time(end_ev) * 1e3 / iters
|
||||
recv_tokens = dout[0].size(0)
|
||||
|
||||
# Rough effective BW per rank: dispatched + combined payload bytes through
|
||||
# comm. We treat each iter as sending `recv_tokens * hidden * elt_size`
|
||||
# one way for dispatch and the same back for combine.
|
||||
avg_recv = recv_tokens_total / iters
|
||||
bytes_per_iter = 2 * avg_recv * bench_hidden * x_b.element_size()
|
||||
bw_gbps = bytes_per_iter / (us_per_iter * 1e-6) / 1e9
|
||||
# Time combine alone (reusing the same dispatch output each iter).
|
||||
dist.barrier(group=group)
|
||||
start_ev.record()
|
||||
for _ in range(iters):
|
||||
_combine(dout)
|
||||
end_ev.record()
|
||||
torch.cuda.synchronize()
|
||||
comb_us = start_ev.elapsed_time(end_ev) * 1e3 / iters
|
||||
|
||||
# Aggregate across ranks
|
||||
bw_t = torch.tensor([bw_gbps], dtype=torch.float64, device="cuda")
|
||||
us_t = torch.tensor([us_per_iter], dtype=torch.float64, device="cuda")
|
||||
dist.all_reduce(bw_t, op=dist.ReduceOp.SUM, group=group)
|
||||
dist.all_reduce(us_t, op=dist.ReduceOp.MAX, group=group)
|
||||
# One-way payload bytes (per phase) per rank.
|
||||
bytes_one_way = recv_tokens * bench_hidden * x_b.element_size()
|
||||
disp_bw = bytes_one_way / (disp_us * 1e-6) / 1e9
|
||||
comb_bw = bytes_one_way / (comb_us * 1e-6) / 1e9
|
||||
|
||||
disp_us_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda")
|
||||
comb_us_t = torch.tensor([comb_us], dtype=torch.float64, device="cuda")
|
||||
disp_bw_t = torch.tensor([disp_bw], dtype=torch.float64, device="cuda")
|
||||
comb_bw_t = torch.tensor([comb_bw], dtype=torch.float64, device="cuda")
|
||||
dist.all_reduce(disp_us_t, op=dist.ReduceOp.MAX, group=group)
|
||||
dist.all_reduce(comb_us_t, op=dist.ReduceOp.MAX, group=group)
|
||||
dist.all_reduce(disp_bw_t, op=dist.ReduceOp.SUM, group=group)
|
||||
dist.all_reduce(comb_bw_t, op=dist.ReduceOp.SUM, group=group)
|
||||
if rank == 0:
|
||||
print(
|
||||
f"[bench intranode] tokens={bench_tokens} hidden={bench_hidden} "
|
||||
f"warmup={warmup} iters={iters} "
|
||||
f"per-iter={us_t.item():.1f}us (max across ranks) "
|
||||
f"agg_bw={bw_t.item():.2f} GB/s (sum dispatch+combine)",
|
||||
f"[bench intranode HT] tokens={bench_tokens} hidden={bench_hidden} "
|
||||
f"warmup={warmup} iters={iters}",
|
||||
flush=True,
|
||||
)
|
||||
print(
|
||||
f" dispatch: {disp_us_t.item():.1f}us (max) agg_bw={disp_bw_t.item():.2f} GB/s",
|
||||
flush=True,
|
||||
)
|
||||
print(
|
||||
f" combine : {comb_us_t.item():.1f}us (max) agg_bw={comb_bw_t.item():.2f} GB/s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -203,6 +203,89 @@ def main():
|
||||
if rank == 0:
|
||||
print("PASS", flush=True)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Optional benchmark (enable with MSCCLPP_EP_BENCH=1). Times dispatch
|
||||
# and combine separately, reporting per-iter latency (max across ranks)
|
||||
# and aggregate effective bandwidth (sum across ranks).
|
||||
# ------------------------------------------------------------------
|
||||
if os.environ.get("MSCCLPP_EP_BENCH", "0") != "1":
|
||||
return
|
||||
|
||||
warmup = int(os.environ.get("MSCCLPP_EP_BENCH_WARMUP", "5"))
|
||||
iters = int(os.environ.get("MSCCLPP_EP_BENCH_ITERS", "20"))
|
||||
|
||||
def _dispatch():
|
||||
return buf.low_latency_dispatch(
|
||||
x, topk_idx, num_tokens, num_experts,
|
||||
False, False, False, # use_fp8, async, return_recv_hook
|
||||
)
|
||||
|
||||
def _combine(dout):
|
||||
(recv_x, _scales, _cnt, src_info_, layout_range_, _ev, _hk) = dout
|
||||
out_ = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
|
||||
buf.low_latency_combine(
|
||||
recv_x.clone(), topk_idx, topk_weights,
|
||||
src_info_, layout_range_,
|
||||
num_tokens, num_experts,
|
||||
False, False, False,
|
||||
out_,
|
||||
)
|
||||
|
||||
for _ in range(warmup):
|
||||
_combine(_dispatch())
|
||||
torch.cuda.synchronize()
|
||||
dist.barrier(group=group)
|
||||
|
||||
start_ev = torch.cuda.Event(enable_timing=True)
|
||||
end_ev = torch.cuda.Event(enable_timing=True)
|
||||
start_ev.record()
|
||||
dout = None
|
||||
for _ in range(iters):
|
||||
dout = _dispatch()
|
||||
end_ev.record()
|
||||
torch.cuda.synchronize()
|
||||
disp_us = start_ev.elapsed_time(end_ev) * 1e3 / iters
|
||||
recv_tokens = int(dout[2].sum().item()) # packed_recv_count summed over local experts
|
||||
|
||||
dist.barrier(group=group)
|
||||
start_ev.record()
|
||||
for _ in range(iters):
|
||||
_combine(dout)
|
||||
end_ev.record()
|
||||
torch.cuda.synchronize()
|
||||
comb_us = start_ev.elapsed_time(end_ev) * 1e3 / iters
|
||||
|
||||
# Dispatch payload: recv_tokens × hidden × bf16 (received on this rank).
|
||||
# Combine payload: num_tokens × hidden × bf16 (sent from each local expert
|
||||
# back to the owning rank; one token's worth of bytes per reduction).
|
||||
disp_bytes = recv_tokens * hidden * 2
|
||||
comb_bytes = num_tokens * hidden * 2
|
||||
disp_bw = disp_bytes / (disp_us * 1e-6) / 1e9
|
||||
comb_bw = comb_bytes / (comb_us * 1e-6) / 1e9
|
||||
|
||||
disp_us_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda")
|
||||
comb_us_t = torch.tensor([comb_us], dtype=torch.float64, device="cuda")
|
||||
disp_bw_t = torch.tensor([disp_bw], dtype=torch.float64, device="cuda")
|
||||
comb_bw_t = torch.tensor([comb_bw], dtype=torch.float64, device="cuda")
|
||||
dist.all_reduce(disp_us_t, op=dist.ReduceOp.MAX, group=group)
|
||||
dist.all_reduce(comb_us_t, op=dist.ReduceOp.MAX, group=group)
|
||||
dist.all_reduce(disp_bw_t, op=dist.ReduceOp.SUM, group=group)
|
||||
dist.all_reduce(comb_bw_t, op=dist.ReduceOp.SUM, group=group)
|
||||
if rank == 0:
|
||||
print(
|
||||
f"[bench LL] num_ranks={num_ranks} tokens={num_tokens} hidden={hidden} "
|
||||
f"num_experts={num_experts} warmup={warmup} iters={iters}",
|
||||
flush=True,
|
||||
)
|
||||
print(
|
||||
f" dispatch: {disp_us_t.item():.1f}us (max) agg_bw={disp_bw_t.item():.2f} GB/s",
|
||||
flush=True,
|
||||
)
|
||||
print(
|
||||
f" combine : {comb_us_t.item():.1f}us (max) agg_bw={comb_bw_t.item():.2f} GB/s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user