mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-13 01:36:10 +00:00
cleaned up the mess
This commit is contained in:
@@ -24,11 +24,16 @@ static union mscclppSocketAddress bootstrapNetIfAddr;
|
||||
static int bootstrapNetInitDone = 0;
|
||||
pthread_mutex_t bootstrapNetLock = PTHREAD_MUTEX_INITIALIZER;
|
||||
|
||||
mscclppResult_t bootstrapNetInit() {
|
||||
mscclppResult_t bootstrapNetInit(char* ip_port_pair) {
|
||||
if (bootstrapNetInitDone == 0) {
|
||||
pthread_mutex_lock(&bootstrapNetLock);
|
||||
if (bootstrapNetInitDone == 0) {
|
||||
char* env = getenv("MSCCLPP_COMM_ID");
|
||||
char* env;
|
||||
if (ip_port_pair) {
|
||||
env = ip_port_pair;
|
||||
} else {
|
||||
env = getenv("MSCCLPP_COMM_ID");
|
||||
}
|
||||
if (env) {
|
||||
union mscclppSocketAddress remoteAddr;
|
||||
if (mscclppSocketGetAddrFromString(&remoteAddr, env) != mscclppSuccess) {
|
||||
@@ -188,11 +193,16 @@ mscclppResult_t bootstrapCreateRoot(struct mscclppBootstrapHandle* handle, bool
|
||||
// #include <netinet/in.h>
|
||||
// #include <arpa/inet.h>
|
||||
|
||||
mscclppResult_t bootstrapGetUniqueId(struct mscclppBootstrapHandle* handle, bool isRoot) {
|
||||
mscclppResult_t bootstrapGetUniqueId(struct mscclppBootstrapHandle* handle, bool isRoot, char* ip_port_pair) {
|
||||
memset(handle, 0, sizeof(mscclppBootstrapHandle));
|
||||
// MSCCLPPCHECK(getRandomData(&handle->magic, sizeof(handle->magic)));
|
||||
handle->magic = 0xdeadbeef;
|
||||
char* env = getenv("MSCCLPP_COMM_ID");
|
||||
char* env;
|
||||
if (ip_port_pair) {
|
||||
env = ip_port_pair;
|
||||
} else {
|
||||
env = getenv("MSCCLPP_COMM_ID");
|
||||
}
|
||||
if (env) {
|
||||
INFO(MSCCLPP_ENV, "MSCCLPP_COMM_ID set by environment to %s", env);
|
||||
if (mscclppSocketGetAddrFromString(&handle->addr, env) != mscclppSuccess) {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#include "bootstrap.h"
|
||||
#include "mscclpp.h"
|
||||
#include "alloc.h"
|
||||
#include "mpi.h"
|
||||
#include <stdio.h>
|
||||
@@ -11,49 +11,10 @@ int main()
|
||||
int world_size;
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
|
||||
// int a;
|
||||
// scanf("%d", &a);
|
||||
|
||||
mscclppResult_t res = bootstrapNetInit();
|
||||
if (res != mscclppSuccess) {
|
||||
printf("bootstrapNetInit failed\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
mscclppBootstrapHandle handle;
|
||||
if (true || rank == 0) {
|
||||
res = bootstrapGetUniqueId(&handle, rank == 0);
|
||||
if (res != mscclppSuccess) {
|
||||
printf("bootstrapGetUniqueId failed\n");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// MPI_Bcast(&handle, sizeof(mscclppBootstrapHandle), MPI_BYTE, 0, MPI_COMM_WORLD);
|
||||
|
||||
mscclppComm *comm;
|
||||
res = mscclppCalloc(&comm, 1);
|
||||
if (res != mscclppSuccess) {
|
||||
printf("mscclppCalloc failed\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
comm->magic = 0xdeadbeef;
|
||||
comm->rank = rank;
|
||||
comm->nRanks = world_size;
|
||||
res = mscclppCudaHostCalloc((uint32_t **)&comm->abortFlag, 1);
|
||||
if (res != mscclppSuccess) {
|
||||
printf("mscclppCudaHostCalloc failed\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
res = bootstrapInit(&handle, comm);
|
||||
if (res != mscclppSuccess) {
|
||||
printf("bootstrapInit failed\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
printf("bootstrapInit done\n");
|
||||
mscclppComm_t comm;
|
||||
char ip_port[] = "192.168.0.32:50000";
|
||||
mscclppCommInitRank(&comm, world_size, rank, ip_port);
|
||||
|
||||
int *buf = (int *)calloc(world_size, sizeof(int));
|
||||
if (buf == nullptr) {
|
||||
@@ -61,7 +22,7 @@ int main()
|
||||
return -1;
|
||||
}
|
||||
buf[rank] = rank;
|
||||
res = bootstrapAllGather(comm->bootstrap, buf, sizeof(int));
|
||||
mscclppResult_t res = mscclppBootStrapAllGather(comm, buf, sizeof(int));
|
||||
if (res != mscclppSuccess) {
|
||||
printf("bootstrapAllGather failed\n");
|
||||
return -1;
|
||||
@@ -74,9 +35,9 @@ int main()
|
||||
}
|
||||
}
|
||||
|
||||
res = bootstrapClose(comm->bootstrap);
|
||||
res = mscclppCommDestroy(comm);
|
||||
if (res != mscclppSuccess) {
|
||||
printf("bootstrapClose failed\n");
|
||||
printf("mscclppDestroy failed\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
@@ -43,4 +43,47 @@ mscclppResult_t mscclppGetUniqueId(mscclppUniqueId* out) {
|
||||
mscclppResult_t res = bootstrapGetUniqueId((struct mscclppBootstrapHandle*)out);
|
||||
TRACE_CALL("mscclppGetUniqueId(0x%llx)", (unsigned long long)hashUniqueId(*out));
|
||||
return res;
|
||||
}
|
||||
|
||||
mscclppResult_t mscclppBootStrapAllGather(mscclppComm_t comm, void* data, int size){
|
||||
MSCCLPPCHECK(bootstrapAllGather(comm->bootstrap, data, size));
|
||||
return mscclppSuccess;
|
||||
}
|
||||
|
||||
|
||||
mscclppResult_t mscclppCommInitRank(mscclppComm_t* comm, int nranks, int rank, char* ip_port_pair){
|
||||
mscclppResult_t res = mscclppSuccess;
|
||||
mscclppComm_t _comm = NULL;
|
||||
MSCCLPPCHECKGOTO(mscclppCalloc(&_comm, 1), res, fail);
|
||||
_comm->rank = rank;
|
||||
_comm->nRanks = nranks;
|
||||
|
||||
MSCCLPPCHECK(bootstrapNetInit(ip_port_pair));
|
||||
mscclppBootstrapHandle handle;
|
||||
MSCCLPPCHECK(bootstrapGetUniqueId(&handle, rank == 0, ip_port_pair));
|
||||
_comm->magic = handle.magic;
|
||||
|
||||
MSCCLPPCHECKGOTO(mscclppCudaHostCalloc((uint32_t **)&_comm->abortFlag, 1), res, fail);
|
||||
MSCCLPPCHECK(bootstrapInit(&handle, _comm));
|
||||
*comm = _comm;
|
||||
return res;
|
||||
fail:
|
||||
if (_comm) {
|
||||
if (_comm->abortFlag) mscclppCudaHostFree((void *)_comm->abortFlag);
|
||||
free(_comm);
|
||||
}
|
||||
if (comm) *comm = NULL;
|
||||
return res;
|
||||
}
|
||||
|
||||
mscclppResult_t mscclppCommDestroy(mscclppComm_t comm){
|
||||
if (comm == NULL)
|
||||
return mscclppSuccess;
|
||||
|
||||
if (comm->bootstrap)
|
||||
MSCCLPPCHECK(bootstrapClose(comm->bootstrap));
|
||||
|
||||
mscclppCudaHostFree((void *)comm->abortFlag);
|
||||
free(comm);
|
||||
return mscclppSuccess;
|
||||
}
|
||||
@@ -339,8 +339,8 @@ int mscclppFindInterfaces(char* ifNames, union mscclppSocketAddress *ifAddrs, in
|
||||
if (nIfs == 0) {
|
||||
char* commId = getenv("MSCCLPP_COMM_ID");
|
||||
if (commId && strlen(commId) > 1) {
|
||||
INFO(MSCCLPP_ENV, "MSCCLPP_COMM_ID set by environment to %s", commId);
|
||||
// Try to find interface that is in the same subnet as the IP in comm id
|
||||
INFO(MSCCLPP_ENV, "MSCCLPP_COMM_ID set by environment to %s", commId);
|
||||
// Try to find interface that is in the same subnet as the IP in comm id
|
||||
union mscclppSocketAddress idAddr;
|
||||
mscclppSocketGetAddrFromString(&idAddr, commId);
|
||||
nIfs = mscclppFindInterfaceMatchSubnet(ifNames, ifAddrs, &idAddr, ifNameMaxSize, maxIfs);
|
||||
|
||||
@@ -18,9 +18,9 @@ struct mscclppBootstrapHandle {
|
||||
};
|
||||
static_assert(sizeof(struct mscclppBootstrapHandle) <= sizeof(mscclppUniqueId), "Bootstrap handle is too large to fit inside MSCCLPP unique ID");
|
||||
|
||||
mscclppResult_t bootstrapNetInit();
|
||||
mscclppResult_t bootstrapNetInit(char* ip_port_pair = NULL);
|
||||
mscclppResult_t bootstrapCreateRoot(struct mscclppBootstrapHandle* handle, bool idFromEnv);
|
||||
mscclppResult_t bootstrapGetUniqueId(struct mscclppBootstrapHandle* handle, bool isRoot = true);
|
||||
mscclppResult_t bootstrapGetUniqueId(struct mscclppBootstrapHandle* handle, bool isRoot = true, char* ip_port_pair = NULL);
|
||||
mscclppResult_t bootstrapInit(struct mscclppBootstrapHandle* handle, struct mscclppComm* comm);
|
||||
mscclppResult_t bootstrapAllGather(void* commState, void* allData, int size);
|
||||
mscclppResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int size);
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
|
||||
#define MSCCLPP_VERSION (MSCCLPP_MAJOR * 100 + MSCCLPP_MINOR)
|
||||
|
||||
typedef struct mscclppComm* mscclppComm_t;
|
||||
|
||||
|
||||
#define MSCCLPP_UNIQUE_ID_BYTES 128
|
||||
typedef struct { char internal[MSCCLPP_UNIQUE_ID_BYTES]; } mscclppUniqueId;
|
||||
@@ -67,7 +69,11 @@ typedef enum { mscclppInt8 = 0, mscclppChar = 0,
|
||||
} mscclppDataType_t;
|
||||
|
||||
|
||||
mscclppResult_t mscclppCommInitRank(mscclppComm_t* comm, int nranks, int rank, char* ip_port_pair);
|
||||
|
||||
mscclppResult_t mscclppBootStrapAllGather(mscclppComm_t comm, void* data, int size);
|
||||
|
||||
//mscclppResult_t mscclppCommInitRank(mscclppComm_t* comm, int nranks, mscclppUniqueId commId, int rank);
|
||||
//mscclppResult_t mscclppCommDestroy(mscclppComm_t comm);
|
||||
mscclppResult_t mscclppCommDestroy(mscclppComm_t comm);
|
||||
|
||||
#endif // MSCCLPP_H_
|
||||
|
||||
Reference in New Issue
Block a user