Files
mscclpp/python/mscclpp/language/channel.py
2026-04-08 17:02:07 -07:00

1019 lines
46 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from mscclpp.language.internal.types import RemoteBuffer, SyncType, ReduceOperationType, Chunk, RankGroup
from mscclpp.language.thread_block_group import *
from mscclpp.language.internal.globals import get_program
from mscclpp.language.internal.operations import *
from dataclasses import dataclass
from collections import defaultdict
@dataclass
class MemoryChannel:
"""A memory channel for direct memory access communication between GPUs.
MemoryChannel enables direct memory access between GPUs through memory mapping,
providing high-performance communication for operations like put, get, and reduce.
Each channel connects a source rank to a destination rank.
Attributes:
channel_id (int): Unique identifier for this channel within the source rank.
dst_rank (int): The destination rank for communication operations.
src_rank (int): The source rank for communication operations.
channel_type (ChannelType): The type of channel (memory).
"""
_channel_counts = defaultdict(int)
@classmethod
def reset(cls):
"""Reset all channel counts for this channel type."""
cls._channel_counts.clear()
def __init__(self, dst_rank: int, src_rank: int):
"""Initialize a new MemoryChannel.
Args:
dst_rank (int): The destination rank for this channel.
src_rank (int): The source rank for this channel.
Raises:
RuntimeError: If src_rank or dst_rank is out of bounds for the current program.
Example:
>>> channel = MemoryChannel(dst_rank=1, src_rank=0)
"""
num_ranks = get_program().num_ranks
if src_rank >= num_ranks:
raise RuntimeError(f"Source rank {src_rank} is out of bounds. Number of ranks: {num_ranks}")
if dst_rank >= num_ranks:
raise RuntimeError(f"Destination rank {dst_rank} is out of bounds. Number of ranks: {num_ranks}")
self.channel_id = MemoryChannel._channel_counts[src_rank]
MemoryChannel._channel_counts[src_rank] += 1
self.dst_rank = dst_rank
self.src_rank = src_rank
self.channel_type = ChannelType.memory
get_program().add_channel(self)
def signal(self, tb: int, data_sync: SyncType = SyncType.both, relaxed: bool = False):
"""Send a signal through the memory channel.
Signals notify the destination that data is ready or an operation has completed.
This is used for synchronization between ranks.
Args:
tb (int): The thread block ID that will execute this signal operation.
data_sync (SyncType, optional): Defines the order where threads inside the thread block
will be synchronized (equivalent to __syncthreads()) relative to the signal operation.
Defaults to SyncType.both.
relaxed (bool, optional): Whether to use relaxed memory ordering.
Defaults to False.
Example:
>>> channel.signal(tb=0, data_sync=SyncType.before)
"""
tb_channel_ids = get_program().setup_channel(tb, self)
op = SignalOperation(tb_channel_ids, self.channel_type, data_sync, relaxed)
get_program().add_operation(self.src_rank, tb, op)
def wait(self, tb: int, data_sync: SyncType = SyncType.both, relaxed: bool = False):
"""Wait for a signal through the memory channel.
Waits for a signal from the destination rank, typically used for synchronization
to ensure operations are completed before proceeding.
Args:
tb (int): The thread block ID that will execute this wait operation.
data_sync (SyncType, optional): Defines the order where threads inside the thread block
will be synchronized (equivalent to __syncthreads()) relative to the wait operation.
Defaults to SyncType.both.
relaxed (bool, optional): Whether to use relaxed memory ordering.
Defaults to False.
Example:
>>> channel.wait(tb=0, data_sync=SyncType.after)
"""
tb_channel_ids = get_program().setup_channel(tb, self)
op = WaitOperation(tb_channel_ids, self.channel_type, data_sync, relaxed)
get_program().add_operation(self.src_rank, tb, op)
def get(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None):
"""Retrieve data from remote memory to local memory.
Performs a get operation to copy data from the destination rank's memory
to the source rank's local memory through the memory channel.
Args:
dst_chunk (Chunk): The destination chunk on the source rank where data will be stored.
src_chunk (Chunk): The source chunk on the destination rank to retrieve data from.
tb (int, optional): The thread block ID that will execute this get operation. Defaults to None.
tb_group (ThreadBlockGroup, optional): The ThreadBlockGroup that will execute this get operation. Defaults to None.
Raises:
RuntimeError: If chunk ranks don't match the channel configuration.
Example:
>>> channel.get(dst_chunk, src_chunk, tb=0)
"""
if dst_chunk.rank != self.src_rank:
raise RuntimeError(
f"Source chunk rank {dst_chunk.rank} does not match current channel source rank {self.src_rank}."
)
if src_chunk.rank != self.dst_rank:
raise RuntimeError(
f"Dst chunk rank {src_chunk.rank} does not match current channel dst rank {self.dst_rank}."
)
remote_chunk = RemoteBuffer(dst_chunk.rank, src_chunk.rank, src_chunk.buffer, self.channel_type)
if tb is not None:
tb_list = [tb]
elif tb_group is not None:
tb_list = tb_group.tb_list
else:
raise RuntimeError(
"Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None."
)
for tb_id in tb_list:
tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type)
tb_channel_ids = get_program().setup_channel(tb_id, self)
op = GetOperation(
src_buff=[RemoteChunk(src_chunk.buffer, src_chunk.index, src_chunk.size, tb_chunk_id)],
dst_buff=[LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)],
channel_ids=tb_channel_ids,
channel_type=self.channel_type,
tbg_info=(
ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb())
if tb_group is not None
else None
),
)
get_program().add_operation(self.src_rank, tb_id, op)
def put(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None):
"""Send data from local memory to remote memory.
Performs a put operation to copy data from the source rank's local memory
to the destination rank's memory through the memory channel.
Args:
dst_chunk (Chunk): The destination chunk on the destination rank where data will be stored.
src_chunk (Chunk): The source chunk on the source rank to send data from.
tb (int, optional): The thread block ID that will execute this put operation. Defaults to None.
tb_group (ThreadBlockGroup, optional): The ThreadBlockGroup that will execute this put operation. Defaults to None.
Raises:
RuntimeError: If chunk ranks don't match the channel configuration or
if chunk sizes don't match.
Example:
>>> channel.put(dst_chunk, src_chunk, tb=0)
"""
if src_chunk.rank != self.src_rank:
raise RuntimeError(
f"Source chunk rank {src_chunk.rank} does not match current channel source rank {self.src_rank}."
)
if dst_chunk.rank != self.dst_rank:
raise RuntimeError(
f"Dst chunk rank {dst_chunk.rank} does not match current channel dst rank {self.dst_rank}."
)
if dst_chunk.size != src_chunk.size:
raise RuntimeError(
f"Destination chunk size {dst_chunk.size} does not match source chunk size {src_chunk.size}."
)
remote_chunk = RemoteBuffer(src_chunk.rank, dst_chunk.rank, dst_chunk.buffer, self.channel_type)
if tb is not None:
tb_list = [tb]
elif tb_group is not None:
tb_list = tb_group.tb_list
else:
raise RuntimeError(
"Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None."
)
for tb_id in tb_list:
tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type)
tb_channel_ids = get_program().setup_channel(tb_id, self)
op = PutOperation(
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
channel_ids=tb_channel_ids,
channel_type=self.channel_type,
tbg_info=(
ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb())
if tb_group is not None
else None
),
)
get_program().add_operation(self.src_rank, tb_id, op)
def read_put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None):
"""Transfer data in packet format from local to remote scratch buffer.
Performs a specialized put operation that transfers data in packet format
from the source rank's scratch buffer to the destination rank's scratch buffer.
Both source and destination chunks must be scratch buffers.
Args:
dst_chunk (Chunk): The destination scratch chunk on the destination rank.
src_chunk (Chunk): The source scratch chunk on the source rank.
tb (int, optional): The thread block ID that will execute this operation. Defaults to None.
tb_group (ThreadBlockGroup, optional): The ThreadBlockGroup that will execute this operation. Defaults to None.
Raises:
RuntimeError: If chunk ranks don't match channel configuration, if chunks
are not scratch buffers, or if chunk sizes don't match.
Example:
>>> channel.read_put_packet(dst_chunk, src_chunk, tb=0)
"""
if src_chunk.rank != self.src_rank:
raise RuntimeError(
f"Source chunk rank {src_chunk.rank} does not match current channel source rank {self.src_rank}."
)
if src_chunk.buffer != BufferType.scratch:
raise RuntimeError(f"Source chunk must be of type scratch.")
if dst_chunk.rank != self.dst_rank:
raise RuntimeError(
f"Dst chunk rank {dst_chunk.rank} does not match current channel dst rank {self.dst_rank}."
)
if dst_chunk.buffer != BufferType.scratch:
raise RuntimeError(f"Destination chunk must be of type scratch.")
if dst_chunk.size != src_chunk.size:
raise RuntimeError(
f"Destination chunk size {dst_chunk.size} does not match source chunk size {src_chunk.size}."
)
remote_chunk = RemoteBuffer(src_chunk.rank, dst_chunk.rank, dst_chunk.buffer, self.channel_type)
if tb is not None:
tb_list = [tb]
elif tb_group is not None:
tb_list = tb_group.tb_list
else:
raise ValueError(
"Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None."
)
for tb_id in tb_list:
tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type)
tb_channel_ids = get_program().setup_channel(tb_id, self)
op = PutOperation(
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
channel_ids=tb_channel_ids,
channel_type=self.channel_type,
tbg_info=(
ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb())
if tb_group is not None
else None
),
from_packet=True,
to_packet=True,
)
get_program().add_operation(self.src_rank, tb_id, op)
def put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None):
"""Transfer data from local buffer to remote scratch buffer in packet format.
Performs a put operation that transfers data from the source rank's buffer
to the destination rank's scratch buffer in packet format. The destination
chunk must be a scratch buffer.
Args:
dst_chunk (Chunk): The destination scratch chunk on the destination rank.
src_chunk (Chunk): The source chunk on the source rank.
tb (int, optional): The thread block ID that will execute this operation. Defaults to None.
tb_group (ThreadBlockGroup, optional): The ThreadBlockGroup that will execute this operation. Defaults to None.
Raises:
RuntimeError: If chunk ranks don't match channel configuration, if destination
chunk is not a scratch buffer, or if chunk sizes don't match.
Example:
>>> channel.put_packet(dst_chunk, src_chunk, tb=0)
"""
if src_chunk.rank != self.src_rank:
raise RuntimeError(
f"Source chunk rank {src_chunk.rank} does not match current channel source rank {self.src_rank}."
)
if dst_chunk.rank != self.dst_rank:
raise RuntimeError(
f"Dst chunk rank {dst_chunk.rank} does not match current channel dst rank {self.dst_rank}."
)
if dst_chunk.buffer != BufferType.scratch:
raise RuntimeError(f"Destination chunk must be of type scratch.")
if dst_chunk.size != src_chunk.size:
raise RuntimeError(
f"Destination chunk size {dst_chunk.size} does not match source chunk size {src_chunk.size}."
)
remote_chunk = RemoteBuffer(src_chunk.rank, dst_chunk.rank, dst_chunk.buffer, self.channel_type)
if tb is not None:
tb_list = [tb]
elif tb_group is not None:
tb_list = tb_group.tb_list
else:
raise RuntimeError(
"Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None."
)
for tb_id in tb_list:
tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type)
tb_channel_ids = get_program().setup_channel(tb_id, self)
op = PutOperation(
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
channel_ids=tb_channel_ids,
channel_type=self.channel_type,
tbg_info=(
ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb())
if tb_group is not None
else None
),
from_packet=False,
to_packet=True,
)
get_program().add_operation(self.src_rank, tb_id, op)
def reduce(
self,
local_src_chunk: Chunk,
remote_src_chunks: List[Chunk],
tb: int = None,
tb_group: ThreadBlockGroup = None,
local_dst_chunk: Chunk = None,
reduce_op: ReduceOperation = ReduceOperationType.sum,
):
"""Perform a reduction operation combining local and remote data.
Reduces data from multiple remote source chunks with a local source chunk,
storing the result in a local destination chunk. If no destination chunk
is specified, the result is stored in the local source chunk.
Args:
local_src_chunk (Chunk): The local source chunk on the source rank.
remote_src_chunks (List[Chunk]): List of remote source chunks to reduce with.
tb (int, optional): The thread block ID that will execute this operation. Defaults to None.
tb_group (ThreadBlockGroup, optional): The ThreadBlockGroup that will execute this operation. Defaults to None.
local_dst_chunk (Chunk, optional): The local destination chunk. If None,
uses local_src_chunk as destination. Defaults to None.
reduce_op (ReduceOperation, optional): The reduction operation to perform.
Defaults to ReduceOperationType.sum.
Raises:
RuntimeError: If chunk ranks don't match channel configuration or if
chunk sizes are inconsistent.
Example:
>>> channel.reduce(local_chunk, remote_chunks, tb=0)
"""
if local_dst_chunk is None:
local_dst_chunk = local_src_chunk
if local_src_chunk.rank != self.src_rank:
raise RuntimeError(
f"Destination chunk rank {local_src_chunk.rank} does not match current channel source rank {self.src_rank}."
)
if local_src_chunk.size != local_dst_chunk.size:
raise RuntimeError(
f"Source chunk size {local_src_chunk.size} does not match destination chunk size {local_dst_chunk.size}."
)
for chunk in remote_src_chunks:
if chunk.rank != self.dst_rank:
raise RuntimeError(
f"Source chunk rank {chunk.rank} does not match current channel dst rank {self.dst_rank}."
)
if chunk.size != local_src_chunk.size:
raise RuntimeError(
f"Source chunk size {chunk.size} does not match local source chunk size {local_src_chunk.size}."
)
if tb is not None:
tb_list = [tb]
elif tb_group is not None:
tb_list = tb_group.tb_list
else:
raise RuntimeError(
"Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None."
)
for tb_id in tb_list:
remote_chunks = [
RemoteChunk(
chunk.buffer,
chunk.index,
chunk.size,
get_program().setup_remote_chunk(
self.src_rank,
tb_id,
RemoteBuffer(local_src_chunk.rank, chunk.rank, chunk.buffer, self.channel_type),
self.channel_type,
),
)
for chunk in remote_src_chunks
]
tb_channel_ids = get_program().setup_channel(tb_id, self)
op = ReduceOperation(
local_src_buff=[LocalChunk(local_src_chunk.buffer, local_src_chunk.index, local_src_chunk.size)],
local_dst_buff=[LocalChunk(local_dst_chunk.buffer, local_dst_chunk.index, local_dst_chunk.size)],
remote_src_buff=remote_chunks,
remote_dst_buff=[],
channel_ids=tb_channel_ids,
channel_type=self.channel_type,
tbg_info=(
ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb())
if tb_group is not None
else None
),
reduce_operation=reduce_op,
)
get_program().add_operation(self.src_rank, tb_id, op)
@dataclass
class PortChannel:
"""A port channel for communication using port-based mechanisms between GPUs.
PortChannel enables communication between GPUs using interconnection ports, supporting
operations such as put, signal, wait, and flush. Each channel connects a source rank to a
destination rank and is suitable for scenarios requiring port-mapping-based data copy
and synchronization methods.
Attributes:
channel_id (int): Unique identifier for this channel within the source rank.
dst_rank (int): The destination rank for communication operations.
src_rank (int): The source rank for communication operations.
channel_type (ChannelType): The type of channel (port).
"""
_channel_counts = defaultdict(int)
@classmethod
def reset(cls):
"""Reset all channel counts for this channel type."""
cls._channel_counts.clear()
def __init__(self, dst_rank: int, src_rank: int):
"""Initialize a new PortChannel.
Args:
dst_rank (int): The destination rank for this channel.
src_rank (int): The source rank for this channel.
Raises:
RuntimeError: If src_rank or dst_rank is out of bounds for the current program.
Example:
>>> channel = PortChannel(dst_rank=1, src_rank=0)
"""
num_ranks = get_program().num_ranks
if src_rank >= num_ranks:
raise RuntimeError(f"Source rank {src_rank} is out of bounds. Number of ranks: {num_ranks}")
if dst_rank >= num_ranks:
raise RuntimeError(f"Destination rank {dst_rank} is out of bounds. Number of ranks: {num_ranks}")
self.channel_id = PortChannel._channel_counts[src_rank]
PortChannel._channel_counts[src_rank] += 1
self.dst_rank = dst_rank
self.src_rank = src_rank
self.channel_type = ChannelType.port
get_program().add_channel(self)
def signal(self, tb: int, data_sync: SyncType = SyncType.both):
"""Send a signal through the port channel.
Signals notify the destination that data is ready or an operation has completed.
This is used for synchronization between ranks through port-based mechanisms.
Args:
tb (int): The thread block ID that will execute this signal operation.
data_sync (SyncType, optional): Defines the order where threads inside the thread block
will be synchronized (equivalent to __syncthreads()) relative to the signal operation.
Defaults to SyncType.both.
Example:
>>> channel.signal(tb=0, data_sync=SyncType.before)
"""
tb_channel_ids = get_program().setup_channel(tb, self)
op = SignalOperation(tb_channel_ids, self.channel_type, data_sync)
get_program().add_operation(self.src_rank, tb, op)
def wait(self, tb: int, data_sync: SyncType = SyncType.both):
"""Wait for a signal through the port channel.
Waits for a signal from the destination rank, typically used for synchronization
to ensure operations are completed before proceeding through port-based mechanisms.
Args:
tb (int): The thread block ID that will execute this wait operation.
data_sync (SyncType, optional): Defines the order where threads inside the thread block
will be synchronized (equivalent to __syncthreads()) relative to the wait operation.
Defaults to SyncType.both.
Example:
>>> channel.wait(tb=0, data_sync=SyncType.after)
"""
tb_channel_ids = get_program().setup_channel(tb, self)
op = WaitOperation(tb_channel_ids, self.channel_type, data_sync)
get_program().add_operation(self.src_rank, tb, op)
def flush(self, tb: int, data_sync: SyncType = SyncType.both):
"""Flush pending operations through the port channel.
Forces completion of all pending operations on the port channel, ensuring
data consistency. This operation is only supported for port channels.
Args:
tb (int): The thread block ID that will execute this flush operation.
data_sync (SyncType, optional): Defines the order where threads inside the thread block
will be synchronized (equivalent to __syncthreads()) relative to the flush operation.
Defaults to SyncType.both.
Example:
>>> channel.flush(tb=0, data_sync=SyncType.after)
"""
tb_channel_ids = get_program().setup_channel(tb, self)
op = FlushOperation(tb_channel_ids, self.channel_type, data_sync)
get_program().add_operation(self.src_rank, tb, op)
def put(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
"""Send data from local memory to remote memory through the port channel.
Performs a put operation to copy data from the source rank's local memory
to the destination rank's memory through the port channel.
Args:
dst_chunk (Chunk): The destination chunk on the destination rank where data will be stored.
src_chunk (Chunk): The source chunk on the source rank to send data from.
tb (int): The thread block ID that will execute this put operation.
Raises:
RuntimeError: If chunk ranks don't match the channel configuration or
if chunk sizes don't match.
Example:
>>> channel.put(dst_chunk, src_chunk, tb=0)
"""
if src_chunk.rank != self.src_rank:
raise RuntimeError(
f"Source chunk rank {src_chunk.rank} does not match current channel source rank {self.src_rank}."
)
if dst_chunk.rank != self.dst_rank:
raise RuntimeError(
f"Dst chunk rank {dst_chunk.rank} does not match current channel dst rank {self.dst_rank}."
)
if dst_chunk.size != src_chunk.size:
raise RuntimeError(
f"Destination chunk size {dst_chunk.size} does not match source chunk size {src_chunk.size}."
)
remote_chunk = RemoteBuffer(src_chunk.rank, dst_chunk.rank, dst_chunk.buffer, self.channel_type)
tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb, remote_chunk, self.channel_type)
tb_channel_ids = get_program().setup_channel(tb, self)
op = PutOperation(
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
channel_ids=tb_channel_ids,
channel_type=self.channel_type,
)
get_program().add_operation(self.src_rank, tb, op)
def put_with_signal(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
"""Send data from local memory to remote memory with automatic signaling.
Performs a put operation that transfers data and automatically sends a signal
to notify the destination that the data transfer is complete. This combines
data transfer and synchronization in a single operation.
Args:
dst_chunk (Chunk): The destination chunk on the destination rank where data will be stored.
src_chunk (Chunk): The source chunk on the source rank to send data from.
tb (int): The thread block ID that will execute this put operation.
Raises:
RuntimeError: If chunk ranks don't match the channel configuration or
if chunk sizes don't match.
Example:
>>> channel.put_with_signal(dst_chunk, src_chunk, tb=0)
"""
if src_chunk.rank != self.src_rank:
raise RuntimeError(
f"Source chunk rank {src_chunk.rank} does not match current channel source rank {self.src_rank}."
)
if dst_chunk.rank != self.dst_rank:
raise RuntimeError(
f"Dst chunk rank {dst_chunk.rank} does not match current channel dst rank {self.dst_rank}."
)
if dst_chunk.size != src_chunk.size:
raise RuntimeError(
f"Destination chunk size {dst_chunk.size} does not match source chunk size {src_chunk.size}."
)
remote_chunk = RemoteBuffer(src_chunk.rank, dst_chunk.rank, dst_chunk.buffer, self.channel_type)
tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb, remote_chunk, self.channel_type)
tb_channel_ids = get_program().setup_channel(tb, self)
op = PutOperation(
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
channel_ids=tb_channel_ids,
channel_type=self.channel_type,
with_signal=True,
)
get_program().add_operation(self.src_rank, tb, op)
def put_with_signal_and_flush(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
"""Send data from local memory to remote memory with signal and flush.
Performs a put operation that transfers data, automatically sends a signal,
and flushes the channel. This provides the guarantee of data transfer completion.
Args:
dst_chunk (Chunk): The destination chunk on the destination rank where data will be stored.
src_chunk (Chunk): The source chunk on the source rank to send data from.
tb (int): The thread block ID that will execute this put operation.
Raises:
RuntimeError: If chunk ranks don't match the channel configuration or
if chunk sizes don't match.
Example:
>>> channel.put_with_signal_and_flush(dst_chunk, src_chunk, tb=0)
"""
if src_chunk.rank != self.src_rank:
raise RuntimeError(
f"Source chunk rank {src_chunk.rank} does not match current channel source rank {self.src_rank}."
)
if dst_chunk.rank != self.dst_rank:
raise RuntimeError(
f"Dst chunk rank {dst_chunk.rank} does not match current channel dst rank {self.dst_rank}."
)
if dst_chunk.size != src_chunk.size:
raise RuntimeError(
f"Destination chunk size {dst_chunk.size} does not match source chunk size {src_chunk.size}."
)
remote_chunk = RemoteBuffer(src_chunk.rank, dst_chunk.rank, dst_chunk.buffer, self.channel_type)
tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb, remote_chunk, self.channel_type)
tb_channel_ids = get_program().setup_channel(tb, self)
op = PutOperation(
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
channel_ids=tb_channel_ids,
channel_type=self.channel_type,
with_signal_and_flush=True,
)
get_program().add_operation(self.src_rank, tb, op)
def put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
"""Transfer data from local buffer to remote scratch buffer in packet format.
Performs a specialized put operation that transfers data from the source rank's buffer
to the destination rank's scratch buffer in packet format through the port channel.
The destination chunk must be a scratch buffer.
Args:
dst_chunk (Chunk): The destination scratch chunk on the destination rank.
src_chunk (Chunk): The source chunk on the source rank (any buffer type).
tb (int): The thread block ID that will execute this operation.
Raises:
RuntimeError: If chunk ranks don't match channel configuration, if destination
chunk is not a scratch buffer, or if chunk sizes don't match.
Example:
>>> channel.put_packets(dst_chunk, src_chunk, tb=0)
"""
if src_chunk.rank != self.src_rank:
raise RuntimeError(
f"Source chunk rank {src_chunk.rank} does not match current channel source rank {self.src_rank}."
)
if dst_chunk.rank != self.dst_rank:
raise RuntimeError(
f"Dst chunk rank {dst_chunk.rank} does not match current channel dst rank {self.dst_rank}."
)
if dst_chunk.buffer != BufferType.scratch:
raise RuntimeError(f"Destination chunk must be of type scratch.")
if dst_chunk.size != src_chunk.size:
raise RuntimeError(
f"Destination chunk size {dst_chunk.size} does not match source chunk size {src_chunk.size}."
)
remote_chunk = RemoteBuffer(src_chunk.rank, dst_chunk.rank, dst_chunk.buffer, self.channel_type)
tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb, remote_chunk, self.channel_type)
tb_channel_ids = get_program().setup_channel(tb, self)
op = PutOperation(
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
channel_ids=tb_channel_ids,
channel_type=self.channel_type,
from_packet=False,
to_packet=True,
)
get_program().add_operation(self.src_rank, tb, op)
def read_put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
"""Transfer data in packet format from local to remote scratch buffer.
Performs a specialized put operation that transfers data in packet format
from the source rank's scratch buffer to the destination rank's scratch buffer
through the port channel. Both source and destination chunks must be scratch buffers.
Args:
dst_chunk (Chunk): The destination scratch chunk on the destination rank.
src_chunk (Chunk): The source scratch chunk on the source rank.
tb (int): The thread block ID that will execute this operation.
Raises:
RuntimeError: If chunk ranks don't match channel configuration, if chunks
are not scratch buffers, or if chunk sizes don't match.
Example:
>>> channel.read_put_packet(dst_chunk, src_chunk, tb=0)
"""
if src_chunk.rank != self.src_rank:
raise RuntimeError(
f"Source chunk rank {src_chunk.rank} does not match current channel source rank {self.src_rank}."
)
if src_chunk.buffer != BufferType.scratch:
raise RuntimeError(f"Source chunk must be of type scratch.")
if dst_chunk.rank != self.dst_rank:
raise RuntimeError(
f"Dst chunk rank {dst_chunk.rank} does not match current channel dst rank {self.dst_rank}."
)
if dst_chunk.buffer != BufferType.scratch:
raise RuntimeError(f"Destination chunk must be of type scratch.")
if dst_chunk.size != src_chunk.size:
raise RuntimeError(
f"Destination chunk size {dst_chunk.size} does not match source chunk size {src_chunk.size}."
)
remote_chunk = RemoteBuffer(src_chunk.rank, dst_chunk.rank, dst_chunk.buffer, self.channel_type)
tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb, remote_chunk, self.channel_type)
tb_channel_ids = get_program().setup_channel(tb, self)
op = PutOperation(
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
channel_ids=tb_channel_ids,
channel_type=self.channel_type,
from_packet=True,
to_packet=True,
)
get_program().add_operation(self.src_rank, tb, op)
@dataclass
class SwitchChannel:
"""A switch channel for collective communication operations among multiple GPUs.
SwitchChannel enables collective operations like reduce and broadcast among a group
of ranks through a switch-based communication mechanism. It supports operations
on shared buffers across multiple ranks in the group.
Attributes:
channel_ids (dict): Dictionary mapping ranks to their channel IDs.
channel_type (ChannelType): The type of channel (switch).
buffer_type (BufferType): The type of buffer used for operations.
rank_group (RankGroup): The group of ranks participating in this channel.
"""
_channel_counts = defaultdict(int)
@classmethod
def reset(cls):
"""Reset all channel counts for this channel type."""
cls._channel_counts.clear()
def __init__(self, rank_list: List[int], buffer_type: BufferType):
"""Initialize a new SwitchChannel.
Args:
rank_list (List[int]): List of ranks that will participate in this switch channel.
buffer_type (BufferType): The type of buffer to use for switch operations.
Raises:
RuntimeError: If any rank in rank_list is out of bounds for the current program.
Example:
>>> channel = SwitchChannel(rank_list=[0, 1, 2, 3], buffer_type=BufferType.input)
"""
num_ranks = get_program().num_ranks
if not all(rank < num_ranks for rank in rank_list):
raise RuntimeError(f"One or more ranks in {rank_list} are out of bounds. Number of ranks: {num_ranks}")
self.channel_ids = {}
for rank in rank_list:
if rank >= num_ranks:
raise RuntimeError(f"Rank {rank} is out of bounds. Number of ranks: {num_ranks}")
self.channel_ids[rank] = SwitchChannel._channel_counts[rank]
SwitchChannel._channel_counts[rank] += 1
self.channel_type = ChannelType.switch
self.buffer_type = buffer_type
self.rank_group = RankGroup(len(rank_list), rank_list)
get_program().add_channel(self)
def at_rank(self, rank):
"""Get a rank-specific view of this switch channel.
Returns a SwitchChannelRankView that provides rank-specific operations
for reduce and broadcast on this switch channel.
Args:
rank (int): The rank to create a view for.
Returns:
SwitchChannelRankView: A rank-specific view of this channel.
Raises:
RuntimeError: If rank is not part of this channel's rank group.
Example:
>>> channel.at_rank(0)
"""
if rank not in self.rank_group.ranks:
raise RuntimeError(f"Rank {rank} is not part of this SwitchChannel's rank group.")
return SwitchChannel.SwitchChannelRankView(self, rank)
def reduce(self, rank, buffer_offset, size, dst_chunk: Chunk, tb, reduce_op=ReduceOperationType.sum):
"""Perform a reduction operation across all ranks in the switch channel.
Reduces data from the specified buffer region across all ranks in the
rank group, storing the result in the destination chunk.
Args:
rank (int): The rank that will execute this reduction operation.
buffer_offset (int): The offset in the buffer where reduction data starts.
size (int): The size of data to reduce.
dst_chunk (Chunk): The destination chunk where the result will be stored.
tb (int): The thread block ID that will execute this operation.
reduce_op (ReduceOperationType, optional): The reduction operation to perform.
Defaults to ReduceOperationType.sum.
Raises:
RuntimeError: If dst_chunk rank is not in the rank group, if chunk size
doesn't match the required size, or if buffer size is insufficient.
Example:
>>> channel.reduce(rank=0, buffer_offset=0, size=1, dst_chunk=chunk, tb=0)
"""
self.src_rank = rank
if dst_chunk.rank not in self.rank_group.ranks:
raise RuntimeError(
f"Destination chunk rank {dst_chunk.rank} is not part of the rank group {self.rank_group.ranks}."
)
if dst_chunk.size != size:
raise RuntimeError(f"Destination chunk size {dst_chunk.size} does not match the required size {size}.")
for rank in self.rank_group.ranks:
if self.buffer_type == BufferType.scratch:
buffer_size = get_program().gpus[rank].scratch_chunks
else:
buffer_size = get_program().buffers[rank][self.buffer_type].size
if buffer_size < buffer_offset + size:
raise RuntimeError(
f"Buffer size {buffer_size} is smaller than required size {buffer_offset + size} for rank {rank}."
)
tb_channel_ids = get_program().setup_channel(tb, self)
op = GroupLoadReduce(
self.buffer_type,
buffer_offset,
size,
dst_chunk,
tb_channel_ids,
self.channel_type,
reduce_op,
)
get_program().add_operation(self.src_rank, tb, op)
def broadcast(self, rank, src_chunk: Chunk, buffer_offset, size, tb):
"""Broadcast data from source chunk to all ranks in the switch channel.
Broadcasts data from the source chunk to the specified buffer region
across all ranks in the rank group.
Args:
rank (int): The rank that will execute this broadcast operation.
src_chunk (Chunk): The source chunk containing data to broadcast.
buffer_offset (int): The offset in the destination buffer where data will be stored.
size (int): The size of data to broadcast.
tb (int): The thread block ID that will execute this operation.
Raises:
RuntimeError: If src_chunk rank is not in the rank group, if chunk size
doesn't match the required size, or if buffer size is insufficient.
Example:
>>> channel.broadcast(rank=0, src_chunk=chunk, buffer_offset=0, size=1, tb=0)
"""
self.src_rank = rank
if src_chunk.rank not in self.rank_group.ranks:
raise RuntimeError(
f"Destination chunk rank {src_chunk.rank} is not part of the rank group {self.rank_group.ranks}."
)
if src_chunk.size != size:
raise RuntimeError(f"Destination chunk size {src_chunk.size} does not match the required size {size}.")
for rank in self.rank_group.ranks:
if self.buffer_type == BufferType.scratch:
buffer_size = get_program().gpus[rank].scratch_chunks
else:
buffer_size = get_program().buffers[rank][self.buffer_type].size
if buffer_size < buffer_offset + size:
raise RuntimeError(
f"Buffer size {buffer_size} is smaller than required size {buffer_offset + size} for rank {rank}."
)
tb_channel_ids = get_program().setup_channel(tb, self)
op = GroupStore(src_chunk, self.buffer_type, buffer_offset, size, tb_channel_ids, self.channel_type)
get_program().add_operation(self.src_rank, tb, op)
class SwitchChannelRankView:
"""A rank-specific view of a SwitchChannel for performing operations.
This class provides a convenient interface for performing switch channel
operations from the perspective of a specific rank, automatically passing
the rank parameter to the underlying channel methods.
Attributes:
_channel (SwitchChannel): The underlying switch channel.
_rank (int): The rank this view represents.
"""
def __init__(self, channel, rank):
"""Initialize a new SwitchChannelRankView.
Args:
channel (SwitchChannel): The switch channel to create a view for.
rank (int): The rank this view will represent.
"""
self._channel: SwitchChannel = channel
self._rank: int = rank
def reduce(self, buffer_offset, size, dst_chunk: Chunk, tb, reduce_op=ReduceOperationType.sum):
"""Perform a reduction operation from this rank's perspective.
Convenience method that calls the underlying channel's reduce method
with this view's rank automatically provided.
Args:
buffer_offset (int): The offset in the buffer where reduction data starts.
size (int): The size of data to reduce.
dst_chunk (Chunk): The destination chunk where the result will be stored.
tb (int): The thread block ID that will execute this operation.
reduce_op (ReduceOperationType, optional): The reduction operation to perform.
Defaults to ReduceOperationType.sum.
Returns:
The result of the underlying channel's reduce operation.
Example:
>>> rank_view.reduce(buffer_offset=0, size=1, dst_chunk=chunk, tb=0)
"""
return self._channel.reduce(self._rank, buffer_offset, size, dst_chunk, tb, reduce_op)
def broadcast(self, src_chunk: Chunk, buffer_offset, size, tb):
"""Perform a broadcast operation from this rank's perspective.
Convenience method that calls the underlying channel's broadcast method
with this view's rank automatically provided.
Args:
src_chunk (Chunk): The source chunk containing data to broadcast.
buffer_offset (int): The offset in the destination buffer where data will be stored.
size (int): The size of data to broadcast.
tb (int): The thread block ID that will execute this operation.
Returns:
The result of the underlying channel's broadcast operation.
Example:
>>> rank_view.broadcast(src_chunk=chunk, buffer_offset=0, size=1, tb=0)
"""
return self._channel.broadcast(self._rank, src_chunk, buffer_offset, size, tb)