diff --git a/src/bootstrap/bootstrap.cc b/src/bootstrap/bootstrap.cc index c1724a55..3b2ec6ab 100644 --- a/src/bootstrap/bootstrap.cc +++ b/src/bootstrap/bootstrap.cc @@ -24,11 +24,16 @@ static union mscclppSocketAddress bootstrapNetIfAddr; static int bootstrapNetInitDone = 0; pthread_mutex_t bootstrapNetLock = PTHREAD_MUTEX_INITIALIZER; -mscclppResult_t bootstrapNetInit() { +mscclppResult_t bootstrapNetInit(char* ip_port_pair) { if (bootstrapNetInitDone == 0) { pthread_mutex_lock(&bootstrapNetLock); if (bootstrapNetInitDone == 0) { - char* env = getenv("MSCCLPP_COMM_ID"); + char* env; + if (ip_port_pair) { + env = ip_port_pair; + } else { + env = getenv("MSCCLPP_COMM_ID"); + } if (env) { union mscclppSocketAddress remoteAddr; if (mscclppSocketGetAddrFromString(&remoteAddr, env) != mscclppSuccess) { @@ -188,11 +193,16 @@ mscclppResult_t bootstrapCreateRoot(struct mscclppBootstrapHandle* handle, bool // #include // #include -mscclppResult_t bootstrapGetUniqueId(struct mscclppBootstrapHandle* handle, bool isRoot) { +mscclppResult_t bootstrapGetUniqueId(struct mscclppBootstrapHandle* handle, bool isRoot, char* ip_port_pair) { memset(handle, 0, sizeof(mscclppBootstrapHandle)); // MSCCLPPCHECK(getRandomData(&handle->magic, sizeof(handle->magic))); handle->magic = 0xdeadbeef; - char* env = getenv("MSCCLPP_COMM_ID"); + char* env; + if (ip_port_pair) { + env = ip_port_pair; + } else { + env = getenv("MSCCLPP_COMM_ID"); + } if (env) { INFO(MSCCLPP_ENV, "MSCCLPP_COMM_ID set by environment to %s", env); if (mscclppSocketGetAddrFromString(&handle->addr, env) != mscclppSuccess) { diff --git a/src/bootstrap/bootstrap_test.cc b/src/bootstrap/bootstrap_test.cc index f02fa7c7..b3dc9954 100644 --- a/src/bootstrap/bootstrap_test.cc +++ b/src/bootstrap/bootstrap_test.cc @@ -1,4 +1,4 @@ -#include "bootstrap.h" +#include "mscclpp.h" #include "alloc.h" #include "mpi.h" #include @@ -11,49 +11,10 @@ int main() int world_size; MPI_Comm_rank(MPI_COMM_WORLD, &rank); MPI_Comm_size(MPI_COMM_WORLD, &world_size); - // int a; - // scanf("%d", &a); - mscclppResult_t res = bootstrapNetInit(); - if (res != mscclppSuccess) { - printf("bootstrapNetInit failed\n"); - return -1; - } - - mscclppBootstrapHandle handle; - if (true || rank == 0) { - res = bootstrapGetUniqueId(&handle, rank == 0); - if (res != mscclppSuccess) { - printf("bootstrapGetUniqueId failed\n"); - return -1; - } - } - - // MPI_Bcast(&handle, sizeof(mscclppBootstrapHandle), MPI_BYTE, 0, MPI_COMM_WORLD); - - mscclppComm *comm; - res = mscclppCalloc(&comm, 1); - if (res != mscclppSuccess) { - printf("mscclppCalloc failed\n"); - return -1; - } - - comm->magic = 0xdeadbeef; - comm->rank = rank; - comm->nRanks = world_size; - res = mscclppCudaHostCalloc((uint32_t **)&comm->abortFlag, 1); - if (res != mscclppSuccess) { - printf("mscclppCudaHostCalloc failed\n"); - return -1; - } - - res = bootstrapInit(&handle, comm); - if (res != mscclppSuccess) { - printf("bootstrapInit failed\n"); - return -1; - } - - printf("bootstrapInit done\n"); + mscclppComm_t comm; + char ip_port[] = "192.168.0.32:50000"; + mscclppCommInitRank(&comm, world_size, rank, ip_port); int *buf = (int *)calloc(world_size, sizeof(int)); if (buf == nullptr) { @@ -61,7 +22,7 @@ int main() return -1; } buf[rank] = rank; - res = bootstrapAllGather(comm->bootstrap, buf, sizeof(int)); + mscclppResult_t res = mscclppBootStrapAllGather(comm, buf, sizeof(int)); if (res != mscclppSuccess) { printf("bootstrapAllGather failed\n"); return -1; @@ -74,9 +35,9 @@ int main() } } - res = bootstrapClose(comm->bootstrap); + res = mscclppCommDestroy(comm); if (res != mscclppSuccess) { - printf("bootstrapClose failed\n"); + printf("mscclppDestroy failed\n"); return -1; } diff --git a/src/bootstrap/init.cc b/src/bootstrap/init.cc index a70f3295..81207d52 100644 --- a/src/bootstrap/init.cc +++ b/src/bootstrap/init.cc @@ -43,4 +43,47 @@ mscclppResult_t mscclppGetUniqueId(mscclppUniqueId* out) { mscclppResult_t res = bootstrapGetUniqueId((struct mscclppBootstrapHandle*)out); TRACE_CALL("mscclppGetUniqueId(0x%llx)", (unsigned long long)hashUniqueId(*out)); return res; +} + +mscclppResult_t mscclppBootStrapAllGather(mscclppComm_t comm, void* data, int size){ + MSCCLPPCHECK(bootstrapAllGather(comm->bootstrap, data, size)); + return mscclppSuccess; +} + + +mscclppResult_t mscclppCommInitRank(mscclppComm_t* comm, int nranks, int rank, char* ip_port_pair){ + mscclppResult_t res = mscclppSuccess; + mscclppComm_t _comm = NULL; + MSCCLPPCHECKGOTO(mscclppCalloc(&_comm, 1), res, fail); + _comm->rank = rank; + _comm->nRanks = nranks; + + MSCCLPPCHECK(bootstrapNetInit(ip_port_pair)); + mscclppBootstrapHandle handle; + MSCCLPPCHECK(bootstrapGetUniqueId(&handle, rank == 0, ip_port_pair)); + _comm->magic = handle.magic; + + MSCCLPPCHECKGOTO(mscclppCudaHostCalloc((uint32_t **)&_comm->abortFlag, 1), res, fail); + MSCCLPPCHECK(bootstrapInit(&handle, _comm)); + *comm = _comm; + return res; +fail: + if (_comm) { + if (_comm->abortFlag) mscclppCudaHostFree((void *)_comm->abortFlag); + free(_comm); + } + if (comm) *comm = NULL; + return res; +} + +mscclppResult_t mscclppCommDestroy(mscclppComm_t comm){ + if (comm == NULL) + return mscclppSuccess; + + if (comm->bootstrap) + MSCCLPPCHECK(bootstrapClose(comm->bootstrap)); + + mscclppCudaHostFree((void *)comm->abortFlag); + free(comm); + return mscclppSuccess; } \ No newline at end of file diff --git a/src/bootstrap/socket.cc b/src/bootstrap/socket.cc index 4f0cede5..48c05752 100644 --- a/src/bootstrap/socket.cc +++ b/src/bootstrap/socket.cc @@ -339,8 +339,8 @@ int mscclppFindInterfaces(char* ifNames, union mscclppSocketAddress *ifAddrs, in if (nIfs == 0) { char* commId = getenv("MSCCLPP_COMM_ID"); if (commId && strlen(commId) > 1) { - INFO(MSCCLPP_ENV, "MSCCLPP_COMM_ID set by environment to %s", commId); - // Try to find interface that is in the same subnet as the IP in comm id + INFO(MSCCLPP_ENV, "MSCCLPP_COMM_ID set by environment to %s", commId); + // Try to find interface that is in the same subnet as the IP in comm id union mscclppSocketAddress idAddr; mscclppSocketGetAddrFromString(&idAddr, commId); nIfs = mscclppFindInterfaceMatchSubnet(ifNames, ifAddrs, &idAddr, ifNameMaxSize, maxIfs); diff --git a/src/include/bootstrap.h b/src/include/bootstrap.h index 82868434..183c9b1f 100644 --- a/src/include/bootstrap.h +++ b/src/include/bootstrap.h @@ -18,9 +18,9 @@ struct mscclppBootstrapHandle { }; static_assert(sizeof(struct mscclppBootstrapHandle) <= sizeof(mscclppUniqueId), "Bootstrap handle is too large to fit inside MSCCLPP unique ID"); -mscclppResult_t bootstrapNetInit(); +mscclppResult_t bootstrapNetInit(char* ip_port_pair = NULL); mscclppResult_t bootstrapCreateRoot(struct mscclppBootstrapHandle* handle, bool idFromEnv); -mscclppResult_t bootstrapGetUniqueId(struct mscclppBootstrapHandle* handle, bool isRoot = true); +mscclppResult_t bootstrapGetUniqueId(struct mscclppBootstrapHandle* handle, bool isRoot = true, char* ip_port_pair = NULL); mscclppResult_t bootstrapInit(struct mscclppBootstrapHandle* handle, struct mscclppComm* comm); mscclppResult_t bootstrapAllGather(void* commState, void* allData, int size); mscclppResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int size); diff --git a/src/include/mscclpp.h b/src/include/mscclpp.h index 0683f03f..cc0534d6 100644 --- a/src/include/mscclpp.h +++ b/src/include/mscclpp.h @@ -12,6 +12,8 @@ #define MSCCLPP_VERSION (MSCCLPP_MAJOR * 100 + MSCCLPP_MINOR) +typedef struct mscclppComm* mscclppComm_t; + #define MSCCLPP_UNIQUE_ID_BYTES 128 typedef struct { char internal[MSCCLPP_UNIQUE_ID_BYTES]; } mscclppUniqueId; @@ -67,7 +69,11 @@ typedef enum { mscclppInt8 = 0, mscclppChar = 0, } mscclppDataType_t; +mscclppResult_t mscclppCommInitRank(mscclppComm_t* comm, int nranks, int rank, char* ip_port_pair); + +mscclppResult_t mscclppBootStrapAllGather(mscclppComm_t comm, void* data, int size); + //mscclppResult_t mscclppCommInitRank(mscclppComm_t* comm, int nranks, mscclppUniqueId commId, int rank); -//mscclppResult_t mscclppCommDestroy(mscclppComm_t comm); +mscclppResult_t mscclppCommDestroy(mscclppComm_t comm); #endif // MSCCLPP_H_