mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
tests/ep: add NCCL-EP six-metric BW breakdown (send/recv x total/nvl/rdma)
For HT intra/internode benches, compute per-rank avg total_send/rdma_send and total_recv/rdma_recv token counts (matching NCCL-EP ep_bench accounting) and print send-side and recv-side BW split into total / nvl / rdma columns. Combine reverses send<->recv. Byte-count line mirrors NCCL-EP's '(per rank avg)' summary so numbers are directly comparable.
This commit is contained in:
@@ -22,6 +22,12 @@ from __future__ import annotations
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Disable ProcessGroupNCCL's HeartbeatMonitor before importing torch.distributed.
|
||||
# It runs in a background thread polling the TCPStore; under mpirun, rank 0
|
||||
# (the store server) can exit before non-zero ranks finish teardown, producing
|
||||
# noisy 'recvValue failed / Connection was likely closed' stack traces.
|
||||
os.environ.setdefault("TORCH_NCCL_ENABLE_MONITORING", "0")
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
@@ -289,6 +295,36 @@ def main():
|
||||
# convention here so `per_rank_bw` is directly comparable across stacks.
|
||||
bytes_one_way = bench_tokens * bench_hidden * x_b.element_size()
|
||||
|
||||
# NCCL-EP `ep_bench` six-metric breakdown
|
||||
# (intranode -> single node, so rdma_*=0; nvl_*=total_*).
|
||||
bytes_per_token = bench_hidden * x_b.element_size()
|
||||
total_send_tokens_local = int(num_tokens_per_rank_b.sum().item())
|
||||
rdma_send_tokens_local = 0 # intranode: no remote nodes
|
||||
recv_from_src = torch.empty(num_ranks, dtype=torch.int64, device="cuda")
|
||||
dist.all_to_all_single(
|
||||
recv_from_src,
|
||||
num_tokens_per_rank_b.to(torch.int64),
|
||||
group=group,
|
||||
)
|
||||
total_recv_tokens_local = int(recv_from_src.sum().item())
|
||||
rdma_recv_tokens_local = 0 # intranode
|
||||
|
||||
# Average per-rank token counts across ranks (matches NCCL-EP `Byte counts (per rank avg)`).
|
||||
counts_t = torch.tensor(
|
||||
[total_send_tokens_local, rdma_send_tokens_local,
|
||||
total_recv_tokens_local, rdma_recv_tokens_local],
|
||||
dtype=torch.float64, device="cuda",
|
||||
)
|
||||
dist.all_reduce(counts_t, op=dist.ReduceOp.SUM, group=group)
|
||||
counts_avg = (counts_t / num_ranks).tolist()
|
||||
total_send_avg, rdma_send_avg, total_recv_avg, rdma_recv_avg = counts_avg
|
||||
total_send_bytes = total_send_avg * bytes_per_token
|
||||
rdma_send_bytes = rdma_send_avg * bytes_per_token
|
||||
total_recv_bytes = total_recv_avg * bytes_per_token
|
||||
rdma_recv_bytes = rdma_recv_avg * bytes_per_token
|
||||
nvl_send_bytes = total_send_bytes - rdma_send_bytes
|
||||
nvl_recv_bytes = total_recv_bytes - rdma_recv_bytes
|
||||
|
||||
# Reduce timings: report min/avg/max and base BW on AVG to match NCCL-EP's
|
||||
# `ep_bench.cu` convention.
|
||||
disp_min_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda")
|
||||
@@ -307,6 +343,21 @@ def main():
|
||||
comb_avg_us = comb_avg_t.item() / num_ranks
|
||||
disp_bw_per_rank = bytes_one_way / (disp_avg_us * 1e-6) / 1e9
|
||||
comb_bw_per_rank = bytes_one_way / (comb_avg_us * 1e-6) / 1e9
|
||||
# Six-metric BW (NCCL-EP convention). Combine reverses send<->recv.
|
||||
disp_t_s = disp_avg_us * 1e-6
|
||||
comb_t_s = comb_avg_us * 1e-6
|
||||
d_send_total_bw = total_send_bytes / disp_t_s / 1e9
|
||||
d_send_nvl_bw = nvl_send_bytes / disp_t_s / 1e9
|
||||
d_send_rdma_bw = rdma_send_bytes / disp_t_s / 1e9
|
||||
d_recv_total_bw = total_recv_bytes / disp_t_s / 1e9
|
||||
d_recv_nvl_bw = nvl_recv_bytes / disp_t_s / 1e9
|
||||
d_recv_rdma_bw = rdma_recv_bytes / disp_t_s / 1e9
|
||||
c_send_total_bw = total_recv_bytes / comb_t_s / 1e9 # combine sends back what dispatch received
|
||||
c_send_nvl_bw = nvl_recv_bytes / comb_t_s / 1e9
|
||||
c_send_rdma_bw = rdma_recv_bytes / comb_t_s / 1e9
|
||||
c_recv_total_bw = total_send_bytes / comb_t_s / 1e9 # combine receives back what dispatch sent
|
||||
c_recv_nvl_bw = nvl_send_bytes / comb_t_s / 1e9
|
||||
c_recv_rdma_bw = rdma_send_bytes / comb_t_s / 1e9
|
||||
if rank == 0:
|
||||
print(
|
||||
f"[bench intranode HT] tokens={bench_tokens} hidden={bench_hidden} "
|
||||
@@ -320,12 +371,30 @@ def main():
|
||||
f"agg_bw={disp_bw_per_rank * num_ranks:.2f} GB/s (BW @ avg time)",
|
||||
flush=True,
|
||||
)
|
||||
print(
|
||||
f" send: total={d_send_total_bw:.2f} nvl={d_send_nvl_bw:.2f} rdma={d_send_rdma_bw:.2f} GB/s "
|
||||
f"recv: total={d_recv_total_bw:.2f} nvl={d_recv_nvl_bw:.2f} rdma={d_recv_rdma_bw:.2f} GB/s",
|
||||
flush=True,
|
||||
)
|
||||
print(
|
||||
f" combine : avg={comb_avg_us:.1f}us min={comb_min_t.item():.1f}us max={comb_max_t.item():.1f}us "
|
||||
f"per_rank_bw={comb_bw_per_rank:.2f} GB/s "
|
||||
f"agg_bw={comb_bw_per_rank * num_ranks:.2f} GB/s (BW @ avg time)",
|
||||
flush=True,
|
||||
)
|
||||
print(
|
||||
f" send: total={c_send_total_bw:.2f} nvl={c_send_nvl_bw:.2f} rdma={c_send_rdma_bw:.2f} GB/s "
|
||||
f"recv: total={c_recv_total_bw:.2f} nvl={c_recv_nvl_bw:.2f} rdma={c_recv_rdma_bw:.2f} GB/s",
|
||||
flush=True,
|
||||
)
|
||||
print(
|
||||
f" byte counts (per rank avg): "
|
||||
f"total_send={total_send_bytes/1e6:.2f} MB ({total_send_avg:.0f} tok) "
|
||||
f"rdma_send={rdma_send_bytes/1e6:.2f} MB ({rdma_send_avg:.0f} tok) "
|
||||
f"total_recv={total_recv_bytes/1e6:.2f} MB ({total_recv_avg:.0f} tok) "
|
||||
f"rdma_recv={rdma_recv_bytes/1e6:.2f} MB ({rdma_recv_avg:.0f} tok)",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user