From 45172bec886279ff5bc4bdd908381ead3e09eb90 Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Fri, 14 Apr 2023 14:21:53 +0000 Subject: [PATCH] Implement mscclpp::Communicator using C-style API --- src/communicator.cc | 70 ++++++++++++++++++++++++++++++----------- src/include/mscclpp.hpp | 35 +-------------------- 2 files changed, 53 insertions(+), 52 deletions(-) diff --git a/src/communicator.cc b/src/communicator.cc index 5519b9c5..3272987d 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -3,49 +3,83 @@ namespace mscclpp { +mscclppTransport_t transportTypeToCStyle(TransportType type) { + switch (type) { + case TransportType::IB: + return mscclppTransportIB; + case TransportType::P2P: + return mscclppTransportP2P; + default: + throw std::runtime_error("Unknown transport type"); + } +} + struct Communicator::impl { mscclppComm_t comm; + std::vector> connections; + + impl() : comm(nullptr) {} + + ~impl() { + if (comm) { + mscclppCommDestroy(comm); + } + } }; void Communicator::initRank(int nranks, const char* ipPortPair, int rank) { - + if (pimpl) { + throw std::runtime_error("Communicator already initialized"); + } + pimpl = std::make_unique(); + mscclppCommInitRank(&pimpl->comm, nranks, ipPortPair, rank); } void Communicator::initRankFromId(int nranks, UniqueId id, int rank) { - + if (pimpl) { + throw std::runtime_error("Communicator already initialized"); + } + pimpl = std::make_unique(); + static_assert(sizeof(mscclppUniqueId) == sizeof(UniqueId), "UniqueId size mismatch"); + mscclppUniqueId *cstyle_id = reinterpret_cast(&id); + mscclppCommInitRankFromId(&pimpl->comm, nranks, *cstyle_id, rank); } void Communicator::bootstrapAllGather(void* data, int size) { - + mscclppBootstrapAllGather(pimpl->comm, data, size); } void Communicator::bootstrapBarrier() { - + mscclppBootstrapBarrier(pimpl->comm); } -std::shared_ptr Communicator::connect(int remoteRank, int tag, void* localBuff, uint64_t buffSize, - TransportType transportType, const char* ibDev = 0) { - +std::shared_ptr Communicator::connect(int remoteRank, int tag, + TransportType transportType, const char* ibDev = 0) { + mscclppConnect(pimpl->comm, remoteRank, tag, transportTypeToCStyle(transportType), ibDev); + auto conn = std::make_shared(); + auto connId = pimpl->connections.size(); + conn->pimpl->init(connId); + pimpl->connections.push_back(conn); + return conn; } void Communicator::connectionSetup() { - -} - -void Communicator::destroy() { - + mscclppConnectionSetup(pimpl->comm); + for (int connIdx = 0; connIdx < pimpl->connections.size(); ++connIdx) { + pimpl->connections[connIdx]->pimpl->setup(); + } } int Communicator::rank() { - + int result; + mscclppCommRank(pimpl->comm, &result); + return result; } int Communicator::size() { - -} - -void Communicator::setBootstrapConnTimeout(unsigned timeout) { - + int result; + mscclppCommSize(pimpl->comm, &result); + return result; } } // namespace mscclpp \ No newline at end of file diff --git a/src/include/mscclpp.hpp b/src/include/mscclpp.hpp index d44b04b9..85aa22f8 100644 --- a/src/include/mscclpp.hpp +++ b/src/include/mscclpp.hpp @@ -331,21 +331,15 @@ public: * remoteRank: the rank of the remote process * tag: the tag of the connection. tag is copied into the corresponding mscclppDevConn_t, which can be * used to identify the connection inside a GPU kernel. - * localBuff: the local send/receive buffer - * buffSize: the size of the local buffer * transportType: the type of transport to be used (mscclppTransportP2P or mscclppTransportIB) * ibDev: the name of the IB device to be used. Expects a null for mscclppTransportP2P. */ - std::shared_ptr connect(int remoteRank, int tag, void* localBuff, uint64_t buffSize, - TransportType transportType, const char* ibDev = 0); + std::shared_ptr connect(int remoteRank, int tag, TransportType transportType, const char* ibDev = 0); /* Establish all connections created by mscclppConnect(). This function must be called after all mscclppConnect() * calls are made. This function ensures that all remote ranks are ready to communicate when it returns. */ void connectionSetup(); - - /* Destroy the communicator. */ - void destroy(); /* Return the rank of the calling process. * @@ -361,38 +355,11 @@ public: */ int size(); - /* Set the timeout for the bootstrap connection. - * - * Inputs: - * timeout: the timeout in seconds - */ - void setBootstrapConnTimeout(unsigned timeout); - private: struct impl; std::unique_ptr pimpl; }; -/* Log handler type which is a callback function for - * however user likes to handle the log messages. Once set, - * the logger will just call this function with msg. - */ -typedef void (*LogHandler)(const char* msg); - -/* The default log handler. - * - * Inputs: - * msg: the log message - */ -void defaultLogHandler(const char* msg); - -/* Set a custom log handler. - * - * Inputs: - * handler: the log handler function - */ -void setLogHandler(LogHandler handler); - } // namespace mscclpp #endif // MSCCLPP_H_