diff --git a/src/include/mscclpp.h b/src/include/mscclpp.h index 0b4db9d5..de6edbc3 100644 --- a/src/include/mscclpp.h +++ b/src/include/mscclpp.h @@ -241,6 +241,9 @@ mscclppResult_t mscclppCommInitRankFromId(mscclppComm_t* comm, int nranks, msccl */ mscclppResult_t mscclppBootstrapAllGather(mscclppComm_t comm, void* data, int size); +/* A no-op function that is used to synchronize all processes via a bootstrap allgather*/ +mscclppResult_t mscclppBootstrapBarrier(mscclppComm_t comm); + /* Destroy a communicator. * * Inputs: diff --git a/src/init.cc b/src/init.cc index 645a6683..8199073e 100644 --- a/src/init.cc +++ b/src/init.cc @@ -505,6 +505,9 @@ mscclppResult_t mscclppConnectionSetup(mscclppComm_t comm) MSCCLPPCHECK(mscclppIbConnectionSetupEnd(&cInfo, conn)); } } + + // a barrier to ensure setup on all gpus are done and we can return to the user + MSCCLPPCHECK(mscclppBootstrapBarrier(comm)); return mscclppSuccess; } @@ -515,12 +518,21 @@ mscclppResult_t mscclppProxyLaunch(mscclppComm_t comm) return mscclppSuccess; } +MSCCLPP_API(mscclppResult_t, mscclppBootstrapBarrier, mscclppComm_t comm); +mscclppResult_t mscclppBootstrapBarrier(mscclppComm_t comm) +{ + int* tmp = new int[comm->nRanks]; + MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int))); + delete[] tmp; + return mscclppSuccess; +} + + MSCCLPP_API(mscclppResult_t, mscclppProxyStop, mscclppComm_t comm); mscclppResult_t mscclppProxyStop(mscclppComm_t comm) { // a barrier to make sure all ranks are done with their work before stopping the proxy - int* tmp = new int[comm->nRanks]; - MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int))); + MSCCLPPCHECK(mscclppBootstrapBarrier(comm)); MSCCLPPCHECK(mscclppProxyDestroy(comm)); return mscclppSuccess;