cleaned up the mess

This commit is contained in:
lambda7xx
2023-02-07 04:42:58 +00:00
parent 38c3bf56eb
commit fe7d8097d6
6 changed files with 75 additions and 55 deletions

View File

@@ -24,11 +24,16 @@ static union mscclppSocketAddress bootstrapNetIfAddr;
static int bootstrapNetInitDone = 0;
pthread_mutex_t bootstrapNetLock = PTHREAD_MUTEX_INITIALIZER;
mscclppResult_t bootstrapNetInit() {
mscclppResult_t bootstrapNetInit(char* ip_port_pair) {
if (bootstrapNetInitDone == 0) {
pthread_mutex_lock(&bootstrapNetLock);
if (bootstrapNetInitDone == 0) {
char* env = getenv("MSCCLPP_COMM_ID");
char* env;
if (ip_port_pair) {
env = ip_port_pair;
} else {
env = getenv("MSCCLPP_COMM_ID");
}
if (env) {
union mscclppSocketAddress remoteAddr;
if (mscclppSocketGetAddrFromString(&remoteAddr, env) != mscclppSuccess) {
@@ -188,11 +193,16 @@ mscclppResult_t bootstrapCreateRoot(struct mscclppBootstrapHandle* handle, bool
// #include <netinet/in.h>
// #include <arpa/inet.h>
mscclppResult_t bootstrapGetUniqueId(struct mscclppBootstrapHandle* handle, bool isRoot) {
mscclppResult_t bootstrapGetUniqueId(struct mscclppBootstrapHandle* handle, bool isRoot, char* ip_port_pair) {
memset(handle, 0, sizeof(mscclppBootstrapHandle));
// MSCCLPPCHECK(getRandomData(&handle->magic, sizeof(handle->magic)));
handle->magic = 0xdeadbeef;
char* env = getenv("MSCCLPP_COMM_ID");
char* env;
if (ip_port_pair) {
env = ip_port_pair;
} else {
env = getenv("MSCCLPP_COMM_ID");
}
if (env) {
INFO(MSCCLPP_ENV, "MSCCLPP_COMM_ID set by environment to %s", env);
if (mscclppSocketGetAddrFromString(&handle->addr, env) != mscclppSuccess) {

View File

@@ -1,4 +1,4 @@
#include "bootstrap.h"
#include "mscclpp.h"
#include "alloc.h"
#include "mpi.h"
#include <stdio.h>
@@ -11,49 +11,10 @@ int main()
int world_size;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
// int a;
// scanf("%d", &a);
mscclppResult_t res = bootstrapNetInit();
if (res != mscclppSuccess) {
printf("bootstrapNetInit failed\n");
return -1;
}
mscclppBootstrapHandle handle;
if (true || rank == 0) {
res = bootstrapGetUniqueId(&handle, rank == 0);
if (res != mscclppSuccess) {
printf("bootstrapGetUniqueId failed\n");
return -1;
}
}
// MPI_Bcast(&handle, sizeof(mscclppBootstrapHandle), MPI_BYTE, 0, MPI_COMM_WORLD);
mscclppComm *comm;
res = mscclppCalloc(&comm, 1);
if (res != mscclppSuccess) {
printf("mscclppCalloc failed\n");
return -1;
}
comm->magic = 0xdeadbeef;
comm->rank = rank;
comm->nRanks = world_size;
res = mscclppCudaHostCalloc((uint32_t **)&comm->abortFlag, 1);
if (res != mscclppSuccess) {
printf("mscclppCudaHostCalloc failed\n");
return -1;
}
res = bootstrapInit(&handle, comm);
if (res != mscclppSuccess) {
printf("bootstrapInit failed\n");
return -1;
}
printf("bootstrapInit done\n");
mscclppComm_t comm;
char ip_port[] = "192.168.0.32:50000";
mscclppCommInitRank(&comm, world_size, rank, ip_port);
int *buf = (int *)calloc(world_size, sizeof(int));
if (buf == nullptr) {
@@ -61,7 +22,7 @@ int main()
return -1;
}
buf[rank] = rank;
res = bootstrapAllGather(comm->bootstrap, buf, sizeof(int));
mscclppResult_t res = mscclppBootStrapAllGather(comm, buf, sizeof(int));
if (res != mscclppSuccess) {
printf("bootstrapAllGather failed\n");
return -1;
@@ -74,9 +35,9 @@ int main()
}
}
res = bootstrapClose(comm->bootstrap);
res = mscclppCommDestroy(comm);
if (res != mscclppSuccess) {
printf("bootstrapClose failed\n");
printf("mscclppDestroy failed\n");
return -1;
}

View File

@@ -43,4 +43,47 @@ mscclppResult_t mscclppGetUniqueId(mscclppUniqueId* out) {
mscclppResult_t res = bootstrapGetUniqueId((struct mscclppBootstrapHandle*)out);
TRACE_CALL("mscclppGetUniqueId(0x%llx)", (unsigned long long)hashUniqueId(*out));
return res;
}
mscclppResult_t mscclppBootStrapAllGather(mscclppComm_t comm, void* data, int size){
MSCCLPPCHECK(bootstrapAllGather(comm->bootstrap, data, size));
return mscclppSuccess;
}
mscclppResult_t mscclppCommInitRank(mscclppComm_t* comm, int nranks, int rank, char* ip_port_pair){
mscclppResult_t res = mscclppSuccess;
mscclppComm_t _comm = NULL;
MSCCLPPCHECKGOTO(mscclppCalloc(&_comm, 1), res, fail);
_comm->rank = rank;
_comm->nRanks = nranks;
MSCCLPPCHECK(bootstrapNetInit(ip_port_pair));
mscclppBootstrapHandle handle;
MSCCLPPCHECK(bootstrapGetUniqueId(&handle, rank == 0, ip_port_pair));
_comm->magic = handle.magic;
MSCCLPPCHECKGOTO(mscclppCudaHostCalloc((uint32_t **)&_comm->abortFlag, 1), res, fail);
MSCCLPPCHECK(bootstrapInit(&handle, _comm));
*comm = _comm;
return res;
fail:
if (_comm) {
if (_comm->abortFlag) mscclppCudaHostFree((void *)_comm->abortFlag);
free(_comm);
}
if (comm) *comm = NULL;
return res;
}
mscclppResult_t mscclppCommDestroy(mscclppComm_t comm){
if (comm == NULL)
return mscclppSuccess;
if (comm->bootstrap)
MSCCLPPCHECK(bootstrapClose(comm->bootstrap));
mscclppCudaHostFree((void *)comm->abortFlag);
free(comm);
return mscclppSuccess;
}

View File

@@ -339,8 +339,8 @@ int mscclppFindInterfaces(char* ifNames, union mscclppSocketAddress *ifAddrs, in
if (nIfs == 0) {
char* commId = getenv("MSCCLPP_COMM_ID");
if (commId && strlen(commId) > 1) {
INFO(MSCCLPP_ENV, "MSCCLPP_COMM_ID set by environment to %s", commId);
// Try to find interface that is in the same subnet as the IP in comm id
INFO(MSCCLPP_ENV, "MSCCLPP_COMM_ID set by environment to %s", commId);
// Try to find interface that is in the same subnet as the IP in comm id
union mscclppSocketAddress idAddr;
mscclppSocketGetAddrFromString(&idAddr, commId);
nIfs = mscclppFindInterfaceMatchSubnet(ifNames, ifAddrs, &idAddr, ifNameMaxSize, maxIfs);

View File

@@ -18,9 +18,9 @@ struct mscclppBootstrapHandle {
};
static_assert(sizeof(struct mscclppBootstrapHandle) <= sizeof(mscclppUniqueId), "Bootstrap handle is too large to fit inside MSCCLPP unique ID");
mscclppResult_t bootstrapNetInit();
mscclppResult_t bootstrapNetInit(char* ip_port_pair = NULL);
mscclppResult_t bootstrapCreateRoot(struct mscclppBootstrapHandle* handle, bool idFromEnv);
mscclppResult_t bootstrapGetUniqueId(struct mscclppBootstrapHandle* handle, bool isRoot = true);
mscclppResult_t bootstrapGetUniqueId(struct mscclppBootstrapHandle* handle, bool isRoot = true, char* ip_port_pair = NULL);
mscclppResult_t bootstrapInit(struct mscclppBootstrapHandle* handle, struct mscclppComm* comm);
mscclppResult_t bootstrapAllGather(void* commState, void* allData, int size);
mscclppResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int size);

View File

@@ -12,6 +12,8 @@
#define MSCCLPP_VERSION (MSCCLPP_MAJOR * 100 + MSCCLPP_MINOR)
typedef struct mscclppComm* mscclppComm_t;
#define MSCCLPP_UNIQUE_ID_BYTES 128
typedef struct { char internal[MSCCLPP_UNIQUE_ID_BYTES]; } mscclppUniqueId;
@@ -67,7 +69,11 @@ typedef enum { mscclppInt8 = 0, mscclppChar = 0,
} mscclppDataType_t;
mscclppResult_t mscclppCommInitRank(mscclppComm_t* comm, int nranks, int rank, char* ip_port_pair);
mscclppResult_t mscclppBootStrapAllGather(mscclppComm_t comm, void* data, int size);
//mscclppResult_t mscclppCommInitRank(mscclppComm_t* comm, int nranks, mscclppUniqueId commId, int rank);
//mscclppResult_t mscclppCommDestroy(mscclppComm_t comm);
mscclppResult_t mscclppCommDestroy(mscclppComm_t comm);
#endif // MSCCLPP_H_