Fix performance issue introduced in PR: 499 (#505)

1. use `fence+relaxed` to replace `release` for fifo. `fence+relax` is
more efficient on A100
2. Update the deviceSyncer. Previous one cannot handle threadBlock
number change correctly. Use three counters to solve this issue. Reset
previous counter before sync on current counter.
3. Introduce relaxedWait which can be used with relaxedSignal for case
doesn't need guarantee the memory visibility
This commit is contained in:
Binyang Li
2025-04-22 14:03:37 -07:00
committed by GitHub
parent e412804eab
commit 06f31994dc
6 changed files with 45 additions and 21 deletions

View File

@@ -134,16 +134,11 @@ __device__ void allreduce1_helper(mscclpp::MemoryChannelDeviceHandle* memChans,
int4* buff4 = (int4*)buff;
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
// synchronize everyone
if (tid == 0) {
__threadfence_system();
}
__syncthreads();
if (tid < nPeer) {
memChans[tid].relaxedSignal();
}
if (tid >= nPeer && tid < nPeer * 2) {
memChans[tid - nPeer].wait();
memChans[tid - nPeer].relaxedWait();
}
deviceSyncer.sync(gridDim.x);
@@ -193,15 +188,11 @@ __device__ void allreduce1_helper(mscclpp::MemoryChannelDeviceHandle* memChans,
// synchronize everyone again
deviceSyncer.sync(gridDim.x);
if (tid == 0) {
__threadfence_system();
}
__syncthreads();
if (tid < nPeer) {
memChans[tid].relaxedSignal();
}
if (tid >= nPeer && tid < nPeer * 2) {
memChans[tid - nPeer].wait();
memChans[tid - nPeer].relaxedWait();
}
if (READ_ONLY) {
@@ -437,7 +428,7 @@ __device__ void localReduceScatterMem(mscclpp::MemoryChannelDeviceHandle* memCha
memChans[peerIdx].relaxedSignal();
}
for (int peerIdx = threadIdx.x + blockIdx.x * blockDim.x; peerIdx < nPeer; peerIdx += blockDim.x * nBlocks) {
memChans[peerIdx].wait();
memChans[peerIdx].relaxedWait();
}
reduceScatterDeviceSyncer.sync(nBlocks);
@@ -497,7 +488,7 @@ __device__ void localAllGatherMem(mscclpp::MemoryChannelDeviceHandle* memChans,
}
if (threadIdx.x == 0 && peerLocalBlockIdx == 0) {
memChans[peerIdx].relaxedSignal();
memChans[peerIdx].wait();
memChans[peerIdx].relaxedWait();
}
allGatherDeviceSyncer.sync(nBlocks);
size_t offset = rankChunkSize * (startRankChunkIndex + remoteRankLocalIndex) + offsetInRankChunk;