Fix Python bindings and tests (#690)

Minimal fix to make things work. We need a more careful look at
preventing silent fallback of nanobind when it fails to (properly)
construct a C++ STL object with mscclpp instances.
This commit is contained in:
Changho Hwang
2025-11-21 12:53:12 -08:00
committed by GitHub
parent 060c35fec6
commit 8b8593ba51
8 changed files with 73 additions and 63 deletions

View File

@@ -216,6 +216,7 @@ void register_core(nb::module_& m) {
def_shared_future<RegisteredMemory>(m, "RegisteredMemory");
def_shared_future<Connection>(m, "Connection");
def_shared_future<Semaphore>(m, "Semaphore");
nb::class_<Communicator>(m, "Communicator")
.def(nb::init<std::shared_ptr<Bootstrap>, std::shared_ptr<Context>>(), nb::arg("bootstrap"),
@@ -242,7 +243,7 @@ void register_core(nb::module_& m) {
nb::arg("remote_rank"), nb::arg("tag"), nb::arg("local_config"))
.def("send_memory_on_setup", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remote_rank"), nb::arg("tag"))
.def("recv_memory_on_setup", &Communicator::recvMemory, nb::arg("remote_rank"), nb::arg("tag"))
.def("build_semaphore", &Communicator::buildSemaphore, nb::arg("local_flag"), nb::arg("remote_rank"),
.def("build_semaphore", &Communicator::buildSemaphore, nb::arg("connection"), nb::arg("remote_rank"),
nb::arg("tag") = 0)
.def("remote_rank_of", &Communicator::remoteRankOf)
.def("tag_of", &Communicator::tagOf)

View File

@@ -26,14 +26,20 @@ void register_memory_channel(nb::module_& m) {
nb::class_<MemoryChannel>(m, "MemoryChannel")
.def(nb::init<>())
.def("__init__",
[](MemoryChannel* memoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
RegisteredMemory dst, RegisteredMemory src) { new (memoryChannel) MemoryChannel(semaphore, dst, src); })
.def("__init__",
[](MemoryChannel* memoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
RegisteredMemory dst, RegisteredMemory src, uintptr_t packet_buffer) {
new (memoryChannel) MemoryChannel(semaphore, dst, src, reinterpret_cast<void*>(packet_buffer));
})
.def(
"__init__",
[](MemoryChannel* memoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
RegisteredMemory dst, RegisteredMemory src, uintptr_t packet_buffer) {
new (memoryChannel) MemoryChannel(semaphore, dst, src, reinterpret_cast<void*>(packet_buffer));
},
nb::arg("semaphore"), nb::arg("dst"), nb::arg("src"), nb::arg("packet_buffer") = 0)
.def(
"__init__",
[](MemoryChannel* memoryChannel, const Semaphore& semaphore, RegisteredMemory dst, RegisteredMemory src,
uintptr_t packet_buffer = 0) {
new (memoryChannel) MemoryChannel(semaphore, dst, src, reinterpret_cast<void*>(packet_buffer));
},
nb::arg("semaphore"), nb::arg("dst"), nb::arg("src"), nb::arg("packet_buffer") = 0)
.def("device_handle", &MemoryChannel::deviceHandle);
nb::class_<MemoryChannel::DeviceHandle>(m, "MemoryChannelDeviceHandle")

View File

@@ -47,6 +47,7 @@ from ._mscclpp import (
connect_nvls_collective,
EndpointConfig,
Fifo,
Semaphore,
Host2DeviceSemaphore,
Host2HostSemaphore,
numa,
@@ -79,6 +80,7 @@ __all__ = [
"connect_nvls_collective",
"EndpointConfig",
"Fifo",
"Semaphore",
"Host2DeviceSemaphore",
"Host2HostSemaphore",
"numa",

View File

@@ -10,6 +10,7 @@ from ._mscclpp import (
Connection,
connect_nvls_collective,
EndpointConfig,
Semaphore,
Host2DeviceSemaphore,
Host2HostSemaphore,
ProxyService,
@@ -133,18 +134,14 @@ class CommGroup:
all_registered_memories[rank] = future_memories[rank].get()
return all_registered_memories
def make_semaphore(
self,
connections: dict[int, Connection],
semaphore_type: Type[Host2HostSemaphore] | Type[Host2DeviceSemaphore] | Type[MemoryDevice2DeviceSemaphore],
) -> dict[int, Host2HostSemaphore]:
semaphores = {}
def make_semaphores(self, connections: dict[int, Connection]) -> dict[int, Semaphore]:
future_semaphores = {}
for rank in connections:
semaphores[rank] = semaphore_type(self.communicator, connections[rank])
return semaphores
future_semaphores[rank] = self.communicator.build_semaphore(connections[rank], rank)
return {rank: future.get() for rank, future in future_semaphores.items()}
def make_memory_channels(self, tensor: cp.ndarray, connections: dict[int, Connection]) -> dict[int, MemoryChannel]:
semaphores = self.make_semaphore(connections, MemoryDevice2DeviceSemaphore)
semaphores = self.make_semaphores(connections)
registered_memories = self.register_tensor_with_connections(tensor, connections)
channels = {}
for rank in connections:
@@ -159,7 +156,7 @@ class CommGroup:
registeredScratchBuffer: RegisteredMemory,
connections: dict[int, Connection],
) -> dict[int, MemoryChannel]:
semaphores = self.make_semaphore(connections, MemoryDevice2DeviceSemaphore)
semaphores = self.make_semaphores(connections)
registered_memories = self._register_memory_with_connections(registeredScratchBuffer, connections)
channels = {}
tensor_data_ptr = tensor.data_ptr() if is_torch_tensor(tensor) else tensor.data.ptr
@@ -177,7 +174,7 @@ class CommGroup:
def make_port_channels(
self, proxy_service: ProxyService, tensor: cp.ndarray, connections: dict[int, Connection]
) -> dict[int, PortChannel]:
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
semaphores = self.make_semaphores(connections)
registered_memories = self.register_tensor_with_connections(tensor, connections)
memory_ids = {}
semaphore_ids = {}
@@ -210,7 +207,7 @@ class CommGroup:
)
local_reg_memory = self.communicator.register_memory(data_ptr, tensor_size, transport_flags)
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
semaphores = self.make_semaphores(connections)
registered_memories = self._register_memory_with_connections(registeredScratchBuffer, connections)
memory_ids = {}
semaphore_ids = {}
@@ -229,7 +226,7 @@ class CommGroup:
def register_semaphore_with_proxy(
self, proxy_service: ProxyService, connections: dict[int, Connection]
) -> dict[int, PortChannel]:
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
semaphores = self.make_semaphores(connections)
semaphore_ids = {}
for rank in semaphores:
semaphore_ids[rank] = proxy_service.add_semaphore(semaphores[rank])

View File

@@ -453,7 +453,10 @@ class MscclppAllReduce6:
)
# create a memory_channel for each remote neighbor
self.semaphores = group.make_semaphore(self.nvlink_connections, MemoryDevice2DeviceSemaphore)
self.semaphores = {
rank: MemoryDevice2DeviceSemaphore(sema)
for rank, sema in group.make_semaphores(self.nvlink_connections).items()
}
file_dir = os.path.dirname(os.path.abspath(__file__))
self.kernel = KernelBuilder(
file="allreduce.cu",

View File

@@ -10,7 +10,6 @@
#include <mscclpp/core.hpp>
#include <mscclpp/fifo.hpp>
#include <mscclpp/gpu_utils.hpp>
#include <mscclpp/numa.hpp>
#include <mscclpp/proxy.hpp>
#include <mscclpp/semaphore.hpp>
#include <vector>
@@ -19,37 +18,39 @@ namespace nb = nanobind;
class MyProxyService {
private:
int deviceNumaNode_;
int my_rank_, nranks_, dataSize_;
std::vector<mscclpp::Connection> connections_;
std::vector<std::shared_ptr<mscclpp::RegisteredMemory>> allRegMem_;
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> semaphores_;
std::vector<mscclpp::RegisteredMemory> allRegMem_;
std::vector<mscclpp::Host2DeviceSemaphore> semaphores_;
mscclpp::Proxy proxy_;
public:
MyProxyService(int my_rank, int nranks, int dataSize, std::vector<mscclpp::Connection> conns,
std::vector<std::shared_ptr<mscclpp::RegisteredMemory>> allRegMem,
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> semaphores)
MyProxyService(int my_rank, int nranks, int dataSize, nb::list allRegMemList, nb::list semaphoreList)
: my_rank_(my_rank),
nranks_(nranks),
dataSize_(dataSize),
connections_(conns),
allRegMem_(allRegMem),
semaphores_(semaphores),
proxy_([&](mscclpp::ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }) {
int cudaDevice;
MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDevice));
deviceNumaNode_ = mscclpp::getDeviceNumaNode(cudaDevice);
allRegMem_.reserve(allRegMemList.size());
for (size_t i = 0; i < allRegMemList.size(); ++i) {
auto& regMem = nb::cast<const mscclpp::RegisteredMemory&>(allRegMemList[i]);
allRegMem_.push_back(regMem);
}
semaphores_.reserve(semaphoreList.size());
for (size_t i = 0; i < semaphoreList.size(); ++i) {
auto& sema = nb::cast<const mscclpp::Semaphore&>(semaphoreList[i]);
semaphores_.emplace_back(sema);
}
}
mscclpp::ProxyHandlerResult handleTrigger(mscclpp::ProxyTrigger) {
int dataSizePerRank = dataSize_ / nranks_;
for (int r = 1; r < nranks_; ++r) {
int nghr = (my_rank_ + r) % nranks_;
connections_[nghr].write(*allRegMem_[nghr], my_rank_ * (uint64_t)dataSizePerRank, *allRegMem_[my_rank_],
my_rank_ * (uint64_t)dataSizePerRank, dataSizePerRank);
semaphores_[nghr]->signal();
connections_[nghr].flush();
auto& sema = semaphores_[nghr];
auto& conn = sema.connection();
conn.write(allRegMem_[nghr], my_rank_ * (uint64_t)dataSizePerRank, allRegMem_[my_rank_],
my_rank_ * (uint64_t)dataSizePerRank, dataSizePerRank);
sema.signal();
conn.flush();
}
return mscclpp::ProxyHandlerResult::Continue;
}
@@ -61,16 +62,11 @@ class MyProxyService {
mscclpp::FifoDeviceHandle fifoDeviceHandle() { return proxy_.fifo()->deviceHandle(); }
};
void init_mscclpp_proxy_test_module(nb::module_ &m) {
NB_MODULE(_ext, m) {
nb::class_<MyProxyService>(m, "MyProxyService")
.def(nb::init<int, int, int, std::vector<mscclpp::Connection>,
std::vector<std::shared_ptr<mscclpp::RegisteredMemory>>,
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>>>(),
nb::arg("rank"), nb::arg("nranks"), nb::arg("data_size"), nb::arg("conn_vec"), nb::arg("reg_mem_vec"),
nb::arg("h2d_sem_vec"))
.def(nb::init<int, int, int, nb::list, nb::list>(), nb::arg("rank"), nb::arg("nranks"), nb::arg("data_size"),
nb::arg("reg_mem_list"), nb::arg("sem_list"))
.def("fifo_device_handle", &MyProxyService::fifoDeviceHandle)
.def("start", &MyProxyService::start)
.def("stop", &MyProxyService::stop);
}
NB_MODULE(_ext, m) { init_mscclpp_proxy_test_module(m); }

View File

@@ -290,7 +290,8 @@ def test_h2h_semaphores(mpi_group: MpiGroup):
connections = {rank: group.communicator.connect(endpoint, rank) for rank in remote_nghrs}
connections = {rank: conn.get() for rank, conn in connections.items()}
semaphores = group.make_semaphore(connections, Host2HostSemaphore)
semaphores = group.make_semaphores(connections)
semaphores = {rank: Host2HostSemaphore(sema) for rank, sema in semaphores.items()}
for rank in connections:
semaphores[rank].signal()
@@ -309,7 +310,8 @@ def test_h2h_semaphores_gil_release(mpi_group: MpiGroup):
connections = {rank: group.communicator.connect(endpoint, rank) for rank in remote_nghrs}
connections = {rank: conn.get() for rank, conn in connections.items()}
semaphores = group.make_semaphore(connections, Host2HostSemaphore)
semaphores = group.make_semaphores(connections)
semaphores = {rank: Host2HostSemaphore(sema) for rank, sema in semaphores.items()}
def target_wait(sems, conns):
for rank in conns:
@@ -457,7 +459,8 @@ def test_h2d_semaphores(mpi_group: MpiGroup, connection_type: str):
group, connections = create_group_and_connection(mpi_group, connection_type)
semaphores = group.make_semaphore(connections, Host2DeviceSemaphore)
semaphores = group.make_semaphores(connections)
semaphores = {rank: Host2DeviceSemaphore(sema) for rank, sema in semaphores.items()}
kernel = MscclppKernel("h2d_semaphore", group.my_rank, group.nranks, semaphores)
kernel()
@@ -473,7 +476,8 @@ def test_h2d_semaphores(mpi_group: MpiGroup, connection_type: str):
def test_d2d_semaphores(mpi_group: MpiGroup):
group, connections = create_group_and_connection(mpi_group, "NVLink")
semaphores = group.make_semaphore(connections, MemoryDevice2DeviceSemaphore)
semaphores = group.make_semaphores(connections)
semaphores = {rank: MemoryDevice2DeviceSemaphore(sema) for rank, sema in semaphores.items()}
group.barrier()
kernel = MscclppKernel("d2d_semaphore", group.my_rank, group.nranks, semaphores)
kernel()
@@ -545,29 +549,29 @@ def test_proxy(mpi_group: MpiGroup, nelem: int, connection_type: str):
group.barrier()
all_reg_memories = group.register_tensor_with_connections(memory, connections)
semaphores = group.make_semaphore(connections, Host2DeviceSemaphore)
semaphores = group.make_semaphores(connections)
list_conn = []
list_sem = []
list_reg_mem = []
first_conn = next(iter(connections.values()))
first_sem = next(iter(semaphores.values()))
for rank in range(group.nranks):
if rank in connections:
list_conn.append(connections[rank])
list_sem.append(semaphores[rank])
else:
list_conn.append(first_conn) # just for simplicity of indexing
list_sem.append(first_sem)
list_reg_mem.append(all_reg_memories[rank])
proxy = _ext.MyProxyService(group.my_rank, group.nranks, nelem * memory.itemsize, list_conn, list_reg_mem, list_sem)
proxy = _ext.MyProxyService(group.my_rank, group.nranks, nelem * memory.itemsize, list_reg_mem, list_sem)
fifo_device_handle = proxy.fifo_device_handle()
kernel = MscclppKernel(
"proxy", my_rank=group.my_rank, nranks=group.nranks, semaphore_or_channels=semaphores, fifo=fifo_device_handle
"proxy",
my_rank=group.my_rank,
nranks=group.nranks,
semaphore_or_channels={rank: Host2DeviceSemaphore(sema) for rank, sema in semaphores.items()},
fifo=fifo_device_handle,
)
proxy.start()
group.barrier()
@@ -632,7 +636,8 @@ def test_nvls(mpi_group: MpiGroup):
mem_handle = nvls_connection.bind_allocated_memory(memory.data.ptr, memory.data.mem.size)
nvlinks_connections = create_connection(group, "NVLink")
semaphores = group.make_semaphore(nvlinks_connections, MemoryDevice2DeviceSemaphore)
semaphores = group.make_semaphores(nvlinks_connections)
semaphores = {rank: MemoryDevice2DeviceSemaphore(sema) for rank, sema in semaphores.items()}
kernel = MscclppKernel(
"nvls",