diff --git a/src/include/mscclpp.h b/src/include/mscclpp.h index 9ba60b4c..3827f9fd 100644 --- a/src/include/mscclpp.h +++ b/src/include/mscclpp.h @@ -8,15 +8,6 @@ #endif #include -#define MSCCLPPCHECK(call) do { \ - mscclppResult_t res = call; \ - if (res != mscclppSuccess && res != mscclppInProgress) { \ - /* Print the back trace*/ \ - printf("Failure at %s:%d -> %d\n", __FILE__, __LINE__, res); \ - return res; \ - } \ -} while (0) - #define MSCCLPP_MAJOR 0 #define MSCCLPP_MINOR 1 #define MSCCLPP_PROXY_FIFO_SIZE 8 diff --git a/src/init.cc b/src/init.cc index 8743b2e3..50ee81a0 100644 --- a/src/init.cc +++ b/src/init.cc @@ -170,7 +170,7 @@ mscclppResult_t mscclppCommDestroy(mscclppComm_t comm){ MSCCLPP_API(mscclppResult_t, mscclppConnect, mscclppComm_t comm, mscclppDevConn* devConnOut, int remoteRank, void* localBuff, size_t buffSize, int tag, mscclppTransport_t transportType, const char *ibDev); mscclppResult_t mscclppConnect(mscclppComm_t comm, mscclppDevConn* devConnOut, int remoteRank, void* localBuff, size_t buffSize, - int tag, mscclppTransport_t transportType, const char *ibDev/*=NULL*/) + int tag, mscclppTransport_t transportType, const char *ibDev) { if (comm->nConns == MAXCONNECTIONS) { WARN("Too many connections made"); diff --git a/tests/allgather_test.cu b/tests/allgather_test.cu index 0e1c96af..911a2bfa 100644 --- a/tests/allgather_test.cu +++ b/tests/allgather_test.cu @@ -1,4 +1,5 @@ #include "mscclpp.h" + #ifdef MSCCLPP_USE_MPI_FOR_TESTS #include "mpi.h" #endif // MSCCLPP_USE_MPI_FOR_TESTS @@ -9,6 +10,18 @@ #define RANKS_PER_NODE 8 +// Propagate errors up + +#define MSCCLPPCHECK(call) do { \ + mscclppResult_t res = call; \ + if (res != mscclppSuccess && res != mscclppInProgress) { \ + /* Print the back trace*/ \ + printf("Failure at %s:%d -> %d\n", __FILE__, __LINE__, res); \ + return res; \ + } \ +} while (0) + + // Check CUDA RT calls #define CUDACHECK(cmd) do { \ cudaError_t err = cmd; \ @@ -29,6 +42,8 @@ static double getTime(void) return (tspec.tv_nsec / 1.0e9) + tspec.tv_sec; } +mscclppComm_t comm; +mscclppDevConn_t devConns[16]; __constant__ mscclppDevConn_t constDevConns[16]; __global__ void kernel(int rank, int world_size, int nelemsPerGPU) @@ -103,6 +118,31 @@ void initializeAndAllocateAllGatherData(int rank, int world_size, size_t data_si CUDACHECK(cudaMemcpy(*data_d, *data_h, data_size, cudaMemcpyHostToDevice)); } +mscclppResult_t setupMscclppConnections(int rank, int world_size, mscclppComm_t comm, int* data_d, size_t data_size){ + int thisNode = rankToNode(rank); + int cudaNum = rankToLocalRank(rank); + std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum); + + for (int r = 0; r < world_size; ++r) { + if (r == rank) continue; + mscclppTransport_t transportType; + const char* ibDev = ibDevStr.c_str(); + if (rankToNode(r) == thisNode){ + ibDev = NULL; + transportType = mscclppTransportP2P; + } else { + transportType = mscclppTransportIB; + } + // Connect with all other ranks + MSCCLPPCHECK(mscclppConnect(comm, &devConns[r], r, data_d, data_size, 0, transportType, ibDev)); + } + + MSCCLPPCHECK(mscclppConnectionSetup(comm)); + CUDACHECK(cudaMemcpyToSymbol(constDevConns, devConns, sizeof(mscclppDevConn_t) * world_size)); + + return mscclppSuccess; +} + int main(int argc, const char *argv[]) { #ifdef MSCCLPP_USE_MPI_FOR_TESTS @@ -133,11 +173,8 @@ int main(int argc, const char *argv[]) int thisNode = rankToNode(rank); int cudaNum = rankToLocalRank(rank); - CUDACHECK(cudaSetDevice(cudaNum)); - std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum); - mscclppComm_t comm; MSCCLPPCHECK(mscclppCommInitRank(&comm, world_size, rank, ip_port)); int *data_d; @@ -147,31 +184,13 @@ int main(int argc, const char *argv[]) initializeAndAllocateAllGatherData(rank, world_size, data_size, nelemsPerGPU, &data_h, &data_d); - mscclppDevConn_t devConns[16]; - for (int r = 0; r < world_size; ++r) { - if (r == rank) continue; - mscclppTransport_t transportType; - const char* ibDev = ibDevStr.c_str(); - if (rankToNode(r) == thisNode){ - ibDev = NULL; - transportType = mscclppTransportP2P; - } else { - transportType = mscclppTransportIB; - } - // Connect with all other ranks - MSCCLPPCHECK(mscclppConnect(comm, &devConns[r], r, data_d, data_size, 0, transportType, ibDev)); - } - - MSCCLPPCHECK(mscclppConnectionSetup(comm)); + MSCCLPPCHECK(setupMscclppConnections(rank, world_size, comm, data_d, data_size)); MSCCLPPCHECK(mscclppProxyLaunch(comm)); - CUDACHECK(cudaMemcpyToSymbol(constDevConns, devConns, sizeof(mscclppDevConn_t) * world_size)); - cudaStream_t stream; CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); - CUDACHECK(cudaDeviceSynchronize()); kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nelemsPerGPU); CUDACHECK(cudaDeviceSynchronize()); diff --git a/tests/allgather_test2.cu b/tests/allgather_test2.cu index 80df881a..6b65c880 100644 --- a/tests/allgather_test2.cu +++ b/tests/allgather_test2.cu @@ -9,15 +9,6 @@ #define RANKS_PER_NODE 8 -#define MSCCLPPCHECK(call) do { \ - mscclppResult_t res = call; \ - if (res != mscclppSuccess && res != mscclppInProgress) { \ - /* Print the back trace*/ \ - printf("Failure at %s:%d -> %d\n", __FILE__, __LINE__, res); \ - return res; \ - } \ -} while (0); - // Check CUDA RT calls #define CUDACHECK(cmd) do { \ cudaError_t err = cmd; \