diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp index e1521b67..eebc648a 100644 --- a/apps/nccl/src/allreduce.hpp +++ b/apps/nccl/src/allreduce.hpp @@ -168,11 +168,17 @@ template __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* memoryChannels, size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, - int worldSize, size_t nelems, uint32_t flag) { + int worldSize, size_t nelems, uint32_t* deviceFlag, uint32_t numScratchBuff) { // This version of allreduce only works for single nodes if (worldSize != nRanksPerNode) return; if (sizeof(T) == 2) nelems = (nelems * sizeof(T) + sizeof(T)) / sizeof(int); const int nPeers = nRanksPerNode - 1; + + uint32_t flag = deviceFlag[blockIdx.x]; + + size_t scratchBaseOffset = (flag % numScratchBuff) ? SCRATCH_SIZE / numScratchBuff : 0; + channelScratchOffset = scratchBaseOffset; + const int nBlocksPerPeer = gridDim.x / nPeers; const int localBlockIdx = blockIdx.x % nBlocksPerPeer; const int tid = threadIdx.x + localBlockIdx * blockDim.x; @@ -198,13 +204,17 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, } dst[idx] = data; } + __syncthreads(); + if (threadIdx.x == 0) { + deviceFlag[blockIdx.x] = deviceFlag[blockIdx.x] + 1; + } } template __global__ void __launch_bounds__(1024, 1) allreduce7(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* memoryChannels, size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, - size_t nelems, uint32_t flag + size_t nelems, uint32_t* deviceFlag, uint32_t numScratchBuff #if defined(ENABLE_NPKIT) , NpKitEventCollectContext* npKitEventCollectContexts, uint64_t* cpuTimestamp) { @@ -247,6 +257,11 @@ __global__ void __launch_bounds__(1024, 1) const int nPeers = nRanksPerNode - 1; const size_t nPkts = nelems / 2; + uint32_t flag = (uint32_t)deviceFlag[blockIdx.x]; + + size_t scratchBaseOffset = (flag % numScratchBuff) ? SCRATCH_SIZE / numScratchBuff : 0; + channelScratchOffset = scratchBaseOffset; + int nelemsPerRank = nelems / worldSize; if ((nelemsPerRank % 2)) nelemsPerRank = (nelemsPerRank * sizeof(T) + sizeof(T)) / sizeof(T); @@ -309,6 +324,8 @@ __global__ void __launch_bounds__(1024, 1) result[idx].x = data.x; result[idx].y = data.y; } + + __syncthreads(); #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_KERNEL_ALLREDUCE_ENTRY) && \ defined(ENABLE_NPKIT_EVENT_KERNEL_ALLREDUCE_EXIT) NpKit::CollectGpuEventShm(NPKIT_EVENT_KERNEL_ALLREDUCE_ENTRY, 0, 0, npkit_timestamp_entry, event_buffer, @@ -319,6 +336,9 @@ __global__ void __launch_bounds__(1024, 1) #if defined(ENABLE_NPKIT) NpKit::StoreGpuEventShm(npKitEventCollectContexts, event_buffer, event_buffer_head); #endif + if (threadIdx.x == 0) { + deviceFlag[blockIdx.x] = deviceFlag[blockIdx.x] + 1; + } } template @@ -462,37 +482,40 @@ cudaError_t allreduce(const void* buff, void* scratch, void* resultBuff, mscclpp::DeviceHandle* memoryChannels, mscclpp::DeviceHandle* memoryOutChannels, size_t channelInOffset, size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, - size_t nelems, cudaStream_t stream) { - static uint32_t flag = 1; - + size_t nelems, cudaStream_t stream, uint32_t* deviceFlag7, uint32_t* deviceFlag28, + uint32_t* deviceFlag56, uint32_t numScratchBuff) { + uint32_t* deviceFlag; if (sizeof(T) * nelems < worldSize * sizeof(int)) { int nBlocks = 7; int nThreadsPerBlock = 32; allreduceAllPairs<<>>( (T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank, - nRanksPerNode, worldSize, nelems, flag++); + nRanksPerNode, worldSize, nelems, deviceFlag7, numScratchBuff); } else if (sizeof(T) * nelems <= (1 << 14)) { int nBlocks = 28; int nThreadsPerBlock = 512; allreduceAllPairs<<>>( (T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank, - nRanksPerNode, worldSize, nelems, flag++); + nRanksPerNode, worldSize, nelems, deviceFlag28, numScratchBuff); } else if (sizeof(T) * nelems <= (1 << 20)) { int nBlocks = 28; int nThreadsPerBlock = 1024; + deviceFlag = deviceFlag28; if (nelems >= 8192) { nBlocks = 56; nThreadsPerBlock = (nelems <= 76800) ? 512 : 1024; + deviceFlag = deviceFlag56; } #if defined(ENABLE_NPKIT) size_t NpkitSharedMemSize = NPKIT_SHM_NUM_EVENTS * sizeof(NpKitEvent); allreduce7<<>>( (T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank, - nRanksPerNode, worldSize, nelems, flag++, NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); + nRanksPerNode, worldSize, nelems, deviceFlag, numScratchBuff, NpKit::GetGpuEventCollectContexts(), + NpKit::GetCpuTimestamp()); #else - allreduce7<<>>((T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, - channelInOffset, channelScratchOffset, rank, - nRanksPerNode, worldSize, nelems, flag++); + allreduce7<<>>( + (T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank, + nRanksPerNode, worldSize, nelems, deviceFlag, numScratchBuff); #endif } else { int nBlocks = 35; diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index 1f42dbfa..739d3ec8 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -195,6 +195,10 @@ struct ncclComm { uint32_t numScratchBuff; uint32_t buffFlag; + std::shared_ptr deviceFlag7; + std::shared_ptr deviceFlag28; + std::shared_ptr deviceFlag56; + void* mscclppNcclComm; }; @@ -383,7 +387,7 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, Op reduceOp = getReduceOp(op); std::function*, mscclpp::DeviceHandle*, size_t, size_t, size_t, int, int, int, - size_t, cudaStream_t)> + size_t, cudaStream_t, uint32_t*, uint32_t*, uint32_t*, int)> allreduceFunc; if (reduceOp == SUM) { if (datatype == ncclFloat16) { @@ -414,7 +418,9 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, } CUDACHECK(allreduceFunc(sendbuff, comm->scratchBuff.get(), recvbuff, memoryChannels, memoryOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), NRANKS_PER_NODE, - comm->comm->bootstrap()->getNranks(), count, stream)); + comm->comm->bootstrap()->getNranks(), count, stream, (uint32_t*)comm->deviceFlag7.get(), + (uint32_t*)comm->deviceFlag28.get(), (uint32_t*)comm->deviceFlag56.get(), + comm->numScratchBuff)); return ncclSuccess; } @@ -533,6 +539,19 @@ static void ncclCommInitRankFallbackSingleNode(ncclComm* commPtr, std::shared_pt commPtr->scratchBuff = mscclpp::GpuBuffer(SCRATCH_SIZE).memory(); commPtr->remoteScratchRegMemories = setupRemoteMemories(commPtr->comm, rank, commPtr->scratchBuff.get(), SCRATCH_SIZE, mscclpp::Transport::CudaIpc); + + commPtr->deviceFlag7 = mscclpp::detail::gpuCallocShared(7); + commPtr->deviceFlag28 = mscclpp::detail::gpuCallocShared(28); + commPtr->deviceFlag56 = mscclpp::detail::gpuCallocShared(56); + + std::vector initFlag(56); + for (int i = 0; i < 56; ++i) { + initFlag[i] = 1; + } + + mscclpp::gpuMemcpy(commPtr->deviceFlag7.get(), initFlag.data(), 7, cudaMemcpyHostToDevice); + mscclpp::gpuMemcpy(commPtr->deviceFlag28.get(), initFlag.data(), 28, cudaMemcpyHostToDevice); + mscclpp::gpuMemcpy(commPtr->deviceFlag56.get(), initFlag.data(), 56, cudaMemcpyHostToDevice); } NCCL_API ncclResult_t ncclGetVersion(int* version) {