mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-13 09:46:00 +00:00
Implement mscclpp::Communicator using C-style API
This commit is contained in:
@@ -3,49 +3,83 @@
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
mscclppTransport_t transportTypeToCStyle(TransportType type) {
|
||||
switch (type) {
|
||||
case TransportType::IB:
|
||||
return mscclppTransportIB;
|
||||
case TransportType::P2P:
|
||||
return mscclppTransportP2P;
|
||||
default:
|
||||
throw std::runtime_error("Unknown transport type");
|
||||
}
|
||||
}
|
||||
|
||||
struct Communicator::impl {
|
||||
mscclppComm_t comm;
|
||||
std::vector<std::shared_ptr<HostConnection>> connections;
|
||||
|
||||
impl() : comm(nullptr) {}
|
||||
|
||||
~impl() {
|
||||
if (comm) {
|
||||
mscclppCommDestroy(comm);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void Communicator::initRank(int nranks, const char* ipPortPair, int rank) {
|
||||
|
||||
if (pimpl) {
|
||||
throw std::runtime_error("Communicator already initialized");
|
||||
}
|
||||
pimpl = std::make_unique<impl>();
|
||||
mscclppCommInitRank(&pimpl->comm, nranks, ipPortPair, rank);
|
||||
}
|
||||
|
||||
void Communicator::initRankFromId(int nranks, UniqueId id, int rank) {
|
||||
|
||||
if (pimpl) {
|
||||
throw std::runtime_error("Communicator already initialized");
|
||||
}
|
||||
pimpl = std::make_unique<impl>();
|
||||
static_assert(sizeof(mscclppUniqueId) == sizeof(UniqueId), "UniqueId size mismatch");
|
||||
mscclppUniqueId *cstyle_id = reinterpret_cast<mscclppUniqueId*>(&id);
|
||||
mscclppCommInitRankFromId(&pimpl->comm, nranks, *cstyle_id, rank);
|
||||
}
|
||||
|
||||
void Communicator::bootstrapAllGather(void* data, int size) {
|
||||
|
||||
mscclppBootstrapAllGather(pimpl->comm, data, size);
|
||||
}
|
||||
|
||||
void Communicator::bootstrapBarrier() {
|
||||
|
||||
mscclppBootstrapBarrier(pimpl->comm);
|
||||
}
|
||||
|
||||
std::shared_ptr<HostConnection> Communicator::connect(int remoteRank, int tag, void* localBuff, uint64_t buffSize,
|
||||
TransportType transportType, const char* ibDev = 0) {
|
||||
|
||||
std::shared_ptr<HostConnection> Communicator::connect(int remoteRank, int tag,
|
||||
TransportType transportType, const char* ibDev = 0) {
|
||||
mscclppConnect(pimpl->comm, remoteRank, tag, transportTypeToCStyle(transportType), ibDev);
|
||||
auto conn = std::make_shared<HostConnection>();
|
||||
auto connId = pimpl->connections.size();
|
||||
conn->pimpl->init(connId);
|
||||
pimpl->connections.push_back(conn);
|
||||
return conn;
|
||||
}
|
||||
|
||||
void Communicator::connectionSetup() {
|
||||
|
||||
}
|
||||
|
||||
void Communicator::destroy() {
|
||||
|
||||
mscclppConnectionSetup(pimpl->comm);
|
||||
for (int connIdx = 0; connIdx < pimpl->connections.size(); ++connIdx) {
|
||||
pimpl->connections[connIdx]->pimpl->setup();
|
||||
}
|
||||
}
|
||||
|
||||
int Communicator::rank() {
|
||||
|
||||
int result;
|
||||
mscclppCommRank(pimpl->comm, &result);
|
||||
return result;
|
||||
}
|
||||
|
||||
int Communicator::size() {
|
||||
|
||||
}
|
||||
|
||||
void Communicator::setBootstrapConnTimeout(unsigned timeout) {
|
||||
|
||||
int result;
|
||||
mscclppCommSize(pimpl->comm, &result);
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
@@ -331,21 +331,15 @@ public:
|
||||
* remoteRank: the rank of the remote process
|
||||
* tag: the tag of the connection. tag is copied into the corresponding mscclppDevConn_t, which can be
|
||||
* used to identify the connection inside a GPU kernel.
|
||||
* localBuff: the local send/receive buffer
|
||||
* buffSize: the size of the local buffer
|
||||
* transportType: the type of transport to be used (mscclppTransportP2P or mscclppTransportIB)
|
||||
* ibDev: the name of the IB device to be used. Expects a null for mscclppTransportP2P.
|
||||
*/
|
||||
std::shared_ptr<HostConnection> connect(int remoteRank, int tag, void* localBuff, uint64_t buffSize,
|
||||
TransportType transportType, const char* ibDev = 0);
|
||||
std::shared_ptr<HostConnection> connect(int remoteRank, int tag, TransportType transportType, const char* ibDev = 0);
|
||||
|
||||
/* Establish all connections created by mscclppConnect(). This function must be called after all mscclppConnect()
|
||||
* calls are made. This function ensures that all remote ranks are ready to communicate when it returns.
|
||||
*/
|
||||
void connectionSetup();
|
||||
|
||||
/* Destroy the communicator. */
|
||||
void destroy();
|
||||
|
||||
/* Return the rank of the calling process.
|
||||
*
|
||||
@@ -361,38 +355,11 @@ public:
|
||||
*/
|
||||
int size();
|
||||
|
||||
/* Set the timeout for the bootstrap connection.
|
||||
*
|
||||
* Inputs:
|
||||
* timeout: the timeout in seconds
|
||||
*/
|
||||
void setBootstrapConnTimeout(unsigned timeout);
|
||||
|
||||
private:
|
||||
struct impl;
|
||||
std::unique_ptr<impl> pimpl;
|
||||
};
|
||||
|
||||
/* Log handler type which is a callback function for
|
||||
* however user likes to handle the log messages. Once set,
|
||||
* the logger will just call this function with msg.
|
||||
*/
|
||||
typedef void (*LogHandler)(const char* msg);
|
||||
|
||||
/* The default log handler.
|
||||
*
|
||||
* Inputs:
|
||||
* msg: the log message
|
||||
*/
|
||||
void defaultLogHandler(const char* msg);
|
||||
|
||||
/* Set a custom log handler.
|
||||
*
|
||||
* Inputs:
|
||||
* handler: the log handler function
|
||||
*/
|
||||
void setLogHandler(LogHandler handler);
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_H_
|
||||
|
||||
Reference in New Issue
Block a user