a third kernel for allgather cross-node

This commit is contained in:
Saeed Maleki
2023-03-30 23:24:04 +00:00
parent b58eae4037
commit fef0bff945

View File

@@ -82,7 +82,44 @@ __device__ void allgather1(mscclppDevConn_t devConn, int rank, int world_size, i
devConn.wait();
}
__global__ void kernel(int rank, int world_size, int nelemsPerGPU, int kernel)
__device__ void allgather2(mscclppDevConn_t devConn, int rank, int world_size, int nranksPerNode, int remoteRank, int nelemsPerGPU)
{
if (remoteRank % nranksPerNode == rank % nranksPerNode){
devConn.putWithSignal(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int));
for (int i = 1; i < nranksPerNode; i++)
__syncthreads();
devConn.wait();
devConn.flush();
__syncthreads();
} else if (remoteRank / nranksPerNode == rank / nranksPerNode) {
remoteRank = remoteRank % nranksPerNode;
for (int i = 1; i < nranksPerNode; i++) {
if (remoteRank == ((rank + i) % nranksPerNode)){
// put your data to GPU (rank+i) % world_size and signal all in one call
devConn.putWithSignalAndFlush(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int));
}
__syncthreads();
}
// all connections wait for the signal from the sender
devConn.wait();
__syncthreads();
int nodeNghr = (rank + nranksPerNode) % world_size;
for (int i = 1; i < nranksPerNode; i++) {
if (remoteRank == ((rank + i) % nranksPerNode)){
// put your data to GPU (rank+i) % world_size and signal all in one call
devConn.putWithSignalAndFlush(nodeNghr * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int));
}
__syncthreads();
}
// all connections wait for the signal from the sender
devConn.wait();
} else {
return;
}
}
__global__ void kernel(int rank, int world_size, int nranksPerNode, int nelemsPerGPU, int kernel)
{
// only use a single thread from each warp
if (threadIdx.x % 32 != 0)
@@ -98,6 +135,8 @@ __global__ void kernel(int rank, int world_size, int nelemsPerGPU, int kernel)
allgather0(devConn, rank, world_size, remoteRank, nelemsPerGPU);
else if (kernel == 1)
allgather1(devConn, rank, world_size, remoteRank, nelemsPerGPU);
else if (kernel == 2)
allgather2(devConn, rank, world_size, nranksPerNode, remoteRank, nelemsPerGPU);
}
int rankToLocalRank(int rank)
@@ -141,6 +180,7 @@ mscclppResult_t setupMscclppConnections(int rank, int world_size, mscclppComm_t
{
int thisNode = rankToNode(rank);
int cudaNum = rankToLocalRank(rank);
int map[8] = {2,0,6,4,3,1,7,5};
std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum);
for (int r = 0; r < world_size; ++r) {
@@ -338,7 +378,7 @@ int main(int argc, const char* argv[])
cudaStream_t stream;
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
CUDACHECK(cudaDeviceSynchronize());
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nelemsPerGPU, kernelNum);
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
CUDACHECK(cudaDeviceSynchronize());
CUDACHECK(cudaMemcpy(data_h, data_d, dataSize, cudaMemcpyDeviceToHost));
CUDACHECK(cudaDeviceSynchronize());
@@ -361,7 +401,7 @@ int main(int argc, const char* argv[])
if (rank == 0)
printf("Running %d iterations of the kernel without CUDA graph\n", iterwithoutcudagraph);
for (int i = 0; i < iterwithoutcudagraph; ++i) {
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nelemsPerGPU, kernelNum);
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
}
CUDACHECK(cudaDeviceSynchronize());
MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int)));
@@ -374,7 +414,7 @@ int main(int argc, const char* argv[])
cudaGraphExec_t instance;
cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);
for (int i = 0; i < cudagraphiter; ++i) {
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nelemsPerGPU, kernelNum);
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
}
cudaStreamEndCapture(stream, &graph);
cudaGraphInstantiate(&instance, graph, NULL, NULL, 0);