fix hang on 4 ranks and make send/recv test more like nccl-test

This commit is contained in:
Ubuntu
2026-04-07 01:32:54 +00:00
parent 1a065dd6ad
commit 812f6cfded

View File

@@ -12,6 +12,7 @@ from mscclpp.language.collectives import *
def send_recv_test(name, nnodes, gpus_per_node, split_mask):
gpu_size = nnodes * gpus_per_node
collective = TestCollective(gpu_size, 1, 1)
with CollectiveProgram(
name,
collective,
@@ -21,70 +22,102 @@ def send_recv_test(name, nnodes, gpus_per_node, split_mask):
use_double_scratch_buffer=False,
min_message_size=0,
max_message_size=2**64 - 1,
instances=4
instances=1, # ✅ correctness-first
):
# Creating separate port channels for next and prev directions.
# When prev and next are the same peer (e.g., 2-node ring), both channels go to the same peer
# and get distinct tags. To ensure cross-rank tag matching (rank A's prev_channel signal
# arrives at rank B's next_channel wait), we create channels in opposite order for the
# "higher" rank so that tags cross-match:
# Lower rank: [next(tag0), prev(tag1)]
# Higher rank: [prev(tag0), next(tag1)]
# Then lower.prev(tag1) == higher.next(tag1) ✓ and higher.prev(tag0) == lower.next(tag0) ✓
# When prev != next (3+ nodes), each channel targets a different peer so each gets tag 0
# and this ordering doesn't matter.
# Ring grouping
group_size = split_mask + 1
num_groups = gpu_size // group_size
next_channels = {} # channel for sending to next rank
prev_channels = {} # channel for receiving from prev rank
next_channels = {}
prev_channels = {}
prev_next_ids = {}
# ------------------------------------------------------------------
# Channel creation (parity-based for deterministic tag matching)
# ------------------------------------------------------------------
for node in range(nnodes):
for gpu in range(gpus_per_node):
global_rank_id = gpu + gpus_per_node * node
position_in_group = global_rank_id & split_mask
group_id = global_rank_id // group_size
next_group_id = (group_id + 1) % num_groups
next_global_rank_id = next_group_id * group_size + position_in_group
prev_group_id = (group_id - 1 + num_groups) % num_groups
prev_global_rank_id = prev_group_id * group_size + position_in_group
if prev_global_rank_id == next_global_rank_id and global_rank_id > prev_global_rank_id:
# Higher rank: create prev first, then next (swapped order)
prev_channels[global_rank_id] = PortChannel(prev_global_rank_id, global_rank_id)
next_channels[global_rank_id] = PortChannel(next_global_rank_id, global_rank_id)
rank = gpu + gpus_per_node * node
pos = rank & split_mask
group = rank // group_size
next_group = (group + 1) % num_groups
prev_group = (group - 1 + num_groups) % num_groups
next_rank = next_group * group_size + pos
prev_rank = prev_group * group_size + pos
# ✅ parity-based creation order
if (rank & 1) == 0:
next_channels[rank] = PortChannel(next_rank, rank)
prev_channels[rank] = PortChannel(prev_rank, rank)
else:
# Lower rank or different peers: create next first, then prev
next_channels[global_rank_id] = PortChannel(next_global_rank_id, global_rank_id)
prev_channels[global_rank_id] = PortChannel(prev_global_rank_id, global_rank_id)
prev_next_ids[global_rank_id] = (prev_global_rank_id, next_global_rank_id)
prev_channels[rank] = PortChannel(prev_rank, rank)
next_channels[rank] = PortChannel(next_rank, rank)
# sync with the next rank and the previous rank in the group
prev_next_ids[rank] = (prev_rank, next_rank)
# ------------------------------------------------------------------
# Ring send/recv (deadlock-free)
# ------------------------------------------------------------------
for node in range(nnodes):
for gpu in range(gpus_per_node):
global_rank_id = gpu + gpus_per_node * node
prev_global_rank_id, next_global_rank_id = prev_next_ids[global_rank_id]
prev_channels[global_rank_id].signal(tb=0, data_sync=SyncType.none)
next_channels[global_rank_id].wait(tb=0, data_sync=SyncType.after)
src_rank = Rank(global_rank_id)
src_buffer = src_rank.get_input_buffer()
dst_rank = Rank(next_global_rank_id)
dst_buffer = dst_rank.get_output_buffer()
rank = gpu + gpus_per_node * node
prev_rank, next_rank = prev_next_ids[rank]
ch_from_prev = prev_channels[rank]
ch_to_next = next_channels[rank]
src_rank = Rank(rank)
src_buf = src_rank.get_input_buffer()
src_chunk = src_buf[0:src_buf.size]
dst_rank = Rank(next_rank)
dst_buf = dst_rank.get_output_buffer()
dst_chunk = dst_buf[0:dst_buf.size]
if rank == 0:
# ✅ starter sends first
ch_to_next.put_with_signal_and_flush(
dst_chunk,
src_chunk,
tb=0,
)
# then receive from prev
ch_from_prev.wait(tb=0, data_sync=SyncType.after)
else:
# ✅ everyone else receives first
ch_from_prev.wait(tb=0, data_sync=SyncType.after)
ch_to_next.put_with_signal_and_flush(
dst_chunk,
src_chunk,
tb=0,
)
next_channels[global_rank_id].put_with_signal(dst_buffer[:], src_buffer[:], tb=0)
prev_channels[global_rank_id].wait(tb=0, data_sync=SyncType.none)
print(JSON())
# ----------------------------------------------------------------------
# CLI
# ----------------------------------------------------------------------
parser = argparse.ArgumentParser()
parser.add_argument("--name", type=str, help="name of the program")
parser.add_argument("--nnodes", type=int, default=1, help="number of nodes")
parser.add_argument("--gpus_per_node", type=int, help="number of gpus per node")
parser.add_argument("--split_mask", type=lambda x: int(x, 0), default=0x3, help="split mask (e.g. 0x3)")
parser.add_argument(
"--split_mask",
type=lambda x: int(x, 0),
default=0x3,
help="split mask (e.g. 0x3)",
)
args = parser.parse_args()
send_recv_test(
args.name, args.nnodes, args.gpus_per_node, args.split_mask
args.name,
args.nnodes,
args.gpus_per_node,
args.split_mask,
)