mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 01:10:22 +00:00
Revised ProxyChannel interfaces (#400)
* Renamed `ProxyChannel` -> `BaseProxyChannel` and `SimpleProxyChannel` -> `ProxyChannel`. It makes the interface more consistent by defining channels to be associated with a certain src/dst memory region: `ProxyChannel` as "sema + src/dst + fifo" and `SmChannel` as "sema + src/dst". BaseProxyChannel is not associated with any memory regions, as "sema + fifo". * `ProxyChannelDeviceHandle` now inherits from `BaseProxyChannelDeviceHandle`, instead of having one as a member.
This commit is contained in:
@@ -14,7 +14,7 @@ from ._mscclpp import (
|
||||
numa,
|
||||
ProxyService,
|
||||
RegisteredMemory,
|
||||
SimpleProxyChannel,
|
||||
ProxyChannel,
|
||||
SmChannel,
|
||||
SmDevice2DeviceSemaphore,
|
||||
TcpBootstrap,
|
||||
|
||||
@@ -14,7 +14,7 @@ from ._mscclpp import (
|
||||
Host2HostSemaphore,
|
||||
ProxyService,
|
||||
RegisteredMemory,
|
||||
SimpleProxyChannel,
|
||||
ProxyChannel,
|
||||
SmChannel,
|
||||
SmDevice2DeviceSemaphore,
|
||||
TcpBootstrap,
|
||||
@@ -180,8 +180,8 @@ class CommGroup:
|
||||
semaphore_ids[rank] = proxy_service.add_semaphore(semaphores[rank])
|
||||
channels = {}
|
||||
for rank in semaphores:
|
||||
channels[rank] = SimpleProxyChannel(
|
||||
proxy_service.proxy_channel(semaphore_ids[rank]), memory_ids[rank], memory_ids[self.my_rank]
|
||||
channels[rank] = proxy_service.proxy_channel(
|
||||
semaphore_ids[rank], memory_ids[rank], memory_ids[self.my_rank]
|
||||
)
|
||||
return channels
|
||||
|
||||
@@ -218,8 +218,8 @@ class CommGroup:
|
||||
semaphore_ids[rank] = proxy_service.add_semaphore(semaphores[rank])
|
||||
channels = {}
|
||||
for rank in semaphores:
|
||||
channels[rank] = SimpleProxyChannel(
|
||||
proxy_service.proxy_channel(semaphore_ids[rank]), memory_ids[rank], memory_ids[self.my_rank]
|
||||
channels[rank] = proxy_service.proxy_channel(
|
||||
semaphore_ids[rank], memory_ids[rank], memory_ids[self.my_rank]
|
||||
)
|
||||
return channels
|
||||
|
||||
@@ -232,7 +232,7 @@ class CommGroup:
|
||||
semaphore_ids[rank] = proxy_service.add_semaphore(semaphores[rank])
|
||||
channels = {}
|
||||
for rank in semaphores:
|
||||
channels[rank] = proxy_service.proxy_channel(semaphore_ids[rank])
|
||||
channels[rank] = proxy_service.base_proxy_channel(semaphore_ids[rank])
|
||||
return channels
|
||||
|
||||
def register_memory_with_proxy(
|
||||
|
||||
@@ -23,11 +23,26 @@ void register_proxy_channel(nb::module_& m) {
|
||||
.def("add_semaphore", &ProxyService::addSemaphore, nb::arg("semaphore"))
|
||||
.def("add_memory", &ProxyService::addMemory, nb::arg("memory"))
|
||||
.def("semaphore", &ProxyService::semaphore, nb::arg("id"))
|
||||
.def("proxy_channel", &ProxyService::proxyChannel, nb::arg("id"));
|
||||
.def("base_proxy_channel", &ProxyService::baseProxyChannel, nb::arg("id"))
|
||||
.def("proxy_channel", &ProxyService::proxyChannel, nb::arg("id"), nb::arg("dst"), nb::arg("src"));
|
||||
|
||||
nb::class_<ProxyChannel>(m, "ProxyChannel")
|
||||
nb::class_<BaseProxyChannel>(m, "BaseProxyChannel")
|
||||
.def(nb::init<SemaphoreId, std::shared_ptr<Host2DeviceSemaphore>, std::shared_ptr<Proxy>>(),
|
||||
nb::arg("semaphoreId"), nb::arg("semaphore"), nb::arg("proxy"))
|
||||
.def("device_handle", &BaseProxyChannel::deviceHandle);
|
||||
|
||||
nb::class_<BaseProxyChannel::DeviceHandle>(m, "BaseProxyChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("semaphoreId_", &BaseProxyChannel::DeviceHandle::semaphoreId_)
|
||||
.def_rw("semaphore_", &BaseProxyChannel::DeviceHandle::semaphore_)
|
||||
.def_rw("fifo_", &BaseProxyChannel::DeviceHandle::fifo_)
|
||||
.def_prop_ro("raw", [](const BaseProxyChannel::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<ProxyChannel>(m, "ProxyChannel")
|
||||
.def(nb::init<SemaphoreId, std::shared_ptr<Host2DeviceSemaphore>, std::shared_ptr<Proxy>, MemoryId, MemoryId>(),
|
||||
nb::arg("semaphoreId"), nb::arg("semaphore"), nb::arg("proxy"), nb::arg("dst"), nb::arg("src"))
|
||||
.def("device_handle", &ProxyChannel::deviceHandle);
|
||||
|
||||
nb::class_<ProxyChannel::DeviceHandle>(m, "ProxyChannelDeviceHandle")
|
||||
@@ -35,21 +50,9 @@ void register_proxy_channel(nb::module_& m) {
|
||||
.def_rw("semaphoreId_", &ProxyChannel::DeviceHandle::semaphoreId_)
|
||||
.def_rw("semaphore_", &ProxyChannel::DeviceHandle::semaphore_)
|
||||
.def_rw("fifo_", &ProxyChannel::DeviceHandle::fifo_)
|
||||
.def_rw("src_", &ProxyChannel::DeviceHandle::src_)
|
||||
.def_rw("dst_", &ProxyChannel::DeviceHandle::dst_)
|
||||
.def_prop_ro("raw", [](const ProxyChannel::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
|
||||
nb::class_<SimpleProxyChannel>(m, "SimpleProxyChannel")
|
||||
.def(nb::init<ProxyChannel, MemoryId, MemoryId>(), nb::arg("proxyChan"), nb::arg("dst"), nb::arg("src"))
|
||||
.def(nb::init<SimpleProxyChannel>(), nb::arg("proxyChan"))
|
||||
.def("device_handle", &SimpleProxyChannel::deviceHandle);
|
||||
|
||||
nb::class_<SimpleProxyChannel::DeviceHandle>(m, "SimpleProxyChannelDeviceHandle")
|
||||
.def(nb::init<>())
|
||||
.def_rw("proxyChan_", &SimpleProxyChannel::DeviceHandle::proxyChan_)
|
||||
.def_rw("src_", &SimpleProxyChannel::DeviceHandle::src_)
|
||||
.def_rw("dst_", &SimpleProxyChannel::DeviceHandle::dst_)
|
||||
.def_prop_ro("raw", [](const SimpleProxyChannel::DeviceHandle& self) -> nb::bytes {
|
||||
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
|
||||
});
|
||||
};
|
||||
|
||||
@@ -301,9 +301,8 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
// -------------------------------------------
|
||||
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
allreduce3(mscclpp::SimpleProxyChannelDeviceHandle* fstRoundChans,
|
||||
mscclpp::SimpleProxyChannelDeviceHandle* sndRoundChans, TYPE* buff, TYPE* scratch, int rank,
|
||||
int worldSize, size_t nelems) {
|
||||
allreduce3(mscclpp::ProxyChannelDeviceHandle* fstRoundChans, mscclpp::ProxyChannelDeviceHandle* sndRoundChans,
|
||||
TYPE* buff, TYPE* scratch, int rank, int worldSize, size_t nelems) {
|
||||
nelems = nelems / (sizeof(int) / sizeof(TYPE));
|
||||
|
||||
int isComm = (threadIdx.x == 0) && (blockIdx.x == 0);
|
||||
@@ -312,10 +311,10 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
int peerSendId = (remoteSendRank < rank) ? remoteSendRank : remoteSendRank - 1;
|
||||
int peerRecvId = (remoteRecvRank < rank) ? remoteRecvRank : remoteRecvRank - 1;
|
||||
|
||||
mscclpp::SimpleProxyChannelDeviceHandle& devFstSendChan = fstRoundChans[peerSendId];
|
||||
mscclpp::SimpleProxyChannelDeviceHandle& devFstRecvChan = fstRoundChans[peerRecvId];
|
||||
mscclpp::SimpleProxyChannelDeviceHandle& devSndSendChan = sndRoundChans[peerSendId];
|
||||
mscclpp::SimpleProxyChannelDeviceHandle& devSndRecvChan = sndRoundChans[peerRecvId];
|
||||
mscclpp::ProxyChannelDeviceHandle& devFstSendChan = fstRoundChans[peerSendId];
|
||||
mscclpp::ProxyChannelDeviceHandle& devFstRecvChan = fstRoundChans[peerRecvId];
|
||||
mscclpp::ProxyChannelDeviceHandle& devSndSendChan = sndRoundChans[peerSendId];
|
||||
mscclpp::ProxyChannelDeviceHandle& devSndRecvChan = sndRoundChans[peerRecvId];
|
||||
|
||||
// Step 1
|
||||
size_t chunkIndex = (rank + worldSize - 1) % worldSize;
|
||||
@@ -529,9 +528,8 @@ __device__ void localAllGatherAllPairsSm(mscclpp::SmChannelDeviceHandle* smChans
|
||||
}
|
||||
|
||||
// This is an allgather4 equivalent
|
||||
__device__ void allGatherSm(mscclpp::SmChannelDeviceHandle* smChans,
|
||||
mscclpp::SimpleProxyChannelDeviceHandle* proxyChans, int rank, int worldSize,
|
||||
int nRanksPerNode, size_t nelemsPerGPU, int pipelineDepth) {
|
||||
__device__ void allGatherSm(mscclpp::SmChannelDeviceHandle* smChans, mscclpp::ProxyChannelDeviceHandle* proxyChans,
|
||||
int rank, int worldSize, int nRanksPerNode, size_t nelemsPerGPU, int pipelineDepth) {
|
||||
// this allgather is a pipelined and hierarchical one and only works for two nodes
|
||||
// it is implemented as follows:
|
||||
// Step 1: each node does a local allgather and concurrently,
|
||||
@@ -546,7 +544,7 @@ __device__ void allGatherSm(mscclpp::SmChannelDeviceHandle* smChans,
|
||||
int peerRank = (rank + nRanksPerNode) % worldSize;
|
||||
int peerNodeId = peerRank / nRanksPerNode;
|
||||
int peer = (peerRank < rank) ? peerRank : peerRank - 1;
|
||||
mscclpp::SimpleProxyChannelDeviceHandle proxyChan = proxyChans[peer];
|
||||
mscclpp::ProxyChannelDeviceHandle proxyChan = proxyChans[peer];
|
||||
const size_t nBlocksForLocalAllGather = gridDim.x / (nRanksPerNode - 1) * (nRanksPerNode - 1);
|
||||
const size_t rankChunkSize = nelemsPerGPU * sizeof(int);
|
||||
const int startRankIndexInLocalNode = (rank / nRanksPerNode) * nRanksPerNode;
|
||||
@@ -590,9 +588,8 @@ __device__ void allGatherSm(mscclpp::SmChannelDeviceHandle* smChans,
|
||||
nBlocksForLocalAllGather);
|
||||
}
|
||||
|
||||
__device__ void reduceScatterSm(mscclpp::SmChannelDeviceHandle* smChans,
|
||||
mscclpp::SimpleProxyChannelDeviceHandle* proxyChans, TYPE* buff, TYPE* scratch,
|
||||
int rank, int nRanksPerNode, int worldSize,
|
||||
__device__ void reduceScatterSm(mscclpp::SmChannelDeviceHandle* smChans, mscclpp::ProxyChannelDeviceHandle* proxyChans,
|
||||
TYPE* buff, TYPE* scratch, int rank, int nRanksPerNode, int worldSize,
|
||||
size_t nelems, // must be divisible by 3
|
||||
int pipelineDepth) {
|
||||
// this reduce-scatter algorithm works as follows:
|
||||
@@ -615,7 +612,7 @@ __device__ void reduceScatterSm(mscclpp::SmChannelDeviceHandle* smChans,
|
||||
int isComm = (threadIdx.x == 0) && (blockIdx.x == nBlocksForReduceScatter);
|
||||
int peer = (peerRank < rank) ? peerRank : peerRank - 1;
|
||||
int nBlocksRemain = gridDim.x - nBlocksForReduceScatter;
|
||||
mscclpp::SimpleProxyChannelDeviceHandle proxyChan = proxyChans[peer];
|
||||
mscclpp::ProxyChannelDeviceHandle proxyChan = proxyChans[peer];
|
||||
if (peerNodeId == rank / nRanksPerNode) {
|
||||
localReduceScatterSm(smChans, buff, rank, nRanksPerNode, 0, 0, chunkSize, chunkSize, gridDim.x);
|
||||
return;
|
||||
@@ -675,9 +672,8 @@ __device__ void reduceScatterSm(mscclpp::SmChannelDeviceHandle* smChans,
|
||||
}
|
||||
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1) __global__
|
||||
allreduce4(mscclpp::SmChannelDeviceHandle* smChans,
|
||||
mscclpp::SimpleProxyChannelDeviceHandle* reduceScatterProxyChans,
|
||||
mscclpp::SimpleProxyChannelDeviceHandle* allGatherProxyChans, TYPE* buff, TYPE* scratch, int rank,
|
||||
allreduce4(mscclpp::SmChannelDeviceHandle* smChans, mscclpp::ProxyChannelDeviceHandle* reduceScatterProxyChans,
|
||||
mscclpp::ProxyChannelDeviceHandle* allGatherProxyChans, TYPE* buff, TYPE* scratch, int rank,
|
||||
int nRanksPerNode, int worldSize, size_t nelems, int pipelineDepth) {
|
||||
nelems = nelems / (sizeof(int) / sizeof(TYPE));
|
||||
reduceScatterSm(smChans, reduceScatterProxyChans, buff, scratch, rank, nRanksPerNode, worldSize, nelems,
|
||||
@@ -688,7 +684,7 @@ extern "C" __global__ void __launch_bounds__(1024, 1) __global__
|
||||
|
||||
// allreduce 5 for 2-nodes
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
allreduce5(mscclpp::SmChannelDeviceHandle* smChans, mscclpp::SimpleProxyChannelDeviceHandle* proxyChans, TYPE* buff,
|
||||
allreduce5(mscclpp::SmChannelDeviceHandle* smChans, mscclpp::ProxyChannelDeviceHandle* proxyChans, TYPE* buff,
|
||||
TYPE* scratch, TYPE* putBuff, TYPE* resultBuff, int rank, int nRanksPerNode, int worldSize,
|
||||
size_t nelems) {
|
||||
nelems = nelems / (sizeof(int) / sizeof(TYPE));
|
||||
@@ -706,7 +702,7 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
const int peerIdx = blockIdx.x / nBlocksPerPeer;
|
||||
const int remoteRankIdx = peerIdx < localRankId ? peerIdx : peerIdx + 1;
|
||||
mscclpp::SmChannelDeviceHandle smChan = smChans[peerIdx];
|
||||
mscclpp::SimpleProxyChannelDeviceHandle proxyChan = proxyChans[localRankId];
|
||||
mscclpp::ProxyChannelDeviceHandle proxyChan = proxyChans[localRankId];
|
||||
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
|
||||
// double buffering
|
||||
size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LLPacket);
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
|
||||
// be careful about using channels[my_rank] as it is inavlie and it is there just for simplicity of indexing
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
simple_proxy_channel(mscclpp::SimpleProxyChannelDeviceHandle* channels, int my_rank, int nranks, int* data,
|
||||
int* scratch, int num_elements, int use_packet) {
|
||||
proxy_channel(mscclpp::ProxyChannelDeviceHandle* channels, int my_rank, int nranks, int* data, int* scratch,
|
||||
int num_elements, int use_packet) {
|
||||
int tid = threadIdx.x;
|
||||
int nthreads = blockDim.x;
|
||||
uint64_t size_per_rank = (num_elements * sizeof(int)) / nranks;
|
||||
@@ -346,9 +346,9 @@ class MscclppKernel:
|
||||
).get_compiled_kernel()
|
||||
self.nblocks = 1
|
||||
self.nthreads = nranks
|
||||
elif test_name == "simple_proxy_channel":
|
||||
elif test_name == "proxy_channel":
|
||||
self._kernel = KernelBuilder(
|
||||
file="simple_proxy_channel_test.cu", kernel_name="simple_proxy_channel", file_dir=file_dir
|
||||
file="proxy_channel_test.cu", kernel_name="proxy_channel", file_dir=file_dir
|
||||
).get_compiled_kernel()
|
||||
self.nblocks = 1
|
||||
self.nthreads = 1024
|
||||
@@ -376,11 +376,11 @@ class MscclppKernel:
|
||||
# keep a reference to the device handles so that they don't get garbage collected
|
||||
self._d_semaphore_or_channels = cp.asarray(memoryview(b"".join(device_handles)), dtype=cp.uint8)
|
||||
|
||||
if test_name in ["h2d_semaphore", "d2d_semaphore", "sm_channel", "simple_proxy_channel"]:
|
||||
if test_name in ["h2d_semaphore", "d2d_semaphore", "sm_channel", "proxy_channel"]:
|
||||
self.params += pack(self._d_semaphore_or_channels, my_rank, nranks)
|
||||
if test_name == "sm_channel":
|
||||
self.params += pack(tensor.size, use_packet)
|
||||
if test_name == "simple_proxy_channel":
|
||||
if test_name == "proxy_channel":
|
||||
self.params += pack(tensor, scratch, tensor.size, use_packet)
|
||||
elif test_name == "fifo":
|
||||
self.params = fifo.device_handle().raw
|
||||
@@ -531,7 +531,7 @@ def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):
|
||||
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
|
||||
@pytest.mark.parametrize("transport", ["NVLink", "IB"])
|
||||
@pytest.mark.parametrize("use_packet", [False, True])
|
||||
def test_simple_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_packet: bool):
|
||||
def test_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_packet: bool):
|
||||
group, connections = create_group_and_connection(mpi_group, transport)
|
||||
|
||||
memory = cp.zeros(nelem, dtype=cp.int32)
|
||||
@@ -552,13 +552,13 @@ def test_simple_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, u
|
||||
memory_to_register = scratch
|
||||
else:
|
||||
memory_to_register = memory
|
||||
simple_channels = group.make_proxy_channels(proxy_service, memory_to_register, connections)
|
||||
channels = group.make_proxy_channels(proxy_service, memory_to_register, connections)
|
||||
|
||||
kernel = MscclppKernel(
|
||||
"simple_proxy_channel",
|
||||
"proxy_channel",
|
||||
my_rank=group.my_rank,
|
||||
nranks=group.nranks,
|
||||
semaphore_or_channels=simple_channels,
|
||||
semaphore_or_channels=channels,
|
||||
tensor=memory,
|
||||
use_packet=use_packet,
|
||||
scratch=scratch,
|
||||
|
||||
Reference in New Issue
Block a user