mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 01:10:22 +00:00
ext/ep tests: add optional HT benchmark pass
Gated behind MSCCLPP_EP_BENCH=1 to keep correctness runs fast. Reports per-iter latency (max across ranks, CUDA-event timed) and aggregate effective bandwidth (sum across ranks, dispatch+combine payload bytes). Tunable via MSCCLPP_EP_BENCH_WARMUP / _ITERS / _TOKENS / _HIDDEN. Bench reuses the Buffer allocated for the correctness phase and self-skips if the requested hidden exceeds the per-peer NVL/RDMA budget.
This commit is contained in:
@@ -15,6 +15,12 @@ Launch on each node with (example: 2 nodes x 8 GPUs = 16 ranks):
|
||||
test/python/ext/ep/test_internode_multirank.py
|
||||
|
||||
Round-trip dispatch + combine using internode HT kernels across nodes.
|
||||
|
||||
Set ``MSCCLPP_EP_BENCH=1`` to also run a post-correctness benchmark pass
|
||||
(CUDA-event timed, reports per-iter latency and aggregate effective
|
||||
bandwidth from rank 0). Override iteration counts with
|
||||
``MSCCLPP_EP_BENCH_WARMUP`` / ``MSCCLPP_EP_BENCH_ITERS`` and the bench
|
||||
problem size with ``MSCCLPP_EP_BENCH_TOKENS`` / ``_HIDDEN``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -202,6 +208,111 @@ def main():
|
||||
if rank == 0:
|
||||
print("PASS", flush=True)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Optional benchmark (enable with MSCCLPP_EP_BENCH=1).
|
||||
# ------------------------------------------------------------------
|
||||
if os.environ.get("MSCCLPP_EP_BENCH", "0") != "1":
|
||||
return
|
||||
|
||||
warmup = int(os.environ.get("MSCCLPP_EP_BENCH_WARMUP", "5"))
|
||||
iters = int(os.environ.get("MSCCLPP_EP_BENCH_ITERS", "20"))
|
||||
bench_tokens = int(os.environ.get("MSCCLPP_EP_BENCH_TOKENS", "4096"))
|
||||
bench_hidden = int(os.environ.get("MSCCLPP_EP_BENCH_HIDDEN", "7168"))
|
||||
|
||||
# Respect the Buffer's pre-sized num_nvl_bytes / num_rdma_bytes budget.
|
||||
per_peer_nvl = num_nvl_bytes // max(1, num_ranks)
|
||||
per_peer_rdma = num_rdma_bytes // max(1, num_ranks)
|
||||
if bench_hidden * x.element_size() > min(per_peer_nvl, per_peer_rdma):
|
||||
if rank == 0:
|
||||
print(
|
||||
f"[bench] skip: hidden={bench_hidden} bytes/row={bench_hidden * x.element_size()} "
|
||||
f">= min(per-peer NVL {per_peer_nvl}, RDMA {per_peer_rdma}). "
|
||||
f"Rerun with a larger Buffer or smaller hidden.",
|
||||
flush=True,
|
||||
)
|
||||
return
|
||||
|
||||
scores_b = torch.randn((bench_tokens, num_experts), device="cuda", dtype=torch.float32).abs() + 1
|
||||
topk_idx_b = torch.topk(scores_b, num_topk, dim=-1, sorted=False).indices
|
||||
topk_weights_b = torch.ones((bench_tokens, num_topk), dtype=torch.float32, device="cuda")
|
||||
rank_idx_b = topk_idx_b // (num_experts // num_ranks)
|
||||
rank_idx_b.masked_fill_(topk_idx_b == -1, -1)
|
||||
inplace_unique(rank_idx_b, num_ranks)
|
||||
rdma_rank_idx_b = rank_idx_b // num_local_ranks
|
||||
rdma_rank_idx_b.masked_fill_(rank_idx_b == -1, -1)
|
||||
inplace_unique(rdma_rank_idx_b, num_nodes)
|
||||
|
||||
num_tokens_per_expert_b = torch.zeros((num_experts,), dtype=torch.int, device="cuda")
|
||||
for i in range(num_experts):
|
||||
num_tokens_per_expert_b[i] = (topk_idx_b == i).sum()
|
||||
num_tokens_per_rank_b = torch.empty((num_ranks,), dtype=torch.int, device="cuda")
|
||||
num_tokens_per_rdma_rank_b = torch.empty((num_nodes,), dtype=torch.int, device="cuda")
|
||||
token_idx_in_rank_b = torch.full((num_ranks, bench_tokens), -1, dtype=torch.long, device="cuda")
|
||||
for i in range(num_ranks):
|
||||
num_tokens_per_rank_b[i] = (rank_idx_b == i).sum()
|
||||
token_sel = (rank_idx_b == i).max(dim=-1).values
|
||||
cnt = token_sel.sum().item()
|
||||
tokens = torch.sort(token_sel.to(torch.int), descending=True).indices
|
||||
tokens[:cnt] = torch.sort(tokens[:cnt]).values
|
||||
token_idx_in_rank_b[i][tokens[:cnt]] = torch.arange(cnt, dtype=torch.long, device="cuda")
|
||||
for i in range(num_nodes):
|
||||
num_tokens_per_rdma_rank_b[i] = (rdma_rank_idx_b == i).sum()
|
||||
token_idx_in_rank_b = token_idx_in_rank_b.T.contiguous().to(torch.int)
|
||||
is_token_in_rank_b = token_idx_in_rank_b >= 0
|
||||
x_b = torch.ones((bench_tokens, bench_hidden), dtype=torch.bfloat16, device="cuda") * float(rank)
|
||||
|
||||
def _one_iter():
|
||||
(rx, rxs, rti, rtw, _lst,
|
||||
rpm, gpm, rrcpm, rrps, rgpm, rgps,
|
||||
rsm, sh_rdma, sh_nvl, _ev) = buf.runtime.internode_dispatch(
|
||||
x_b, None, topk_idx_b, topk_weights_b,
|
||||
num_tokens_per_rank_b, num_tokens_per_rdma_rank_b, is_token_in_rank_b, num_tokens_per_expert_b,
|
||||
0, 0, None, None, None, None,
|
||||
1, cfg, None, False, False,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
dist.barrier(group=group)
|
||||
buf.runtime.internode_combine(
|
||||
rx, rtw, rsm, is_token_in_rank_b,
|
||||
rrcpm, rrps, rgpm,
|
||||
sh_rdma, sh_nvl, cfg, None, False, False,
|
||||
)
|
||||
return rx.size(0)
|
||||
|
||||
for _ in range(warmup):
|
||||
_one_iter()
|
||||
torch.cuda.synchronize()
|
||||
dist.barrier(group=group)
|
||||
|
||||
start_ev = torch.cuda.Event(enable_timing=True)
|
||||
end_ev = torch.cuda.Event(enable_timing=True)
|
||||
start_ev.record()
|
||||
recv_tokens_total = 0
|
||||
for _ in range(iters):
|
||||
recv_tokens_total += _one_iter()
|
||||
end_ev.record()
|
||||
torch.cuda.synchronize()
|
||||
elapsed_ms = start_ev.elapsed_time(end_ev)
|
||||
us_per_iter = elapsed_ms * 1e3 / iters
|
||||
|
||||
avg_recv = recv_tokens_total / iters
|
||||
bytes_per_iter = 2 * avg_recv * bench_hidden * x_b.element_size()
|
||||
bw_gbps = bytes_per_iter / (us_per_iter * 1e-6) / 1e9
|
||||
|
||||
bw_t = torch.tensor([bw_gbps], dtype=torch.float64, device="cuda")
|
||||
us_t = torch.tensor([us_per_iter], dtype=torch.float64, device="cuda")
|
||||
dist.all_reduce(bw_t, op=dist.ReduceOp.SUM, group=group)
|
||||
dist.all_reduce(us_t, op=dist.ReduceOp.MAX, group=group)
|
||||
if rank == 0:
|
||||
print(
|
||||
f"[bench internode HT] nodes={num_nodes} num_ranks={num_ranks} "
|
||||
f"tokens={bench_tokens} hidden={bench_hidden} "
|
||||
f"warmup={warmup} iters={iters} "
|
||||
f"per-iter={us_t.item():.1f}us (max across ranks) "
|
||||
f"agg_bw={bw_t.item():.2f} GB/s (sum dispatch+combine)",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
|
||||
@@ -6,6 +6,12 @@ Launch with:
|
||||
Tests that Buffer::sync() succeeds across N GPUs on a single node and that
|
||||
a round-trip dispatch + combine preserves data (sum of top-k weighted copies).
|
||||
|
||||
Set ``MSCCLPP_EP_BENCH=1`` to also run a post-correctness benchmark pass
|
||||
(N iterations with CUDA events; reports per-iter latency and effective
|
||||
NVLink bandwidth from rank 0). Override iteration counts with
|
||||
``MSCCLPP_EP_BENCH_WARMUP`` / ``MSCCLPP_EP_BENCH_ITERS`` / the bench
|
||||
problem size with ``MSCCLPP_EP_BENCH_TOKENS`` / ``_HIDDEN``.
|
||||
|
||||
This is a minimal adaptation of DeepEP's tests/test_intranode.py stripped
|
||||
to exercise only the code paths we've ported.
|
||||
"""
|
||||
@@ -169,6 +175,101 @@ def main():
|
||||
if rank == 0:
|
||||
print("PASS", flush=True)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Optional benchmark (enable with MSCCLPP_EP_BENCH=1).
|
||||
# ------------------------------------------------------------------
|
||||
if os.environ.get("MSCCLPP_EP_BENCH", "0") != "1":
|
||||
return
|
||||
|
||||
warmup = int(os.environ.get("MSCCLPP_EP_BENCH_WARMUP", "5"))
|
||||
iters = int(os.environ.get("MSCCLPP_EP_BENCH_ITERS", "20"))
|
||||
bench_tokens = int(os.environ.get("MSCCLPP_EP_BENCH_TOKENS", "4096"))
|
||||
bench_hidden = int(os.environ.get("MSCCLPP_EP_BENCH_HIDDEN", "7168"))
|
||||
|
||||
# Rebuild inputs at bench size. Keep same layout recipe as above but at
|
||||
# larger (num_tokens, hidden); Buffer is sized off the original cfg+hidden,
|
||||
# so bench must fit within num_nvl_bytes. If it doesn't, we skip.
|
||||
if bench_hidden * x.element_size() > (num_nvl_bytes // max(1, num_ranks)):
|
||||
if rank == 0:
|
||||
print(
|
||||
f"[bench] skip: hidden={bench_hidden} bytes/row={bench_hidden * x.element_size()} "
|
||||
f"> per-peer budget {num_nvl_bytes // num_ranks}. "
|
||||
f"Rerun with a larger Buffer or smaller hidden.",
|
||||
flush=True,
|
||||
)
|
||||
return
|
||||
|
||||
scores_b = torch.randn((bench_tokens, num_experts), device="cuda", dtype=torch.float32).abs() + 1
|
||||
topk_idx_b = torch.topk(scores_b, num_topk, dim=-1, sorted=False).indices
|
||||
topk_weights_b = torch.ones((bench_tokens, num_topk), dtype=torch.float32, device="cuda")
|
||||
rank_idx_b = topk_idx_b // (num_experts // num_ranks)
|
||||
rank_idx_b.masked_fill_(topk_idx_b == -1, -1)
|
||||
inplace_unique(rank_idx_b, num_ranks)
|
||||
num_tokens_per_expert_b = torch.zeros((num_experts,), dtype=torch.int, device="cuda")
|
||||
for i in range(num_experts):
|
||||
num_tokens_per_expert_b[i] = (topk_idx_b == i).sum()
|
||||
num_tokens_per_rank_b = torch.empty((num_ranks,), dtype=torch.int, device="cuda")
|
||||
token_idx_in_rank_b = torch.full((num_ranks, bench_tokens), -1, dtype=torch.long, device="cuda")
|
||||
for i in range(num_ranks):
|
||||
num_tokens_per_rank_b[i] = (rank_idx_b == i).sum()
|
||||
token_sel = (rank_idx_b == i).max(dim=-1).values
|
||||
cnt = token_sel.sum().item()
|
||||
tokens = torch.sort(token_sel.to(torch.int), descending=True).indices
|
||||
tokens[:cnt] = torch.sort(tokens[:cnt]).values
|
||||
token_idx_in_rank_b[i][tokens[:cnt]] = torch.arange(cnt, dtype=torch.long, device="cuda")
|
||||
token_idx_in_rank_b = token_idx_in_rank_b.T.contiguous().to(torch.int)
|
||||
is_token_in_rank_b = token_idx_in_rank_b >= 0
|
||||
x_b = torch.ones((bench_tokens, bench_hidden), dtype=torch.bfloat16, device="cuda") * float(rank)
|
||||
|
||||
def _one_iter():
|
||||
(rx, rxs, rti, rtw, _lst, rpm, cpm, rcpm, rsi, sh, _ev) = buf.runtime.intranode_dispatch(
|
||||
x_b, None, topk_idx_b, topk_weights_b,
|
||||
num_tokens_per_rank_b, is_token_in_rank_b, num_tokens_per_expert_b,
|
||||
0, None, None, 1, cfg, None, False, False,
|
||||
)
|
||||
buf.runtime.intranode_combine(
|
||||
rx, rtw, rsi, rpm, rcpm, sh, cfg, None, False, False,
|
||||
)
|
||||
return rx.size(0)
|
||||
|
||||
# Warmup
|
||||
for _ in range(warmup):
|
||||
_one_iter()
|
||||
torch.cuda.synchronize()
|
||||
dist.barrier(group=group)
|
||||
|
||||
start_ev = torch.cuda.Event(enable_timing=True)
|
||||
end_ev = torch.cuda.Event(enable_timing=True)
|
||||
start_ev.record()
|
||||
recv_tokens_total = 0
|
||||
for _ in range(iters):
|
||||
recv_tokens_total += _one_iter()
|
||||
end_ev.record()
|
||||
torch.cuda.synchronize()
|
||||
elapsed_ms = start_ev.elapsed_time(end_ev)
|
||||
us_per_iter = elapsed_ms * 1e3 / iters
|
||||
|
||||
# Rough effective BW per rank: dispatched + combined payload bytes through
|
||||
# comm. We treat each iter as sending `recv_tokens * hidden * elt_size`
|
||||
# one way for dispatch and the same back for combine.
|
||||
avg_recv = recv_tokens_total / iters
|
||||
bytes_per_iter = 2 * avg_recv * bench_hidden * x_b.element_size()
|
||||
bw_gbps = bytes_per_iter / (us_per_iter * 1e-6) / 1e9
|
||||
|
||||
# Aggregate across ranks
|
||||
bw_t = torch.tensor([bw_gbps], dtype=torch.float64, device="cuda")
|
||||
us_t = torch.tensor([us_per_iter], dtype=torch.float64, device="cuda")
|
||||
dist.all_reduce(bw_t, op=dist.ReduceOp.SUM, group=group)
|
||||
dist.all_reduce(us_t, op=dist.ReduceOp.MAX, group=group)
|
||||
if rank == 0:
|
||||
print(
|
||||
f"[bench intranode] tokens={bench_tokens} hidden={bench_hidden} "
|
||||
f"warmup={warmup} iters={iters} "
|
||||
f"per-iter={us_t.item():.1f}us (max across ranks) "
|
||||
f"agg_bw={bw_t.item():.2f} GB/s (sum dispatch+combine)",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user