diff --git a/test/python/ext/ep/test_intranode_multirank.py b/test/python/ext/ep/test_intranode_multirank.py index 54d6b689..4dcfad9f 100644 --- a/test/python/ext/ep/test_intranode_multirank.py +++ b/test/python/ext/ep/test_intranode_multirank.py @@ -297,8 +297,14 @@ def main(): # NCCL-EP `ep_bench` six-metric breakdown # (intranode -> single node, so rdma_*=0; nvl_*=total_*). + # + # Send side follows NCCL-EP: count unique (token, dst_node) pairs. With a + # single node every selected destination collapses to that node, so a + # token with at least one valid expert contributes exactly one to + # `total_send_tokens`. Recv side counts unique (src_rank, token) pairs + # landing on this rank. bytes_per_token = bench_hidden * x_b.element_size() - total_send_tokens_local = int(num_tokens_per_rank_b.sum().item()) + total_send_tokens_local = int(is_token_in_rank_b.any(dim=1).sum().item()) rdma_send_tokens_local = 0 # intranode: no remote nodes recv_from_src = torch.empty(num_ranks, dtype=torch.int64, device="cuda") dist.all_to_all_single(