mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 01:10:22 +00:00
add extra signal/wait and avoid local flush
This commit is contained in:
@@ -9,93 +9,175 @@ from mscclpp.language.program import *
|
||||
from mscclpp.language.collectives import *
|
||||
|
||||
|
||||
def send_recv_test(name, nnodes, gpus_per_node, split_mask):
|
||||
gpu_size = nnodes * gpus_per_node
|
||||
collective = TestCollective(gpu_size, 1, 1)
|
||||
def send_recv_test_ring_even_ranks(name, nnodes, gpus_per_node):
|
||||
nranks = nnodes * gpus_per_node
|
||||
|
||||
if nranks < 2:
|
||||
raise ValueError("This test requires at least 2 ranks")
|
||||
if nranks % 2 != 0:
|
||||
raise ValueError(
|
||||
f"This odd/even ring schedule requires an even number of ranks, got {nranks}"
|
||||
)
|
||||
|
||||
collective = TestCollective(nranks, 1, 1)
|
||||
|
||||
with CollectiveProgram(
|
||||
name,
|
||||
collective,
|
||||
gpu_size,
|
||||
nranks,
|
||||
protocol="Simple",
|
||||
num_threads_per_block=1024,
|
||||
use_double_scratch_buffer=False,
|
||||
min_message_size=0,
|
||||
max_message_size=2**64 - 1,
|
||||
instances=1, # ✅ correctness-first
|
||||
instances=2,
|
||||
):
|
||||
|
||||
# Ring grouping
|
||||
group_size = split_mask + 1
|
||||
num_groups = gpu_size // group_size
|
||||
|
||||
next_channels = {}
|
||||
prev_channels = {}
|
||||
prev_next_ids = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Channel creation (parity-based for deterministic tag matching)
|
||||
# ------------------------------------------------------------------
|
||||
for node in range(nnodes):
|
||||
for gpu in range(gpus_per_node):
|
||||
rank = gpu + gpus_per_node * node
|
||||
# --------------------------------------------------------------
|
||||
# Classic ring across all ranks:
|
||||
# prev = (rank - 1 + nranks) % nranks
|
||||
# next = (rank + 1) % nranks
|
||||
# --------------------------------------------------------------
|
||||
for rank in range(nranks):
|
||||
prev_rank = (rank - 1 + nranks) % nranks
|
||||
next_rank = (rank + 1) % nranks
|
||||
|
||||
pos = rank & split_mask
|
||||
group = rank // group_size
|
||||
# Deterministic channel creation order
|
||||
if (rank & 1) == 0:
|
||||
next_channels[rank] = PortChannel(next_rank, rank)
|
||||
prev_channels[rank] = PortChannel(prev_rank, rank)
|
||||
else:
|
||||
prev_channels[rank] = PortChannel(prev_rank, rank)
|
||||
next_channels[rank] = PortChannel(next_rank, rank)
|
||||
|
||||
next_group = (group + 1) % num_groups
|
||||
prev_group = (group - 1 + num_groups) % num_groups
|
||||
# --------------------------------------------------------------
|
||||
# --------------------------------------------------------------
|
||||
# Ring send/recv with explicit ACK
|
||||
#
|
||||
# Data path:
|
||||
# sender: put_with_signal() to next
|
||||
# receiver: wait() from prev
|
||||
#
|
||||
# ACK path:
|
||||
# receiver: signal() back to prev after data is available
|
||||
# sender: wait() for ACK from next before proceeding
|
||||
#
|
||||
# Even ranks: send first, then recv, then ACK prev, then wait ACK
|
||||
# Odd ranks : recv first, then ACK prev, then send, then wait ACK
|
||||
# --------------------------------------------------------------
|
||||
for rank in range(nranks):
|
||||
prev_rank = (rank - 1 + nranks) % nranks
|
||||
next_rank = (rank + 1) % nranks
|
||||
|
||||
next_rank = next_group * group_size + pos
|
||||
prev_rank = prev_group * group_size + pos
|
||||
src_rank = Rank(rank)
|
||||
next_rank_obj = Rank(next_rank)
|
||||
|
||||
# ✅ parity-based creation order
|
||||
if (rank & 1) == 0:
|
||||
next_channels[rank] = PortChannel(next_rank, rank)
|
||||
prev_channels[rank] = PortChannel(prev_rank, rank)
|
||||
else:
|
||||
prev_channels[rank] = PortChannel(prev_rank, rank)
|
||||
next_channels[rank] = PortChannel(next_rank, rank)
|
||||
src_buf = src_rank.get_input_buffer()
|
||||
next_out_buf = next_rank_obj.get_output_buffer()
|
||||
|
||||
prev_next_ids[rank] = (prev_rank, next_rank)
|
||||
src_chunk = src_buf[0:src_buf.size]
|
||||
dst_chunk = next_out_buf[0:next_out_buf.size]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Ring send/recv (deadlock-free)
|
||||
# ------------------------------------------------------------------
|
||||
for node in range(nnodes):
|
||||
for gpu in range(gpus_per_node):
|
||||
rank = gpu + gpus_per_node * node
|
||||
prev_rank, next_rank = prev_next_ids[rank]
|
||||
ch_to_next = next_channels[rank]
|
||||
ch_from_prev = prev_channels[rank]
|
||||
|
||||
ch_from_prev = prev_channels[rank]
|
||||
ch_to_next = next_channels[rank]
|
||||
if (rank & 1) == 0:
|
||||
# Send data to next and signal arrival
|
||||
ch_to_next.put_with_signal(
|
||||
dst_chunk,
|
||||
src_chunk,
|
||||
tb=0,
|
||||
)
|
||||
|
||||
src_rank = Rank(rank)
|
||||
src_buf = src_rank.get_input_buffer()
|
||||
src_chunk = src_buf[0:src_buf.size]
|
||||
# Wait for data from prev to become visible locally
|
||||
ch_from_prev.wait(
|
||||
tb=0,
|
||||
data_sync=SyncType.after,
|
||||
)
|
||||
|
||||
dst_rank = Rank(next_rank)
|
||||
dst_buf = dst_rank.get_output_buffer()
|
||||
dst_chunk = dst_buf[0:dst_buf.size]
|
||||
# Ack back to prev that this rank has observed/consumed input
|
||||
ch_from_prev.signal(
|
||||
tb=0,
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
# ✅ starter sends first
|
||||
ch_to_next.put_with_signal_and_flush(
|
||||
dst_chunk,
|
||||
src_chunk,
|
||||
tb=0,
|
||||
)
|
||||
# then receive from prev
|
||||
ch_from_prev.wait(tb=0, data_sync=SyncType.after)
|
||||
else:
|
||||
# ✅ everyone else receives first
|
||||
ch_from_prev.wait(tb=0, data_sync=SyncType.after)
|
||||
ch_to_next.put_with_signal_and_flush(
|
||||
dst_chunk,
|
||||
src_chunk,
|
||||
tb=0,
|
||||
)
|
||||
# Wait for next rank to ack our outgoing transfer
|
||||
ch_to_next.wait(
|
||||
tb=0,
|
||||
)
|
||||
|
||||
else:
|
||||
# Wait for data from prev first
|
||||
ch_from_prev.wait(
|
||||
tb=0,
|
||||
data_sync=SyncType.after,
|
||||
)
|
||||
|
||||
# Ack back to prev that this rank has observed/consumed input
|
||||
ch_from_prev.signal(
|
||||
tb=0,
|
||||
)
|
||||
|
||||
# Then send data to next
|
||||
ch_to_next.put_with_signal(
|
||||
dst_chunk,
|
||||
src_chunk,
|
||||
tb=0,
|
||||
)
|
||||
|
||||
# Wait for next rank to ack our outgoing transfer
|
||||
ch_to_next.wait(
|
||||
tb=0,
|
||||
)
|
||||
# --------------------------------------------------------------
|
||||
# Ring send/recv
|
||||
#
|
||||
# Even ranks: send first, then wait
|
||||
# Odd ranks : wait first, then send
|
||||
#
|
||||
# This is safe for an even-sized ring and avoids the
|
||||
# single-rank-starter wave.
|
||||
# --------------------------------------------------------------
|
||||
'''
|
||||
for rank in range(nranks):
|
||||
prev_rank = (rank - 1 + nranks) % nranks
|
||||
next_rank = (rank + 1) % nranks
|
||||
|
||||
src_rank = Rank(rank)
|
||||
next_rank_obj = Rank(next_rank)
|
||||
|
||||
src_buf = src_rank.get_input_buffer()
|
||||
next_out_buf = next_rank_obj.get_output_buffer()
|
||||
|
||||
src_chunk = src_buf[0:src_buf.size]
|
||||
dst_chunk = next_out_buf[0:next_out_buf.size]
|
||||
|
||||
ch_to_next = next_channels[rank]
|
||||
ch_from_prev = prev_channels[rank]
|
||||
|
||||
if (rank & 1) == 0:
|
||||
ch_to_next.put_with_signal_and_flush(
|
||||
dst_chunk,
|
||||
src_chunk,
|
||||
tb=0,
|
||||
)
|
||||
ch_from_prev.wait(
|
||||
tb=0,
|
||||
data_sync=SyncType.after,
|
||||
)
|
||||
else:
|
||||
ch_from_prev.wait(
|
||||
tb=0,
|
||||
data_sync=SyncType.after,
|
||||
)
|
||||
ch_to_next.put_with_signal_and_flush(
|
||||
dst_chunk,
|
||||
src_chunk,
|
||||
tb=0,
|
||||
)
|
||||
|
||||
'''
|
||||
print(JSON())
|
||||
|
||||
|
||||
@@ -103,21 +185,14 @@ def send_recv_test(name, nnodes, gpus_per_node, split_mask):
|
||||
# CLI
|
||||
# ----------------------------------------------------------------------
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--name", type=str, help="name of the program")
|
||||
parser.add_argument("--name", type=str, required=True, help="name of the program")
|
||||
parser.add_argument("--nnodes", type=int, default=1, help="number of nodes")
|
||||
parser.add_argument("--gpus_per_node", type=int, help="number of gpus per node")
|
||||
parser.add_argument(
|
||||
"--split_mask",
|
||||
type=lambda x: int(x, 0),
|
||||
default=0x3,
|
||||
help="split mask (e.g. 0x3)",
|
||||
)
|
||||
parser.add_argument("--gpus_per_node", type=int, required=True, help="number of GPUs per node")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
send_recv_test(
|
||||
send_recv_test_ring_even_ranks(
|
||||
args.name,
|
||||
args.nnodes,
|
||||
args.gpus_per_node,
|
||||
args.split_mask,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user