diff --git a/test/python/ext/ep/test_internode_multirank.py b/test/python/ext/ep/test_internode_multirank.py index c9d771fc..c9d28a7a 100644 --- a/test/python/ext/ep/test_internode_multirank.py +++ b/test/python/ext/ep/test_internode_multirank.py @@ -29,6 +29,12 @@ from __future__ import annotations import os import sys +# Disable ProcessGroupNCCL's HeartbeatMonitor before importing torch.distributed. +# It runs in a background thread polling the TCPStore; under mpirun, rank 0 +# (the store server) can exit before non-zero ranks finish teardown, producing +# noisy 'recvValue failed / Connection was likely closed' stack traces. +os.environ.setdefault("TORCH_NCCL_ENABLE_MONITORING", "0") + import torch import torch.distributed as dist @@ -343,6 +349,45 @@ def main(): # convention here so `per_rank_bw` is directly comparable across stacks. bytes_one_way = bench_tokens * bench_hidden * x_b.element_size() + # NCCL-EP `ep_bench` six-metric breakdown. + # Send-side accounting follows NCCL-EP: count unique (token, dst_node) pairs. + # `num_tokens_per_rdma_rank_b[n]` is exactly that count for node `n`. + # Recv-side accounting: each rank reports `num_tokens_per_rank_b[r]` + # (tokens it sends to dst rank `r`); an `all_to_all_single` lets every + # rank read how many tokens each source rank sent to it. + bytes_per_token = bench_hidden * x_b.element_size() + local_node = rank // num_local_ranks + nodes_unique = num_tokens_per_rdma_rank_b.to(torch.int64) + total_send_tokens_local = int(nodes_unique.sum().item()) + nvl_send_tokens_local = int(nodes_unique[local_node].item()) + rdma_send_tokens_local = total_send_tokens_local - nvl_send_tokens_local + recv_from_src = torch.empty(num_ranks, dtype=torch.int64, device="cuda") + dist.all_to_all_single( + recv_from_src, + num_tokens_per_rank_b.to(torch.int64), + group=group, + ) + src_node = (torch.arange(num_ranks, device="cuda") // num_local_ranks) + remote_mask = (src_node != local_node).to(torch.int64) + total_recv_tokens_local = int(recv_from_src.sum().item()) + rdma_recv_tokens_local = int((recv_from_src * remote_mask).sum().item()) + + # Average per-rank token counts across ranks (matches NCCL-EP `Byte counts (per rank avg)`). + counts_t = torch.tensor( + [total_send_tokens_local, rdma_send_tokens_local, + total_recv_tokens_local, rdma_recv_tokens_local], + dtype=torch.float64, device="cuda", + ) + dist.all_reduce(counts_t, op=dist.ReduceOp.SUM, group=group) + counts_avg = (counts_t / num_ranks).tolist() + total_send_avg, rdma_send_avg, total_recv_avg, rdma_recv_avg = counts_avg + total_send_bytes = total_send_avg * bytes_per_token + rdma_send_bytes = rdma_send_avg * bytes_per_token + total_recv_bytes = total_recv_avg * bytes_per_token + rdma_recv_bytes = rdma_recv_avg * bytes_per_token + nvl_send_bytes = total_send_bytes - rdma_send_bytes + nvl_recv_bytes = total_recv_bytes - rdma_recv_bytes + # Reduce timings: report min/avg/max and base BW on AVG to match NCCL-EP's # `ep_bench.cu` convention. disp_min_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda") @@ -361,6 +406,22 @@ def main(): comb_avg_us = comb_avg_t.item() / num_ranks disp_bw_per_rank = bytes_one_way / (disp_avg_us * 1e-6) / 1e9 comb_bw_per_rank = bytes_one_way / (comb_avg_us * 1e-6) / 1e9 + # Six-metric BW (NCCL-EP convention). Combine reverses send<->recv: + # in combine, this rank pushes back what it received in dispatch. + disp_t_s = disp_avg_us * 1e-6 + comb_t_s = comb_avg_us * 1e-6 + d_send_total_bw = total_send_bytes / disp_t_s / 1e9 + d_send_nvl_bw = nvl_send_bytes / disp_t_s / 1e9 + d_send_rdma_bw = rdma_send_bytes / disp_t_s / 1e9 + d_recv_total_bw = total_recv_bytes / disp_t_s / 1e9 + d_recv_nvl_bw = nvl_recv_bytes / disp_t_s / 1e9 + d_recv_rdma_bw = rdma_recv_bytes / disp_t_s / 1e9 + c_send_total_bw = total_recv_bytes / comb_t_s / 1e9 + c_send_nvl_bw = nvl_recv_bytes / comb_t_s / 1e9 + c_send_rdma_bw = rdma_recv_bytes / comb_t_s / 1e9 + c_recv_total_bw = total_send_bytes / comb_t_s / 1e9 + c_recv_nvl_bw = nvl_send_bytes / comb_t_s / 1e9 + c_recv_rdma_bw = rdma_send_bytes / comb_t_s / 1e9 if rank == 0: print( f"[bench internode HT] nodes={num_nodes} num_ranks={num_ranks} " @@ -375,12 +436,30 @@ def main(): f"agg_bw={disp_bw_per_rank * num_ranks:.2f} GB/s (BW @ avg time)", flush=True, ) + print( + f" send: total={d_send_total_bw:.2f} nvl={d_send_nvl_bw:.2f} rdma={d_send_rdma_bw:.2f} GB/s " + f"recv: total={d_recv_total_bw:.2f} nvl={d_recv_nvl_bw:.2f} rdma={d_recv_rdma_bw:.2f} GB/s", + flush=True, + ) print( f" combine : avg={comb_avg_us:.1f}us min={comb_min_t.item():.1f}us max={comb_max_t.item():.1f}us " f"per_rank_bw={comb_bw_per_rank:.2f} GB/s " f"agg_bw={comb_bw_per_rank * num_ranks:.2f} GB/s (BW @ avg time)", flush=True, ) + print( + f" send: total={c_send_total_bw:.2f} nvl={c_send_nvl_bw:.2f} rdma={c_send_rdma_bw:.2f} GB/s " + f"recv: total={c_recv_total_bw:.2f} nvl={c_recv_nvl_bw:.2f} rdma={c_recv_rdma_bw:.2f} GB/s", + flush=True, + ) + print( + f" byte counts (per rank avg): " + f"total_send={total_send_bytes/1e6:.2f} MB ({total_send_avg:.0f} tok) " + f"rdma_send={rdma_send_bytes/1e6:.2f} MB ({rdma_send_avg:.0f} tok) " + f"total_recv={total_recv_bytes/1e6:.2f} MB ({total_recv_avg:.0f} tok) " + f"rdma_recv={rdma_recv_bytes/1e6:.2f} MB ({rdma_recv_avg:.0f} tok)", + flush=True, + ) if __name__ == "__main__": diff --git a/test/python/ext/ep/test_intranode_multirank.py b/test/python/ext/ep/test_intranode_multirank.py index 3383c78d..54d6b689 100644 --- a/test/python/ext/ep/test_intranode_multirank.py +++ b/test/python/ext/ep/test_intranode_multirank.py @@ -22,6 +22,12 @@ from __future__ import annotations import os import sys +# Disable ProcessGroupNCCL's HeartbeatMonitor before importing torch.distributed. +# It runs in a background thread polling the TCPStore; under mpirun, rank 0 +# (the store server) can exit before non-zero ranks finish teardown, producing +# noisy 'recvValue failed / Connection was likely closed' stack traces. +os.environ.setdefault("TORCH_NCCL_ENABLE_MONITORING", "0") + import torch import torch.distributed as dist @@ -289,6 +295,36 @@ def main(): # convention here so `per_rank_bw` is directly comparable across stacks. bytes_one_way = bench_tokens * bench_hidden * x_b.element_size() + # NCCL-EP `ep_bench` six-metric breakdown + # (intranode -> single node, so rdma_*=0; nvl_*=total_*). + bytes_per_token = bench_hidden * x_b.element_size() + total_send_tokens_local = int(num_tokens_per_rank_b.sum().item()) + rdma_send_tokens_local = 0 # intranode: no remote nodes + recv_from_src = torch.empty(num_ranks, dtype=torch.int64, device="cuda") + dist.all_to_all_single( + recv_from_src, + num_tokens_per_rank_b.to(torch.int64), + group=group, + ) + total_recv_tokens_local = int(recv_from_src.sum().item()) + rdma_recv_tokens_local = 0 # intranode + + # Average per-rank token counts across ranks (matches NCCL-EP `Byte counts (per rank avg)`). + counts_t = torch.tensor( + [total_send_tokens_local, rdma_send_tokens_local, + total_recv_tokens_local, rdma_recv_tokens_local], + dtype=torch.float64, device="cuda", + ) + dist.all_reduce(counts_t, op=dist.ReduceOp.SUM, group=group) + counts_avg = (counts_t / num_ranks).tolist() + total_send_avg, rdma_send_avg, total_recv_avg, rdma_recv_avg = counts_avg + total_send_bytes = total_send_avg * bytes_per_token + rdma_send_bytes = rdma_send_avg * bytes_per_token + total_recv_bytes = total_recv_avg * bytes_per_token + rdma_recv_bytes = rdma_recv_avg * bytes_per_token + nvl_send_bytes = total_send_bytes - rdma_send_bytes + nvl_recv_bytes = total_recv_bytes - rdma_recv_bytes + # Reduce timings: report min/avg/max and base BW on AVG to match NCCL-EP's # `ep_bench.cu` convention. disp_min_t = torch.tensor([disp_us], dtype=torch.float64, device="cuda") @@ -307,6 +343,21 @@ def main(): comb_avg_us = comb_avg_t.item() / num_ranks disp_bw_per_rank = bytes_one_way / (disp_avg_us * 1e-6) / 1e9 comb_bw_per_rank = bytes_one_way / (comb_avg_us * 1e-6) / 1e9 + # Six-metric BW (NCCL-EP convention). Combine reverses send<->recv. + disp_t_s = disp_avg_us * 1e-6 + comb_t_s = comb_avg_us * 1e-6 + d_send_total_bw = total_send_bytes / disp_t_s / 1e9 + d_send_nvl_bw = nvl_send_bytes / disp_t_s / 1e9 + d_send_rdma_bw = rdma_send_bytes / disp_t_s / 1e9 + d_recv_total_bw = total_recv_bytes / disp_t_s / 1e9 + d_recv_nvl_bw = nvl_recv_bytes / disp_t_s / 1e9 + d_recv_rdma_bw = rdma_recv_bytes / disp_t_s / 1e9 + c_send_total_bw = total_recv_bytes / comb_t_s / 1e9 # combine sends back what dispatch received + c_send_nvl_bw = nvl_recv_bytes / comb_t_s / 1e9 + c_send_rdma_bw = rdma_recv_bytes / comb_t_s / 1e9 + c_recv_total_bw = total_send_bytes / comb_t_s / 1e9 # combine receives back what dispatch sent + c_recv_nvl_bw = nvl_send_bytes / comb_t_s / 1e9 + c_recv_rdma_bw = rdma_send_bytes / comb_t_s / 1e9 if rank == 0: print( f"[bench intranode HT] tokens={bench_tokens} hidden={bench_hidden} " @@ -320,12 +371,30 @@ def main(): f"agg_bw={disp_bw_per_rank * num_ranks:.2f} GB/s (BW @ avg time)", flush=True, ) + print( + f" send: total={d_send_total_bw:.2f} nvl={d_send_nvl_bw:.2f} rdma={d_send_rdma_bw:.2f} GB/s " + f"recv: total={d_recv_total_bw:.2f} nvl={d_recv_nvl_bw:.2f} rdma={d_recv_rdma_bw:.2f} GB/s", + flush=True, + ) print( f" combine : avg={comb_avg_us:.1f}us min={comb_min_t.item():.1f}us max={comb_max_t.item():.1f}us " f"per_rank_bw={comb_bw_per_rank:.2f} GB/s " f"agg_bw={comb_bw_per_rank * num_ranks:.2f} GB/s (BW @ avg time)", flush=True, ) + print( + f" send: total={c_send_total_bw:.2f} nvl={c_send_nvl_bw:.2f} rdma={c_send_rdma_bw:.2f} GB/s " + f"recv: total={c_recv_total_bw:.2f} nvl={c_recv_nvl_bw:.2f} rdma={c_recv_rdma_bw:.2f} GB/s", + flush=True, + ) + print( + f" byte counts (per rank avg): " + f"total_send={total_send_bytes/1e6:.2f} MB ({total_send_avg:.0f} tok) " + f"rdma_send={rdma_send_bytes/1e6:.2f} MB ({rdma_send_avg:.0f} tok) " + f"total_recv={total_recv_bytes/1e6:.2f} MB ({total_recv_avg:.0f} tok) " + f"rdma_recv={rdma_recv_bytes/1e6:.2f} MB ({rdma_recv_avg:.0f} tok)", + flush=True, + ) if __name__ == "__main__":