diff --git a/test/python/ext/ep/test_internode_multirank.py b/test/python/ext/ep/test_internode_multirank.py index 9ac9639c..33971cf6 100644 --- a/test/python/ext/ep/test_internode_multirank.py +++ b/test/python/ext/ep/test_internode_multirank.py @@ -164,15 +164,28 @@ def main(): if rank == 0: print(f"[dispatch] OK (recv {recv_x.size(0)} tokens)", flush=True) + # XXX: forcing a device+group sync here is currently required for combine + # to see consistent dispatch outputs. Without this both send_nvl_head and + # the various *_channel_prefix_matrix tensors can still be in flight on + # the comm stream when combine launches, producing a deadlock inside the + # combine forwarder (NVL check never advances). Investigate proper + # stream-dependency hand-off in Buffer::internode_dispatch. + torch.cuda.synchronize() + dist.barrier(group=group) + # internode_combine signature: # (x, topk_weights, # src_meta, is_combined_token_in_rank, # rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, # combined_rdma_head, combined_nvl_head, config, previous_event, async, allocate_on_comm_stream) + # NOTE: combine goes in the reverse direction of dispatch, so the prefix + # matrices passed here must be the RECEIVER-side ones returned by dispatch + # (`recv_rdma_channel_prefix_matrix`, `recv_rdma_rank_prefix_sum`, + # `recv_gbl_channel_prefix_matrix`) — not the sender-side ones. combined_x, combined_topk_weights, _ = buf.runtime.internode_combine( recv_x, recv_topk_weights, recv_src_meta, is_token_in_rank, - rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, + recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, send_rdma_head, send_nvl_head, cfg, None, False, False, ) @@ -182,8 +195,7 @@ def main(): got = combined_x.float().mean(dim=1) diff = (got - expected).abs().max().item() max_exp = expected.abs().max().item() - if rank == 0: - print(f"[combine] max|got-expected|={diff:.4e} max|expected|={max_exp:.4e}", flush=True) + print(f"[combine r{rank}] max|got-expected|={diff:.4e} max|expected|={max_exp:.4e}", flush=True) assert diff < 1e-2, f"rank{rank}: combine mismatch max diff {diff}" dist.barrier(group=group)