Revised ProxyChannel interfaces (#400)

* Renamed `ProxyChannel` -> `BaseProxyChannel` and `SimpleProxyChannel`
-> `ProxyChannel`. It makes the interface more consistent by defining
channels to be associated with a certain src/dst memory region:
`ProxyChannel` as "sema + src/dst + fifo" and `SmChannel` as "sema +
src/dst". BaseProxyChannel is not associated with any memory regions, as
"sema + fifo".
* `ProxyChannelDeviceHandle` now inherits from
`BaseProxyChannelDeviceHandle`, instead of having one as a member.
This commit is contained in:
Changho Hwang
2024-12-06 10:53:34 -08:00
committed by GitHub
parent f6305a3c1d
commit 756f24c697
25 changed files with 255 additions and 250 deletions

View File

@@ -14,7 +14,7 @@ from ._mscclpp import (
numa,
ProxyService,
RegisteredMemory,
SimpleProxyChannel,
ProxyChannel,
SmChannel,
SmDevice2DeviceSemaphore,
TcpBootstrap,

View File

@@ -14,7 +14,7 @@ from ._mscclpp import (
Host2HostSemaphore,
ProxyService,
RegisteredMemory,
SimpleProxyChannel,
ProxyChannel,
SmChannel,
SmDevice2DeviceSemaphore,
TcpBootstrap,
@@ -180,8 +180,8 @@ class CommGroup:
semaphore_ids[rank] = proxy_service.add_semaphore(semaphores[rank])
channels = {}
for rank in semaphores:
channels[rank] = SimpleProxyChannel(
proxy_service.proxy_channel(semaphore_ids[rank]), memory_ids[rank], memory_ids[self.my_rank]
channels[rank] = proxy_service.proxy_channel(
semaphore_ids[rank], memory_ids[rank], memory_ids[self.my_rank]
)
return channels
@@ -218,8 +218,8 @@ class CommGroup:
semaphore_ids[rank] = proxy_service.add_semaphore(semaphores[rank])
channels = {}
for rank in semaphores:
channels[rank] = SimpleProxyChannel(
proxy_service.proxy_channel(semaphore_ids[rank]), memory_ids[rank], memory_ids[self.my_rank]
channels[rank] = proxy_service.proxy_channel(
semaphore_ids[rank], memory_ids[rank], memory_ids[self.my_rank]
)
return channels
@@ -232,7 +232,7 @@ class CommGroup:
semaphore_ids[rank] = proxy_service.add_semaphore(semaphores[rank])
channels = {}
for rank in semaphores:
channels[rank] = proxy_service.proxy_channel(semaphore_ids[rank])
channels[rank] = proxy_service.base_proxy_channel(semaphore_ids[rank])
return channels
def register_memory_with_proxy(

View File

@@ -23,11 +23,26 @@ void register_proxy_channel(nb::module_& m) {
.def("add_semaphore", &ProxyService::addSemaphore, nb::arg("semaphore"))
.def("add_memory", &ProxyService::addMemory, nb::arg("memory"))
.def("semaphore", &ProxyService::semaphore, nb::arg("id"))
.def("proxy_channel", &ProxyService::proxyChannel, nb::arg("id"));
.def("base_proxy_channel", &ProxyService::baseProxyChannel, nb::arg("id"))
.def("proxy_channel", &ProxyService::proxyChannel, nb::arg("id"), nb::arg("dst"), nb::arg("src"));
nb::class_<ProxyChannel>(m, "ProxyChannel")
nb::class_<BaseProxyChannel>(m, "BaseProxyChannel")
.def(nb::init<SemaphoreId, std::shared_ptr<Host2DeviceSemaphore>, std::shared_ptr<Proxy>>(),
nb::arg("semaphoreId"), nb::arg("semaphore"), nb::arg("proxy"))
.def("device_handle", &BaseProxyChannel::deviceHandle);
nb::class_<BaseProxyChannel::DeviceHandle>(m, "BaseProxyChannelDeviceHandle")
.def(nb::init<>())
.def_rw("semaphoreId_", &BaseProxyChannel::DeviceHandle::semaphoreId_)
.def_rw("semaphore_", &BaseProxyChannel::DeviceHandle::semaphore_)
.def_rw("fifo_", &BaseProxyChannel::DeviceHandle::fifo_)
.def_prop_ro("raw", [](const BaseProxyChannel::DeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});
nb::class_<ProxyChannel>(m, "ProxyChannel")
.def(nb::init<SemaphoreId, std::shared_ptr<Host2DeviceSemaphore>, std::shared_ptr<Proxy>, MemoryId, MemoryId>(),
nb::arg("semaphoreId"), nb::arg("semaphore"), nb::arg("proxy"), nb::arg("dst"), nb::arg("src"))
.def("device_handle", &ProxyChannel::deviceHandle);
nb::class_<ProxyChannel::DeviceHandle>(m, "ProxyChannelDeviceHandle")
@@ -35,21 +50,9 @@ void register_proxy_channel(nb::module_& m) {
.def_rw("semaphoreId_", &ProxyChannel::DeviceHandle::semaphoreId_)
.def_rw("semaphore_", &ProxyChannel::DeviceHandle::semaphore_)
.def_rw("fifo_", &ProxyChannel::DeviceHandle::fifo_)
.def_rw("src_", &ProxyChannel::DeviceHandle::src_)
.def_rw("dst_", &ProxyChannel::DeviceHandle::dst_)
.def_prop_ro("raw", [](const ProxyChannel::DeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});
nb::class_<SimpleProxyChannel>(m, "SimpleProxyChannel")
.def(nb::init<ProxyChannel, MemoryId, MemoryId>(), nb::arg("proxyChan"), nb::arg("dst"), nb::arg("src"))
.def(nb::init<SimpleProxyChannel>(), nb::arg("proxyChan"))
.def("device_handle", &SimpleProxyChannel::deviceHandle);
nb::class_<SimpleProxyChannel::DeviceHandle>(m, "SimpleProxyChannelDeviceHandle")
.def(nb::init<>())
.def_rw("proxyChan_", &SimpleProxyChannel::DeviceHandle::proxyChan_)
.def_rw("src_", &SimpleProxyChannel::DeviceHandle::src_)
.def_rw("dst_", &SimpleProxyChannel::DeviceHandle::dst_)
.def_prop_ro("raw", [](const SimpleProxyChannel::DeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});
};

View File

@@ -301,9 +301,8 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
// -------------------------------------------
extern "C" __global__ void __launch_bounds__(1024, 1)
allreduce3(mscclpp::SimpleProxyChannelDeviceHandle* fstRoundChans,
mscclpp::SimpleProxyChannelDeviceHandle* sndRoundChans, TYPE* buff, TYPE* scratch, int rank,
int worldSize, size_t nelems) {
allreduce3(mscclpp::ProxyChannelDeviceHandle* fstRoundChans, mscclpp::ProxyChannelDeviceHandle* sndRoundChans,
TYPE* buff, TYPE* scratch, int rank, int worldSize, size_t nelems) {
nelems = nelems / (sizeof(int) / sizeof(TYPE));
int isComm = (threadIdx.x == 0) && (blockIdx.x == 0);
@@ -312,10 +311,10 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
int peerSendId = (remoteSendRank < rank) ? remoteSendRank : remoteSendRank - 1;
int peerRecvId = (remoteRecvRank < rank) ? remoteRecvRank : remoteRecvRank - 1;
mscclpp::SimpleProxyChannelDeviceHandle& devFstSendChan = fstRoundChans[peerSendId];
mscclpp::SimpleProxyChannelDeviceHandle& devFstRecvChan = fstRoundChans[peerRecvId];
mscclpp::SimpleProxyChannelDeviceHandle& devSndSendChan = sndRoundChans[peerSendId];
mscclpp::SimpleProxyChannelDeviceHandle& devSndRecvChan = sndRoundChans[peerRecvId];
mscclpp::ProxyChannelDeviceHandle& devFstSendChan = fstRoundChans[peerSendId];
mscclpp::ProxyChannelDeviceHandle& devFstRecvChan = fstRoundChans[peerRecvId];
mscclpp::ProxyChannelDeviceHandle& devSndSendChan = sndRoundChans[peerSendId];
mscclpp::ProxyChannelDeviceHandle& devSndRecvChan = sndRoundChans[peerRecvId];
// Step 1
size_t chunkIndex = (rank + worldSize - 1) % worldSize;
@@ -529,9 +528,8 @@ __device__ void localAllGatherAllPairsSm(mscclpp::SmChannelDeviceHandle* smChans
}
// This is an allgather4 equivalent
__device__ void allGatherSm(mscclpp::SmChannelDeviceHandle* smChans,
mscclpp::SimpleProxyChannelDeviceHandle* proxyChans, int rank, int worldSize,
int nRanksPerNode, size_t nelemsPerGPU, int pipelineDepth) {
__device__ void allGatherSm(mscclpp::SmChannelDeviceHandle* smChans, mscclpp::ProxyChannelDeviceHandle* proxyChans,
int rank, int worldSize, int nRanksPerNode, size_t nelemsPerGPU, int pipelineDepth) {
// this allgather is a pipelined and hierarchical one and only works for two nodes
// it is implemented as follows:
// Step 1: each node does a local allgather and concurrently,
@@ -546,7 +544,7 @@ __device__ void allGatherSm(mscclpp::SmChannelDeviceHandle* smChans,
int peerRank = (rank + nRanksPerNode) % worldSize;
int peerNodeId = peerRank / nRanksPerNode;
int peer = (peerRank < rank) ? peerRank : peerRank - 1;
mscclpp::SimpleProxyChannelDeviceHandle proxyChan = proxyChans[peer];
mscclpp::ProxyChannelDeviceHandle proxyChan = proxyChans[peer];
const size_t nBlocksForLocalAllGather = gridDim.x / (nRanksPerNode - 1) * (nRanksPerNode - 1);
const size_t rankChunkSize = nelemsPerGPU * sizeof(int);
const int startRankIndexInLocalNode = (rank / nRanksPerNode) * nRanksPerNode;
@@ -590,9 +588,8 @@ __device__ void allGatherSm(mscclpp::SmChannelDeviceHandle* smChans,
nBlocksForLocalAllGather);
}
__device__ void reduceScatterSm(mscclpp::SmChannelDeviceHandle* smChans,
mscclpp::SimpleProxyChannelDeviceHandle* proxyChans, TYPE* buff, TYPE* scratch,
int rank, int nRanksPerNode, int worldSize,
__device__ void reduceScatterSm(mscclpp::SmChannelDeviceHandle* smChans, mscclpp::ProxyChannelDeviceHandle* proxyChans,
TYPE* buff, TYPE* scratch, int rank, int nRanksPerNode, int worldSize,
size_t nelems, // must be divisible by 3
int pipelineDepth) {
// this reduce-scatter algorithm works as follows:
@@ -615,7 +612,7 @@ __device__ void reduceScatterSm(mscclpp::SmChannelDeviceHandle* smChans,
int isComm = (threadIdx.x == 0) && (blockIdx.x == nBlocksForReduceScatter);
int peer = (peerRank < rank) ? peerRank : peerRank - 1;
int nBlocksRemain = gridDim.x - nBlocksForReduceScatter;
mscclpp::SimpleProxyChannelDeviceHandle proxyChan = proxyChans[peer];
mscclpp::ProxyChannelDeviceHandle proxyChan = proxyChans[peer];
if (peerNodeId == rank / nRanksPerNode) {
localReduceScatterSm(smChans, buff, rank, nRanksPerNode, 0, 0, chunkSize, chunkSize, gridDim.x);
return;
@@ -675,9 +672,8 @@ __device__ void reduceScatterSm(mscclpp::SmChannelDeviceHandle* smChans,
}
extern "C" __global__ void __launch_bounds__(1024, 1) __global__
allreduce4(mscclpp::SmChannelDeviceHandle* smChans,
mscclpp::SimpleProxyChannelDeviceHandle* reduceScatterProxyChans,
mscclpp::SimpleProxyChannelDeviceHandle* allGatherProxyChans, TYPE* buff, TYPE* scratch, int rank,
allreduce4(mscclpp::SmChannelDeviceHandle* smChans, mscclpp::ProxyChannelDeviceHandle* reduceScatterProxyChans,
mscclpp::ProxyChannelDeviceHandle* allGatherProxyChans, TYPE* buff, TYPE* scratch, int rank,
int nRanksPerNode, int worldSize, size_t nelems, int pipelineDepth) {
nelems = nelems / (sizeof(int) / sizeof(TYPE));
reduceScatterSm(smChans, reduceScatterProxyChans, buff, scratch, rank, nRanksPerNode, worldSize, nelems,
@@ -688,7 +684,7 @@ extern "C" __global__ void __launch_bounds__(1024, 1) __global__
// allreduce 5 for 2-nodes
extern "C" __global__ void __launch_bounds__(1024, 1)
allreduce5(mscclpp::SmChannelDeviceHandle* smChans, mscclpp::SimpleProxyChannelDeviceHandle* proxyChans, TYPE* buff,
allreduce5(mscclpp::SmChannelDeviceHandle* smChans, mscclpp::ProxyChannelDeviceHandle* proxyChans, TYPE* buff,
TYPE* scratch, TYPE* putBuff, TYPE* resultBuff, int rank, int nRanksPerNode, int worldSize,
size_t nelems) {
nelems = nelems / (sizeof(int) / sizeof(TYPE));
@@ -706,7 +702,7 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
const int peerIdx = blockIdx.x / nBlocksPerPeer;
const int remoteRankIdx = peerIdx < localRankId ? peerIdx : peerIdx + 1;
mscclpp::SmChannelDeviceHandle smChan = smChans[peerIdx];
mscclpp::SimpleProxyChannelDeviceHandle proxyChan = proxyChans[localRankId];
mscclpp::ProxyChannelDeviceHandle proxyChan = proxyChans[localRankId];
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
// double buffering
size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LLPacket);

View File

@@ -6,8 +6,8 @@
// be careful about using channels[my_rank] as it is inavlie and it is there just for simplicity of indexing
extern "C" __global__ void __launch_bounds__(1024, 1)
simple_proxy_channel(mscclpp::SimpleProxyChannelDeviceHandle* channels, int my_rank, int nranks, int* data,
int* scratch, int num_elements, int use_packet) {
proxy_channel(mscclpp::ProxyChannelDeviceHandle* channels, int my_rank, int nranks, int* data, int* scratch,
int num_elements, int use_packet) {
int tid = threadIdx.x;
int nthreads = blockDim.x;
uint64_t size_per_rank = (num_elements * sizeof(int)) / nranks;

View File

@@ -346,9 +346,9 @@ class MscclppKernel:
).get_compiled_kernel()
self.nblocks = 1
self.nthreads = nranks
elif test_name == "simple_proxy_channel":
elif test_name == "proxy_channel":
self._kernel = KernelBuilder(
file="simple_proxy_channel_test.cu", kernel_name="simple_proxy_channel", file_dir=file_dir
file="proxy_channel_test.cu", kernel_name="proxy_channel", file_dir=file_dir
).get_compiled_kernel()
self.nblocks = 1
self.nthreads = 1024
@@ -376,11 +376,11 @@ class MscclppKernel:
# keep a reference to the device handles so that they don't get garbage collected
self._d_semaphore_or_channels = cp.asarray(memoryview(b"".join(device_handles)), dtype=cp.uint8)
if test_name in ["h2d_semaphore", "d2d_semaphore", "sm_channel", "simple_proxy_channel"]:
if test_name in ["h2d_semaphore", "d2d_semaphore", "sm_channel", "proxy_channel"]:
self.params += pack(self._d_semaphore_or_channels, my_rank, nranks)
if test_name == "sm_channel":
self.params += pack(tensor.size, use_packet)
if test_name == "simple_proxy_channel":
if test_name == "proxy_channel":
self.params += pack(tensor, scratch, tensor.size, use_packet)
elif test_name == "fifo":
self.params = fifo.device_handle().raw
@@ -531,7 +531,7 @@ def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
@pytest.mark.parametrize("transport", ["NVLink", "IB"])
@pytest.mark.parametrize("use_packet", [False, True])
def test_simple_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_packet: bool):
def test_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_packet: bool):
group, connections = create_group_and_connection(mpi_group, transport)
memory = cp.zeros(nelem, dtype=cp.int32)
@@ -552,13 +552,13 @@ def test_simple_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, u
memory_to_register = scratch
else:
memory_to_register = memory
simple_channels = group.make_proxy_channels(proxy_service, memory_to_register, connections)
channels = group.make_proxy_channels(proxy_service, memory_to_register, connections)
kernel = MscclppKernel(
"simple_proxy_channel",
"proxy_channel",
my_rank=group.my_rank,
nranks=group.nranks,
semaphore_or_channels=simple_channels,
semaphore_or_channels=channels,
tensor=memory,
use_packet=use_packet,
scratch=scratch,