mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
Revised ProxyChannel interfaces (#400)
* Renamed `ProxyChannel` -> `BaseProxyChannel` and `SimpleProxyChannel` -> `ProxyChannel`. It makes the interface more consistent by defining channels to be associated with a certain src/dst memory region: `ProxyChannel` as "sema + src/dst + fifo" and `SmChannel` as "sema + src/dst". BaseProxyChannel is not associated with any memory regions, as "sema + fifo". * `ProxyChannelDeviceHandle` now inherits from `BaseProxyChannelDeviceHandle`, instead of having one as a member.
This commit is contained in:
@@ -6,8 +6,8 @@
|
||||
|
||||
// be careful about using channels[my_rank] as it is inavlie and it is there just for simplicity of indexing
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
simple_proxy_channel(mscclpp::SimpleProxyChannelDeviceHandle* channels, int my_rank, int nranks, int* data,
|
||||
int* scratch, int num_elements, int use_packet) {
|
||||
proxy_channel(mscclpp::ProxyChannelDeviceHandle* channels, int my_rank, int nranks, int* data, int* scratch,
|
||||
int num_elements, int use_packet) {
|
||||
int tid = threadIdx.x;
|
||||
int nthreads = blockDim.x;
|
||||
uint64_t size_per_rank = (num_elements * sizeof(int)) / nranks;
|
||||
@@ -346,9 +346,9 @@ class MscclppKernel:
|
||||
).get_compiled_kernel()
|
||||
self.nblocks = 1
|
||||
self.nthreads = nranks
|
||||
elif test_name == "simple_proxy_channel":
|
||||
elif test_name == "proxy_channel":
|
||||
self._kernel = KernelBuilder(
|
||||
file="simple_proxy_channel_test.cu", kernel_name="simple_proxy_channel", file_dir=file_dir
|
||||
file="proxy_channel_test.cu", kernel_name="proxy_channel", file_dir=file_dir
|
||||
).get_compiled_kernel()
|
||||
self.nblocks = 1
|
||||
self.nthreads = 1024
|
||||
@@ -376,11 +376,11 @@ class MscclppKernel:
|
||||
# keep a reference to the device handles so that they don't get garbage collected
|
||||
self._d_semaphore_or_channels = cp.asarray(memoryview(b"".join(device_handles)), dtype=cp.uint8)
|
||||
|
||||
if test_name in ["h2d_semaphore", "d2d_semaphore", "sm_channel", "simple_proxy_channel"]:
|
||||
if test_name in ["h2d_semaphore", "d2d_semaphore", "sm_channel", "proxy_channel"]:
|
||||
self.params += pack(self._d_semaphore_or_channels, my_rank, nranks)
|
||||
if test_name == "sm_channel":
|
||||
self.params += pack(tensor.size, use_packet)
|
||||
if test_name == "simple_proxy_channel":
|
||||
if test_name == "proxy_channel":
|
||||
self.params += pack(tensor, scratch, tensor.size, use_packet)
|
||||
elif test_name == "fifo":
|
||||
self.params = fifo.device_handle().raw
|
||||
@@ -531,7 +531,7 @@ def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):
|
||||
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
|
||||
@pytest.mark.parametrize("transport", ["NVLink", "IB"])
|
||||
@pytest.mark.parametrize("use_packet", [False, True])
|
||||
def test_simple_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_packet: bool):
|
||||
def test_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_packet: bool):
|
||||
group, connections = create_group_and_connection(mpi_group, transport)
|
||||
|
||||
memory = cp.zeros(nelem, dtype=cp.int32)
|
||||
@@ -552,13 +552,13 @@ def test_simple_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, u
|
||||
memory_to_register = scratch
|
||||
else:
|
||||
memory_to_register = memory
|
||||
simple_channels = group.make_proxy_channels(proxy_service, memory_to_register, connections)
|
||||
channels = group.make_proxy_channels(proxy_service, memory_to_register, connections)
|
||||
|
||||
kernel = MscclppKernel(
|
||||
"simple_proxy_channel",
|
||||
"proxy_channel",
|
||||
my_rank=group.my_rank,
|
||||
nranks=group.nranks,
|
||||
semaphore_or_channels=simple_channels,
|
||||
semaphore_or_channels=channels,
|
||||
tensor=memory,
|
||||
use_packet=use_packet,
|
||||
scratch=scratch,
|
||||
|
||||
Reference in New Issue
Block a user