diff --git a/python/mscclpp/default_algos/mscclpp_send_recv.py b/python/mscclpp/default_algos/mscclpp_send_recv.py index 7f68fe86..d4ce0004 100644 --- a/python/mscclpp/default_algos/mscclpp_send_recv.py +++ b/python/mscclpp/default_algos/mscclpp_send_recv.py @@ -9,93 +9,175 @@ from mscclpp.language.program import * from mscclpp.language.collectives import * -def send_recv_test(name, nnodes, gpus_per_node, split_mask): - gpu_size = nnodes * gpus_per_node - collective = TestCollective(gpu_size, 1, 1) +def send_recv_test_ring_even_ranks(name, nnodes, gpus_per_node): + nranks = nnodes * gpus_per_node + + if nranks < 2: + raise ValueError("This test requires at least 2 ranks") + if nranks % 2 != 0: + raise ValueError( + f"This odd/even ring schedule requires an even number of ranks, got {nranks}" + ) + + collective = TestCollective(nranks, 1, 1) with CollectiveProgram( name, collective, - gpu_size, + nranks, protocol="Simple", num_threads_per_block=1024, use_double_scratch_buffer=False, min_message_size=0, max_message_size=2**64 - 1, - instances=1, # ✅ correctness-first + instances=2, ): - - # Ring grouping - group_size = split_mask + 1 - num_groups = gpu_size // group_size - next_channels = {} prev_channels = {} - prev_next_ids = {} - # ------------------------------------------------------------------ - # Channel creation (parity-based for deterministic tag matching) - # ------------------------------------------------------------------ - for node in range(nnodes): - for gpu in range(gpus_per_node): - rank = gpu + gpus_per_node * node + # -------------------------------------------------------------- + # Classic ring across all ranks: + # prev = (rank - 1 + nranks) % nranks + # next = (rank + 1) % nranks + # -------------------------------------------------------------- + for rank in range(nranks): + prev_rank = (rank - 1 + nranks) % nranks + next_rank = (rank + 1) % nranks - pos = rank & split_mask - group = rank // group_size + # Deterministic channel creation order + if (rank & 1) == 0: + next_channels[rank] = PortChannel(next_rank, rank) + prev_channels[rank] = PortChannel(prev_rank, rank) + else: + prev_channels[rank] = PortChannel(prev_rank, rank) + next_channels[rank] = PortChannel(next_rank, rank) - next_group = (group + 1) % num_groups - prev_group = (group - 1 + num_groups) % num_groups + # -------------------------------------------------------------- + # -------------------------------------------------------------- + # Ring send/recv with explicit ACK + # + # Data path: + # sender: put_with_signal() to next + # receiver: wait() from prev + # + # ACK path: + # receiver: signal() back to prev after data is available + # sender: wait() for ACK from next before proceeding + # + # Even ranks: send first, then recv, then ACK prev, then wait ACK + # Odd ranks : recv first, then ACK prev, then send, then wait ACK + # -------------------------------------------------------------- + for rank in range(nranks): + prev_rank = (rank - 1 + nranks) % nranks + next_rank = (rank + 1) % nranks - next_rank = next_group * group_size + pos - prev_rank = prev_group * group_size + pos + src_rank = Rank(rank) + next_rank_obj = Rank(next_rank) - # ✅ parity-based creation order - if (rank & 1) == 0: - next_channels[rank] = PortChannel(next_rank, rank) - prev_channels[rank] = PortChannel(prev_rank, rank) - else: - prev_channels[rank] = PortChannel(prev_rank, rank) - next_channels[rank] = PortChannel(next_rank, rank) + src_buf = src_rank.get_input_buffer() + next_out_buf = next_rank_obj.get_output_buffer() - prev_next_ids[rank] = (prev_rank, next_rank) + src_chunk = src_buf[0:src_buf.size] + dst_chunk = next_out_buf[0:next_out_buf.size] - # ------------------------------------------------------------------ - # Ring send/recv (deadlock-free) - # ------------------------------------------------------------------ - for node in range(nnodes): - for gpu in range(gpus_per_node): - rank = gpu + gpus_per_node * node - prev_rank, next_rank = prev_next_ids[rank] + ch_to_next = next_channels[rank] + ch_from_prev = prev_channels[rank] - ch_from_prev = prev_channels[rank] - ch_to_next = next_channels[rank] + if (rank & 1) == 0: + # Send data to next and signal arrival + ch_to_next.put_with_signal( + dst_chunk, + src_chunk, + tb=0, + ) - src_rank = Rank(rank) - src_buf = src_rank.get_input_buffer() - src_chunk = src_buf[0:src_buf.size] + # Wait for data from prev to become visible locally + ch_from_prev.wait( + tb=0, + data_sync=SyncType.after, + ) - dst_rank = Rank(next_rank) - dst_buf = dst_rank.get_output_buffer() - dst_chunk = dst_buf[0:dst_buf.size] + # Ack back to prev that this rank has observed/consumed input + ch_from_prev.signal( + tb=0, + ) - if rank == 0: - # ✅ starter sends first - ch_to_next.put_with_signal_and_flush( - dst_chunk, - src_chunk, - tb=0, - ) - # then receive from prev - ch_from_prev.wait(tb=0, data_sync=SyncType.after) - else: - # ✅ everyone else receives first - ch_from_prev.wait(tb=0, data_sync=SyncType.after) - ch_to_next.put_with_signal_and_flush( - dst_chunk, - src_chunk, - tb=0, - ) + # Wait for next rank to ack our outgoing transfer + ch_to_next.wait( + tb=0, + ) + else: + # Wait for data from prev first + ch_from_prev.wait( + tb=0, + data_sync=SyncType.after, + ) + + # Ack back to prev that this rank has observed/consumed input + ch_from_prev.signal( + tb=0, + ) + + # Then send data to next + ch_to_next.put_with_signal( + dst_chunk, + src_chunk, + tb=0, + ) + + # Wait for next rank to ack our outgoing transfer + ch_to_next.wait( + tb=0, + ) + # -------------------------------------------------------------- + # Ring send/recv + # + # Even ranks: send first, then wait + # Odd ranks : wait first, then send + # + # This is safe for an even-sized ring and avoids the + # single-rank-starter wave. + # -------------------------------------------------------------- + ''' + for rank in range(nranks): + prev_rank = (rank - 1 + nranks) % nranks + next_rank = (rank + 1) % nranks + + src_rank = Rank(rank) + next_rank_obj = Rank(next_rank) + + src_buf = src_rank.get_input_buffer() + next_out_buf = next_rank_obj.get_output_buffer() + + src_chunk = src_buf[0:src_buf.size] + dst_chunk = next_out_buf[0:next_out_buf.size] + + ch_to_next = next_channels[rank] + ch_from_prev = prev_channels[rank] + + if (rank & 1) == 0: + ch_to_next.put_with_signal_and_flush( + dst_chunk, + src_chunk, + tb=0, + ) + ch_from_prev.wait( + tb=0, + data_sync=SyncType.after, + ) + else: + ch_from_prev.wait( + tb=0, + data_sync=SyncType.after, + ) + ch_to_next.put_with_signal_and_flush( + dst_chunk, + src_chunk, + tb=0, + ) + + ''' print(JSON()) @@ -103,21 +185,14 @@ def send_recv_test(name, nnodes, gpus_per_node, split_mask): # CLI # ---------------------------------------------------------------------- parser = argparse.ArgumentParser() -parser.add_argument("--name", type=str, help="name of the program") +parser.add_argument("--name", type=str, required=True, help="name of the program") parser.add_argument("--nnodes", type=int, default=1, help="number of nodes") -parser.add_argument("--gpus_per_node", type=int, help="number of gpus per node") -parser.add_argument( - "--split_mask", - type=lambda x: int(x, 0), - default=0x3, - help="split mask (e.g. 0x3)", -) +parser.add_argument("--gpus_per_node", type=int, required=True, help="number of GPUs per node") args = parser.parse_args() -send_recv_test( +send_recv_test_ring_even_ranks( args.name, args.nnodes, args.gpus_per_node, - args.split_mask, )