mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-06-08 15:30:41 +00:00
apps/nccl: fix a bug in allreduce kernels for graph mode (#502)
`allreduce7` and `allreduceAllpairs` kernels were updating the LL protocol flag on the host side. So, it was not properly captured in graph mode. This PR fixes the issue by updating the flag in the kernels.
This commit is contained in:
@@ -168,11 +168,17 @@ template <Op OpType, typename T>
|
||||
__global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff,
|
||||
mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
|
||||
size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode,
|
||||
int worldSize, size_t nelems, uint32_t flag) {
|
||||
int worldSize, size_t nelems, uint32_t* deviceFlag, uint32_t numScratchBuff) {
|
||||
// This version of allreduce only works for single nodes
|
||||
if (worldSize != nRanksPerNode) return;
|
||||
if (sizeof(T) == 2) nelems = (nelems * sizeof(T) + sizeof(T)) / sizeof(int);
|
||||
const int nPeers = nRanksPerNode - 1;
|
||||
|
||||
uint32_t flag = deviceFlag[blockIdx.x];
|
||||
|
||||
size_t scratchBaseOffset = (flag % numScratchBuff) ? SCRATCH_SIZE / numScratchBuff : 0;
|
||||
channelScratchOffset = scratchBaseOffset;
|
||||
|
||||
const int nBlocksPerPeer = gridDim.x / nPeers;
|
||||
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
|
||||
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
|
||||
@@ -198,13 +204,17 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff,
|
||||
}
|
||||
dst[idx] = data;
|
||||
}
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0) {
|
||||
deviceFlag[blockIdx.x] = deviceFlag[blockIdx.x] + 1;
|
||||
}
|
||||
}
|
||||
|
||||
template <Op OpType, typename T>
|
||||
__global__ void __launch_bounds__(1024, 1)
|
||||
allreduce7(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
|
||||
size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
|
||||
size_t nelems, uint32_t flag
|
||||
size_t nelems, uint32_t* deviceFlag, uint32_t numScratchBuff
|
||||
#if defined(ENABLE_NPKIT)
|
||||
,
|
||||
NpKitEventCollectContext* npKitEventCollectContexts, uint64_t* cpuTimestamp) {
|
||||
@@ -247,6 +257,11 @@ __global__ void __launch_bounds__(1024, 1)
|
||||
const int nPeers = nRanksPerNode - 1;
|
||||
const size_t nPkts = nelems / 2;
|
||||
|
||||
uint32_t flag = (uint32_t)deviceFlag[blockIdx.x];
|
||||
|
||||
size_t scratchBaseOffset = (flag % numScratchBuff) ? SCRATCH_SIZE / numScratchBuff : 0;
|
||||
channelScratchOffset = scratchBaseOffset;
|
||||
|
||||
int nelemsPerRank = nelems / worldSize;
|
||||
if ((nelemsPerRank % 2)) nelemsPerRank = (nelemsPerRank * sizeof(T) + sizeof(T)) / sizeof(T);
|
||||
|
||||
@@ -309,6 +324,8 @@ __global__ void __launch_bounds__(1024, 1)
|
||||
result[idx].x = data.x;
|
||||
result[idx].y = data.y;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_KERNEL_ALLREDUCE_ENTRY) && \
|
||||
defined(ENABLE_NPKIT_EVENT_KERNEL_ALLREDUCE_EXIT)
|
||||
NpKit::CollectGpuEventShm(NPKIT_EVENT_KERNEL_ALLREDUCE_ENTRY, 0, 0, npkit_timestamp_entry, event_buffer,
|
||||
@@ -319,6 +336,9 @@ __global__ void __launch_bounds__(1024, 1)
|
||||
#if defined(ENABLE_NPKIT)
|
||||
NpKit::StoreGpuEventShm(npKitEventCollectContexts, event_buffer, event_buffer_head);
|
||||
#endif
|
||||
if (threadIdx.x == 0) {
|
||||
deviceFlag[blockIdx.x] = deviceFlag[blockIdx.x] + 1;
|
||||
}
|
||||
}
|
||||
|
||||
template <Op OpType, typename T>
|
||||
@@ -462,37 +482,40 @@ cudaError_t allreduce(const void* buff, void* scratch, void* resultBuff,
|
||||
mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
|
||||
mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryOutChannels, size_t channelInOffset,
|
||||
size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
|
||||
size_t nelems, cudaStream_t stream) {
|
||||
static uint32_t flag = 1;
|
||||
|
||||
size_t nelems, cudaStream_t stream, uint32_t* deviceFlag7, uint32_t* deviceFlag28,
|
||||
uint32_t* deviceFlag56, uint32_t numScratchBuff) {
|
||||
uint32_t* deviceFlag;
|
||||
if (sizeof(T) * nelems < worldSize * sizeof(int)) {
|
||||
int nBlocks = 7;
|
||||
int nThreadsPerBlock = 32;
|
||||
allreduceAllPairs<OpType><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
|
||||
(T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank,
|
||||
nRanksPerNode, worldSize, nelems, flag++);
|
||||
nRanksPerNode, worldSize, nelems, deviceFlag7, numScratchBuff);
|
||||
} else if (sizeof(T) * nelems <= (1 << 14)) {
|
||||
int nBlocks = 28;
|
||||
int nThreadsPerBlock = 512;
|
||||
allreduceAllPairs<OpType><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
|
||||
(T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank,
|
||||
nRanksPerNode, worldSize, nelems, flag++);
|
||||
nRanksPerNode, worldSize, nelems, deviceFlag28, numScratchBuff);
|
||||
} else if (sizeof(T) * nelems <= (1 << 20)) {
|
||||
int nBlocks = 28;
|
||||
int nThreadsPerBlock = 1024;
|
||||
deviceFlag = deviceFlag28;
|
||||
if (nelems >= 8192) {
|
||||
nBlocks = 56;
|
||||
nThreadsPerBlock = (nelems <= 76800) ? 512 : 1024;
|
||||
deviceFlag = deviceFlag56;
|
||||
}
|
||||
#if defined(ENABLE_NPKIT)
|
||||
size_t NpkitSharedMemSize = NPKIT_SHM_NUM_EVENTS * sizeof(NpKitEvent);
|
||||
allreduce7<OpType><<<nBlocks, nThreadsPerBlock, NpkitSharedMemSize, stream>>>(
|
||||
(T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank,
|
||||
nRanksPerNode, worldSize, nelems, flag++, NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
|
||||
nRanksPerNode, worldSize, nelems, deviceFlag, numScratchBuff, NpKit::GetGpuEventCollectContexts(),
|
||||
NpKit::GetCpuTimestamp());
|
||||
#else
|
||||
allreduce7<OpType><<<nBlocks, nThreadsPerBlock, 0, stream>>>((T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels,
|
||||
channelInOffset, channelScratchOffset, rank,
|
||||
nRanksPerNode, worldSize, nelems, flag++);
|
||||
allreduce7<OpType><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
|
||||
(T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank,
|
||||
nRanksPerNode, worldSize, nelems, deviceFlag, numScratchBuff);
|
||||
#endif
|
||||
} else {
|
||||
int nBlocks = 35;
|
||||
|
||||
@@ -195,6 +195,10 @@ struct ncclComm {
|
||||
uint32_t numScratchBuff;
|
||||
uint32_t buffFlag;
|
||||
|
||||
std::shared_ptr<uint32_t> deviceFlag7;
|
||||
std::shared_ptr<uint32_t> deviceFlag28;
|
||||
std::shared_ptr<uint32_t> deviceFlag56;
|
||||
|
||||
void* mscclppNcclComm;
|
||||
};
|
||||
|
||||
@@ -383,7 +387,7 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
|
||||
Op reduceOp = getReduceOp(op);
|
||||
std::function<cudaError_t(const void*, void*, void*, mscclpp::DeviceHandle<mscclpp::MemoryChannel>*,
|
||||
mscclpp::DeviceHandle<mscclpp::MemoryChannel>*, size_t, size_t, size_t, int, int, int,
|
||||
size_t, cudaStream_t)>
|
||||
size_t, cudaStream_t, uint32_t*, uint32_t*, uint32_t*, int)>
|
||||
allreduceFunc;
|
||||
if (reduceOp == SUM) {
|
||||
if (datatype == ncclFloat16) {
|
||||
@@ -414,7 +418,9 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
|
||||
}
|
||||
CUDACHECK(allreduceFunc(sendbuff, comm->scratchBuff.get(), recvbuff, memoryChannels, memoryOutChannels, offsetIn,
|
||||
offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), NRANKS_PER_NODE,
|
||||
comm->comm->bootstrap()->getNranks(), count, stream));
|
||||
comm->comm->bootstrap()->getNranks(), count, stream, (uint32_t*)comm->deviceFlag7.get(),
|
||||
(uint32_t*)comm->deviceFlag28.get(), (uint32_t*)comm->deviceFlag56.get(),
|
||||
comm->numScratchBuff));
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@@ -533,6 +539,19 @@ static void ncclCommInitRankFallbackSingleNode(ncclComm* commPtr, std::shared_pt
|
||||
commPtr->scratchBuff = mscclpp::GpuBuffer<char>(SCRATCH_SIZE).memory();
|
||||
commPtr->remoteScratchRegMemories =
|
||||
setupRemoteMemories(commPtr->comm, rank, commPtr->scratchBuff.get(), SCRATCH_SIZE, mscclpp::Transport::CudaIpc);
|
||||
|
||||
commPtr->deviceFlag7 = mscclpp::detail::gpuCallocShared<uint32_t>(7);
|
||||
commPtr->deviceFlag28 = mscclpp::detail::gpuCallocShared<uint32_t>(28);
|
||||
commPtr->deviceFlag56 = mscclpp::detail::gpuCallocShared<uint32_t>(56);
|
||||
|
||||
std::vector<uint32_t> initFlag(56);
|
||||
for (int i = 0; i < 56; ++i) {
|
||||
initFlag[i] = 1;
|
||||
}
|
||||
|
||||
mscclpp::gpuMemcpy<uint32_t>(commPtr->deviceFlag7.get(), initFlag.data(), 7, cudaMemcpyHostToDevice);
|
||||
mscclpp::gpuMemcpy<uint32_t>(commPtr->deviceFlag28.get(), initFlag.data(), 28, cudaMemcpyHostToDevice);
|
||||
mscclpp::gpuMemcpy<uint32_t>(commPtr->deviceFlag56.get(), initFlag.data(), 56, cudaMemcpyHostToDevice);
|
||||
}
|
||||
|
||||
NCCL_API ncclResult_t ncclGetVersion(int* version) {
|
||||
|
||||
Reference in New Issue
Block a user