Revised ProxyChannel interfaces (#400)

* Renamed `ProxyChannel` -> `BaseProxyChannel` and `SimpleProxyChannel`
-> `ProxyChannel`. It makes the interface more consistent by defining
channels to be associated with a certain src/dst memory region:
`ProxyChannel` as "sema + src/dst + fifo" and `SmChannel` as "sema +
src/dst". BaseProxyChannel is not associated with any memory regions, as
"sema + fifo".
* `ProxyChannelDeviceHandle` now inherits from
`BaseProxyChannelDeviceHandle`, instead of having one as a member.
This commit is contained in:
Changho Hwang
2024-12-06 10:53:34 -08:00
committed by GitHub
parent f6305a3c1d
commit 756f24c697
25 changed files with 255 additions and 250 deletions

View File

@@ -6,8 +6,8 @@
// be careful about using channels[my_rank] as it is inavlie and it is there just for simplicity of indexing
extern "C" __global__ void __launch_bounds__(1024, 1)
simple_proxy_channel(mscclpp::SimpleProxyChannelDeviceHandle* channels, int my_rank, int nranks, int* data,
int* scratch, int num_elements, int use_packet) {
proxy_channel(mscclpp::ProxyChannelDeviceHandle* channels, int my_rank, int nranks, int* data, int* scratch,
int num_elements, int use_packet) {
int tid = threadIdx.x;
int nthreads = blockDim.x;
uint64_t size_per_rank = (num_elements * sizeof(int)) / nranks;

View File

@@ -346,9 +346,9 @@ class MscclppKernel:
).get_compiled_kernel()
self.nblocks = 1
self.nthreads = nranks
elif test_name == "simple_proxy_channel":
elif test_name == "proxy_channel":
self._kernel = KernelBuilder(
file="simple_proxy_channel_test.cu", kernel_name="simple_proxy_channel", file_dir=file_dir
file="proxy_channel_test.cu", kernel_name="proxy_channel", file_dir=file_dir
).get_compiled_kernel()
self.nblocks = 1
self.nthreads = 1024
@@ -376,11 +376,11 @@ class MscclppKernel:
# keep a reference to the device handles so that they don't get garbage collected
self._d_semaphore_or_channels = cp.asarray(memoryview(b"".join(device_handles)), dtype=cp.uint8)
if test_name in ["h2d_semaphore", "d2d_semaphore", "sm_channel", "simple_proxy_channel"]:
if test_name in ["h2d_semaphore", "d2d_semaphore", "sm_channel", "proxy_channel"]:
self.params += pack(self._d_semaphore_or_channels, my_rank, nranks)
if test_name == "sm_channel":
self.params += pack(tensor.size, use_packet)
if test_name == "simple_proxy_channel":
if test_name == "proxy_channel":
self.params += pack(tensor, scratch, tensor.size, use_packet)
elif test_name == "fifo":
self.params = fifo.device_handle().raw
@@ -531,7 +531,7 @@ def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
@pytest.mark.parametrize("transport", ["NVLink", "IB"])
@pytest.mark.parametrize("use_packet", [False, True])
def test_simple_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_packet: bool):
def test_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_packet: bool):
group, connections = create_group_and_connection(mpi_group, transport)
memory = cp.zeros(nelem, dtype=cp.int32)
@@ -552,13 +552,13 @@ def test_simple_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, u
memory_to_register = scratch
else:
memory_to_register = memory
simple_channels = group.make_proxy_channels(proxy_service, memory_to_register, connections)
channels = group.make_proxy_channels(proxy_service, memory_to_register, connections)
kernel = MscclppKernel(
"simple_proxy_channel",
"proxy_channel",
my_rank=group.my_rank,
nranks=group.nranks,
semaphore_or_channels=simple_channels,
semaphore_or_channels=channels,
tensor=memory,
use_packet=use_packet,
scratch=scratch,