add support for allgather packet for small message sizes

This commit is contained in:
Empyreus
2026-06-11 20:46:09 +00:00
parent 5d7737437a
commit 02eb2cfc2e
8 changed files with 125 additions and 1 deletions

View File

@@ -959,6 +959,51 @@ class SwitchChannel:
op = GroupStore(src_chunk, self.buffer_type, buffer_offset, size, tb_channel_ids, self.channel_type)
get_program().add_operation(self.src_rank, tb, op)
def broadcast_packets(self, rank, src_chunk: Chunk, buffer_offset, size, tb):
"""Broadcast packet-formatted data from source chunk to all ranks in the switch channel.
Packet variant of :meth:`broadcast`. Emits a ``gstorepkt`` (MULTI_STORE_PKT)
operation that multicasts LL-protocol packets (data + flag) from the source
chunk to the specified buffer region across all ranks in the rank group, with
no explicit barrier required (the packet flag provides synchronization).
Args:
rank (int): The rank that will execute this broadcast operation.
src_chunk (Chunk): The source chunk containing packet data to broadcast.
buffer_offset (int): The offset in the destination buffer where data will be stored.
size (int): The size of data to broadcast.
tb (int): The thread block ID that will execute this operation.
Raises:
RuntimeError: If src_chunk rank is not in the rank group, if chunk size
doesn't match the required size, or if buffer size is insufficient.
Example:
>>> channel.broadcast_packets(rank=0, src_chunk=chunk, buffer_offset=0, size=1, tb=0)
"""
self.src_rank = rank
if src_chunk.rank not in self.rank_group.ranks:
raise RuntimeError(
f"Destination chunk rank {src_chunk.rank} is not part of the rank group {self.rank_group.ranks}."
)
if src_chunk.size != size:
raise RuntimeError(f"Destination chunk size {src_chunk.size} does not match the required size {size}.")
for rank in self.rank_group.ranks:
if self.buffer_type == BufferType.scratch:
buffer_size = get_program().gpus[rank].scratch_chunks
else:
buffer_size = get_program().buffers[rank][self.buffer_type].size
if buffer_size < buffer_offset + size:
raise RuntimeError(
f"Buffer size {buffer_size} is smaller than required size {buffer_offset + size} for rank {rank}."
)
tb_channel_ids = get_program().setup_channel(tb, self)
op = GroupStorePacket(src_chunk, self.buffer_type, buffer_offset, size, tb_channel_ids, self.channel_type)
get_program().add_operation(self.src_rank, tb, op)
class SwitchChannelRankView:
"""A rank-specific view of a SwitchChannel for performing operations.
@@ -1022,3 +1067,23 @@ class SwitchChannel:
>>> rank_view.broadcast(src_chunk=chunk, buffer_offset=0, size=1, tb=0)
"""
return self._channel.broadcast(self._rank, src_chunk, buffer_offset, size, tb)
def broadcast_packets(self, src_chunk: Chunk, buffer_offset, size, tb):
"""Perform a packet broadcast operation from this rank's perspective.
Convenience method that calls the underlying channel's broadcast_packets
method with this view's rank automatically provided.
Args:
src_chunk (Chunk): The source chunk containing packet data to broadcast.
buffer_offset (int): The offset in the destination buffer where data will be stored.
size (int): The size of data to broadcast.
tb (int): The thread block ID that will execute this operation.
Returns:
The result of the underlying channel's broadcast_packets operation.
Example:
>>> rank_view.broadcast_packets(src_chunk=chunk, buffer_offset=0, size=1, tb=0)
"""
return self._channel.broadcast_packets(self._rank, src_chunk, buffer_offset, size, tb)

View File

@@ -934,6 +934,38 @@ class GroupStore(BaseOperation):
return result
class GroupStorePacket(BaseOperation):
def __init__(
self,
src_chunk: Chunk,
buffer_type: BufferType,
buffer_offset: int,
size: int,
channel_ids: List[int],
channel_type: ChannelType = ChannelType.switch,
):
super().__init__(Instruction.group_store_packet)
self.src_chunk = src_chunk
self.buffer_type = buffer_type
self.buffer_offset = buffer_offset
self.size = size
self.channel_ids = channel_ids
self.channel_type = channel_type
def shift_buffers(self, instance, num_instances, replication_function):
self.buffer_offset = replication_function(self.buffer_offset, self.size, instance, num_instances)
self.src_chunk.index = replication_function(self.src_chunk.index, self.size, instance, num_instances)
def to_dict(self):
result = {"name": self.name.value}
result["src_buff"] = [{"type": self.src_chunk.buffer.value, "index": self.src_chunk.index, "size": self.size}]
result["dst_buff"] = [
{"switch_channel_id": self.channel_ids[0], "index": self.buffer_offset, "size": self.size}
]
result["channel_type"] = self.channel_type.value
return result
@dataclass
class GroupLoadReduceStore(BaseOperation):
def __init__(

View File

@@ -90,6 +90,7 @@ class Instruction(Enum):
read_reduce = "rre"
read_reduce_send = "rres"
group_store = "gstore"
group_store_packet = "gstorepkt"
group_load_reduce = "glre"
group_load_reduce_store = "glres"
pipeline = "pipeline"