mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
connect() APIs changed to return an instance instead of a shared_ptr (#680)
The key purpose is handling all mscclpp objects' memory internally by hiding shared pointers from user APIs. * `Connection` class is now a wrapper of `BaseConnection` class that is equivalent to the previous `Connection` class * `connect()` methods now return `Connection` instead of `std::shared_ptr<Connection>` * Removed `connectOnSetup()` method
This commit is contained in:
@@ -202,7 +202,7 @@ void register_core(nb::module_& m) {
|
||||
.def("connect", &Context::connect, nb::arg("local_endpoint"), nb::arg("remote_endpoint"));
|
||||
|
||||
nb::class_<SemaphoreStub>(m, "SemaphoreStub")
|
||||
.def(nb::init<std::shared_ptr<Connection>>(), nb::arg("connection"))
|
||||
.def(nb::init<const Connection&>(), nb::arg("connection"))
|
||||
.def("memory", &SemaphoreStub::memory)
|
||||
.def("serialize", &SemaphoreStub::serialize)
|
||||
.def_static("deserialize", &SemaphoreStub::deserialize, nb::arg("data"));
|
||||
@@ -215,7 +215,7 @@ void register_core(nb::module_& m) {
|
||||
.def("remote_memory", &Semaphore::remoteMemory);
|
||||
|
||||
def_shared_future<RegisteredMemory>(m, "RegisteredMemory");
|
||||
def_shared_future<std::shared_ptr<Connection>>(m, "shared_ptr_Connection");
|
||||
def_shared_future<Connection>(m, "Connection");
|
||||
|
||||
nb::class_<Communicator>(m, "Communicator")
|
||||
.def(nb::init<std::shared_ptr<Bootstrap>, std::shared_ptr<Context>>(), nb::arg("bootstrap"),
|
||||
@@ -231,8 +231,8 @@ void register_core(nb::module_& m) {
|
||||
.def("send_memory", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remote_rank"), nb::arg("tag") = 0)
|
||||
.def("recv_memory", &Communicator::recvMemory, nb::arg("remote_rank"), nb::arg("tag") = 0)
|
||||
.def("connect",
|
||||
static_cast<std::shared_future<std::shared_ptr<Connection>> (Communicator::*)(const EndpointConfig&, int,
|
||||
int)>(&Communicator::connect),
|
||||
static_cast<std::shared_future<Connection> (Communicator::*)(const EndpointConfig&, int, int)>(
|
||||
&Communicator::connect),
|
||||
nb::arg("local_config"), nb::arg("remote_rank"), nb::arg("tag") = 0)
|
||||
.def(
|
||||
"connect_on_setup",
|
||||
|
||||
@@ -12,7 +12,7 @@ using namespace mscclpp;
|
||||
void register_semaphore(nb::module_& m) {
|
||||
nb::class_<Host2DeviceSemaphore> host2DeviceSemaphore(m, "Host2DeviceSemaphore");
|
||||
host2DeviceSemaphore.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
|
||||
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def(nb::init<Communicator&, const Connection&>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("connection", &Host2DeviceSemaphore::connection)
|
||||
.def("signal", &Host2DeviceSemaphore::signal)
|
||||
.def("device_handle", &Host2DeviceSemaphore::deviceHandle);
|
||||
@@ -27,7 +27,7 @@ void register_semaphore(nb::module_& m) {
|
||||
|
||||
nb::class_<Host2HostSemaphore>(m, "Host2HostSemaphore")
|
||||
.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
|
||||
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def(nb::init<Communicator&, const Connection&>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("connection", &Host2HostSemaphore::connection)
|
||||
.def("signal", &Host2HostSemaphore::signal)
|
||||
.def("poll", &Host2HostSemaphore::poll)
|
||||
@@ -36,7 +36,7 @@ void register_semaphore(nb::module_& m) {
|
||||
|
||||
nb::class_<MemoryDevice2DeviceSemaphore> memoryDevice2DeviceSemaphore(m, "MemoryDevice2DeviceSemaphore");
|
||||
memoryDevice2DeviceSemaphore.def(nb::init<const Semaphore&>(), nb::arg("semaphore"))
|
||||
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def(nb::init<Communicator&, const Connection&>(), nb::arg("communicator"), nb::arg("connection"))
|
||||
.def("connection", &MemoryDevice2DeviceSemaphore::connection)
|
||||
.def("device_handle", &MemoryDevice2DeviceSemaphore::deviceHandle);
|
||||
|
||||
|
||||
@@ -21,13 +21,13 @@ class MyProxyService {
|
||||
private:
|
||||
int deviceNumaNode_;
|
||||
int my_rank_, nranks_, dataSize_;
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>> connections_;
|
||||
std::vector<mscclpp::Connection> connections_;
|
||||
std::vector<std::shared_ptr<mscclpp::RegisteredMemory>> allRegMem_;
|
||||
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> semaphores_;
|
||||
mscclpp::Proxy proxy_;
|
||||
|
||||
public:
|
||||
MyProxyService(int my_rank, int nranks, int dataSize, std::vector<std::shared_ptr<mscclpp::Connection>> conns,
|
||||
MyProxyService(int my_rank, int nranks, int dataSize, std::vector<mscclpp::Connection> conns,
|
||||
std::vector<std::shared_ptr<mscclpp::RegisteredMemory>> allRegMem,
|
||||
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> semaphores)
|
||||
: my_rank_(my_rank),
|
||||
@@ -46,10 +46,10 @@ class MyProxyService {
|
||||
int dataSizePerRank = dataSize_ / nranks_;
|
||||
for (int r = 1; r < nranks_; ++r) {
|
||||
int nghr = (my_rank_ + r) % nranks_;
|
||||
connections_[nghr]->write(*allRegMem_[nghr], my_rank_ * (uint64_t)dataSizePerRank, *allRegMem_[my_rank_],
|
||||
my_rank_ * (uint64_t)dataSizePerRank, dataSizePerRank);
|
||||
connections_[nghr].write(*allRegMem_[nghr], my_rank_ * (uint64_t)dataSizePerRank, *allRegMem_[my_rank_],
|
||||
my_rank_ * (uint64_t)dataSizePerRank, dataSizePerRank);
|
||||
semaphores_[nghr]->signal();
|
||||
connections_[nghr]->flush();
|
||||
connections_[nghr].flush();
|
||||
}
|
||||
return mscclpp::ProxyHandlerResult::Continue;
|
||||
}
|
||||
@@ -63,7 +63,7 @@ class MyProxyService {
|
||||
|
||||
void init_mscclpp_proxy_test_module(nb::module_ &m) {
|
||||
nb::class_<MyProxyService>(m, "MyProxyService")
|
||||
.def(nb::init<int, int, int, std::vector<std::shared_ptr<mscclpp::Connection>>,
|
||||
.def(nb::init<int, int, int, std::vector<mscclpp::Connection>,
|
||||
std::vector<std::shared_ptr<mscclpp::RegisteredMemory>>,
|
||||
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>>>(),
|
||||
nb::arg("rank"), nb::arg("nranks"), nb::arg("data_size"), nb::arg("conn_vec"), nb::arg("reg_mem_vec"),
|
||||
|
||||
Reference in New Issue
Block a user