diff --git a/src/bootstrap/init.cc b/src/bootstrap/init.cc index 5cfee437..56d5110b 100644 --- a/src/bootstrap/init.cc +++ b/src/bootstrap/init.cc @@ -45,8 +45,8 @@ mscclppResult_t mscclppGetUniqueId(mscclppUniqueId* out) { return res; } -MSCCLPP_API(mscclppResult_t, mscclppBootStrapAllGather, mscclppComm_t comm, void* data, int size); -mscclppResult_t mscclppBootStrapAllGather(mscclppComm_t comm, void* data, int size){ +MSCCLPP_API(mscclppResult_t, mscclppBootatrapAllGather, mscclppComm_t comm, void* data, int size); +mscclppResult_t mscclppBootatrapAllGather(mscclppComm_t comm, void* data, int size){ MSCCLPPCHECK(bootstrapAllGather(comm->bootstrap, data, size)); return mscclppSuccess; } diff --git a/src/include/mscclpp.h b/src/include/mscclpp.h index ce392736..576adb87 100644 --- a/src/include/mscclpp.h +++ b/src/include/mscclpp.h @@ -75,7 +75,7 @@ typedef enum { mscclppInt8 = 0, mscclppChar = 0, mscclppResult_t mscclppCommInitRank(mscclppComm_t* comm, int nranks, int rank, const char* ip_port_pair); -mscclppResult_t mscclppBootStrapAllGather(mscclppComm_t comm, void* data, int size); +mscclppResult_t mscclppBootatrapAllGather(mscclppComm_t comm, void* data, int size); mscclppResult_t mscclppCommDestroy(mscclppComm_t comm); diff --git a/tests/bootstrap_test.cc b/tests/bootstrap_test.cc index 42ad0ac3..a16d2345 100644 --- a/tests/bootstrap_test.cc +++ b/tests/bootstrap_test.cc @@ -3,6 +3,14 @@ #include #include +#define MSCCLPPCHECK(call) do { \ + mscclppResult_t res = call; \ + if (res != mscclppSuccess && res != mscclppInProgress) { \ + /* Print the back trace*/ \ + printf("Failure at %s:%d -> %d", __FILE__, __LINE__, res); \ + return res; \ + } \ +} while (0); void print_usage(const char *prog) { @@ -21,22 +29,20 @@ int main(int argc, const char *argv[]) int rank = atoi(argv[2]); int world_size = atoi(argv[3]); - // sleep(10); - - mscclppCommInitRank(&comm, world_size, rank, ip_port); + MSCCLPPCHECK(mscclppCommInitRank(&comm, world_size, rank, ip_port)); + // allocate some test buffer int *buf = (int *)calloc(world_size, sizeof(int)); if (buf == nullptr) { printf("calloc failed\n"); return -1; } + // each rank sets one element in the array buf[rank] = rank; - mscclppResult_t res = mscclppBootStrapAllGather(comm, buf, sizeof(int)); - if (res != mscclppSuccess) { - printf("bootstrapAllGather failed\n"); - return -1; - } + MSCCLPPCHECK(mscclppBootatrapAllGather(comm, buf, sizeof(int))); + + // check the correctness of all elements in the output of AllGather for (int i = 0; i < world_size; ++i) { if (buf[i] != i) { printf("wrong data: %d, expected %d\n", buf[i], i); @@ -44,12 +50,8 @@ int main(int argc, const char *argv[]) } } - res = mscclppCommDestroy(comm); - if (res != mscclppSuccess) { - printf("mscclppDestroy failed\n"); - return -1; - } + MSCCLPPCHECK(mscclppCommDestroy(comm)); - printf("Succeeded! %d\n", rank); + printf("Rank %d Succeeded\n", rank); return 0; } diff --git a/tests/bootstrap_test_mpi.cc b/tests/bootstrap_test_mpi.cc index 3ebb064d..12be5f3f 100644 --- a/tests/bootstrap_test_mpi.cc +++ b/tests/bootstrap_test_mpi.cc @@ -32,7 +32,7 @@ int main(int argc, const char *argv[]) return -1; } buf[rank] = rank; - mscclppResult_t res = mscclppBootStrapAllGather(comm, buf, sizeof(int)); + mscclppResult_t res = mscclppBootatrapAllGather(comm, buf, sizeof(int)); if (res != mscclppSuccess) { printf("bootstrapAllGather failed\n"); return -1;