enforce scratch on boardcast packets

This commit is contained in:
Empyreus
2026-06-18 16:17:56 +00:00
parent 253bc05c7c
commit 9a02f3669f

View File

@@ -967,15 +967,19 @@ class SwitchChannel:
chunk to the specified buffer region across all ranks in the rank group, with
no explicit barrier required (the packet flag provides synchronization).
Both the source chunk and the destination buffer must be scratch buffers,
because the data is broadcast in LL (Low Latency) packet format (data + flag).
Args:
rank (int): The rank that will execute this broadcast operation.
src_chunk (Chunk): The source chunk containing packet data to broadcast.
buffer_offset (int): The offset in the destination buffer where data will be stored.
src_chunk (Chunk): The source scratch chunk containing packet data to broadcast.
buffer_offset (int): The offset in the destination scratch buffer where data will be stored.
size (int): The size of data to broadcast.
tb (int): The thread block ID that will execute this operation.
Raises:
RuntimeError: If src_chunk rank is not in the rank group, if chunk size
RuntimeError: If src_chunk rank is not in the rank group, if the source
chunk or destination buffer is not a scratch buffer, if chunk size
doesn't match the required size, or if buffer size is insufficient.
Example:
@@ -986,14 +990,15 @@ class SwitchChannel:
raise RuntimeError(
f"Source chunk rank {src_chunk.rank} is not part of the rank group {self.rank_group.ranks}."
)
if src_chunk.buffer != BufferType.scratch:
raise RuntimeError(f"Source chunk must be of type scratch for the packet broadcast.")
if self.buffer_type != BufferType.scratch:
raise RuntimeError(f"Destination buffer must be of type scratch for the packet broadcast.")
if src_chunk.size != size:
raise RuntimeError(f"Source chunk size {src_chunk.size} does not match the required size {size}.")
for rank in self.rank_group.ranks:
if self.buffer_type == BufferType.scratch:
buffer_size = get_program().gpus[rank].scratch_chunks
else:
buffer_size = get_program().buffers[rank][self.buffer_type].size
buffer_size = get_program().gpus[rank].scratch_chunks
if buffer_size < buffer_offset + size:
raise RuntimeError(