mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-21 21:39:21 +00:00
documents for allgather2 + refactoring local allgather
This commit is contained in:
@@ -49,61 +49,76 @@ __constant__ mscclppDevConn_t constDevConns[16];
|
||||
|
||||
__device__ void allgather0(mscclppDevConn_t devConn, int rank, int world_size, int remoteRank, size_t nelemsPerGPU)
|
||||
{
|
||||
if (threadIdx.x % 32 != 0)
|
||||
return;
|
||||
// this allgather is really simple and implemented as an alltoall
|
||||
|
||||
// this thread's role is a sender role
|
||||
// put your data asynchronously
|
||||
devConn.putWithSignal(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int));
|
||||
if (threadIdx.x % 32 != 0)
|
||||
devConn.putWithSignal(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int));
|
||||
// make sure everyone is put their data before some thread randomly blocks everyone else in signal
|
||||
__syncthreads();
|
||||
// push with flag and sync to make sure the data is received
|
||||
devConn.flush();
|
||||
if (threadIdx.x % 32 != 0)
|
||||
devConn.flush();
|
||||
|
||||
// this thread's role is a receiver role. wait on the semaphore to make sure the data is ready
|
||||
devConn.wait();
|
||||
if (threadIdx.x % 32 != 0)
|
||||
devConn.wait();
|
||||
}
|
||||
|
||||
__device__ void allgather1(mscclppDevConn_t devConn, int rank, int world_size, int remoteRank, size_t nelemsPerGPU)
|
||||
__device__ void localAllGather(mscclppDevConn_t devConn, int rank, int world_size, int nranksPerNode, int remoteRank,
|
||||
uint64_t offset, uint64_t size)
|
||||
{
|
||||
if (threadIdx.x % 32 != 0)
|
||||
return;
|
||||
// this allgather algorithm works as follows:
|
||||
// Step 1: GPU rank i sends data to GPU rank (i+1) % world_size
|
||||
// Step 2: GPU rank i waits for data from GPU rank (i+2) % world_size
|
||||
// Step 1: GPU rank i sends data to GPU rank (i+1) % nranksPerNode
|
||||
// and waits for data from GPU rank (i-1) % nranksPerNode
|
||||
// Step 2: GPU rank i sends data to GPU rank (i+2) % nranksPerNode
|
||||
// ...
|
||||
// This order is much better for DMA engine for NVLinks
|
||||
|
||||
for (int i = 1; i < world_size; i++) {
|
||||
__syncthreads();
|
||||
if (remoteRank != ((rank + i) % world_size))
|
||||
continue;
|
||||
// put your data to GPU (rank+i) % world_size and signal all in one call
|
||||
devConn.putWithSignalAndFlush(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int));
|
||||
for (int i = 1; i < nranksPerNode; i++) {
|
||||
if ((remoteRank % nranksPerNode) == ((rank + i) % nranksPerNode)) {
|
||||
// put your data to GPU (rank+i) % nranksPerNode and signal in one call
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devConn.putWithSignalAndFlush(offset, size);
|
||||
}
|
||||
// wait for the data from GPU (rank-i) % nranksPerNode to arrive
|
||||
if ((remoteRank % nranksPerNode) == ((rank - i + nranksPerNode) % nranksPerNode)) {
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devConn.wait();
|
||||
}
|
||||
asm volatile("bar.sync %0, %1;" ::"r"(11), "r"((nranksPerNode - 1) * 32) : "memory");
|
||||
}
|
||||
// all connections wait for the signal from the sender
|
||||
devConn.wait();
|
||||
}
|
||||
|
||||
__device__ void allgather1(mscclppDevConn_t devConn, int rank, int world_size, int nranksPerNode, int remoteRank,
|
||||
size_t nelemsPerGPU)
|
||||
{
|
||||
localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int),
|
||||
nelemsPerGPU * sizeof(int));
|
||||
}
|
||||
|
||||
__device__ void allgather2(mscclppDevConn_t devConn, int rank, int world_size, int nranksPerNode, int remoteRank,
|
||||
size_t nelemsPerGPU)
|
||||
{
|
||||
// this allgather is a pipelined and hierarchical one and only works for two nodes
|
||||
// it is implemented as follows:
|
||||
// Step 1: each node does a local allgather and concurrently,
|
||||
// local GPU i exchange (piplineSize-1)/pipelineSize portion of their data with
|
||||
// its cross-node neighbor (local GPU i on the other node) via IB
|
||||
// Step 2: each node does a local allgather again with the data just received from its
|
||||
// cross-node neighbor in step 1, and concurrently, exchange the rest of the data with
|
||||
// its cross-node neighbor
|
||||
// Step 3: each node does a local allgather for the last time with the rest of the data
|
||||
|
||||
int pipelineSize = 3;
|
||||
|
||||
// Step 1
|
||||
// local allgather
|
||||
if (remoteRank / nranksPerNode == rank / nranksPerNode) {
|
||||
for (int i = 1; i < nranksPerNode; i++) {
|
||||
if ((remoteRank % nranksPerNode) == ((rank + i) % nranksPerNode)) {
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devConn.putWithSignalAndFlush(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int));
|
||||
}
|
||||
if ((remoteRank % nranksPerNode) == ((rank - i + nranksPerNode) % nranksPerNode))
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devConn.wait();
|
||||
asm volatile("bar.sync %0, %1;" ::"r"(10), "r"((nranksPerNode - 1) * 32) : "memory");
|
||||
}
|
||||
localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int),
|
||||
nelemsPerGPU * sizeof(int));
|
||||
}
|
||||
|
||||
// cross-node exchange
|
||||
if (remoteRank % nranksPerNode == rank % nranksPerNode) {
|
||||
// opposite side
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
@@ -115,21 +130,15 @@ __device__ void allgather2(mscclppDevConn_t devConn, int rank, int world_size, i
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Step 2
|
||||
// local allgather
|
||||
int otherNghr = (rank + nranksPerNode) % world_size;
|
||||
if (remoteRank / nranksPerNode == rank / nranksPerNode) {
|
||||
for (int i = 1; i < nranksPerNode; i++) {
|
||||
int otherNghr = (rank + nranksPerNode) % world_size;
|
||||
if ((remoteRank % nranksPerNode) == ((rank + i) % nranksPerNode)) {
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devConn.putWithSignalAndFlush(otherNghr * nelemsPerGPU * sizeof(int),
|
||||
(nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int));
|
||||
}
|
||||
if ((remoteRank % nranksPerNode) == ((rank - i + nranksPerNode) % nranksPerNode))
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devConn.wait();
|
||||
asm volatile("bar.sync %0, %1;" ::"r"(11), "r"((nranksPerNode - 1) * 32) : "memory");
|
||||
}
|
||||
localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank, otherNghr * nelemsPerGPU * sizeof(int),
|
||||
(nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int));
|
||||
}
|
||||
|
||||
// cross-node exchange
|
||||
if (remoteRank % nranksPerNode == rank % nranksPerNode) {
|
||||
// opposite side
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
@@ -142,29 +151,17 @@ __device__ void allgather2(mscclppDevConn_t devConn, int rank, int world_size, i
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Step 3
|
||||
// local allgather
|
||||
if (remoteRank / nranksPerNode == rank / nranksPerNode) {
|
||||
for (int i = 1; i < nranksPerNode; i++) {
|
||||
int otherNghr = (rank + nranksPerNode) % world_size;
|
||||
if ((remoteRank % nranksPerNode) == ((rank + i) % nranksPerNode)) {
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devConn.putWithSignalAndFlush(
|
||||
(otherNghr * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) * sizeof(int),
|
||||
nelemsPerGPU / pipelineSize * sizeof(int));
|
||||
}
|
||||
if ((remoteRank % nranksPerNode) == ((rank - i + nranksPerNode) % nranksPerNode))
|
||||
if ((threadIdx.x % 32) == 0)
|
||||
devConn.wait();
|
||||
asm volatile("bar.sync %0, %1;" ::"r"(11), "r"((nranksPerNode - 1) * 32) : "memory");
|
||||
}
|
||||
localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank,
|
||||
(otherNghr * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) * sizeof(int),
|
||||
nelemsPerGPU / pipelineSize * sizeof(int));
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void kernel(int rank, int world_size, int nranksPerNode, size_t nelemsPerGPU, int kernel)
|
||||
{
|
||||
// only use a single thread from each warp
|
||||
// if (threadIdx.x % 32 != 0)
|
||||
// return;
|
||||
|
||||
// find the mapping between remoteRank and devConns
|
||||
int warpId = threadIdx.x / 32;
|
||||
int remoteRank = (warpId < rank) ? warpId : warpId + 1;
|
||||
@@ -174,7 +171,7 @@ __global__ void kernel(int rank, int world_size, int nranksPerNode, size_t nelem
|
||||
if (kernel == 0)
|
||||
allgather0(devConn, rank, world_size, remoteRank, nelemsPerGPU);
|
||||
else if (kernel == 1)
|
||||
allgather1(devConn, rank, world_size, remoteRank, nelemsPerGPU);
|
||||
allgather1(devConn, rank, world_size, nranksPerNode, remoteRank, nelemsPerGPU);
|
||||
else if (kernel == 2)
|
||||
allgather2(devConn, rank, world_size, nranksPerNode, remoteRank, nelemsPerGPU);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user