From 38c3bf56eb80d50b8961839288739a3dc8713db3 Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Mon, 6 Feb 2023 23:04:03 +0000 Subject: [PATCH] works without bcast --- src/bootstrap/bootstrap.cc | 13 ++++++++++--- src/bootstrap/bootstrap_test.cc | 10 ++++++---- src/include/bootstrap.h | 2 +- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/bootstrap/bootstrap.cc b/src/bootstrap/bootstrap.cc index 504ec72b..c1724a55 100644 --- a/src/bootstrap/bootstrap.cc +++ b/src/bootstrap/bootstrap.cc @@ -185,10 +185,13 @@ mscclppResult_t bootstrapCreateRoot(struct mscclppBootstrapHandle* handle, bool return mscclppSuccess; } -mscclppResult_t bootstrapGetUniqueId(struct mscclppBootstrapHandle* handle) { - memset(handle, 0, sizeof(mscclppBootstrapHandle)); - MSCCLPPCHECK(getRandomData(&handle->magic, sizeof(handle->magic))); +// #include +// #include +mscclppResult_t bootstrapGetUniqueId(struct mscclppBootstrapHandle* handle, bool isRoot) { + memset(handle, 0, sizeof(mscclppBootstrapHandle)); + // MSCCLPPCHECK(getRandomData(&handle->magic, sizeof(handle->magic))); + handle->magic = 0xdeadbeef; char* env = getenv("MSCCLPP_COMM_ID"); if (env) { INFO(MSCCLPP_ENV, "MSCCLPP_COMM_ID set by environment to %s", env); @@ -196,10 +199,14 @@ mscclppResult_t bootstrapGetUniqueId(struct mscclppBootstrapHandle* handle) { WARN("Invalid MSCCLPP_COMM_ID, please use format: : or []: or :"); return mscclppInvalidArgument; } + if (isRoot) + MSCCLPPCHECK(bootstrapCreateRoot(handle, false)); } else { memcpy(&handle->addr, &bootstrapNetIfAddr, sizeof(union mscclppSocketAddress)); MSCCLPPCHECK(bootstrapCreateRoot(handle, false)); } + // printf("addr = %s port = %d\n", inet_ntoa(handle->addr.sin.sin_addr), (int)ntohs(handle->addr.sin.sin_port)); + // printf("addr = %s\n", inet_ntoa((*(struct sockaddr_in*)&handle->addr.sa).sin_addr)); return mscclppSuccess; } diff --git a/src/bootstrap/bootstrap_test.cc b/src/bootstrap/bootstrap_test.cc index 1035e88c..f02fa7c7 100644 --- a/src/bootstrap/bootstrap_test.cc +++ b/src/bootstrap/bootstrap_test.cc @@ -11,6 +11,8 @@ int main() int world_size; MPI_Comm_rank(MPI_COMM_WORLD, &rank); MPI_Comm_size(MPI_COMM_WORLD, &world_size); + // int a; + // scanf("%d", &a); mscclppResult_t res = bootstrapNetInit(); if (res != mscclppSuccess) { @@ -19,15 +21,15 @@ int main() } mscclppBootstrapHandle handle; - if (rank == 0) { - res = bootstrapGetUniqueId(&handle); + if (true || rank == 0) { + res = bootstrapGetUniqueId(&handle, rank == 0); if (res != mscclppSuccess) { printf("bootstrapGetUniqueId failed\n"); return -1; } } - MPI_Bcast(&handle, sizeof(mscclppBootstrapHandle), MPI_BYTE, 0, MPI_COMM_WORLD); + // MPI_Bcast(&handle, sizeof(mscclppBootstrapHandle), MPI_BYTE, 0, MPI_COMM_WORLD); mscclppComm *comm; res = mscclppCalloc(&comm, 1); @@ -80,6 +82,6 @@ int main() MPI_Finalize(); - printf("Succeeded!\n"); + printf("Succeeded! %d\n", rank); return 0; } diff --git a/src/include/bootstrap.h b/src/include/bootstrap.h index 2d2b36b0..82868434 100644 --- a/src/include/bootstrap.h +++ b/src/include/bootstrap.h @@ -20,7 +20,7 @@ static_assert(sizeof(struct mscclppBootstrapHandle) <= sizeof(mscclppUniqueId), mscclppResult_t bootstrapNetInit(); mscclppResult_t bootstrapCreateRoot(struct mscclppBootstrapHandle* handle, bool idFromEnv); -mscclppResult_t bootstrapGetUniqueId(struct mscclppBootstrapHandle* handle); +mscclppResult_t bootstrapGetUniqueId(struct mscclppBootstrapHandle* handle, bool isRoot = true); mscclppResult_t bootstrapInit(struct mscclppBootstrapHandle* handle, struct mscclppComm* comm); mscclppResult_t bootstrapAllGather(void* commState, void* allData, int size); mscclppResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int size);