apps/nccl: fix a bug in allreduce kernels for graph mode (#502)

`allreduce7` and `allreduceAllpairs` kernels were updating the LL
protocol flag on the host side. So, it was not properly captured in
graph mode. This PR fixes the issue by updating the flag in the kernels.
This commit is contained in:
Nusrat Islam
2025-04-24 18:43:47 -05:00
committed by GitHub
parent cbdcf9064c
commit 9df2bdb2bf
2 changed files with 55 additions and 13 deletions

View File

@@ -168,11 +168,17 @@ template <Op OpType, typename T>
__global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff,
mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode,
int worldSize, size_t nelems, uint32_t flag) {
int worldSize, size_t nelems, uint32_t* deviceFlag, uint32_t numScratchBuff) {
// This version of allreduce only works for single nodes
if (worldSize != nRanksPerNode) return;
if (sizeof(T) == 2) nelems = (nelems * sizeof(T) + sizeof(T)) / sizeof(int);
const int nPeers = nRanksPerNode - 1;
uint32_t flag = deviceFlag[blockIdx.x];
size_t scratchBaseOffset = (flag % numScratchBuff) ? SCRATCH_SIZE / numScratchBuff : 0;
channelScratchOffset = scratchBaseOffset;
const int nBlocksPerPeer = gridDim.x / nPeers;
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
@@ -198,13 +204,17 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff,
}
dst[idx] = data;
}
__syncthreads();
if (threadIdx.x == 0) {
deviceFlag[blockIdx.x] = deviceFlag[blockIdx.x] + 1;
}
}
template <Op OpType, typename T>
__global__ void __launch_bounds__(1024, 1)
allreduce7(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
size_t nelems, uint32_t flag
size_t nelems, uint32_t* deviceFlag, uint32_t numScratchBuff
#if defined(ENABLE_NPKIT)
,
NpKitEventCollectContext* npKitEventCollectContexts, uint64_t* cpuTimestamp) {
@@ -247,6 +257,11 @@ __global__ void __launch_bounds__(1024, 1)
const int nPeers = nRanksPerNode - 1;
const size_t nPkts = nelems / 2;
uint32_t flag = (uint32_t)deviceFlag[blockIdx.x];
size_t scratchBaseOffset = (flag % numScratchBuff) ? SCRATCH_SIZE / numScratchBuff : 0;
channelScratchOffset = scratchBaseOffset;
int nelemsPerRank = nelems / worldSize;
if ((nelemsPerRank % 2)) nelemsPerRank = (nelemsPerRank * sizeof(T) + sizeof(T)) / sizeof(T);
@@ -309,6 +324,8 @@ __global__ void __launch_bounds__(1024, 1)
result[idx].x = data.x;
result[idx].y = data.y;
}
__syncthreads();
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_KERNEL_ALLREDUCE_ENTRY) && \
defined(ENABLE_NPKIT_EVENT_KERNEL_ALLREDUCE_EXIT)
NpKit::CollectGpuEventShm(NPKIT_EVENT_KERNEL_ALLREDUCE_ENTRY, 0, 0, npkit_timestamp_entry, event_buffer,
@@ -319,6 +336,9 @@ __global__ void __launch_bounds__(1024, 1)
#if defined(ENABLE_NPKIT)
NpKit::StoreGpuEventShm(npKitEventCollectContexts, event_buffer, event_buffer_head);
#endif
if (threadIdx.x == 0) {
deviceFlag[blockIdx.x] = deviceFlag[blockIdx.x] + 1;
}
}
template <Op OpType, typename T>
@@ -462,37 +482,40 @@ cudaError_t allreduce(const void* buff, void* scratch, void* resultBuff,
mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryOutChannels, size_t channelInOffset,
size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
size_t nelems, cudaStream_t stream) {
static uint32_t flag = 1;
size_t nelems, cudaStream_t stream, uint32_t* deviceFlag7, uint32_t* deviceFlag28,
uint32_t* deviceFlag56, uint32_t numScratchBuff) {
uint32_t* deviceFlag;
if (sizeof(T) * nelems < worldSize * sizeof(int)) {
int nBlocks = 7;
int nThreadsPerBlock = 32;
allreduceAllPairs<OpType><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
(T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank,
nRanksPerNode, worldSize, nelems, flag++);
nRanksPerNode, worldSize, nelems, deviceFlag7, numScratchBuff);
} else if (sizeof(T) * nelems <= (1 << 14)) {
int nBlocks = 28;
int nThreadsPerBlock = 512;
allreduceAllPairs<OpType><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
(T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank,
nRanksPerNode, worldSize, nelems, flag++);
nRanksPerNode, worldSize, nelems, deviceFlag28, numScratchBuff);
} else if (sizeof(T) * nelems <= (1 << 20)) {
int nBlocks = 28;
int nThreadsPerBlock = 1024;
deviceFlag = deviceFlag28;
if (nelems >= 8192) {
nBlocks = 56;
nThreadsPerBlock = (nelems <= 76800) ? 512 : 1024;
deviceFlag = deviceFlag56;
}
#if defined(ENABLE_NPKIT)
size_t NpkitSharedMemSize = NPKIT_SHM_NUM_EVENTS * sizeof(NpKitEvent);
allreduce7<OpType><<<nBlocks, nThreadsPerBlock, NpkitSharedMemSize, stream>>>(
(T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank,
nRanksPerNode, worldSize, nelems, flag++, NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
nRanksPerNode, worldSize, nelems, deviceFlag, numScratchBuff, NpKit::GetGpuEventCollectContexts(),
NpKit::GetCpuTimestamp());
#else
allreduce7<OpType><<<nBlocks, nThreadsPerBlock, 0, stream>>>((T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels,
channelInOffset, channelScratchOffset, rank,
nRanksPerNode, worldSize, nelems, flag++);
allreduce7<OpType><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
(T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank,
nRanksPerNode, worldSize, nelems, deviceFlag, numScratchBuff);
#endif
} else {
int nBlocks = 35;

View File

@@ -195,6 +195,10 @@ struct ncclComm {
uint32_t numScratchBuff;
uint32_t buffFlag;
std::shared_ptr<uint32_t> deviceFlag7;
std::shared_ptr<uint32_t> deviceFlag28;
std::shared_ptr<uint32_t> deviceFlag56;
void* mscclppNcclComm;
};
@@ -383,7 +387,7 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
Op reduceOp = getReduceOp(op);
std::function<cudaError_t(const void*, void*, void*, mscclpp::DeviceHandle<mscclpp::MemoryChannel>*,
mscclpp::DeviceHandle<mscclpp::MemoryChannel>*, size_t, size_t, size_t, int, int, int,
size_t, cudaStream_t)>
size_t, cudaStream_t, uint32_t*, uint32_t*, uint32_t*, int)>
allreduceFunc;
if (reduceOp == SUM) {
if (datatype == ncclFloat16) {
@@ -414,7 +418,9 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
}
CUDACHECK(allreduceFunc(sendbuff, comm->scratchBuff.get(), recvbuff, memoryChannels, memoryOutChannels, offsetIn,
offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), NRANKS_PER_NODE,
comm->comm->bootstrap()->getNranks(), count, stream));
comm->comm->bootstrap()->getNranks(), count, stream, (uint32_t*)comm->deviceFlag7.get(),
(uint32_t*)comm->deviceFlag28.get(), (uint32_t*)comm->deviceFlag56.get(),
comm->numScratchBuff));
return ncclSuccess;
}
@@ -533,6 +539,19 @@ static void ncclCommInitRankFallbackSingleNode(ncclComm* commPtr, std::shared_pt
commPtr->scratchBuff = mscclpp::GpuBuffer<char>(SCRATCH_SIZE).memory();
commPtr->remoteScratchRegMemories =
setupRemoteMemories(commPtr->comm, rank, commPtr->scratchBuff.get(), SCRATCH_SIZE, mscclpp::Transport::CudaIpc);
commPtr->deviceFlag7 = mscclpp::detail::gpuCallocShared<uint32_t>(7);
commPtr->deviceFlag28 = mscclpp::detail::gpuCallocShared<uint32_t>(28);
commPtr->deviceFlag56 = mscclpp::detail::gpuCallocShared<uint32_t>(56);
std::vector<uint32_t> initFlag(56);
for (int i = 0; i < 56; ++i) {
initFlag[i] = 1;
}
mscclpp::gpuMemcpy<uint32_t>(commPtr->deviceFlag7.get(), initFlag.data(), 7, cudaMemcpyHostToDevice);
mscclpp::gpuMemcpy<uint32_t>(commPtr->deviceFlag28.get(), initFlag.data(), 28, cudaMemcpyHostToDevice);
mscclpp::gpuMemcpy<uint32_t>(commPtr->deviceFlag56.get(), initFlag.data(), 56, cudaMemcpyHostToDevice);
}
NCCL_API ncclResult_t ncclGetVersion(int* version) {