diff --git a/tests/p2p_test_mpi.cu b/tests/p2p_test_mpi.cu index 217fedc3..4005d111 100644 --- a/tests/p2p_test_mpi.cu +++ b/tests/p2p_test_mpi.cu @@ -4,6 +4,15 @@ #include #include +#define MSCCLPPCHECK(call) do { \ + mscclppResult_t res = call; \ + if (res != mscclppSuccess && res != mscclppInProgress) { \ + /* Print the back trace*/ \ + printf("Failure at %s:%d -> %d", __FILE__, __LINE__, res); \ + return res; \ + } \ +} while (0); + // Check CUDA RT calls #define CUDACHECK(cmd) do { \ cudaError_t err = cmd; \ @@ -82,42 +91,20 @@ int main(int argc, const char *argv[]) mscclppResult_t res; - // if (rank == 0) - // sleep(10); - // else - // sleep(10); - mscclppDevConn_t devConns[8]; // Read from all other ranks for (int r = 0; r < world_size; ++r) { if (r == rank) continue; int tag = 0; - res = mscclppConnect(comm, &devConns[r], r, data_d, flag_d, tag, mscclppTransportP2P); - if (res != mscclppSuccess) { - printf("mscclppConnect failed\n"); - return -1; - } - } - - // Let others read from me - // for (int r = 0; r < world_size; ++r) { - // if (r == rank) continue; - // int tag = r * world_size + rank; - // res = mscclppConnect(comm, r, rank, data_d, flag_d, tag, mscclppTransportP2P); - // if (res != mscclppSuccess) { - // printf("mscclppConnect failed\n"); - // return -1; - // } - // } - - res = mscclppConnectionSetup(comm); - if (res != mscclppSuccess) { - printf("mscclppConnectionSetup failed\n"); - return -1; + MSCCLPPCHECK(mscclppConnect(comm, &devConns[r], r, data_d, flag_d, tag, mscclppTransportP2P)); } + MSCCLPPCHECK(mscclppConnectionSetup(comm)); CUDACHECK(cudaMemcpyToSymbol(constDevConns, devConns, sizeof(mscclppDevConn_t) * world_size)); + + kernel<<<1, 1>>>(rank, world_size); + CUDACHECK(cudaDeviceSynchronize()); int *buf = (int *)calloc(world_size, sizeof(int)); @@ -134,11 +121,7 @@ int main(int argc, const char *argv[]) } } - res = mscclppCommDestroy(comm); - if (res != mscclppSuccess) { - printf("mscclppDestroy failed\n"); - return -1; - } + MSCCLPPCHECK(mscclppCommDestroy(comm)); MPI_Finalize();