tests/ep: add NCCL-EP six-metric BW breakdown (send/recv x total/nvl/rdma)

For HT intra/internode benches, compute per-rank avg total_send/rdma_send
and total_recv/rdma_recv token counts (matching NCCL-EP ep_bench
accounting) and print send-side and recv-side BW split into total / nvl
/ rdma columns. Combine reverses send<->recv. Byte-count line mirrors
NCCL-EP's '(per rank avg)' summary so numbers are directly comparable.
This commit is contained in:
Qinghua Zhou
2026-04-29 20:44:10 +00:00
parent f2feb120b8
commit e752dbaf97
2 changed files with 148 additions and 0 deletions

View File

@@ -29,6 +29,12 @@ from __future__ import annotations
import os
import sys
# Disable ProcessGroupNCCL's HeartbeatMonitor before importing torch.distributed.
# It runs in a background thread polling the TCPStore; under mpirun, rank 0
# (the store server) can exit before non-zero ranks finish teardown, producing
# noisy 'recvValue failed / Connection was likely closed' stack traces.
os.environ.setdefault("TORCH_NCCL_ENABLE_MONITORING", "0")
import torch
import torch.distributed as dist
@@ -343,6 +349,45 @@ def main():
# convention here so `per_rank_bw` is directly comparable across stacks.
bytes_one_way = bench_tokens * bench_hidden * x_b.element_size()
# NCCL-EP `ep_bench` six-metric breakdown.
# Send-side accounting follows NCCL-EP: count unique (token, dst_node) pairs.
# `num_tokens_per_rdma_rank_b[n]` is exactly that count for node `n`.
# Recv-side accounting: each rank reports `num_tokens_per_rank_b[r]`
# (tokens it sends to dst rank `r`); an `all_to_all_single` lets every
# rank read how many tokens each source rank sent to it.
bytes_per_token = bench_hidden * x_b.element_size()
local_node = rank // num_local_ranks
nodes_unique = num_tokens_per_rdma_rank_b.to(torch.int64)
total_send_tokens_local = int(nodes_unique.sum().item())
nvl_send_tokens_local = int(nodes_unique[local_node].item())
rdma_send_tokens_local = total_send_tokens_local - nvl_send_tokens_local
recv_from_src = torch.empty(num_ranks, dtype=torch.int64, device="cuda")
dist.all_to_all_single(
recv_from_src,
num_tokens_per_rank_b.to(torch.int64),
group=group,
)
src_node = (torch.arange(num_ranks, device="cuda") // num_local_ranks)
remote_mask = (src_node != local_node).to(torch.int64)
total_recv_tokens_local = int(recv_from_src.sum().item())
rdma_recv_tokens_local = int((recv_from_src * remote_mask).sum().item())
# Average per-rank token counts across ranks (matches NCCL-EP `Byte counts (per rank avg)`).
counts_t = torch.tensor(
[total_send_tokens_local, rdma_send_tokens_local,
total_recv_tokens_local, rdma_recv_tokens_local],
dtype=torch.float64, device="cuda",
)
dist.all_reduce(counts_t, op=dist.ReduceOp.SUM, group=group)
counts_avg = (counts_t / num_ranks).tolist()
total_send_avg, rdma_send_avg, total_recv_avg, rdma_recv_avg = counts_avg
total_send_bytes = total_send_avg * bytes_per_token
rdma_send_bytes = rdma_send_avg * bytes_per_token
total_recv_bytes = total_recv_avg * bytes_per_token
rdma_recv_bytes = rdma_recv_avg * bytes_per_token
nvl_send_bytes = total_send_bytes - rdma_send_bytes
nvl_recv_bytes = total_recv_bytes - rdma_recv_bytes
# Reduce timings: report min/avg/max and base BW on AVG to match NCCL-EP's
# `ep_bench.cu` convention.
disp_min_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda")
@@ -361,6 +406,22 @@ def main():
comb_avg_us = comb_avg_t.item() / num_ranks
disp_bw_per_rank = bytes_one_way / (disp_avg_us * 1e-6) / 1e9
comb_bw_per_rank = bytes_one_way / (comb_avg_us * 1e-6) / 1e9
# Six-metric BW (NCCL-EP convention). Combine reverses send<->recv:
# in combine, this rank pushes back what it received in dispatch.
disp_t_s = disp_avg_us * 1e-6
comb_t_s = comb_avg_us * 1e-6
d_send_total_bw = total_send_bytes / disp_t_s / 1e9
d_send_nvl_bw = nvl_send_bytes / disp_t_s / 1e9
d_send_rdma_bw = rdma_send_bytes / disp_t_s / 1e9
d_recv_total_bw = total_recv_bytes / disp_t_s / 1e9
d_recv_nvl_bw = nvl_recv_bytes / disp_t_s / 1e9
d_recv_rdma_bw = rdma_recv_bytes / disp_t_s / 1e9
c_send_total_bw = total_recv_bytes / comb_t_s / 1e9
c_send_nvl_bw = nvl_recv_bytes / comb_t_s / 1e9
c_send_rdma_bw = rdma_recv_bytes / comb_t_s / 1e9
c_recv_total_bw = total_send_bytes / comb_t_s / 1e9
c_recv_nvl_bw = nvl_send_bytes / comb_t_s / 1e9
c_recv_rdma_bw = rdma_send_bytes / comb_t_s / 1e9
if rank == 0:
print(
f"[bench internode HT] nodes={num_nodes} num_ranks={num_ranks} "
@@ -375,12 +436,30 @@ def main():
f"agg_bw={disp_bw_per_rank * num_ranks:.2f} GB/s (BW @ avg time)",
flush=True,
)
print(
f" send: total={d_send_total_bw:.2f} nvl={d_send_nvl_bw:.2f} rdma={d_send_rdma_bw:.2f} GB/s "
f"recv: total={d_recv_total_bw:.2f} nvl={d_recv_nvl_bw:.2f} rdma={d_recv_rdma_bw:.2f} GB/s",
flush=True,
)
print(
f" combine : avg={comb_avg_us:.1f}us min={comb_min_t.item():.1f}us max={comb_max_t.item():.1f}us "
f"per_rank_bw={comb_bw_per_rank:.2f} GB/s "
f"agg_bw={comb_bw_per_rank * num_ranks:.2f} GB/s (BW @ avg time)",
flush=True,
)
print(
f" send: total={c_send_total_bw:.2f} nvl={c_send_nvl_bw:.2f} rdma={c_send_rdma_bw:.2f} GB/s "
f"recv: total={c_recv_total_bw:.2f} nvl={c_recv_nvl_bw:.2f} rdma={c_recv_rdma_bw:.2f} GB/s",
flush=True,
)
print(
f" byte counts (per rank avg): "
f"total_send={total_send_bytes/1e6:.2f} MB ({total_send_avg:.0f} tok) "
f"rdma_send={rdma_send_bytes/1e6:.2f} MB ({rdma_send_avg:.0f} tok) "
f"total_recv={total_recv_bytes/1e6:.2f} MB ({total_recv_avg:.0f} tok) "
f"rdma_recv={rdma_recv_bytes/1e6:.2f} MB ({rdma_recv_avg:.0f} tok)",
flush=True,
)
if __name__ == "__main__":

View File

@@ -22,6 +22,12 @@ from __future__ import annotations
import os
import sys
# Disable ProcessGroupNCCL's HeartbeatMonitor before importing torch.distributed.
# It runs in a background thread polling the TCPStore; under mpirun, rank 0
# (the store server) can exit before non-zero ranks finish teardown, producing
# noisy 'recvValue failed / Connection was likely closed' stack traces.
os.environ.setdefault("TORCH_NCCL_ENABLE_MONITORING", "0")
import torch
import torch.distributed as dist
@@ -289,6 +295,36 @@ def main():
# convention here so `per_rank_bw` is directly comparable across stacks.
bytes_one_way = bench_tokens * bench_hidden * x_b.element_size()
# NCCL-EP `ep_bench` six-metric breakdown
# (intranode -> single node, so rdma_*=0; nvl_*=total_*).
bytes_per_token = bench_hidden * x_b.element_size()
total_send_tokens_local = int(num_tokens_per_rank_b.sum().item())
rdma_send_tokens_local = 0 # intranode: no remote nodes
recv_from_src = torch.empty(num_ranks, dtype=torch.int64, device="cuda")
dist.all_to_all_single(
recv_from_src,
num_tokens_per_rank_b.to(torch.int64),
group=group,
)
total_recv_tokens_local = int(recv_from_src.sum().item())
rdma_recv_tokens_local = 0 # intranode
# Average per-rank token counts across ranks (matches NCCL-EP `Byte counts (per rank avg)`).
counts_t = torch.tensor(
[total_send_tokens_local, rdma_send_tokens_local,
total_recv_tokens_local, rdma_recv_tokens_local],
dtype=torch.float64, device="cuda",
)
dist.all_reduce(counts_t, op=dist.ReduceOp.SUM, group=group)
counts_avg = (counts_t / num_ranks).tolist()
total_send_avg, rdma_send_avg, total_recv_avg, rdma_recv_avg = counts_avg
total_send_bytes = total_send_avg * bytes_per_token
rdma_send_bytes = rdma_send_avg * bytes_per_token
total_recv_bytes = total_recv_avg * bytes_per_token
rdma_recv_bytes = rdma_recv_avg * bytes_per_token
nvl_send_bytes = total_send_bytes - rdma_send_bytes
nvl_recv_bytes = total_recv_bytes - rdma_recv_bytes
# Reduce timings: report min/avg/max and base BW on AVG to match NCCL-EP's
# `ep_bench.cu` convention.
disp_min_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda")
@@ -307,6 +343,21 @@ def main():
comb_avg_us = comb_avg_t.item() / num_ranks
disp_bw_per_rank = bytes_one_way / (disp_avg_us * 1e-6) / 1e9
comb_bw_per_rank = bytes_one_way / (comb_avg_us * 1e-6) / 1e9
# Six-metric BW (NCCL-EP convention). Combine reverses send<->recv.
disp_t_s = disp_avg_us * 1e-6
comb_t_s = comb_avg_us * 1e-6
d_send_total_bw = total_send_bytes / disp_t_s / 1e9
d_send_nvl_bw = nvl_send_bytes / disp_t_s / 1e9
d_send_rdma_bw = rdma_send_bytes / disp_t_s / 1e9
d_recv_total_bw = total_recv_bytes / disp_t_s / 1e9
d_recv_nvl_bw = nvl_recv_bytes / disp_t_s / 1e9
d_recv_rdma_bw = rdma_recv_bytes / disp_t_s / 1e9
c_send_total_bw = total_recv_bytes / comb_t_s / 1e9 # combine sends back what dispatch received
c_send_nvl_bw = nvl_recv_bytes / comb_t_s / 1e9
c_send_rdma_bw = rdma_recv_bytes / comb_t_s / 1e9
c_recv_total_bw = total_send_bytes / comb_t_s / 1e9 # combine receives back what dispatch sent
c_recv_nvl_bw = nvl_send_bytes / comb_t_s / 1e9
c_recv_rdma_bw = rdma_send_bytes / comb_t_s / 1e9
if rank == 0:
print(
f"[bench intranode HT] tokens={bench_tokens} hidden={bench_hidden} "
@@ -320,12 +371,30 @@ def main():
f"agg_bw={disp_bw_per_rank * num_ranks:.2f} GB/s (BW @ avg time)",
flush=True,
)
print(
f" send: total={d_send_total_bw:.2f} nvl={d_send_nvl_bw:.2f} rdma={d_send_rdma_bw:.2f} GB/s "
f"recv: total={d_recv_total_bw:.2f} nvl={d_recv_nvl_bw:.2f} rdma={d_recv_rdma_bw:.2f} GB/s",
flush=True,
)
print(
f" combine : avg={comb_avg_us:.1f}us min={comb_min_t.item():.1f}us max={comb_max_t.item():.1f}us "
f"per_rank_bw={comb_bw_per_rank:.2f} GB/s "
f"agg_bw={comb_bw_per_rank * num_ranks:.2f} GB/s (BW @ avg time)",
flush=True,
)
print(
f" send: total={c_send_total_bw:.2f} nvl={c_send_nvl_bw:.2f} rdma={c_send_rdma_bw:.2f} GB/s "
f"recv: total={c_recv_total_bw:.2f} nvl={c_recv_nvl_bw:.2f} rdma={c_recv_rdma_bw:.2f} GB/s",
flush=True,
)
print(
f" byte counts (per rank avg): "
f"total_send={total_send_bytes/1e6:.2f} MB ({total_send_avg:.0f} tok) "
f"rdma_send={rdma_send_bytes/1e6:.2f} MB ({rdma_send_avg:.0f} tok) "
f"total_recv={total_recv_bytes/1e6:.2f} MB ({total_recv_avg:.0f} tok) "
f"rdma_recv={rdma_recv_bytes/1e6:.2f} MB ({rdma_recv_avg:.0f} tok)",
flush=True,
)
if __name__ == "__main__":