connect() APIs changed to return an instance instead of a shared_ptr (#680)

The key purpose is handling all mscclpp objects' memory internally by
hiding shared pointers from user APIs.
* `Connection` class is now a wrapper of `BaseConnection` class that is
equivalent to the previous `Connection` class
* `connect()` methods now return `Connection` instead of
`std::shared_ptr<Connection>`
* Removed `connectOnSetup()` method
This commit is contained in:
Changho Hwang
2025-11-15 11:40:40 -08:00
committed by GitHub
parent 7eb3ff701a
commit 1bf4e8c90e
31 changed files with 252 additions and 213 deletions

View File

@@ -21,13 +21,13 @@ class MyProxyService {
private:
int deviceNumaNode_;
int my_rank_, nranks_, dataSize_;
std::vector<std::shared_ptr<mscclpp::Connection>> connections_;
std::vector<mscclpp::Connection> connections_;
std::vector<std::shared_ptr<mscclpp::RegisteredMemory>> allRegMem_;
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> semaphores_;
mscclpp::Proxy proxy_;
public:
MyProxyService(int my_rank, int nranks, int dataSize, std::vector<std::shared_ptr<mscclpp::Connection>> conns,
MyProxyService(int my_rank, int nranks, int dataSize, std::vector<mscclpp::Connection> conns,
std::vector<std::shared_ptr<mscclpp::RegisteredMemory>> allRegMem,
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> semaphores)
: my_rank_(my_rank),
@@ -46,10 +46,10 @@ class MyProxyService {
int dataSizePerRank = dataSize_ / nranks_;
for (int r = 1; r < nranks_; ++r) {
int nghr = (my_rank_ + r) % nranks_;
connections_[nghr]->write(*allRegMem_[nghr], my_rank_ * (uint64_t)dataSizePerRank, *allRegMem_[my_rank_],
my_rank_ * (uint64_t)dataSizePerRank, dataSizePerRank);
connections_[nghr].write(*allRegMem_[nghr], my_rank_ * (uint64_t)dataSizePerRank, *allRegMem_[my_rank_],
my_rank_ * (uint64_t)dataSizePerRank, dataSizePerRank);
semaphores_[nghr]->signal();
connections_[nghr]->flush();
connections_[nghr].flush();
}
return mscclpp::ProxyHandlerResult::Continue;
}
@@ -63,7 +63,7 @@ class MyProxyService {
void init_mscclpp_proxy_test_module(nb::module_ &m) {
nb::class_<MyProxyService>(m, "MyProxyService")
.def(nb::init<int, int, int, std::vector<std::shared_ptr<mscclpp::Connection>>,
.def(nb::init<int, int, int, std::vector<mscclpp::Connection>,
std::vector<std::shared_ptr<mscclpp::RegisteredMemory>>,
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>>>(),
nb::arg("rank"), nb::arg("nranks"), nb::arg("data_size"), nb::arg("conn_vec"), nb::arg("reg_mem_vec"),