mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-07-02 21:37:00 +00:00
enforce scratch on boardcast packets
This commit is contained in:
@@ -967,15 +967,19 @@ class SwitchChannel:
|
||||
chunk to the specified buffer region across all ranks in the rank group, with
|
||||
no explicit barrier required (the packet flag provides synchronization).
|
||||
|
||||
Both the source chunk and the destination buffer must be scratch buffers,
|
||||
because the data is broadcast in LL (Low Latency) packet format (data + flag).
|
||||
|
||||
Args:
|
||||
rank (int): The rank that will execute this broadcast operation.
|
||||
src_chunk (Chunk): The source chunk containing packet data to broadcast.
|
||||
buffer_offset (int): The offset in the destination buffer where data will be stored.
|
||||
src_chunk (Chunk): The source scratch chunk containing packet data to broadcast.
|
||||
buffer_offset (int): The offset in the destination scratch buffer where data will be stored.
|
||||
size (int): The size of data to broadcast.
|
||||
tb (int): The thread block ID that will execute this operation.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If src_chunk rank is not in the rank group, if chunk size
|
||||
RuntimeError: If src_chunk rank is not in the rank group, if the source
|
||||
chunk or destination buffer is not a scratch buffer, if chunk size
|
||||
doesn't match the required size, or if buffer size is insufficient.
|
||||
|
||||
Example:
|
||||
@@ -986,14 +990,15 @@ class SwitchChannel:
|
||||
raise RuntimeError(
|
||||
f"Source chunk rank {src_chunk.rank} is not part of the rank group {self.rank_group.ranks}."
|
||||
)
|
||||
if src_chunk.buffer != BufferType.scratch:
|
||||
raise RuntimeError(f"Source chunk must be of type scratch for the packet broadcast.")
|
||||
if self.buffer_type != BufferType.scratch:
|
||||
raise RuntimeError(f"Destination buffer must be of type scratch for the packet broadcast.")
|
||||
if src_chunk.size != size:
|
||||
raise RuntimeError(f"Source chunk size {src_chunk.size} does not match the required size {size}.")
|
||||
|
||||
for rank in self.rank_group.ranks:
|
||||
if self.buffer_type == BufferType.scratch:
|
||||
buffer_size = get_program().gpus[rank].scratch_chunks
|
||||
else:
|
||||
buffer_size = get_program().buffers[rank][self.buffer_type].size
|
||||
buffer_size = get_program().gpus[rank].scratch_chunks
|
||||
|
||||
if buffer_size < buffer_offset + size:
|
||||
raise RuntimeError(
|
||||
|
||||
Reference in New Issue
Block a user