Fix Python bindings and tests (#690)

Minimal fix to make things work. We need a more careful look at
preventing silent fallback of nanobind when it fails to (properly)
construct a C++ STL object with mscclpp instances.
This commit is contained in:
Changho Hwang
2025-11-21 12:53:12 -08:00
committed by GitHub
parent 060c35fec6
commit 8b8593ba51
8 changed files with 73 additions and 63 deletions

View File

@@ -10,7 +10,6 @@
#include <mscclpp/core.hpp>
#include <mscclpp/fifo.hpp>
#include <mscclpp/gpu_utils.hpp>
#include <mscclpp/numa.hpp>
#include <mscclpp/proxy.hpp>
#include <mscclpp/semaphore.hpp>
#include <vector>
@@ -19,37 +18,39 @@ namespace nb = nanobind;
class MyProxyService {
private:
int deviceNumaNode_;
int my_rank_, nranks_, dataSize_;
std::vector<mscclpp::Connection> connections_;
std::vector<std::shared_ptr<mscclpp::RegisteredMemory>> allRegMem_;
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> semaphores_;
std::vector<mscclpp::RegisteredMemory> allRegMem_;
std::vector<mscclpp::Host2DeviceSemaphore> semaphores_;
mscclpp::Proxy proxy_;
public:
MyProxyService(int my_rank, int nranks, int dataSize, std::vector<mscclpp::Connection> conns,
std::vector<std::shared_ptr<mscclpp::RegisteredMemory>> allRegMem,
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> semaphores)
MyProxyService(int my_rank, int nranks, int dataSize, nb::list allRegMemList, nb::list semaphoreList)
: my_rank_(my_rank),
nranks_(nranks),
dataSize_(dataSize),
connections_(conns),
allRegMem_(allRegMem),
semaphores_(semaphores),
proxy_([&](mscclpp::ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }) {
int cudaDevice;
MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDevice));
deviceNumaNode_ = mscclpp::getDeviceNumaNode(cudaDevice);
allRegMem_.reserve(allRegMemList.size());
for (size_t i = 0; i < allRegMemList.size(); ++i) {
auto& regMem = nb::cast<const mscclpp::RegisteredMemory&>(allRegMemList[i]);
allRegMem_.push_back(regMem);
}
semaphores_.reserve(semaphoreList.size());
for (size_t i = 0; i < semaphoreList.size(); ++i) {
auto& sema = nb::cast<const mscclpp::Semaphore&>(semaphoreList[i]);
semaphores_.emplace_back(sema);
}
}
mscclpp::ProxyHandlerResult handleTrigger(mscclpp::ProxyTrigger) {
int dataSizePerRank = dataSize_ / nranks_;
for (int r = 1; r < nranks_; ++r) {
int nghr = (my_rank_ + r) % nranks_;
connections_[nghr].write(*allRegMem_[nghr], my_rank_ * (uint64_t)dataSizePerRank, *allRegMem_[my_rank_],
my_rank_ * (uint64_t)dataSizePerRank, dataSizePerRank);
semaphores_[nghr]->signal();
connections_[nghr].flush();
auto& sema = semaphores_[nghr];
auto& conn = sema.connection();
conn.write(allRegMem_[nghr], my_rank_ * (uint64_t)dataSizePerRank, allRegMem_[my_rank_],
my_rank_ * (uint64_t)dataSizePerRank, dataSizePerRank);
sema.signal();
conn.flush();
}
return mscclpp::ProxyHandlerResult::Continue;
}
@@ -61,16 +62,11 @@ class MyProxyService {
mscclpp::FifoDeviceHandle fifoDeviceHandle() { return proxy_.fifo()->deviceHandle(); }
};
void init_mscclpp_proxy_test_module(nb::module_ &m) {
NB_MODULE(_ext, m) {
nb::class_<MyProxyService>(m, "MyProxyService")
.def(nb::init<int, int, int, std::vector<mscclpp::Connection>,
std::vector<std::shared_ptr<mscclpp::RegisteredMemory>>,
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>>>(),
nb::arg("rank"), nb::arg("nranks"), nb::arg("data_size"), nb::arg("conn_vec"), nb::arg("reg_mem_vec"),
nb::arg("h2d_sem_vec"))
.def(nb::init<int, int, int, nb::list, nb::list>(), nb::arg("rank"), nb::arg("nranks"), nb::arg("data_size"),
nb::arg("reg_mem_list"), nb::arg("sem_list"))
.def("fifo_device_handle", &MyProxyService::fifoDeviceHandle)
.def("start", &MyProxyService::start)
.def("stop", &MyProxyService::stop);
}
NB_MODULE(_ext, m) { init_mscclpp_proxy_test_module(m); }