connect() APIs changed to return an instance instead of a shared_ptr (#680)

The key purpose is handling all mscclpp objects' memory internally by
hiding shared pointers from user APIs.
* `Connection` class is now a wrapper of `BaseConnection` class that is
equivalent to the previous `Connection` class
* `connect()` methods now return `Connection` instead of
`std::shared_ptr<Connection>`
* Removed `connectOnSetup()` method
This commit is contained in:
Changho Hwang
2025-11-15 11:40:40 -08:00
committed by GitHub
parent 7eb3ff701a
commit 1bf4e8c90e
31 changed files with 252 additions and 213 deletions

View File

@@ -112,7 +112,7 @@ namespace mscclpp {
struct ExecutionContext {
std::shared_ptr<ProxyService> proxyService;
std::unordered_map<int, std::shared_ptr<Connection>> connections;
std::unordered_map<int, Connection> connections;
std::vector<std::shared_ptr<NvlsConnection>> nvlsConnections;
MemoryId localMemoryIdBegin = MemoryId(0);
@@ -270,7 +270,7 @@ struct Executor::Impl {
};
std::vector<int> connectedPeers = plan.impl_->getConnectedPeers();
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connectionFutures;
std::vector<std::shared_future<mscclpp::Connection>> connectionFutures;
for (int peer : connectedPeers) {
Transport transport =
inSameNode(rank, peer, this->nranksPerNode) ? Transport::CudaIpc : IBs[rank % this->nranksPerNode];
@@ -339,10 +339,10 @@ struct Executor::Impl {
auto connection = context.connections.at(peer);
if (info.channelType == ChannelType::MEMORY) {
futureMemorySemaphores.push_back(this->comm->buildSemaphore(
connection, this->comm->remoteRankOf(*connection), this->comm->tagOf(*connection)));
connection, this->comm->remoteRankOf(connection), this->comm->tagOf(connection)));
} else if (info.channelType == ChannelType::PORT) {
futureProxySemaphores.push_back(this->comm->buildSemaphore(
connection, this->comm->remoteRankOf(*connection), this->comm->tagOf(*connection)));
futureProxySemaphores.push_back(this->comm->buildSemaphore(connection, this->comm->remoteRankOf(connection),
this->comm->tagOf(connection)));
}
}
}