mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
* In cases when the same `tag` is used for receiving data from the same remote rank, #514 changed the behavior of `Communicator::connect` and `Communicator::recvMemory` to receive data in the order of `std::shared_future::get()` is called, instead of the original behvaior that receive data in the order of the method calls. Since the original behavior is more intuitive, we get that back. Now when `get()` is called on a future, the async function will first call `wait()` on the latest previously returned future. In a recursive manner, this will call `wait()` on all previous futures that are not yet ready. * Removed all deprecated API calls and replaced into the new ones.
This commit is contained in:
@@ -270,13 +270,12 @@ static std::vector<mscclpp::RegisteredMemory> setupRemoteMemories(std::shared_pt
|
||||
mscclpp::TransportFlags transport) {
|
||||
std::vector<mscclpp::RegisteredMemory> remoteMemories;
|
||||
mscclpp::RegisteredMemory memory = comm->registerMemory(buff, bytes, transport);
|
||||
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteRegMemoryFutures;
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemoryFutures;
|
||||
for (int i = 0; i < comm->bootstrap()->getNranks(); i++) {
|
||||
if (i == rank) continue;
|
||||
remoteRegMemoryFutures.push_back(comm->recvMemoryOnSetup(i, 0));
|
||||
comm->sendMemoryOnSetup(memory, i, 0);
|
||||
remoteRegMemoryFutures.push_back(comm->recvMemory(i, 0));
|
||||
comm->sendMemory(memory, i, 0);
|
||||
}
|
||||
comm->setup();
|
||||
std::transform(remoteRegMemoryFutures.begin(), remoteRegMemoryFutures.end(), std::back_inserter(remoteMemories),
|
||||
[](const auto& future) { return future.get(); });
|
||||
return remoteMemories;
|
||||
@@ -602,15 +601,13 @@ static ncclResult_t ncclAllGatherFallback(const void* sendbuff, void* recvbuff,
|
||||
|
||||
static void ncclCommInitRankFallbackSingleNode(ncclComm* commPtr, std::shared_ptr<mscclpp::Communicator> mscclppComm,
|
||||
int rank) {
|
||||
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connectionFutures;
|
||||
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connectionFutures;
|
||||
|
||||
for (int i = 0; i < mscclppComm->bootstrap()->getNranks(); i++) {
|
||||
if (i == rank) continue;
|
||||
mscclpp::Transport transport = getTransport(rank, i);
|
||||
connectionFutures.push_back(mscclppComm->connectOnSetup(i, 0, transport));
|
||||
connectionFutures.push_back(mscclppComm->connect(i, 0, transport));
|
||||
}
|
||||
mscclppComm->setup();
|
||||
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
|
||||
std::transform(connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections),
|
||||
[](const auto& future) { return future.get(); });
|
||||
@@ -625,7 +622,6 @@ static void ncclCommInitRankFallbackSingleNode(ncclComm* commPtr, std::shared_pt
|
||||
}
|
||||
}
|
||||
|
||||
mscclppComm->setup();
|
||||
commPtr->connections = std::move(connections);
|
||||
if (mscclpp::isNvlsSupported()) {
|
||||
commPtr->nvlsConnections = setupNvlsConnections(commPtr, NVLS_BUFFER_SIZE);
|
||||
|
||||
@@ -32,29 +32,25 @@ void setupMeshTopology(int rank, int worldsize, void* data, size_t dataSize) {
|
||||
|
||||
std::vector<mscclpp::SemaphoreId> semaphoreIds;
|
||||
std::vector<mscclpp::RegisteredMemory> localMemories;
|
||||
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connections(world_size);
|
||||
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteMemories;
|
||||
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connections(world_size);
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteMemories;
|
||||
|
||||
for (int r = 0; r < world_size; ++r) {
|
||||
if (r == rank) continue;
|
||||
mscclpp::Transport transport = mscclpp::Transport::CudaIpc;
|
||||
// Connect with all other ranks
|
||||
connections[r] = comm.connectOnSetup(r, 0, transport);
|
||||
connections[r] = comm.connect(r, 0, transport);
|
||||
auto memory = comm.registerMemory(data, dataSize, mscclpp::Transport::CudaIpc | ibTransport);
|
||||
localMemories.push_back(memory);
|
||||
comm.sendMemoryOnSetup(memory, r, 0);
|
||||
remoteMemories.push_back(comm.recvMemoryOnSetup(r, 0));
|
||||
comm.sendMemory(memory, r, 0);
|
||||
remoteMemories.push_back(comm.recvMemory(r, 0));
|
||||
}
|
||||
|
||||
comm.setup();
|
||||
|
||||
for (int r = 0; r < world_size; ++r) {
|
||||
if (r == rank) continue;
|
||||
semaphoreIds.push_back(proxyService.buildAndAddSemaphore(comm, connections[r].get()));
|
||||
}
|
||||
|
||||
comm.setup();
|
||||
|
||||
std::vector<DeviceHandle<mscclpp::PortChannel>> portChannels;
|
||||
for (size_t i = 0; i < semaphoreIds.size(); ++i) {
|
||||
portChannels.push_back(mscclpp::deviceHandle(mscclpp::PortChannel(
|
||||
|
||||
@@ -101,13 +101,12 @@ class CommGroup:
|
||||
if endpoint.transport == Transport.Nvls:
|
||||
return connect_nvls_collective(self.communicator, all_ranks, 2**30)
|
||||
else:
|
||||
connections[rank] = self.communicator.connect_on_setup(rank, 0, endpoint)
|
||||
self.communicator.setup()
|
||||
connections[rank] = self.communicator.connect(rank, 0, endpoint)
|
||||
connections = {rank: connections[rank].get() for rank in connections}
|
||||
return connections
|
||||
|
||||
def register_tensor_with_connections(
|
||||
self, tensor: Type[cp.ndarray] or Type[np.ndarray], connections: dict[int, Connection]
|
||||
self, tensor: Type[cp.ndarray] | Type[np.ndarray], connections: dict[int, Connection]
|
||||
) -> dict[int, RegisteredMemory]:
|
||||
transport_flags = TransportFlags()
|
||||
for rank in connections:
|
||||
@@ -125,9 +124,8 @@ class CommGroup:
|
||||
all_registered_memories[self.my_rank] = local_reg_memory
|
||||
future_memories = {}
|
||||
for rank in connections:
|
||||
self.communicator.send_memory_on_setup(local_reg_memory, rank, 0)
|
||||
future_memories[rank] = self.communicator.recv_memory_on_setup(rank, 0)
|
||||
self.communicator.setup()
|
||||
self.communicator.send_memory(local_reg_memory, rank, 0)
|
||||
future_memories[rank] = self.communicator.recv_memory(rank, 0)
|
||||
for rank in connections:
|
||||
all_registered_memories[rank] = future_memories[rank].get()
|
||||
return all_registered_memories
|
||||
@@ -135,12 +133,11 @@ class CommGroup:
|
||||
def make_semaphore(
|
||||
self,
|
||||
connections: dict[int, Connection],
|
||||
semaphore_type: Type[Host2HostSemaphore] or Type[Host2DeviceSemaphore] or Type[MemoryDevice2DeviceSemaphore],
|
||||
semaphore_type: Type[Host2HostSemaphore] | Type[Host2DeviceSemaphore] | Type[MemoryDevice2DeviceSemaphore],
|
||||
) -> dict[int, Host2HostSemaphore]:
|
||||
semaphores = {}
|
||||
for rank in connections:
|
||||
semaphores[rank] = semaphore_type(self.communicator, connections[rank])
|
||||
self.communicator.setup()
|
||||
return semaphores
|
||||
|
||||
def make_memory_channels(self, tensor: cp.ndarray, connections: dict[int, Connection]) -> dict[int, MemoryChannel]:
|
||||
|
||||
@@ -27,9 +27,9 @@ extern void register_npkit(nb::module_& m);
|
||||
extern void register_gpu_utils(nb::module_& m);
|
||||
|
||||
template <typename T>
|
||||
void def_nonblocking_future(nb::handle& m, const std::string& typestr) {
|
||||
std::string pyclass_name = std::string("NonblockingFuture") + typestr;
|
||||
nb::class_<NonblockingFuture<T>>(m, pyclass_name.c_str()).def("get", &NonblockingFuture<T>::get);
|
||||
void def_shared_future(nb::handle& m, const std::string& typestr) {
|
||||
std::string pyclass_name = std::string("shared_future_") + typestr;
|
||||
nb::class_<std::shared_future<T>>(m, pyclass_name.c_str()).def("get", &std::shared_future<T>::get);
|
||||
}
|
||||
|
||||
void register_core(nb::module_& m) {
|
||||
@@ -158,8 +158,8 @@ void register_core(nb::module_& m) {
|
||||
.def("create_endpoint", &Context::createEndpoint, nb::arg("config"))
|
||||
.def("connect", &Context::connect, nb::arg("local_endpoint"), nb::arg("remote_endpoint"));
|
||||
|
||||
def_nonblocking_future<RegisteredMemory>(m, "RegisteredMemory");
|
||||
def_nonblocking_future<std::shared_ptr<Connection>>(m, "shared_ptr_Connection");
|
||||
def_shared_future<RegisteredMemory>(m, "RegisteredMemory");
|
||||
def_shared_future<std::shared_ptr<Connection>>(m, "shared_ptr_Connection");
|
||||
|
||||
nb::class_<Communicator>(m, "Communicator")
|
||||
.def(nb::init<std::shared_ptr<Bootstrap>, std::shared_ptr<Context>>(), nb::arg("bootstrap"),
|
||||
@@ -172,14 +172,15 @@ void register_core(nb::module_& m) {
|
||||
return self->registerMemory((void*)ptr, size, transports);
|
||||
},
|
||||
nb::arg("ptr"), nb::arg("size"), nb::arg("transports"))
|
||||
.def("send_memory_on_setup", &Communicator::sendMemoryOnSetup, nb::arg("memory"), nb::arg("remoteRank"),
|
||||
nb::arg("tag"))
|
||||
.def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag"))
|
||||
.def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"),
|
||||
nb::arg("localConfig"))
|
||||
.def("send_memory", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remoteRank"), nb::arg("tag"))
|
||||
.def("recv_memory", &Communicator::recvMemory, nb::arg("remoteRank"), nb::arg("tag"))
|
||||
.def("connect", &Communicator::connect, nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
|
||||
.def("send_memory_on_setup", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remoteRank"), nb::arg("tag"))
|
||||
.def("recv_memory_on_setup", &Communicator::recvMemory, nb::arg("remoteRank"), nb::arg("tag"))
|
||||
.def("connect_on_setup", &Communicator::connect, nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
|
||||
.def("remote_rank_of", &Communicator::remoteRankOf)
|
||||
.def("tag_of", &Communicator::tagOf)
|
||||
.def("setup", &Communicator::setup);
|
||||
.def("setup", [](Communicator*) {});
|
||||
}
|
||||
|
||||
NB_MODULE(_mscclpp, m) {
|
||||
|
||||
@@ -52,14 +52,14 @@ MSCCLPP_API_CPP void Bootstrap::groupBarrier(const std::vector<int>& ranks) {
|
||||
MSCCLPP_API_CPP void Bootstrap::send(const std::vector<char>& data, int peer, int tag) {
|
||||
size_t size = data.size();
|
||||
send((void*)&size, sizeof(size_t), peer, tag);
|
||||
send((void*)data.data(), data.size(), peer, tag + 1);
|
||||
send((void*)data.data(), data.size(), peer, tag);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Bootstrap::recv(std::vector<char>& data, int peer, int tag) {
|
||||
size_t size;
|
||||
recv((void*)&size, sizeof(size_t), peer, tag);
|
||||
data.resize(size);
|
||||
recv((void*)data.data(), data.size(), peer, tag + 1);
|
||||
recv((void*)data.data(), data.size(), peer, tag);
|
||||
}
|
||||
|
||||
struct UniqueIdInternal {
|
||||
@@ -528,6 +528,7 @@ std::shared_ptr<Socket> TcpBootstrap::Impl::getPeerRecvSocket(int peer, int tag)
|
||||
if (recvPeer == peer && recvTag == tag) {
|
||||
return sock;
|
||||
}
|
||||
// TODO(chhwang): set an exit condition or timeout
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,22 @@ Communicator::Impl::Impl(std::shared_ptr<Bootstrap> bootstrap, std::shared_ptr<C
|
||||
}
|
||||
}
|
||||
|
||||
void Communicator::Impl::setLastRecvItem(int remoteRank, int tag, std::shared_ptr<BaseRecvItem> item) {
|
||||
lastRecvItems_[{remoteRank, tag}] = item;
|
||||
}
|
||||
|
||||
std::shared_ptr<BaseRecvItem> Communicator::Impl::getLastRecvItem(int remoteRank, int tag) {
|
||||
auto it = lastRecvItems_.find({remoteRank, tag});
|
||||
if (it == lastRecvItems_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (it->second->isReady()) {
|
||||
lastRecvItems_.erase(it);
|
||||
return nullptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Communicator::~Communicator() = default;
|
||||
|
||||
MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr<Bootstrap> bootstrap, std::shared_ptr<Context> context)
|
||||
@@ -31,30 +47,47 @@ MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void Communicator::sendMemory(RegisteredMemory memory, int remoteRank, int tag) {
|
||||
pimpl_->bootstrap_->send(memory.serialize(), remoteRank, tag);
|
||||
bootstrap()->send(memory.serialize(), remoteRank, tag);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_future<RegisteredMemory> Communicator::recvMemory(int remoteRank, int tag) {
|
||||
return std::async(std::launch::deferred, [this, remoteRank, tag]() {
|
||||
std::vector<char> data;
|
||||
bootstrap()->recv(data, remoteRank, tag);
|
||||
return RegisteredMemory::deserialize(data);
|
||||
});
|
||||
auto future = std::async(std::launch::deferred,
|
||||
[this, remoteRank, tag, lastRecvItem = pimpl_->getLastRecvItem(remoteRank, tag)]() {
|
||||
if (lastRecvItem) {
|
||||
// Recursive call to the previous receive items
|
||||
lastRecvItem->wait();
|
||||
}
|
||||
std::vector<char> data;
|
||||
bootstrap()->recv(data, remoteRank, tag);
|
||||
return RegisteredMemory::deserialize(data);
|
||||
});
|
||||
auto shared_future = std::shared_future<RegisteredMemory>(std::move(future));
|
||||
pimpl_->setLastRecvItem(remoteRank, tag, std::make_shared<RecvItem<RegisteredMemory>>(shared_future));
|
||||
return shared_future;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(int remoteRank, int tag,
|
||||
EndpointConfig localConfig) {
|
||||
auto localEndpoint = pimpl_->context_->createEndpoint(localConfig);
|
||||
pimpl_->bootstrap_->send(localEndpoint.serialize(), remoteRank, tag);
|
||||
auto localEndpoint = context()->createEndpoint(localConfig);
|
||||
bootstrap()->send(localEndpoint.serialize(), remoteRank, tag);
|
||||
|
||||
return std::async(std::launch::deferred, [this, remoteRank, tag, localEndpoint = std::move(localEndpoint)]() mutable {
|
||||
std::vector<char> data;
|
||||
bootstrap()->recv(data, remoteRank, tag);
|
||||
auto remoteEndpoint = Endpoint::deserialize(data);
|
||||
auto connection = context()->connect(localEndpoint, remoteEndpoint);
|
||||
pimpl_->connectionInfos_[connection.get()] = {remoteRank, tag};
|
||||
return connection;
|
||||
});
|
||||
auto future =
|
||||
std::async(std::launch::deferred, [this, remoteRank, tag, lastRecvItem = pimpl_->getLastRecvItem(remoteRank, tag),
|
||||
localEndpoint = std::move(localEndpoint)]() mutable {
|
||||
if (lastRecvItem) {
|
||||
// Recursive call to the previous receive items
|
||||
lastRecvItem->wait();
|
||||
}
|
||||
std::vector<char> data;
|
||||
bootstrap()->recv(data, remoteRank, tag);
|
||||
auto remoteEndpoint = Endpoint::deserialize(data);
|
||||
auto connection = context()->connect(localEndpoint, remoteEndpoint);
|
||||
pimpl_->connectionInfos_[connection.get()] = {remoteRank, tag};
|
||||
return connection;
|
||||
});
|
||||
auto shared_future = std::shared_future<std::shared_ptr<Connection>>(std::move(future));
|
||||
pimpl_->setLastRecvItem(remoteRank, tag, std::make_shared<RecvItem<std::shared_ptr<Connection>>>(shared_future));
|
||||
return shared_future;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP int Communicator::remoteRankOf(const Connection& connection) {
|
||||
|
||||
@@ -212,13 +212,12 @@ struct Executor::Impl {
|
||||
void setupConnections(ExecutionContext& context, int rank, const ExecutionPlan& plan, size_t sendBufferSize,
|
||||
size_t recvBufferSize) {
|
||||
std::vector<int> connectedPeers = plan.impl_->getConnectedPeers(rank);
|
||||
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connectionFutures;
|
||||
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connectionFutures;
|
||||
for (int peer : connectedPeers) {
|
||||
Transport transport =
|
||||
inSameNode(rank, peer, this->nranksPerNode) ? Transport::CudaIpc : IBs[rank % this->nranksPerNode];
|
||||
connectionFutures.push_back(this->comm->connectOnSetup(peer, 0, transport));
|
||||
connectionFutures.push_back(this->comm->connect(peer, 0, transport));
|
||||
}
|
||||
this->comm->setup();
|
||||
for (size_t i = 0; i < connectionFutures.size(); i++) {
|
||||
context.connections[connectedPeers[i]] = connectionFutures[i].get();
|
||||
}
|
||||
@@ -262,16 +261,15 @@ struct Executor::Impl {
|
||||
RegisteredMemory memory =
|
||||
this->comm->registerMemory(getBufferInfo(bufferType).first, getBufferInfo(bufferType).second, transportFlags);
|
||||
std::vector<int> connectedPeers = getConnectedPeers(channelInfos);
|
||||
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteRegMemoryFutures;
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemoryFutures;
|
||||
for (int peer : connectedPeers) {
|
||||
comm->sendMemoryOnSetup(memory, peer, 0);
|
||||
comm->sendMemory(memory, peer, 0);
|
||||
}
|
||||
channelInfos = plan.impl_->getChannelInfos(rank, bufferType);
|
||||
connectedPeers = getConnectedPeers(channelInfos);
|
||||
for (int peer : connectedPeers) {
|
||||
remoteRegMemoryFutures.push_back(comm->recvMemoryOnSetup(peer, 0));
|
||||
remoteRegMemoryFutures.push_back(comm->recvMemory(peer, 0));
|
||||
}
|
||||
comm->setup();
|
||||
for (size_t i = 0; i < remoteRegMemoryFutures.size(); i++) {
|
||||
context.registeredMemories[{bufferType, connectedPeers[i]}] = std::move(remoteRegMemoryFutures[i].get());
|
||||
}
|
||||
@@ -307,7 +305,6 @@ struct Executor::Impl {
|
||||
channelInfos = plan.impl_->getUnpairedChannelInfos(rank, nranks, channelType);
|
||||
processChannelInfos(channelInfos);
|
||||
}
|
||||
this->comm->setup();
|
||||
context.memorySemaphores = std::move(memorySemaphores);
|
||||
context.proxySemaphores = std::move(proxySemaphores);
|
||||
|
||||
|
||||
@@ -9,9 +9,29 @@
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "utils_internal.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
class ConnectionBase;
|
||||
class BaseRecvItem {
|
||||
public:
|
||||
virtual ~BaseRecvItem() = default;
|
||||
virtual void wait() = 0;
|
||||
virtual bool isReady() const = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class RecvItem : public BaseRecvItem {
|
||||
public:
|
||||
RecvItem(std::shared_future<T> future) : future_(future) {}
|
||||
|
||||
void wait() { future_.wait(); }
|
||||
|
||||
bool isReady() const { return future_.wait_for(std::chrono::seconds(0)) == std::future_status::ready; }
|
||||
|
||||
private:
|
||||
std::shared_future<T> future_;
|
||||
};
|
||||
|
||||
struct ConnectionInfo {
|
||||
int remoteRank;
|
||||
@@ -22,9 +42,22 @@ struct Communicator::Impl {
|
||||
std::shared_ptr<Bootstrap> bootstrap_;
|
||||
std::shared_ptr<Context> context_;
|
||||
std::unordered_map<const Connection*, ConnectionInfo> connectionInfos_;
|
||||
std::shared_ptr<BaseRecvItem> lastRecvItem_;
|
||||
|
||||
// Temporary storage for the latest RecvItem of each {remoteRank, tag} pair.
|
||||
// If the RecvItem gets ready, it will be removed at the next call to getLastRecvItem.
|
||||
std::unordered_map<std::pair<int, int>, std::shared_ptr<BaseRecvItem>, PairHash> lastRecvItems_;
|
||||
|
||||
Impl(std::shared_ptr<Bootstrap> bootstrap, std::shared_ptr<Context> context);
|
||||
|
||||
// Set the last RecvItem for a {remoteRank, tag} pair.
|
||||
// This is used to store the corresponding RecvItem of a future returned by recvMemory() or connect().
|
||||
void setLastRecvItem(int remoteRank, int tag, std::shared_ptr<BaseRecvItem> item);
|
||||
|
||||
// Return the last RecvItem that is not ready.
|
||||
// If the item is ready, it will be removed from the map and nullptr will be returned.
|
||||
std::shared_ptr<BaseRecvItem> getLastRecvItem(int remoteRank, int tag);
|
||||
|
||||
struct Connector;
|
||||
};
|
||||
|
||||
|
||||
@@ -9,14 +9,14 @@
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
static NonblockingFuture<RegisteredMemory> setupInboundSemaphoreId(Communicator& communicator, Connection* connection,
|
||||
void* localInboundSemaphoreId) {
|
||||
static std::shared_future<RegisteredMemory> setupInboundSemaphoreId(Communicator& communicator, Connection* connection,
|
||||
void* localInboundSemaphoreId) {
|
||||
auto localInboundSemaphoreIdsRegMem =
|
||||
communicator.registerMemory(localInboundSemaphoreId, sizeof(uint64_t), connection->transport());
|
||||
int remoteRank = communicator.remoteRankOf(*connection);
|
||||
int tag = communicator.tagOf(*connection);
|
||||
communicator.sendMemoryOnSetup(localInboundSemaphoreIdsRegMem, remoteRank, tag);
|
||||
return communicator.recvMemoryOnSetup(remoteRank, tag);
|
||||
communicator.sendMemory(localInboundSemaphoreIdsRegMem, remoteRank, tag);
|
||||
return communicator.recvMemory(remoteRank, tag);
|
||||
}
|
||||
|
||||
static detail::UniqueGpuPtr<uint64_t> createGpuSemaphoreId() {
|
||||
|
||||
@@ -206,8 +206,8 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co
|
||||
mscclpp::Transport ibTransport = mscclpp::getIBTransportByDeviceName(ibDevStr);
|
||||
std::vector<mscclpp::SemaphoreId> semaphoreIds;
|
||||
std::vector<mscclpp::RegisteredMemory> localMemories;
|
||||
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connections(world_size);
|
||||
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteMemories;
|
||||
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connections(world_size);
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteMemories;
|
||||
|
||||
for (int r = 0; r < world_size; ++r) {
|
||||
if (r == rank) continue;
|
||||
@@ -218,22 +218,18 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co
|
||||
transport = ibTransport;
|
||||
}
|
||||
// Connect with all other ranks
|
||||
connections[r] = comm.connectOnSetup(r, 0, transport);
|
||||
connections[r] = comm.connect(r, 0, transport);
|
||||
auto memory = comm.registerMemory(data_d, dataSize, mscclpp::Transport::CudaIpc | ibTransport);
|
||||
localMemories.push_back(memory);
|
||||
comm.sendMemoryOnSetup(memory, r, 0);
|
||||
remoteMemories.push_back(comm.recvMemoryOnSetup(r, 0));
|
||||
comm.sendMemory(memory, r, 0);
|
||||
remoteMemories.push_back(comm.recvMemory(r, 0));
|
||||
}
|
||||
|
||||
comm.setup();
|
||||
|
||||
for (int r = 0; r < world_size; ++r) {
|
||||
if (r == rank) continue;
|
||||
semaphoreIds.push_back(proxyService.buildAndAddSemaphore(comm, connections[r].get()));
|
||||
}
|
||||
|
||||
comm.setup();
|
||||
|
||||
std::vector<DeviceHandle<mscclpp::PortChannel>> portChannels;
|
||||
for (size_t i = 0; i < semaphoreIds.size(); ++i) {
|
||||
portChannels.push_back(mscclpp::deviceHandle(proxyService.portChannel(
|
||||
|
||||
@@ -104,8 +104,8 @@ class MyProxyService {
|
||||
int cudaNum = rankToLocalRank(rank);
|
||||
std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum);
|
||||
mscclpp::Transport ibTransport = mscclpp::getIBTransportByDeviceName(ibDevStr);
|
||||
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connectionsFuture(world_size);
|
||||
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteMemoriesFuture(world_size);
|
||||
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connectionsFuture(world_size);
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteMemoriesFuture(world_size);
|
||||
|
||||
localMemory_ = comm.registerMemory(data_d, dataSize, mscclpp::Transport::CudaIpc | ibTransport);
|
||||
for (int r = 0; r < world_size; ++r) {
|
||||
@@ -122,14 +122,12 @@ class MyProxyService {
|
||||
transport = ibTransport;
|
||||
}
|
||||
// Connect with all other ranks
|
||||
connectionsFuture[r] = comm.connectOnSetup(r, 0, transport);
|
||||
comm.sendMemoryOnSetup(localMemory_, r, 0);
|
||||
connectionsFuture[r] = comm.connect(r, 0, transport);
|
||||
comm.sendMemory(localMemory_, r, 0);
|
||||
|
||||
remoteMemoriesFuture[r] = comm.recvMemoryOnSetup(r, 0);
|
||||
remoteMemoriesFuture[r] = comm.recvMemory(r, 0);
|
||||
}
|
||||
|
||||
comm.setup();
|
||||
|
||||
for (int r = 0; r < world_size; ++r) {
|
||||
if (r == rank) {
|
||||
continue;
|
||||
@@ -144,8 +142,6 @@ class MyProxyService {
|
||||
deviceSemaphores2_.emplace_back(std::make_shared<mscclpp::Host2DeviceSemaphore>(comm, connections_[r]));
|
||||
remoteMemories_[r] = remoteMemoriesFuture[r].get();
|
||||
}
|
||||
|
||||
comm.setup();
|
||||
}
|
||||
|
||||
void bindThread() {
|
||||
|
||||
@@ -43,19 +43,18 @@ void CommunicatorTestBase::TearDown() {
|
||||
void CommunicatorTestBase::setNumRanksToUse(int num) { numRanksToUse = num; }
|
||||
|
||||
void CommunicatorTestBase::connectMesh(bool useIpc, bool useIb, bool useEthernet) {
|
||||
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connectionFutures(numRanksToUse);
|
||||
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connectionFutures(numRanksToUse);
|
||||
for (int i = 0; i < numRanksToUse; i++) {
|
||||
if (i != gEnv->rank) {
|
||||
if ((rankToNode(i) == rankToNode(gEnv->rank)) && useIpc) {
|
||||
connectionFutures[i] = communicator->connectOnSetup(i, 0, mscclpp::Transport::CudaIpc);
|
||||
connectionFutures[i] = communicator->connect(i, 0, mscclpp::Transport::CudaIpc);
|
||||
} else if (useIb) {
|
||||
connectionFutures[i] = communicator->connectOnSetup(i, 0, ibTransport);
|
||||
connectionFutures[i] = communicator->connect(i, 0, ibTransport);
|
||||
} else if (useEthernet) {
|
||||
connectionFutures[i] = communicator->connectOnSetup(i, 0, mscclpp::Transport::Ethernet);
|
||||
connectionFutures[i] = communicator->connect(i, 0, mscclpp::Transport::Ethernet);
|
||||
}
|
||||
}
|
||||
}
|
||||
communicator->setup();
|
||||
for (int i = 0; i < numRanksToUse; i++) {
|
||||
if (i != gEnv->rank) {
|
||||
connections[i] = connectionFutures[i].get();
|
||||
@@ -69,14 +68,13 @@ void CommunicatorTestBase::registerMemoryPairs(void* buff, size_t buffSize, mscc
|
||||
mscclpp::RegisteredMemory& localMemory,
|
||||
std::unordered_map<int, mscclpp::RegisteredMemory>& remoteMemories) {
|
||||
localMemory = communicator->registerMemory(buff, buffSize, transport);
|
||||
std::unordered_map<int, mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> futureRemoteMemories;
|
||||
std::unordered_map<int, std::shared_future<mscclpp::RegisteredMemory>> futureRemoteMemories;
|
||||
for (int remoteRank : remoteRanks) {
|
||||
if (remoteRank != communicator->bootstrap()->getRank()) {
|
||||
communicator->sendMemoryOnSetup(localMemory, remoteRank, tag);
|
||||
futureRemoteMemories[remoteRank] = communicator->recvMemoryOnSetup(remoteRank, tag);
|
||||
communicator->sendMemory(localMemory, remoteRank, tag);
|
||||
futureRemoteMemories[remoteRank] = communicator->recvMemory(remoteRank, tag);
|
||||
}
|
||||
}
|
||||
communicator->setup();
|
||||
for (int remoteRank : remoteRanks) {
|
||||
if (remoteRank != communicator->bootstrap()->getRank()) {
|
||||
remoteMemories[remoteRank] = futureRemoteMemories[remoteRank].get();
|
||||
@@ -208,7 +206,6 @@ TEST_F(CommunicatorTest, WriteWithDeviceSemaphores) {
|
||||
auto& conn = entry.second;
|
||||
semaphores.insert({entry.first, std::make_shared<mscclpp::Host2DeviceSemaphore>(*communicator.get(), conn)});
|
||||
}
|
||||
communicator->setup();
|
||||
communicator->bootstrap()->barrier();
|
||||
|
||||
deviceBufferInit();
|
||||
@@ -250,7 +247,6 @@ TEST_F(CommunicatorTest, WriteWithHostSemaphores) {
|
||||
if (conn->transport() == mscclpp::Transport::CudaIpc) continue;
|
||||
semaphores.insert({entry.first, std::make_shared<mscclpp::Host2HostSemaphore>(*communicator.get(), conn)});
|
||||
}
|
||||
communicator->setup();
|
||||
communicator->bootstrap()->barrier();
|
||||
|
||||
deviceBufferInit();
|
||||
|
||||
@@ -25,8 +25,8 @@ void MemoryChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::Memory
|
||||
const bool isInPlace = (outputBuff == nullptr);
|
||||
mscclpp::TransportFlags transport = mscclpp::Transport::CudaIpc | ibTransport;
|
||||
|
||||
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connectionFutures(worldSize);
|
||||
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteMemFutures(worldSize);
|
||||
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connectionFutures(worldSize);
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteMemFutures(worldSize);
|
||||
|
||||
mscclpp::RegisteredMemory inputBufRegMem = communicator->registerMemory(inputBuff, inputBuffBytes, transport);
|
||||
mscclpp::RegisteredMemory outputBufRegMem;
|
||||
@@ -39,21 +39,19 @@ void MemoryChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::Memory
|
||||
continue;
|
||||
}
|
||||
if (rankToNode(r) == rankToNode(gEnv->rank)) {
|
||||
connectionFutures[r] = communicator->connectOnSetup(r, 0, mscclpp::Transport::CudaIpc);
|
||||
connectionFutures[r] = communicator->connect(r, 0, mscclpp::Transport::CudaIpc);
|
||||
} else {
|
||||
connectionFutures[r] = communicator->connectOnSetup(r, 0, ibTransport);
|
||||
connectionFutures[r] = communicator->connect(r, 0, ibTransport);
|
||||
}
|
||||
|
||||
if (isInPlace) {
|
||||
communicator->sendMemoryOnSetup(inputBufRegMem, r, 0);
|
||||
communicator->sendMemory(inputBufRegMem, r, 0);
|
||||
} else {
|
||||
communicator->sendMemoryOnSetup(outputBufRegMem, r, 0);
|
||||
communicator->sendMemory(outputBufRegMem, r, 0);
|
||||
}
|
||||
remoteMemFutures[r] = communicator->recvMemoryOnSetup(r, 0);
|
||||
remoteMemFutures[r] = communicator->recvMemory(r, 0);
|
||||
}
|
||||
|
||||
communicator->setup();
|
||||
|
||||
for (int r = 0; r < worldSize; r++) {
|
||||
if (r == rank) {
|
||||
continue;
|
||||
@@ -65,8 +63,6 @@ void MemoryChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::Memory
|
||||
memoryChannels.emplace_back(memorySemaphores[r], remoteMemFutures[r].get(), inputBufRegMem.data(),
|
||||
(isInPlace ? nullptr : outputBufRegMem.data()));
|
||||
}
|
||||
|
||||
communicator->setup();
|
||||
}
|
||||
|
||||
__constant__ DeviceHandle<mscclpp::MemoryChannel> gChannelOneToOneTestConstMemChans;
|
||||
|
||||
@@ -27,8 +27,8 @@ void PortChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::PortChan
|
||||
if (useIb) transport |= ibTransport;
|
||||
if (useEthernet) transport |= mscclpp::Transport::Ethernet;
|
||||
|
||||
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connectionFutures(worldSize);
|
||||
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteMemFutures(worldSize);
|
||||
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connectionFutures(worldSize);
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteMemFutures(worldSize);
|
||||
|
||||
mscclpp::RegisteredMemory sendBufRegMem = communicator->registerMemory(sendBuff, sendBuffBytes, transport);
|
||||
mscclpp::RegisteredMemory recvBufRegMem;
|
||||
@@ -41,23 +41,21 @@ void PortChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::PortChan
|
||||
continue;
|
||||
}
|
||||
if ((rankToNode(r) == rankToNode(gEnv->rank)) && useIPC) {
|
||||
connectionFutures[r] = communicator->connectOnSetup(r, 0, mscclpp::Transport::CudaIpc);
|
||||
connectionFutures[r] = communicator->connect(r, 0, mscclpp::Transport::CudaIpc);
|
||||
} else if (useIb) {
|
||||
connectionFutures[r] = communicator->connectOnSetup(r, 0, ibTransport);
|
||||
connectionFutures[r] = communicator->connect(r, 0, ibTransport);
|
||||
} else if (useEthernet) {
|
||||
connectionFutures[r] = communicator->connectOnSetup(r, 0, mscclpp::Transport::Ethernet);
|
||||
connectionFutures[r] = communicator->connect(r, 0, mscclpp::Transport::Ethernet);
|
||||
}
|
||||
|
||||
if (isInPlace) {
|
||||
communicator->sendMemoryOnSetup(sendBufRegMem, r, 0);
|
||||
communicator->sendMemory(sendBufRegMem, r, 0);
|
||||
} else {
|
||||
communicator->sendMemoryOnSetup(recvBufRegMem, r, 0);
|
||||
communicator->sendMemory(recvBufRegMem, r, 0);
|
||||
}
|
||||
remoteMemFutures[r] = communicator->recvMemoryOnSetup(r, 0);
|
||||
remoteMemFutures[r] = communicator->recvMemory(r, 0);
|
||||
}
|
||||
|
||||
communicator->setup();
|
||||
|
||||
for (int r = 0; r < worldSize; r++) {
|
||||
if (r == rank) {
|
||||
continue;
|
||||
@@ -67,8 +65,6 @@ void PortChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::PortChan
|
||||
portChannels.emplace_back(proxyService->portChannel(cid, proxyService->addMemory(remoteMemFutures[r].get()),
|
||||
proxyService->addMemory(sendBufRegMem)));
|
||||
}
|
||||
|
||||
communicator->setup();
|
||||
}
|
||||
|
||||
__constant__ DeviceHandle<mscclpp::PortChannel> gChannelOneToOneTestConstPortChans;
|
||||
|
||||
@@ -764,7 +764,7 @@ void AllGatherTestEngine::setupConnections() {
|
||||
auto service = std::dynamic_pointer_cast<AllGatherProxyService>(chanService_);
|
||||
setupMeshConnections(devPortChannels, sendBuff_.get(), args_.maxBytes, nullptr, 0,
|
||||
[&](std::vector<std::shared_ptr<mscclpp::Connection>> conns,
|
||||
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>>& remoteMemories,
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>>& remoteMemories,
|
||||
const mscclpp::RegisteredMemory& localMemory) {
|
||||
std::vector<mscclpp::SemaphoreId> semaphoreIds;
|
||||
for (size_t i = 0; i < conns.size(); ++i) {
|
||||
@@ -772,7 +772,6 @@ void AllGatherTestEngine::setupConnections() {
|
||||
service->addRemoteMemory(remoteMemories[i].get());
|
||||
}
|
||||
service->setLocalMemory(localMemory);
|
||||
comm_->setup();
|
||||
});
|
||||
auto portChannels = service->portChannels();
|
||||
if (portChannels.size() > sizeof(constRawPortChan) / sizeof(DeviceHandle<mscclpp::BasePortChannel>)) {
|
||||
|
||||
@@ -362,13 +362,13 @@ std::shared_ptr<mscclpp::BaseProxyService> BaseTestEngine::createProxyService()
|
||||
|
||||
void BaseTestEngine::setupMeshConnectionsInternal(
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>>& connections, mscclpp::RegisteredMemory& localRegMemory,
|
||||
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>>& remoteRegMemories, bool addConnections) {
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>>& remoteRegMemories, bool addConnections) {
|
||||
const int worldSize = args_.totalRanks;
|
||||
const int rank = args_.rank;
|
||||
const int nRanksPerNode = args_.nRanksPerNode;
|
||||
const int thisNode = rank / nRanksPerNode;
|
||||
const mscclpp::Transport ibTransport = IBs[args_.gpuNum];
|
||||
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connectionFutures;
|
||||
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connectionFutures;
|
||||
|
||||
auto rankToNode = [&](int rank) { return rank / nRanksPerNode; };
|
||||
for (int r = 0; r < worldSize; r++) {
|
||||
@@ -383,16 +383,14 @@ void BaseTestEngine::setupMeshConnectionsInternal(
|
||||
transport = ibTransport;
|
||||
}
|
||||
// Connect with all other ranks
|
||||
connectionFutures.push_back(comm_->connectOnSetup(r, 0, transport));
|
||||
connectionFutures.push_back(comm_->connect(r, 0, transport));
|
||||
}
|
||||
comm_->sendMemoryOnSetup(localRegMemory, r, 0);
|
||||
auto remoteMemory = comm_->recvMemoryOnSetup(r, 0);
|
||||
comm_->sendMemory(localRegMemory, r, 0);
|
||||
auto remoteMemory = comm_->recvMemory(r, 0);
|
||||
remoteRegMemories.push_back(remoteMemory);
|
||||
}
|
||||
comm_->setup();
|
||||
std::transform(
|
||||
connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections),
|
||||
[](const mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>& future) { return future.get(); });
|
||||
std::transform(connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections),
|
||||
[](const std::shared_future<std::shared_ptr<mscclpp::Connection>>& future) { return future.get(); });
|
||||
}
|
||||
|
||||
// Create mesh connections between all ranks. If recvBuff is nullptr, assume in-place.
|
||||
@@ -409,7 +407,7 @@ void BaseTestEngine::setupMeshConnections(std::vector<DeviceHandle<mscclpp::Port
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
|
||||
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteRegMemories;
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemories;
|
||||
mscclpp::RegisteredMemory& localRegMemory = (outputBuff) ? outputBufRegMem : inputBufRegMem;
|
||||
|
||||
setupMeshConnectionsInternal(connections, localRegMemory, remoteRegMemories);
|
||||
@@ -424,8 +422,6 @@ void BaseTestEngine::setupMeshConnections(std::vector<DeviceHandle<mscclpp::Port
|
||||
service->addMemory(remoteRegMemories[i].get()), service->addMemory(inputBufRegMem))));
|
||||
}
|
||||
}
|
||||
|
||||
comm_->setup();
|
||||
}
|
||||
|
||||
void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::MemoryChannel>& memoryChannels, void* inputBuff,
|
||||
@@ -441,7 +437,7 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::MemoryChannel>& m
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
|
||||
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteRegMemories;
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemories;
|
||||
mscclpp::RegisteredMemory& localRegMemory =
|
||||
(outputBuff && semantic == ChannelSemantic::PUT) ? outputBufRegMem : inputBufRegMem;
|
||||
setupMeshConnectionsInternal(connections, localRegMemory, remoteRegMemories);
|
||||
@@ -455,7 +451,6 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::MemoryChannel>& m
|
||||
}
|
||||
}
|
||||
}
|
||||
comm_->setup();
|
||||
|
||||
for (size_t i = 0; i < nChannelPerConnection; ++i) {
|
||||
for (size_t cid = 0; cid < connections.size(); ++cid) {
|
||||
@@ -490,13 +485,13 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::MemoryChannel>& m
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
|
||||
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteRegMemories;
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemories;
|
||||
mscclpp::RegisteredMemory& localRegMemory =
|
||||
(getPacketBuff) ? getPacketBufRegMem : ((outputBuff) ? outputBufRegMem : inputBufRegMem);
|
||||
|
||||
setupMeshConnectionsInternal(connections, localRegMemory, remoteRegMemories);
|
||||
|
||||
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteRegMemoriesOutput;
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemoriesOutput;
|
||||
if (outputBuff) {
|
||||
setupMeshConnectionsInternal(connections, outputBufRegMem, remoteRegMemoriesOutput, false);
|
||||
}
|
||||
@@ -512,7 +507,6 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::MemoryChannel>& m
|
||||
connIdToSemId[cid] = service->buildAndAddSemaphore(*comm_, connections[cid]);
|
||||
}
|
||||
}
|
||||
comm_->setup();
|
||||
|
||||
for (size_t cid = 0; cid < connections.size(); ++cid) {
|
||||
if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
|
||||
|
||||
@@ -102,14 +102,14 @@ class BaseTestEngine {
|
||||
|
||||
double benchTime();
|
||||
|
||||
void setupMeshConnectionsInternal(
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>>& connections, mscclpp::RegisteredMemory& localMemory,
|
||||
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>>& remoteRegMemories,
|
||||
bool addConnections = true);
|
||||
void setupMeshConnectionsInternal(std::vector<std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
mscclpp::RegisteredMemory& localMemory,
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>>& remoteRegMemories,
|
||||
bool addConnections = true);
|
||||
|
||||
protected:
|
||||
using SetupChannelFunc = std::function<void(std::vector<std::shared_ptr<mscclpp::Connection>>,
|
||||
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>>&,
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>>&,
|
||||
const mscclpp::RegisteredMemory&)>;
|
||||
template <class T>
|
||||
using DeviceHandle = mscclpp::DeviceHandle<T>;
|
||||
|
||||
@@ -156,29 +156,25 @@ void SendRecvTestEngine::setupConnections() {
|
||||
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memorySemaphores;
|
||||
|
||||
auto sendConnFuture =
|
||||
comm_->connectOnSetup(sendToRank, 0, getTransport(args_.rank, sendToRank, args_.nRanksPerNode, ibDevice));
|
||||
comm_->connect(sendToRank, 0, getTransport(args_.rank, sendToRank, args_.nRanksPerNode, ibDevice));
|
||||
if (recvFromRank != sendToRank) {
|
||||
auto recvConnFuture =
|
||||
comm_->connectOnSetup(recvFromRank, 0, getTransport(args_.rank, recvFromRank, args_.nRanksPerNode, ibDevice));
|
||||
comm_->setup();
|
||||
comm_->connect(recvFromRank, 0, getTransport(args_.rank, recvFromRank, args_.nRanksPerNode, ibDevice));
|
||||
memorySemaphores.push_back(std::make_shared<mscclpp::MemoryDevice2DeviceSemaphore>(*comm_, sendConnFuture.get()));
|
||||
memorySemaphores.push_back(std::make_shared<mscclpp::MemoryDevice2DeviceSemaphore>(*comm_, recvConnFuture.get()));
|
||||
} else {
|
||||
comm_->setup();
|
||||
memorySemaphores.push_back(std::make_shared<mscclpp::MemoryDevice2DeviceSemaphore>(*comm_, sendConnFuture.get()));
|
||||
memorySemaphores.push_back(memorySemaphores[0]); // reuse the send channel if worldSize is 2
|
||||
}
|
||||
comm_->setup();
|
||||
|
||||
std::vector<mscclpp::RegisteredMemory> localMemories;
|
||||
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> futureRemoteMemory;
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>> futureRemoteMemory;
|
||||
|
||||
for (int i : {0, 1}) {
|
||||
auto regMem = comm_->registerMemory(devicePtrs_[i].get(), args_.maxBytes, mscclpp::Transport::CudaIpc | ibDevice);
|
||||
comm_->sendMemoryOnSetup(regMem, ranks[i], 0);
|
||||
comm_->sendMemory(regMem, ranks[i], 0);
|
||||
localMemories.push_back(regMem);
|
||||
futureRemoteMemory.push_back(comm_->recvMemoryOnSetup(ranks[1 - i], 0));
|
||||
comm_->setup();
|
||||
futureRemoteMemory.push_back(comm_->recvMemory(ranks[1 - i], 0));
|
||||
}
|
||||
|
||||
// swap to make sure devicePtrs_[0] in local rank write to devicePtrs_[1] in remote rank
|
||||
|
||||
@@ -29,9 +29,8 @@ TEST_F(LocalCommunicatorTest, RegisterMemory) {
|
||||
TEST_F(LocalCommunicatorTest, SendMemoryToSelf) {
|
||||
int dummy[42];
|
||||
auto memory = comm->registerMemory(&dummy, sizeof(dummy), mscclpp::NoTransports);
|
||||
comm->sendMemoryOnSetup(memory, 0, 0);
|
||||
auto memoryFuture = comm->recvMemoryOnSetup(0, 0);
|
||||
comm->setup();
|
||||
comm->sendMemory(memory, 0, 0);
|
||||
auto memoryFuture = comm->recvMemory(0, 0);
|
||||
auto sameMemory = memoryFuture.get();
|
||||
EXPECT_EQ(sameMemory.data(), memory.data());
|
||||
EXPECT_EQ(sameMemory.size(), memory.size());
|
||||
|
||||
Reference in New Issue
Block a user