mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-01 20:21:26 +00:00
Fix Python bindings and tests (#690)
Minimal fix to make things work. We need a more careful look at preventing silent fallback of nanobind when it fails to (properly) construct a C++ STL object with mscclpp instances.
This commit is contained in:
@@ -10,7 +10,6 @@
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/fifo.hpp>
|
||||
#include <mscclpp/gpu_utils.hpp>
|
||||
#include <mscclpp/numa.hpp>
|
||||
#include <mscclpp/proxy.hpp>
|
||||
#include <mscclpp/semaphore.hpp>
|
||||
#include <vector>
|
||||
@@ -19,37 +18,39 @@ namespace nb = nanobind;
|
||||
|
||||
class MyProxyService {
|
||||
private:
|
||||
int deviceNumaNode_;
|
||||
int my_rank_, nranks_, dataSize_;
|
||||
std::vector<mscclpp::Connection> connections_;
|
||||
std::vector<std::shared_ptr<mscclpp::RegisteredMemory>> allRegMem_;
|
||||
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> semaphores_;
|
||||
std::vector<mscclpp::RegisteredMemory> allRegMem_;
|
||||
std::vector<mscclpp::Host2DeviceSemaphore> semaphores_;
|
||||
mscclpp::Proxy proxy_;
|
||||
|
||||
public:
|
||||
MyProxyService(int my_rank, int nranks, int dataSize, std::vector<mscclpp::Connection> conns,
|
||||
std::vector<std::shared_ptr<mscclpp::RegisteredMemory>> allRegMem,
|
||||
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> semaphores)
|
||||
MyProxyService(int my_rank, int nranks, int dataSize, nb::list allRegMemList, nb::list semaphoreList)
|
||||
: my_rank_(my_rank),
|
||||
nranks_(nranks),
|
||||
dataSize_(dataSize),
|
||||
connections_(conns),
|
||||
allRegMem_(allRegMem),
|
||||
semaphores_(semaphores),
|
||||
proxy_([&](mscclpp::ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }) {
|
||||
int cudaDevice;
|
||||
MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDevice));
|
||||
deviceNumaNode_ = mscclpp::getDeviceNumaNode(cudaDevice);
|
||||
allRegMem_.reserve(allRegMemList.size());
|
||||
for (size_t i = 0; i < allRegMemList.size(); ++i) {
|
||||
auto& regMem = nb::cast<const mscclpp::RegisteredMemory&>(allRegMemList[i]);
|
||||
allRegMem_.push_back(regMem);
|
||||
}
|
||||
semaphores_.reserve(semaphoreList.size());
|
||||
for (size_t i = 0; i < semaphoreList.size(); ++i) {
|
||||
auto& sema = nb::cast<const mscclpp::Semaphore&>(semaphoreList[i]);
|
||||
semaphores_.emplace_back(sema);
|
||||
}
|
||||
}
|
||||
|
||||
mscclpp::ProxyHandlerResult handleTrigger(mscclpp::ProxyTrigger) {
|
||||
int dataSizePerRank = dataSize_ / nranks_;
|
||||
for (int r = 1; r < nranks_; ++r) {
|
||||
int nghr = (my_rank_ + r) % nranks_;
|
||||
connections_[nghr].write(*allRegMem_[nghr], my_rank_ * (uint64_t)dataSizePerRank, *allRegMem_[my_rank_],
|
||||
my_rank_ * (uint64_t)dataSizePerRank, dataSizePerRank);
|
||||
semaphores_[nghr]->signal();
|
||||
connections_[nghr].flush();
|
||||
auto& sema = semaphores_[nghr];
|
||||
auto& conn = sema.connection();
|
||||
conn.write(allRegMem_[nghr], my_rank_ * (uint64_t)dataSizePerRank, allRegMem_[my_rank_],
|
||||
my_rank_ * (uint64_t)dataSizePerRank, dataSizePerRank);
|
||||
sema.signal();
|
||||
conn.flush();
|
||||
}
|
||||
return mscclpp::ProxyHandlerResult::Continue;
|
||||
}
|
||||
@@ -61,16 +62,11 @@ class MyProxyService {
|
||||
mscclpp::FifoDeviceHandle fifoDeviceHandle() { return proxy_.fifo()->deviceHandle(); }
|
||||
};
|
||||
|
||||
void init_mscclpp_proxy_test_module(nb::module_ &m) {
|
||||
NB_MODULE(_ext, m) {
|
||||
nb::class_<MyProxyService>(m, "MyProxyService")
|
||||
.def(nb::init<int, int, int, std::vector<mscclpp::Connection>,
|
||||
std::vector<std::shared_ptr<mscclpp::RegisteredMemory>>,
|
||||
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>>>(),
|
||||
nb::arg("rank"), nb::arg("nranks"), nb::arg("data_size"), nb::arg("conn_vec"), nb::arg("reg_mem_vec"),
|
||||
nb::arg("h2d_sem_vec"))
|
||||
.def(nb::init<int, int, int, nb::list, nb::list>(), nb::arg("rank"), nb::arg("nranks"), nb::arg("data_size"),
|
||||
nb::arg("reg_mem_list"), nb::arg("sem_list"))
|
||||
.def("fifo_device_handle", &MyProxyService::fifoDeviceHandle)
|
||||
.def("start", &MyProxyService::start)
|
||||
.def("stop", &MyProxyService::stop);
|
||||
}
|
||||
|
||||
NB_MODULE(_ext, m) { init_mscclpp_proxy_test_module(m); }
|
||||
|
||||
Reference in New Issue
Block a user