From 812f6cfdede1a7102a105e9530cada27f9defed6 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 7 Apr 2026 01:32:54 +0000 Subject: [PATCH] fix hang on 4 ranks and make send/recv test more like nccl-test --- .../default_algos/mscclpp_send_recv.py | 121 +++++++++++------- 1 file changed, 77 insertions(+), 44 deletions(-) diff --git a/python/mscclpp/default_algos/mscclpp_send_recv.py b/python/mscclpp/default_algos/mscclpp_send_recv.py index ef052210..7f68fe86 100644 --- a/python/mscclpp/default_algos/mscclpp_send_recv.py +++ b/python/mscclpp/default_algos/mscclpp_send_recv.py @@ -12,6 +12,7 @@ from mscclpp.language.collectives import * def send_recv_test(name, nnodes, gpus_per_node, split_mask): gpu_size = nnodes * gpus_per_node collective = TestCollective(gpu_size, 1, 1) + with CollectiveProgram( name, collective, @@ -21,70 +22,102 @@ def send_recv_test(name, nnodes, gpus_per_node, split_mask): use_double_scratch_buffer=False, min_message_size=0, max_message_size=2**64 - 1, - instances=4 + instances=1, # ✅ correctness-first ): - # Creating separate port channels for next and prev directions. - # When prev and next are the same peer (e.g., 2-node ring), both channels go to the same peer - # and get distinct tags. To ensure cross-rank tag matching (rank A's prev_channel signal - # arrives at rank B's next_channel wait), we create channels in opposite order for the - # "higher" rank so that tags cross-match: - # Lower rank: [next(tag0), prev(tag1)] - # Higher rank: [prev(tag0), next(tag1)] - # Then lower.prev(tag1) == higher.next(tag1) ✓ and higher.prev(tag0) == lower.next(tag0) ✓ - # When prev != next (3+ nodes), each channel targets a different peer so each gets tag 0 - # and this ordering doesn't matter. + + # Ring grouping group_size = split_mask + 1 num_groups = gpu_size // group_size - next_channels = {} # channel for sending to next rank - prev_channels = {} # channel for receiving from prev rank + + next_channels = {} + prev_channels = {} prev_next_ids = {} + + # ------------------------------------------------------------------ + # Channel creation (parity-based for deterministic tag matching) + # ------------------------------------------------------------------ for node in range(nnodes): for gpu in range(gpus_per_node): - global_rank_id = gpu + gpus_per_node * node - position_in_group = global_rank_id & split_mask - group_id = global_rank_id // group_size - next_group_id = (group_id + 1) % num_groups - next_global_rank_id = next_group_id * group_size + position_in_group - prev_group_id = (group_id - 1 + num_groups) % num_groups - prev_global_rank_id = prev_group_id * group_size + position_in_group - if prev_global_rank_id == next_global_rank_id and global_rank_id > prev_global_rank_id: - # Higher rank: create prev first, then next (swapped order) - prev_channels[global_rank_id] = PortChannel(prev_global_rank_id, global_rank_id) - next_channels[global_rank_id] = PortChannel(next_global_rank_id, global_rank_id) + rank = gpu + gpus_per_node * node + + pos = rank & split_mask + group = rank // group_size + + next_group = (group + 1) % num_groups + prev_group = (group - 1 + num_groups) % num_groups + + next_rank = next_group * group_size + pos + prev_rank = prev_group * group_size + pos + + # ✅ parity-based creation order + if (rank & 1) == 0: + next_channels[rank] = PortChannel(next_rank, rank) + prev_channels[rank] = PortChannel(prev_rank, rank) else: - # Lower rank or different peers: create next first, then prev - next_channels[global_rank_id] = PortChannel(next_global_rank_id, global_rank_id) - prev_channels[global_rank_id] = PortChannel(prev_global_rank_id, global_rank_id) - prev_next_ids[global_rank_id] = (prev_global_rank_id, next_global_rank_id) + prev_channels[rank] = PortChannel(prev_rank, rank) + next_channels[rank] = PortChannel(next_rank, rank) - # sync with the next rank and the previous rank in the group + prev_next_ids[rank] = (prev_rank, next_rank) + + # ------------------------------------------------------------------ + # Ring send/recv (deadlock-free) + # ------------------------------------------------------------------ for node in range(nnodes): for gpu in range(gpus_per_node): - global_rank_id = gpu + gpus_per_node * node - prev_global_rank_id, next_global_rank_id = prev_next_ids[global_rank_id] - prev_channels[global_rank_id].signal(tb=0, data_sync=SyncType.none) - next_channels[global_rank_id].wait(tb=0, data_sync=SyncType.after) - - src_rank = Rank(global_rank_id) - src_buffer = src_rank.get_input_buffer() - dst_rank = Rank(next_global_rank_id) - dst_buffer = dst_rank.get_output_buffer() + rank = gpu + gpus_per_node * node + prev_rank, next_rank = prev_next_ids[rank] + + ch_from_prev = prev_channels[rank] + ch_to_next = next_channels[rank] + + src_rank = Rank(rank) + src_buf = src_rank.get_input_buffer() + src_chunk = src_buf[0:src_buf.size] + + dst_rank = Rank(next_rank) + dst_buf = dst_rank.get_output_buffer() + dst_chunk = dst_buf[0:dst_buf.size] + + if rank == 0: + # ✅ starter sends first + ch_to_next.put_with_signal_and_flush( + dst_chunk, + src_chunk, + tb=0, + ) + # then receive from prev + ch_from_prev.wait(tb=0, data_sync=SyncType.after) + else: + # ✅ everyone else receives first + ch_from_prev.wait(tb=0, data_sync=SyncType.after) + ch_to_next.put_with_signal_and_flush( + dst_chunk, + src_chunk, + tb=0, + ) - next_channels[global_rank_id].put_with_signal(dst_buffer[:], src_buffer[:], tb=0) - prev_channels[global_rank_id].wait(tb=0, data_sync=SyncType.none) - print(JSON()) +# ---------------------------------------------------------------------- +# CLI +# ---------------------------------------------------------------------- parser = argparse.ArgumentParser() - parser.add_argument("--name", type=str, help="name of the program") parser.add_argument("--nnodes", type=int, default=1, help="number of nodes") parser.add_argument("--gpus_per_node", type=int, help="number of gpus per node") -parser.add_argument("--split_mask", type=lambda x: int(x, 0), default=0x3, help="split mask (e.g. 0x3)") +parser.add_argument( + "--split_mask", + type=lambda x: int(x, 0), + default=0x3, + help="split mask (e.g. 0x3)", +) args = parser.parse_args() send_recv_test( - args.name, args.nnodes, args.gpus_per_node, args.split_mask + args.name, + args.nnodes, + args.gpus_per_node, + args.split_mask, )