diff --git a/tests/allgather_test.cu b/tests/allgather_test.cu index 48e085b6..e004e72a 100644 --- a/tests/allgather_test.cu +++ b/tests/allgather_test.cu @@ -82,7 +82,44 @@ __device__ void allgather1(mscclppDevConn_t devConn, int rank, int world_size, i devConn.wait(); } -__global__ void kernel(int rank, int world_size, int nelemsPerGPU, int kernel) +__device__ void allgather2(mscclppDevConn_t devConn, int rank, int world_size, int nranksPerNode, int remoteRank, int nelemsPerGPU) +{ + if (remoteRank % nranksPerNode == rank % nranksPerNode){ + devConn.putWithSignal(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int)); + for (int i = 1; i < nranksPerNode; i++) + __syncthreads(); + devConn.wait(); + devConn.flush(); + __syncthreads(); + } else if (remoteRank / nranksPerNode == rank / nranksPerNode) { + remoteRank = remoteRank % nranksPerNode; + for (int i = 1; i < nranksPerNode; i++) { + if (remoteRank == ((rank + i) % nranksPerNode)){ + // put your data to GPU (rank+i) % world_size and signal all in one call + devConn.putWithSignalAndFlush(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int)); + } + __syncthreads(); + } + // all connections wait for the signal from the sender + devConn.wait(); + + __syncthreads(); + int nodeNghr = (rank + nranksPerNode) % world_size; + for (int i = 1; i < nranksPerNode; i++) { + if (remoteRank == ((rank + i) % nranksPerNode)){ + // put your data to GPU (rank+i) % world_size and signal all in one call + devConn.putWithSignalAndFlush(nodeNghr * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int)); + } + __syncthreads(); + } + // all connections wait for the signal from the sender + devConn.wait(); + } else { + return; + } +} + +__global__ void kernel(int rank, int world_size, int nranksPerNode, int nelemsPerGPU, int kernel) { // only use a single thread from each warp if (threadIdx.x % 32 != 0) @@ -98,6 +135,8 @@ __global__ void kernel(int rank, int world_size, int nelemsPerGPU, int kernel) allgather0(devConn, rank, world_size, remoteRank, nelemsPerGPU); else if (kernel == 1) allgather1(devConn, rank, world_size, remoteRank, nelemsPerGPU); + else if (kernel == 2) + allgather2(devConn, rank, world_size, nranksPerNode, remoteRank, nelemsPerGPU); } int rankToLocalRank(int rank) @@ -141,6 +180,7 @@ mscclppResult_t setupMscclppConnections(int rank, int world_size, mscclppComm_t { int thisNode = rankToNode(rank); int cudaNum = rankToLocalRank(rank); + int map[8] = {2,0,6,4,3,1,7,5}; std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum); for (int r = 0; r < world_size; ++r) { @@ -338,7 +378,7 @@ int main(int argc, const char* argv[]) cudaStream_t stream; CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); CUDACHECK(cudaDeviceSynchronize()); - kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nelemsPerGPU, kernelNum); + kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum); CUDACHECK(cudaDeviceSynchronize()); CUDACHECK(cudaMemcpy(data_h, data_d, dataSize, cudaMemcpyDeviceToHost)); CUDACHECK(cudaDeviceSynchronize()); @@ -361,7 +401,7 @@ int main(int argc, const char* argv[]) if (rank == 0) printf("Running %d iterations of the kernel without CUDA graph\n", iterwithoutcudagraph); for (int i = 0; i < iterwithoutcudagraph; ++i) { - kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nelemsPerGPU, kernelNum); + kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum); } CUDACHECK(cudaDeviceSynchronize()); MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int))); @@ -374,7 +414,7 @@ int main(int argc, const char* argv[]) cudaGraphExec_t instance; cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal); for (int i = 0; i < cudagraphiter; ++i) { - kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nelemsPerGPU, kernelNum); + kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum); } cudaStreamEndCapture(stream, &graph); cudaGraphInstantiate(&instance, graph, NULL, NULL, 0);