diff --git a/src/include/bootstrap.h b/src/include/bootstrap.h index 95320b07..25dbe51c 100644 --- a/src/include/bootstrap.h +++ b/src/include/bootstrap.h @@ -20,6 +20,20 @@ struct mscclppBootstrapHandle static_assert(sizeof(struct mscclppBootstrapHandle) <= sizeof(mscclppUniqueId), "Bootstrap handle is too large to fit inside MSCCLPP unique ID"); +class mscclppBootstrap : Bootstrap { +public: + mscclppBootstrap(std::string ip_port_pair, int rank, int nranks); + mscclppBootstrap(mscclppBootstrapHandle handle, int rank, int nranks); + mscclppBootstrapHandle mscclppGetUniqueId(); + void Send(void* data, int size, int peer, int tag); + void Recv(void* data, int size, int peer, int tag); + void AllGather(void* allData, int size); + void Barrier(); +private: + struct impl; + std::unique_ptr pimpl; +}; + mscclppResult_t bootstrapNetInit(const char* ip_port_pair = NULL); mscclppResult_t bootstrapCreateRoot(struct mscclppBootstrapHandle* handle); mscclppResult_t bootstrapGetUniqueId(struct mscclppBootstrapHandle* handle, bool isRoot = true, diff --git a/src/include/mscclpp.h b/src/include/mscclpp.h index e48eaaf8..e10b8e4f 100644 --- a/src/include/mscclpp.h +++ b/src/include/mscclpp.h @@ -247,6 +247,17 @@ typedef enum mscclppNumResults = 8 } mscclppResult_t; + +class Bootstrap { +public: + Bootstrap(){}; + virtual ~Bootstrap() = 0; + virtual void Send(void* data, int size, int peer, int tag) = 0; + virtual void Recv(void* data, int size, int peer, int tag) = 0; + virtual void AllGather(void* allData, int size) = 0; + virtual void Barrier() = 0; +}; + /* Create a unique ID for communication. Only needs to be called by one process. * Use with mscclppCommInitRankFromId(). * All processes need to provide the same ID to mscclppCommInitRankFromId().