mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
WIP
This commit is contained in:
@@ -11,7 +11,7 @@ from mscclpp.language.collectives import *
|
||||
|
||||
def send_recv_test(name, nnodes, gpus_per_node, split_mask):
|
||||
gpu_size = nnodes * gpus_per_node
|
||||
collective = TestCollective(gpu_size, 1, 1)
|
||||
collective = SendRecv(gpu_size, 1, False)
|
||||
with CollectiveProgram(
|
||||
name,
|
||||
collective,
|
||||
@@ -236,3 +236,46 @@ class AllToAll(Collective):
|
||||
}
|
||||
rank_buffers.append(buffers)
|
||||
return rank_buffers
|
||||
|
||||
|
||||
class SendRecv(Collective):
|
||||
"""A SendRecv collective communication pattern.
|
||||
|
||||
SendRecv performs a point-to-point send/receive operation in a ring topology.
|
||||
Each rank sends its input buffer to the next rank and receives data from the
|
||||
previous rank into its output buffer.
|
||||
|
||||
This operation creates input and output buffers both sized by chunk_factor,
|
||||
as each rank sends and receives the same amount of data.
|
||||
"""
|
||||
|
||||
def __init__(self, num_ranks, chunk_factor, inplace):
|
||||
"""Initialize a new SendRecv collective.
|
||||
|
||||
Args:
|
||||
num_ranks (int): The number of ranks participating in the SendRecv.
|
||||
chunk_factor (int): The size factor for data chunks.
|
||||
inplace (bool): Whether the operation should be performed in-place.
|
||||
|
||||
Example:
|
||||
>>> sendrecv = SendRecv(num_ranks=4, chunk_factor=1, inplace=False)
|
||||
"""
|
||||
Collective.__init__(self, num_ranks, chunk_factor, inplace)
|
||||
self.name = "sendrecv"
|
||||
|
||||
def init_buffers(self):
|
||||
"""Initialize buffers for the SendRecv operation.
|
||||
|
||||
Creates input and output buffers both sized by chunk_factor.
|
||||
|
||||
Returns:
|
||||
list: A list of buffer dictionaries, one for each rank.
|
||||
"""
|
||||
rank_buffers = []
|
||||
for rank in range(self.num_ranks):
|
||||
buffers = {
|
||||
BufferType.input: BaseBuffer(rank, BufferType.input, 0, self.chunk_factor),
|
||||
BufferType.output: BaseBuffer(rank, BufferType.output, 0, self.chunk_factor),
|
||||
}
|
||||
rank_buffers.append(buffers)
|
||||
return rank_buffers
|
||||
|
||||
Reference in New Issue
Block a user