From 9df2bdb2bf71f529e6c2516bf5b6cb89ec98ec58 Mon Sep 17 00:00:00 2001 From: Nusrat Islam Date: Thu, 24 Apr 2025 18:43:47 -0500 Subject: [PATCH] apps/nccl: fix a bug in allreduce kernels for graph mode (#502) `allreduce7` and `allreduceAllpairs` kernels were updating the LL protocol flag on the host side. So, it was not properly captured in graph mode. This PR fixes the issue by updating the flag in the kernels. --- apps/nccl/src/allreduce.hpp | 45 ++++++++++++++++++++++++++++--------- apps/nccl/src/nccl.cu | 23 +++++++++++++++++-- 2 files changed, 55 insertions(+), 13 deletions(-) diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp index e1521b67..eebc648a 100644 --- a/apps/nccl/src/allreduce.hpp +++ b/apps/nccl/src/allreduce.hpp @@ -168,11 +168,17 @@ template __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* memoryChannels, size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, - int worldSize, size_t nelems, uint32_t flag) { + int worldSize, size_t nelems, uint32_t* deviceFlag, uint32_t numScratchBuff) { // This version of allreduce only works for single nodes if (worldSize != nRanksPerNode) return; if (sizeof(T) == 2) nelems = (nelems * sizeof(T) + sizeof(T)) / sizeof(int); const int nPeers = nRanksPerNode - 1; + + uint32_t flag = deviceFlag[blockIdx.x]; + + size_t scratchBaseOffset = (flag % numScratchBuff) ? SCRATCH_SIZE / numScratchBuff : 0; + channelScratchOffset = scratchBaseOffset; + const int nBlocksPerPeer = gridDim.x / nPeers; const int localBlockIdx = blockIdx.x % nBlocksPerPeer; const int tid = threadIdx.x + localBlockIdx * blockDim.x; @@ -198,13 +204,17 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, } dst[idx] = data; } + __syncthreads(); + if (threadIdx.x == 0) { + deviceFlag[blockIdx.x] = deviceFlag[blockIdx.x] + 1; + } } template __global__ void __launch_bounds__(1024, 1) allreduce7(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* memoryChannels, size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, - size_t nelems, uint32_t flag + size_t nelems, uint32_t* deviceFlag, uint32_t numScratchBuff #if defined(ENABLE_NPKIT) , NpKitEventCollectContext* npKitEventCollectContexts, uint64_t* cpuTimestamp) { @@ -247,6 +257,11 @@ __global__ void __launch_bounds__(1024, 1) const int nPeers = nRanksPerNode - 1; const size_t nPkts = nelems / 2; + uint32_t flag = (uint32_t)deviceFlag[blockIdx.x]; + + size_t scratchBaseOffset = (flag % numScratchBuff) ? SCRATCH_SIZE / numScratchBuff : 0; + channelScratchOffset = scratchBaseOffset; + int nelemsPerRank = nelems / worldSize; if ((nelemsPerRank % 2)) nelemsPerRank = (nelemsPerRank * sizeof(T) + sizeof(T)) / sizeof(T); @@ -309,6 +324,8 @@ __global__ void __launch_bounds__(1024, 1) result[idx].x = data.x; result[idx].y = data.y; } + + __syncthreads(); #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_KERNEL_ALLREDUCE_ENTRY) && \ defined(ENABLE_NPKIT_EVENT_KERNEL_ALLREDUCE_EXIT) NpKit::CollectGpuEventShm(NPKIT_EVENT_KERNEL_ALLREDUCE_ENTRY, 0, 0, npkit_timestamp_entry, event_buffer, @@ -319,6 +336,9 @@ __global__ void __launch_bounds__(1024, 1) #if defined(ENABLE_NPKIT) NpKit::StoreGpuEventShm(npKitEventCollectContexts, event_buffer, event_buffer_head); #endif + if (threadIdx.x == 0) { + deviceFlag[blockIdx.x] = deviceFlag[blockIdx.x] + 1; + } } template @@ -462,37 +482,40 @@ cudaError_t allreduce(const void* buff, void* scratch, void* resultBuff, mscclpp::DeviceHandle* memoryChannels, mscclpp::DeviceHandle* memoryOutChannels, size_t channelInOffset, size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, - size_t nelems, cudaStream_t stream) { - static uint32_t flag = 1; - + size_t nelems, cudaStream_t stream, uint32_t* deviceFlag7, uint32_t* deviceFlag28, + uint32_t* deviceFlag56, uint32_t numScratchBuff) { + uint32_t* deviceFlag; if (sizeof(T) * nelems < worldSize * sizeof(int)) { int nBlocks = 7; int nThreadsPerBlock = 32; allreduceAllPairs<<>>( (T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank, - nRanksPerNode, worldSize, nelems, flag++); + nRanksPerNode, worldSize, nelems, deviceFlag7, numScratchBuff); } else if (sizeof(T) * nelems <= (1 << 14)) { int nBlocks = 28; int nThreadsPerBlock = 512; allreduceAllPairs<<>>( (T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank, - nRanksPerNode, worldSize, nelems, flag++); + nRanksPerNode, worldSize, nelems, deviceFlag28, numScratchBuff); } else if (sizeof(T) * nelems <= (1 << 20)) { int nBlocks = 28; int nThreadsPerBlock = 1024; + deviceFlag = deviceFlag28; if (nelems >= 8192) { nBlocks = 56; nThreadsPerBlock = (nelems <= 76800) ? 512 : 1024; + deviceFlag = deviceFlag56; } #if defined(ENABLE_NPKIT) size_t NpkitSharedMemSize = NPKIT_SHM_NUM_EVENTS * sizeof(NpKitEvent); allreduce7<<>>( (T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank, - nRanksPerNode, worldSize, nelems, flag++, NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); + nRanksPerNode, worldSize, nelems, deviceFlag, numScratchBuff, NpKit::GetGpuEventCollectContexts(), + NpKit::GetCpuTimestamp()); #else - allreduce7<<>>((T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, - channelInOffset, channelScratchOffset, rank, - nRanksPerNode, worldSize, nelems, flag++); + allreduce7<<>>( + (T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank, + nRanksPerNode, worldSize, nelems, deviceFlag, numScratchBuff); #endif } else { int nBlocks = 35; diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index 1f42dbfa..739d3ec8 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -195,6 +195,10 @@ struct ncclComm { uint32_t numScratchBuff; uint32_t buffFlag; + std::shared_ptr deviceFlag7; + std::shared_ptr deviceFlag28; + std::shared_ptr deviceFlag56; + void* mscclppNcclComm; }; @@ -383,7 +387,7 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, Op reduceOp = getReduceOp(op); std::function*, mscclpp::DeviceHandle*, size_t, size_t, size_t, int, int, int, - size_t, cudaStream_t)> + size_t, cudaStream_t, uint32_t*, uint32_t*, uint32_t*, int)> allreduceFunc; if (reduceOp == SUM) { if (datatype == ncclFloat16) { @@ -414,7 +418,9 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, } CUDACHECK(allreduceFunc(sendbuff, comm->scratchBuff.get(), recvbuff, memoryChannels, memoryOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), NRANKS_PER_NODE, - comm->comm->bootstrap()->getNranks(), count, stream)); + comm->comm->bootstrap()->getNranks(), count, stream, (uint32_t*)comm->deviceFlag7.get(), + (uint32_t*)comm->deviceFlag28.get(), (uint32_t*)comm->deviceFlag56.get(), + comm->numScratchBuff)); return ncclSuccess; } @@ -533,6 +539,19 @@ static void ncclCommInitRankFallbackSingleNode(ncclComm* commPtr, std::shared_pt commPtr->scratchBuff = mscclpp::GpuBuffer(SCRATCH_SIZE).memory(); commPtr->remoteScratchRegMemories = setupRemoteMemories(commPtr->comm, rank, commPtr->scratchBuff.get(), SCRATCH_SIZE, mscclpp::Transport::CudaIpc); + + commPtr->deviceFlag7 = mscclpp::detail::gpuCallocShared(7); + commPtr->deviceFlag28 = mscclpp::detail::gpuCallocShared(28); + commPtr->deviceFlag56 = mscclpp::detail::gpuCallocShared(56); + + std::vector initFlag(56); + for (int i = 0; i < 56; ++i) { + initFlag[i] = 1; + } + + mscclpp::gpuMemcpy(commPtr->deviceFlag7.get(), initFlag.data(), 7, cudaMemcpyHostToDevice); + mscclpp::gpuMemcpy(commPtr->deviceFlag28.get(), initFlag.data(), 28, cudaMemcpyHostToDevice); + mscclpp::gpuMemcpy(commPtr->deviceFlag56.get(), initFlag.data(), 56, cudaMemcpyHostToDevice); } NCCL_API ncclResult_t ncclGetVersion(int* version) {