From 5ff64d36f42942e672652ff189924c7307851d46 Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Sun, 2 Apr 2023 03:36:22 +0000 Subject: [PATCH] documents for allgather2 + refactoring local allgather --- tests/allgather_test.cu | 117 ++++++++++++++++++++-------------------- 1 file changed, 57 insertions(+), 60 deletions(-) diff --git a/tests/allgather_test.cu b/tests/allgather_test.cu index 0e159783..05cfd6de 100644 --- a/tests/allgather_test.cu +++ b/tests/allgather_test.cu @@ -49,61 +49,76 @@ __constant__ mscclppDevConn_t constDevConns[16]; __device__ void allgather0(mscclppDevConn_t devConn, int rank, int world_size, int remoteRank, size_t nelemsPerGPU) { - if (threadIdx.x % 32 != 0) - return; // this allgather is really simple and implemented as an alltoall // this thread's role is a sender role // put your data asynchronously - devConn.putWithSignal(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int)); + if (threadIdx.x % 32 != 0) + devConn.putWithSignal(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int)); // make sure everyone is put their data before some thread randomly blocks everyone else in signal __syncthreads(); // push with flag and sync to make sure the data is received - devConn.flush(); + if (threadIdx.x % 32 != 0) + devConn.flush(); // this thread's role is a receiver role. wait on the semaphore to make sure the data is ready - devConn.wait(); + if (threadIdx.x % 32 != 0) + devConn.wait(); } -__device__ void allgather1(mscclppDevConn_t devConn, int rank, int world_size, int remoteRank, size_t nelemsPerGPU) +__device__ void localAllGather(mscclppDevConn_t devConn, int rank, int world_size, int nranksPerNode, int remoteRank, + uint64_t offset, uint64_t size) { - if (threadIdx.x % 32 != 0) - return; // this allgather algorithm works as follows: - // Step 1: GPU rank i sends data to GPU rank (i+1) % world_size - // Step 2: GPU rank i waits for data from GPU rank (i+2) % world_size + // Step 1: GPU rank i sends data to GPU rank (i+1) % nranksPerNode + // and waits for data from GPU rank (i-1) % nranksPerNode + // Step 2: GPU rank i sends data to GPU rank (i+2) % nranksPerNode // ... // This order is much better for DMA engine for NVLinks - - for (int i = 1; i < world_size; i++) { - __syncthreads(); - if (remoteRank != ((rank + i) % world_size)) - continue; - // put your data to GPU (rank+i) % world_size and signal all in one call - devConn.putWithSignalAndFlush(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int)); + for (int i = 1; i < nranksPerNode; i++) { + if ((remoteRank % nranksPerNode) == ((rank + i) % nranksPerNode)) { + // put your data to GPU (rank+i) % nranksPerNode and signal in one call + if ((threadIdx.x % 32) == 0) + devConn.putWithSignalAndFlush(offset, size); + } + // wait for the data from GPU (rank-i) % nranksPerNode to arrive + if ((remoteRank % nranksPerNode) == ((rank - i + nranksPerNode) % nranksPerNode)) { + if ((threadIdx.x % 32) == 0) + devConn.wait(); + } + asm volatile("bar.sync %0, %1;" ::"r"(11), "r"((nranksPerNode - 1) * 32) : "memory"); } - // all connections wait for the signal from the sender - devConn.wait(); +} + +__device__ void allgather1(mscclppDevConn_t devConn, int rank, int world_size, int nranksPerNode, int remoteRank, + size_t nelemsPerGPU) +{ + localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int), + nelemsPerGPU * sizeof(int)); } __device__ void allgather2(mscclppDevConn_t devConn, int rank, int world_size, int nranksPerNode, int remoteRank, size_t nelemsPerGPU) { + // this allgather is a pipelined and hierarchical one and only works for two nodes + // it is implemented as follows: + // Step 1: each node does a local allgather and concurrently, + // local GPU i exchange (piplineSize-1)/pipelineSize portion of their data with + // its cross-node neighbor (local GPU i on the other node) via IB + // Step 2: each node does a local allgather again with the data just received from its + // cross-node neighbor in step 1, and concurrently, exchange the rest of the data with + // its cross-node neighbor + // Step 3: each node does a local allgather for the last time with the rest of the data + int pipelineSize = 3; + // Step 1 + // local allgather if (remoteRank / nranksPerNode == rank / nranksPerNode) { - for (int i = 1; i < nranksPerNode; i++) { - if ((remoteRank % nranksPerNode) == ((rank + i) % nranksPerNode)) { - if ((threadIdx.x % 32) == 0) - devConn.putWithSignalAndFlush(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int)); - } - if ((remoteRank % nranksPerNode) == ((rank - i + nranksPerNode) % nranksPerNode)) - if ((threadIdx.x % 32) == 0) - devConn.wait(); - asm volatile("bar.sync %0, %1;" ::"r"(10), "r"((nranksPerNode - 1) * 32) : "memory"); - } + localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int), + nelemsPerGPU * sizeof(int)); } - + // cross-node exchange if (remoteRank % nranksPerNode == rank % nranksPerNode) { // opposite side if ((threadIdx.x % 32) == 0) @@ -115,21 +130,15 @@ __device__ void allgather2(mscclppDevConn_t devConn, int rank, int world_size, i __syncthreads(); + // Step 2 + // local allgather + int otherNghr = (rank + nranksPerNode) % world_size; if (remoteRank / nranksPerNode == rank / nranksPerNode) { - for (int i = 1; i < nranksPerNode; i++) { - int otherNghr = (rank + nranksPerNode) % world_size; - if ((remoteRank % nranksPerNode) == ((rank + i) % nranksPerNode)) { - if ((threadIdx.x % 32) == 0) - devConn.putWithSignalAndFlush(otherNghr * nelemsPerGPU * sizeof(int), - (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int)); - } - if ((remoteRank % nranksPerNode) == ((rank - i + nranksPerNode) % nranksPerNode)) - if ((threadIdx.x % 32) == 0) - devConn.wait(); - asm volatile("bar.sync %0, %1;" ::"r"(11), "r"((nranksPerNode - 1) * 32) : "memory"); - } + localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank, otherNghr * nelemsPerGPU * sizeof(int), + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int)); } + // cross-node exchange if (remoteRank % nranksPerNode == rank % nranksPerNode) { // opposite side if ((threadIdx.x % 32) == 0) @@ -142,29 +151,17 @@ __device__ void allgather2(mscclppDevConn_t devConn, int rank, int world_size, i __syncthreads(); + // Step 3 + // local allgather if (remoteRank / nranksPerNode == rank / nranksPerNode) { - for (int i = 1; i < nranksPerNode; i++) { - int otherNghr = (rank + nranksPerNode) % world_size; - if ((remoteRank % nranksPerNode) == ((rank + i) % nranksPerNode)) { - if ((threadIdx.x % 32) == 0) - devConn.putWithSignalAndFlush( - (otherNghr * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) * sizeof(int), - nelemsPerGPU / pipelineSize * sizeof(int)); - } - if ((remoteRank % nranksPerNode) == ((rank - i + nranksPerNode) % nranksPerNode)) - if ((threadIdx.x % 32) == 0) - devConn.wait(); - asm volatile("bar.sync %0, %1;" ::"r"(11), "r"((nranksPerNode - 1) * 32) : "memory"); - } + localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank, + (otherNghr * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) * sizeof(int), + nelemsPerGPU / pipelineSize * sizeof(int)); } } __global__ void kernel(int rank, int world_size, int nranksPerNode, size_t nelemsPerGPU, int kernel) { - // only use a single thread from each warp - // if (threadIdx.x % 32 != 0) - // return; - // find the mapping between remoteRank and devConns int warpId = threadIdx.x / 32; int remoteRank = (warpId < rank) ? warpId : warpId + 1; @@ -174,7 +171,7 @@ __global__ void kernel(int rank, int world_size, int nranksPerNode, size_t nelem if (kernel == 0) allgather0(devConn, rank, world_size, remoteRank, nelemsPerGPU); else if (kernel == 1) - allgather1(devConn, rank, world_size, remoteRank, nelemsPerGPU); + allgather1(devConn, rank, world_size, nranksPerNode, remoteRank, nelemsPerGPU); else if (kernel == 2) allgather2(devConn, rank, world_size, nranksPerNode, remoteRank, nelemsPerGPU); }