Get rid of comm.setup()

This commit is contained in:
Olli Saarikivi
2023-08-31 17:45:58 +00:00
committed by Saeed Maleki
parent 0863e862f5
commit 8cb63a7d1a
23 changed files with 253 additions and 352 deletions

View File

@@ -32,15 +32,15 @@ class Bootstrap {
public:
Bootstrap(){};
virtual ~Bootstrap() = default;
virtual int getRank() = 0;
virtual int getNranks() = 0;
virtual int rank() = 0;
virtual int size() = 0;
virtual void send(void* data, int size, int peer, int tag) = 0;
virtual void recv(void* data, int size, int peer, int tag) = 0;
[[nodiscard]] virtual std::future<void> recv(void* data, int size, int peer, int tag) = 0;
virtual void allGather(void* allData, int size) = 0;
virtual void barrier() = 0;
void send(const std::vector<char>& data, int peer, int tag);
void recv(std::vector<char>& data, int peer, int tag);
std::future<std::vector<char>> recv(int peer, int tag);
};
/// A native implementation of the bootstrap using TCP sockets.
@@ -73,10 +73,10 @@ class TcpBootstrap : public Bootstrap {
void initialize(const std::string& ifIpPortTrio, int64_t timeoutSec = 30);
/// Return the rank of the process.
int getRank() override;
int rank() override;
/// Return the total number of ranks.
int getNranks() override;
int size() override;
/// Send data to another process.
///
@@ -98,7 +98,8 @@ class TcpBootstrap : public Bootstrap {
/// @param size The size of the data to receive.
/// @param peer The rank of the process to receive the data from.
/// @param tag The tag to receive the data with.
void recv(void* data, int size, int peer, int tag) override;
/// @return A future that will be ready when the data has been received.
[[nodiscard]] std::future<void> recv(void* data, int size, int peer, int tag) override;
/// Gather data from all processes.
///
@@ -329,17 +330,17 @@ class RegisteredMemory {
/// Get the size of the memory block.
///
/// @return The size of the memory block.
size_t size();
size_t size() const;
/// Get the transport flags associated with the memory block.
///
/// @return The transport flags associated with the memory block.
TransportFlags transports();
TransportFlags transports() const;
/// Serialize the RegisteredMemory object to a vector of characters.
///
/// @return A vector of characters representing the serialized RegisteredMemory object.
std::vector<char> serialize();
std::vector<char> serialize() const;
/// Deserialize a RegisteredMemory object from a vector of characters.
///
@@ -370,12 +371,12 @@ class Endpoint {
/// Get the transport used.
///
/// @return The transport used.
Transport transport();
Transport transport() const;
/// Serialize the Endpoint object to a vector of characters.
///
/// @return A vector of characters representing the serialized Endpoint object.
std::vector<char> serialize();
std::vector<char> serialize() const;
/// Deserialize a Endpoint object from a vector of characters.
///
@@ -537,50 +538,14 @@ struct Setuppable {
virtual void endSetup(std::shared_ptr<Bootstrap> bootstrap);
};
/// A non-blocking future that can be used to check if a value is ready and retrieve it.
template <typename T>
class NonblockingFuture {
std::shared_future<T> future;
public:
/// Default constructor.
NonblockingFuture() = default;
/// Constructor that takes a shared future and moves it into the NonblockingFuture.
///
/// @param future The shared future to move.
NonblockingFuture(std::shared_future<T>&& future) : future(std::move(future)) {}
/// Copy constructor.
///
/// @param other The @ref NonblockingFuture to copy.
NonblockingFuture(const NonblockingFuture& other) = default;
/// Check if the value is ready to be retrieved.
///
/// @return True if the value is ready, false otherwise.
bool ready() const { return future.wait_for(std::chrono::seconds(0)) == std::future_status::ready; }
/// Get the value.
///
/// @return The value.
///
/// @throws Error if the value is not ready.
T get() const {
if (!ready()) throw Error("NonblockingFuture::get() called before ready", ErrorCode::InvalidUsage);
return future.get();
}
};
/// A class that sets up all registered memories and connections between processes.
///
/// A typical way to use this class:
/// 1. Call @ref connectOnSetup() to declare connections between the calling process with other processes.
/// 1. Call @ref connect() to declare connections between the calling process with other processes.
/// 2. Call @ref registerMemory() to register memory regions that will be used for communication.
/// 3. Call @ref sendMemoryOnSetup() or @ref recvMemoryOnSetup() to send/receive registered memory regions to/from
/// 3. Call @ref sendMemory() or @ref recvMemory() to send/receive registered memory regions to/from
/// other processes.
/// 4. Call @ref setup() to set up all registered memories and connections declared in the previous steps.
/// 5. Call @ref NonblockingFuture<RegisteredMemory>::get() to get the registered memory regions received from other
/// 5. Call @ref std::future<RegisteredMemory>::get() to get the registered memory regions received from other
/// processes.
/// 6. All done; use connections and registered memories to build channels.
///
@@ -613,30 +578,23 @@ class Communicator {
/// @return RegisteredMemory A handle to the buffer.
RegisteredMemory registerMemory(void* ptr, size_t size, TransportFlags transports);
/// Send information of a registered memory to the remote side on setup.
///
/// This function registers a send to a remote process that will happen by a following call of @ref setup(). The send
/// will carry information about a registered memory on the local process.
/// Send information of a registered memory to the remote side.
///
/// @param memory The registered memory buffer to send information about.
/// @param remoteRank The rank of the remote process.
/// @param tag The tag to use for identifying the send.
void sendMemoryOnSetup(RegisteredMemory memory, int remoteRank, int tag);
void sendMemory(RegisteredMemory memory, int remoteRank, int tag);
/// Receive memory on setup.
///
/// This function registers a receive from a remote process that will happen by a following call of @ref setup(). The
/// receive will carry information about a registered memory on the remote process.
/// Receive memory.
///
/// @param remoteRank The rank of the remote process.
/// @param tag The tag to use for identifying the receive.
/// @return NonblockingFuture<RegisteredMemory> A non-blocking future of registered memory.
NonblockingFuture<RegisteredMemory> recvMemoryOnSetup(int remoteRank, int tag);
/// @return std::future<RegisteredMemory> A future of registered memory.
std::future<RegisteredMemory> recvMemory(int remoteRank, int tag);
/// Connect to a remote rank on setup.
/// Connect to a remote rank.
///
/// This function only prepares metadata for connection. The actual connection is made by a following call of
/// @ref setup(). Note that this function is two-way and a connection from rank `i` to remote rank `j` needs
/// Note that this function is two-way and a connection from rank `i` to remote rank `j` needs
/// to have a counterpart from rank `j` to rank `i`. Note that with IB, buffers are registered at a page level and if
/// a buffer is spread through multiple pages and do not fully utilize all of them, IB's QP has to register for all
/// involved pages. This potentially has security risks if the connection's accesses are given to a malicious process.
@@ -644,9 +602,8 @@ class Communicator {
/// @param remoteRank The rank of the remote process.
/// @param tag The tag of the connection for identifying it.
/// @param config The configuration for the local endpoint.
/// @return NonblockingFuture<NonblockingFuture<std::shared_ptr<Connection>>> A non-blocking future of shared pointer
/// to the connection.
NonblockingFuture<std::shared_ptr<Connection>> connectOnSetup(int remoteRank, int tag, EndpointConfig localConfig);
/// @return std::future<std::shared_ptr<Connection>> A future of shared pointer to the connection.
std::future<std::shared_ptr<Connection>> connect(int remoteRank, int tag, EndpointConfig localConfig);
/// Get the remote rank a connection is connected to.
///
@@ -660,18 +617,6 @@ class Communicator {
/// @return The tag the connection was made with.
int tagOf(const Connection& connection);
/// Add a custom Setuppable object to a list of objects to be setup later, when @ref setup() is called.
///
/// @param setuppable A shared pointer to the Setuppable object.
void onSetup(std::shared_ptr<Setuppable> setuppable);
/// Setup all objects that have registered for setup.
///
/// This includes previous calls of @ref sendMemoryOnSetup(), @ref recvMemoryOnSetup(), @ref connectOnSetup(), and
/// @ref onSetup(). It is allowed to call this function multiple times, where the n-th call will only setup objects
/// that have been registered after the (n-1)-th call.
void setup();
private:
// The interal implementation.
struct Impl;

View File

@@ -30,7 +30,7 @@ template <template <typename> typename InboundDeleter, template <typename> typen
class BaseSemaphore {
protected:
/// The registered memory for the remote peer's inbound semaphore ID.
NonblockingFuture<RegisteredMemory> remoteInboundSemaphoreIdsRegMem_;
std::shared_future<RegisteredMemory> remoteInboundSemaphoreIdsRegMem_;
/// The inbound semaphore ID that is incremented by the remote peer and waited on by the local peer.
///

View File

@@ -40,12 +40,11 @@ def setup_connections(comm, rank, world_size, element_size, proxy_service):
for r in range(world_size):
if r == rank:
continue
conn = comm.connect_on_setup(r, 0, mscclpp.Transport.CudaIpc)
conn = comm.connect(r, 0, mscclpp.Transport.CudaIpc)
connections.append(conn)
comm.send_memory_on_setup(reg_mem, r, 0)
remote_mem = comm.recv_memory_on_setup(r, 0)
comm.send_memory(reg_mem, r, 0)
remote_mem = comm.recv_memory(r, 0)
remote_memories.append(remote_mem)
comm.setup()
connections = [conn.get() for conn in connections]

View File

@@ -35,15 +35,13 @@ def main(args):
size = elements * memory.itemsize
my_reg_mem = comm.register_memory(ptr, size, mscclpp.Transport.IB0)
conn = comm.connect_on_setup((rank + 1) % 2, 0, mscclpp.Transport.IB0)
conn = comm.connect((rank + 1) % 2, 0, mscclpp.Transport.IB0)
other_reg_mem = None
if rank == 0:
other_reg_mem = comm.recv_memory_on_setup((rank + 1) % 2, 0)
other_reg_mem = comm.recv_memory((rank + 1) % 2, 0)
else:
comm.send_memory_on_setup(my_reg_mem, (rank + 1) % 2, 0)
comm.setup()
comm.send_memory(my_reg_mem, (rank + 1) % 2, 0)
if rank == 0:
other_reg_mem = other_reg_mem.get()

View File

@@ -21,19 +21,17 @@ extern void register_utils(nb::module_& m);
extern void register_numa(nb::module_& m);
template <typename T>
void def_nonblocking_future(nb::handle& m, const std::string& typestr) {
std::string pyclass_name = std::string("NonblockingFuture") + typestr;
nb::class_<NonblockingFuture<T>>(m, pyclass_name.c_str())
.def("ready", &NonblockingFuture<T>::ready)
.def("get", &NonblockingFuture<T>::get);
void def_future(nb::handle& m, const std::string& typestr) {
std::string pyclass_name = std::string("std_future_") + typestr;
nb::class_<std::future<T>>(m, pyclass_name.c_str()).def("get", &std::future<T>::get);
}
void register_core(nb::module_& m) {
m.def("version", &version);
nb::class_<Bootstrap>(m, "Bootstrap")
.def("get_rank", &Bootstrap::getRank)
.def("get_n_ranks", &Bootstrap::getNranks)
.def_prop_ro("rank", &Bootstrap::rank)
.def_prop_ro("size", &Bootstrap::size)
.def(
"send",
[](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) {
@@ -45,15 +43,15 @@ void register_core(nb::module_& m) {
"recv",
[](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) {
void* data = reinterpret_cast<void*>(ptr);
self->recv(data, size, peer, tag);
return self->recv(data, size, peer, tag);
},
nb::arg("data"), nb::arg("size"), nb::arg("peer"), nb::arg("tag"))
.def("all_gather", &Bootstrap::allGather, nb::arg("allData"), nb::arg("size"))
.def("barrier", &Bootstrap::barrier)
.def("send", (void (Bootstrap::*)(const std::vector<char>&, int, int)) & Bootstrap::send, nb::arg("data"),
nb::arg("peer"), nb::arg("tag"))
.def("recv", (void (Bootstrap::*)(std::vector<char>&, int, int)) & Bootstrap::recv, nb::arg("data"),
nb::arg("peer"), nb::arg("tag"));
.def("recv", (std::future<std::vector<char>>(Bootstrap::*)(int, int)) & Bootstrap::recv, nb::arg("peer"),
nb::arg("tag"));
nb::class_<UniqueId>(m, "UniqueId");
@@ -149,8 +147,8 @@ void register_core(nb::module_& m) {
.def("create_endpoint", &Context::createEndpoint, nb::arg("config"))
.def("connect", &Context::connect, nb::arg("local_endpoint"), nb::arg("remote_endpoint"));
def_nonblocking_future<RegisteredMemory>(m, "RegisteredMemory");
def_nonblocking_future<std::shared_ptr<Connection>>(m, "shared_ptr_Connection");
def_future<RegisteredMemory>(m, "RegisteredMemory");
def_future<std::shared_ptr<Connection>>(m, "shared_ptr_Connection");
nb::class_<Communicator>(m, "Communicator")
.def(nb::init<std::shared_ptr<Bootstrap>, std::shared_ptr<Context>>(), nb::arg("bootstrap"),
@@ -163,14 +161,11 @@ void register_core(nb::module_& m) {
return self->registerMemory((void*)ptr, size, transports);
},
nb::arg("ptr"), nb::arg("size"), nb::arg("transports"))
.def("send_memory_on_setup", &Communicator::sendMemoryOnSetup, nb::arg("memory"), nb::arg("remoteRank"),
nb::arg("tag"))
.def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag"))
.def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"),
nb::arg("localConfig"))
.def("send_memory", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remoteRank"), nb::arg("tag"))
.def("recv_memory", &Communicator::recvMemory, nb::arg("remoteRank"), nb::arg("tag"))
.def("connect", &Communicator::connect, nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
.def("remote_rank_of", &Communicator::remoteRankOf)
.def("tag_of", &Communicator::tagOf)
.def("setup", &Communicator::setup);
.def("tag_of", &Communicator::tagOf);
}
NB_MODULE(_mscclpp, m) {

View File

@@ -76,8 +76,7 @@ class MscclppGroup:
def make_connection(self, remote_ranks: list[int], transport: Transport) -> dict[int, Connection]:
connections = {}
for rank in remote_ranks:
connections[rank] = self.communicator.connect_on_setup(rank, 0, transport)
self.communicator.setup()
connections[rank] = self.communicator.connect(rank, 0, transport)
connections = {rank: connections[rank].get() for rank in connections}
return connections
@@ -93,9 +92,8 @@ class MscclppGroup:
all_registered_memories[self.my_rank] = local_reg_memory
future_memories = {}
for rank in connections:
self.communicator.send_memory_on_setup(local_reg_memory, rank, 0)
future_memories[rank] = self.communicator.recv_memory_on_setup(rank, 0)
self.communicator.setup()
self.communicator.send_memory(local_reg_memory, rank, 0)
future_memories[rank] = self.communicator.recv_memory(rank, 0)
for rank in connections:
all_registered_memories[rank] = future_memories[rank].get()
return all_registered_memories
@@ -108,7 +106,6 @@ class MscclppGroup:
semaphores = {}
for rank in connections:
semaphores[rank] = semaphore_type(self.communicator, connections[rank])
self.communicator.setup()
return semaphores
def make_sm_channels(self, tensor: cp.ndarray, connections: dict[int, Connection]) -> dict[int, SmChannel]:

View File

@@ -6,6 +6,7 @@
#include <cstring>
#include <mscclpp/core.hpp>
#include <mscclpp/errors.hpp>
#include <queue>
#include <sstream>
#include <thread>
#include <unordered_map>
@@ -41,11 +42,16 @@ MSCCLPP_API_CPP void Bootstrap::send(const std::vector<char>& data, int peer, in
send((void*)data.data(), data.size(), peer, tag + 1);
}
MSCCLPP_API_CPP void Bootstrap::recv(std::vector<char>& data, int peer, int tag) {
size_t size;
recv((void*)&size, sizeof(size_t), peer, tag);
data.resize(size);
recv((void*)data.data(), data.size(), peer, tag + 1);
MSCCLPP_API_CPP std::future<std::vector<char>> Bootstrap::recv(int peer, int tag) {
auto size = std::make_unique<size_t>();
auto recvTask = recv((void*)size.get(), sizeof(size_t), peer, tag);
return std::async(std::launch::deferred,
[this, size = std::move(size), recvTask = std::move(recvTask), peer, tag]() mutable {
recvTask.wait();
std::vector<char> data(*size);
recv((void*)data.data(), data.size(), peer, tag + 1).wait();
return data;
});
}
struct UniqueIdInternal {
@@ -54,6 +60,22 @@ struct UniqueIdInternal {
};
static_assert(sizeof(UniqueIdInternal) <= sizeof(UniqueId), "UniqueIdInternal is too large to fit into UniqueId");
struct RecvTask {
std::promise<void> promise;
std::shared_ptr<Socket> sock;
void* data;
int size;
RecvTask() = default;
RecvTask(std::shared_ptr<Socket> sock, void* data, int size) : sock(sock), data(data), size(size) {}
};
struct RecvThreadData {
std::thread thread;
std::mutex mutex;
std::condition_variable cond;
std::queue<RecvTask> queue;
};
class TcpBootstrap::Impl {
public:
Impl(int rank, int nRanks);
@@ -63,11 +85,11 @@ class TcpBootstrap::Impl {
void establishConnections(int64_t timeoutSec);
UniqueId createUniqueId();
UniqueId getUniqueId() const;
int getRank();
int getNranks();
int rank();
int size();
void allGather(void* allData, int size);
void send(void* data, int size, int peer, int tag);
void recv(void* data, int size, int peer, int tag);
std::future<void> recv(void* data, int size, int peer, int tag);
void barrier();
void close();
@@ -85,6 +107,7 @@ class TcpBootstrap::Impl {
std::unique_ptr<uint32_t> abortFlagStorage_;
volatile uint32_t* abortFlag_;
std::thread rootThread_;
std::unordered_map<std::pair<int, int>, std::unique_ptr<RecvThreadData>, PairHash> recvThreads_;
char netIfName_[MAX_IF_NAME_SIZE + 1];
SocketAddress netIfAddr_;
std::unordered_map<std::pair<int, int>, std::shared_ptr<Socket>, PairHash> peerSendSockets_;
@@ -93,6 +116,7 @@ class TcpBootstrap::Impl {
void netSend(Socket* sock, const void* data, int size);
void netRecv(Socket* sock, void* data, int size);
RecvThreadData* getRecvThreadData(int peer, int tag);
std::shared_ptr<Socket> getPeerSendSocket(int peer, int tag);
std::shared_ptr<Socket> getPeerRecvSocket(int peer, int tag);
@@ -128,9 +152,9 @@ UniqueId TcpBootstrap::Impl::createUniqueId() {
return getUniqueId();
}
int TcpBootstrap::Impl::getRank() { return rank_; }
int TcpBootstrap::Impl::rank() { return rank_; }
int TcpBootstrap::Impl::getNranks() { return nRanks_; }
int TcpBootstrap::Impl::size() { return nRanks_; }
void TcpBootstrap::Impl::initialize(const UniqueId& uniqueId, int64_t timeoutSec) {
netInit("", "");
@@ -176,6 +200,16 @@ TcpBootstrap::Impl::~Impl() {
if (rootThread_.joinable()) {
rootThread_.join();
}
for (auto& it : recvThreads_) {
{
std::lock_guard<std::mutex> lock(it.second->mutex);
it.second->queue.push(RecvTask()); // signal thread to exit
it.second->cond.notify_one();
}
}
for (auto& it : recvThreads_) {
it.second->thread.join();
}
}
void TcpBootstrap::Impl::getRemoteAddresses(Socket* listenSock, std::vector<SocketAddress>& rankAddresses,
@@ -404,6 +438,32 @@ void TcpBootstrap::Impl::allGather(void* allData, int size) {
TRACE(MSCCLPP_INIT, "rank %d nranks %d size %d - DONE", rank, nRanks, size);
}
RecvThreadData* TcpBootstrap::Impl::getRecvThreadData(int peer, int tag) {
auto it = recvThreads_.find(std::make_pair(peer, tag));
if (it != recvThreads_.end()) {
return it->second.get();
}
auto threadData = std::make_unique<RecvThreadData>();
threadData->thread = std::thread([this, threadData = threadData.get(), peer, tag]() {
for (;;) {
RecvTask task;
{
std::unique_lock<std::mutex> lock(threadData->mutex);
threadData->cond.wait(lock, [&]() { return !threadData->queue.empty(); });
task = std::move(threadData->queue.front());
threadData->queue.pop();
}
if (task.sock == nullptr) {
break;
}
netRecv(task.sock.get(), task.data, task.size);
task.promise.set_value();
}
});
recvThreads_[std::make_pair(peer, tag)] = std::move(threadData);
return recvThreads_[std::make_pair(peer, tag)].get();
}
std::shared_ptr<Socket> TcpBootstrap::Impl::getPeerSendSocket(int peer, int tag) {
auto it = peerSendSockets_.find(std::make_pair(peer, tag));
if (it != peerSendSockets_.end()) {
@@ -456,9 +516,17 @@ void TcpBootstrap::Impl::send(void* data, int size, int peer, int tag) {
netSend(sock.get(), data, size);
}
void TcpBootstrap::Impl::recv(void* data, int size, int peer, int tag) {
std::future<void> TcpBootstrap::Impl::recv(void* data, int size, int peer, int tag) {
auto sock = getPeerRecvSocket(peer, tag);
netRecv(sock.get(), data, size);
RecvTask task(sock, data, size);
auto future = task.promise.get_future();
auto threadData = getRecvThreadData(peer, tag);
{
std::lock_guard<std::mutex> lock(threadData->mutex);
threadData->queue.push(std::move(task));
threadData->cond.notify_one();
}
return future;
}
void TcpBootstrap::Impl::barrier() { allGather(barrierArr_.data(), sizeof(int)); }
@@ -478,16 +546,16 @@ MSCCLPP_API_CPP UniqueId TcpBootstrap::createUniqueId() { return pimpl_->createU
MSCCLPP_API_CPP UniqueId TcpBootstrap::getUniqueId() const { return pimpl_->getUniqueId(); }
MSCCLPP_API_CPP int TcpBootstrap::getRank() { return pimpl_->getRank(); }
MSCCLPP_API_CPP int TcpBootstrap::rank() { return pimpl_->rank(); }
MSCCLPP_API_CPP int TcpBootstrap::getNranks() { return pimpl_->getNranks(); }
MSCCLPP_API_CPP int TcpBootstrap::size() { return pimpl_->size(); }
MSCCLPP_API_CPP void TcpBootstrap::send(void* data, int size, int peer, int tag) {
pimpl_->send(data, size, peer, tag);
}
MSCCLPP_API_CPP void TcpBootstrap::recv(void* data, int size, int peer, int tag) {
pimpl_->recv(data, size, peer, tag);
MSCCLPP_API_CPP std::future<void> TcpBootstrap::recv(void* data, int size, int peer, int tag) {
return pimpl_->recv(data, size, peer, tag);
}
MSCCLPP_API_CPP void TcpBootstrap::allGather(void* allData, int size) { pimpl_->allGather(allData, size); }

View File

@@ -30,79 +30,29 @@ MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t
return context()->registerMemory(ptr, size, transports);
}
struct MemorySender : public Setuppable {
MemorySender(RegisteredMemory memory, int remoteRank, int tag)
: memory_(memory), remoteRank_(remoteRank), tag_(tag) {}
void beginSetup(std::shared_ptr<Bootstrap> bootstrap) override {
bootstrap->send(memory_.serialize(), remoteRank_, tag_);
}
RegisteredMemory memory_;
int remoteRank_;
int tag_;
};
MSCCLPP_API_CPP void Communicator::sendMemoryOnSetup(RegisteredMemory memory, int remoteRank, int tag) {
onSetup(std::make_shared<MemorySender>(memory, remoteRank, tag));
MSCCLPP_API_CPP void Communicator::sendMemory(RegisteredMemory memory, int remoteRank, int tag) {
pimpl_->bootstrap_->send(memory.serialize(), remoteRank, tag);
}
struct MemoryReceiver : public Setuppable {
MemoryReceiver(int remoteRank, int tag) : remoteRank_(remoteRank), tag_(tag) {}
void endSetup(std::shared_ptr<Bootstrap> bootstrap) override {
std::vector<char> data;
bootstrap->recv(data, remoteRank_, tag_);
memoryPromise_.set_value(RegisteredMemory::deserialize(data));
}
std::promise<RegisteredMemory> memoryPromise_;
int remoteRank_;
int tag_;
};
MSCCLPP_API_CPP NonblockingFuture<RegisteredMemory> Communicator::recvMemoryOnSetup(int remoteRank, int tag) {
auto memoryReceiver = std::make_shared<MemoryReceiver>(remoteRank, tag);
onSetup(memoryReceiver);
return NonblockingFuture<RegisteredMemory>(memoryReceiver->memoryPromise_.get_future());
MSCCLPP_API_CPP std::future<RegisteredMemory> Communicator::recvMemory(int remoteRank, int tag) {
auto futureData = pimpl_->bootstrap_->recv(remoteRank, tag);
return std::async(std::launch::deferred, [futureData = std::move(futureData)]() mutable {
return RegisteredMemory::deserialize(futureData.get());
});
}
struct Communicator::Impl::Connector : public Setuppable {
Connector(Communicator& comm, Communicator::Impl& commImpl_, int remoteRank, int tag, EndpointConfig localConfig)
: comm_(comm),
commImpl_(commImpl_),
remoteRank_(remoteRank),
tag_(tag),
localEndpoint_(comm.context()->createEndpoint(localConfig)) {}
void beginSetup(std::shared_ptr<Bootstrap> bootstrap) override {
bootstrap->send(localEndpoint_.serialize(), remoteRank_, tag_);
}
void endSetup(std::shared_ptr<Bootstrap> bootstrap) override {
std::vector<char> data;
bootstrap->recv(data, remoteRank_, tag_);
auto remoteEndpoint = Endpoint::deserialize(data);
auto connection = comm_.context()->connect(localEndpoint_, remoteEndpoint);
commImpl_.connectionInfos_[connection.get()] = {remoteRank_, tag_};
connectionPromise_.set_value(connection);
INFO(MSCCLPP_INIT, "Connection %d -> %d created (%s)", comm_.bootstrap()->getRank(), remoteRank_,
connection->getTransportName().c_str());
}
std::promise<std::shared_ptr<Connection>> connectionPromise_;
Communicator& comm_;
Communicator::Impl& commImpl_;
int remoteRank_;
int tag_;
Endpoint localEndpoint_;
};
MSCCLPP_API_CPP NonblockingFuture<std::shared_ptr<Connection>> Communicator::connectOnSetup(
int remoteRank, int tag, EndpointConfig localConfig) {
auto connector = std::make_shared<Communicator::Impl::Connector>(*this, *pimpl_, remoteRank, tag, localConfig);
onSetup(connector);
return NonblockingFuture<std::shared_ptr<Connection>>(connector->connectionPromise_.get_future());
MSCCLPP_API_CPP std::future<std::shared_ptr<Connection>> Communicator::connect(int remoteRank, int tag,
EndpointConfig localConfig) {
auto localEndpoint = context()->createEndpoint(localConfig);
pimpl_->bootstrap_->send(localEndpoint.serialize(), remoteRank, tag);
auto futureData = pimpl_->bootstrap_->recv(remoteRank, tag);
return std::async(std::launch::deferred, [this, localEndpoint = std::move(localEndpoint),
futureData = std::move(futureData), remoteRank, tag]() mutable {
auto remoteEndpoint = Endpoint::deserialize(futureData.get());
auto connection = context()->connect(localEndpoint, remoteEndpoint);
pimpl_->connectionInfos_[connection.get()] = {remoteRank, tag};
return connection;
});
}
MSCCLPP_API_CPP int Communicator::remoteRankOf(const Connection& connection) {
@@ -113,18 +63,4 @@ MSCCLPP_API_CPP int Communicator::tagOf(const Connection& connection) {
return pimpl_->connectionInfos_.at(&connection).tag;
}
MSCCLPP_API_CPP void Communicator::onSetup(std::shared_ptr<Setuppable> setuppable) {
pimpl_->toSetup_.push_back(setuppable);
}
MSCCLPP_API_CPP void Communicator::setup() {
for (auto& setuppable : pimpl_->toSetup_) {
setuppable->beginSetup(pimpl_->bootstrap_);
}
for (auto& setuppable : pimpl_->toSetup_) {
setuppable->endSetup(pimpl_->bootstrap_);
}
pimpl_->toSetup_.clear();
}
} // namespace mscclpp

View File

@@ -18,9 +18,9 @@ Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl)
}
}
MSCCLPP_API_CPP Transport Endpoint::transport() { return pimpl_->transport_; }
MSCCLPP_API_CPP Transport Endpoint::transport() const { return pimpl_->transport_; }
MSCCLPP_API_CPP std::vector<char> Endpoint::serialize() {
MSCCLPP_API_CPP std::vector<char> Endpoint::serialize() const {
std::vector<char> data;
std::copy_n(reinterpret_cast<char*>(&pimpl_->transport_), sizeof(pimpl_->transport_), std::back_inserter(data));
std::copy_n(reinterpret_cast<char*>(&pimpl_->hostHash_), sizeof(pimpl_->hostHash_), std::back_inserter(data));

View File

@@ -64,11 +64,11 @@ MSCCLPP_API_CPP RegisteredMemory::~RegisteredMemory() = default;
MSCCLPP_API_CPP void* RegisteredMemory::data() const { return pimpl_->data; }
MSCCLPP_API_CPP size_t RegisteredMemory::size() { return pimpl_->size; }
MSCCLPP_API_CPP size_t RegisteredMemory::size() const { return pimpl_->size; }
MSCCLPP_API_CPP TransportFlags RegisteredMemory::transports() { return pimpl_->transports; }
MSCCLPP_API_CPP TransportFlags RegisteredMemory::transports() const { return pimpl_->transports; }
MSCCLPP_API_CPP std::vector<char> RegisteredMemory::serialize() {
MSCCLPP_API_CPP std::vector<char> RegisteredMemory::serialize() const {
std::vector<char> result;
std::copy_n(reinterpret_cast<char*>(&pimpl_->originalDataPtr), sizeof(pimpl_->originalDataPtr),
std::back_inserter(result));

View File

@@ -9,14 +9,14 @@
namespace mscclpp {
static NonblockingFuture<RegisteredMemory> setupInboundSemaphoreId(Communicator& communicator, Connection* connection,
void* localInboundSemaphoreId) {
static std::future<RegisteredMemory> setupInboundSemaphoreId(Communicator& communicator, Connection* connection,
void* localInboundSemaphoreId) {
auto localInboundSemaphoreIdsRegMem =
communicator.registerMemory(localInboundSemaphoreId, sizeof(uint64_t), connection->transport());
int remoteRank = communicator.remoteRankOf(*connection);
int tag = communicator.tagOf(*connection);
communicator.sendMemoryOnSetup(localInboundSemaphoreIdsRegMem, remoteRank, tag);
return communicator.recvMemoryOnSetup(remoteRank, tag);
communicator.sendMemory(localInboundSemaphoreIdsRegMem, remoteRank, tag);
return communicator.recvMemory(remoteRank, tag);
}
MSCCLPP_API_CPP Host2DeviceSemaphore::Host2DeviceSemaphore(Communicator& communicator,
@@ -24,7 +24,7 @@ MSCCLPP_API_CPP Host2DeviceSemaphore::Host2DeviceSemaphore(Communicator& communi
: BaseSemaphore(allocUniqueCuda<uint64_t>(), allocUniqueCuda<uint64_t>(), std::make_unique<uint64_t>()),
connection_(connection) {
INFO(MSCCLPP_INIT, "Creating a Host2Device semaphore for %s transport from %d to %d",
connection->getTransportName().c_str(), communicator.bootstrap()->getRank(),
connection->getTransportName().c_str(), communicator.bootstrap()->rank(),
communicator.remoteRankOf(*connection));
remoteInboundSemaphoreIdsRegMem_ =
setupInboundSemaphoreId(communicator, connection.get(), localInboundSemaphore_.get());
@@ -49,7 +49,7 @@ MSCCLPP_API_CPP Host2HostSemaphore::Host2HostSemaphore(Communicator& communicato
: BaseSemaphore(std::make_unique<uint64_t>(), std::make_unique<uint64_t>(), std::make_unique<uint64_t>()),
connection_(connection) {
INFO(MSCCLPP_INIT, "Creating a Host2Host semaphore for %s transport from %d to %d",
connection->getTransportName().c_str(), communicator.bootstrap()->getRank(),
connection->getTransportName().c_str(), communicator.bootstrap()->rank(),
communicator.remoteRankOf(*connection));
if (connection->transport() == Transport::CudaIpc) {
@@ -88,7 +88,7 @@ MSCCLPP_API_CPP SmDevice2DeviceSemaphore::SmDevice2DeviceSemaphore(Communicator&
std::shared_ptr<Connection> connection)
: BaseSemaphore(allocUniqueCuda<uint64_t>(), allocUniqueCuda<uint64_t>(), allocUniqueCuda<uint64_t>()) {
INFO(MSCCLPP_INIT, "Creating a Device2Device semaphore for %s transport from %d to %d",
connection->getTransportName().c_str(), communicator.bootstrap()->getRank(),
connection->getTransportName().c_str(), communicator.bootstrap()->rank(),
communicator.remoteRankOf(*connection));
if (connection->transport() == Transport::CudaIpc) {
remoteInboundSemaphoreIdsRegMem_ =

View File

@@ -214,8 +214,8 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co
mscclpp::Transport ibTransport = mscclpp::getIBTransportByDeviceName(ibDevStr);
std::vector<mscclpp::SemaphoreId> semaphoreIds;
std::vector<mscclpp::RegisteredMemory> localMemories;
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connections(world_size);
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteMemories;
std::vector<std::future<std::shared_ptr<mscclpp::Connection>>> connections(world_size);
std::vector<std::future<mscclpp::RegisteredMemory>> remoteMemories;
for (int r = 0; r < world_size; ++r) {
if (r == rank) continue;
@@ -226,22 +226,18 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co
transport = ibTransport;
}
// Connect with all other ranks
connections[r] = comm.connectOnSetup(r, 0, transport);
connections[r] = comm.connect(r, 0, transport);
auto memory = comm.registerMemory(data_d, dataSize, mscclpp::Transport::CudaIpc | ibTransport);
localMemories.push_back(memory);
comm.sendMemoryOnSetup(memory, r, 0);
remoteMemories.push_back(comm.recvMemoryOnSetup(r, 0));
comm.sendMemory(memory, r, 0);
remoteMemories.push_back(comm.recvMemory(r, 0));
}
comm.setup();
for (int r = 0; r < world_size; ++r) {
if (r == rank) continue;
semaphoreIds.push_back(proxyService.buildAndAddSemaphore(comm, connections[r].get()));
}
comm.setup();
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>> proxyChannels;
for (size_t i = 0; i < semaphoreIds.size(); ++i) {
proxyChannels.push_back(mscclpp::deviceHandle(mscclpp::SimpleProxyChannel(

View File

@@ -116,8 +116,8 @@ class MyProxyService {
int cudaNum = rankToLocalRank(rank);
std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum);
mscclpp::Transport ibTransport = mscclpp::getIBTransportByDeviceName(ibDevStr);
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connectionsFuture(world_size);
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteMemoriesFuture(world_size);
std::vector<std::future<std::shared_ptr<mscclpp::Connection>>> connectionsFuture(world_size);
std::vector<std::future<mscclpp::RegisteredMemory>> remoteMemoriesFuture(world_size);
localMemory_ = comm.registerMemory(data_d, dataSize, mscclpp::Transport::CudaIpc | ibTransport);
for (int r = 0; r < world_size; ++r) {
@@ -134,14 +134,12 @@ class MyProxyService {
transport = ibTransport;
}
// Connect with all other ranks
connectionsFuture[r] = comm.connectOnSetup(r, 0, transport);
comm.sendMemoryOnSetup(localMemory_, r, 0);
connectionsFuture[r] = comm.connect(r, 0, transport);
comm.sendMemory(localMemory_, r, 0);
remoteMemoriesFuture[r] = comm.recvMemoryOnSetup(r, 0);
remoteMemoriesFuture[r] = comm.recvMemory(r, 0);
}
comm.setup();
for (int r = 0; r < world_size; ++r) {
if (r == rank) {
continue;
@@ -156,8 +154,6 @@ class MyProxyService {
deviceSemaphores2_.emplace_back(std::make_shared<mscclpp::Host2DeviceSemaphore>(comm, connections_[r]));
remoteMemories_[r] = remoteMemoriesFuture[r].get();
}
comm.setup();
}
void bindThread() {

View File

@@ -6,10 +6,10 @@
#include "mp_unit_tests.hpp"
void BootstrapTest::bootstrapTestAllGather(std::shared_ptr<mscclpp::Bootstrap> bootstrap) {
std::vector<int> tmp(bootstrap->getNranks(), 0);
tmp[bootstrap->getRank()] = bootstrap->getRank() + 1;
std::vector<int> tmp(bootstrap->size(), 0);
tmp[bootstrap->rank()] = bootstrap->rank() + 1;
bootstrap->allGather(tmp.data(), sizeof(int));
for (int i = 0; i < bootstrap->getNranks(); ++i) {
for (int i = 0; i < bootstrap->size(); ++i) {
EXPECT_EQ(tmp[i], i + 1);
}
}
@@ -17,25 +17,25 @@ void BootstrapTest::bootstrapTestAllGather(std::shared_ptr<mscclpp::Bootstrap> b
void BootstrapTest::bootstrapTestBarrier(std::shared_ptr<mscclpp::Bootstrap> bootstrap) { bootstrap->barrier(); }
void BootstrapTest::bootstrapTestSendRecv(std::shared_ptr<mscclpp::Bootstrap> bootstrap) {
for (int i = 0; i < bootstrap->getNranks(); i++) {
if (bootstrap->getRank() == i) continue;
int msg1 = (bootstrap->getRank() + 1) * 3;
int msg2 = (bootstrap->getRank() + 1) * 3 + 1;
int msg3 = (bootstrap->getRank() + 1) * 3 + 2;
for (int i = 0; i < bootstrap->size(); i++) {
if (bootstrap->rank() == i) continue;
int msg1 = (bootstrap->rank() + 1) * 3;
int msg2 = (bootstrap->rank() + 1) * 3 + 1;
int msg3 = (bootstrap->rank() + 1) * 3 + 2;
bootstrap->send(&msg1, sizeof(int), i, 0);
bootstrap->send(&msg2, sizeof(int), i, 1);
bootstrap->send(&msg3, sizeof(int), i, 2);
}
for (int i = 0; i < bootstrap->getNranks(); i++) {
if (bootstrap->getRank() == i) continue;
for (int i = 0; i < bootstrap->size(); i++) {
if (bootstrap->rank() == i) continue;
int msg1 = 0;
int msg2 = 0;
int msg3 = 0;
// recv them in the opposite order to check correctness
bootstrap->recv(&msg2, sizeof(int), i, 1);
bootstrap->recv(&msg3, sizeof(int), i, 2);
bootstrap->recv(&msg1, sizeof(int), i, 0);
bootstrap->recv(&msg2, sizeof(int), i, 1).wait();
bootstrap->recv(&msg3, sizeof(int), i, 2).wait();
bootstrap->recv(&msg1, sizeof(int), i, 0).wait();
EXPECT_EQ(msg1, (i + 1) * 3);
EXPECT_EQ(msg2, (i + 1) * 3 + 1);
EXPECT_EQ(msg3, (i + 1) * 3 + 2);
@@ -51,7 +51,7 @@ void BootstrapTest::bootstrapTestAll(std::shared_ptr<mscclpp::Bootstrap> bootstr
TEST_F(BootstrapTest, WithId) {
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(gEnv->rank, gEnv->worldSize);
mscclpp::UniqueId id;
if (bootstrap->getRank() == 0) id = bootstrap->createUniqueId();
if (bootstrap->rank() == 0) id = bootstrap->createUniqueId();
MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
bootstrap->initialize(id);
bootstrapTestAll(bootstrap);
@@ -70,7 +70,7 @@ TEST_F(BootstrapTest, ResumeWithId) {
for (int i = 0; i < 3000; ++i) {
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(gEnv->rank, gEnv->worldSize);
mscclpp::UniqueId id;
if (bootstrap->getRank() == 0) id = bootstrap->createUniqueId();
if (bootstrap->rank() == 0) id = bootstrap->createUniqueId();
MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
bootstrap->initialize(id, 300);
}
@@ -110,12 +110,12 @@ TEST_F(BootstrapTest, TimeoutWithId) {
class MPIBootstrap : public mscclpp::Bootstrap {
public:
MPIBootstrap() : Bootstrap() {}
int getRank() override {
int rank() override {
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
return rank;
}
int getNranks() override {
int size() override {
int worldSize;
MPI_Comm_size(MPI_COMM_WORLD, &worldSize);
return worldSize;
@@ -125,10 +125,14 @@ class MPIBootstrap : public mscclpp::Bootstrap {
}
void barrier() override { MPI_Barrier(MPI_COMM_WORLD); }
void send(void* sendbuf, int size, int dest, int tag) override {
MPI_Send(sendbuf, size, MPI_BYTE, dest, tag, MPI_COMM_WORLD);
MPI_Request request;
MPI_Isend(sendbuf, size, MPI_BYTE, dest, tag, MPI_COMM_WORLD, &request);
MPI_Wait(&request, MPI_STATUS_IGNORE);
}
void recv(void* recvbuf, int size, int source, int tag) override {
MPI_Recv(recvbuf, size, MPI_BYTE, source, tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
std::future<void> recv(void* recvbuf, int size, int source, int tag) override {
MPI_Request request;
MPI_Irecv(recvbuf, size, MPI_BYTE, source, tag, MPI_COMM_WORLD, &request);
return std::async(std::launch::deferred, [request]() mutable { MPI_Wait(&request, MPI_STATUS_IGNORE); });
}
};

View File

@@ -43,17 +43,16 @@ void CommunicatorTestBase::TearDown() {
void CommunicatorTestBase::setNumRanksToUse(int num) { numRanksToUse = num; }
void CommunicatorTestBase::connectMesh(bool useIbOnly) {
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connectionFutures(numRanksToUse);
std::vector<std::future<std::shared_ptr<mscclpp::Connection>>> connectionFutures(numRanksToUse);
for (int i = 0; i < numRanksToUse; i++) {
if (i != gEnv->rank) {
if ((rankToNode(i) == rankToNode(gEnv->rank)) && !useIbOnly) {
connectionFutures[i] = communicator->connectOnSetup(i, 0, mscclpp::Transport::CudaIpc);
connectionFutures[i] = communicator->connect(i, 0, mscclpp::Transport::CudaIpc);
} else {
connectionFutures[i] = communicator->connectOnSetup(i, 0, ibTransport);
connectionFutures[i] = communicator->connect(i, 0, ibTransport);
}
}
}
communicator->setup();
for (int i = 0; i < numRanksToUse; i++) {
if (i != gEnv->rank) {
connections[i] = connectionFutures[i].get();
@@ -67,16 +66,15 @@ void CommunicatorTestBase::registerMemoryPairs(void* buff, size_t buffSize, mscc
mscclpp::RegisteredMemory& localMemory,
std::unordered_map<int, mscclpp::RegisteredMemory>& remoteMemories) {
localMemory = communicator->registerMemory(buff, buffSize, transport);
std::unordered_map<int, mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> futureRemoteMemories;
std::unordered_map<int, std::future<mscclpp::RegisteredMemory>> futureRemoteMemories;
for (int remoteRank : remoteRanks) {
if (remoteRank != communicator->bootstrap()->getRank()) {
communicator->sendMemoryOnSetup(localMemory, remoteRank, tag);
futureRemoteMemories[remoteRank] = communicator->recvMemoryOnSetup(remoteRank, tag);
if (remoteRank != communicator->bootstrap()->rank()) {
communicator->sendMemory(localMemory, remoteRank, tag);
futureRemoteMemories[remoteRank] = communicator->recvMemory(remoteRank, tag);
}
}
communicator->setup();
for (int remoteRank : remoteRanks) {
if (remoteRank != communicator->bootstrap()->getRank()) {
if (remoteRank != communicator->bootstrap()->rank()) {
remoteMemories[remoteRank] = futureRemoteMemories[remoteRank].get();
}
}
@@ -206,7 +204,6 @@ TEST_F(CommunicatorTest, WriteWithDeviceSemaphores) {
auto& conn = entry.second;
semaphores.insert({entry.first, std::make_shared<mscclpp::Host2DeviceSemaphore>(*communicator.get(), conn)});
}
communicator->setup();
communicator->bootstrap()->barrier();
deviceBufferInit();
@@ -247,7 +244,6 @@ TEST_F(CommunicatorTest, WriteWithHostSemaphores) {
if (conn->transport() == mscclpp::Transport::CudaIpc) continue;
semaphores.insert({entry.first, std::make_shared<mscclpp::Host2HostSemaphore>(*communicator.get(), conn)});
}
communicator->setup();
communicator->bootstrap()->barrier();
deviceBufferInit();

View File

@@ -25,7 +25,7 @@ void IbPeerToPeerTest::SetUp() {
if (gEnv->rank < 2) {
// This test needs only two ranks
bootstrap = std::make_shared<mscclpp::TcpBootstrap>(gEnv->rank, 2);
if (bootstrap->getRank() == 0) id = bootstrap->createUniqueId();
if (bootstrap->rank() == 0) id = bootstrap->createUniqueId();
}
MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
if (gEnv->rank >= 2) {
@@ -48,7 +48,7 @@ void IbPeerToPeerTest::registerBufferAndConnect(void* buf, size_t size) {
mrInfo[gEnv->rank] = mr->getInfo();
bootstrap->allGather(mrInfo.data(), sizeof(mscclpp::IbMrInfo));
for (int i = 0; i < bootstrap->getNranks(); ++i) {
for (int i = 0; i < bootstrap->size(); ++i) {
if (i == gEnv->rank) continue;
qp->rtr(qpInfo[i]);
qp->rts();

View File

@@ -18,13 +18,13 @@ void ProxyChannelOneToOneTest::TearDown() { CommunicatorTestBase::TearDown(); }
void ProxyChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::SimpleProxyChannel>& proxyChannels,
bool useIbOnly, void* sendBuff, size_t sendBuffBytes,
void* recvBuff, size_t recvBuffBytes) {
const int rank = communicator->bootstrap()->getRank();
const int worldSize = communicator->bootstrap()->getNranks();
const int rank = communicator->bootstrap()->rank();
const int worldSize = communicator->bootstrap()->size();
const bool isInPlace = (recvBuff == nullptr);
mscclpp::TransportFlags transport = (useIbOnly) ? ibTransport : (mscclpp::Transport::CudaIpc | ibTransport);
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connectionFutures(worldSize);
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteMemFutures(worldSize);
std::vector<std::future<std::shared_ptr<mscclpp::Connection>>> connectionFutures(worldSize);
std::vector<std::future<mscclpp::RegisteredMemory>> remoteMemFutures(worldSize);
mscclpp::RegisteredMemory sendBufRegMem = communicator->registerMemory(sendBuff, sendBuffBytes, transport);
mscclpp::RegisteredMemory recvBufRegMem;
@@ -37,21 +37,19 @@ void ProxyChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::SimpleP
continue;
}
if ((rankToNode(r) == rankToNode(gEnv->rank)) && !useIbOnly) {
connectionFutures[r] = communicator->connectOnSetup(r, 0, mscclpp::Transport::CudaIpc);
connectionFutures[r] = communicator->connect(r, 0, mscclpp::Transport::CudaIpc);
} else {
connectionFutures[r] = communicator->connectOnSetup(r, 0, ibTransport);
connectionFutures[r] = communicator->connect(r, 0, ibTransport);
}
if (isInPlace) {
communicator->sendMemoryOnSetup(sendBufRegMem, r, 0);
communicator->sendMemory(sendBufRegMem, r, 0);
} else {
communicator->sendMemoryOnSetup(recvBufRegMem, r, 0);
communicator->sendMemory(recvBufRegMem, r, 0);
}
remoteMemFutures[r] = communicator->recvMemoryOnSetup(r, 0);
remoteMemFutures[r] = communicator->recvMemory(r, 0);
}
communicator->setup();
for (int r = 0; r < worldSize; r++) {
if (r == rank) {
continue;
@@ -61,8 +59,6 @@ void ProxyChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::SimpleP
proxyChannels.emplace_back(proxyService->proxyChannel(cid), proxyService->addMemory(remoteMemFutures[r].get()),
proxyService->addMemory(sendBufRegMem));
}
communicator->setup();
}
__constant__ DeviceHandle<mscclpp::SimpleProxyChannel> gChannelOneToOneTestConstProxyChans;

View File

@@ -19,13 +19,13 @@ void SmChannelOneToOneTest::TearDown() { CommunicatorTestBase::TearDown(); }
void SmChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::SmChannel>& smChannels, void* inputBuff,
size_t inputBuffBytes, void* outputBuff, size_t outputBuffBytes) {
const int rank = communicator->bootstrap()->getRank();
const int worldSize = communicator->bootstrap()->getNranks();
const int rank = communicator->bootstrap()->rank();
const int worldSize = communicator->bootstrap()->size();
const bool isInPlace = (outputBuff == nullptr);
mscclpp::TransportFlags transport = mscclpp::Transport::CudaIpc | ibTransport;
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connectionFutures(worldSize);
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteMemFutures(worldSize);
std::vector<std::future<std::shared_ptr<mscclpp::Connection>>> connectionFutures(worldSize);
std::vector<std::future<mscclpp::RegisteredMemory>> remoteMemFutures(worldSize);
mscclpp::RegisteredMemory inputBufRegMem = communicator->registerMemory(inputBuff, inputBuffBytes, transport);
mscclpp::RegisteredMemory outputBufRegMem;
@@ -38,21 +38,19 @@ void SmChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::SmChannel>
continue;
}
if (rankToNode(r) == rankToNode(gEnv->rank)) {
connectionFutures[r] = communicator->connectOnSetup(r, 0, mscclpp::Transport::CudaIpc);
connectionFutures[r] = communicator->connect(r, 0, mscclpp::Transport::CudaIpc);
} else {
connectionFutures[r] = communicator->connectOnSetup(r, 0, ibTransport);
connectionFutures[r] = communicator->connect(r, 0, ibTransport);
}
if (isInPlace) {
communicator->sendMemoryOnSetup(inputBufRegMem, r, 0);
communicator->sendMemory(inputBufRegMem, r, 0);
} else {
communicator->sendMemoryOnSetup(outputBufRegMem, r, 0);
communicator->sendMemory(outputBufRegMem, r, 0);
}
remoteMemFutures[r] = communicator->recvMemoryOnSetup(r, 0);
remoteMemFutures[r] = communicator->recvMemory(r, 0);
}
communicator->setup();
for (int r = 0; r < worldSize; r++) {
if (r == rank) {
continue;
@@ -64,8 +62,6 @@ void SmChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::SmChannel>
smChannels.emplace_back(smSemaphores[r], remoteMemFutures[r].get(), inputBufRegMem.data(),
(isInPlace ? nullptr : outputBufRegMem.data()));
}
communicator->setup();
}
__constant__ DeviceHandle<mscclpp::SmChannel> gChannelOneToOneTestConstSmChans;

View File

@@ -497,7 +497,7 @@ void AllGatherTestEngine::setupConnections() {
auto service = std::dynamic_pointer_cast<AllGatherProxyService>(chanService_);
setupMeshConnections(devProxyChannels, sendBuff_.get(), args_.maxBytes, nullptr, 0,
[&](std::vector<std::shared_ptr<mscclpp::Connection>> conns,
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>>& remoteMemories,
std::vector<std::future<mscclpp::RegisteredMemory>>& remoteMemories,
const mscclpp::RegisteredMemory& localMemory) {
std::vector<mscclpp::SemaphoreId> semaphoreIds;
for (size_t i = 0; i < conns.size(); ++i) {
@@ -505,7 +505,6 @@ void AllGatherTestEngine::setupConnections() {
service->addRemoteMemory(remoteMemories[i].get());
}
service->setLocalMemory(localMemory);
comm_->setup();
});
auto proxyChannels = service->proxyChannels();
if (proxyChannels.size() > sizeof(constRawProxyChan) / sizeof(DeviceHandle<mscclpp::ProxyChannel>)) {

View File

@@ -327,7 +327,7 @@ void BaseTestEngine::runTest() {
void BaseTestEngine::bootstrap() {
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(args_.rank, args_.totalRanks);
mscclpp::UniqueId id;
if (bootstrap->getRank() == 0) id = bootstrap->createUniqueId();
if (bootstrap->rank() == 0) id = bootstrap->createUniqueId();
MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
bootstrap->initialize(id);
comm_ = std::make_shared<mscclpp::Communicator>(bootstrap);
@@ -362,13 +362,13 @@ std::shared_ptr<mscclpp::BaseProxyService> BaseTestEngine::createProxyService()
void BaseTestEngine::setupMeshConnectionsInternal(
std::vector<std::shared_ptr<mscclpp::Connection>>& connections, mscclpp::RegisteredMemory& localRegMemory,
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>>& remoteRegMemories, bool addConnections) {
std::vector<std::future<mscclpp::RegisteredMemory>>& remoteRegMemories, bool addConnections) {
const int worldSize = args_.totalRanks;
const int rank = args_.rank;
const int nRanksPerNode = args_.nRanksPerNode;
const int thisNode = rank / nRanksPerNode;
const mscclpp::Transport ibTransport = IBs[args_.gpuNum];
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connectionFutures;
std::vector<std::future<std::shared_ptr<mscclpp::Connection>>> connectionFutures;
auto rankToNode = [&](int rank) { return rank / nRanksPerNode; };
for (int r = 0; r < worldSize; r++) {
@@ -383,16 +383,13 @@ void BaseTestEngine::setupMeshConnectionsInternal(
transport = ibTransport;
}
// Connect with all other ranks
connectionFutures.push_back(comm_->connectOnSetup(r, 0, transport));
connectionFutures.push_back(comm_->connect(r, 0, transport));
}
comm_->sendMemoryOnSetup(localRegMemory, r, 0);
auto remoteMemory = comm_->recvMemoryOnSetup(r, 0);
remoteRegMemories.push_back(remoteMemory);
comm_->sendMemory(localRegMemory, r, 0);
remoteRegMemories.push_back(comm_->recvMemory(r, 0));
}
comm_->setup();
std::transform(
connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections),
[](const mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>& future) { return future.get(); });
std::transform(connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections),
[](std::future<std::shared_ptr<mscclpp::Connection>>& future) { return future.get(); });
}
// Create mesh connections between all ranks. If recvBuff is nullptr, assume in-place.
@@ -408,7 +405,7 @@ void BaseTestEngine::setupMeshConnections(std::vector<DeviceHandle<mscclpp::Simp
}
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteRegMemories;
std::vector<std::future<mscclpp::RegisteredMemory>> remoteRegMemories;
mscclpp::RegisteredMemory& localRegMemory = (outputBuff) ? outputBufRegMem : inputBufRegMem;
setupMeshConnectionsInternal(connections, localRegMemory, remoteRegMemories);
@@ -423,8 +420,6 @@ void BaseTestEngine::setupMeshConnections(std::vector<DeviceHandle<mscclpp::Simp
service->addMemory(remoteRegMemories[i].get()), service->addMemory(inputBufRegMem))));
}
}
comm_->setup();
}
void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::SmChannel>& smChannels, void* inputBuff,
@@ -439,7 +434,7 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::SmChannel>& smCha
}
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteRegMemories;
std::vector<std::future<mscclpp::RegisteredMemory>> remoteRegMemories;
mscclpp::RegisteredMemory& localRegMemory =
(outputBuff && semantic == ChannelSemantic::PUT) ? outputBufRegMem : inputBufRegMem;
setupMeshConnectionsInternal(connections, localRegMemory, remoteRegMemories);
@@ -450,7 +445,6 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::SmChannel>& smCha
smSemaphores.emplace(cid, std::make_shared<mscclpp::SmDevice2DeviceSemaphore>(*comm_, connections[cid]));
}
}
comm_->setup();
for (size_t cid = 0; cid < connections.size(); ++cid) {
if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
@@ -482,13 +476,13 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::SmChannel>& smCha
}
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteRegMemories;
std::vector<std::future<mscclpp::RegisteredMemory>> remoteRegMemories;
mscclpp::RegisteredMemory& localRegMemory =
(getPacketBuff) ? getPacketBufRegMem : ((outputBuff) ? outputBufRegMem : inputBufRegMem);
setupMeshConnectionsInternal(connections, localRegMemory, remoteRegMemories);
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteRegMemoriesOutput;
std::vector<std::future<mscclpp::RegisteredMemory>> remoteRegMemoriesOutput;
if (outputBuff) {
setupMeshConnectionsInternal(connections, outputBufRegMem, remoteRegMemoriesOutput, false);
}
@@ -504,7 +498,6 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::SmChannel>& smCha
connIdToSemId[cid] = service->buildAndAddSemaphore(*comm_, connections[cid]);
}
}
comm_->setup();
for (size_t cid = 0; cid < connections.size(); ++cid) {
if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) {

View File

@@ -102,15 +102,15 @@ class BaseTestEngine {
double benchTime();
void setupMeshConnectionsInternal(
std::vector<std::shared_ptr<mscclpp::Connection>>& connections, mscclpp::RegisteredMemory& localMemory,
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>>& remoteRegMemories,
bool addConnections = true);
void setupMeshConnectionsInternal(std::vector<std::shared_ptr<mscclpp::Connection>>& connections,
mscclpp::RegisteredMemory& localMemory,
std::vector<std::future<mscclpp::RegisteredMemory>>& remoteRegMemories,
bool addConnections = true);
protected:
using SetupChannelFunc = std::function<void(std::vector<std::shared_ptr<mscclpp::Connection>>,
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>>&,
const mscclpp::RegisteredMemory&)>;
using SetupChannelFunc =
std::function<void(std::vector<std::shared_ptr<mscclpp::Connection>>,
std::vector<std::future<mscclpp::RegisteredMemory>>&, const mscclpp::RegisteredMemory&)>;
template <class T>
using DeviceHandle = mscclpp::DeviceHandle<T>;
void setupMeshConnections(std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>>& proxyChannels, void* inputBuff,

View File

@@ -155,29 +155,25 @@ void SendRecvTestEngine::setupConnections() {
std::vector<std::shared_ptr<mscclpp::SmDevice2DeviceSemaphore>> smSemaphores;
auto sendConnFuture =
comm_->connectOnSetup(sendToRank, 0, getTransport(args_.rank, sendToRank, args_.nRanksPerNode, ibDevice));
comm_->connect(sendToRank, 0, getTransport(args_.rank, sendToRank, args_.nRanksPerNode, ibDevice));
if (recvFromRank != sendToRank) {
auto recvConnFuture =
comm_->connectOnSetup(recvFromRank, 0, getTransport(args_.rank, recvFromRank, args_.nRanksPerNode, ibDevice));
comm_->setup();
comm_->connect(recvFromRank, 0, getTransport(args_.rank, recvFromRank, args_.nRanksPerNode, ibDevice));
smSemaphores.push_back(std::make_shared<mscclpp::SmDevice2DeviceSemaphore>(*comm_, sendConnFuture.get()));
smSemaphores.push_back(std::make_shared<mscclpp::SmDevice2DeviceSemaphore>(*comm_, recvConnFuture.get()));
} else {
comm_->setup();
smSemaphores.push_back(std::make_shared<mscclpp::SmDevice2DeviceSemaphore>(*comm_, sendConnFuture.get()));
smSemaphores.push_back(smSemaphores[0]); // reuse the send channel if worldSize is 2
}
comm_->setup();
std::vector<mscclpp::RegisteredMemory> localMemories;
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> futureRemoteMemory;
std::vector<std::future<mscclpp::RegisteredMemory>> futureRemoteMemory;
for (int i : {0, 1}) {
auto regMem = comm_->registerMemory(devicePtrs_[i].get(), args_.maxBytes, mscclpp::Transport::CudaIpc | ibDevice);
comm_->sendMemoryOnSetup(regMem, ranks[i], 0);
comm_->sendMemory(regMem, ranks[i], 0);
localMemories.push_back(regMem);
futureRemoteMemory.push_back(comm_->recvMemoryOnSetup(ranks[1 - i], 0));
comm_->setup();
futureRemoteMemory.push_back(comm_->recvMemory(ranks[1 - i], 0));
}
// swap to make sure devicePtrs_[0] in local rank write to devicePtrs_[1] in remote rank

View File

@@ -24,14 +24,6 @@ class MockSetuppable : public mscclpp::Setuppable {
MOCK_METHOD(void, endSetup, (std::shared_ptr<mscclpp::Bootstrap> bootstrap), (override));
};
TEST_F(LocalCommunicatorTest, OnSetup) {
auto mockSetuppable = std::make_shared<MockSetuppable>();
comm->onSetup(mockSetuppable);
EXPECT_CALL(*mockSetuppable, beginSetup(std::dynamic_pointer_cast<mscclpp::Bootstrap>(bootstrap)));
EXPECT_CALL(*mockSetuppable, endSetup(std::dynamic_pointer_cast<mscclpp::Bootstrap>(bootstrap)));
comm->setup();
}
TEST_F(LocalCommunicatorTest, RegisterMemory) {
int dummy[42];
auto memory = comm->registerMemory(&dummy, sizeof(dummy), mscclpp::NoTransports);
@@ -43,9 +35,8 @@ TEST_F(LocalCommunicatorTest, RegisterMemory) {
TEST_F(LocalCommunicatorTest, SendMemoryToSelf) {
int dummy[42];
auto memory = comm->registerMemory(&dummy, sizeof(dummy), mscclpp::NoTransports);
comm->sendMemoryOnSetup(memory, 0, 0);
auto memoryFuture = comm->recvMemoryOnSetup(0, 0);
comm->setup();
comm->sendMemory(memory, 0, 0);
auto memoryFuture = comm->recvMemory(0, 0);
auto sameMemory = memoryFuture.get();
EXPECT_EQ(sameMemory.data(), memory.data());
EXPECT_EQ(sameMemory.size(), memory.size());