mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-24 23:06:17 +00:00
a third kernel for allgather cross-node
This commit is contained in:
@@ -82,7 +82,44 @@ __device__ void allgather1(mscclppDevConn_t devConn, int rank, int world_size, i
|
||||
devConn.wait();
|
||||
}
|
||||
|
||||
__global__ void kernel(int rank, int world_size, int nelemsPerGPU, int kernel)
|
||||
__device__ void allgather2(mscclppDevConn_t devConn, int rank, int world_size, int nranksPerNode, int remoteRank, int nelemsPerGPU)
|
||||
{
|
||||
if (remoteRank % nranksPerNode == rank % nranksPerNode){
|
||||
devConn.putWithSignal(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int));
|
||||
for (int i = 1; i < nranksPerNode; i++)
|
||||
__syncthreads();
|
||||
devConn.wait();
|
||||
devConn.flush();
|
||||
__syncthreads();
|
||||
} else if (remoteRank / nranksPerNode == rank / nranksPerNode) {
|
||||
remoteRank = remoteRank % nranksPerNode;
|
||||
for (int i = 1; i < nranksPerNode; i++) {
|
||||
if (remoteRank == ((rank + i) % nranksPerNode)){
|
||||
// put your data to GPU (rank+i) % world_size and signal all in one call
|
||||
devConn.putWithSignalAndFlush(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int));
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
// all connections wait for the signal from the sender
|
||||
devConn.wait();
|
||||
|
||||
__syncthreads();
|
||||
int nodeNghr = (rank + nranksPerNode) % world_size;
|
||||
for (int i = 1; i < nranksPerNode; i++) {
|
||||
if (remoteRank == ((rank + i) % nranksPerNode)){
|
||||
// put your data to GPU (rank+i) % world_size and signal all in one call
|
||||
devConn.putWithSignalAndFlush(nodeNghr * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int));
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
// all connections wait for the signal from the sender
|
||||
devConn.wait();
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void kernel(int rank, int world_size, int nranksPerNode, int nelemsPerGPU, int kernel)
|
||||
{
|
||||
// only use a single thread from each warp
|
||||
if (threadIdx.x % 32 != 0)
|
||||
@@ -98,6 +135,8 @@ __global__ void kernel(int rank, int world_size, int nelemsPerGPU, int kernel)
|
||||
allgather0(devConn, rank, world_size, remoteRank, nelemsPerGPU);
|
||||
else if (kernel == 1)
|
||||
allgather1(devConn, rank, world_size, remoteRank, nelemsPerGPU);
|
||||
else if (kernel == 2)
|
||||
allgather2(devConn, rank, world_size, nranksPerNode, remoteRank, nelemsPerGPU);
|
||||
}
|
||||
|
||||
int rankToLocalRank(int rank)
|
||||
@@ -141,6 +180,7 @@ mscclppResult_t setupMscclppConnections(int rank, int world_size, mscclppComm_t
|
||||
{
|
||||
int thisNode = rankToNode(rank);
|
||||
int cudaNum = rankToLocalRank(rank);
|
||||
int map[8] = {2,0,6,4,3,1,7,5};
|
||||
std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum);
|
||||
|
||||
for (int r = 0; r < world_size; ++r) {
|
||||
@@ -338,7 +378,7 @@ int main(int argc, const char* argv[])
|
||||
cudaStream_t stream;
|
||||
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
|
||||
CUDACHECK(cudaDeviceSynchronize());
|
||||
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nelemsPerGPU, kernelNum);
|
||||
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
|
||||
CUDACHECK(cudaDeviceSynchronize());
|
||||
CUDACHECK(cudaMemcpy(data_h, data_d, dataSize, cudaMemcpyDeviceToHost));
|
||||
CUDACHECK(cudaDeviceSynchronize());
|
||||
@@ -361,7 +401,7 @@ int main(int argc, const char* argv[])
|
||||
if (rank == 0)
|
||||
printf("Running %d iterations of the kernel without CUDA graph\n", iterwithoutcudagraph);
|
||||
for (int i = 0; i < iterwithoutcudagraph; ++i) {
|
||||
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nelemsPerGPU, kernelNum);
|
||||
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
|
||||
}
|
||||
CUDACHECK(cudaDeviceSynchronize());
|
||||
MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int)));
|
||||
@@ -374,7 +414,7 @@ int main(int argc, const char* argv[])
|
||||
cudaGraphExec_t instance;
|
||||
cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);
|
||||
for (int i = 0; i < cudagraphiter; ++i) {
|
||||
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nelemsPerGPU, kernelNum);
|
||||
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
|
||||
}
|
||||
cudaStreamEndCapture(stream, &graph);
|
||||
cudaGraphInstantiate(&instance, graph, NULL, NULL, 0);
|
||||
|
||||
Reference in New Issue
Block a user