mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-06-29 10:57:27 +00:00
add support for allgather packet for small message sizes
This commit is contained in:
@@ -959,6 +959,51 @@ class SwitchChannel:
|
||||
op = GroupStore(src_chunk, self.buffer_type, buffer_offset, size, tb_channel_ids, self.channel_type)
|
||||
get_program().add_operation(self.src_rank, tb, op)
|
||||
|
||||
def broadcast_packets(self, rank, src_chunk: Chunk, buffer_offset, size, tb):
|
||||
"""Broadcast packet-formatted data from source chunk to all ranks in the switch channel.
|
||||
|
||||
Packet variant of :meth:`broadcast`. Emits a ``gstorepkt`` (MULTI_STORE_PKT)
|
||||
operation that multicasts LL-protocol packets (data + flag) from the source
|
||||
chunk to the specified buffer region across all ranks in the rank group, with
|
||||
no explicit barrier required (the packet flag provides synchronization).
|
||||
|
||||
Args:
|
||||
rank (int): The rank that will execute this broadcast operation.
|
||||
src_chunk (Chunk): The source chunk containing packet data to broadcast.
|
||||
buffer_offset (int): The offset in the destination buffer where data will be stored.
|
||||
size (int): The size of data to broadcast.
|
||||
tb (int): The thread block ID that will execute this operation.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If src_chunk rank is not in the rank group, if chunk size
|
||||
doesn't match the required size, or if buffer size is insufficient.
|
||||
|
||||
Example:
|
||||
>>> channel.broadcast_packets(rank=0, src_chunk=chunk, buffer_offset=0, size=1, tb=0)
|
||||
"""
|
||||
self.src_rank = rank
|
||||
if src_chunk.rank not in self.rank_group.ranks:
|
||||
raise RuntimeError(
|
||||
f"Destination chunk rank {src_chunk.rank} is not part of the rank group {self.rank_group.ranks}."
|
||||
)
|
||||
if src_chunk.size != size:
|
||||
raise RuntimeError(f"Destination chunk size {src_chunk.size} does not match the required size {size}.")
|
||||
|
||||
for rank in self.rank_group.ranks:
|
||||
if self.buffer_type == BufferType.scratch:
|
||||
buffer_size = get_program().gpus[rank].scratch_chunks
|
||||
else:
|
||||
buffer_size = get_program().buffers[rank][self.buffer_type].size
|
||||
|
||||
if buffer_size < buffer_offset + size:
|
||||
raise RuntimeError(
|
||||
f"Buffer size {buffer_size} is smaller than required size {buffer_offset + size} for rank {rank}."
|
||||
)
|
||||
|
||||
tb_channel_ids = get_program().setup_channel(tb, self)
|
||||
op = GroupStorePacket(src_chunk, self.buffer_type, buffer_offset, size, tb_channel_ids, self.channel_type)
|
||||
get_program().add_operation(self.src_rank, tb, op)
|
||||
|
||||
class SwitchChannelRankView:
|
||||
"""A rank-specific view of a SwitchChannel for performing operations.
|
||||
|
||||
@@ -1022,3 +1067,23 @@ class SwitchChannel:
|
||||
>>> rank_view.broadcast(src_chunk=chunk, buffer_offset=0, size=1, tb=0)
|
||||
"""
|
||||
return self._channel.broadcast(self._rank, src_chunk, buffer_offset, size, tb)
|
||||
|
||||
def broadcast_packets(self, src_chunk: Chunk, buffer_offset, size, tb):
|
||||
"""Perform a packet broadcast operation from this rank's perspective.
|
||||
|
||||
Convenience method that calls the underlying channel's broadcast_packets
|
||||
method with this view's rank automatically provided.
|
||||
|
||||
Args:
|
||||
src_chunk (Chunk): The source chunk containing packet data to broadcast.
|
||||
buffer_offset (int): The offset in the destination buffer where data will be stored.
|
||||
size (int): The size of data to broadcast.
|
||||
tb (int): The thread block ID that will execute this operation.
|
||||
|
||||
Returns:
|
||||
The result of the underlying channel's broadcast_packets operation.
|
||||
|
||||
Example:
|
||||
>>> rank_view.broadcast_packets(src_chunk=chunk, buffer_offset=0, size=1, tb=0)
|
||||
"""
|
||||
return self._channel.broadcast_packets(self._rank, src_chunk, buffer_offset, size, tb)
|
||||
|
||||
@@ -934,6 +934,38 @@ class GroupStore(BaseOperation):
|
||||
return result
|
||||
|
||||
|
||||
class GroupStorePacket(BaseOperation):
|
||||
def __init__(
|
||||
self,
|
||||
src_chunk: Chunk,
|
||||
buffer_type: BufferType,
|
||||
buffer_offset: int,
|
||||
size: int,
|
||||
channel_ids: List[int],
|
||||
channel_type: ChannelType = ChannelType.switch,
|
||||
):
|
||||
super().__init__(Instruction.group_store_packet)
|
||||
self.src_chunk = src_chunk
|
||||
self.buffer_type = buffer_type
|
||||
self.buffer_offset = buffer_offset
|
||||
self.size = size
|
||||
self.channel_ids = channel_ids
|
||||
self.channel_type = channel_type
|
||||
|
||||
def shift_buffers(self, instance, num_instances, replication_function):
|
||||
self.buffer_offset = replication_function(self.buffer_offset, self.size, instance, num_instances)
|
||||
self.src_chunk.index = replication_function(self.src_chunk.index, self.size, instance, num_instances)
|
||||
|
||||
def to_dict(self):
|
||||
result = {"name": self.name.value}
|
||||
result["src_buff"] = [{"type": self.src_chunk.buffer.value, "index": self.src_chunk.index, "size": self.size}]
|
||||
result["dst_buff"] = [
|
||||
{"switch_channel_id": self.channel_ids[0], "index": self.buffer_offset, "size": self.size}
|
||||
]
|
||||
result["channel_type"] = self.channel_type.value
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroupLoadReduceStore(BaseOperation):
|
||||
def __init__(
|
||||
|
||||
@@ -90,6 +90,7 @@ class Instruction(Enum):
|
||||
read_reduce = "rre"
|
||||
read_reduce_send = "rres"
|
||||
group_store = "gstore"
|
||||
group_store_packet = "gstorepkt"
|
||||
group_load_reduce = "glre"
|
||||
group_load_reduce_store = "glres"
|
||||
pipeline = "pipeline"
|
||||
|
||||
Reference in New Issue
Block a user