Implement mscclpp::Communicator using C-style API

This commit is contained in:
Olli Saarikivi
2023-04-14 14:21:53 +00:00
parent c042d9af54
commit 45172bec88
2 changed files with 53 additions and 52 deletions

View File

@@ -3,49 +3,83 @@
namespace mscclpp {
mscclppTransport_t transportTypeToCStyle(TransportType type) {
switch (type) {
case TransportType::IB:
return mscclppTransportIB;
case TransportType::P2P:
return mscclppTransportP2P;
default:
throw std::runtime_error("Unknown transport type");
}
}
struct Communicator::impl {
mscclppComm_t comm;
std::vector<std::shared_ptr<HostConnection>> connections;
impl() : comm(nullptr) {}
~impl() {
if (comm) {
mscclppCommDestroy(comm);
}
}
};
void Communicator::initRank(int nranks, const char* ipPortPair, int rank) {
if (pimpl) {
throw std::runtime_error("Communicator already initialized");
}
pimpl = std::make_unique<impl>();
mscclppCommInitRank(&pimpl->comm, nranks, ipPortPair, rank);
}
void Communicator::initRankFromId(int nranks, UniqueId id, int rank) {
if (pimpl) {
throw std::runtime_error("Communicator already initialized");
}
pimpl = std::make_unique<impl>();
static_assert(sizeof(mscclppUniqueId) == sizeof(UniqueId), "UniqueId size mismatch");
mscclppUniqueId *cstyle_id = reinterpret_cast<mscclppUniqueId*>(&id);
mscclppCommInitRankFromId(&pimpl->comm, nranks, *cstyle_id, rank);
}
void Communicator::bootstrapAllGather(void* data, int size) {
mscclppBootstrapAllGather(pimpl->comm, data, size);
}
void Communicator::bootstrapBarrier() {
mscclppBootstrapBarrier(pimpl->comm);
}
std::shared_ptr<HostConnection> Communicator::connect(int remoteRank, int tag, void* localBuff, uint64_t buffSize,
TransportType transportType, const char* ibDev = 0) {
std::shared_ptr<HostConnection> Communicator::connect(int remoteRank, int tag,
TransportType transportType, const char* ibDev = 0) {
mscclppConnect(pimpl->comm, remoteRank, tag, transportTypeToCStyle(transportType), ibDev);
auto conn = std::make_shared<HostConnection>();
auto connId = pimpl->connections.size();
conn->pimpl->init(connId);
pimpl->connections.push_back(conn);
return conn;
}
void Communicator::connectionSetup() {
}
void Communicator::destroy() {
mscclppConnectionSetup(pimpl->comm);
for (int connIdx = 0; connIdx < pimpl->connections.size(); ++connIdx) {
pimpl->connections[connIdx]->pimpl->setup();
}
}
int Communicator::rank() {
int result;
mscclppCommRank(pimpl->comm, &result);
return result;
}
int Communicator::size() {
}
void Communicator::setBootstrapConnTimeout(unsigned timeout) {
int result;
mscclppCommSize(pimpl->comm, &result);
return result;
}
} // namespace mscclpp

View File

@@ -331,21 +331,15 @@ public:
* remoteRank: the rank of the remote process
* tag: the tag of the connection. tag is copied into the corresponding mscclppDevConn_t, which can be
* used to identify the connection inside a GPU kernel.
* localBuff: the local send/receive buffer
* buffSize: the size of the local buffer
* transportType: the type of transport to be used (mscclppTransportP2P or mscclppTransportIB)
* ibDev: the name of the IB device to be used. Expects a null for mscclppTransportP2P.
*/
std::shared_ptr<HostConnection> connect(int remoteRank, int tag, void* localBuff, uint64_t buffSize,
TransportType transportType, const char* ibDev = 0);
std::shared_ptr<HostConnection> connect(int remoteRank, int tag, TransportType transportType, const char* ibDev = 0);
/* Establish all connections created by mscclppConnect(). This function must be called after all mscclppConnect()
* calls are made. This function ensures that all remote ranks are ready to communicate when it returns.
*/
void connectionSetup();
/* Destroy the communicator. */
void destroy();
/* Return the rank of the calling process.
*
@@ -361,38 +355,11 @@ public:
*/
int size();
/* Set the timeout for the bootstrap connection.
*
* Inputs:
* timeout: the timeout in seconds
*/
void setBootstrapConnTimeout(unsigned timeout);
private:
struct impl;
std::unique_ptr<impl> pimpl;
};
/* Log handler type which is a callback function for
* however user likes to handle the log messages. Once set,
* the logger will just call this function with msg.
*/
typedef void (*LogHandler)(const char* msg);
/* The default log handler.
*
* Inputs:
* msg: the log message
*/
void defaultLogHandler(const char* msg);
/* Set a custom log handler.
*
* Inputs:
* handler: the log handler function
*/
void setLogHandler(LogHandler handler);
} // namespace mscclpp
#endif // MSCCLPP_H_