mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 17:26:04 +00:00
Fix performance issue introduced in PR: 499 (#505)
1. use `fence+relaxed` to replace `release` for fifo. `fence+relax` is more efficient on A100 2. Update the deviceSyncer. Previous one cannot handle threadBlock number change correctly. Use three counters to solve this issue. Reset previous counter before sync on current counter. 3. Introduce relaxedWait which can be used with relaxedSignal for case doesn't need guarantee the memory visibility
This commit is contained in:
@@ -134,16 +134,11 @@ __device__ void allreduce1_helper(mscclpp::MemoryChannelDeviceHandle* memChans,
|
||||
int4* buff4 = (int4*)buff;
|
||||
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
|
||||
// synchronize everyone
|
||||
if (tid == 0) {
|
||||
__threadfence_system();
|
||||
}
|
||||
__syncthreads();
|
||||
if (tid < nPeer) {
|
||||
memChans[tid].relaxedSignal();
|
||||
}
|
||||
if (tid >= nPeer && tid < nPeer * 2) {
|
||||
memChans[tid - nPeer].wait();
|
||||
memChans[tid - nPeer].relaxedWait();
|
||||
}
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
|
||||
@@ -193,15 +188,11 @@ __device__ void allreduce1_helper(mscclpp::MemoryChannelDeviceHandle* memChans,
|
||||
|
||||
// synchronize everyone again
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
if (tid == 0) {
|
||||
__threadfence_system();
|
||||
}
|
||||
__syncthreads();
|
||||
if (tid < nPeer) {
|
||||
memChans[tid].relaxedSignal();
|
||||
}
|
||||
if (tid >= nPeer && tid < nPeer * 2) {
|
||||
memChans[tid - nPeer].wait();
|
||||
memChans[tid - nPeer].relaxedWait();
|
||||
}
|
||||
|
||||
if (READ_ONLY) {
|
||||
@@ -437,7 +428,7 @@ __device__ void localReduceScatterMem(mscclpp::MemoryChannelDeviceHandle* memCha
|
||||
memChans[peerIdx].relaxedSignal();
|
||||
}
|
||||
for (int peerIdx = threadIdx.x + blockIdx.x * blockDim.x; peerIdx < nPeer; peerIdx += blockDim.x * nBlocks) {
|
||||
memChans[peerIdx].wait();
|
||||
memChans[peerIdx].relaxedWait();
|
||||
}
|
||||
reduceScatterDeviceSyncer.sync(nBlocks);
|
||||
|
||||
@@ -497,7 +488,7 @@ __device__ void localAllGatherMem(mscclpp::MemoryChannelDeviceHandle* memChans,
|
||||
}
|
||||
if (threadIdx.x == 0 && peerLocalBlockIdx == 0) {
|
||||
memChans[peerIdx].relaxedSignal();
|
||||
memChans[peerIdx].wait();
|
||||
memChans[peerIdx].relaxedWait();
|
||||
}
|
||||
allGatherDeviceSyncer.sync(nBlocks);
|
||||
size_t offset = rankChunkSize * (startRankChunkIndex + remoteRankLocalIndex) + offsetInRankChunk;
|
||||
|
||||
Reference in New Issue
Block a user