diff --git a/python/mscclpp/language/channel.py b/python/mscclpp/language/channel.py index 153985ff..cb68c80c 100644 --- a/python/mscclpp/language/channel.py +++ b/python/mscclpp/language/channel.py @@ -967,15 +967,19 @@ class SwitchChannel: chunk to the specified buffer region across all ranks in the rank group, with no explicit barrier required (the packet flag provides synchronization). + Both the source chunk and the destination buffer must be scratch buffers, + because the data is broadcast in LL (Low Latency) packet format (data + flag). + Args: rank (int): The rank that will execute this broadcast operation. - src_chunk (Chunk): The source chunk containing packet data to broadcast. - buffer_offset (int): The offset in the destination buffer where data will be stored. + src_chunk (Chunk): The source scratch chunk containing packet data to broadcast. + buffer_offset (int): The offset in the destination scratch buffer where data will be stored. size (int): The size of data to broadcast. tb (int): The thread block ID that will execute this operation. Raises: - RuntimeError: If src_chunk rank is not in the rank group, if chunk size + RuntimeError: If src_chunk rank is not in the rank group, if the source + chunk or destination buffer is not a scratch buffer, if chunk size doesn't match the required size, or if buffer size is insufficient. Example: @@ -986,14 +990,15 @@ class SwitchChannel: raise RuntimeError( f"Source chunk rank {src_chunk.rank} is not part of the rank group {self.rank_group.ranks}." ) + if src_chunk.buffer != BufferType.scratch: + raise RuntimeError(f"Source chunk must be of type scratch for the packet broadcast.") + if self.buffer_type != BufferType.scratch: + raise RuntimeError(f"Destination buffer must be of type scratch for the packet broadcast.") if src_chunk.size != size: raise RuntimeError(f"Source chunk size {src_chunk.size} does not match the required size {size}.") for rank in self.rank_group.ranks: - if self.buffer_type == BufferType.scratch: - buffer_size = get_program().gpus[rank].scratch_chunks - else: - buffer_size = get_program().buffers[rank][self.buffer_type].size + buffer_size = get_program().gpus[rank].scratch_chunks if buffer_size < buffer_offset + size: raise RuntimeError(